language_model.py 23.6 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer based language model."""

import torch
import torch.nn.functional as F

Mohammad's avatar
Mohammad committed
21
from megatron import get_args
22
from megatron import mpu
23
from .module import MegatronModule
24
from megatron.model.enums import LayerType, AttnMaskType
Mohammad's avatar
Mohammad committed
25
26
from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import get_linear_layer
27
from megatron.model.utils import init_method_normal, scaled_init_method_normal
28
29
30
31
32

def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
                       bias=None):
    """LM logits using word embedding weights."""
    # Parallel logits.
33
    input_parallel = mpu.copy_to_tensor_model_parallel_region(input_)
34
35
36
37
38
39
40
41
    # Matrix multiply.
    if bias is None:
        logits_parallel = F.linear(input_parallel, word_embeddings_weight)
    else:
        logits_parallel = F.linear(input_parallel, word_embeddings_weight, bias)
    # Gather if needed.
    if parallel_output:
        return logits_parallel
Mohammad's avatar
Mohammad committed
42

43
    return mpu.gather_from_tensor_model_parallel_region(logits_parallel)
Mohammad's avatar
Mohammad committed
44
45
46


def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
47
48
49
                       add_decoder=False, init_method=None,
                       scaled_init_method=None,
                       self_attn_mask_type=AttnMaskType.padding):
Mohammad's avatar
Mohammad committed
50
    """Build language model and return along with the key to save."""
51
    args = get_args()
Mohammad's avatar
Mohammad committed
52

53
54
55
56
    if init_method is None:
        init_method = init_method_normal(args.init_method_std)

    if scaled_init_method is None:
57
58
        scaled_init_method = scaled_init_method_normal(args.init_method_std,
                                                       args.num_layers)
59

60
    # Language model.
61
62
63
    args = [attention_mask_func, init_method, scaled_init_method]
    kwargs = {}
    cls = None
64
    if mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage():
65
66
        cls = TransformerLanguageModel
        kwargs['num_tokentypes'] = num_tokentypes
67
68
        kwargs['self_attn_mask_type'] = self_attn_mask_type
        kwargs['add_decoder'] = add_decoder
69
        kwargs['add_pooler'] = add_pooler
70
    elif mpu.is_pipeline_first_stage() and not mpu.is_pipeline_last_stage():
71
72
        cls = TransformerLanguageModelFirstStage
        kwargs['num_tokentypes'] = num_tokentypes
73
    elif not mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage():
74
75
76
77
78
79
80
        cls = TransformerLanguageModelLastStage
        kwargs['add_pooler'] = add_pooler
    else:
        cls = TransformerLanguageModelIntermediateStage

    # Language model.
    language_model = cls(*args, **kwargs)
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    # key used for checkpoints.
    language_model_key = 'language_model'

    return language_model, language_model_key


class Pooler(MegatronModule):
    """Pooler layer.

    Pool hidden states of a specific token (for example start of the
    sequence) and add a linear transformation followed by a tanh.

    Arguments:
        hidden_size: hidden size
        init_method: weight initialization method for the linear layer.
            bias is set to zero.
    """
Neel Kant's avatar
Neel Kant committed
98

99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    def __init__(self, hidden_size, init_method):
        super(Pooler, self).__init__()
        self.dense = get_linear_layer(hidden_size, hidden_size, init_method)

    def forward(self, hidden_states, sequence_index=0):
        # hidden_states: [b, s, h]
        # sequence_index: index of the token to pool.
        pooled = hidden_states[:, sequence_index, :]
        pooled = self.dense(pooled)
        pooled = torch.tanh(pooled)
        return pooled


class Embedding(MegatronModule):
    """Language model embeddings.

    Arguments:
        hidden_size: hidden size
        vocab_size: vocabulary size
        max_sequence_length: maximum size of sequence. This
                             is used for positional embedding
        embedding_dropout_prob: dropout probability for embeddings
        init_method: weight initialization method
        num_tokentypes: size of the token-type embeddings. 0 value
                        will ignore this embedding
    """
Neel Kant's avatar
Neel Kant committed
125

126
127
128
129
130
131
132
133
134
135
136
137
138
    def __init__(self,
                 hidden_size,
                 vocab_size,
                 max_sequence_length,
                 embedding_dropout_prob,
                 init_method,
                 num_tokentypes=0):
        super(Embedding, self).__init__()

        self.hidden_size = hidden_size
        self.init_method = init_method
        self.num_tokentypes = num_tokentypes

139
140
        args = get_args()

141
142
        # Word embeddings (parallel).
        self.word_embeddings = mpu.VocabParallelEmbedding(
143
144
            vocab_size, self.hidden_size,
            init_method=self.init_method)
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
        self._word_embeddings_key = 'word_embeddings'

        # Position embedding (serial).
        self.position_embeddings = torch.nn.Embedding(
            max_sequence_length, self.hidden_size)
        self._position_embeddings_key = 'position_embeddings'
        # Initialize the position embeddings.
        self.init_method(self.position_embeddings.weight)

        # Token type embedding.
        # Add this as an optional field that can be added through
        # method call so we can load a pretrain model without
        # token types and add them as needed.
        self._tokentype_embeddings_key = 'tokentype_embeddings'
        if self.num_tokentypes > 0:
            self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes,
                                                           self.hidden_size)
            # Initialize the token-type embeddings.
            self.init_method(self.tokentype_embeddings.weight)
        else:
            self.tokentype_embeddings = None

        # Embeddings dropout
        self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)

    def add_tokentype_embeddings(self, num_tokentypes):
        """Add token-type embedding. This function is provided so we can add
        token-type embeddings in case the pretrained model does not have it.
        This allows us to load the model normally and then add this embedding.
        """
        if self.tokentype_embeddings is not None:
            raise Exception('tokentype embeddings is already initialized')
        if torch.distributed.get_rank() == 0:
            print('adding embedding for {} tokentypes'.format(num_tokentypes),
                  flush=True)
        self.num_tokentypes = num_tokentypes
        self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes,
                                                       self.hidden_size)
        # Initialize the token-type embeddings.
184
        args = get_args()
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
        self.init_method(self.tokentype_embeddings.weight)

    def forward(self, input_ids, position_ids, tokentype_ids=None):
        # Embeddings.
        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        embeddings = words_embeddings + position_embeddings
        if tokentype_ids is not None:
            assert self.tokentype_embeddings is not None
            embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)

        # Dropout.
        embeddings = self.embedding_dropout(embeddings)

        return embeddings

    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
        """For easy load."""

        state_dict_ = {}
        state_dict_[self._word_embeddings_key] \
            = self.word_embeddings.state_dict(destination, prefix, keep_vars)
        state_dict_[self._position_embeddings_key] \
            = self.position_embeddings.state_dict(
                destination, prefix, keep_vars)
        if self.num_tokentypes > 0:
            state_dict_[self._tokentype_embeddings_key] \
                = self.tokentype_embeddings.state_dict(
                    destination, prefix, keep_vars)

        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        # Word embedding.
        if self._word_embeddings_key in state_dict:
            state_dict_ = state_dict[self._word_embeddings_key]
        else:
            # for backward compatibility.
            state_dict_ = {}
            for key in state_dict.keys():
                if 'word_embeddings' in key:
                    state_dict_[key.split('word_embeddings.')[1]] \
                        = state_dict[key]
        self.word_embeddings.load_state_dict(state_dict_, strict=strict)

        # Position embedding.
        if self._position_embeddings_key in state_dict:
            state_dict_ = state_dict[self._position_embeddings_key]
        else:
            # for backward compatibility.
            state_dict_ = {}
            for key in state_dict.keys():
                if 'position_embeddings' in key:
                    state_dict_[key.split('position_embeddings.')[1]] \
                        = state_dict[key]
        self.position_embeddings.load_state_dict(state_dict_, strict=strict)

        # Tokentype embedding.
Neel Kant's avatar
Neel Kant committed
246
        if self.num_tokentypes > 0:
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
            state_dict_ = {}
            if self._tokentype_embeddings_key in state_dict:
                state_dict_ = state_dict[self._tokentype_embeddings_key]
            else:
                # for backward compatibility.
                for key in state_dict.keys():
                    if 'tokentype_embeddings' in key:
                        state_dict_[key.split('tokentype_embeddings.')[1]] \
                            = state_dict[key]
            if len(state_dict_.keys()) > 0:
                self.tokentype_embeddings.load_state_dict(state_dict_,
                                                          strict=strict)
            else:
                print('***WARNING*** expected tokentype embeddings in the '
                      'checkpoint but could not find it', flush=True)


264
class TransformerLanguageModelBase(MegatronModule):
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
    """Transformer language model.

    Arguments:
        transformer_hparams: transformer hyperparameters
        attention_mask_func: a function that takes `unmaksed-attention-scores`
            with size [b, np, s, s] and an `attention-mask` and will apply
            the masking. The function should return a masked score of the
            same size [b, np, s, s].
          masked-attention-scores = attention_mask_func(
                                     unmaksed-attention-scores, attention-mask)
        vocab_size: vocabulary size
        max_sequence_length: maximum size of sequence. This
                             is used for positional embedding
        embedding_dropout_prob: dropout probability for embeddings
        num_tokentypes: size of the token-type embeddings. 0 value
                        will ignore this embedding
    """
Neel Kant's avatar
Neel Kant committed
282

283
284
    def __init__(self,
                 attention_mask_func,
Mohammad's avatar
Mohammad committed
285
286
                 init_method,
                 output_layer_init_method,
287
                 num_tokentypes=0,
288
289
                 self_attn_mask_type=AttnMaskType.padding,
                 add_decoder=False,
290
                 add_pooler=False):
291
        super(TransformerLanguageModelBase, self).__init__()
Mohammad's avatar
Mohammad committed
292
        args = get_args()
293

Mohammad's avatar
Mohammad committed
294
        self.hidden_size = args.hidden_size
295
        self.num_tokentypes = num_tokentypes
Mohammad's avatar
Mohammad committed
296
        self.init_method = init_method
297
298
        self.self_attn_mask_type = self_attn_mask_type
        self.add_decoder = add_decoder
299
300
        self.add_pooler = add_pooler

301
        # Embeddings.
302
        if mpu.is_pipeline_first_stage():
303
304
305
306
307
308
309
            self.embedding = Embedding(self.hidden_size,
                                       args.padded_vocab_size,
                                       args.max_position_embeddings,
                                       args.hidden_dropout,
                                       self.init_method,
                                       self.num_tokentypes)
            self._embedding_key = 'embedding'
310

311
        # Transformer.
312
313
314
315
316
317
318
        self.encoder = ParallelTransformer(
            attention_mask_func,
            self.init_method,
            output_layer_init_method,
            self_attn_mask_type=self_attn_mask_type)
        self._encoder_key = 'encoder'

Vijay Korthikanti's avatar
Vijay Korthikanti committed
319
320
321
322
323
324
325
326
327
328
329
        # Decoder
        if self.add_decoder:
            assert args.pipeline_model_parallel_size == 1, \
                'pipeline parallelism is not supported in the presence of decoder'
            self.decoder = ParallelTransformer(
                attention_mask_func,
                self.init_method,
                output_layer_init_method,
                layer_type=LayerType.decoder,
                self_attn_mask_type=AttnMaskType.causal)
            self._decoder_key = 'decoder'
330

Vijay Korthikanti's avatar
Vijay Korthikanti committed
331
        if mpu.is_pipeline_last_stage():
332
333
334
335
336
337
338
339
340
341
            # Pooler.
            if self.add_pooler:
                self.pooler = Pooler(self.hidden_size, self.init_method)
                self._pooler_key = 'pooler'

    def forward(self, enc_language_model_input, enc_attention_mask,
                dec_language_model_input=None, dec_attn_mask=None,
                enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None,
                get_key_value=False, pooling_sequence_index=0, 
                enc_hidden_states=None, output_enc_hidden=False):
342
343

        # Embeddings.
344
        if mpu.is_pipeline_first_stage():
345
            (input_ids, position_ids) = enc_language_model_input
346
347
            embedding_output = self.embedding(input_ids, position_ids,
                                              tokentype_ids=tokentype_ids)
348
            encoder_input = embedding_output
349
        else:
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
            encoder_input = enc_language_model_input

        # encoder.
        if enc_hidden_states is None:
            encoder_output = self.encoder(encoder_input,
                                          enc_attention_mask,
                                          layer_past=layer_past,
                                          get_key_value=get_key_value)
        else:
            encoder_output = enc_hidden_states.to(encoder_input.dtype)

        if mpu.is_pipeline_last_stage():
            if self.add_pooler:
                pooled_output = self.pooler(encoder_output,
                                            pooling_sequence_index)

Vijay Korthikanti's avatar
Vijay Korthikanti committed
366
367
368
369
370
371
        # output_enc_hidden refers to when we just need the encoder's
        # output. For example, it is helpful to compute
        # similarity between two sequences by average pooling
        if not self.add_decoder or output_enc_hidden:
            if self.add_pooler and mpu.is_pipeline_last_stage():
                return encoder_output, pooled_output
372
            else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
                return encoder_output

        # Decoder Embedding
        (dec_input_ids, dec_position_ids) = dec_language_model_input
        dec_embedding_output = self.embedding(dec_input_ids,
                                              dec_position_ids)
        # decoder
        decoder_output = self.decoder(dec_embedding_output,
                                      dec_attn_mask,
                                      layer_past=layer_past,
                                      get_key_value=get_key_value,
                                      encoder_output=encoder_output,
                                      enc_dec_attn_mask=enc_dec_attn_mask)

        if self.add_pooler and mpu.is_pipeline_last_stage():
            return decoder_output, encoder_output, pooled_output
        else:
            return decoder_output, encoder_output
391
392
393
394
395
396

    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
        """For easy load."""

        state_dict_ = {}
397
        if mpu.is_pipeline_first_stage():
398
399
400
            state_dict_[self._embedding_key] \
                = self.embedding.state_dict_for_save_checkpoint(
                    destination, prefix, keep_vars)
401
402
        state_dict_[self._encoder_key] \
            = self.encoder.state_dict_for_save_checkpoint(
403
                destination, prefix, keep_vars)
404
405
406
407
408
409
410
411
412
        if mpu.is_pipeline_last_stage():
            if self.add_pooler:
                state_dict_[self._pooler_key] \
                    = self.pooler.state_dict_for_save_checkpoint(
                        destination, prefix, keep_vars)
            if self.add_decoder:
                state_dict_[self._decoder_key] \
                    = self.decoder.state_dict_for_save_checkpoint(
                        destination, prefix, keep_vars)
413
414
415
416
417
418
419

        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        # Embedding.
420
        if mpu.is_pipeline_first_stage():
421
422
423
424
425
426
427
428
429
            if self._embedding_key in state_dict:
                state_dict_ = state_dict[self._embedding_key]
            else:
                # for backward compatibility.
                state_dict_ = {}
                for key in state_dict.keys():
                    if '_embeddings' in key:
                        state_dict_[key] = state_dict[key]
            self.embedding.load_state_dict(state_dict_, strict=strict)
430

431
432
433
434
435
436
        # Encoder.
        if self._encoder_key in state_dict:
            state_dict_ = state_dict[self._encoder_key]
        # for backward compatibility.
        elif 'transformer' in state_dict:
            state_dict_ = state_dict['transformer']
437
438
439
440
        else:
            # for backward compatibility.
            state_dict_ = {}
            for key in state_dict.keys():
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
                if 'encoder.' in key:
                    state_dict_[key.split('encoder.')[1]] = state_dict[key]

        # for backward compatibility.
        state_dict_self_attention = {}
        for key in state_dict_.keys():
            if '.attention.' in key:
                state_dict_self_attention[key.replace(".attention.",
                    ".self_attention.")] = state_dict_[key]
            else:
                state_dict_self_attention[key] = state_dict_[key]
        state_dict_ = state_dict_self_attention

        self.encoder.load_state_dict(state_dict_, strict=strict)

        if mpu.is_pipeline_last_stage():
            # pooler
            if self.add_pooler:
                assert 'pooler' in state_dict, \
                    'could not find data for pooler in the checkpoint'
                self.pooler.load_state_dict(state_dict[self._pooler_key],
                                            strict=strict)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
463
464
465
466
467
468
        # decoder
        if self.add_decoder:
            assert 'decoder' in state_dict, \
                'could not find data for pooler in the checkpoint'
            self.decoder.load_state_dict(state_dict[self._decoder_key],
                                         strict=strict)
469
470
471
472
473
474
475
476
477
478
479
480


class TransformerLanguageModel(TransformerLanguageModelBase):
    """Transformer language model (see TransformerLanguageModelBase
       for description of arguments).
    """

    def __init__(self,
                 attention_mask_func,
                 init_method,
                 output_layer_init_method,
                 num_tokentypes=0,
481
482
                 self_attn_mask_type=AttnMaskType.padding,
                 add_decoder=False,
483
484
485
486
487
488
                 add_pooler=False):
        super(TransformerLanguageModel, self).__init__(
            attention_mask_func,
            init_method,
            output_layer_init_method,
            num_tokentypes=num_tokentypes,
489
490
            self_attn_mask_type=self_attn_mask_type,
            add_decoder=add_decoder,
491
492
            add_pooler=add_pooler)

493
494
495
496
497
    def forward(self, enc_input_ids, enc_position_ids, enc_attention_mask,
                dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
                enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None,
                get_key_value=False, pooling_sequence_index=0,
                enc_hidden_states=None, output_enc_hidden=False):
498
        return super(TransformerLanguageModel, self).forward(
499
500
501
502
503
            (enc_input_ids, enc_position_ids),
            enc_attention_mask,
            dec_language_model_input=(dec_input_ids, dec_position_ids),
            dec_attn_mask=dec_attn_mask,
            enc_dec_attn_mask=enc_dec_attn_mask,
504
505
506
            tokentype_ids=tokentype_ids,
            layer_past=layer_past,
            get_key_value=get_key_value,
507
508
509
            pooling_sequence_index=pooling_sequence_index,
            enc_hidden_states=enc_hidden_states,
            output_enc_hidden=output_enc_hidden
510
511
512
513
514
515
516
517
518
519
520
521
        )


class TransformerLanguageModelFirstStage(TransformerLanguageModelBase):
    """Transformer language model, first stage (see
       TransformerLanguageModelBase for description of arguments).
    """

    def __init__(self,
                 attention_mask_func,
                 init_method,
                 output_layer_init_method,
522
523
                 num_tokentypes=0,
                 self_attn_mask_type=AttnMaskType.padding):
524
525
526
527
        super(TransformerLanguageModelFirstStage, self).__init__(
            attention_mask_func,
            init_method,
            output_layer_init_method,
528
529
            num_tokentypes=num_tokentypes,
            self_attn_mask_type=self_attn_mask_type)
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549

    def forward(self, input_ids, position_ids, attention_mask,
                tokentype_ids=None, layer_past=None, get_key_value=False):
        return super(TransformerLanguageModelFirstStage, self).forward(
            (input_ids, position_ids),
            attention_mask,
            tokentype_ids=tokentype_ids,
            layer_past=layer_past,
            get_key_value=get_key_value
        )


class TransformerLanguageModelIntermediateStage(TransformerLanguageModelBase):
    """Transformer language model, intermediate stage (see
       TransformerLanguageModelBase for description of arguments).
    """

    def __init__(self,
                 attention_mask_func,
                 init_method,
550
551
                 output_layer_init_method,
                 self_attn_mask_type=AttnMaskType.padding):
552
553
554
        super(TransformerLanguageModelIntermediateStage, self).__init__(
            attention_mask_func,
            init_method,
555
556
            output_layer_init_method,
            self_attn_mask_type=self_attn_mask_type)
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576

    def forward(self, hidden_states, attention_mask,
                layer_past=None, get_key_value=False):
        return super(TransformerLanguageModelIntermediateStage, self).forward(
            hidden_states,
            attention_mask,
            layer_past=layer_past,
            get_key_value=get_key_value
        )


class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
    """Transformer language model, final stage (see
       TransformerLanguageModelBase for description of arguments).
    """

    def __init__(self,
                 attention_mask_func,
                 init_method,
                 output_layer_init_method,
577
                 self_attn_mask_type=AttnMaskType.padding,
578
579
580
581
582
                 add_pooler=False):
        super(TransformerLanguageModelLastStage, self).__init__(
            attention_mask_func,
            init_method,
            output_layer_init_method,
583
            self_attn_mask_type=AttnMaskType.padding,
584
585
            add_pooler=add_pooler)

Vijay Korthikanti's avatar
Vijay Korthikanti committed
586
587
588
    def forward(self, hidden_states, attention_mask,
                layer_past=None, get_key_value=False,
                pooling_sequence_index=0):
589
590
        return super(TransformerLanguageModelLastStage, self).forward(
            hidden_states,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
591
            attention_mask,
592
593
            layer_past=layer_past,
            get_key_value=get_key_value,
594
            pooling_sequence_index=pooling_sequence_index,
595
        )