"router/client/src/v3/client.rs" did not exist on "e520d5b34917f02504a280f87420fd2a5ce0eeb2"
language_model.py 20.5 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer based language model."""

import torch
import torch.nn.functional as F

Mohammad's avatar
Mohammad committed
21
from megatron import get_args
22
from megatron import mpu
23
from .module import MegatronModule
24
from megatron.model.enums import LayerType, AttnMaskType
Mohammad's avatar
Mohammad committed
25
26
from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import get_linear_layer
27
from megatron.model.utils import init_method_normal, scaled_init_method_normal
28
29
30
31
32

def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
                       bias=None):
    """LM logits using word embedding weights."""
    # Parallel logits.
33
    input_parallel = mpu.copy_to_tensor_model_parallel_region(input_)
34
35
36
37
38
39
40
41
    # Matrix multiply.
    if bias is None:
        logits_parallel = F.linear(input_parallel, word_embeddings_weight)
    else:
        logits_parallel = F.linear(input_parallel, word_embeddings_weight, bias)
    # Gather if needed.
    if parallel_output:
        return logits_parallel
Mohammad's avatar
Mohammad committed
42

43
    return mpu.gather_from_tensor_model_parallel_region(logits_parallel)
Mohammad's avatar
Mohammad committed
44
45


46
def get_language_model(num_tokentypes, add_pooler,
47
                       encoder_attn_mask_type, init_method=None,
48
49
                       scaled_init_method=None, add_encoder=True,
                       add_decoder=False,
50
51
                       decoder_attn_mask_type=AttnMaskType.causal,
                       pre_process=True, post_process=True):
Mohammad's avatar
Mohammad committed
52
    """Build language model and return along with the key to save."""
53
    args = get_args()
Mohammad's avatar
Mohammad committed
54

55
56
57
58
    if init_method is None:
        init_method = init_method_normal(args.init_method_std)

    if scaled_init_method is None:
59
60
        scaled_init_method = scaled_init_method_normal(args.init_method_std,
                                                       args.num_layers)
61

62
    # Language model.
63
64
65
66
67
    language_model = TransformerLanguageModel(
        init_method,
        scaled_init_method,
        encoder_attn_mask_type,
        num_tokentypes=num_tokentypes,
68
        add_encoder=add_encoder,
69
70
71
72
73
74
        add_decoder=add_decoder,
        decoder_attn_mask_type=decoder_attn_mask_type,
        add_pooler=add_pooler,
        pre_process=pre_process,
        post_process=post_process
    )
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
    # key used for checkpoints.
    language_model_key = 'language_model'

    return language_model, language_model_key


class Pooler(MegatronModule):
    """Pooler layer.

    Pool hidden states of a specific token (for example start of the
    sequence) and add a linear transformation followed by a tanh.

    Arguments:
        hidden_size: hidden size
        init_method: weight initialization method for the linear layer.
            bias is set to zero.
    """
Neel Kant's avatar
Neel Kant committed
92

93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
    def __init__(self, hidden_size, init_method):
        super(Pooler, self).__init__()
        self.dense = get_linear_layer(hidden_size, hidden_size, init_method)

    def forward(self, hidden_states, sequence_index=0):
        # hidden_states: [b, s, h]
        # sequence_index: index of the token to pool.
        pooled = hidden_states[:, sequence_index, :]
        pooled = self.dense(pooled)
        pooled = torch.tanh(pooled)
        return pooled


class Embedding(MegatronModule):
    """Language model embeddings.

    Arguments:
        hidden_size: hidden size
        vocab_size: vocabulary size
        max_sequence_length: maximum size of sequence. This
                             is used for positional embedding
        embedding_dropout_prob: dropout probability for embeddings
        init_method: weight initialization method
        num_tokentypes: size of the token-type embeddings. 0 value
                        will ignore this embedding
    """
Neel Kant's avatar
Neel Kant committed
119

120
121
122
123
124
125
126
127
128
129
130
131
132
    def __init__(self,
                 hidden_size,
                 vocab_size,
                 max_sequence_length,
                 embedding_dropout_prob,
                 init_method,
                 num_tokentypes=0):
        super(Embedding, self).__init__()

        self.hidden_size = hidden_size
        self.init_method = init_method
        self.num_tokentypes = num_tokentypes

133
134
        args = get_args()

135
136
        # Word embeddings (parallel).
        self.word_embeddings = mpu.VocabParallelEmbedding(
137
138
            vocab_size, self.hidden_size,
            init_method=self.init_method)
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        self._word_embeddings_key = 'word_embeddings'

        # Position embedding (serial).
        self.position_embeddings = torch.nn.Embedding(
            max_sequence_length, self.hidden_size)
        self._position_embeddings_key = 'position_embeddings'
        # Initialize the position embeddings.
        self.init_method(self.position_embeddings.weight)

        # Token type embedding.
        # Add this as an optional field that can be added through
        # method call so we can load a pretrain model without
        # token types and add them as needed.
        self._tokentype_embeddings_key = 'tokentype_embeddings'
        if self.num_tokentypes > 0:
            self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes,
                                                           self.hidden_size)
            # Initialize the token-type embeddings.
            self.init_method(self.tokentype_embeddings.weight)
        else:
            self.tokentype_embeddings = None

        # Embeddings dropout
        self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)

164
165
166
167
168
169
170
    def zero_parameters(self):
        """Zero out all parameters in embedding."""
        self.word_embeddings.weight.data.fill_(0)
        self.position_embeddings.weight.data.fill_(0)
        if self.num_tokentypes > 0:
            self.tokentype_embeddings.weight.data.fill_(0)

171
172
173
174
175
176
177
178
179
180
181
182
183
184
    def add_tokentype_embeddings(self, num_tokentypes):
        """Add token-type embedding. This function is provided so we can add
        token-type embeddings in case the pretrained model does not have it.
        This allows us to load the model normally and then add this embedding.
        """
        if self.tokentype_embeddings is not None:
            raise Exception('tokentype embeddings is already initialized')
        if torch.distributed.get_rank() == 0:
            print('adding embedding for {} tokentypes'.format(num_tokentypes),
                  flush=True)
        self.num_tokentypes = num_tokentypes
        self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes,
                                                       self.hidden_size)
        # Initialize the token-type embeddings.
185
        args = get_args()
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
        self.init_method(self.tokentype_embeddings.weight)

    def forward(self, input_ids, position_ids, tokentype_ids=None):
        # Embeddings.
        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        embeddings = words_embeddings + position_embeddings
        if tokentype_ids is not None:
            assert self.tokentype_embeddings is not None
            embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
        else:
            assert self.tokentype_embeddings is None

        # Dropout.
        embeddings = self.embedding_dropout(embeddings)

        return embeddings

    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
        """For easy load."""

        state_dict_ = {}
        state_dict_[self._word_embeddings_key] \
            = self.word_embeddings.state_dict(destination, prefix, keep_vars)
        state_dict_[self._position_embeddings_key] \
            = self.position_embeddings.state_dict(
                destination, prefix, keep_vars)
        if self.num_tokentypes > 0:
            state_dict_[self._tokentype_embeddings_key] \
                = self.tokentype_embeddings.state_dict(
                    destination, prefix, keep_vars)

        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        # Word embedding.
        if self._word_embeddings_key in state_dict:
            state_dict_ = state_dict[self._word_embeddings_key]
        else:
            # for backward compatibility.
            state_dict_ = {}
            for key in state_dict.keys():
                if 'word_embeddings' in key:
                    state_dict_[key.split('word_embeddings.')[1]] \
                        = state_dict[key]
        self.word_embeddings.load_state_dict(state_dict_, strict=strict)

        # Position embedding.
        if self._position_embeddings_key in state_dict:
            state_dict_ = state_dict[self._position_embeddings_key]
        else:
            # for backward compatibility.
            state_dict_ = {}
            for key in state_dict.keys():
                if 'position_embeddings' in key:
                    state_dict_[key.split('position_embeddings.')[1]] \
                        = state_dict[key]
        self.position_embeddings.load_state_dict(state_dict_, strict=strict)

        # Tokentype embedding.
Neel Kant's avatar
Neel Kant committed
249
        if self.num_tokentypes > 0:
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
            state_dict_ = {}
            if self._tokentype_embeddings_key in state_dict:
                state_dict_ = state_dict[self._tokentype_embeddings_key]
            else:
                # for backward compatibility.
                for key in state_dict.keys():
                    if 'tokentype_embeddings' in key:
                        state_dict_[key.split('tokentype_embeddings.')[1]] \
                            = state_dict[key]
            if len(state_dict_.keys()) > 0:
                self.tokentype_embeddings.load_state_dict(state_dict_,
                                                          strict=strict)
            else:
                print('***WARNING*** expected tokentype embeddings in the '
                      'checkpoint but could not find it', flush=True)


267
class TransformerLanguageModel(MegatronModule):
268
269
270
271
272
273
274
275
276
277
278
    """Transformer language model.

    Arguments:
        transformer_hparams: transformer hyperparameters
        vocab_size: vocabulary size
        max_sequence_length: maximum size of sequence. This
                             is used for positional embedding
        embedding_dropout_prob: dropout probability for embeddings
        num_tokentypes: size of the token-type embeddings. 0 value
                        will ignore this embedding
    """
Neel Kant's avatar
Neel Kant committed
279

280
    def __init__(self,
Mohammad's avatar
Mohammad committed
281
282
                 init_method,
                 output_layer_init_method,
283
                 encoder_attn_mask_type,
284
                 num_tokentypes=0,
285
                 add_encoder=True,
286
                 add_decoder=False,
287
                 decoder_attn_mask_type=AttnMaskType.causal,
288
289
290
291
                 add_pooler=False,
                 pre_process=True,
                 post_process=True):
        super(TransformerLanguageModel, self).__init__()
Mohammad's avatar
Mohammad committed
292
        args = get_args()
293

294
295
        self.pre_process = pre_process
        self.post_process = post_process
Mohammad's avatar
Mohammad committed
296
        self.hidden_size = args.hidden_size
297
        self.num_tokentypes = num_tokentypes
Mohammad's avatar
Mohammad committed
298
        self.init_method = init_method
299
        self.add_encoder = add_encoder
300
        self.encoder_attn_mask_type = encoder_attn_mask_type
301
        self.add_decoder = add_decoder
302
        self.decoder_attn_mask_type = decoder_attn_mask_type
303
        self.add_pooler = add_pooler
304
        self.encoder_hidden_state = None
305

306
        # Embeddings.
307
        if self.pre_process:
308
309
310
311
312
313
314
            self.embedding = Embedding(self.hidden_size,
                                       args.padded_vocab_size,
                                       args.max_position_embeddings,
                                       args.hidden_dropout,
                                       self.init_method,
                                       self.num_tokentypes)
            self._embedding_key = 'embedding'
315

316
        # Transformer.
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
        # Encoder (usually set to True, False if part of an encoder-decoder
        # architecture and in encoder-only stage).
        if self.add_encoder:
            self.encoder = ParallelTransformer(
                self.init_method,
                output_layer_init_method,
                self_attn_mask_type=self.encoder_attn_mask_type,
                pre_process=self.pre_process,
                post_process=self.post_process
            )
            self._encoder_key = 'encoder'
        else:
            self.encoder = None

        # Decoder (usually set to False, True if part of an encoder-decoder
        # architecture and in decoder-only stage).
Vijay Korthikanti's avatar
Vijay Korthikanti committed
333
        if self.add_decoder:
334
335
336
337
            # Temporary assertion until we verify correctness of pipeline parallelism
            # implementation of T5.
            assert args.pipeline_model_parallel_size == 1, \
                'pipeline parallelism is not supported in the presence of decoder'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
338
339
340
341
            self.decoder = ParallelTransformer(
                self.init_method,
                output_layer_init_method,
                layer_type=LayerType.decoder,
342
343
344
                self_attn_mask_type=self.decoder_attn_mask_type,
                pre_process=self.pre_process,
                post_process=self.post_process)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
345
            self._decoder_key = 'decoder'
346
347
        else:
            self.decoder = None
348

349
        if self.post_process:
350
351
352
353
354
            # Pooler.
            if self.add_pooler:
                self.pooler = Pooler(self.hidden_size, self.init_method)
                self._pooler_key = 'pooler'

355
    def set_input_tensor(self, input_tensor):
356
        """ See megatron.model.transformer.set_input_tensor()"""
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
        if self.add_encoder and self.add_decoder:
            assert len(input_tensor) == 1, \
                'input_tensor should only be length 1 for stage with both encoder and decoder'
            self.encoder.set_input_tensor(input_tensor[0])
        elif self.add_encoder:
            assert len(input_tensor) == 1, \
                'input_tensor should only be length 1 for stage with only encoder'
            self.encoder.set_input_tensor(input_tensor[0])
        elif self.add_decoder:
            if len(input_tensor) == 2:
                self.decoder.set_input_tensor(input_tensor[0])
                self.encoder_hidden_state = input_tensor[1]
            elif len(input_tensor) == 1:
                self.decoder.set_input_tensor(None)
                self.encoder_hidden_state = input_tensor[0]
            else:
                raise Exception('input_tensor must have either length 1 or 2')
        else:
            raise Exception('Stage must have at least either encoder or decoder')
376
377
378

    def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
                dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
379
                enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None,
380
                get_key_value=False, pooling_sequence_index=0,
381
                enc_hidden_states=None, output_enc_hidden=False):
382

383
        # Encoder embedding.
384
        if self.pre_process:
385
386
            encoder_input = self.embedding(enc_input_ids, enc_position_ids,
                                           tokentype_ids=tokentype_ids)
387
        else:
388
            encoder_input = None
389

390
        # Run encoder.
391
        if enc_hidden_states is None:
392
393
394
395
396
397
398
            if self.encoder is not None:
                encoder_output = self.encoder(encoder_input,
                                              enc_attn_mask,
                                              layer_past=layer_past,
                                              get_key_value=get_key_value)
            else:
                encoder_output = self.encoder_hidden_state
399
400
401
        else:
            encoder_output = enc_hidden_states.to(encoder_input.dtype)

402
        if self.post_process:
403
404
405
406
            if self.add_pooler:
                pooled_output = self.pooler(encoder_output,
                                            pooling_sequence_index)

Vijay Korthikanti's avatar
Vijay Korthikanti committed
407
408
409
410
        # output_enc_hidden refers to when we just need the encoder's
        # output. For example, it is helpful to compute
        # similarity between two sequences by average pooling
        if not self.add_decoder or output_enc_hidden:
411
            if self.add_pooler and self.post_process:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
412
                return encoder_output, pooled_output
413
            else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
414
415
                return encoder_output

416
417
418
419
420
421
422
423
424
        # Decoder embedding.
        if self.pre_process:
            decoder_input = self.embedding(dec_input_ids,
                                           dec_position_ids)
        else:
            decoder_input = None

        # Run decoder.
        decoder_output = self.decoder(decoder_input,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
425
426
427
428
429
430
                                      dec_attn_mask,
                                      layer_past=layer_past,
                                      get_key_value=get_key_value,
                                      encoder_output=encoder_output,
                                      enc_dec_attn_mask=enc_dec_attn_mask)

431
        if self.add_pooler and self.post_process:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
432
433
434
            return decoder_output, encoder_output, pooled_output
        else:
            return decoder_output, encoder_output
435
436
437
438
439
440

    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
        """For easy load."""

        state_dict_ = {}
441
        if self.pre_process:
442
443
444
            state_dict_[self._embedding_key] \
                = self.embedding.state_dict_for_save_checkpoint(
                    destination, prefix, keep_vars)
445
446
447
448
        if self.add_encoder:
            state_dict_[self._encoder_key] \
                = self.encoder.state_dict_for_save_checkpoint(
                    destination, prefix, keep_vars)
449
        if self.post_process:
450
451
452
453
            if self.add_pooler:
                state_dict_[self._pooler_key] \
                    = self.pooler.state_dict_for_save_checkpoint(
                        destination, prefix, keep_vars)
454
455
456
        if self.add_decoder:
            state_dict_[self._decoder_key] \
                = self.decoder.state_dict_for_save_checkpoint(
457
458
459
460
461
462
463
464
                    destination, prefix, keep_vars)

        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        # Embedding.
465
        if self.pre_process:
466
467
468
469
470
471
472
473
474
            if self._embedding_key in state_dict:
                state_dict_ = state_dict[self._embedding_key]
            else:
                # for backward compatibility.
                state_dict_ = {}
                for key in state_dict.keys():
                    if '_embeddings' in key:
                        state_dict_[key] = state_dict[key]
            self.embedding.load_state_dict(state_dict_, strict=strict)
475

476
        # Encoder.
477
478
479
480
481
482
        if self.add_encoder:
            if self._encoder_key in state_dict:
                state_dict_ = state_dict[self._encoder_key]
            # For backward compatibility.
            elif 'transformer' in state_dict:
                state_dict_ = state_dict['transformer']
483
            else:
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
                # For backward compatibility.
                state_dict_ = {}
                for key in state_dict.keys():
                    if 'transformer.' in key:
                        state_dict_[key.split('transformer.')[1]] = state_dict[key]

            # For backward compatibility.
            state_dict_self_attention = {}
            for key in state_dict_.keys():
                if '.attention.' in key:
                    state_dict_self_attention[key.replace(".attention.",
                        ".self_attention.")] = state_dict_[key]
                else:
                    state_dict_self_attention[key] = state_dict_[key]
            state_dict_ = state_dict_self_attention

            self.encoder.load_state_dict(state_dict_, strict=strict)

        # Pooler.
503
        if self.post_process:
504
505
506
507
508
            if self.add_pooler:
                assert 'pooler' in state_dict, \
                    'could not find data for pooler in the checkpoint'
                self.pooler.load_state_dict(state_dict[self._pooler_key],
                                            strict=strict)
509
        # Decoder.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
510
511
        if self.add_decoder:
            assert 'decoder' in state_dict, \
512
                'could not find data for pooler in the checkpoint'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
513
514
            self.decoder.load_state_dict(state_dict[self._decoder_key],
                                         strict=strict)