fconv.py 15.6 KB
Newer Older
Sergey Edunov's avatar
Sergey Edunov committed
1
2
3
4
5
6
7
8
9
10
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#

import math
import torch
Myle Ott's avatar
Myle Ott committed
11
from torch.autograd import Variable
Sergey Edunov's avatar
Sergey Edunov committed
12
13
14
import torch.nn as nn
import torch.nn.functional as F

Myle Ott's avatar
Myle Ott committed
15
from fairseq.data import LanguagePairDataset
16
from fairseq.modules import BeamableMM, GradMultiply, LearnedPositionalEmbedding, LinearizedConvolution
Sergey Edunov's avatar
Sergey Edunov committed
17

Myle Ott's avatar
Myle Ott committed
18
from . import FairseqEncoder, FairseqIncrementalDecoder, FairseqModel
Sergey Edunov's avatar
Sergey Edunov committed
19
20


Myle Ott's avatar
Myle Ott committed
21
22
23
24
class FConvModel(FairseqModel):
    def __init__(self, encoder, decoder):
        super().__init__(encoder, decoder)
        self.encoder.num_attention_layers = sum(layer is not None for layer in decoder.attention)
Sergey Edunov's avatar
Sergey Edunov committed
25
26


Myle Ott's avatar
Myle Ott committed
27
class FConvEncoder(FairseqEncoder):
Sergey Edunov's avatar
Sergey Edunov committed
28
    """Convolutional encoder"""
29
30
    def __init__(self, dictionary, embed_dim=512, max_positions=1024,
                 convolutions=((512, 3),) * 20, dropout=0.1):
31
        super().__init__(dictionary)
Sergey Edunov's avatar
Sergey Edunov committed
32
33
        self.dropout = dropout
        self.num_attention_layers = None
34
35
36

        num_embeddings = len(dictionary)
        padding_idx = dictionary.pad()
Sergey Edunov's avatar
Sergey Edunov committed
37
        self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
38
39
        self.embed_positions = PositionalEmbedding(max_positions, embed_dim, padding_idx,
                                                   left_pad=LanguagePairDataset.LEFT_PAD_SOURCE)
Sergey Edunov's avatar
Sergey Edunov committed
40
41
42
43
44
45

        in_channels = convolutions[0][0]
        self.fc1 = Linear(embed_dim, in_channels, dropout=dropout)
        self.projections = nn.ModuleList()
        self.convolutions = nn.ModuleList()
        for (out_channels, kernel_size) in convolutions:
46
            pad = (kernel_size - 1) / 2
Sergey Edunov's avatar
Sergey Edunov committed
47
48
49
50
51
52
53
54
            self.projections.append(Linear(in_channels, out_channels)
                                    if in_channels != out_channels else None)
            self.convolutions.append(
                ConvTBC(in_channels, out_channels * 2, kernel_size, padding=pad,
                        dropout=dropout))
            in_channels = out_channels
        self.fc2 = Linear(in_channels, embed_dim)

Myle Ott's avatar
Myle Ott committed
55
    def forward(self, src_tokens):
Sergey Edunov's avatar
Sergey Edunov committed
56
        # embed tokens and positions
57
        x = self.embed_tokens(src_tokens) + self.embed_positions(src_tokens)
Sergey Edunov's avatar
Sergey Edunov committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
        x = F.dropout(x, p=self.dropout, training=self.training)
        input_embedding = x

        # project to size of convolution
        x = self.fc1(x)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # temporal convolutions
        for proj, conv in zip(self.projections, self.convolutions):
            residual = x if proj is None else proj(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = conv(x)
72
            x = F.glu(x, dim=2)
Sergey Edunov's avatar
Sergey Edunov committed
73
74
75
76
77
78
79
80
81
            x = (x + residual) * math.sqrt(0.5)

        # T x B x C -> B x T x C
        x = x.transpose(1, 0)

        # project back to size of embedding
        x = self.fc2(x)

        # scale gradients (this only affects backward, not forward)
Myle Ott's avatar
Myle Ott committed
82
        x = GradMultiply.apply(x, 1.0 / (2.0 * self.num_attention_layers))
Sergey Edunov's avatar
Sergey Edunov committed
83
84
85
86
87
88

        # add output to input embedding for attention
        y = (x + input_embedding) * math.sqrt(0.5)

        return x, y

Myle Ott's avatar
Myle Ott committed
89
90
    def max_positions(self):
        """Maximum input length supported by the encoder."""
91
        return self.embed_positions.max_positions()
Myle Ott's avatar
Myle Ott committed
92

Sergey Edunov's avatar
Sergey Edunov committed
93
94
95

class AttentionLayer(nn.Module):
    def __init__(self, conv_channels, embed_dim, bmm=None):
Myle Ott's avatar
Myle Ott committed
96
        super().__init__()
Sergey Edunov's avatar
Sergey Edunov committed
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
        # projects from output of convolution to embedding dimension
        self.in_projection = Linear(conv_channels, embed_dim)
        # projects from embedding dimension to convolution size
        self.out_projection = Linear(embed_dim, conv_channels)

        self.bmm = bmm if bmm is not None else torch.bmm

    def forward(self, x, target_embedding, encoder_out):
        residual = x

        # attention
        x = (self.in_projection(x) + target_embedding) * math.sqrt(0.5)
        x = self.bmm(x, encoder_out[0])

        # softmax over last dim
        sz = x.size()
113
        x = F.softmax(x.view(sz[0] * sz[1], sz[2]), dim=1)
Sergey Edunov's avatar
Sergey Edunov committed
114
115
116
117
118
119
120
121
122
123
124
125
126
        x = x.view(sz)
        attn_scores = x

        x = self.bmm(x, encoder_out[1])

        # scale attention output
        s = encoder_out[1].size(1)
        x = x * (s * math.sqrt(1.0 / s))

        # project back
        x = (self.out_projection(x) + residual) * math.sqrt(0.5)
        return x, attn_scores

Myle Ott's avatar
Myle Ott committed
127
128
129
    def make_generation_fast_(self, beamable_mm_beam_size=None, **kwargs):
        """Replace torch.bmm with BeamableMM."""
        if beamable_mm_beam_size is not None:
Myle Ott's avatar
Myle Ott committed
130
131
            del self.bmm
            self.add_module('bmm', BeamableMM(beamable_mm_beam_size))
Sergey Edunov's avatar
Sergey Edunov committed
132

Myle Ott's avatar
Myle Ott committed
133
134

class FConvDecoder(FairseqIncrementalDecoder):
Sergey Edunov's avatar
Sergey Edunov committed
135
    """Convolutional decoder"""
136
    def __init__(self, dictionary, embed_dim=512, out_embed_dim=256,
Sergey Edunov's avatar
Sergey Edunov committed
137
                 max_positions=1024, convolutions=((512, 3),) * 20,
Sergey Edunov's avatar
Sergey Edunov committed
138
                 attention=True, dropout=0.1, share_embed=False):
139
        super().__init__(dictionary)
140
        self.register_buffer('version', torch.Tensor([2]))
Sergey Edunov's avatar
Sergey Edunov committed
141
142
143
144
145
146
        self.dropout = dropout

        in_channels = convolutions[0][0]
        if isinstance(attention, bool):
            # expand True into [True, True, ...] and do the same with False
            attention = [attention] * len(convolutions)
147
148
149
        if not isinstance(attention, list) or len(attention) != len(convolutions):
            raise ValueError('Attention is expected to be a list of booleans of '
                             'length equal to the number of layers.')
Sergey Edunov's avatar
Sergey Edunov committed
150

151
152
        num_embeddings = len(dictionary)
        padding_idx = dictionary.pad()
Sergey Edunov's avatar
Sergey Edunov committed
153
        self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
154
155
        self.embed_positions = PositionalEmbedding(max_positions, embed_dim, padding_idx,
                                                   left_pad=LanguagePairDataset.LEFT_PAD_TARGET)
156

Sergey Edunov's avatar
Sergey Edunov committed
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
        self.fc1 = Linear(embed_dim, in_channels, dropout=dropout)
        self.projections = nn.ModuleList()
        self.convolutions = nn.ModuleList()
        self.attention = nn.ModuleList()
        for i, (out_channels, kernel_size) in enumerate(convolutions):
            pad = kernel_size - 1
            self.projections.append(Linear(in_channels, out_channels)
                                    if in_channels != out_channels else None)
            self.convolutions.append(
                LinearizedConv1d(in_channels, out_channels * 2, kernel_size,
                                 padding=pad, dropout=dropout))
            self.attention.append(AttentionLayer(out_channels, embed_dim)
                                  if attention[i] else None)
            in_channels = out_channels
        self.fc2 = Linear(in_channels, out_embed_dim)
Sergey Edunov's avatar
Sergey Edunov committed
172
173
174
175
176
177
178
179
        if share_embed:
            assert out_embed_dim == embed_dim, \
                "Shared embed weights implies same dimensions " \
                " out_embed_dim={} vs embed_dim={}".format(out_embed_dim, embed_dim)
            self.fc3 = nn.Linear(out_embed_dim, num_embeddings)
            self.fc3.weight = self.embed_tokens.weight
        else:
            self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout)
Sergey Edunov's avatar
Sergey Edunov committed
180

Myle Ott's avatar
Myle Ott committed
181
182
183
    def forward(self, input_tokens, encoder_out):
        # split and transpose encoder outputs
        encoder_a, encoder_b = self._split_encoder_out(encoder_out)
Sergey Edunov's avatar
Sergey Edunov committed
184

185
186
187
188
189
190
191
        # embed positions
        positions = self.embed_positions(input_tokens)

        if self._is_incremental_eval:
            # keep only the last token for incremental forward pass
            input_tokens = input_tokens[:, -1:]

Sergey Edunov's avatar
Sergey Edunov committed
192
        # embed tokens and positions
193
        x = self.embed_tokens(input_tokens) + positions
Sergey Edunov's avatar
Sergey Edunov committed
194
195
196
197
198
199
200
        x = F.dropout(x, p=self.dropout, training=self.training)
        target_embedding = x

        # project to size of convolution
        x = self.fc1(x)

        # B x T x C -> T x B x C
Myle Ott's avatar
Myle Ott committed
201
        x = self._transpose_unless_incremental_eval(x)
Sergey Edunov's avatar
Sergey Edunov committed
202
203

        # temporal convolutions
Myle Ott's avatar
Myle Ott committed
204
205
        avg_attn_scores = None
        num_attn_layers = len(self.attention)
Sergey Edunov's avatar
Sergey Edunov committed
206
207
208
209
210
211
        for proj, conv, attention in zip(self.projections, self.convolutions, self.attention):
            residual = x if proj is None else proj(x)

            x = F.dropout(x, p=self.dropout, training=self.training)
            x = conv(x)
            x = conv.remove_future_timesteps(x)
212
            x = F.glu(x, dim=2)
Sergey Edunov's avatar
Sergey Edunov committed
213
214
215

            # attention
            if attention is not None:
Myle Ott's avatar
Myle Ott committed
216
                x = self._transpose_unless_incremental_eval(x)
Sergey Edunov's avatar
Sergey Edunov committed
217
218
219
220
221
222

                x, attn_scores = attention(x, target_embedding, (encoder_a, encoder_b))
                attn_scores = attn_scores / num_attn_layers
                if avg_attn_scores is None:
                    avg_attn_scores = attn_scores
                else:
Myle Ott's avatar
Myle Ott committed
223
224
225
                    avg_attn_scores.add_(attn_scores)

                x = self._transpose_unless_incremental_eval(x)
Sergey Edunov's avatar
Sergey Edunov committed
226
227
228
229

            # residual
            x = (x + residual) * math.sqrt(0.5)

Myle Ott's avatar
Myle Ott committed
230
231
232
        # T x B x C -> B x T x C
        x = self._transpose_unless_incremental_eval(x)

Sergey Edunov's avatar
Sergey Edunov committed
233
234
        # project back to size of vocabulary
        x = self.fc2(x)
Myle Ott's avatar
Myle Ott committed
235
        x = F.dropout(x, p=self.dropout, training=self.training)
Sergey Edunov's avatar
Sergey Edunov committed
236
237
238
239
        x = self.fc3(x)

        return x, avg_attn_scores

Myle Ott's avatar
Myle Ott committed
240
241
242
243
    def reorder_incremental_state(self, new_order):
        """Reorder buffered internal state (for incremental generation)."""
        super().reorder_incremental_state(new_order)

Myle Ott's avatar
Myle Ott committed
244
245
    def max_positions(self):
        """Maximum output length supported by the decoder."""
246
        return self.embed_positions.max_positions()
Sergey Edunov's avatar
Sergey Edunov committed
247

248
249
250
251
252
253
254
255
256
257
    def upgrade_state_dict(self, state_dict):
        if state_dict.get('decoder.version', torch.Tensor([1]))[0] < 2:
            # old models use incorrect weight norm dimension
            for i, conv in enumerate(self.convolutions):
                # reconfigure weight norm
                nn.utils.remove_weight_norm(conv)
                self.convolutions[i] = nn.utils.weight_norm(conv, dim=0)
            state_dict['decoder.version'] = torch.Tensor([1])
        return state_dict

Myle Ott's avatar
Myle Ott committed
258
259
    def _split_encoder_out(self, encoder_out):
        """Split and transpose encoder outputs.
Sergey Edunov's avatar
Sergey Edunov committed
260

Myle Ott's avatar
Myle Ott committed
261
        This is cached when doing incremental inference.
Sergey Edunov's avatar
Sergey Edunov committed
262
        """
Myle Ott's avatar
Myle Ott committed
263
264
265
266
267
268
269
270
271
272
        cached_result = self.get_incremental_state('encoder_out')
        if cached_result:
            return cached_result

        # transpose only once to speed up attention layers
        encoder_a, encoder_b = encoder_out
        encoder_a = encoder_a.transpose(1, 2).contiguous()
        result = (encoder_a, encoder_b)

        return self.set_incremental_state('encoder_out', result)
Sergey Edunov's avatar
Sergey Edunov committed
273

Myle Ott's avatar
Myle Ott committed
274
275
276
277
    def _transpose_unless_incremental_eval(self, x):
        if self._is_incremental_eval:
            return x
        return x.transpose(0, 1)
Sergey Edunov's avatar
Sergey Edunov committed
278
279
280
281
282
283
284
285


def Embedding(num_embeddings, embedding_dim, padding_idx):
    m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
    m.weight.data.normal_(0, 0.1)
    return m


286
287
288
289
290
291
def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad):
    m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad)
    m.weight.data.normal_(0, 0.1)
    return m


Sergey Edunov's avatar
Sergey Edunov committed
292
293
294
295
296
297
298
299
300
301
302
303
304
305
def Linear(in_features, out_features, dropout=0):
    """Weight-normalized Linear layer (input: N x T x C)"""
    m = nn.Linear(in_features, out_features)
    m.weight.data.normal_(mean=0, std=math.sqrt((1 - dropout) / in_features))
    m.bias.data.zero_()
    return nn.utils.weight_norm(m)


def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs):
    """Weight-normalized Conv1d layer optimized for decoding"""
    m = LinearizedConvolution(in_channels, out_channels, kernel_size, **kwargs)
    std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels))
    m.weight.data.normal_(mean=0, std=std)
    m.bias.data.zero_()
306
    return nn.utils.weight_norm(m, dim=2)
Sergey Edunov's avatar
Sergey Edunov committed
307
308
309
310
311
312
313
314
315
316
317
318


def ConvTBC(in_channels, out_channels, kernel_size, dropout=0, **kwargs):
    """Weight-normalized Conv1d layer"""
    from fairseq.modules import ConvTBC
    m = ConvTBC(in_channels, out_channels, kernel_size, **kwargs)
    std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels))
    m.weight.data.normal_(mean=0, std=std)
    m.bias.data.zero_()
    return nn.utils.weight_norm(m, dim=2)


319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
def get_archs():
    return [
        'fconv', 'fconv_iwslt_de_en', 'fconv_wmt_en_ro', 'fconv_wmt_en_de', 'fconv_wmt_en_fr',
    ]


def _check_arch(args):
    """Check that the specified architecture is valid and not ambiguous."""
    if args.arch not in get_archs():
        raise ValueError('Unknown fconv model architecture: {}'.format(args.arch))
    if args.arch != 'fconv':
        # check that architecture is not ambiguous
        for a in ['encoder_embed_dim', 'encoder_layers', 'decoder_embed_dim', 'decoder_layers',
                  'decoder_out_embed_dim']:
            if hasattr(args, a):
                raise ValueError('--{} cannot be combined with --arch={}'.format(a, args.arch))


def parse_arch(args):
    _check_arch(args)

    if args.arch == 'fconv_iwslt_de_en':
        args.encoder_embed_dim = 256
        args.encoder_layers = '[(256, 3)] * 4'
        args.decoder_embed_dim = 256
        args.decoder_layers = '[(256, 3)] * 3'
        args.decoder_out_embed_dim = 256
    elif args.arch == 'fconv_wmt_en_ro':
        args.encoder_embed_dim = 512
        args.encoder_layers = '[(512, 3)] * 20'
        args.decoder_embed_dim = 512
        args.decoder_layers = '[(512, 3)] * 20'
        args.decoder_out_embed_dim = 512
    elif args.arch == 'fconv_wmt_en_de':
        convs = '[(512, 3)] * 9'       # first 9 layers have 512 units
        convs += ' + [(1024, 3)] * 4'  # next 4 layers have 1024 units
        convs += ' + [(2048, 1)] * 2'  # final 2 layers use 1x1 convolutions
        args.encoder_embed_dim = 768
        args.encoder_layers = convs
        args.decoder_embed_dim = 768
        args.decoder_layers = convs
        args.decoder_out_embed_dim = 512
    elif args.arch == 'fconv_wmt_en_fr':
        convs = '[(512, 3)] * 6'       # first 6 layers have 512 units
        convs += ' + [(768, 3)] * 4'   # next 4 layers have 768 units
        convs += ' + [(1024, 3)] * 3'  # next 3 layers have 1024 units
        convs += ' + [(2048, 1)] * 1'  # next 1 layer uses 1x1 convolutions
        convs += ' + [(4096, 1)] * 1'  # final 1 layer uses 1x1 convolutions
        args.encoder_embed_dim = 768
        args.encoder_layers = convs
        args.decoder_embed_dim = 768
        args.decoder_layers = convs
        args.decoder_out_embed_dim = 512
    else:
        assert args.arch == 'fconv'

    # default architecture
    args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
    args.encoder_layers = getattr(args, 'encoder_layers', '[(512, 3)] * 20')
    args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
    args.decoder_layers = getattr(args, 'decoder_layers', '[(512, 3)] * 20')
    args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 256)
    args.decoder_attention = getattr(args, 'decoder_attention', 'True')
Sergey Edunov's avatar
Sergey Edunov committed
382
    args.share_input_output_embed = getattr(args, 'share_input_output_embed', False)
383
384
385
    return args


386
def build_model(args, src_dict, dst_dict):
Myle Ott's avatar
Myle Ott committed
387
    encoder = FConvEncoder(
388
        src_dict,
389
390
391
        embed_dim=args.encoder_embed_dim,
        convolutions=eval(args.encoder_layers),
        dropout=args.dropout,
392
        max_positions=args.max_source_positions,
393
    )
Myle Ott's avatar
Myle Ott committed
394
    decoder = FConvDecoder(
395
        dst_dict,
396
397
398
399
400
        embed_dim=args.decoder_embed_dim,
        convolutions=eval(args.decoder_layers),
        out_embed_dim=args.decoder_out_embed_dim,
        attention=eval(args.decoder_attention),
        dropout=args.dropout,
401
        max_positions=args.max_target_positions,
Sergey Edunov's avatar
Sergey Edunov committed
402
        share_embed=args.share_input_output_embed
403
    )
404
    return FConvModel(encoder, decoder)