language_model.py 26.3 KB
Newer Older
liangjing's avatar
v1  
liangjing committed
1
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
3
4
5
6
7

"""Transformer based language model."""

import torch
import torch.nn.functional as F

Mohammad's avatar
Mohammad committed
8
from megatron import get_args
9
from megatron.core import mpu, tensor_parallel
liangjing's avatar
v1  
liangjing committed
10
11
from megatron.core.enums import ModelType
from megatron.core.models.common.rotary_pos_embedding import RotaryEmbedding
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
12

liangjing's avatar
v1  
liangjing committed
13
from .enums import AttnMaskType, LayerType
14
from .module import MegatronModule
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
15
16
17
from .transformer import ParallelTransformer
from .utils import get_linear_layer
from .utils import init_method_normal, scaled_init_method_normal
18

19

20
21
22
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
                       bias=None):
    """LM logits using word embedding weights."""
23
    args = get_args()
24
    # Parallel logits.
25
    if args.async_tensor_model_parallel_allreduce or\
Vijay Korthikanti's avatar
Vijay Korthikanti committed
26
            args.sequence_parallel:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
27
        input_parallel = input_
28
        model_parallel = mpu.get_tensor_model_parallel_world_size() > 1
29
        async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
30
            model_parallel and not args.sequence_parallel
31
    else:
32
        input_parallel = tensor_parallel.copy_to_tensor_model_parallel_region(input_)
33
34
        async_grad_allreduce = False

35
    # Matrix multiply.
36
    logits_parallel = tensor_parallel.linear_with_grad_accumulation_and_async_allreduce(
37
38
39
40
41
        input=input_parallel,
        weight=word_embeddings_weight,
        bias=bias,
        gradient_accumulation_fusion=args.gradient_accumulation_fusion,
        async_grad_allreduce=async_grad_allreduce,
liangjing's avatar
v1  
liangjing committed
42
        sequence_parallel=args.sequence_parallel)
43
    # Gather if needed.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
44

45
46
    if parallel_output:
        return logits_parallel
Mohammad's avatar
Mohammad committed
47

48
    return tensor_parallel.gather_from_tensor_model_parallel_region(logits_parallel)
Mohammad's avatar
Mohammad committed
49
50


liangjing's avatar
v1  
liangjing committed
51
52
53
def get_language_model(config, num_tokentypes, add_pooler,
                       encoder_attn_mask_type,
                       add_encoder=True,
54
                       add_decoder=False,
55
56
                       decoder_attn_mask_type=AttnMaskType.causal,
                       pre_process=True, post_process=True):
Mohammad's avatar
Mohammad committed
57
    """Build language model and return along with the key to save."""
58
    args = get_args()
liangjing's avatar
v1  
liangjing committed
59
60
    if config.init_method is None:
        config.init_method = init_method_normal(config.init_method_std)
Mohammad's avatar
Mohammad committed
61

liangjing's avatar
v1  
liangjing committed
62
63
64
    if config.output_layer_init_method is None:
        config.output_layer_init_method = scaled_init_method_normal(config.init_method_std,
                                                                    config.num_layers)
65

66
    # Language model.
67
    language_model = TransformerLanguageModel(
liangjing's avatar
v1  
liangjing committed
68
        config,
69
70
        encoder_attn_mask_type,
        num_tokentypes=num_tokentypes,
71
        add_encoder=add_encoder,
72
73
74
75
76
77
        add_decoder=add_decoder,
        decoder_attn_mask_type=decoder_attn_mask_type,
        add_pooler=add_pooler,
        pre_process=pre_process,
        post_process=post_process
    )
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    # key used for checkpoints.
    language_model_key = 'language_model'

    return language_model, language_model_key


class Pooler(MegatronModule):
    """Pooler layer.

    Pool hidden states of a specific token (for example start of the
    sequence) and add a linear transformation followed by a tanh.

    Arguments:
        hidden_size: hidden size
        init_method: weight initialization method for the linear layer.
            bias is set to zero.
    """
Neel Kant's avatar
Neel Kant committed
95

96
97
    def __init__(self, hidden_size, init_method):
        super(Pooler, self).__init__()
98
        args = get_args()
99
        self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
100
101
        self.sequence_parallel = args.sequence_parallel

102
103

    def forward(self, hidden_states, sequence_index=0):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
104
        # hidden_states: [s, b, h]
105
        # sequence_index: index of the token to pool.
106
107
108
109

        # gather data along sequence dimensions
        # same pooler is run on all tensor parallel nodes
        if self.sequence_parallel:
110
            hidden_states = tensor_parallel.gather_from_sequence_parallel_region(
111
                hidden_states,
112
                tensor_parallel_output_grad=False)
113

Vijay Korthikanti's avatar
Vijay Korthikanti committed
114
        pooled = hidden_states[sequence_index, :, :]
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
        pooled = self.dense(pooled)
        pooled = torch.tanh(pooled)
        return pooled


class Embedding(MegatronModule):
    """Language model embeddings.

    Arguments:
        hidden_size: hidden size
        vocab_size: vocabulary size
        max_sequence_length: maximum size of sequence. This
                             is used for positional embedding
        embedding_dropout_prob: dropout probability for embeddings
        init_method: weight initialization method
        num_tokentypes: size of the token-type embeddings. 0 value
                        will ignore this embedding
liangjing's avatar
v1  
liangjing committed
132
133
134
135
        embedding_weights_in_fp32: casts word embedding weights to
                                   fp32 before sampling. Required to
                                   maintain reproducibility when
                                   training in bf16.
136
    """
Neel Kant's avatar
Neel Kant committed
137

138
139
140
141
142
    def __init__(self,
                 hidden_size,
                 vocab_size,
                 max_sequence_length,
                 embedding_dropout_prob,
liangjing's avatar
v1  
liangjing committed
143
144
145
                 config,
                 num_tokentypes=0,
                 embedding_weights_in_fp32=False):
146
147
148
        super(Embedding, self).__init__()

        self.hidden_size = hidden_size
liangjing's avatar
v1  
liangjing committed
149
        self.init_method = config.init_method
150
151
        self.num_tokentypes = num_tokentypes

152
153
        args = get_args()

154
        # Word embeddings (parallel).
liangjing's avatar
v1  
liangjing committed
155
156
        self.embedding_weights_in_fp32 = embedding_weights_in_fp32
        self.params_dtype = args.params_dtype
157
        self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
liangjing's avatar
v1  
liangjing committed
158
            vocab_size, self.hidden_size, config=config, init_method=config.init_method)
159
160
161
        self._word_embeddings_key = 'word_embeddings'

        # Position embedding (serial).
liangjing's avatar
v1  
liangjing committed
162
        self.add_position_embedding = args.position_embedding_type == 'learned_absolute'
Mostofa Patwary's avatar
Mostofa Patwary committed
163
164
165
166
167
168
169
        if self.add_position_embedding:
            self.position_embeddings = torch.nn.Embedding(
                max_sequence_length, self.hidden_size)
            self._position_embeddings_key = 'position_embeddings'
            # Initialize the position embeddings.
            if args.perform_initialization:
                self.init_method(self.position_embeddings.weight)
170
171
172
173
174
175
176
177
178
179

        # Token type embedding.
        # Add this as an optional field that can be added through
        # method call so we can load a pretrain model without
        # token types and add them as needed.
        self._tokentype_embeddings_key = 'tokentype_embeddings'
        if self.num_tokentypes > 0:
            self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes,
                                                           self.hidden_size)
            # Initialize the token-type embeddings.
180
181
            if args.perform_initialization:
                self.init_method(self.tokentype_embeddings.weight)
182
183
184
        else:
            self.tokentype_embeddings = None

liangjing's avatar
v1  
liangjing committed
185
        self.fp32_residual_connection = args.fp32_residual_connection
Vijay Korthikanti's avatar
Vijay Korthikanti committed
186
        self.sequence_parallel = args.sequence_parallel
187
188
189
        # Embeddings dropout
        self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)

190
191
192
    def zero_parameters(self):
        """Zero out all parameters in embedding."""
        self.word_embeddings.weight.data.fill_(0)
Deepak Narayanan's avatar
Deepak Narayanan committed
193
        self.word_embeddings.weight.shared = True
Mostofa Patwary's avatar
Mostofa Patwary committed
194
195
196
        if self.add_position_embedding:
            self.position_embeddings.weight.data.fill_(0)
            self.position_embeddings.weight.shared = True
197
198
        if self.num_tokentypes > 0:
            self.tokentype_embeddings.weight.data.fill_(0)
Deepak Narayanan's avatar
Deepak Narayanan committed
199
            self.tokentype_embeddings.weight.shared = True
200

201
202
203
204
205
206
207
208
209
210
211
212
213
214
    def add_tokentype_embeddings(self, num_tokentypes):
        """Add token-type embedding. This function is provided so we can add
        token-type embeddings in case the pretrained model does not have it.
        This allows us to load the model normally and then add this embedding.
        """
        if self.tokentype_embeddings is not None:
            raise Exception('tokentype embeddings is already initialized')
        if torch.distributed.get_rank() == 0:
            print('adding embedding for {} tokentypes'.format(num_tokentypes),
                  flush=True)
        self.num_tokentypes = num_tokentypes
        self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes,
                                                       self.hidden_size)
        # Initialize the token-type embeddings.
215
        args = get_args()
216
217
218
219
        self.init_method(self.tokentype_embeddings.weight)

    def forward(self, input_ids, position_ids, tokentype_ids=None):
        # Embeddings.
liangjing's avatar
v1  
liangjing committed
220
221
        if self.embedding_weights_in_fp32:
            self.word_embeddings = self.word_embeddings.to(torch.float32)
222
        words_embeddings = self.word_embeddings(input_ids)
liangjing's avatar
v1  
liangjing committed
223
224
225
        if self.embedding_weights_in_fp32:
            words_embeddings = words_embeddings.to(self.params_dtype)
            self.word_embeddings = self.word_embeddings.to(self.params_dtype)
Mostofa Patwary's avatar
Mostofa Patwary committed
226
227
228
229
230
231
        if self.add_position_embedding:
            position_embeddings = self.position_embeddings(position_ids)
            embeddings = words_embeddings + position_embeddings
        else:
            embeddings = words_embeddings

232
233
234
235
236
237
        if tokentype_ids is not None:
            assert self.tokentype_embeddings is not None
            embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
        else:
            assert self.tokentype_embeddings is None

238
        # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
Vijay Korthikanti's avatar
Vijay Korthikanti committed
239
240
        embeddings = embeddings.transpose(0, 1).contiguous()

241
242
        # If the input flag for fp32 residual connection is set, convert for float.
        if self.fp32_residual_connection:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
243
            embeddings = embeddings.float()
244

245
        # Dropout.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
246
        if self.sequence_parallel:
247
248
            embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings)
            with tensor_parallel.get_cuda_rng_tracker().fork():
249
250
251
                embeddings = self.embedding_dropout(embeddings)
        else:
            embeddings = self.embedding_dropout(embeddings)
252
253
254

        return embeddings

255
    def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
256
257
258
259
        """For easy load."""

        state_dict_ = {}
        state_dict_[self._word_embeddings_key] \
260
261
            = self.word_embeddings.state_dict(prefix=prefix,
                                              keep_vars=keep_vars)
Mostofa Patwary's avatar
Mostofa Patwary committed
262
263
264
        if self.add_position_embedding:
            state_dict_[self._position_embeddings_key] \
                = self.position_embeddings.state_dict(prefix=prefix,
265
                                                  keep_vars=keep_vars)
266
267
        if self.num_tokentypes > 0:
            state_dict_[self._tokentype_embeddings_key] \
268
269
                = self.tokentype_embeddings.state_dict(prefix=prefix,
                                                       keep_vars=keep_vars)
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288

        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        # Word embedding.
        if self._word_embeddings_key in state_dict:
            state_dict_ = state_dict[self._word_embeddings_key]
        else:
            # for backward compatibility.
            state_dict_ = {}
            for key in state_dict.keys():
                if 'word_embeddings' in key:
                    state_dict_[key.split('word_embeddings.')[1]] \
                        = state_dict[key]
        self.word_embeddings.load_state_dict(state_dict_, strict=strict)

        # Position embedding.
Mostofa Patwary's avatar
Mostofa Patwary committed
289
290
291
292
293
294
295
296
297
298
299
        if self.add_position_embedding:
            if self._position_embeddings_key in state_dict:
                state_dict_ = state_dict[self._position_embeddings_key]
            else:
                # for backward compatibility.
                state_dict_ = {}
                for key in state_dict.keys():
                    if 'position_embeddings' in key:
                        state_dict_[key.split('position_embeddings.')[1]] \
                            = state_dict[key]
            self.position_embeddings.load_state_dict(state_dict_, strict=strict)
300
301

        # Tokentype embedding.
Neel Kant's avatar
Neel Kant committed
302
        if self.num_tokentypes > 0:
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
            state_dict_ = {}
            if self._tokentype_embeddings_key in state_dict:
                state_dict_ = state_dict[self._tokentype_embeddings_key]
            else:
                # for backward compatibility.
                for key in state_dict.keys():
                    if 'tokentype_embeddings' in key:
                        state_dict_[key.split('tokentype_embeddings.')[1]] \
                            = state_dict[key]
            if len(state_dict_.keys()) > 0:
                self.tokentype_embeddings.load_state_dict(state_dict_,
                                                          strict=strict)
            else:
                print('***WARNING*** expected tokentype embeddings in the '
                      'checkpoint but could not find it', flush=True)


320
class TransformerLanguageModel(MegatronModule):
321
322
323
324
325
326
327
328
329
330
331
    """Transformer language model.

    Arguments:
        transformer_hparams: transformer hyperparameters
        vocab_size: vocabulary size
        max_sequence_length: maximum size of sequence. This
                             is used for positional embedding
        embedding_dropout_prob: dropout probability for embeddings
        num_tokentypes: size of the token-type embeddings. 0 value
                        will ignore this embedding
    """
Neel Kant's avatar
Neel Kant committed
332

333
    def __init__(self,
liangjing's avatar
v1  
liangjing committed
334
                 config,
335
                 encoder_attn_mask_type,
336
                 num_tokentypes=0,
337
                 add_encoder=True,
338
                 add_decoder=False,
339
                 decoder_attn_mask_type=AttnMaskType.causal,
340
341
342
                 add_pooler=False,
                 pre_process=True,
                 post_process=True):
Mohammad's avatar
Mohammad committed
343
        args = get_args()
liangjing's avatar
v1  
liangjing committed
344
        # TODO: passing share_embeddings_and_output_weights=False will not work correctly for T5 and embeddings will not be synced. Fix later for T5.
345
        if args.untie_embeddings_and_output_weights: assert not add_decoder
liangjing's avatar
v1  
liangjing committed
346
        super(TransformerLanguageModel, self).__init__(share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights)
347

348
349
        self.pre_process = pre_process
        self.post_process = post_process
liangjing's avatar
v1  
liangjing committed
350
        self.hidden_size = config.hidden_size
351
        self.num_tokentypes = num_tokentypes
liangjing's avatar
v1  
liangjing committed
352
        self.init_method = config.init_method
353
        self.add_encoder = add_encoder
354
        self.encoder_attn_mask_type = encoder_attn_mask_type
355
        self.add_decoder = add_decoder
356
        self.decoder_attn_mask_type = decoder_attn_mask_type
357
        self.add_pooler = add_pooler
358
        self.encoder_hidden_state = None
liangjing's avatar
v1  
liangjing committed
359
        self.add_retriever = args.retro_add_retriever
360
        self.untie_embeddings_and_output_weights = args.untie_embeddings_and_output_weights
361

362
        # Embeddings.
363
        if self.pre_process:
364
365
366
367
            self.embedding = Embedding(self.hidden_size,
                                       args.padded_vocab_size,
                                       args.max_position_embeddings,
                                       args.hidden_dropout,
liangjing's avatar
v1  
liangjing committed
368
369
370
                                       config,
                                       self.num_tokentypes,
                                       args.embedding_weights_in_fp32)
371
            self._embedding_key = 'embedding'
372

Mostofa Patwary's avatar
Mostofa Patwary committed
373
        # Rotary positional embeddings
Mostofa Patwary's avatar
Mostofa Patwary committed
374
        self.use_rotary_position_embeddings = \
liangjing's avatar
v1  
liangjing committed
375
376
            args.position_embedding_type == 'rope'
        if self.use_rotary_position_embeddings:
Mostofa Patwary's avatar
Mostofa Patwary committed
377
378
379
380
381
382
383
384
385
386
            self.seq_length = args.seq_length
            rotary_dim = args.hidden_size // args.num_attention_heads \
                if args.kv_channels is None else args.kv_channels

            if args.rotary_percent < 1.0:
                rotary_dim = int(rotary_dim * args.rotary_percent)

            # partial rotary embeddings, which is better than full rotary
            # Wang and Komatsuzaki et al
            # https://github.com/kingoflolz/mesh-transformer-jax/
liangjing's avatar
v1  
liangjing committed
387
388
389
            self.rotary_pos_emb = RotaryEmbedding(
                rotary_dim,
                seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor
390
            )
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
391
392
393
394

        # Encoder (usually set to True, False if part of an encoder-decoder
        # architecture and in encoder-only stage).
        if self.add_encoder:
liangjing's avatar
v1  
liangjing committed
395
396
397
398
399
400
401
402
            self.encoder = ParallelTransformer(
                config,
                model_type=args.model_type if not args.retro_add_retriever \
                    else ModelType.retro_decoder,
                self_attn_mask_type=self.encoder_attn_mask_type,
                pre_process=self.pre_process,
                post_process=self.post_process,
            )
403
404
405
406
407
408
            self._encoder_key = 'encoder'
        else:
            self.encoder = None

        # Decoder (usually set to False, True if part of an encoder-decoder
        # architecture and in decoder-only stage).
Vijay Korthikanti's avatar
Vijay Korthikanti committed
409
410
        if self.add_decoder:
            self.decoder = ParallelTransformer(
liangjing's avatar
v1  
liangjing committed
411
412
                config,
                model_type=args.model_type,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
413
                layer_type=LayerType.decoder,
414
415
416
                self_attn_mask_type=self.decoder_attn_mask_type,
                pre_process=self.pre_process,
                post_process=self.post_process)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
417
            self._decoder_key = 'decoder'
418
419
        else:
            self.decoder = None
420

421
        if self.post_process:
422
423
424
425
426
            # Pooler.
            if self.add_pooler:
                self.pooler = Pooler(self.hidden_size, self.init_method)
                self._pooler_key = 'pooler'

427
428
429
430
            if self.untie_embeddings_and_output_weights:
                self.output_layer = tensor_parallel.ColumnParallelLinear(
                    args.hidden_size,
                    args.padded_vocab_size,
liangjing's avatar
v1  
liangjing committed
431
432
433
                    config=config,
                    init_method=self.init_method,
                    bias=False) # Setting bias to False always to keep it consistent with embedding tying that also does not have a bias.
434
435
                self._output_layer_key = 'output_layer'

436
    def set_input_tensor(self, input_tensor):
437
        """ See megatron.model.transformer.set_input_tensor()"""
438
439
440
441
442
443

        # This is usually handled in schedules.py but some inference code still
        # gives us non-lists or None
        if not isinstance(input_tensor, list):
            input_tensor = [input_tensor]

444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
        if self.add_encoder and self.add_decoder:
            assert len(input_tensor) == 1, \
                'input_tensor should only be length 1 for stage with both encoder and decoder'
            self.encoder.set_input_tensor(input_tensor[0])
        elif self.add_encoder:
            assert len(input_tensor) == 1, \
                'input_tensor should only be length 1 for stage with only encoder'
            self.encoder.set_input_tensor(input_tensor[0])
        elif self.add_decoder:
            if len(input_tensor) == 2:
                self.decoder.set_input_tensor(input_tensor[0])
                self.encoder_hidden_state = input_tensor[1]
            elif len(input_tensor) == 1:
                self.decoder.set_input_tensor(None)
                self.encoder_hidden_state = input_tensor[0]
            else:
                raise Exception('input_tensor must have either length 1 or 2')
        else:
            raise Exception('Stage must have at least either encoder or decoder')
463
464
465

    def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
                dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
liangjing's avatar
v1  
liangjing committed
466
467
468
                retriever_input_ids=None,
                retriever_position_ids=None,
                retriever_attn_mask=None,
469
                enc_dec_attn_mask=None, tokentype_ids=None,
mshoeybi's avatar
mshoeybi committed
470
                inference_params=None,
471
                pooling_sequence_index=0,
472
                enc_hidden_states=None, output_enc_hidden=False):
473

474
        # Encoder embedding.
475
        if self.pre_process:
476
477
            encoder_input = self.embedding(enc_input_ids, enc_position_ids,
                                           tokentype_ids=tokentype_ids)
478
        else:
479
            encoder_input = None
480

liangjing's avatar
v1  
liangjing committed
481
482
483
484
485
486
487
488
        # Retriever embedding.
        if self.add_retriever and self.pre_process:
            retriever_input = self.embedding(retriever_input_ids,
                                             retriever_position_ids,
                                             tokentype_ids=tokentype_ids)
        else:
            retriever_input = None

Mostofa Patwary's avatar
Mostofa Patwary committed
489
490
491
492
493
        # Rotary positional embeddings
        rotary_pos_emb = None
        if self.use_rotary_position_embeddings:
            if inference_params is not None:
                rotary_pos_emb = \
liangjing's avatar
v1  
liangjing committed
494
                    self.rotary_pos_emb(inference_params.max_sequence_length)
Mostofa Patwary's avatar
Mostofa Patwary committed
495
496
497
            else:
                rotary_pos_emb = self.rotary_pos_emb(self.seq_length)

498
        # Run encoder.
499
        if enc_hidden_states is None:
500
            if self.encoder is not None:
liangjing's avatar
v1  
liangjing committed
501
502
503
504
505
506
507
                encoder_output = self.encoder(
                    encoder_input,
                    enc_attn_mask,
                    retriever_input=retriever_input,
                    retriever_attn_mask=retriever_attn_mask,
                    inference_params=inference_params,
                    rotary_pos_emb=rotary_pos_emb)
508
509
            else:
                encoder_output = self.encoder_hidden_state
510
511
512
        else:
            encoder_output = enc_hidden_states.to(encoder_input.dtype)

513
        if self.post_process:
514
515
516
517
            if self.add_pooler:
                pooled_output = self.pooler(encoder_output,
                                            pooling_sequence_index)

Vijay Korthikanti's avatar
Vijay Korthikanti committed
518
519
520
521
        # output_enc_hidden refers to when we just need the encoder's
        # output. For example, it is helpful to compute
        # similarity between two sequences by average pooling
        if not self.add_decoder or output_enc_hidden:
522
            if self.add_pooler and self.post_process:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
523
                return encoder_output, pooled_output
524
            else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
525
526
                return encoder_output

527
528
529
530
531
532
533
534
        # Decoder embedding.
        if self.pre_process:
            decoder_input = self.embedding(dec_input_ids,
                                           dec_position_ids)
        else:
            decoder_input = None

        # Run decoder.
535
        decoder_output = self.decoder(
536
            decoder_input,
537
538
539
            dec_attn_mask,
            encoder_output=encoder_output,
            enc_dec_attn_mask=enc_dec_attn_mask,
Mostofa Patwary's avatar
Mostofa Patwary committed
540
541
            inference_params=inference_params,
            rotary_pos_emb=rotary_pos_emb)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
542

543
        if self.add_pooler and self.post_process:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
544
545
546
            return decoder_output, encoder_output, pooled_output
        else:
            return decoder_output, encoder_output
547

548
    def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
549
550
551
        """For easy load."""

        state_dict_ = {}
552
        if self.pre_process:
553
            state_dict_[self._embedding_key] \
554
555
                = self.embedding.state_dict_for_save_checkpoint(prefix=prefix,
                                                                keep_vars=keep_vars)
556
557
        if self.add_encoder:
            state_dict_[self._encoder_key] \
558
559
                = self.encoder.state_dict_for_save_checkpoint(prefix=prefix,
                                                              keep_vars=keep_vars)
560
        if self.post_process:
561
562
            if self.add_pooler:
                state_dict_[self._pooler_key] \
563
564
                    = self.pooler.state_dict_for_save_checkpoint(prefix=prefix,
                                                                 keep_vars=keep_vars)
565
566
            if self.untie_embeddings_and_output_weights:
                state_dict_[self._output_layer_key] \
MaximumEntropy's avatar
MaximumEntropy committed
567
568
                    = self.output_layer.state_dict(prefix=prefix, keep_vars=keep_vars)

569
570
        if self.add_decoder:
            state_dict_[self._decoder_key] \
571
572
                = self.decoder.state_dict_for_save_checkpoint(prefix=prefix,
                                                              keep_vars=keep_vars)
573
574
575
576
577
578
579

        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        # Embedding.
580
        if self.pre_process:
581
582
583
584
585
586
587
588
589
            if self._embedding_key in state_dict:
                state_dict_ = state_dict[self._embedding_key]
            else:
                # for backward compatibility.
                state_dict_ = {}
                for key in state_dict.keys():
                    if '_embeddings' in key:
                        state_dict_[key] = state_dict[key]
            self.embedding.load_state_dict(state_dict_, strict=strict)
590

591
        # Encoder.
592
593
594
595
596
597
        if self.add_encoder:
            if self._encoder_key in state_dict:
                state_dict_ = state_dict[self._encoder_key]
            # For backward compatibility.
            elif 'transformer' in state_dict:
                state_dict_ = state_dict['transformer']
598
            else:
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
                # For backward compatibility.
                state_dict_ = {}
                for key in state_dict.keys():
                    if 'transformer.' in key:
                        state_dict_[key.split('transformer.')[1]] = state_dict[key]

            # For backward compatibility.
            state_dict_self_attention = {}
            for key in state_dict_.keys():
                if '.attention.' in key:
                    state_dict_self_attention[key.replace(".attention.",
                        ".self_attention.")] = state_dict_[key]
                else:
                    state_dict_self_attention[key] = state_dict_[key]
            state_dict_ = state_dict_self_attention

            self.encoder.load_state_dict(state_dict_, strict=strict)

        # Pooler.
618
        if self.post_process:
619
620
621
622
623
            if self.add_pooler:
                assert 'pooler' in state_dict, \
                    'could not find data for pooler in the checkpoint'
                self.pooler.load_state_dict(state_dict[self._pooler_key],
                                            strict=strict)
624
625
626
627
628
            if self.untie_embeddings_and_output_weights:
                assert 'output_layer' in state_dict, \
                    'could not find data for output_layer in the checkpoint'
                self.output_layer.load_state_dict(state_dict[self._output_layer_key],
                                                  strict=strict)
629
        # Decoder.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
630
631
        if self.add_decoder:
            assert 'decoder' in state_dict, \
632
                'could not find data for pooler in the checkpoint'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
633
634
            self.decoder.load_state_dict(state_dict[self._decoder_key],
                                         strict=strict)