language_model.py 19.8 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer based language model."""

import torch
import torch.nn.functional as F

Mohammad's avatar
Mohammad committed
21
from megatron import get_args
22
23
from megatron import mpu
from megatron.module import MegatronModule
Mohammad's avatar
Mohammad committed
24
25
from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import get_linear_layer
26
from megatron.model.utils import init_method_normal, scaled_init_method_normal
27
28
29
30
31

def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
                       bias=None):
    """LM logits using word embedding weights."""
    # Parallel logits.
32
    input_parallel = mpu.copy_to_tensor_model_parallel_region(input_)
33
34
35
36
37
38
39
40
    # Matrix multiply.
    if bias is None:
        logits_parallel = F.linear(input_parallel, word_embeddings_weight)
    else:
        logits_parallel = F.linear(input_parallel, word_embeddings_weight, bias)
    # Gather if needed.
    if parallel_output:
        return logits_parallel
Mohammad's avatar
Mohammad committed
41

42
    return mpu.gather_from_tensor_model_parallel_region(logits_parallel)
Mohammad's avatar
Mohammad committed
43
44
45


def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
46
                       init_method=None, scaled_init_method=None):
Mohammad's avatar
Mohammad committed
47
    """Build language model and return along with the key to save."""
48
    args = get_args()
Mohammad's avatar
Mohammad committed
49

50
51
52
53
54
55
    if init_method is None:
        init_method = init_method_normal(args.init_method_std)

    if scaled_init_method is None:
        scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers)

56
    # Language model.
57
58
59
    args = [attention_mask_func, init_method, scaled_init_method]
    kwargs = {}
    cls = None
60
    if mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage():
61
62
63
        cls = TransformerLanguageModel
        kwargs['num_tokentypes'] = num_tokentypes
        kwargs['add_pooler'] = add_pooler
64
    elif mpu.is_pipeline_first_stage() and not mpu.is_pipeline_last_stage():
65
66
        cls = TransformerLanguageModelFirstStage
        kwargs['num_tokentypes'] = num_tokentypes
67
    elif not mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage():
68
69
70
71
72
73
74
        cls = TransformerLanguageModelLastStage
        kwargs['add_pooler'] = add_pooler
    else:
        cls = TransformerLanguageModelIntermediateStage

    # Language model.
    language_model = cls(*args, **kwargs)
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
    # key used for checkpoints.
    language_model_key = 'language_model'

    return language_model, language_model_key


class Pooler(MegatronModule):
    """Pooler layer.

    Pool hidden states of a specific token (for example start of the
    sequence) and add a linear transformation followed by a tanh.

    Arguments:
        hidden_size: hidden size
        init_method: weight initialization method for the linear layer.
            bias is set to zero.
    """
Neel Kant's avatar
Neel Kant committed
92

93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
    def __init__(self, hidden_size, init_method):
        super(Pooler, self).__init__()
        self.dense = get_linear_layer(hidden_size, hidden_size, init_method)

    def forward(self, hidden_states, sequence_index=0):
        # hidden_states: [b, s, h]
        # sequence_index: index of the token to pool.
        pooled = hidden_states[:, sequence_index, :]
        pooled = self.dense(pooled)
        pooled = torch.tanh(pooled)
        return pooled


class Embedding(MegatronModule):
    """Language model embeddings.

    Arguments:
        hidden_size: hidden size
        vocab_size: vocabulary size
        max_sequence_length: maximum size of sequence. This
                             is used for positional embedding
        embedding_dropout_prob: dropout probability for embeddings
        init_method: weight initialization method
        num_tokentypes: size of the token-type embeddings. 0 value
                        will ignore this embedding
    """
Neel Kant's avatar
Neel Kant committed
119

120
121
122
123
124
125
126
127
128
129
130
131
132
    def __init__(self,
                 hidden_size,
                 vocab_size,
                 max_sequence_length,
                 embedding_dropout_prob,
                 init_method,
                 num_tokentypes=0):
        super(Embedding, self).__init__()

        self.hidden_size = hidden_size
        self.init_method = init_method
        self.num_tokentypes = num_tokentypes

133
134
        args = get_args()

135
136
        # Word embeddings (parallel).
        self.word_embeddings = mpu.VocabParallelEmbedding(
137
138
            vocab_size, self.hidden_size,
            init_method=self.init_method)
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
        self._word_embeddings_key = 'word_embeddings'

        # Position embedding (serial).
        self.position_embeddings = torch.nn.Embedding(
            max_sequence_length, self.hidden_size)
        self._position_embeddings_key = 'position_embeddings'
        # Initialize the position embeddings.
        self.init_method(self.position_embeddings.weight)

        # Token type embedding.
        # Add this as an optional field that can be added through
        # method call so we can load a pretrain model without
        # token types and add them as needed.
        self._tokentype_embeddings_key = 'tokentype_embeddings'
        if self.num_tokentypes > 0:
            self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes,
                                                           self.hidden_size)
            # Initialize the token-type embeddings.
            self.init_method(self.tokentype_embeddings.weight)
        else:
            self.tokentype_embeddings = None

        # Embeddings dropout
        self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)

    def add_tokentype_embeddings(self, num_tokentypes):
        """Add token-type embedding. This function is provided so we can add
        token-type embeddings in case the pretrained model does not have it.
        This allows us to load the model normally and then add this embedding.
        """
        if self.tokentype_embeddings is not None:
            raise Exception('tokentype embeddings is already initialized')
        if torch.distributed.get_rank() == 0:
            print('adding embedding for {} tokentypes'.format(num_tokentypes),
                  flush=True)
        self.num_tokentypes = num_tokentypes
        self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes,
                                                       self.hidden_size)
        # Initialize the token-type embeddings.
178
        args = get_args()
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
        self.init_method(self.tokentype_embeddings.weight)

    def forward(self, input_ids, position_ids, tokentype_ids=None):
        # Embeddings.
        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        embeddings = words_embeddings + position_embeddings
        if tokentype_ids is not None:
            assert self.tokentype_embeddings is not None
            embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
        else:
            assert self.tokentype_embeddings is None

        # Dropout.
        embeddings = self.embedding_dropout(embeddings)

        return embeddings

    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
        """For easy load."""

        state_dict_ = {}
        state_dict_[self._word_embeddings_key] \
            = self.word_embeddings.state_dict(destination, prefix, keep_vars)
        state_dict_[self._position_embeddings_key] \
            = self.position_embeddings.state_dict(
                destination, prefix, keep_vars)
        if self.num_tokentypes > 0:
            state_dict_[self._tokentype_embeddings_key] \
                = self.tokentype_embeddings.state_dict(
                    destination, prefix, keep_vars)

        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        # Word embedding.
        if self._word_embeddings_key in state_dict:
            state_dict_ = state_dict[self._word_embeddings_key]
        else:
            # for backward compatibility.
            state_dict_ = {}
            for key in state_dict.keys():
                if 'word_embeddings' in key:
                    state_dict_[key.split('word_embeddings.')[1]] \
                        = state_dict[key]
        self.word_embeddings.load_state_dict(state_dict_, strict=strict)

        # Position embedding.
        if self._position_embeddings_key in state_dict:
            state_dict_ = state_dict[self._position_embeddings_key]
        else:
            # for backward compatibility.
            state_dict_ = {}
            for key in state_dict.keys():
                if 'position_embeddings' in key:
                    state_dict_[key.split('position_embeddings.')[1]] \
                        = state_dict[key]
        self.position_embeddings.load_state_dict(state_dict_, strict=strict)

        # Tokentype embedding.
Neel Kant's avatar
Neel Kant committed
242
        if self.num_tokentypes > 0:
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
            state_dict_ = {}
            if self._tokentype_embeddings_key in state_dict:
                state_dict_ = state_dict[self._tokentype_embeddings_key]
            else:
                # for backward compatibility.
                for key in state_dict.keys():
                    if 'tokentype_embeddings' in key:
                        state_dict_[key.split('tokentype_embeddings.')[1]] \
                            = state_dict[key]
            if len(state_dict_.keys()) > 0:
                self.tokentype_embeddings.load_state_dict(state_dict_,
                                                          strict=strict)
            else:
                print('***WARNING*** expected tokentype embeddings in the '
                      'checkpoint but could not find it', flush=True)


260
class TransformerLanguageModelBase(MegatronModule):
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
    """Transformer language model.

    Arguments:
        transformer_hparams: transformer hyperparameters
        attention_mask_func: a function that takes `unmaksed-attention-scores`
            with size [b, np, s, s] and an `attention-mask` and will apply
            the masking. The function should return a masked score of the
            same size [b, np, s, s].
          masked-attention-scores = attention_mask_func(
                                     unmaksed-attention-scores, attention-mask)
        vocab_size: vocabulary size
        max_sequence_length: maximum size of sequence. This
                             is used for positional embedding
        embedding_dropout_prob: dropout probability for embeddings
        num_tokentypes: size of the token-type embeddings. 0 value
                        will ignore this embedding
    """
Neel Kant's avatar
Neel Kant committed
278

279
280
    def __init__(self,
                 attention_mask_func,
Mohammad's avatar
Mohammad committed
281
282
                 init_method,
                 output_layer_init_method,
283
284
                 num_tokentypes=0,
                 add_pooler=False):
285
        super(TransformerLanguageModelBase, self).__init__()
Mohammad's avatar
Mohammad committed
286
        args = get_args()
287

Mohammad's avatar
Mohammad committed
288
        self.hidden_size = args.hidden_size
289
        self.num_tokentypes = num_tokentypes
Mohammad's avatar
Mohammad committed
290
        self.init_method = init_method
291
292
        self.add_pooler = add_pooler

293
        # Embeddings.
294
        if mpu.is_pipeline_first_stage():
295
296
297
298
299
300
301
            self.embedding = Embedding(self.hidden_size,
                                       args.padded_vocab_size,
                                       args.max_position_embeddings,
                                       args.hidden_dropout,
                                       self.init_method,
                                       self.num_tokentypes)
            self._embedding_key = 'embedding'
302

303
        # Transformer.
304
        self.transformer = ParallelTransformer(
305
306
            attention_mask_func, self.init_method, 
            output_layer_init_method)
307
308
        self._transformer_key = 'transformer'

309
        # Pooler.
310
        if mpu.is_pipeline_last_stage() and self.add_pooler:
311
312
313
            self.pooler = Pooler(self.hidden_size, self.init_method)
            self._pooler_key = 'pooler'

314
    def forward(self, language_model_input, attention_mask,
315
316
317
318
                tokentype_ids=None, layer_past=None, get_key_value=False,
                pooling_sequence_index=0):

        # Embeddings.
319
        if mpu.is_pipeline_first_stage():
320
321
322
323
324
325
            (input_ids, position_ids) = language_model_input
            embedding_output = self.embedding(input_ids, position_ids,
                                              tokentype_ids=tokentype_ids)
            transformer_input = embedding_output
        else:
            transformer_input = language_model_input
326
327

        # Transformer.
328
        transformer_output = self.transformer(transformer_input,
329
330
331
332
                                              attention_mask,
                                              layer_past=layer_past,
                                              get_key_value=get_key_value)

333
        if mpu.is_pipeline_last_stage() and self.add_pooler:
334
335
336
337
338
339
340
341
342
343
344
            pooled_output = self.pooler(transformer_output,
                                        pooling_sequence_index)
            return transformer_output, pooled_output

        return transformer_output

    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
        """For easy load."""

        state_dict_ = {}
345
        if mpu.is_pipeline_first_stage():
346
347
348
            state_dict_[self._embedding_key] \
                = self.embedding.state_dict_for_save_checkpoint(
                    destination, prefix, keep_vars)
349
350
351
        state_dict_[self._transformer_key] \
            = self.transformer.state_dict_for_save_checkpoint(
                destination, prefix, keep_vars)
352
        if mpu.is_pipeline_last_stage() and self.add_pooler:
353
354
355
356
357
358
359
360
361
362
            state_dict_[self._pooler_key] \
                = self.pooler.state_dict_for_save_checkpoint(
                    destination, prefix, keep_vars)

        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        # Embedding.
363
        if mpu.is_pipeline_first_stage():
364
365
366
367
368
369
370
371
372
            if self._embedding_key in state_dict:
                state_dict_ = state_dict[self._embedding_key]
            else:
                # for backward compatibility.
                state_dict_ = {}
                for key in state_dict.keys():
                    if '_embeddings' in key:
                        state_dict_[key] = state_dict[key]
            self.embedding.load_state_dict(state_dict_, strict=strict)
373
374
375
376

        # Transformer.
        if self._transformer_key in state_dict:
            state_dict_ = state_dict[self._transformer_key]
Mostofa Patwary's avatar
Mostofa Patwary committed
377
378
379
380
381
382
383
384
385
386
387
388
389
        # for compatiability with t5 architecture
        # this is temporary unless t5_main is merged
        elif 'encoder' in state_dict:
            state_dict_ = state_dict['encoder']
            # for forward compatibility for t5 architecture
            state_dict_attention = {}
            for key in state_dict_.keys():
                if '.self_attention.' in key:
                    state_dict_attention[key.replace(".self_attention.",
                        ".attention.")] = state_dict_[key]
                else:
                    state_dict_attention[key] = state_dict_[key]
            state_dict_ = state_dict_attention
390
391
392
393
394
395
396
397
398
        else:
            # for backward compatibility.
            state_dict_ = {}
            for key in state_dict.keys():
                if 'transformer.' in key:
                    state_dict_[key.split('transformer.')[1]] = state_dict[key]
        self.transformer.load_state_dict(state_dict_, strict=strict)

        # Pooler.
399
        if mpu.is_pipeline_last_stage() and self.add_pooler:
400
401
402
403
            assert 'pooler' in state_dict, \
                'could not find data for pooler in the checkpoint'
            self.pooler.load_state_dict(state_dict[self._pooler_key],
                                        strict=strict)
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513


class TransformerLanguageModel(TransformerLanguageModelBase):
    """Transformer language model (see TransformerLanguageModelBase
       for description of arguments).
    """

    def __init__(self,
                 attention_mask_func,
                 init_method,
                 output_layer_init_method,
                 num_tokentypes=0,
                 add_pooler=False):
        super(TransformerLanguageModel, self).__init__(
            attention_mask_func,
            init_method,
            output_layer_init_method,
            num_tokentypes=num_tokentypes,
            add_pooler=add_pooler)

    def forward(self, input_ids, position_ids, attention_mask,
                tokentype_ids=None, layer_past=None, get_key_value=False,
                pooling_sequence_index=0):
        return super(TransformerLanguageModel, self).forward(
            (input_ids, position_ids),
            attention_mask,
            tokentype_ids=tokentype_ids,
            layer_past=layer_past,
            get_key_value=get_key_value,
            pooling_sequence_index=pooling_sequence_index
        )


class TransformerLanguageModelFirstStage(TransformerLanguageModelBase):
    """Transformer language model, first stage (see
       TransformerLanguageModelBase for description of arguments).
    """

    def __init__(self,
                 attention_mask_func,
                 init_method,
                 output_layer_init_method,
                 num_tokentypes=0):
        super(TransformerLanguageModelFirstStage, self).__init__(
            attention_mask_func,
            init_method,
            output_layer_init_method,
            num_tokentypes=num_tokentypes)

    def forward(self, input_ids, position_ids, attention_mask,
                tokentype_ids=None, layer_past=None, get_key_value=False):
        return super(TransformerLanguageModelFirstStage, self).forward(
            (input_ids, position_ids),
            attention_mask,
            tokentype_ids=tokentype_ids,
            layer_past=layer_past,
            get_key_value=get_key_value
        )


class TransformerLanguageModelIntermediateStage(TransformerLanguageModelBase):
    """Transformer language model, intermediate stage (see
       TransformerLanguageModelBase for description of arguments).
    """

    def __init__(self,
                 attention_mask_func,
                 init_method,
                 output_layer_init_method):
        super(TransformerLanguageModelIntermediateStage, self).__init__(
            attention_mask_func,
            init_method,
            output_layer_init_method)

    def forward(self, hidden_states, attention_mask,
                layer_past=None, get_key_value=False):
        return super(TransformerLanguageModelIntermediateStage, self).forward(
            hidden_states,
            attention_mask,
            layer_past=layer_past,
            get_key_value=get_key_value
        )


class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
    """Transformer language model, final stage (see
       TransformerLanguageModelBase for description of arguments).
    """

    def __init__(self,
                 attention_mask_func,
                 init_method,
                 output_layer_init_method,
                 add_pooler=False):
        super(TransformerLanguageModelLastStage, self).__init__(
            attention_mask_func,
            init_method,
            output_layer_init_method,
            add_pooler=add_pooler)

    def forward(self, hidden_states, attention_mask,
                layer_past=None, get_key_value=False,
                pooling_sequence_index=0):
        return super(TransformerLanguageModelLastStage, self).forward(
            hidden_states,
            attention_mask,
            layer_past=layer_past,
            get_key_value=get_key_value,
            pooling_sequence_index=pooling_sequence_index
        )