rec_nrtr_optim_head.py 33 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Topdu's avatar
Topdu committed
15
16
17
import math
import paddle
import copy
18
from paddle import nn
Topdu's avatar
Topdu committed
19
20
21
22
23
24
25
26
27
28
29
30
import paddle.nn.functional as F
from paddle.nn import LayerList
from paddle.nn.initializer import XavierNormal as xavier_uniform_
from paddle.nn import Dropout, Linear, LayerNorm, Conv2D
import numpy as np
from ppocr.modeling.heads.multiheadAttention import MultiheadAttentionOptim
from paddle.nn.initializer import Constant as constant_
from paddle.nn.initializer import XavierNormal as xavier_normal_

zeros_ = constant_(value=0.)
ones_ = constant_(value=1.)

31

Topdu's avatar
Topdu committed
32
class TransformerOptim(nn.Layer):
33
    """A transformer model. User is able to modify the attributes as needed. The architechture
Topdu's avatar
Topdu committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
    Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
    Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
    Processing Systems, pages 6000-6010.

    Args:
        d_model: the number of expected features in the encoder/decoder inputs (default=512).
        nhead: the number of heads in the multiheadattention models (default=8).
        num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
        num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        custom_encoder: custom encoder (default=None).
        custom_decoder: custom decoder (default=None).

    """

51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    def __init__(self,
                 d_model=512,
                 nhead=8,
                 num_encoder_layers=6,
                 beam_size=0,
                 num_decoder_layers=6,
                 dim_feedforward=1024,
                 attention_dropout_rate=0.0,
                 residual_dropout_rate=0.1,
                 custom_encoder=None,
                 custom_decoder=None,
                 in_channels=0,
                 out_channels=0,
                 dst_vocab_size=99,
                 scale_embedding=True):
Topdu's avatar
Topdu committed
66
67
68
69
70
        super(TransformerOptim, self).__init__()
        self.embedding = Embeddings(
            d_model=d_model,
            vocab=dst_vocab_size,
            padding_idx=0,
71
            scale_embedding=scale_embedding)
Topdu's avatar
Topdu committed
72
73
        self.positional_encoding = PositionalEncoding(
            dropout=residual_dropout_rate,
74
            dim=d_model, )
Topdu's avatar
Topdu committed
75
76
77
        if custom_encoder is not None:
            self.encoder = custom_encoder
        else:
78
79
80
81
82
83
            if num_encoder_layers > 0:
                encoder_layer = TransformerEncoderLayer(
                    d_model, nhead, dim_feedforward, attention_dropout_rate,
                    residual_dropout_rate)
                self.encoder = TransformerEncoder(encoder_layer,
                                                  num_encoder_layers)
Topdu's avatar
Topdu committed
84
85
86
87
88
89
            else:
                self.encoder = None

        if custom_decoder is not None:
            self.decoder = custom_decoder
        else:
90
91
92
            decoder_layer = TransformerDecoderLayer(
                d_model, nhead, dim_feedforward, attention_dropout_rate,
                residual_dropout_rate)
Topdu's avatar
Topdu committed
93
94
95
96
97
98
99
            self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers)

        self._reset_parameters()
        self.beam_size = beam_size
        self.d_model = d_model
        self.nhead = nhead
        self.tgt_word_prj = nn.Linear(d_model, dst_vocab_size, bias_attr=False)
100
101
        w0 = np.random.normal(0.0, d_model**-0.5,
                              (d_model, dst_vocab_size)).astype(np.float32)
Topdu's avatar
Topdu committed
102
103
104
105
        self.tgt_word_prj.weight.set_value(w0)
        self.apply(self._init_weights)

    def _init_weights(self, m):
106

Topdu's avatar
Topdu committed
107
108
109
110
111
        if isinstance(m, nn.Conv2D):
            xavier_normal_(m.weight)
            if m.bias is not None:
                zeros_(m.bias)

112
113
    def forward_train(self, src, tgt):
        tgt = tgt[:, :-1]
Topdu's avatar
Topdu committed
114

115
116
117
118
        tgt_key_padding_mask = self.generate_padding_mask(tgt)
        tgt = self.embedding(tgt).transpose([1, 0, 2])
        tgt = self.positional_encoding(tgt)
        tgt_mask = self.generate_square_subsequent_mask(tgt.shape[0])
Topdu's avatar
Topdu committed
119

120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
        if self.encoder is not None:
            src = self.positional_encoding(src.transpose([1, 0, 2]))
            memory = self.encoder(src)
        else:
            memory = src.squeeze(2).transpose([2, 0, 1])
        output = self.decoder(
            tgt,
            memory,
            tgt_mask=tgt_mask,
            memory_mask=None,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=None)
        output = output.transpose([1, 0, 2])
        logit = self.tgt_word_prj(output)
        return logit

    def forward(self, src, targets=None):
        """Take in and process masked source/target sequences.
Topdu's avatar
Topdu committed
138
139
140
141
142
143
144
        Args:
            src: the sequence to the encoder (required).
            tgt: the sequence to the decoder (required).
        Shape:
            - src: :math:`(S, N, E)`.
            - tgt: :math:`(T, N, E)`.
        Examples:
145
            >>> output = transformer_model(src, tgt)
Topdu's avatar
Topdu committed
146
        """
147
148
149
150

        if self.training:
            max_len = targets[1].max()
            tgt = targets[0][:, :2 + max_len]
Topdu's avatar
Topdu committed
151
152
            return self.forward_train(src, tgt)
        else:
153
            if self.beam_size > 0:
Topdu's avatar
Topdu committed
154
155
156
157
158
159
                return self.forward_beam(src)
            else:
                return self.forward_test(src)

    def forward_test(self, src):
        bs = src.shape[0]
160
        if self.encoder is not None:
Topdu's avatar
Topdu committed
161
162
163
164
            src = self.positional_encoding(src.transpose([1, 0, 2]))
            memory = self.encoder(src)
        else:
            memory = src.squeeze(2).transpose([2, 0, 1])
165
        dec_seq = paddle.full((bs, 1), 2, dtype=paddle.int64)
Topdu's avatar
Topdu committed
166
167
168
169
170
        for len_dec_seq in range(1, 25):
            src_enc = memory.clone()
            tgt_key_padding_mask = self.generate_padding_mask(dec_seq)
            dec_seq_embed = self.embedding(dec_seq).transpose([1, 0, 2])
            dec_seq_embed = self.positional_encoding(dec_seq_embed)
171
172
173
174
175
176
177
178
179
            tgt_mask = self.generate_square_subsequent_mask(dec_seq_embed.shape[
                0])
            output = self.decoder(
                dec_seq_embed,
                src_enc,
                tgt_mask=tgt_mask,
                memory_mask=None,
                tgt_key_padding_mask=tgt_key_padding_mask,
                memory_key_padding_mask=None)
Topdu's avatar
Topdu committed
180
            dec_output = output.transpose([1, 0, 2])
181
182
183

            dec_output = dec_output[:,
                                    -1, :]  # Pick the last step: (bh * bm) * d_h
Topdu's avatar
Topdu committed
184
185
186
            word_prob = F.log_softmax(self.tgt_word_prj(dec_output), axis=1)
            word_prob = word_prob.reshape([1, bs, -1])
            preds_idx = word_prob.argmax(axis=2)
187
188
189
190
191

            if paddle.equal_all(
                    preds_idx[-1],
                    paddle.full(
                        preds_idx[-1].shape, 3, dtype='int64')):
Topdu's avatar
Topdu committed
192
193
194
                break

            preds_prob = word_prob.max(axis=2)
195
196
            dec_seq = paddle.concat(
                [dec_seq, preds_idx.reshape([-1, 1])], axis=1)
Topdu's avatar
Topdu committed
197

198
        return dec_seq
Topdu's avatar
Topdu committed
199

200
    def forward_beam(self, images):
Topdu's avatar
Topdu committed
201
202
203
204
        ''' Translation work in one batch '''

        def get_inst_idx_to_tensor_position_map(inst_idx_list):
            ''' Indicate the position of an instance in a tensor. '''
205
206
207
208
            return {
                inst_idx: tensor_position
                for tensor_position, inst_idx in enumerate(inst_idx_list)
            }
Topdu's avatar
Topdu committed
209

210
211
        def collect_active_part(beamed_tensor, curr_active_inst_idx,
                                n_prev_active_inst, n_bm):
Topdu's avatar
Topdu committed
212
213
214
215
216
217
            ''' Collect tensor parts associated to active instances. '''

            _, *d_hs = beamed_tensor.shape
            n_curr_active_inst = len(curr_active_inst_idx)
            new_shape = (n_curr_active_inst * n_bm, *d_hs)

218
219
220
221
            beamed_tensor = beamed_tensor.reshape(
                [n_prev_active_inst, -1])  #contiguous()
            beamed_tensor = beamed_tensor.index_select(
                paddle.to_tensor(curr_active_inst_idx), axis=0)
Topdu's avatar
Topdu committed
222
223
224
225
            beamed_tensor = beamed_tensor.reshape([*new_shape])

            return beamed_tensor

226
227
        def collate_active_info(src_enc, inst_idx_to_position_map,
                                active_inst_idx_list):
Topdu's avatar
Topdu committed
228
229
            # Sentences which are still active are collected,
            # so the decoder will not run on completed sentences.
230

Topdu's avatar
Topdu committed
231
            n_prev_active_inst = len(inst_idx_to_position_map)
232
233
234
            active_inst_idx = [
                inst_idx_to_position_map[k] for k in active_inst_idx_list
            ]
Topdu's avatar
Topdu committed
235
            active_inst_idx = paddle.to_tensor(active_inst_idx, dtype='int64')
236
237
238
239
240
            active_src_enc = collect_active_part(
                src_enc.transpose([1, 0, 2]), active_inst_idx,
                n_prev_active_inst, n_bm).transpose([1, 0, 2])
            active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
                active_inst_idx_list)
Topdu's avatar
Topdu committed
241
242
            return active_src_enc, active_inst_idx_to_position_map

243
244
245
        def beam_decode_step(inst_dec_beams, len_dec_seq, enc_output,
                             inst_idx_to_position_map, n_bm,
                             memory_key_padding_mask):
Topdu's avatar
Topdu committed
246
247
248
            ''' Decode and update beam status, and then return active beam idx '''

            def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq):
249
250
251
                dec_partial_seq = [
                    b.get_current_state() for b in inst_dec_beams if not b.done
                ]
Topdu's avatar
Topdu committed
252
                dec_partial_seq = paddle.stack(dec_partial_seq)
253

Topdu's avatar
Topdu committed
254
255
256
                dec_partial_seq = dec_partial_seq.reshape([-1, len_dec_seq])
                return dec_partial_seq

257
258
            def prepare_beam_memory_key_padding_mask(
                    inst_dec_beams, memory_key_padding_mask, n_bm):
Topdu's avatar
Topdu committed
259
260
261
262
                keep = []
                for idx in (memory_key_padding_mask):
                    if not inst_dec_beams[idx].done:
                        keep.append(idx)
263
264
                memory_key_padding_mask = memory_key_padding_mask[
                    paddle.to_tensor(keep)]
Topdu's avatar
Topdu committed
265
266
                len_s = memory_key_padding_mask.shape[-1]
                n_inst = memory_key_padding_mask.shape[0]
267
268
269
270
                memory_key_padding_mask = paddle.concat(
                    [memory_key_padding_mask for i in range(n_bm)], axis=1)
                memory_key_padding_mask = memory_key_padding_mask.reshape(
                    [n_inst * n_bm, len_s])  #repeat(1, n_bm)
Topdu's avatar
Topdu committed
271
272
                return memory_key_padding_mask

273
274
            def predict_word(dec_seq, enc_output, n_active_inst, n_bm,
                             memory_key_padding_mask):
Topdu's avatar
Topdu committed
275
276
277
                tgt_key_padding_mask = self.generate_padding_mask(dec_seq)
                dec_seq = self.embedding(dec_seq).transpose([1, 0, 2])
                dec_seq = self.positional_encoding(dec_seq)
278
279
                tgt_mask = self.generate_square_subsequent_mask(dec_seq.shape[
                    0])
Topdu's avatar
Topdu committed
280
                dec_output = self.decoder(
281
282
                    dec_seq,
                    enc_output,
Topdu's avatar
Topdu committed
283
284
285
286
                    tgt_mask=tgt_mask,
                    tgt_key_padding_mask=tgt_key_padding_mask,
                    memory_key_padding_mask=memory_key_padding_mask,
                ).transpose([1, 0, 2])
287
288
                dec_output = dec_output[:,
                                        -1, :]  # Pick the last step: (bh * bm) * d_h
Topdu's avatar
Topdu committed
289
290
291
292
                word_prob = F.log_softmax(self.tgt_word_prj(dec_output), axis=1)
                word_prob = word_prob.reshape([n_active_inst, n_bm, -1])
                return word_prob

293
294
            def collect_active_inst_idx_list(inst_beams, word_prob,
                                             inst_idx_to_position_map):
Topdu's avatar
Topdu committed
295
296
                active_inst_idx_list = []
                for inst_idx, inst_position in inst_idx_to_position_map.items():
297
298
                    is_inst_complete = inst_beams[inst_idx].advance(word_prob[
                        inst_position])
Topdu's avatar
Topdu committed
299
300
301
302
303
304
305
306
                    if not is_inst_complete:
                        active_inst_idx_list += [inst_idx]

                return active_inst_idx_list

            n_active_inst = len(inst_idx_to_position_map)
            dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq)
            memory_key_padding_mask = None
307
308
            word_prob = predict_word(dec_seq, enc_output, n_active_inst, n_bm,
                                     memory_key_padding_mask)
Topdu's avatar
Topdu committed
309
310
311
312
313
314
315
316
317
318
            # Update the beam with predicted word prob information and collect incomplete instances
            active_inst_idx_list = collect_active_inst_idx_list(
                inst_dec_beams, word_prob, inst_idx_to_position_map)
            return active_inst_idx_list

        def collect_hypothesis_and_scores(inst_dec_beams, n_best):
            all_hyp, all_scores = [], []
            for inst_idx in range(len(inst_dec_beams)):
                scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores()
                all_scores += [scores[:n_best]]
319
320
321
322
                hyps = [
                    inst_dec_beams[inst_idx].get_hypothesis(i)
                    for i in tail_idxs[:n_best]
                ]
Topdu's avatar
Topdu committed
323
324
325
326
327
                all_hyp += [hyps]
            return all_hyp, all_scores

        with paddle.no_grad():
            #-- Encode
328
329

            if self.encoder is not None:
Topdu's avatar
Topdu committed
330
331
332
333
334
335
336
337
                src = self.positional_encoding(images.transpose([1, 0, 2]))
                src_enc = self.encoder(src).transpose([1, 0, 2])
            else:
                src_enc = images.squeeze(2).transpose([0, 2, 1])

            #-- Repeat data for beam search
            n_bm = self.beam_size
            n_inst, len_s, d_h = src_enc.shape
338
339
340
            src_enc = paddle.concat([src_enc for i in range(n_bm)], axis=1)
            src_enc = src_enc.reshape([n_inst * n_bm, len_s, d_h]).transpose(
                [1, 0, 2])  #repeat(1, n_bm, 1)
Topdu's avatar
Topdu committed
341
342
343
344
345
            #-- Prepare beams
            inst_dec_beams = [Beam(n_bm) for _ in range(n_inst)]

            #-- Bookkeeping for active or not
            active_inst_idx_list = list(range(n_inst))
346
347
            inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
                active_inst_idx_list)
Topdu's avatar
Topdu committed
348
349
350
351
            #-- Decode
            for len_dec_seq in range(1, 25):
                src_enc_copy = src_enc.clone()
                active_inst_idx_list = beam_decode_step(
352
353
                    inst_dec_beams, len_dec_seq, src_enc_copy,
                    inst_idx_to_position_map, n_bm, None)
Topdu's avatar
Topdu committed
354
355
356
                if not active_inst_idx_list:
                    break  # all instances have finished their path to <EOS>
                src_enc, inst_idx_to_position_map = collate_active_info(
357
358
359
360
                    src_enc_copy, inst_idx_to_position_map,
                    active_inst_idx_list)
        batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams,
                                                                1)
Topdu's avatar
Topdu committed
361
362
        result_hyp = []
        for bs_hyp in batch_hyp:
363
            bs_hyp_pad = bs_hyp[0] + [3] * (25 - len(bs_hyp[0]))
Topdu's avatar
Topdu committed
364
            result_hyp.append(bs_hyp_pad)
365
        return paddle.to_tensor(np.array(result_hyp), dtype=paddle.int64)
Topdu's avatar
Topdu committed
366
367

    def generate_square_subsequent_mask(self, sz):
368
        """Generate a square mask for the sequence. The masked positions are filled with float('-inf').
Topdu's avatar
Topdu committed
369
370
            Unmasked positions are filled with float(0.0).
        """
371
372
373
374
375
376
        mask = paddle.zeros([sz, sz], dtype='float32')
        mask_inf = paddle.triu(
            paddle.full(
                shape=[sz, sz], dtype='float32', fill_value='-inf'),
            diagonal=1)
        mask = mask + mask_inf
Topdu's avatar
Topdu committed
377
378
379
        return mask

    def generate_padding_mask(self, x):
380
        padding_mask = x.equal(paddle.to_tensor(0, dtype=x.dtype))
Topdu's avatar
Topdu committed
381
382
383
        return padding_mask

    def _reset_parameters(self):
384
        """Initiate parameters in the transformer model."""
Topdu's avatar
Topdu committed
385
386
387
388
389
390
391

        for p in self.parameters():
            if p.dim() > 1:
                xavier_uniform_(p)


class TransformerEncoder(nn.Layer):
392
    """TransformerEncoder is a stack of N encoder layers
Topdu's avatar
Topdu committed
393
394
395
396
397
398
399
400
401
402
403
404
    Args:
        encoder_layer: an instance of the TransformerEncoderLayer() class (required).
        num_layers: the number of sub-encoder-layers in the encoder (required).
        norm: the layer normalization component (optional).
    """

    def __init__(self, encoder_layer, num_layers):
        super(TransformerEncoder, self).__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers

    def forward(self, src):
405
        """Pass the input through the endocder layers in turn.
Topdu's avatar
Topdu committed
406
407
408
409
410
411
412
413
        Args:
            src: the sequnce to the encoder (required).
            mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).
        """
        output = src

        for i in range(self.num_layers):
414
415
            output = self.layers[i](output,
                                    src_mask=None,
Topdu's avatar
Topdu committed
416
417
418
419
420
421
                                    src_key_padding_mask=None)

        return output


class TransformerDecoder(nn.Layer):
422
    """TransformerDecoder is a stack of N decoder layers
Topdu's avatar
Topdu committed
423
424
425
426
427
428
429
430
431
432
433
434
435

    Args:
        decoder_layer: an instance of the TransformerDecoderLayer() class (required).
        num_layers: the number of sub-decoder-layers in the decoder (required).
        norm: the layer normalization component (optional).

    """

    def __init__(self, decoder_layer, num_layers):
        super(TransformerDecoder, self).__init__()
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers

436
437
438
439
440
441
    def forward(self,
                tgt,
                memory,
                tgt_mask=None,
                memory_mask=None,
                tgt_key_padding_mask=None,
Topdu's avatar
Topdu committed
442
                memory_key_padding_mask=None):
443
        """Pass the inputs (and mask) through the decoder layer in turn.
Topdu's avatar
Topdu committed
444
445
446
447
448
449
450
451
452
453
454

        Args:
            tgt: the sequence to the decoder (required).
            memory: the sequnce from the last layer of the encoder (required).
            tgt_mask: the mask for the tgt sequence (optional).
            memory_mask: the mask for the memory sequence (optional).
            tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
            memory_key_padding_mask: the mask for the memory keys per batch (optional).
        """
        output = tgt
        for i in range(self.num_layers):
455
456
457
458
459
460
461
            output = self.layers[i](
                output,
                memory,
                tgt_mask=tgt_mask,
                memory_mask=memory_mask,
                tgt_key_padding_mask=tgt_key_padding_mask,
                memory_key_padding_mask=memory_key_padding_mask)
Topdu's avatar
Topdu committed
462
463
464

        return output

465

Topdu's avatar
Topdu committed
466
class TransformerEncoderLayer(nn.Layer):
467
    """TransformerEncoderLayer is made up of self-attn and feedforward network.
Topdu's avatar
Topdu committed
468
469
470
471
472
473
474
475
476
477
478
479
480
481
    This standard encoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.

    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).

    """

482
483
484
485
486
487
    def __init__(self,
                 d_model,
                 nhead,
                 dim_feedforward=2048,
                 attention_dropout_rate=0.0,
                 residual_dropout_rate=0.1):
Topdu's avatar
Topdu committed
488
        super(TransformerEncoderLayer, self).__init__()
489
490
491
492
493
494
495
496
497
498
499
        self.self_attn = MultiheadAttentionOptim(
            d_model, nhead, dropout=attention_dropout_rate)

        self.conv1 = Conv2D(
            in_channels=d_model,
            out_channels=dim_feedforward,
            kernel_size=(1, 1))
        self.conv2 = Conv2D(
            in_channels=dim_feedforward,
            out_channels=d_model,
            kernel_size=(1, 1))
Topdu's avatar
Topdu committed
500
501
502
503
504
505
506

        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.dropout1 = Dropout(residual_dropout_rate)
        self.dropout2 = Dropout(residual_dropout_rate)

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
507
        """Pass the input through the endocder layer.
Topdu's avatar
Topdu committed
508
509
510
511
512
        Args:
            src: the sequnce to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).
        """
513
514
515
516
517
518
        src2 = self.self_attn(
            src,
            src,
            src,
            attn_mask=src_mask,
            key_padding_mask=src_key_padding_mask)[0]
Topdu's avatar
Topdu committed
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
        src = src + self.dropout1(src2)
        src = self.norm1(src)

        src = src.transpose([1, 2, 0])
        src = paddle.unsqueeze(src, 2)
        src2 = self.conv2(F.relu(self.conv1(src)))
        src2 = paddle.squeeze(src2, 2)
        src2 = src2.transpose([2, 0, 1])
        src = paddle.squeeze(src, 2)
        src = src.transpose([2, 0, 1])

        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

534

Topdu's avatar
Topdu committed
535
class TransformerDecoderLayer(nn.Layer):
536
    """TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
Topdu's avatar
Topdu committed
537
538
539
540
541
542
543
544
545
546
547
548
549
550
    This standard decoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.

    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).

    """

551
552
553
554
555
556
    def __init__(self,
                 d_model,
                 nhead,
                 dim_feedforward=2048,
                 attention_dropout_rate=0.0,
                 residual_dropout_rate=0.1):
Topdu's avatar
Topdu committed
557
        super(TransformerDecoderLayer, self).__init__()
558
559
560
561
562
563
564
565
566
567
568
569
570
        self.self_attn = MultiheadAttentionOptim(
            d_model, nhead, dropout=attention_dropout_rate)
        self.multihead_attn = MultiheadAttentionOptim(
            d_model, nhead, dropout=attention_dropout_rate)

        self.conv1 = Conv2D(
            in_channels=d_model,
            out_channels=dim_feedforward,
            kernel_size=(1, 1))
        self.conv2 = Conv2D(
            in_channels=dim_feedforward,
            out_channels=d_model,
            kernel_size=(1, 1))
Topdu's avatar
Topdu committed
571
572
573
574
575
576
577
578

        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.norm3 = LayerNorm(d_model)
        self.dropout1 = Dropout(residual_dropout_rate)
        self.dropout2 = Dropout(residual_dropout_rate)
        self.dropout3 = Dropout(residual_dropout_rate)

579
580
581
582
583
584
585
586
    def forward(self,
                tgt,
                memory,
                tgt_mask=None,
                memory_mask=None,
                tgt_key_padding_mask=None,
                memory_key_padding_mask=None):
        """Pass the inputs (and mask) through the decoder layer.
Topdu's avatar
Topdu committed
587
588
589
590
591
592
593
594
595
596

        Args:
            tgt: the sequence to the decoder layer (required).
            memory: the sequnce from the last layer of the encoder (required).
            tgt_mask: the mask for the tgt sequence (optional).
            memory_mask: the mask for the memory sequence (optional).
            tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
            memory_key_padding_mask: the mask for the memory keys per batch (optional).

        """
597
598
599
600
601
602
        tgt2 = self.self_attn(
            tgt,
            tgt,
            tgt,
            attn_mask=tgt_mask,
            key_padding_mask=tgt_key_padding_mask)[0]
Topdu's avatar
Topdu committed
603
604
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)
605
606
607
608
609
610
        tgt2 = self.multihead_attn(
            tgt,
            memory,
            memory,
            attn_mask=memory_mask,
            key_padding_mask=memory_key_padding_mask)[0]
Topdu's avatar
Topdu committed
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)

        # default
        tgt = tgt.transpose([1, 2, 0])
        tgt = paddle.unsqueeze(tgt, 2)
        tgt2 = self.conv2(F.relu(self.conv1(tgt)))
        tgt2 = paddle.squeeze(tgt2, 2)
        tgt2 = tgt2.transpose([2, 0, 1])
        tgt = paddle.squeeze(tgt, 2)
        tgt = tgt.transpose([2, 0, 1])

        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)
        return tgt


def _get_clones(module, N):
    return LayerList([copy.deepcopy(module) for i in range(N)])


class PositionalEncoding(nn.Layer):
633
    """Inject some information about the relative or absolute position of the tokens
Topdu's avatar
Topdu committed
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
        in the sequence. The positional encodings have the same dimension as
        the embeddings, so that the two can be summed. Here, we use sine and cosine
        functions of different frequencies.
    .. math::
        \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
        \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
        \text{where pos is the word position and i is the embed idx)
    Args:
        d_model: the embed dim (required).
        dropout: the dropout value (default=0.1).
        max_len: the max. length of the incoming sequence (default=5000).
    Examples:
        >>> pos_encoder = PositionalEncoding(d_model)
    """

    def __init__(self, dropout, dim, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = paddle.zeros([max_len, dim])
        position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1)
655
656
657
        div_term = paddle.exp(
            paddle.arange(0, dim, 2).astype('float32') *
            (-math.log(10000.0) / dim))
Topdu's avatar
Topdu committed
658
659
660
661
662
663
664
        pe[:, 0::2] = paddle.sin(position * div_term)
        pe[:, 1::2] = paddle.cos(position * div_term)
        pe = pe.unsqueeze(0)
        pe = pe.transpose([1, 0, 2])
        self.register_buffer('pe', pe)

    def forward(self, x):
665
        """Inputs of forward function
Topdu's avatar
Topdu committed
666
667
668
669
670
671
672
673
674
675
676
677
678
        Args:
            x: the sequence fed to the positional encoder model (required).
        Shape:
            x: [sequence length, batch size, embed dim]
            output: [sequence length, batch size, embed dim]
        Examples:
            >>> output = pos_encoder(x)
        """
        x = x + self.pe[:x.shape[0], :]
        return self.dropout(x)


class PositionalEncoding_2d(nn.Layer):
679
    """Inject some information about the relative or absolute position of the tokens
Topdu's avatar
Topdu committed
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
        in the sequence. The positional encodings have the same dimension as
        the embeddings, so that the two can be summed. Here, we use sine and cosine
        functions of different frequencies.
    .. math::
        \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
        \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
        \text{where pos is the word position and i is the embed idx)
    Args:
        d_model: the embed dim (required).
        dropout: the dropout value (default=0.1).
        max_len: the max. length of the incoming sequence (default=5000).
    Examples:
        >>> pos_encoder = PositionalEncoding(d_model)
    """

    def __init__(self, dropout, dim, max_len=5000):
        super(PositionalEncoding_2d, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = paddle.zeros([max_len, dim])
        position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1)
701
702
703
        div_term = paddle.exp(
            paddle.arange(0, dim, 2).astype('float32') *
            (-math.log(10000.0) / dim))
Topdu's avatar
Topdu committed
704
705
706
707
708
709
710
711
712
713
714
715
716
        pe[:, 0::2] = paddle.sin(position * div_term)
        pe[:, 1::2] = paddle.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose([1, 0, 2])
        self.register_buffer('pe', pe)

        self.avg_pool_1 = nn.AdaptiveAvgPool2D((1, 1))
        self.linear1 = nn.Linear(dim, dim)
        self.linear1.weight.data.fill_(1.)
        self.avg_pool_2 = nn.AdaptiveAvgPool2D((1, 1))
        self.linear2 = nn.Linear(dim, dim)
        self.linear2.weight.data.fill_(1.)

    def forward(self, x):
717
        """Inputs of forward function
Topdu's avatar
Topdu committed
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
        Args:
            x: the sequence fed to the positional encoder model (required).
        Shape:
            x: [sequence length, batch size, embed dim]
            output: [sequence length, batch size, embed dim]
        Examples:
            >>> output = pos_encoder(x)
        """
        w_pe = self.pe[:x.shape[-1], :]
        w1 = self.linear1(self.avg_pool_1(x).squeeze()).unsqueeze(0)
        w_pe = w_pe * w1
        w_pe = w_pe.transpose([1, 2, 0])
        w_pe = w_pe.unsqueeze(2)

        h_pe = self.pe[:x.shape[-2], :]
        w2 = self.linear2(self.avg_pool_2(x).squeeze()).unsqueeze(0)
        h_pe = h_pe * w2
        h_pe = h_pe.transpose([1, 2, 0])
        h_pe = h_pe.unsqueeze(3)

        x = x + w_pe + h_pe
739
740
741
        x = x.reshape(
            [x.shape[0], x.shape[1], x.shape[2] * x.shape[3]]).transpose(
                [2, 0, 1])
Topdu's avatar
Topdu committed
742
743
744
745
746
747
748
749

        return self.dropout(x)


class Embeddings(nn.Layer):
    def __init__(self, d_model, vocab, padding_idx, scale_embedding):
        super(Embeddings, self).__init__()
        self.embedding = nn.Embedding(vocab, d_model, padding_idx=padding_idx)
750
751
752
        w0 = np.random.normal(0.0, d_model**-0.5,
                              (vocab, d_model)).astype(np.float32)
        self.embedding.weight.set_value(w0)
Topdu's avatar
Topdu committed
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
        self.d_model = d_model
        self.scale_embedding = scale_embedding

    def forward(self, x):
        if self.scale_embedding:
            x = self.embedding(x)
            return x * math.sqrt(self.d_model)
        return self.embedding(x)


class Beam():
    ''' Beam search '''

    def __init__(self, size, device=False):

        self.size = size
        self._done = False
        # The score for each translation on the beam.
771
        self.scores = paddle.zeros((size, ), dtype=paddle.float32)
Topdu's avatar
Topdu committed
772
773
774
775
        self.all_scores = []
        # The backpointers at each time-step.
        self.prev_ks = []
        # The outputs at each time-step.
776
        self.next_ys = [paddle.full((size, ), 0, dtype=paddle.int64)]
Topdu's avatar
Topdu committed
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
        self.next_ys[0][0] = 2

    def get_current_state(self):
        "Get the outputs for the current timestep."
        return self.get_tentative_hypothesis()

    def get_current_origin(self):
        "Get the backpointers for the current timestep."
        return self.prev_ks[-1]

    @property
    def done(self):
        return self._done

    def advance(self, word_prob):
        "Update beam status and check if finished or not."
        num_words = word_prob.shape[1]

        # Sum the previous scores.
        if len(self.prev_ks) > 0:
            beam_lk = word_prob + self.scores.unsqueeze(1).expand_as(word_prob)
        else:
            beam_lk = word_prob[0]

        flat_beam_lk = beam_lk.reshape([-1])
802
803
        best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True,
                                                        True)  # 1st sort
Topdu's avatar
Topdu committed
804
805
806
807
808
809
        self.all_scores.append(self.scores)
        self.scores = best_scores
        # bestScoresId is flattened as a (beam x word) array,
        # so we need to calculate which word and beam each score came from
        prev_k = best_scores_id // num_words
        self.prev_ks.append(prev_k)
810
        self.next_ys.append(best_scores_id - prev_k * num_words)
Topdu's avatar
Topdu committed
811
        # End condition is when top-of-beam is EOS.
812
        if self.next_ys[-1][0] == 3:
Topdu's avatar
Topdu committed
813
814
815
816
817
818
819
            self._done = True
            self.all_scores.append(self.scores)

        return self._done

    def sort_scores(self):
        "Sort the scores."
820
821
        return self.scores, paddle.to_tensor(
            [i for i in range(self.scores.shape[0])], dtype='int32')
Topdu's avatar
Topdu committed
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842

    def get_the_best_score_and_idx(self):
        "Get the score of the best in the beam."
        scores, ids = self.sort_scores()
        return scores[1], ids[1]

    def get_tentative_hypothesis(self):
        "Get the decoded sequence for the current timestep."
        if len(self.next_ys) == 1:
            dec_seq = self.next_ys[0].unsqueeze(1)
        else:
            _, keys = self.sort_scores()
            hyps = [self.get_hypothesis(k) for k in keys]
            hyps = [[2] + h for h in hyps]
            dec_seq = paddle.to_tensor(hyps, dtype='int64')
        return dec_seq

    def get_hypothesis(self, k):
        """ Walk back to construct the full hypothesis. """
        hyp = []
        for j in range(len(self.prev_ks) - 1, -1, -1):
843
            hyp.append(self.next_ys[j + 1][k])
Topdu's avatar
Topdu committed
844
845
            k = self.prev_ks[j][k]
        return list(map(lambda x: x.item(), hyp[::-1]))