fconv.py 16.4 KB
Newer Older
Sergey Edunov's avatar
Sergey Edunov committed
1
2
3
4
5
6
7
8
9
10
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#

import math
import torch
Myle Ott's avatar
Myle Ott committed
11
from torch.autograd import Variable
Sergey Edunov's avatar
Sergey Edunov committed
12
13
14
import torch.nn as nn
import torch.nn.functional as F

Myle Ott's avatar
Myle Ott committed
15
16
from fairseq.data import LanguagePairDataset
from fairseq.modules import BeamableMM, GradMultiply, LinearizedConvolution
Sergey Edunov's avatar
Sergey Edunov committed
17

Myle Ott's avatar
Myle Ott committed
18
from . import FairseqEncoder, FairseqIncrementalDecoder, FairseqModel
Sergey Edunov's avatar
Sergey Edunov committed
19
20


Myle Ott's avatar
Myle Ott committed
21
22
23
24
25
26
27
28
29
30
31
32
33
def make_positions(tokens, padding_idx, left_pad, offset=0):
    seqlen = tokens.size(1)
    if not hasattr(make_positions, 'range'):
        make_positions.range = tokens.new()
    if make_positions.range.numel() < offset + seqlen:
        # offset positions by the padding index
        torch.arange(padding_idx + 1, padding_idx + 1 + offset + seqlen,
                     out=make_positions.range)
    mask = tokens.ne(padding_idx)
    positions = make_positions.range[offset:offset+seqlen].expand_as(tokens)
    if left_pad:
        positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1)
    return tokens.clone().masked_scatter_(mask, positions[mask])
Sergey Edunov's avatar
Sergey Edunov committed
34
35


Myle Ott's avatar
Myle Ott committed
36
37
38
39
class FConvModel(FairseqModel):
    def __init__(self, encoder, decoder):
        super().__init__(encoder, decoder)
        self.encoder.num_attention_layers = sum(layer is not None for layer in decoder.attention)
Sergey Edunov's avatar
Sergey Edunov committed
40
41


Myle Ott's avatar
Myle Ott committed
42
class FConvEncoder(FairseqEncoder):
Sergey Edunov's avatar
Sergey Edunov committed
43
    """Convolutional encoder"""
44
45
    def __init__(self, dictionary, embed_dim=512, max_positions=1024,
                 convolutions=((512, 3),) * 20, dropout=0.1):
Myle Ott's avatar
Myle Ott committed
46
        super().__init__()
47
        self.dictionary = dictionary
Sergey Edunov's avatar
Sergey Edunov committed
48
49
        self.dropout = dropout
        self.num_attention_layers = None
50
51
52

        num_embeddings = len(dictionary)
        padding_idx = dictionary.pad()
Sergey Edunov's avatar
Sergey Edunov committed
53
54
55
56
57
58
59
60
        self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
        self.embed_positions = Embedding(max_positions, embed_dim, padding_idx)

        in_channels = convolutions[0][0]
        self.fc1 = Linear(embed_dim, in_channels, dropout=dropout)
        self.projections = nn.ModuleList()
        self.convolutions = nn.ModuleList()
        for (out_channels, kernel_size) in convolutions:
61
            pad = (kernel_size - 1) / 2
Sergey Edunov's avatar
Sergey Edunov committed
62
63
64
65
66
67
68
69
            self.projections.append(Linear(in_channels, out_channels)
                                    if in_channels != out_channels else None)
            self.convolutions.append(
                ConvTBC(in_channels, out_channels * 2, kernel_size, padding=pad,
                        dropout=dropout))
            in_channels = out_channels
        self.fc2 = Linear(in_channels, embed_dim)

Myle Ott's avatar
Myle Ott committed
70
71
72
73
    def forward(self, src_tokens):
        positions = Variable(make_positions(src_tokens.data, self.dictionary.pad(),
                                            left_pad=LanguagePairDataset.LEFT_PAD_SOURCE))

Sergey Edunov's avatar
Sergey Edunov committed
74
        # embed tokens and positions
Myle Ott's avatar
Myle Ott committed
75
        x = self.embed_tokens(src_tokens) + self.embed_positions(positions)
Sergey Edunov's avatar
Sergey Edunov committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        x = F.dropout(x, p=self.dropout, training=self.training)
        input_embedding = x

        # project to size of convolution
        x = self.fc1(x)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # temporal convolutions
        for proj, conv in zip(self.projections, self.convolutions):
            residual = x if proj is None else proj(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = conv(x)
90
            x = F.glu(x, dim=2)
Sergey Edunov's avatar
Sergey Edunov committed
91
92
93
94
95
96
97
98
99
            x = (x + residual) * math.sqrt(0.5)

        # T x B x C -> B x T x C
        x = x.transpose(1, 0)

        # project back to size of embedding
        x = self.fc2(x)

        # scale gradients (this only affects backward, not forward)
Myle Ott's avatar
Myle Ott committed
100
        x = GradMultiply.apply(x, 1.0 / (2.0 * self.num_attention_layers))
Sergey Edunov's avatar
Sergey Edunov committed
101
102
103
104
105
106

        # add output to input embedding for attention
        y = (x + input_embedding) * math.sqrt(0.5)

        return x, y

Myle Ott's avatar
Myle Ott committed
107
108
109
110
    def max_positions(self):
        """Maximum input length supported by the encoder."""
        return self.embed_positions.num_embeddings - self.dictionary.pad() - 1

Sergey Edunov's avatar
Sergey Edunov committed
111
112
113

class AttentionLayer(nn.Module):
    def __init__(self, conv_channels, embed_dim, bmm=None):
Myle Ott's avatar
Myle Ott committed
114
        super().__init__()
Sergey Edunov's avatar
Sergey Edunov committed
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
        # projects from output of convolution to embedding dimension
        self.in_projection = Linear(conv_channels, embed_dim)
        # projects from embedding dimension to convolution size
        self.out_projection = Linear(embed_dim, conv_channels)

        self.bmm = bmm if bmm is not None else torch.bmm

    def forward(self, x, target_embedding, encoder_out):
        residual = x

        # attention
        x = (self.in_projection(x) + target_embedding) * math.sqrt(0.5)
        x = self.bmm(x, encoder_out[0])

        # softmax over last dim
        sz = x.size()
131
        x = F.softmax(x.view(sz[0] * sz[1], sz[2]), dim=1)
Sergey Edunov's avatar
Sergey Edunov committed
132
133
134
135
136
137
138
139
140
141
142
143
144
        x = x.view(sz)
        attn_scores = x

        x = self.bmm(x, encoder_out[1])

        # scale attention output
        s = encoder_out[1].size(1)
        x = x * (s * math.sqrt(1.0 / s))

        # project back
        x = (self.out_projection(x) + residual) * math.sqrt(0.5)
        return x, attn_scores

Myle Ott's avatar
Myle Ott committed
145
146
147
    def make_generation_fast_(self, beamable_mm_beam_size=None, **kwargs):
        """Replace torch.bmm with BeamableMM."""
        if beamable_mm_beam_size is not None:
Myle Ott's avatar
Myle Ott committed
148
149
            del self.bmm
            self.add_module('bmm', BeamableMM(beamable_mm_beam_size))
Sergey Edunov's avatar
Sergey Edunov committed
150

Myle Ott's avatar
Myle Ott committed
151
152

class FConvDecoder(FairseqIncrementalDecoder):
Sergey Edunov's avatar
Sergey Edunov committed
153
    """Convolutional decoder"""
154
    def __init__(self, dictionary, embed_dim=512, out_embed_dim=256,
Sergey Edunov's avatar
Sergey Edunov committed
155
                 max_positions=1024, convolutions=((512, 3),) * 20,
156
                 attention=True, dropout=0.1):
Myle Ott's avatar
Myle Ott committed
157
        super().__init__()
158
        self.register_buffer('version', torch.Tensor([2]))
159
        self.dictionary = dictionary
Sergey Edunov's avatar
Sergey Edunov committed
160
161
162
163
164
165
        self.dropout = dropout

        in_channels = convolutions[0][0]
        if isinstance(attention, bool):
            # expand True into [True, True, ...] and do the same with False
            attention = [attention] * len(convolutions)
166
167
168
        if not isinstance(attention, list) or len(attention) != len(convolutions):
            raise ValueError('Attention is expected to be a list of booleans of '
                             'length equal to the number of layers.')
Sergey Edunov's avatar
Sergey Edunov committed
169

170
171
        num_embeddings = len(dictionary)
        padding_idx = dictionary.pad()
Sergey Edunov's avatar
Sergey Edunov committed
172
173
        self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
        self.embed_positions = Embedding(max_positions, embed_dim, padding_idx)
174

Sergey Edunov's avatar
Sergey Edunov committed
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        self.fc1 = Linear(embed_dim, in_channels, dropout=dropout)
        self.projections = nn.ModuleList()
        self.convolutions = nn.ModuleList()
        self.attention = nn.ModuleList()
        for i, (out_channels, kernel_size) in enumerate(convolutions):
            pad = kernel_size - 1
            self.projections.append(Linear(in_channels, out_channels)
                                    if in_channels != out_channels else None)
            self.convolutions.append(
                LinearizedConv1d(in_channels, out_channels * 2, kernel_size,
                                 padding=pad, dropout=dropout))
            self.attention.append(AttentionLayer(out_channels, embed_dim)
                                  if attention[i] else None)
            in_channels = out_channels
        self.fc2 = Linear(in_channels, out_embed_dim)
        self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout)

Myle Ott's avatar
Myle Ott committed
192
    def forward(self, input_tokens, encoder_out):
193
194
195
196
197
198
199
        if self._is_incremental_eval:
            return self.incremental_forward(input_tokens, encoder_out)
        else:
            return self.batch_forward(input_tokens, encoder_out)

    def batch_forward(self, input_tokens, encoder_out):
        """Forward pass for decoding multiple time steps in batch mode."""
Myle Ott's avatar
Myle Ott committed
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        positions = Variable(make_positions(input_tokens.data, self.dictionary.pad(),
                                            left_pad=LanguagePairDataset.LEFT_PAD_TARGET))
        return self._forward(input_tokens, positions, encoder_out)

    def incremental_forward(self, input_tokens, encoder_out):
        """Forward pass for one time step."""
        # positions is the same for every token when decoding a single step
        positions = Variable(input_tokens.data.new(1, 1).fill_(
            self.dictionary.pad() + input_tokens.size(1)))

        # keep only the last token for incremental forward pass
        return self._forward(input_tokens[:, -1:], positions, encoder_out)

    def _forward(self, input_tokens, positions, encoder_out):
        # split and transpose encoder outputs
        encoder_a, encoder_b = self._split_encoder_out(encoder_out)
Sergey Edunov's avatar
Sergey Edunov committed
216
217

        # embed tokens and positions
Myle Ott's avatar
Myle Ott committed
218
        x = self.embed_tokens(input_tokens) + self.embed_positions(positions)
Sergey Edunov's avatar
Sergey Edunov committed
219
220
221
222
223
224
225
        x = F.dropout(x, p=self.dropout, training=self.training)
        target_embedding = x

        # project to size of convolution
        x = self.fc1(x)

        # B x T x C -> T x B x C
Myle Ott's avatar
Myle Ott committed
226
        x = self._transpose_unless_incremental_eval(x)
Sergey Edunov's avatar
Sergey Edunov committed
227
228

        # temporal convolutions
Myle Ott's avatar
Myle Ott committed
229
230
        avg_attn_scores = None
        num_attn_layers = len(self.attention)
Sergey Edunov's avatar
Sergey Edunov committed
231
232
233
234
235
236
        for proj, conv, attention in zip(self.projections, self.convolutions, self.attention):
            residual = x if proj is None else proj(x)

            x = F.dropout(x, p=self.dropout, training=self.training)
            x = conv(x)
            x = conv.remove_future_timesteps(x)
237
            x = F.glu(x, dim=2)
Sergey Edunov's avatar
Sergey Edunov committed
238
239
240

            # attention
            if attention is not None:
Myle Ott's avatar
Myle Ott committed
241
                x = self._transpose_unless_incremental_eval(x)
Sergey Edunov's avatar
Sergey Edunov committed
242
243
244
245
246
247

                x, attn_scores = attention(x, target_embedding, (encoder_a, encoder_b))
                attn_scores = attn_scores / num_attn_layers
                if avg_attn_scores is None:
                    avg_attn_scores = attn_scores
                else:
Myle Ott's avatar
Myle Ott committed
248
249
250
                    avg_attn_scores.add_(attn_scores)

                x = self._transpose_unless_incremental_eval(x)
Sergey Edunov's avatar
Sergey Edunov committed
251
252
253
254

            # residual
            x = (x + residual) * math.sqrt(0.5)

Myle Ott's avatar
Myle Ott committed
255
256
257
        # T x B x C -> B x T x C
        x = self._transpose_unless_incremental_eval(x)

Sergey Edunov's avatar
Sergey Edunov committed
258
259
        # project back to size of vocabulary
        x = self.fc2(x)
Myle Ott's avatar
Myle Ott committed
260
        x = F.dropout(x, p=self.dropout, training=self.training)
Sergey Edunov's avatar
Sergey Edunov committed
261
262
263
264
        x = self.fc3(x)

        return x, avg_attn_scores

Myle Ott's avatar
Myle Ott committed
265
266
267
268
    def reorder_incremental_state(self, new_order):
        """Reorder buffered internal state (for incremental generation)."""
        super().reorder_incremental_state(new_order)

Myle Ott's avatar
Myle Ott committed
269
270
271
    def max_positions(self):
        """Maximum output length supported by the decoder."""
        return self.embed_positions.num_embeddings - self.dictionary.pad() - 1
Sergey Edunov's avatar
Sergey Edunov committed
272

273
274
275
276
277
278
279
280
281
282
    def upgrade_state_dict(self, state_dict):
        if state_dict.get('decoder.version', torch.Tensor([1]))[0] < 2:
            # old models use incorrect weight norm dimension
            for i, conv in enumerate(self.convolutions):
                # reconfigure weight norm
                nn.utils.remove_weight_norm(conv)
                self.convolutions[i] = nn.utils.weight_norm(conv, dim=0)
            state_dict['decoder.version'] = torch.Tensor([1])
        return state_dict

Myle Ott's avatar
Myle Ott committed
283
284
    def _split_encoder_out(self, encoder_out):
        """Split and transpose encoder outputs.
Sergey Edunov's avatar
Sergey Edunov committed
285

Myle Ott's avatar
Myle Ott committed
286
        This is cached when doing incremental inference.
Sergey Edunov's avatar
Sergey Edunov committed
287
        """
Myle Ott's avatar
Myle Ott committed
288
289
290
291
292
293
294
295
296
297
        cached_result = self.get_incremental_state('encoder_out')
        if cached_result:
            return cached_result

        # transpose only once to speed up attention layers
        encoder_a, encoder_b = encoder_out
        encoder_a = encoder_a.transpose(1, 2).contiguous()
        result = (encoder_a, encoder_b)

        return self.set_incremental_state('encoder_out', result)
Sergey Edunov's avatar
Sergey Edunov committed
298

Myle Ott's avatar
Myle Ott committed
299
300
301
302
    def _transpose_unless_incremental_eval(self, x):
        if self._is_incremental_eval:
            return x
        return x.transpose(0, 1)
Sergey Edunov's avatar
Sergey Edunov committed
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324


def Embedding(num_embeddings, embedding_dim, padding_idx):
    m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
    m.weight.data.normal_(0, 0.1)
    return m


def Linear(in_features, out_features, dropout=0):
    """Weight-normalized Linear layer (input: N x T x C)"""
    m = nn.Linear(in_features, out_features)
    m.weight.data.normal_(mean=0, std=math.sqrt((1 - dropout) / in_features))
    m.bias.data.zero_()
    return nn.utils.weight_norm(m)


def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs):
    """Weight-normalized Conv1d layer optimized for decoding"""
    m = LinearizedConvolution(in_channels, out_channels, kernel_size, **kwargs)
    std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels))
    m.weight.data.normal_(mean=0, std=std)
    m.bias.data.zero_()
325
    return nn.utils.weight_norm(m, dim=2)
Sergey Edunov's avatar
Sergey Edunov committed
326
327
328
329
330
331
332
333
334
335
336
337


def ConvTBC(in_channels, out_channels, kernel_size, dropout=0, **kwargs):
    """Weight-normalized Conv1d layer"""
    from fairseq.modules import ConvTBC
    m = ConvTBC(in_channels, out_channels, kernel_size, **kwargs)
    std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels))
    m.weight.data.normal_(mean=0, std=std)
    m.bias.data.zero_()
    return nn.utils.weight_norm(m, dim=2)


338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
def get_archs():
    return [
        'fconv', 'fconv_iwslt_de_en', 'fconv_wmt_en_ro', 'fconv_wmt_en_de', 'fconv_wmt_en_fr',
    ]


def _check_arch(args):
    """Check that the specified architecture is valid and not ambiguous."""
    if args.arch not in get_archs():
        raise ValueError('Unknown fconv model architecture: {}'.format(args.arch))
    if args.arch != 'fconv':
        # check that architecture is not ambiguous
        for a in ['encoder_embed_dim', 'encoder_layers', 'decoder_embed_dim', 'decoder_layers',
                  'decoder_out_embed_dim']:
            if hasattr(args, a):
                raise ValueError('--{} cannot be combined with --arch={}'.format(a, args.arch))


def parse_arch(args):
    _check_arch(args)

    if args.arch == 'fconv_iwslt_de_en':
        args.encoder_embed_dim = 256
        args.encoder_layers = '[(256, 3)] * 4'
        args.decoder_embed_dim = 256
        args.decoder_layers = '[(256, 3)] * 3'
        args.decoder_out_embed_dim = 256
    elif args.arch == 'fconv_wmt_en_ro':
        args.encoder_embed_dim = 512
        args.encoder_layers = '[(512, 3)] * 20'
        args.decoder_embed_dim = 512
        args.decoder_layers = '[(512, 3)] * 20'
        args.decoder_out_embed_dim = 512
    elif args.arch == 'fconv_wmt_en_de':
        convs = '[(512, 3)] * 9'       # first 9 layers have 512 units
        convs += ' + [(1024, 3)] * 4'  # next 4 layers have 1024 units
        convs += ' + [(2048, 1)] * 2'  # final 2 layers use 1x1 convolutions
        args.encoder_embed_dim = 768
        args.encoder_layers = convs
        args.decoder_embed_dim = 768
        args.decoder_layers = convs
        args.decoder_out_embed_dim = 512
    elif args.arch == 'fconv_wmt_en_fr':
        convs = '[(512, 3)] * 6'       # first 6 layers have 512 units
        convs += ' + [(768, 3)] * 4'   # next 4 layers have 768 units
        convs += ' + [(1024, 3)] * 3'  # next 3 layers have 1024 units
        convs += ' + [(2048, 1)] * 1'  # next 1 layer uses 1x1 convolutions
        convs += ' + [(4096, 1)] * 1'  # final 1 layer uses 1x1 convolutions
        args.encoder_embed_dim = 768
        args.encoder_layers = convs
        args.decoder_embed_dim = 768
        args.decoder_layers = convs
        args.decoder_out_embed_dim = 512
    else:
        assert args.arch == 'fconv'

    # default architecture
    args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
    args.encoder_layers = getattr(args, 'encoder_layers', '[(512, 3)] * 20')
    args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
    args.decoder_layers = getattr(args, 'decoder_layers', '[(512, 3)] * 20')
    args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 256)
    args.decoder_attention = getattr(args, 'decoder_attention', 'True')
    return args


404
def build_model(args, src_dict, dst_dict):
Myle Ott's avatar
Myle Ott committed
405
    encoder = FConvEncoder(
406
        src_dict,
407
408
409
        embed_dim=args.encoder_embed_dim,
        convolutions=eval(args.encoder_layers),
        dropout=args.dropout,
410
        max_positions=args.max_source_positions,
411
    )
Myle Ott's avatar
Myle Ott committed
412
    decoder = FConvDecoder(
413
        dst_dict,
414
415
416
417
418
        embed_dim=args.decoder_embed_dim,
        convolutions=eval(args.decoder_layers),
        out_embed_dim=args.decoder_out_embed_dim,
        attention=eval(args.decoder_attention),
        dropout=args.dropout,
419
        max_positions=args.max_target_positions,
420
    )
421
    return FConvModel(encoder, decoder)