language_model.py 21.8 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer based language model."""

import torch
import torch.nn.functional as F

Mohammad's avatar
Mohammad committed
21
from megatron import get_args
22
from megatron import mpu
23
from .module import MegatronModule
24
from megatron.model.enums import LayerType, AttnMaskType
Mohammad's avatar
Mohammad committed
25
26
from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import get_linear_layer
27
from megatron.model.utils import init_method_normal, scaled_init_method_normal
28

29

30
31
32
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
                       bias=None):
    """LM logits using word embedding weights."""
33
    args = get_args()
34
    # Parallel logits.
35
    if args.async_tensor_model_parallel_allreduce or\
Vijay Korthikanti's avatar
Vijay Korthikanti committed
36
            args.sequence_parallel:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
37
        input_parallel = input_
38
39
        model_parallel = mpu.get_tensor_model_parallel_world_size() > 1
        async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
40
            model_parallel and not args.sequence_parallel
41
    else:
42
43
44
        input_parallel = mpu.copy_to_tensor_model_parallel_region(input_)
        async_grad_allreduce = False

45
    # Matrix multiply.
46
47
48
    logits_parallel = mpu.LinearWithGradAccumulationAndAsyncCommunication.apply(
        input_parallel, word_embeddings_weight, bias,
        args.gradient_accumulation_fusion,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
49
        async_grad_allreduce, args.sequence_parallel)
50
    # Gather if needed.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
51

52
53
    if parallel_output:
        return logits_parallel
Mohammad's avatar
Mohammad committed
54

55
    return mpu.gather_from_tensor_model_parallel_region(logits_parallel)
Mohammad's avatar
Mohammad committed
56
57


58
def get_language_model(num_tokentypes, add_pooler,
59
                       encoder_attn_mask_type, init_method=None,
60
61
                       scaled_init_method=None, add_encoder=True,
                       add_decoder=False,
62
63
                       decoder_attn_mask_type=AttnMaskType.causal,
                       pre_process=True, post_process=True):
Mohammad's avatar
Mohammad committed
64
    """Build language model and return along with the key to save."""
65
    args = get_args()
Mohammad's avatar
Mohammad committed
66

67
68
69
70
    if init_method is None:
        init_method = init_method_normal(args.init_method_std)

    if scaled_init_method is None:
71
72
        scaled_init_method = scaled_init_method_normal(args.init_method_std,
                                                       args.num_layers)
73

74
    # Language model.
75
76
77
78
79
    language_model = TransformerLanguageModel(
        init_method,
        scaled_init_method,
        encoder_attn_mask_type,
        num_tokentypes=num_tokentypes,
80
        add_encoder=add_encoder,
81
82
83
84
85
86
        add_decoder=add_decoder,
        decoder_attn_mask_type=decoder_attn_mask_type,
        add_pooler=add_pooler,
        pre_process=pre_process,
        post_process=post_process
    )
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    # key used for checkpoints.
    language_model_key = 'language_model'

    return language_model, language_model_key


class Pooler(MegatronModule):
    """Pooler layer.

    Pool hidden states of a specific token (for example start of the
    sequence) and add a linear transformation followed by a tanh.

    Arguments:
        hidden_size: hidden size
        init_method: weight initialization method for the linear layer.
            bias is set to zero.
    """
Neel Kant's avatar
Neel Kant committed
104

105
106
    def __init__(self, hidden_size, init_method):
        super(Pooler, self).__init__()
107
        args = get_args()
108
        self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
109
110
        self.sequence_parallel = args.sequence_parallel

111
112

    def forward(self, hidden_states, sequence_index=0):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
113
        # hidden_states: [s, b, h]
114
        # sequence_index: index of the token to pool.
115
116
117
118
119
120

        # gather data along sequence dimensions
        # same pooler is run on all tensor parallel nodes
        if self.sequence_parallel:
            hidden_states = mpu.gather_from_sequence_parallel_region(hidden_states)

Vijay Korthikanti's avatar
Vijay Korthikanti committed
121
        pooled = hidden_states[sequence_index, :, :]
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
        pooled = self.dense(pooled)
        pooled = torch.tanh(pooled)
        return pooled


class Embedding(MegatronModule):
    """Language model embeddings.

    Arguments:
        hidden_size: hidden size
        vocab_size: vocabulary size
        max_sequence_length: maximum size of sequence. This
                             is used for positional embedding
        embedding_dropout_prob: dropout probability for embeddings
        init_method: weight initialization method
        num_tokentypes: size of the token-type embeddings. 0 value
                        will ignore this embedding
    """
Neel Kant's avatar
Neel Kant committed
140

141
142
143
144
145
146
147
148
149
150
151
152
153
    def __init__(self,
                 hidden_size,
                 vocab_size,
                 max_sequence_length,
                 embedding_dropout_prob,
                 init_method,
                 num_tokentypes=0):
        super(Embedding, self).__init__()

        self.hidden_size = hidden_size
        self.init_method = init_method
        self.num_tokentypes = num_tokentypes

154
155
        args = get_args()

156
157
        # Word embeddings (parallel).
        self.word_embeddings = mpu.VocabParallelEmbedding(
158
159
            vocab_size, self.hidden_size,
            init_method=self.init_method)
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
        self._word_embeddings_key = 'word_embeddings'

        # Position embedding (serial).
        self.position_embeddings = torch.nn.Embedding(
            max_sequence_length, self.hidden_size)
        self._position_embeddings_key = 'position_embeddings'
        # Initialize the position embeddings.
        self.init_method(self.position_embeddings.weight)

        # Token type embedding.
        # Add this as an optional field that can be added through
        # method call so we can load a pretrain model without
        # token types and add them as needed.
        self._tokentype_embeddings_key = 'tokentype_embeddings'
        if self.num_tokentypes > 0:
            self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes,
                                                           self.hidden_size)
            # Initialize the token-type embeddings.
            self.init_method(self.tokentype_embeddings.weight)
        else:
            self.tokentype_embeddings = None

182
        self.fp32_residual_connection = args.fp32_residual_connection 
Vijay Korthikanti's avatar
Vijay Korthikanti committed
183
        self.sequence_parallel = args.sequence_parallel
184
185
186
        # Embeddings dropout
        self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)

187
188
189
    def zero_parameters(self):
        """Zero out all parameters in embedding."""
        self.word_embeddings.weight.data.fill_(0)
Deepak Narayanan's avatar
Deepak Narayanan committed
190
        self.word_embeddings.weight.shared = True
191
        self.position_embeddings.weight.data.fill_(0)
Deepak Narayanan's avatar
Deepak Narayanan committed
192
        self.position_embeddings.weight.shared = True
193
194
        if self.num_tokentypes > 0:
            self.tokentype_embeddings.weight.data.fill_(0)
Deepak Narayanan's avatar
Deepak Narayanan committed
195
            self.tokentype_embeddings.weight.shared = True
196

197
198
199
200
201
202
203
204
205
206
207
208
209
210
    def add_tokentype_embeddings(self, num_tokentypes):
        """Add token-type embedding. This function is provided so we can add
        token-type embeddings in case the pretrained model does not have it.
        This allows us to load the model normally and then add this embedding.
        """
        if self.tokentype_embeddings is not None:
            raise Exception('tokentype embeddings is already initialized')
        if torch.distributed.get_rank() == 0:
            print('adding embedding for {} tokentypes'.format(num_tokentypes),
                  flush=True)
        self.num_tokentypes = num_tokentypes
        self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes,
                                                       self.hidden_size)
        # Initialize the token-type embeddings.
211
        args = get_args()
212
213
214
215
216
217
218
219
220
221
222
223
224
        self.init_method(self.tokentype_embeddings.weight)

    def forward(self, input_ids, position_ids, tokentype_ids=None):
        # Embeddings.
        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        embeddings = words_embeddings + position_embeddings
        if tokentype_ids is not None:
            assert self.tokentype_embeddings is not None
            embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
        else:
            assert self.tokentype_embeddings is None

225
        # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
Vijay Korthikanti's avatar
Vijay Korthikanti committed
226
227
        embeddings = embeddings.transpose(0, 1).contiguous()

228
229
        # If the input flag for fp32 residual connection is set, convert for float.
        if self.fp32_residual_connection:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
230
            embeddings = embeddings.float()
231

232
        # Dropout.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
233
        if self.sequence_parallel:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
234
            embeddings = mpu.scatter_to_sequence_parallel_region(embeddings)
235
236
237
238
            with mpu.get_cuda_rng_tracker().fork():
                embeddings = self.embedding_dropout(embeddings)
        else:
            embeddings = self.embedding_dropout(embeddings)
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286

        return embeddings

    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
        """For easy load."""

        state_dict_ = {}
        state_dict_[self._word_embeddings_key] \
            = self.word_embeddings.state_dict(destination, prefix, keep_vars)
        state_dict_[self._position_embeddings_key] \
            = self.position_embeddings.state_dict(
                destination, prefix, keep_vars)
        if self.num_tokentypes > 0:
            state_dict_[self._tokentype_embeddings_key] \
                = self.tokentype_embeddings.state_dict(
                    destination, prefix, keep_vars)

        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        # Word embedding.
        if self._word_embeddings_key in state_dict:
            state_dict_ = state_dict[self._word_embeddings_key]
        else:
            # for backward compatibility.
            state_dict_ = {}
            for key in state_dict.keys():
                if 'word_embeddings' in key:
                    state_dict_[key.split('word_embeddings.')[1]] \
                        = state_dict[key]
        self.word_embeddings.load_state_dict(state_dict_, strict=strict)

        # Position embedding.
        if self._position_embeddings_key in state_dict:
            state_dict_ = state_dict[self._position_embeddings_key]
        else:
            # for backward compatibility.
            state_dict_ = {}
            for key in state_dict.keys():
                if 'position_embeddings' in key:
                    state_dict_[key.split('position_embeddings.')[1]] \
                        = state_dict[key]
        self.position_embeddings.load_state_dict(state_dict_, strict=strict)

        # Tokentype embedding.
Neel Kant's avatar
Neel Kant committed
287
        if self.num_tokentypes > 0:
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
            state_dict_ = {}
            if self._tokentype_embeddings_key in state_dict:
                state_dict_ = state_dict[self._tokentype_embeddings_key]
            else:
                # for backward compatibility.
                for key in state_dict.keys():
                    if 'tokentype_embeddings' in key:
                        state_dict_[key.split('tokentype_embeddings.')[1]] \
                            = state_dict[key]
            if len(state_dict_.keys()) > 0:
                self.tokentype_embeddings.load_state_dict(state_dict_,
                                                          strict=strict)
            else:
                print('***WARNING*** expected tokentype embeddings in the '
                      'checkpoint but could not find it', flush=True)


305
class TransformerLanguageModel(MegatronModule):
306
307
308
309
310
311
312
313
314
315
316
    """Transformer language model.

    Arguments:
        transformer_hparams: transformer hyperparameters
        vocab_size: vocabulary size
        max_sequence_length: maximum size of sequence. This
                             is used for positional embedding
        embedding_dropout_prob: dropout probability for embeddings
        num_tokentypes: size of the token-type embeddings. 0 value
                        will ignore this embedding
    """
Neel Kant's avatar
Neel Kant committed
317

318
    def __init__(self,
Mohammad's avatar
Mohammad committed
319
320
                 init_method,
                 output_layer_init_method,
321
                 encoder_attn_mask_type,
322
                 num_tokentypes=0,
323
                 add_encoder=True,
324
                 add_decoder=False,
325
                 decoder_attn_mask_type=AttnMaskType.causal,
326
327
328
329
                 add_pooler=False,
                 pre_process=True,
                 post_process=True):
        super(TransformerLanguageModel, self).__init__()
Mohammad's avatar
Mohammad committed
330
        args = get_args()
331

332
333
        self.pre_process = pre_process
        self.post_process = post_process
Mohammad's avatar
Mohammad committed
334
        self.hidden_size = args.hidden_size
335
        self.num_tokentypes = num_tokentypes
Mohammad's avatar
Mohammad committed
336
        self.init_method = init_method
337
        self.add_encoder = add_encoder
338
        self.encoder_attn_mask_type = encoder_attn_mask_type
339
        self.add_decoder = add_decoder
340
        self.decoder_attn_mask_type = decoder_attn_mask_type
341
        self.add_pooler = add_pooler
342
        self.encoder_hidden_state = None
343

344
        # Embeddings.
345
        if self.pre_process:
346
347
348
349
350
351
352
            self.embedding = Embedding(self.hidden_size,
                                       args.padded_vocab_size,
                                       args.max_position_embeddings,
                                       args.hidden_dropout,
                                       self.init_method,
                                       self.num_tokentypes)
            self._embedding_key = 'embedding'
353

354
        # Transformer.
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
        # Encoder (usually set to True, False if part of an encoder-decoder
        # architecture and in encoder-only stage).
        if self.add_encoder:
            self.encoder = ParallelTransformer(
                self.init_method,
                output_layer_init_method,
                self_attn_mask_type=self.encoder_attn_mask_type,
                pre_process=self.pre_process,
                post_process=self.post_process
            )
            self._encoder_key = 'encoder'
        else:
            self.encoder = None

        # Decoder (usually set to False, True if part of an encoder-decoder
        # architecture and in decoder-only stage).
Vijay Korthikanti's avatar
Vijay Korthikanti committed
371
372
373
374
375
        if self.add_decoder:
            self.decoder = ParallelTransformer(
                self.init_method,
                output_layer_init_method,
                layer_type=LayerType.decoder,
376
377
378
                self_attn_mask_type=self.decoder_attn_mask_type,
                pre_process=self.pre_process,
                post_process=self.post_process)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
379
            self._decoder_key = 'decoder'
380
381
        else:
            self.decoder = None
382

383
        if self.post_process:
384
385
386
387
388
            # Pooler.
            if self.add_pooler:
                self.pooler = Pooler(self.hidden_size, self.init_method)
                self._pooler_key = 'pooler'

389
    def set_input_tensor(self, input_tensor):
390
        """ See megatron.model.transformer.set_input_tensor()"""
391
392
393
394
395
396

        # This is usually handled in schedules.py but some inference code still
        # gives us non-lists or None
        if not isinstance(input_tensor, list):
            input_tensor = [input_tensor]

397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
        if self.add_encoder and self.add_decoder:
            assert len(input_tensor) == 1, \
                'input_tensor should only be length 1 for stage with both encoder and decoder'
            self.encoder.set_input_tensor(input_tensor[0])
        elif self.add_encoder:
            assert len(input_tensor) == 1, \
                'input_tensor should only be length 1 for stage with only encoder'
            self.encoder.set_input_tensor(input_tensor[0])
        elif self.add_decoder:
            if len(input_tensor) == 2:
                self.decoder.set_input_tensor(input_tensor[0])
                self.encoder_hidden_state = input_tensor[1]
            elif len(input_tensor) == 1:
                self.decoder.set_input_tensor(None)
                self.encoder_hidden_state = input_tensor[0]
            else:
                raise Exception('input_tensor must have either length 1 or 2')
        else:
            raise Exception('Stage must have at least either encoder or decoder')
416
417
418

    def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
                dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
419
                enc_dec_attn_mask=None, tokentype_ids=None,
mshoeybi's avatar
mshoeybi committed
420
                inference_params=None,
421
                pooling_sequence_index=0,
422
                enc_hidden_states=None, output_enc_hidden=False):
423

424
        # Encoder embedding.
425
        if self.pre_process:
426
427
            encoder_input = self.embedding(enc_input_ids, enc_position_ids,
                                           tokentype_ids=tokentype_ids)
428
        else:
429
            encoder_input = None
430

431
        # Run encoder.
432
        if enc_hidden_states is None:
433
            if self.encoder is not None:
434
435
436
                encoder_output = self.encoder(
                    encoder_input,
                    enc_attn_mask,
mshoeybi's avatar
mshoeybi committed
437
                    inference_params=inference_params)
438
439
            else:
                encoder_output = self.encoder_hidden_state
440
441
442
        else:
            encoder_output = enc_hidden_states.to(encoder_input.dtype)

443
        if self.post_process:
444
            if self.add_pooler:
445
446
                pooled_output = self.pooler(encoder_output,
                                            pooling_sequence_index)
447

Vijay Korthikanti's avatar
Vijay Korthikanti committed
448
449
450
451
        # output_enc_hidden refers to when we just need the encoder's
        # output. For example, it is helpful to compute
        # similarity between two sequences by average pooling
        if not self.add_decoder or output_enc_hidden:
452
            if self.add_pooler and self.post_process:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
453
                return encoder_output, pooled_output
454
            else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
455
456
                return encoder_output

457
458
459
460
461
462
463
464
        # Decoder embedding.
        if self.pre_process:
            decoder_input = self.embedding(dec_input_ids,
                                           dec_position_ids)
        else:
            decoder_input = None

        # Run decoder.
465
        decoder_output = self.decoder(
466
            decoder_input,
467
468
469
            dec_attn_mask,
            encoder_output=encoder_output,
            enc_dec_attn_mask=enc_dec_attn_mask,
mshoeybi's avatar
mshoeybi committed
470
            inference_params=inference_params)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
471

472
        if self.add_pooler and self.post_process:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
473
474
475
            return decoder_output, encoder_output, pooled_output
        else:
            return decoder_output, encoder_output
476
477
478
479
480
481

    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
        """For easy load."""

        state_dict_ = {}
482
        if self.pre_process:
483
484
485
            state_dict_[self._embedding_key] \
                = self.embedding.state_dict_for_save_checkpoint(
                    destination, prefix, keep_vars)
486
487
488
489
        if self.add_encoder:
            state_dict_[self._encoder_key] \
                = self.encoder.state_dict_for_save_checkpoint(
                    destination, prefix, keep_vars)
490
        if self.post_process:
491
492
493
494
            if self.add_pooler:
                state_dict_[self._pooler_key] \
                    = self.pooler.state_dict_for_save_checkpoint(
                        destination, prefix, keep_vars)
495
496
497
        if self.add_decoder:
            state_dict_[self._decoder_key] \
                = self.decoder.state_dict_for_save_checkpoint(
498
499
500
501
502
503
504
505
                    destination, prefix, keep_vars)

        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        # Embedding.
506
        if self.pre_process:
507
508
509
510
511
512
513
514
515
            if self._embedding_key in state_dict:
                state_dict_ = state_dict[self._embedding_key]
            else:
                # for backward compatibility.
                state_dict_ = {}
                for key in state_dict.keys():
                    if '_embeddings' in key:
                        state_dict_[key] = state_dict[key]
            self.embedding.load_state_dict(state_dict_, strict=strict)
516

517
        # Encoder.
518
519
520
521
522
523
        if self.add_encoder:
            if self._encoder_key in state_dict:
                state_dict_ = state_dict[self._encoder_key]
            # For backward compatibility.
            elif 'transformer' in state_dict:
                state_dict_ = state_dict['transformer']
524
            else:
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
                # For backward compatibility.
                state_dict_ = {}
                for key in state_dict.keys():
                    if 'transformer.' in key:
                        state_dict_[key.split('transformer.')[1]] = state_dict[key]

            # For backward compatibility.
            state_dict_self_attention = {}
            for key in state_dict_.keys():
                if '.attention.' in key:
                    state_dict_self_attention[key.replace(".attention.",
                        ".self_attention.")] = state_dict_[key]
                else:
                    state_dict_self_attention[key] = state_dict_[key]
            state_dict_ = state_dict_self_attention

            self.encoder.load_state_dict(state_dict_, strict=strict)

        # Pooler.
544
        if self.post_process:
545
546
547
548
549
            if self.add_pooler:
                assert 'pooler' in state_dict, \
                    'could not find data for pooler in the checkpoint'
                self.pooler.load_state_dict(state_dict[self._pooler_key],
                                            strict=strict)
550
        # Decoder.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
551
552
        if self.add_decoder:
            assert 'decoder' in state_dict, \
553
                'could not find data for pooler in the checkpoint'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
554
555
            self.decoder.load_state_dict(state_dict[self._decoder_key],
                                         strict=strict)