language_model.py 22 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4
5
6
7

"""Transformer based language model."""

import torch
import torch.nn.functional as F

Mohammad's avatar
Mohammad committed
8
from megatron import get_args
9
from megatron.core import mpu, tensor_parallel
10
from .module import MegatronModule
11
from megatron.model.enums import LayerType, AttnMaskType
Mohammad's avatar
Mohammad committed
12
13
from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import get_linear_layer
14
from megatron.model.utils import init_method_normal, scaled_init_method_normal
15

16

17
18
19
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
                       bias=None):
    """LM logits using word embedding weights."""
20
    args = get_args()
21
    # Parallel logits.
22
    if args.async_tensor_model_parallel_allreduce or\
Vijay Korthikanti's avatar
Vijay Korthikanti committed
23
            args.sequence_parallel:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
24
        input_parallel = input_
25
        model_parallel = mpu.get_tensor_model_parallel_world_size() > 1
26
        async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
27
            model_parallel and not args.sequence_parallel
28
    else:
29
        input_parallel = tensor_parallel.copy_to_tensor_model_parallel_region(input_)
30
31
        async_grad_allreduce = False

32
    # Matrix multiply.
33
    logits_parallel = tensor_parallel.linear_with_grad_accumulation_and_async_allreduce(
34
35
36
37
38
39
        input=input_parallel,
        weight=word_embeddings_weight,
        bias=bias,
        gradient_accumulation_fusion=args.gradient_accumulation_fusion,
        async_grad_allreduce=async_grad_allreduce,
        sequence_parallel_enabled=args.sequence_parallel)
40
    # Gather if needed.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
41

42
43
    if parallel_output:
        return logits_parallel
Mohammad's avatar
Mohammad committed
44

45
    return tensor_parallel.gather_from_tensor_model_parallel_region(logits_parallel)
Mohammad's avatar
Mohammad committed
46
47


48
def get_language_model(num_tokentypes, add_pooler,
49
                       encoder_attn_mask_type, init_method=None,
50
51
                       scaled_init_method=None, add_encoder=True,
                       add_decoder=False,
52
53
                       decoder_attn_mask_type=AttnMaskType.causal,
                       pre_process=True, post_process=True):
Mohammad's avatar
Mohammad committed
54
    """Build language model and return along with the key to save."""
55
    args = get_args()
Mohammad's avatar
Mohammad committed
56

57
58
59
60
    if init_method is None:
        init_method = init_method_normal(args.init_method_std)

    if scaled_init_method is None:
61
62
        scaled_init_method = scaled_init_method_normal(args.init_method_std,
                                                       args.num_layers)
63

64
    # Language model.
65
66
67
68
69
    language_model = TransformerLanguageModel(
        init_method,
        scaled_init_method,
        encoder_attn_mask_type,
        num_tokentypes=num_tokentypes,
70
        add_encoder=add_encoder,
71
72
73
74
75
76
        add_decoder=add_decoder,
        decoder_attn_mask_type=decoder_attn_mask_type,
        add_pooler=add_pooler,
        pre_process=pre_process,
        post_process=post_process
    )
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    # key used for checkpoints.
    language_model_key = 'language_model'

    return language_model, language_model_key


class Pooler(MegatronModule):
    """Pooler layer.

    Pool hidden states of a specific token (for example start of the
    sequence) and add a linear transformation followed by a tanh.

    Arguments:
        hidden_size: hidden size
        init_method: weight initialization method for the linear layer.
            bias is set to zero.
    """
Neel Kant's avatar
Neel Kant committed
94

95
96
    def __init__(self, hidden_size, init_method):
        super(Pooler, self).__init__()
97
        args = get_args()
98
        self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
99
100
        self.sequence_parallel = args.sequence_parallel

101
102

    def forward(self, hidden_states, sequence_index=0):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
103
        # hidden_states: [s, b, h]
104
        # sequence_index: index of the token to pool.
105
106
107
108

        # gather data along sequence dimensions
        # same pooler is run on all tensor parallel nodes
        if self.sequence_parallel:
109
            hidden_states = tensor_parallel.gather_from_sequence_parallel_region(
110
                hidden_states,
111
                tensor_parallel_output_grad=False)
112

Vijay Korthikanti's avatar
Vijay Korthikanti committed
113
        pooled = hidden_states[sequence_index, :, :]
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
        pooled = self.dense(pooled)
        pooled = torch.tanh(pooled)
        return pooled


class Embedding(MegatronModule):
    """Language model embeddings.

    Arguments:
        hidden_size: hidden size
        vocab_size: vocabulary size
        max_sequence_length: maximum size of sequence. This
                             is used for positional embedding
        embedding_dropout_prob: dropout probability for embeddings
        init_method: weight initialization method
        num_tokentypes: size of the token-type embeddings. 0 value
                        will ignore this embedding
    """
Neel Kant's avatar
Neel Kant committed
132

133
134
135
136
137
138
139
140
141
142
143
144
145
    def __init__(self,
                 hidden_size,
                 vocab_size,
                 max_sequence_length,
                 embedding_dropout_prob,
                 init_method,
                 num_tokentypes=0):
        super(Embedding, self).__init__()

        self.hidden_size = hidden_size
        self.init_method = init_method
        self.num_tokentypes = num_tokentypes

146
147
        args = get_args()

148
        # Word embeddings (parallel).
149
        self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
150
            vocab_size, self.hidden_size,
151
152
153
154
155
            init_method=self.init_method,
            params_dtype=args.params_dtype,
            use_cpu_initialization=args.use_cpu_initialization,
            perform_initialization=args.perform_initialization
        )
156
157
158
159
160
161
162
        self._word_embeddings_key = 'word_embeddings'

        # Position embedding (serial).
        self.position_embeddings = torch.nn.Embedding(
            max_sequence_length, self.hidden_size)
        self._position_embeddings_key = 'position_embeddings'
        # Initialize the position embeddings.
163
164
        if args.perform_initialization:
            self.init_method(self.position_embeddings.weight)
165
166
167
168
169
170
171
172
173
174

        # Token type embedding.
        # Add this as an optional field that can be added through
        # method call so we can load a pretrain model without
        # token types and add them as needed.
        self._tokentype_embeddings_key = 'tokentype_embeddings'
        if self.num_tokentypes > 0:
            self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes,
                                                           self.hidden_size)
            # Initialize the token-type embeddings.
175
176
            if args.perform_initialization:
                self.init_method(self.tokentype_embeddings.weight)
177
178
179
        else:
            self.tokentype_embeddings = None

180
        self.fp32_residual_connection = args.fp32_residual_connection 
Vijay Korthikanti's avatar
Vijay Korthikanti committed
181
        self.sequence_parallel = args.sequence_parallel
182
183
184
        # Embeddings dropout
        self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)

185
186
187
    def zero_parameters(self):
        """Zero out all parameters in embedding."""
        self.word_embeddings.weight.data.fill_(0)
Deepak Narayanan's avatar
Deepak Narayanan committed
188
        self.word_embeddings.weight.shared = True
189
        self.position_embeddings.weight.data.fill_(0)
Deepak Narayanan's avatar
Deepak Narayanan committed
190
        self.position_embeddings.weight.shared = True
191
192
        if self.num_tokentypes > 0:
            self.tokentype_embeddings.weight.data.fill_(0)
Deepak Narayanan's avatar
Deepak Narayanan committed
193
            self.tokentype_embeddings.weight.shared = True
194

195
196
197
198
199
200
201
202
203
204
205
206
207
208
    def add_tokentype_embeddings(self, num_tokentypes):
        """Add token-type embedding. This function is provided so we can add
        token-type embeddings in case the pretrained model does not have it.
        This allows us to load the model normally and then add this embedding.
        """
        if self.tokentype_embeddings is not None:
            raise Exception('tokentype embeddings is already initialized')
        if torch.distributed.get_rank() == 0:
            print('adding embedding for {} tokentypes'.format(num_tokentypes),
                  flush=True)
        self.num_tokentypes = num_tokentypes
        self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes,
                                                       self.hidden_size)
        # Initialize the token-type embeddings.
209
        args = get_args()
210
211
212
213
214
215
216
217
218
219
220
221
222
        self.init_method(self.tokentype_embeddings.weight)

    def forward(self, input_ids, position_ids, tokentype_ids=None):
        # Embeddings.
        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        embeddings = words_embeddings + position_embeddings
        if tokentype_ids is not None:
            assert self.tokentype_embeddings is not None
            embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
        else:
            assert self.tokentype_embeddings is None

223
        # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
Vijay Korthikanti's avatar
Vijay Korthikanti committed
224
225
        embeddings = embeddings.transpose(0, 1).contiguous()

226
227
        # If the input flag for fp32 residual connection is set, convert for float.
        if self.fp32_residual_connection:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
228
            embeddings = embeddings.float()
229

230
        # Dropout.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
231
        if self.sequence_parallel:
232
233
            embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings)
            with tensor_parallel.get_cuda_rng_tracker().fork():
234
235
236
                embeddings = self.embedding_dropout(embeddings)
        else:
            embeddings = self.embedding_dropout(embeddings)
237
238
239

        return embeddings

240
    def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
241
242
243
244
        """For easy load."""

        state_dict_ = {}
        state_dict_[self._word_embeddings_key] \
245
246
            = self.word_embeddings.state_dict(prefix=prefix,
                                              keep_vars=keep_vars)
247
        state_dict_[self._position_embeddings_key] \
248
249
            = self.position_embeddings.state_dict(prefix=prefix,
                                                  keep_vars=keep_vars)
250
251
        if self.num_tokentypes > 0:
            state_dict_[self._tokentype_embeddings_key] \
252
253
                = self.tokentype_embeddings.state_dict(prefix=prefix,
                                                       keep_vars=keep_vars)
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284

        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        # Word embedding.
        if self._word_embeddings_key in state_dict:
            state_dict_ = state_dict[self._word_embeddings_key]
        else:
            # for backward compatibility.
            state_dict_ = {}
            for key in state_dict.keys():
                if 'word_embeddings' in key:
                    state_dict_[key.split('word_embeddings.')[1]] \
                        = state_dict[key]
        self.word_embeddings.load_state_dict(state_dict_, strict=strict)

        # Position embedding.
        if self._position_embeddings_key in state_dict:
            state_dict_ = state_dict[self._position_embeddings_key]
        else:
            # for backward compatibility.
            state_dict_ = {}
            for key in state_dict.keys():
                if 'position_embeddings' in key:
                    state_dict_[key.split('position_embeddings.')[1]] \
                        = state_dict[key]
        self.position_embeddings.load_state_dict(state_dict_, strict=strict)

        # Tokentype embedding.
Neel Kant's avatar
Neel Kant committed
285
        if self.num_tokentypes > 0:
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
            state_dict_ = {}
            if self._tokentype_embeddings_key in state_dict:
                state_dict_ = state_dict[self._tokentype_embeddings_key]
            else:
                # for backward compatibility.
                for key in state_dict.keys():
                    if 'tokentype_embeddings' in key:
                        state_dict_[key.split('tokentype_embeddings.')[1]] \
                            = state_dict[key]
            if len(state_dict_.keys()) > 0:
                self.tokentype_embeddings.load_state_dict(state_dict_,
                                                          strict=strict)
            else:
                print('***WARNING*** expected tokentype embeddings in the '
                      'checkpoint but could not find it', flush=True)


303
class TransformerLanguageModel(MegatronModule):
304
305
306
307
308
309
310
311
312
313
314
    """Transformer language model.

    Arguments:
        transformer_hparams: transformer hyperparameters
        vocab_size: vocabulary size
        max_sequence_length: maximum size of sequence. This
                             is used for positional embedding
        embedding_dropout_prob: dropout probability for embeddings
        num_tokentypes: size of the token-type embeddings. 0 value
                        will ignore this embedding
    """
Neel Kant's avatar
Neel Kant committed
315

316
    def __init__(self,
Mohammad's avatar
Mohammad committed
317
318
                 init_method,
                 output_layer_init_method,
319
                 encoder_attn_mask_type,
320
                 num_tokentypes=0,
321
                 add_encoder=True,
322
                 add_decoder=False,
323
                 decoder_attn_mask_type=AttnMaskType.causal,
324
325
326
327
                 add_pooler=False,
                 pre_process=True,
                 post_process=True):
        super(TransformerLanguageModel, self).__init__()
Mohammad's avatar
Mohammad committed
328
        args = get_args()
329

330
331
        self.pre_process = pre_process
        self.post_process = post_process
Mohammad's avatar
Mohammad committed
332
        self.hidden_size = args.hidden_size
333
        self.num_tokentypes = num_tokentypes
Mohammad's avatar
Mohammad committed
334
        self.init_method = init_method
335
        self.add_encoder = add_encoder
336
        self.encoder_attn_mask_type = encoder_attn_mask_type
337
        self.add_decoder = add_decoder
338
        self.decoder_attn_mask_type = decoder_attn_mask_type
339
        self.add_pooler = add_pooler
340
        self.encoder_hidden_state = None
341

342
        # Embeddings.
343
        if self.pre_process:
344
345
346
347
348
349
350
            self.embedding = Embedding(self.hidden_size,
                                       args.padded_vocab_size,
                                       args.max_position_embeddings,
                                       args.hidden_dropout,
                                       self.init_method,
                                       self.num_tokentypes)
            self._embedding_key = 'embedding'
351

352
        # Transformer.
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
        # Encoder (usually set to True, False if part of an encoder-decoder
        # architecture and in encoder-only stage).
        if self.add_encoder:
            self.encoder = ParallelTransformer(
                self.init_method,
                output_layer_init_method,
                self_attn_mask_type=self.encoder_attn_mask_type,
                pre_process=self.pre_process,
                post_process=self.post_process
            )
            self._encoder_key = 'encoder'
        else:
            self.encoder = None

        # Decoder (usually set to False, True if part of an encoder-decoder
        # architecture and in decoder-only stage).
Vijay Korthikanti's avatar
Vijay Korthikanti committed
369
370
371
372
373
        if self.add_decoder:
            self.decoder = ParallelTransformer(
                self.init_method,
                output_layer_init_method,
                layer_type=LayerType.decoder,
374
375
376
                self_attn_mask_type=self.decoder_attn_mask_type,
                pre_process=self.pre_process,
                post_process=self.post_process)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
377
            self._decoder_key = 'decoder'
378
379
        else:
            self.decoder = None
380

381
        if self.post_process:
382
383
384
385
386
            # Pooler.
            if self.add_pooler:
                self.pooler = Pooler(self.hidden_size, self.init_method)
                self._pooler_key = 'pooler'

387
    def set_input_tensor(self, input_tensor):
388
        """ See megatron.model.transformer.set_input_tensor()"""
389
390
391
392
393
394

        # This is usually handled in schedules.py but some inference code still
        # gives us non-lists or None
        if not isinstance(input_tensor, list):
            input_tensor = [input_tensor]

395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
        if self.add_encoder and self.add_decoder:
            assert len(input_tensor) == 1, \
                'input_tensor should only be length 1 for stage with both encoder and decoder'
            self.encoder.set_input_tensor(input_tensor[0])
        elif self.add_encoder:
            assert len(input_tensor) == 1, \
                'input_tensor should only be length 1 for stage with only encoder'
            self.encoder.set_input_tensor(input_tensor[0])
        elif self.add_decoder:
            if len(input_tensor) == 2:
                self.decoder.set_input_tensor(input_tensor[0])
                self.encoder_hidden_state = input_tensor[1]
            elif len(input_tensor) == 1:
                self.decoder.set_input_tensor(None)
                self.encoder_hidden_state = input_tensor[0]
            else:
                raise Exception('input_tensor must have either length 1 or 2')
        else:
            raise Exception('Stage must have at least either encoder or decoder')
414
415
416

    def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
                dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
417
                enc_dec_attn_mask=None, tokentype_ids=None,
mshoeybi's avatar
mshoeybi committed
418
                inference_params=None,
419
                pooling_sequence_index=0,
420
                enc_hidden_states=None, output_enc_hidden=False):
421

422
        # Encoder embedding.
423
        if self.pre_process:
424
425
            encoder_input = self.embedding(enc_input_ids, enc_position_ids,
                                           tokentype_ids=tokentype_ids)
426
        else:
427
            encoder_input = None
428

429
        # Run encoder.
430
        if enc_hidden_states is None:
431
            if self.encoder is not None:
432
433
434
                encoder_output = self.encoder(
                    encoder_input,
                    enc_attn_mask,
mshoeybi's avatar
mshoeybi committed
435
                    inference_params=inference_params)
436
437
            else:
                encoder_output = self.encoder_hidden_state
438
439
440
        else:
            encoder_output = enc_hidden_states.to(encoder_input.dtype)

441
        if self.post_process:
442
443
444
445
            if self.add_pooler:
                pooled_output = self.pooler(encoder_output,
                                            pooling_sequence_index)

Vijay Korthikanti's avatar
Vijay Korthikanti committed
446
447
448
449
        # output_enc_hidden refers to when we just need the encoder's
        # output. For example, it is helpful to compute
        # similarity between two sequences by average pooling
        if not self.add_decoder or output_enc_hidden:
450
            if self.add_pooler and self.post_process:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
451
                return encoder_output, pooled_output
452
            else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
453
454
                return encoder_output

455
456
457
458
459
460
461
462
        # Decoder embedding.
        if self.pre_process:
            decoder_input = self.embedding(dec_input_ids,
                                           dec_position_ids)
        else:
            decoder_input = None

        # Run decoder.
463
        decoder_output = self.decoder(
464
            decoder_input,
465
466
467
            dec_attn_mask,
            encoder_output=encoder_output,
            enc_dec_attn_mask=enc_dec_attn_mask,
mshoeybi's avatar
mshoeybi committed
468
            inference_params=inference_params)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
469

470
        if self.add_pooler and self.post_process:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
471
472
473
            return decoder_output, encoder_output, pooled_output
        else:
            return decoder_output, encoder_output
474

475
    def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
476
477
478
        """For easy load."""

        state_dict_ = {}
479
        if self.pre_process:
480
            state_dict_[self._embedding_key] \
481
482
                = self.embedding.state_dict_for_save_checkpoint(prefix=prefix,
                                                                keep_vars=keep_vars)
483
484
        if self.add_encoder:
            state_dict_[self._encoder_key] \
485
486
                = self.encoder.state_dict_for_save_checkpoint(prefix=prefix,
                                                              keep_vars=keep_vars)
487
        if self.post_process:
488
489
            if self.add_pooler:
                state_dict_[self._pooler_key] \
490
491
                    = self.pooler.state_dict_for_save_checkpoint(prefix=prefix,
                                                                 keep_vars=keep_vars)
492
493
        if self.add_decoder:
            state_dict_[self._decoder_key] \
494
495
                = self.decoder.state_dict_for_save_checkpoint(prefix=prefix,
                                                              keep_vars=keep_vars)
496
497
498
499
500
501
502

        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        # Embedding.
503
        if self.pre_process:
504
505
506
507
508
509
510
511
512
            if self._embedding_key in state_dict:
                state_dict_ = state_dict[self._embedding_key]
            else:
                # for backward compatibility.
                state_dict_ = {}
                for key in state_dict.keys():
                    if '_embeddings' in key:
                        state_dict_[key] = state_dict[key]
            self.embedding.load_state_dict(state_dict_, strict=strict)
513

514
        # Encoder.
515
516
517
518
519
520
        if self.add_encoder:
            if self._encoder_key in state_dict:
                state_dict_ = state_dict[self._encoder_key]
            # For backward compatibility.
            elif 'transformer' in state_dict:
                state_dict_ = state_dict['transformer']
521
            else:
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
                # For backward compatibility.
                state_dict_ = {}
                for key in state_dict.keys():
                    if 'transformer.' in key:
                        state_dict_[key.split('transformer.')[1]] = state_dict[key]

            # For backward compatibility.
            state_dict_self_attention = {}
            for key in state_dict_.keys():
                if '.attention.' in key:
                    state_dict_self_attention[key.replace(".attention.",
                        ".self_attention.")] = state_dict_[key]
                else:
                    state_dict_self_attention[key] = state_dict_[key]
            state_dict_ = state_dict_self_attention

            self.encoder.load_state_dict(state_dict_, strict=strict)

        # Pooler.
541
        if self.post_process:
542
543
544
545
546
            if self.add_pooler:
                assert 'pooler' in state_dict, \
                    'could not find data for pooler in the checkpoint'
                self.pooler.load_state_dict(state_dict[self._pooler_key],
                                            strict=strict)
547
        # Decoder.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
548
549
        if self.add_decoder:
            assert 'decoder' in state_dict, \
550
                'could not find data for pooler in the checkpoint'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
551
552
            self.decoder.load_state_dict(state_dict[self._decoder_key],
                                         strict=strict)