language_model.py 21.6 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4
5
6
7

"""Transformer based language model."""

import torch
import torch.nn.functional as F

Mohammad's avatar
Mohammad committed
8
from megatron import get_args
9
from megatron import mpu
10
from .module import MegatronModule
11
from megatron.model.enums import LayerType, AttnMaskType
Mohammad's avatar
Mohammad committed
12
13
from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import get_linear_layer
14
from megatron.model.utils import init_method_normal, scaled_init_method_normal
15

16

17
18
19
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
                       bias=None):
    """LM logits using word embedding weights."""
20
    args = get_args()
21
    # Parallel logits.
22
    if args.async_tensor_model_parallel_allreduce or\
Vijay Korthikanti's avatar
Vijay Korthikanti committed
23
            args.sequence_parallel:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
24
        input_parallel = input_
25
26
        model_parallel = mpu.get_tensor_model_parallel_world_size() > 1
        async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
27
            model_parallel and not args.sequence_parallel
28
    else:
29
30
31
        input_parallel = mpu.copy_to_tensor_model_parallel_region(input_)
        async_grad_allreduce = False

32
    # Matrix multiply.
33
34
35
    logits_parallel = mpu.LinearWithGradAccumulationAndAsyncCommunication.apply(
        input_parallel, word_embeddings_weight, bias,
        args.gradient_accumulation_fusion,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
36
        async_grad_allreduce, args.sequence_parallel)
37
    # Gather if needed.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
38

39
40
    if parallel_output:
        return logits_parallel
Mohammad's avatar
Mohammad committed
41

42
    return mpu.gather_from_tensor_model_parallel_region(logits_parallel)
Mohammad's avatar
Mohammad committed
43
44


45
def get_language_model(num_tokentypes, add_pooler,
46
                       encoder_attn_mask_type, init_method=None,
47
48
                       scaled_init_method=None, add_encoder=True,
                       add_decoder=False,
49
50
                       decoder_attn_mask_type=AttnMaskType.causal,
                       pre_process=True, post_process=True):
Mohammad's avatar
Mohammad committed
51
    """Build language model and return along with the key to save."""
52
    args = get_args()
Mohammad's avatar
Mohammad committed
53

54
55
56
57
    if init_method is None:
        init_method = init_method_normal(args.init_method_std)

    if scaled_init_method is None:
58
59
        scaled_init_method = scaled_init_method_normal(args.init_method_std,
                                                       args.num_layers)
60

61
    # Language model.
62
63
64
65
66
    language_model = TransformerLanguageModel(
        init_method,
        scaled_init_method,
        encoder_attn_mask_type,
        num_tokentypes=num_tokentypes,
67
        add_encoder=add_encoder,
68
69
70
71
72
73
        add_decoder=add_decoder,
        decoder_attn_mask_type=decoder_attn_mask_type,
        add_pooler=add_pooler,
        pre_process=pre_process,
        post_process=post_process
    )
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    # key used for checkpoints.
    language_model_key = 'language_model'

    return language_model, language_model_key


class Pooler(MegatronModule):
    """Pooler layer.

    Pool hidden states of a specific token (for example start of the
    sequence) and add a linear transformation followed by a tanh.

    Arguments:
        hidden_size: hidden size
        init_method: weight initialization method for the linear layer.
            bias is set to zero.
    """
Neel Kant's avatar
Neel Kant committed
91

92
93
    def __init__(self, hidden_size, init_method):
        super(Pooler, self).__init__()
94
        args = get_args()
95
        self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
96
97
        self.sequence_parallel = args.sequence_parallel

98
99

    def forward(self, hidden_states, sequence_index=0):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
100
        # hidden_states: [s, b, h]
101
        # sequence_index: index of the token to pool.
102
103
104
105

        # gather data along sequence dimensions
        # same pooler is run on all tensor parallel nodes
        if self.sequence_parallel:
106
107
            hidden_states = mpu.gather_from_sequence_parallel_region(
                hidden_states,
108
                tensor_parallel_output_grad=False)
109

Vijay Korthikanti's avatar
Vijay Korthikanti committed
110
        pooled = hidden_states[sequence_index, :, :]
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        pooled = self.dense(pooled)
        pooled = torch.tanh(pooled)
        return pooled


class Embedding(MegatronModule):
    """Language model embeddings.

    Arguments:
        hidden_size: hidden size
        vocab_size: vocabulary size
        max_sequence_length: maximum size of sequence. This
                             is used for positional embedding
        embedding_dropout_prob: dropout probability for embeddings
        init_method: weight initialization method
        num_tokentypes: size of the token-type embeddings. 0 value
                        will ignore this embedding
    """
Neel Kant's avatar
Neel Kant committed
129

130
131
132
133
134
135
136
137
138
139
140
141
142
    def __init__(self,
                 hidden_size,
                 vocab_size,
                 max_sequence_length,
                 embedding_dropout_prob,
                 init_method,
                 num_tokentypes=0):
        super(Embedding, self).__init__()

        self.hidden_size = hidden_size
        self.init_method = init_method
        self.num_tokentypes = num_tokentypes

143
144
        args = get_args()

145
146
        # Word embeddings (parallel).
        self.word_embeddings = mpu.VocabParallelEmbedding(
147
148
            vocab_size, self.hidden_size,
            init_method=self.init_method)
149
150
151
152
153
154
155
        self._word_embeddings_key = 'word_embeddings'

        # Position embedding (serial).
        self.position_embeddings = torch.nn.Embedding(
            max_sequence_length, self.hidden_size)
        self._position_embeddings_key = 'position_embeddings'
        # Initialize the position embeddings.
156
157
        if args.perform_initialization:
            self.init_method(self.position_embeddings.weight)
158
159
160
161
162
163
164
165
166
167

        # Token type embedding.
        # Add this as an optional field that can be added through
        # method call so we can load a pretrain model without
        # token types and add them as needed.
        self._tokentype_embeddings_key = 'tokentype_embeddings'
        if self.num_tokentypes > 0:
            self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes,
                                                           self.hidden_size)
            # Initialize the token-type embeddings.
168
169
            if args.perform_initialization:
                self.init_method(self.tokentype_embeddings.weight)
170
171
172
        else:
            self.tokentype_embeddings = None

173
        self.fp32_residual_connection = args.fp32_residual_connection 
Vijay Korthikanti's avatar
Vijay Korthikanti committed
174
        self.sequence_parallel = args.sequence_parallel
175
176
177
        # Embeddings dropout
        self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)

178
179
180
    def zero_parameters(self):
        """Zero out all parameters in embedding."""
        self.word_embeddings.weight.data.fill_(0)
Deepak Narayanan's avatar
Deepak Narayanan committed
181
        self.word_embeddings.weight.shared = True
182
        self.position_embeddings.weight.data.fill_(0)
Deepak Narayanan's avatar
Deepak Narayanan committed
183
        self.position_embeddings.weight.shared = True
184
185
        if self.num_tokentypes > 0:
            self.tokentype_embeddings.weight.data.fill_(0)
Deepak Narayanan's avatar
Deepak Narayanan committed
186
            self.tokentype_embeddings.weight.shared = True
187

188
189
190
191
192
193
194
195
196
197
198
199
200
201
    def add_tokentype_embeddings(self, num_tokentypes):
        """Add token-type embedding. This function is provided so we can add
        token-type embeddings in case the pretrained model does not have it.
        This allows us to load the model normally and then add this embedding.
        """
        if self.tokentype_embeddings is not None:
            raise Exception('tokentype embeddings is already initialized')
        if torch.distributed.get_rank() == 0:
            print('adding embedding for {} tokentypes'.format(num_tokentypes),
                  flush=True)
        self.num_tokentypes = num_tokentypes
        self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes,
                                                       self.hidden_size)
        # Initialize the token-type embeddings.
202
        args = get_args()
203
204
205
206
207
208
209
210
211
212
213
214
215
        self.init_method(self.tokentype_embeddings.weight)

    def forward(self, input_ids, position_ids, tokentype_ids=None):
        # Embeddings.
        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        embeddings = words_embeddings + position_embeddings
        if tokentype_ids is not None:
            assert self.tokentype_embeddings is not None
            embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
        else:
            assert self.tokentype_embeddings is None

216
        # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
Vijay Korthikanti's avatar
Vijay Korthikanti committed
217
218
        embeddings = embeddings.transpose(0, 1).contiguous()

219
220
        # If the input flag for fp32 residual connection is set, convert for float.
        if self.fp32_residual_connection:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
221
            embeddings = embeddings.float()
222

223
        # Dropout.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
224
        if self.sequence_parallel:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
225
            embeddings = mpu.scatter_to_sequence_parallel_region(embeddings)
226
227
228
229
            with mpu.get_cuda_rng_tracker().fork():
                embeddings = self.embedding_dropout(embeddings)
        else:
            embeddings = self.embedding_dropout(embeddings)
230
231
232

        return embeddings

233
    def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
234
235
236
237
        """For easy load."""

        state_dict_ = {}
        state_dict_[self._word_embeddings_key] \
238
239
            = self.word_embeddings.state_dict(prefix=prefix,
                                              keep_vars=keep_vars)
240
        state_dict_[self._position_embeddings_key] \
241
242
            = self.position_embeddings.state_dict(prefix=prefix,
                                                  keep_vars=keep_vars)
243
244
        if self.num_tokentypes > 0:
            state_dict_[self._tokentype_embeddings_key] \
245
246
                = self.tokentype_embeddings.state_dict(prefix=prefix,
                                                       keep_vars=keep_vars)
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277

        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        # Word embedding.
        if self._word_embeddings_key in state_dict:
            state_dict_ = state_dict[self._word_embeddings_key]
        else:
            # for backward compatibility.
            state_dict_ = {}
            for key in state_dict.keys():
                if 'word_embeddings' in key:
                    state_dict_[key.split('word_embeddings.')[1]] \
                        = state_dict[key]
        self.word_embeddings.load_state_dict(state_dict_, strict=strict)

        # Position embedding.
        if self._position_embeddings_key in state_dict:
            state_dict_ = state_dict[self._position_embeddings_key]
        else:
            # for backward compatibility.
            state_dict_ = {}
            for key in state_dict.keys():
                if 'position_embeddings' in key:
                    state_dict_[key.split('position_embeddings.')[1]] \
                        = state_dict[key]
        self.position_embeddings.load_state_dict(state_dict_, strict=strict)

        # Tokentype embedding.
Neel Kant's avatar
Neel Kant committed
278
        if self.num_tokentypes > 0:
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
            state_dict_ = {}
            if self._tokentype_embeddings_key in state_dict:
                state_dict_ = state_dict[self._tokentype_embeddings_key]
            else:
                # for backward compatibility.
                for key in state_dict.keys():
                    if 'tokentype_embeddings' in key:
                        state_dict_[key.split('tokentype_embeddings.')[1]] \
                            = state_dict[key]
            if len(state_dict_.keys()) > 0:
                self.tokentype_embeddings.load_state_dict(state_dict_,
                                                          strict=strict)
            else:
                print('***WARNING*** expected tokentype embeddings in the '
                      'checkpoint but could not find it', flush=True)


296
class TransformerLanguageModel(MegatronModule):
297
298
299
300
301
302
303
304
305
306
307
    """Transformer language model.

    Arguments:
        transformer_hparams: transformer hyperparameters
        vocab_size: vocabulary size
        max_sequence_length: maximum size of sequence. This
                             is used for positional embedding
        embedding_dropout_prob: dropout probability for embeddings
        num_tokentypes: size of the token-type embeddings. 0 value
                        will ignore this embedding
    """
Neel Kant's avatar
Neel Kant committed
308

309
    def __init__(self,
Mohammad's avatar
Mohammad committed
310
311
                 init_method,
                 output_layer_init_method,
312
                 encoder_attn_mask_type,
313
                 num_tokentypes=0,
314
                 add_encoder=True,
315
                 add_decoder=False,
316
                 decoder_attn_mask_type=AttnMaskType.causal,
317
318
319
320
                 add_pooler=False,
                 pre_process=True,
                 post_process=True):
        super(TransformerLanguageModel, self).__init__()
Mohammad's avatar
Mohammad committed
321
        args = get_args()
322

323
324
        self.pre_process = pre_process
        self.post_process = post_process
Mohammad's avatar
Mohammad committed
325
        self.hidden_size = args.hidden_size
326
        self.num_tokentypes = num_tokentypes
Mohammad's avatar
Mohammad committed
327
        self.init_method = init_method
328
        self.add_encoder = add_encoder
329
        self.encoder_attn_mask_type = encoder_attn_mask_type
330
        self.add_decoder = add_decoder
331
        self.decoder_attn_mask_type = decoder_attn_mask_type
332
        self.add_pooler = add_pooler
333
        self.encoder_hidden_state = None
334

335
        # Embeddings.
336
        if self.pre_process:
337
338
339
340
341
342
343
            self.embedding = Embedding(self.hidden_size,
                                       args.padded_vocab_size,
                                       args.max_position_embeddings,
                                       args.hidden_dropout,
                                       self.init_method,
                                       self.num_tokentypes)
            self._embedding_key = 'embedding'
344

345
        # Transformer.
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
        # Encoder (usually set to True, False if part of an encoder-decoder
        # architecture and in encoder-only stage).
        if self.add_encoder:
            self.encoder = ParallelTransformer(
                self.init_method,
                output_layer_init_method,
                self_attn_mask_type=self.encoder_attn_mask_type,
                pre_process=self.pre_process,
                post_process=self.post_process
            )
            self._encoder_key = 'encoder'
        else:
            self.encoder = None

        # Decoder (usually set to False, True if part of an encoder-decoder
        # architecture and in decoder-only stage).
Vijay Korthikanti's avatar
Vijay Korthikanti committed
362
363
364
365
366
        if self.add_decoder:
            self.decoder = ParallelTransformer(
                self.init_method,
                output_layer_init_method,
                layer_type=LayerType.decoder,
367
368
369
                self_attn_mask_type=self.decoder_attn_mask_type,
                pre_process=self.pre_process,
                post_process=self.post_process)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
370
            self._decoder_key = 'decoder'
371
372
        else:
            self.decoder = None
373

374
        if self.post_process:
375
376
377
378
379
            # Pooler.
            if self.add_pooler:
                self.pooler = Pooler(self.hidden_size, self.init_method)
                self._pooler_key = 'pooler'

380
    def set_input_tensor(self, input_tensor):
381
        """ See megatron.model.transformer.set_input_tensor()"""
382
383
384
385
386
387

        # This is usually handled in schedules.py but some inference code still
        # gives us non-lists or None
        if not isinstance(input_tensor, list):
            input_tensor = [input_tensor]

388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
        if self.add_encoder and self.add_decoder:
            assert len(input_tensor) == 1, \
                'input_tensor should only be length 1 for stage with both encoder and decoder'
            self.encoder.set_input_tensor(input_tensor[0])
        elif self.add_encoder:
            assert len(input_tensor) == 1, \
                'input_tensor should only be length 1 for stage with only encoder'
            self.encoder.set_input_tensor(input_tensor[0])
        elif self.add_decoder:
            if len(input_tensor) == 2:
                self.decoder.set_input_tensor(input_tensor[0])
                self.encoder_hidden_state = input_tensor[1]
            elif len(input_tensor) == 1:
                self.decoder.set_input_tensor(None)
                self.encoder_hidden_state = input_tensor[0]
            else:
                raise Exception('input_tensor must have either length 1 or 2')
        else:
            raise Exception('Stage must have at least either encoder or decoder')
407
408
409

    def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
                dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
410
                enc_dec_attn_mask=None, tokentype_ids=None,
mshoeybi's avatar
mshoeybi committed
411
                inference_params=None,
412
                pooling_sequence_index=0,
413
                enc_hidden_states=None, output_enc_hidden=False):
414

415
        # Encoder embedding.
416
        if self.pre_process:
417
418
            encoder_input = self.embedding(enc_input_ids, enc_position_ids,
                                           tokentype_ids=tokentype_ids)
419
        else:
420
            encoder_input = None
421

422
        # Run encoder.
423
        if enc_hidden_states is None:
424
            if self.encoder is not None:
425
426
427
                encoder_output = self.encoder(
                    encoder_input,
                    enc_attn_mask,
mshoeybi's avatar
mshoeybi committed
428
                    inference_params=inference_params)
429
430
            else:
                encoder_output = self.encoder_hidden_state
431
432
433
        else:
            encoder_output = enc_hidden_states.to(encoder_input.dtype)

434
        if self.post_process:
435
436
437
438
            if self.add_pooler:
                pooled_output = self.pooler(encoder_output,
                                            pooling_sequence_index)

Vijay Korthikanti's avatar
Vijay Korthikanti committed
439
440
441
442
        # output_enc_hidden refers to when we just need the encoder's
        # output. For example, it is helpful to compute
        # similarity between two sequences by average pooling
        if not self.add_decoder or output_enc_hidden:
443
            if self.add_pooler and self.post_process:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
444
                return encoder_output, pooled_output
445
            else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
446
447
                return encoder_output

448
449
450
451
452
453
454
455
        # Decoder embedding.
        if self.pre_process:
            decoder_input = self.embedding(dec_input_ids,
                                           dec_position_ids)
        else:
            decoder_input = None

        # Run decoder.
456
        decoder_output = self.decoder(
457
            decoder_input,
458
459
460
            dec_attn_mask,
            encoder_output=encoder_output,
            enc_dec_attn_mask=enc_dec_attn_mask,
mshoeybi's avatar
mshoeybi committed
461
            inference_params=inference_params)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
462

463
        if self.add_pooler and self.post_process:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
464
465
466
            return decoder_output, encoder_output, pooled_output
        else:
            return decoder_output, encoder_output
467

468
    def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
469
470
471
        """For easy load."""

        state_dict_ = {}
472
        if self.pre_process:
473
            state_dict_[self._embedding_key] \
474
475
                = self.embedding.state_dict_for_save_checkpoint(prefix=prefix,
                                                                keep_vars=keep_vars)
476
477
        if self.add_encoder:
            state_dict_[self._encoder_key] \
478
479
                = self.encoder.state_dict_for_save_checkpoint(prefix=prefix,
                                                              keep_vars=keep_vars)
480
        if self.post_process:
481
482
            if self.add_pooler:
                state_dict_[self._pooler_key] \
483
484
                    = self.pooler.state_dict_for_save_checkpoint(prefix=prefix,
                                                                 keep_vars=keep_vars)
485
486
        if self.add_decoder:
            state_dict_[self._decoder_key] \
487
488
                = self.decoder.state_dict_for_save_checkpoint(prefix=prefix,
                                                              keep_vars=keep_vars)
489
490
491
492
493
494
495

        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        # Embedding.
496
        if self.pre_process:
497
498
499
500
501
502
503
504
505
            if self._embedding_key in state_dict:
                state_dict_ = state_dict[self._embedding_key]
            else:
                # for backward compatibility.
                state_dict_ = {}
                for key in state_dict.keys():
                    if '_embeddings' in key:
                        state_dict_[key] = state_dict[key]
            self.embedding.load_state_dict(state_dict_, strict=strict)
506

507
        # Encoder.
508
509
510
511
512
513
        if self.add_encoder:
            if self._encoder_key in state_dict:
                state_dict_ = state_dict[self._encoder_key]
            # For backward compatibility.
            elif 'transformer' in state_dict:
                state_dict_ = state_dict['transformer']
514
            else:
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
                # For backward compatibility.
                state_dict_ = {}
                for key in state_dict.keys():
                    if 'transformer.' in key:
                        state_dict_[key.split('transformer.')[1]] = state_dict[key]

            # For backward compatibility.
            state_dict_self_attention = {}
            for key in state_dict_.keys():
                if '.attention.' in key:
                    state_dict_self_attention[key.replace(".attention.",
                        ".self_attention.")] = state_dict_[key]
                else:
                    state_dict_self_attention[key] = state_dict_[key]
            state_dict_ = state_dict_self_attention

            self.encoder.load_state_dict(state_dict_, strict=strict)

        # Pooler.
534
        if self.post_process:
535
536
537
538
539
            if self.add_pooler:
                assert 'pooler' in state_dict, \
                    'could not find data for pooler in the checkpoint'
                self.pooler.load_state_dict(state_dict[self._pooler_key],
                                            strict=strict)
540
        # Decoder.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
541
542
        if self.add_decoder:
            assert 'decoder' in state_dict, \
543
                'could not find data for pooler in the checkpoint'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
544
545
            self.decoder.load_state_dict(state_dict[self._decoder_key],
                                         strict=strict)