language_model.py 24.9 KB
Newer Older
liangjing's avatar
v1  
liangjing committed
1
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
3
4
5
6
7

"""Transformer based language model."""

import torch
import torch.nn.functional as F

8
from megatron.core import mpu, tensor_parallel
liangjing's avatar
v1  
liangjing committed
9
from megatron.core.enums import ModelType
xingjinliang's avatar
xingjinliang committed
10
11
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.training import get_args
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
12

liangjing's avatar
v1  
liangjing committed
13
from .enums import AttnMaskType, LayerType
14
from .module import MegatronModule
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
15
from .transformer import ParallelTransformer
xingjinliang's avatar
xingjinliang committed
16
from .utils import get_linear_layer, init_method_normal, scaled_init_method_normal
17

18

xingjinliang's avatar
xingjinliang committed
19
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None):
20
    """LM logits using word embedding weights."""
21
    args = get_args()
22
    # Parallel logits.
xingjinliang's avatar
xingjinliang committed
23
24
    model_parallel = mpu.get_tensor_model_parallel_world_size() > 1
    if model_parallel or args.sequence_parallel:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
25
        input_parallel = input_
xingjinliang's avatar
xingjinliang committed
26
        allreduce_dgrad = model_parallel and not args.sequence_parallel
27
    else:
28
        input_parallel = tensor_parallel.copy_to_tensor_model_parallel_region(input_)
xingjinliang's avatar
xingjinliang committed
29
        allreduce_dgrad = False
30

31
    # Matrix multiply.
32
    logits_parallel = tensor_parallel.linear_with_grad_accumulation_and_async_allreduce(
33
34
35
36
        input=input_parallel,
        weight=word_embeddings_weight,
        bias=bias,
        gradient_accumulation_fusion=args.gradient_accumulation_fusion,
xingjinliang's avatar
xingjinliang committed
37
38
39
40
        sequence_parallel=args.sequence_parallel,
        grad_output_buffer=None,
        allreduce_dgrad=allreduce_dgrad,
    )
41
    # Gather if needed.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
42

43
44
    if parallel_output:
        return logits_parallel
Mohammad's avatar
Mohammad committed
45

46
    return tensor_parallel.gather_from_tensor_model_parallel_region(logits_parallel)
Mohammad's avatar
Mohammad committed
47
48


xingjinliang's avatar
xingjinliang committed
49
50
51
52
53
54
55
56
57
58
59
def get_language_model(
    config,
    num_tokentypes,
    add_pooler,
    encoder_attn_mask_type,
    add_encoder=True,
    add_decoder=False,
    decoder_attn_mask_type=AttnMaskType.causal,
    pre_process=True,
    post_process=True,
):
Mohammad's avatar
Mohammad committed
60
    """Build language model and return along with the key to save."""
61
    args = get_args()
liangjing's avatar
v1  
liangjing committed
62
63
    if config.init_method is None:
        config.init_method = init_method_normal(config.init_method_std)
Mohammad's avatar
Mohammad committed
64

liangjing's avatar
v1  
liangjing committed
65
    if config.output_layer_init_method is None:
xingjinliang's avatar
xingjinliang committed
66
67
68
        config.output_layer_init_method = scaled_init_method_normal(
            config.init_method_std, config.num_layers
        )
69

70
    # Language model.
71
    language_model = TransformerLanguageModel(
liangjing's avatar
v1  
liangjing committed
72
        config,
73
74
        encoder_attn_mask_type,
        num_tokentypes=num_tokentypes,
75
        add_encoder=add_encoder,
76
77
78
79
        add_decoder=add_decoder,
        decoder_attn_mask_type=decoder_attn_mask_type,
        add_pooler=add_pooler,
        pre_process=pre_process,
xingjinliang's avatar
xingjinliang committed
80
        post_process=post_process,
81
    )
82
83
84
85
86
87
88
89
90
91
92
93
    # key used for checkpoints.
    language_model_key = 'language_model'

    return language_model, language_model_key


class Pooler(MegatronModule):
    """Pooler layer.

    Pool hidden states of a specific token (for example start of the
    sequence) and add a linear transformation followed by a tanh.

xingjinliang's avatar
xingjinliang committed
94
    Args:
95
96
97
98
        hidden_size: hidden size
        init_method: weight initialization method for the linear layer.
            bias is set to zero.
    """
Neel Kant's avatar
Neel Kant committed
99

100
101
    def __init__(self, hidden_size, init_method):
        super(Pooler, self).__init__()
102
        args = get_args()
103
        self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
104
105
        self.sequence_parallel = args.sequence_parallel

106
    def forward(self, hidden_states, sequence_index=0):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
107
        # hidden_states: [s, b, h]
108
        # sequence_index: index of the token to pool.
109
110
111
112

        # gather data along sequence dimensions
        # same pooler is run on all tensor parallel nodes
        if self.sequence_parallel:
113
            hidden_states = tensor_parallel.gather_from_sequence_parallel_region(
xingjinliang's avatar
xingjinliang committed
114
115
                hidden_states, tensor_parallel_output_grad=False
            )
116

Vijay Korthikanti's avatar
Vijay Korthikanti committed
117
        pooled = hidden_states[sequence_index, :, :]
118
119
120
121
122
123
124
125
        pooled = self.dense(pooled)
        pooled = torch.tanh(pooled)
        return pooled


class Embedding(MegatronModule):
    """Language model embeddings.

xingjinliang's avatar
xingjinliang committed
126
    Args:
127
128
129
130
131
132
133
134
135
        hidden_size: hidden size
        vocab_size: vocabulary size
        max_sequence_length: maximum size of sequence. This
                             is used for positional embedding
        embedding_dropout_prob: dropout probability for embeddings
        init_method: weight initialization method
        num_tokentypes: size of the token-type embeddings. 0 value
                        will ignore this embedding
    """
Neel Kant's avatar
Neel Kant committed
136

xingjinliang's avatar
xingjinliang committed
137
138
139
140
141
142
143
144
145
    def __init__(
        self,
        hidden_size,
        vocab_size,
        max_sequence_length,
        embedding_dropout_prob,
        config,
        num_tokentypes=0,
    ):
146
147
148
        super(Embedding, self).__init__()

        self.hidden_size = hidden_size
liangjing's avatar
v1  
liangjing committed
149
        self.init_method = config.init_method
150
151
        self.num_tokentypes = num_tokentypes

152
153
        args = get_args()

154
        # Word embeddings (parallel).
liangjing's avatar
v1  
liangjing committed
155
        self.params_dtype = args.params_dtype
156
        self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
xingjinliang's avatar
xingjinliang committed
157
158
            vocab_size, self.hidden_size, config=config, init_method=config.init_method
        )
159
160
161
        self._word_embeddings_key = 'word_embeddings'

        # Position embedding (serial).
liangjing's avatar
v1  
liangjing committed
162
        self.add_position_embedding = args.position_embedding_type == 'learned_absolute'
Mostofa Patwary's avatar
Mostofa Patwary committed
163
        if self.add_position_embedding:
xingjinliang's avatar
xingjinliang committed
164
            self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size)
Mostofa Patwary's avatar
Mostofa Patwary committed
165
166
167
168
            self._position_embeddings_key = 'position_embeddings'
            # Initialize the position embeddings.
            if args.perform_initialization:
                self.init_method(self.position_embeddings.weight)
169
170
171
172
173
174
175

        # Token type embedding.
        # Add this as an optional field that can be added through
        # method call so we can load a pretrain model without
        # token types and add them as needed.
        self._tokentype_embeddings_key = 'tokentype_embeddings'
        if self.num_tokentypes > 0:
xingjinliang's avatar
xingjinliang committed
176
            self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size)
177
            # Initialize the token-type embeddings.
178
179
            if args.perform_initialization:
                self.init_method(self.tokentype_embeddings.weight)
180
181
182
        else:
            self.tokentype_embeddings = None

liangjing's avatar
v1  
liangjing committed
183
        self.fp32_residual_connection = args.fp32_residual_connection
Vijay Korthikanti's avatar
Vijay Korthikanti committed
184
        self.sequence_parallel = args.sequence_parallel
xingjinliang's avatar
xingjinliang committed
185
        self.clone_scatter_output_in_embedding = args.clone_scatter_output_in_embedding
186
187
188
        # Embeddings dropout
        self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)

189
190
191
    def zero_parameters(self):
        """Zero out all parameters in embedding."""
        self.word_embeddings.weight.data.fill_(0)
Deepak Narayanan's avatar
Deepak Narayanan committed
192
        self.word_embeddings.weight.shared = True
Mostofa Patwary's avatar
Mostofa Patwary committed
193
194
195
        if self.add_position_embedding:
            self.position_embeddings.weight.data.fill_(0)
            self.position_embeddings.weight.shared = True
196
197
        if self.num_tokentypes > 0:
            self.tokentype_embeddings.weight.data.fill_(0)
Deepak Narayanan's avatar
Deepak Narayanan committed
198
            self.tokentype_embeddings.weight.shared = True
199

200
201
202
203
204
205
206
207
    def add_tokentype_embeddings(self, num_tokentypes):
        """Add token-type embedding. This function is provided so we can add
        token-type embeddings in case the pretrained model does not have it.
        This allows us to load the model normally and then add this embedding.
        """
        if self.tokentype_embeddings is not None:
            raise Exception('tokentype embeddings is already initialized')
        if torch.distributed.get_rank() == 0:
xingjinliang's avatar
xingjinliang committed
208
            print('adding embedding for {} tokentypes'.format(num_tokentypes), flush=True)
209
        self.num_tokentypes = num_tokentypes
xingjinliang's avatar
xingjinliang committed
210
        self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size)
211
        # Initialize the token-type embeddings.
212
        args = get_args()
213
214
215
216
217
        self.init_method(self.tokentype_embeddings.weight)

    def forward(self, input_ids, position_ids, tokentype_ids=None):
        # Embeddings.
        words_embeddings = self.word_embeddings(input_ids)
Mostofa Patwary's avatar
Mostofa Patwary committed
218
219
220
221
222
223
        if self.add_position_embedding:
            position_embeddings = self.position_embeddings(position_ids)
            embeddings = words_embeddings + position_embeddings
        else:
            embeddings = words_embeddings

224
225
226
227
228
229
        if tokentype_ids is not None:
            assert self.tokentype_embeddings is not None
            embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
        else:
            assert self.tokentype_embeddings is None

230
        # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
Vijay Korthikanti's avatar
Vijay Korthikanti committed
231
232
        embeddings = embeddings.transpose(0, 1).contiguous()

233
234
        # If the input flag for fp32 residual connection is set, convert for float.
        if self.fp32_residual_connection:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
235
            embeddings = embeddings.float()
236

237
        # Dropout.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
238
        if self.sequence_parallel:
239
            embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings)
xingjinliang's avatar
xingjinliang committed
240
241
242
243
244
            # `scatter_to_sequence_parallel_region` returns a view, which prevents
            # the original tensor from being garbage collected. Clone to facilitate GC.
            # Has a small runtime cost (~0.5%).
            if self.clone_scatter_output_in_embedding:
                embeddings = embeddings.clone()
245
            with tensor_parallel.get_cuda_rng_tracker().fork():
246
247
248
                embeddings = self.embedding_dropout(embeddings)
        else:
            embeddings = self.embedding_dropout(embeddings)
249
250
251

        return embeddings

252
    def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
253
254
255
        """For easy load."""

        state_dict_ = {}
xingjinliang's avatar
xingjinliang committed
256
257
258
        state_dict_[self._word_embeddings_key] = self.word_embeddings.state_dict(
            prefix=prefix, keep_vars=keep_vars
        )
Mostofa Patwary's avatar
Mostofa Patwary committed
259
        if self.add_position_embedding:
xingjinliang's avatar
xingjinliang committed
260
261
262
            state_dict_[self._position_embeddings_key] = self.position_embeddings.state_dict(
                prefix=prefix, keep_vars=keep_vars
            )
263
        if self.num_tokentypes > 0:
xingjinliang's avatar
xingjinliang committed
264
265
266
            state_dict_[self._tokentype_embeddings_key] = self.tokentype_embeddings.state_dict(
                prefix=prefix, keep_vars=keep_vars
            )
267
268
269
270
271
272
273
274
275
276
277
278
279
280

        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        # Word embedding.
        if self._word_embeddings_key in state_dict:
            state_dict_ = state_dict[self._word_embeddings_key]
        else:
            # for backward compatibility.
            state_dict_ = {}
            for key in state_dict.keys():
                if 'word_embeddings' in key:
xingjinliang's avatar
xingjinliang committed
281
                    state_dict_[key.split('word_embeddings.')[1]] = state_dict[key]
282
283
284
        self.word_embeddings.load_state_dict(state_dict_, strict=strict)

        # Position embedding.
Mostofa Patwary's avatar
Mostofa Patwary committed
285
286
287
288
289
290
291
292
        if self.add_position_embedding:
            if self._position_embeddings_key in state_dict:
                state_dict_ = state_dict[self._position_embeddings_key]
            else:
                # for backward compatibility.
                state_dict_ = {}
                for key in state_dict.keys():
                    if 'position_embeddings' in key:
xingjinliang's avatar
xingjinliang committed
293
                        state_dict_[key.split('position_embeddings.')[1]] = state_dict[key]
Mostofa Patwary's avatar
Mostofa Patwary committed
294
            self.position_embeddings.load_state_dict(state_dict_, strict=strict)
295
296

        # Tokentype embedding.
Neel Kant's avatar
Neel Kant committed
297
        if self.num_tokentypes > 0:
298
299
300
301
302
303
304
            state_dict_ = {}
            if self._tokentype_embeddings_key in state_dict:
                state_dict_ = state_dict[self._tokentype_embeddings_key]
            else:
                # for backward compatibility.
                for key in state_dict.keys():
                    if 'tokentype_embeddings' in key:
xingjinliang's avatar
xingjinliang committed
305
                        state_dict_[key.split('tokentype_embeddings.')[1]] = state_dict[key]
306
            if len(state_dict_.keys()) > 0:
xingjinliang's avatar
xingjinliang committed
307
                self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict)
308
            else:
xingjinliang's avatar
xingjinliang committed
309
310
311
312
313
                print(
                    '***WARNING*** expected tokentype embeddings in the '
                    'checkpoint but could not find it',
                    flush=True,
                )
314
315


316
class TransformerLanguageModel(MegatronModule):
317
318
    """Transformer language model.

xingjinliang's avatar
xingjinliang committed
319
    Args:
320
321
322
323
324
325
326
327
        transformer_hparams: transformer hyperparameters
        vocab_size: vocabulary size
        max_sequence_length: maximum size of sequence. This
                             is used for positional embedding
        embedding_dropout_prob: dropout probability for embeddings
        num_tokentypes: size of the token-type embeddings. 0 value
                        will ignore this embedding
    """
Neel Kant's avatar
Neel Kant committed
328

xingjinliang's avatar
xingjinliang committed
329
330
331
332
333
334
335
336
337
338
339
340
    def __init__(
        self,
        config,
        encoder_attn_mask_type,
        num_tokentypes=0,
        add_encoder=True,
        add_decoder=False,
        decoder_attn_mask_type=AttnMaskType.causal,
        add_pooler=False,
        pre_process=True,
        post_process=True,
    ):
Mohammad's avatar
Mohammad committed
341
        args = get_args()
liangjing's avatar
v1  
liangjing committed
342
        # TODO: passing share_embeddings_and_output_weights=False will not work correctly for T5 and embeddings will not be synced. Fix later for T5.
xingjinliang's avatar
xingjinliang committed
343
344
345
346
347
        if args.untie_embeddings_and_output_weights:
            assert not add_decoder
        super(TransformerLanguageModel, self).__init__(
            share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights
        )
348

349
350
        self.pre_process = pre_process
        self.post_process = post_process
liangjing's avatar
v1  
liangjing committed
351
        self.hidden_size = config.hidden_size
352
        self.num_tokentypes = num_tokentypes
liangjing's avatar
v1  
liangjing committed
353
        self.init_method = config.init_method
354
        self.add_encoder = add_encoder
355
        self.encoder_attn_mask_type = encoder_attn_mask_type
356
        self.add_decoder = add_decoder
357
        self.decoder_attn_mask_type = decoder_attn_mask_type
358
        self.add_pooler = add_pooler
359
        self.encoder_hidden_state = None
liangjing's avatar
v1  
liangjing committed
360
        self.add_retriever = args.retro_add_retriever
361
        self.untie_embeddings_and_output_weights = args.untie_embeddings_and_output_weights
362

363
        # Embeddings.
364
        if self.pre_process:
xingjinliang's avatar
xingjinliang committed
365
366
367
368
369
370
371
372
            self.embedding = Embedding(
                self.hidden_size,
                args.padded_vocab_size,
                args.max_position_embeddings,
                args.hidden_dropout,
                config,
                self.num_tokentypes,
            )
373
            self._embedding_key = 'embedding'
374

Mostofa Patwary's avatar
Mostofa Patwary committed
375
        # Rotary positional embeddings
xingjinliang's avatar
xingjinliang committed
376
        self.use_rotary_position_embeddings = args.position_embedding_type == 'rope'
liangjing's avatar
v1  
liangjing committed
377
        if self.use_rotary_position_embeddings:
Mostofa Patwary's avatar
Mostofa Patwary committed
378
            self.seq_length = args.seq_length
xingjinliang's avatar
xingjinliang committed
379
380
381
382
383
            rotary_dim = (
                args.hidden_size // args.num_attention_heads
                if args.kv_channels is None
                else args.kv_channels
            )
Mostofa Patwary's avatar
Mostofa Patwary committed
384
385
386
387

            # partial rotary embeddings, which is better than full rotary
            # Wang and Komatsuzaki et al
            # https://github.com/kingoflolz/mesh-transformer-jax/
liangjing's avatar
v1  
liangjing committed
388
            self.rotary_pos_emb = RotaryEmbedding(
xingjinliang's avatar
xingjinliang committed
389
390
391
                kv_channels=rotary_dim,
                rotary_percent=args.rotary_percent,
                seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor,
392
            )
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
393
394
395
396

        # Encoder (usually set to True, False if part of an encoder-decoder
        # architecture and in encoder-only stage).
        if self.add_encoder:
liangjing's avatar
v1  
liangjing committed
397
398
            self.encoder = ParallelTransformer(
                config,
xingjinliang's avatar
xingjinliang committed
399
400
401
                model_type=(
                    args.model_type if not args.retro_add_retriever else ModelType.retro_decoder
                ),
liangjing's avatar
v1  
liangjing committed
402
403
404
405
                self_attn_mask_type=self.encoder_attn_mask_type,
                pre_process=self.pre_process,
                post_process=self.post_process,
            )
406
407
408
409
410
411
            self._encoder_key = 'encoder'
        else:
            self.encoder = None

        # Decoder (usually set to False, True if part of an encoder-decoder
        # architecture and in decoder-only stage).
Vijay Korthikanti's avatar
Vijay Korthikanti committed
412
413
        if self.add_decoder:
            self.decoder = ParallelTransformer(
liangjing's avatar
v1  
liangjing committed
414
415
                config,
                model_type=args.model_type,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
416
                layer_type=LayerType.decoder,
417
418
                self_attn_mask_type=self.decoder_attn_mask_type,
                pre_process=self.pre_process,
xingjinliang's avatar
xingjinliang committed
419
420
                post_process=self.post_process,
            )
Vijay Korthikanti's avatar
Vijay Korthikanti committed
421
            self._decoder_key = 'decoder'
422
423
        else:
            self.decoder = None
424

425
        if self.post_process:
426
427
428
429
430
            # Pooler.
            if self.add_pooler:
                self.pooler = Pooler(self.hidden_size, self.init_method)
                self._pooler_key = 'pooler'

431
432
433
434
            if self.untie_embeddings_and_output_weights:
                self.output_layer = tensor_parallel.ColumnParallelLinear(
                    args.hidden_size,
                    args.padded_vocab_size,
liangjing's avatar
v1  
liangjing committed
435
436
                    config=config,
                    init_method=self.init_method,
xingjinliang's avatar
xingjinliang committed
437
438
                    bias=False,
                )  # Setting bias to False always to keep it consistent with embedding tying that also does not have a bias.
439
440
                self._output_layer_key = 'output_layer'

441
    def set_input_tensor(self, input_tensor):
xingjinliang's avatar
xingjinliang committed
442
        """See megatron.legacy.model.transformer.set_input_tensor()"""
443
444
445
446
447
448

        # This is usually handled in schedules.py but some inference code still
        # gives us non-lists or None
        if not isinstance(input_tensor, list):
            input_tensor = [input_tensor]

449
        if self.add_encoder and self.add_decoder:
xingjinliang's avatar
xingjinliang committed
450
451
452
            assert (
                len(input_tensor) == 1
            ), 'input_tensor should only be length 1 for stage with both encoder and decoder'
453
454
            self.encoder.set_input_tensor(input_tensor[0])
        elif self.add_encoder:
xingjinliang's avatar
xingjinliang committed
455
456
457
            assert (
                len(input_tensor) == 1
            ), 'input_tensor should only be length 1 for stage with only encoder'
458
459
460
461
462
463
464
465
466
467
468
469
            self.encoder.set_input_tensor(input_tensor[0])
        elif self.add_decoder:
            if len(input_tensor) == 2:
                self.decoder.set_input_tensor(input_tensor[0])
                self.encoder_hidden_state = input_tensor[1]
            elif len(input_tensor) == 1:
                self.decoder.set_input_tensor(None)
                self.encoder_hidden_state = input_tensor[0]
            else:
                raise Exception('input_tensor must have either length 1 or 2')
        else:
            raise Exception('Stage must have at least either encoder or decoder')
470

xingjinliang's avatar
xingjinliang committed
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
    def forward(
        self,
        enc_input_ids,
        enc_position_ids,
        enc_attn_mask,
        dec_input_ids=None,
        dec_position_ids=None,
        dec_attn_mask=None,
        retriever_input_ids=None,
        retriever_position_ids=None,
        retriever_attn_mask=None,
        enc_dec_attn_mask=None,
        tokentype_ids=None,
        inference_params=None,
        pooling_sequence_index=0,
        enc_hidden_states=None,
        output_enc_hidden=False,
    ):
489

490
        # Encoder embedding.
491
        if self.pre_process:
xingjinliang's avatar
xingjinliang committed
492
493
494
            encoder_input = self.embedding(
                enc_input_ids, enc_position_ids, tokentype_ids=tokentype_ids
            )
495
        else:
496
            encoder_input = None
497

liangjing's avatar
v1  
liangjing committed
498
499
        # Retriever embedding.
        if self.add_retriever and self.pre_process:
xingjinliang's avatar
xingjinliang committed
500
501
502
            retriever_input = self.embedding(
                retriever_input_ids, retriever_position_ids, tokentype_ids=tokentype_ids
            )
liangjing's avatar
v1  
liangjing committed
503
504
505
        else:
            retriever_input = None

Mostofa Patwary's avatar
Mostofa Patwary committed
506
507
508
509
        # Rotary positional embeddings
        rotary_pos_emb = None
        if self.use_rotary_position_embeddings:
            if inference_params is not None:
xingjinliang's avatar
xingjinliang committed
510
                rotary_pos_emb = self.rotary_pos_emb(inference_params.max_sequence_length)
Mostofa Patwary's avatar
Mostofa Patwary committed
511
512
513
            else:
                rotary_pos_emb = self.rotary_pos_emb(self.seq_length)

514
        # Run encoder.
515
        if enc_hidden_states is None:
516
            if self.encoder is not None:
liangjing's avatar
v1  
liangjing committed
517
518
519
520
521
522
                encoder_output = self.encoder(
                    encoder_input,
                    enc_attn_mask,
                    retriever_input=retriever_input,
                    retriever_attn_mask=retriever_attn_mask,
                    inference_params=inference_params,
xingjinliang's avatar
xingjinliang committed
523
524
                    rotary_pos_emb=rotary_pos_emb,
                )
525
526
            else:
                encoder_output = self.encoder_hidden_state
527
528
529
        else:
            encoder_output = enc_hidden_states.to(encoder_input.dtype)

530
        if self.post_process:
531
            if self.add_pooler:
xingjinliang's avatar
xingjinliang committed
532
                pooled_output = self.pooler(encoder_output, pooling_sequence_index)
533

Vijay Korthikanti's avatar
Vijay Korthikanti committed
534
535
536
537
        # output_enc_hidden refers to when we just need the encoder's
        # output. For example, it is helpful to compute
        # similarity between two sequences by average pooling
        if not self.add_decoder or output_enc_hidden:
538
            if self.add_pooler and self.post_process:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
539
                return encoder_output, pooled_output
540
            else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
541
542
                return encoder_output

543
544
        # Decoder embedding.
        if self.pre_process:
xingjinliang's avatar
xingjinliang committed
545
            decoder_input = self.embedding(dec_input_ids, dec_position_ids)
546
547
548
549
        else:
            decoder_input = None

        # Run decoder.
550
        decoder_output = self.decoder(
551
            decoder_input,
552
553
554
            dec_attn_mask,
            encoder_output=encoder_output,
            enc_dec_attn_mask=enc_dec_attn_mask,
Mostofa Patwary's avatar
Mostofa Patwary committed
555
            inference_params=inference_params,
xingjinliang's avatar
xingjinliang committed
556
557
            rotary_pos_emb=rotary_pos_emb,
        )
Vijay Korthikanti's avatar
Vijay Korthikanti committed
558

559
        if self.add_pooler and self.post_process:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
560
561
562
            return decoder_output, encoder_output, pooled_output
        else:
            return decoder_output, encoder_output
563

564
    def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
565
566
567
        """For easy load."""

        state_dict_ = {}
568
        if self.pre_process:
xingjinliang's avatar
xingjinliang committed
569
570
571
            state_dict_[self._embedding_key] = self.embedding.state_dict_for_save_checkpoint(
                prefix=prefix, keep_vars=keep_vars
            )
572
        if self.add_encoder:
xingjinliang's avatar
xingjinliang committed
573
574
575
            state_dict_[self._encoder_key] = self.encoder.state_dict_for_save_checkpoint(
                prefix=prefix, keep_vars=keep_vars
            )
576
        if self.post_process:
577
            if self.add_pooler:
xingjinliang's avatar
xingjinliang committed
578
579
580
                state_dict_[self._pooler_key] = self.pooler.state_dict_for_save_checkpoint(
                    prefix=prefix, keep_vars=keep_vars
                )
581
            if self.untie_embeddings_and_output_weights:
xingjinliang's avatar
xingjinliang committed
582
583
584
                state_dict_[self._output_layer_key] = self.output_layer.state_dict(
                    prefix=prefix, keep_vars=keep_vars
                )
MaximumEntropy's avatar
MaximumEntropy committed
585

586
        if self.add_decoder:
xingjinliang's avatar
xingjinliang committed
587
588
589
            state_dict_[self._decoder_key] = self.decoder.state_dict_for_save_checkpoint(
                prefix=prefix, keep_vars=keep_vars
            )
590
591
592
593
594
595
596

        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        # Embedding.
597
        if self.pre_process:
598
599
600
601
602
603
604
605
606
            if self._embedding_key in state_dict:
                state_dict_ = state_dict[self._embedding_key]
            else:
                # for backward compatibility.
                state_dict_ = {}
                for key in state_dict.keys():
                    if '_embeddings' in key:
                        state_dict_[key] = state_dict[key]
            self.embedding.load_state_dict(state_dict_, strict=strict)
607

608
        # Encoder.
609
610
611
612
613
614
        if self.add_encoder:
            if self._encoder_key in state_dict:
                state_dict_ = state_dict[self._encoder_key]
            # For backward compatibility.
            elif 'transformer' in state_dict:
                state_dict_ = state_dict['transformer']
615
            else:
616
617
618
619
620
621
622
623
624
625
                # For backward compatibility.
                state_dict_ = {}
                for key in state_dict.keys():
                    if 'transformer.' in key:
                        state_dict_[key.split('transformer.')[1]] = state_dict[key]

            # For backward compatibility.
            state_dict_self_attention = {}
            for key in state_dict_.keys():
                if '.attention.' in key:
xingjinliang's avatar
xingjinliang committed
626
627
628
                    state_dict_self_attention[key.replace(".attention.", ".self_attention.")] = (
                        state_dict_[key]
                    )
629
630
631
632
633
634
635
                else:
                    state_dict_self_attention[key] = state_dict_[key]
            state_dict_ = state_dict_self_attention

            self.encoder.load_state_dict(state_dict_, strict=strict)

        # Pooler.
636
        if self.post_process:
637
            if self.add_pooler:
xingjinliang's avatar
xingjinliang committed
638
639
                assert 'pooler' in state_dict, 'could not find data for pooler in the checkpoint'
                self.pooler.load_state_dict(state_dict[self._pooler_key], strict=strict)
640
            if self.untie_embeddings_and_output_weights:
xingjinliang's avatar
xingjinliang committed
641
642
643
644
                assert (
                    'output_layer' in state_dict
                ), 'could not find data for output_layer in the checkpoint'
                self.output_layer.load_state_dict(state_dict[self._output_layer_key], strict=strict)
645
        # Decoder.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
646
        if self.add_decoder:
xingjinliang's avatar
xingjinliang committed
647
648
            assert 'decoder' in state_dict, 'could not find data for pooler in the checkpoint'
            self.decoder.load_state_dict(state_dict[self._decoder_key], strict=strict)