language_model.py 20.8 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer based language model."""

import torch
import torch.nn.functional as F

Mohammad's avatar
Mohammad committed
21
from megatron import get_args
22
from megatron import mpu
23
from .module import MegatronModule
24
from megatron.model.enums import LayerType, AttnMaskType
Mohammad's avatar
Mohammad committed
25
26
from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import get_linear_layer
27
from megatron.model.utils import init_method_normal, scaled_init_method_normal
28
29
30
31

def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
                       bias=None):
    """LM logits using word embedding weights."""
Sangkug Lym's avatar
Sangkug Lym committed
32
    args = get_args()
33
    # Parallel logits.
Sangkug Lym's avatar
Sangkug Lym committed
34
35
36
    if args.async_tensor_model_parallel_allreduce:
        input_parallel = input_
        async_grad_allreduce = mpu.get_tensor_model_parallel_world_size() > 1
37
    else:
Sangkug Lym's avatar
Sangkug Lym committed
38
39
40
41
42
43
44
        input_parallel = mpu.copy_to_tensor_model_parallel_region(input_)
        async_grad_allreduce = False
    # Matrix multiply.
    logits_parallel = mpu.LinearWithGradAccumulationAndAsyncAllreduce.apply(
            input_parallel, word_embeddings_weight, bias,
            args.gradient_accumulation_fusion,
            async_grad_allreduce)
45
46
47
    # Gather if needed.
    if parallel_output:
        return logits_parallel
Mohammad's avatar
Mohammad committed
48

49
    return mpu.gather_from_tensor_model_parallel_region(logits_parallel)
Mohammad's avatar
Mohammad committed
50
51


52
def get_language_model(num_tokentypes, add_pooler,
53
                       encoder_attn_mask_type, init_method=None,
54
55
                       scaled_init_method=None, add_encoder=True,
                       add_decoder=False,
56
57
                       decoder_attn_mask_type=AttnMaskType.causal,
                       pre_process=True, post_process=True):
Mohammad's avatar
Mohammad committed
58
    """Build language model and return along with the key to save."""
59
    args = get_args()
Mohammad's avatar
Mohammad committed
60

61
62
63
64
    if init_method is None:
        init_method = init_method_normal(args.init_method_std)

    if scaled_init_method is None:
65
66
        scaled_init_method = scaled_init_method_normal(args.init_method_std,
                                                       args.num_layers)
67

68
    # Language model.
69
70
71
72
73
    language_model = TransformerLanguageModel(
        init_method,
        scaled_init_method,
        encoder_attn_mask_type,
        num_tokentypes=num_tokentypes,
74
        add_encoder=add_encoder,
75
76
77
78
79
80
        add_decoder=add_decoder,
        decoder_attn_mask_type=decoder_attn_mask_type,
        add_pooler=add_pooler,
        pre_process=pre_process,
        post_process=post_process
    )
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    # key used for checkpoints.
    language_model_key = 'language_model'

    return language_model, language_model_key


class Pooler(MegatronModule):
    """Pooler layer.

    Pool hidden states of a specific token (for example start of the
    sequence) and add a linear transformation followed by a tanh.

    Arguments:
        hidden_size: hidden size
        init_method: weight initialization method for the linear layer.
            bias is set to zero.
    """
Neel Kant's avatar
Neel Kant committed
98

99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    def __init__(self, hidden_size, init_method):
        super(Pooler, self).__init__()
        self.dense = get_linear_layer(hidden_size, hidden_size, init_method)

    def forward(self, hidden_states, sequence_index=0):
        # hidden_states: [b, s, h]
        # sequence_index: index of the token to pool.
        pooled = hidden_states[:, sequence_index, :]
        pooled = self.dense(pooled)
        pooled = torch.tanh(pooled)
        return pooled


class Embedding(MegatronModule):
    """Language model embeddings.

    Arguments:
        hidden_size: hidden size
        vocab_size: vocabulary size
        max_sequence_length: maximum size of sequence. This
                             is used for positional embedding
        embedding_dropout_prob: dropout probability for embeddings
        init_method: weight initialization method
        num_tokentypes: size of the token-type embeddings. 0 value
                        will ignore this embedding
    """
Neel Kant's avatar
Neel Kant committed
125

126
127
128
129
130
131
132
133
134
135
136
137
138
    def __init__(self,
                 hidden_size,
                 vocab_size,
                 max_sequence_length,
                 embedding_dropout_prob,
                 init_method,
                 num_tokentypes=0):
        super(Embedding, self).__init__()

        self.hidden_size = hidden_size
        self.init_method = init_method
        self.num_tokentypes = num_tokentypes

139
140
        args = get_args()

141
142
        # Word embeddings (parallel).
        self.word_embeddings = mpu.VocabParallelEmbedding(
143
144
            vocab_size, self.hidden_size,
            init_method=self.init_method)
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
        self._word_embeddings_key = 'word_embeddings'

        # Position embedding (serial).
        self.position_embeddings = torch.nn.Embedding(
            max_sequence_length, self.hidden_size)
        self._position_embeddings_key = 'position_embeddings'
        # Initialize the position embeddings.
        self.init_method(self.position_embeddings.weight)

        # Token type embedding.
        # Add this as an optional field that can be added through
        # method call so we can load a pretrain model without
        # token types and add them as needed.
        self._tokentype_embeddings_key = 'tokentype_embeddings'
        if self.num_tokentypes > 0:
            self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes,
                                                           self.hidden_size)
            # Initialize the token-type embeddings.
            self.init_method(self.tokentype_embeddings.weight)
        else:
            self.tokentype_embeddings = None

        # Embeddings dropout
        self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)

170
171
172
    def zero_parameters(self):
        """Zero out all parameters in embedding."""
        self.word_embeddings.weight.data.fill_(0)
Deepak Narayanan's avatar
Deepak Narayanan committed
173
        self.word_embeddings.weight.shared = True
174
        self.position_embeddings.weight.data.fill_(0)
Deepak Narayanan's avatar
Deepak Narayanan committed
175
        self.position_embeddings.weight.shared = True
176
177
        if self.num_tokentypes > 0:
            self.tokentype_embeddings.weight.data.fill_(0)
Deepak Narayanan's avatar
Deepak Narayanan committed
178
            self.tokentype_embeddings.weight.shared = True
179

180
181
182
183
184
185
186
187
188
189
190
191
192
193
    def add_tokentype_embeddings(self, num_tokentypes):
        """Add token-type embedding. This function is provided so we can add
        token-type embeddings in case the pretrained model does not have it.
        This allows us to load the model normally and then add this embedding.
        """
        if self.tokentype_embeddings is not None:
            raise Exception('tokentype embeddings is already initialized')
        if torch.distributed.get_rank() == 0:
            print('adding embedding for {} tokentypes'.format(num_tokentypes),
                  flush=True)
        self.num_tokentypes = num_tokentypes
        self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes,
                                                       self.hidden_size)
        # Initialize the token-type embeddings.
194
        args = get_args()
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
        self.init_method(self.tokentype_embeddings.weight)

    def forward(self, input_ids, position_ids, tokentype_ids=None):
        # Embeddings.
        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        embeddings = words_embeddings + position_embeddings
        if tokentype_ids is not None:
            assert self.tokentype_embeddings is not None
            embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
        else:
            assert self.tokentype_embeddings is None

        # Dropout.
        embeddings = self.embedding_dropout(embeddings)

        return embeddings

    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
        """For easy load."""

        state_dict_ = {}
        state_dict_[self._word_embeddings_key] \
            = self.word_embeddings.state_dict(destination, prefix, keep_vars)
        state_dict_[self._position_embeddings_key] \
            = self.position_embeddings.state_dict(
                destination, prefix, keep_vars)
        if self.num_tokentypes > 0:
            state_dict_[self._tokentype_embeddings_key] \
                = self.tokentype_embeddings.state_dict(
                    destination, prefix, keep_vars)

        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        # Word embedding.
        if self._word_embeddings_key in state_dict:
            state_dict_ = state_dict[self._word_embeddings_key]
        else:
            # for backward compatibility.
            state_dict_ = {}
            for key in state_dict.keys():
                if 'word_embeddings' in key:
                    state_dict_[key.split('word_embeddings.')[1]] \
                        = state_dict[key]
        self.word_embeddings.load_state_dict(state_dict_, strict=strict)

        # Position embedding.
        if self._position_embeddings_key in state_dict:
            state_dict_ = state_dict[self._position_embeddings_key]
        else:
            # for backward compatibility.
            state_dict_ = {}
            for key in state_dict.keys():
                if 'position_embeddings' in key:
                    state_dict_[key.split('position_embeddings.')[1]] \
                        = state_dict[key]
        self.position_embeddings.load_state_dict(state_dict_, strict=strict)

        # Tokentype embedding.
Neel Kant's avatar
Neel Kant committed
258
        if self.num_tokentypes > 0:
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
            state_dict_ = {}
            if self._tokentype_embeddings_key in state_dict:
                state_dict_ = state_dict[self._tokentype_embeddings_key]
            else:
                # for backward compatibility.
                for key in state_dict.keys():
                    if 'tokentype_embeddings' in key:
                        state_dict_[key.split('tokentype_embeddings.')[1]] \
                            = state_dict[key]
            if len(state_dict_.keys()) > 0:
                self.tokentype_embeddings.load_state_dict(state_dict_,
                                                          strict=strict)
            else:
                print('***WARNING*** expected tokentype embeddings in the '
                      'checkpoint but could not find it', flush=True)


276
class TransformerLanguageModel(MegatronModule):
277
278
279
280
281
282
283
284
285
286
287
    """Transformer language model.

    Arguments:
        transformer_hparams: transformer hyperparameters
        vocab_size: vocabulary size
        max_sequence_length: maximum size of sequence. This
                             is used for positional embedding
        embedding_dropout_prob: dropout probability for embeddings
        num_tokentypes: size of the token-type embeddings. 0 value
                        will ignore this embedding
    """
Neel Kant's avatar
Neel Kant committed
288

289
    def __init__(self,
Mohammad's avatar
Mohammad committed
290
291
                 init_method,
                 output_layer_init_method,
292
                 encoder_attn_mask_type,
293
                 num_tokentypes=0,
294
                 add_encoder=True,
295
                 add_decoder=False,
296
                 decoder_attn_mask_type=AttnMaskType.causal,
297
298
299
300
                 add_pooler=False,
                 pre_process=True,
                 post_process=True):
        super(TransformerLanguageModel, self).__init__()
Mohammad's avatar
Mohammad committed
301
        args = get_args()
302

303
304
        self.pre_process = pre_process
        self.post_process = post_process
Mohammad's avatar
Mohammad committed
305
        self.hidden_size = args.hidden_size
306
        self.num_tokentypes = num_tokentypes
Mohammad's avatar
Mohammad committed
307
        self.init_method = init_method
308
        self.add_encoder = add_encoder
309
        self.encoder_attn_mask_type = encoder_attn_mask_type
310
        self.add_decoder = add_decoder
311
        self.decoder_attn_mask_type = decoder_attn_mask_type
312
        self.add_pooler = add_pooler
313
        self.encoder_hidden_state = None
314

315
        # Embeddings.
316
        if self.pre_process:
317
318
319
320
321
322
323
            self.embedding = Embedding(self.hidden_size,
                                       args.padded_vocab_size,
                                       args.max_position_embeddings,
                                       args.hidden_dropout,
                                       self.init_method,
                                       self.num_tokentypes)
            self._embedding_key = 'embedding'
324

325
        # Transformer.
326
327
328
329
330
331
332
333
334
335
336
337
338
339
        # Encoder (usually set to True, False if part of an encoder-decoder
        # architecture and in encoder-only stage).
        if self.add_encoder:
            self.encoder = ParallelTransformer(
                self.init_method,
                output_layer_init_method,
                self_attn_mask_type=self.encoder_attn_mask_type,
                pre_process=self.pre_process,
                post_process=self.post_process
            )
            self._encoder_key = 'encoder'
        else:
            self.encoder = None

Lawrence McAfee's avatar
Lawrence McAfee committed
340
341
342
343
344
345
        # >>>
        # if torch.distributed.get_rank() == 3:
        #     print(self.encoder)
        #     raise Exception("bye.")
        # <<<

346
347
        # Decoder (usually set to False, True if part of an encoder-decoder
        # architecture and in decoder-only stage).
Vijay Korthikanti's avatar
Vijay Korthikanti committed
348
349
350
351
352
        if self.add_decoder:
            self.decoder = ParallelTransformer(
                self.init_method,
                output_layer_init_method,
                layer_type=LayerType.decoder,
353
354
355
                self_attn_mask_type=self.decoder_attn_mask_type,
                pre_process=self.pre_process,
                post_process=self.post_process)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
356
            self._decoder_key = 'decoder'
357
358
        else:
            self.decoder = None
359

360
        if self.post_process:
361
362
363
364
365
            # Pooler.
            if self.add_pooler:
                self.pooler = Pooler(self.hidden_size, self.init_method)
                self._pooler_key = 'pooler'

366
    def set_input_tensor(self, input_tensor):
367
        """ See megatron.model.transformer.set_input_tensor()"""
368
369
370
371
372
373

        # This is usually handled in schedules.py but some inference code still
        # gives us non-lists or None
        if not isinstance(input_tensor, list):
            input_tensor = [input_tensor]

374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
        if self.add_encoder and self.add_decoder:
            assert len(input_tensor) == 1, \
                'input_tensor should only be length 1 for stage with both encoder and decoder'
            self.encoder.set_input_tensor(input_tensor[0])
        elif self.add_encoder:
            assert len(input_tensor) == 1, \
                'input_tensor should only be length 1 for stage with only encoder'
            self.encoder.set_input_tensor(input_tensor[0])
        elif self.add_decoder:
            if len(input_tensor) == 2:
                self.decoder.set_input_tensor(input_tensor[0])
                self.encoder_hidden_state = input_tensor[1]
            elif len(input_tensor) == 1:
                self.decoder.set_input_tensor(None)
                self.encoder_hidden_state = input_tensor[0]
            else:
                raise Exception('input_tensor must have either length 1 or 2')
        else:
            raise Exception('Stage must have at least either encoder or decoder')
393
394
395

    def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
                dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
396
                enc_dec_attn_mask=None, tokentype_ids=None,
mshoeybi's avatar
mshoeybi committed
397
                inference_params=None,
398
                pooling_sequence_index=0,
399
                enc_hidden_states=None, output_enc_hidden=False):
400

401
        # Encoder embedding.
402
        if self.pre_process:
403
404
            encoder_input = self.embedding(enc_input_ids, enc_position_ids,
                                           tokentype_ids=tokentype_ids)
405
        else:
406
            encoder_input = None
407

408
        # Run encoder.
409
        if enc_hidden_states is None:
410
            if self.encoder is not None:
411
412
413
                encoder_output = self.encoder(
                    encoder_input,
                    enc_attn_mask,
mshoeybi's avatar
mshoeybi committed
414
                    inference_params=inference_params)
415
416
            else:
                encoder_output = self.encoder_hidden_state
417
418
419
        else:
            encoder_output = enc_hidden_states.to(encoder_input.dtype)

420
        if self.post_process:
421
422
423
424
            if self.add_pooler:
                pooled_output = self.pooler(encoder_output,
                                            pooling_sequence_index)

Vijay Korthikanti's avatar
Vijay Korthikanti committed
425
426
427
428
        # output_enc_hidden refers to when we just need the encoder's
        # output. For example, it is helpful to compute
        # similarity between two sequences by average pooling
        if not self.add_decoder or output_enc_hidden:
429
            if self.add_pooler and self.post_process:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
430
                return encoder_output, pooled_output
431
            else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
432
433
                return encoder_output

434
435
436
437
438
439
440
441
        # Decoder embedding.
        if self.pre_process:
            decoder_input = self.embedding(dec_input_ids,
                                           dec_position_ids)
        else:
            decoder_input = None

        # Run decoder.
442
        decoder_output = self.decoder(
443
            decoder_input,
444
445
446
            dec_attn_mask,
            encoder_output=encoder_output,
            enc_dec_attn_mask=enc_dec_attn_mask,
mshoeybi's avatar
mshoeybi committed
447
            inference_params=inference_params)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
448

449
        if self.add_pooler and self.post_process:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
450
451
452
            return decoder_output, encoder_output, pooled_output
        else:
            return decoder_output, encoder_output
453
454
455
456
457
458

    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
        """For easy load."""

        state_dict_ = {}
459
        if self.pre_process:
460
461
462
            state_dict_[self._embedding_key] \
                = self.embedding.state_dict_for_save_checkpoint(
                    destination, prefix, keep_vars)
463
464
465
466
        if self.add_encoder:
            state_dict_[self._encoder_key] \
                = self.encoder.state_dict_for_save_checkpoint(
                    destination, prefix, keep_vars)
467
        if self.post_process:
468
469
470
471
            if self.add_pooler:
                state_dict_[self._pooler_key] \
                    = self.pooler.state_dict_for_save_checkpoint(
                        destination, prefix, keep_vars)
472
473
474
        if self.add_decoder:
            state_dict_[self._decoder_key] \
                = self.decoder.state_dict_for_save_checkpoint(
475
476
477
478
479
480
481
482
                    destination, prefix, keep_vars)

        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        # Embedding.
483
        if self.pre_process:
484
485
486
487
488
489
490
491
492
            if self._embedding_key in state_dict:
                state_dict_ = state_dict[self._embedding_key]
            else:
                # for backward compatibility.
                state_dict_ = {}
                for key in state_dict.keys():
                    if '_embeddings' in key:
                        state_dict_[key] = state_dict[key]
            self.embedding.load_state_dict(state_dict_, strict=strict)
493

494
        # Encoder.
495
496
497
498
499
500
        if self.add_encoder:
            if self._encoder_key in state_dict:
                state_dict_ = state_dict[self._encoder_key]
            # For backward compatibility.
            elif 'transformer' in state_dict:
                state_dict_ = state_dict['transformer']
501
            else:
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
                # For backward compatibility.
                state_dict_ = {}
                for key in state_dict.keys():
                    if 'transformer.' in key:
                        state_dict_[key.split('transformer.')[1]] = state_dict[key]

            # For backward compatibility.
            state_dict_self_attention = {}
            for key in state_dict_.keys():
                if '.attention.' in key:
                    state_dict_self_attention[key.replace(".attention.",
                        ".self_attention.")] = state_dict_[key]
                else:
                    state_dict_self_attention[key] = state_dict_[key]
            state_dict_ = state_dict_self_attention

            self.encoder.load_state_dict(state_dict_, strict=strict)

        # Pooler.
521
        if self.post_process:
522
523
524
525
526
            if self.add_pooler:
                assert 'pooler' in state_dict, \
                    'could not find data for pooler in the checkpoint'
                self.pooler.load_state_dict(state_dict[self._pooler_key],
                                            strict=strict)
527
        # Decoder.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
528
529
        if self.add_decoder:
            assert 'decoder' in state_dict, \
530
                'could not find data for pooler in the checkpoint'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
531
532
            self.decoder.load_state_dict(state_dict[self._decoder_key],
                                         strict=strict)