language_model.py 25 KB
Newer Older
liangjing's avatar
v1  
liangjing committed
1
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
3
4
5
6
7

"""Transformer based language model."""

import torch
import torch.nn.functional as F

8
from megatron.core import mpu, tensor_parallel
liangjing's avatar
v1  
liangjing committed
9
from megatron.core.enums import ModelType
xingjinliang's avatar
xingjinliang committed
10
11
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.training import get_args
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
12

liangjing's avatar
v1  
liangjing committed
13
from .enums import AttnMaskType, LayerType
14
from .module import MegatronModule
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
15
from .transformer import ParallelTransformer
xingjinliang's avatar
xingjinliang committed
16
from .utils import get_linear_layer, init_method_normal, scaled_init_method_normal
17

silencealiang's avatar
add  
silencealiang committed
18
19
20
import torch._dynamo
torch._dynamo.config.suppress_errors = True

21

xingjinliang's avatar
xingjinliang committed
22
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None):
23
    """LM logits using word embedding weights."""
24
    args = get_args()
25
    # Parallel logits.
xingjinliang's avatar
xingjinliang committed
26
27
    model_parallel = mpu.get_tensor_model_parallel_world_size() > 1
    if model_parallel or args.sequence_parallel:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
28
        input_parallel = input_
xingjinliang's avatar
xingjinliang committed
29
        allreduce_dgrad = model_parallel and not args.sequence_parallel
30
    else:
31
        input_parallel = tensor_parallel.copy_to_tensor_model_parallel_region(input_)
xingjinliang's avatar
xingjinliang committed
32
        allreduce_dgrad = False
33

34
    # Matrix multiply.
35
    logits_parallel = tensor_parallel.linear_with_grad_accumulation_and_async_allreduce(
36
37
38
39
        input=input_parallel,
        weight=word_embeddings_weight,
        bias=bias,
        gradient_accumulation_fusion=args.gradient_accumulation_fusion,
xingjinliang's avatar
xingjinliang committed
40
41
42
43
        sequence_parallel=args.sequence_parallel,
        grad_output_buffer=None,
        allreduce_dgrad=allreduce_dgrad,
    )
44
    # Gather if needed.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
45

46
47
    if parallel_output:
        return logits_parallel
Mohammad's avatar
Mohammad committed
48

49
    return tensor_parallel.gather_from_tensor_model_parallel_region(logits_parallel)
Mohammad's avatar
Mohammad committed
50
51


xingjinliang's avatar
xingjinliang committed
52
53
54
55
56
57
58
59
60
61
62
def get_language_model(
    config,
    num_tokentypes,
    add_pooler,
    encoder_attn_mask_type,
    add_encoder=True,
    add_decoder=False,
    decoder_attn_mask_type=AttnMaskType.causal,
    pre_process=True,
    post_process=True,
):
Mohammad's avatar
Mohammad committed
63
    """Build language model and return along with the key to save."""
64
    args = get_args()
liangjing's avatar
v1  
liangjing committed
65
66
    if config.init_method is None:
        config.init_method = init_method_normal(config.init_method_std)
Mohammad's avatar
Mohammad committed
67

liangjing's avatar
v1  
liangjing committed
68
    if config.output_layer_init_method is None:
xingjinliang's avatar
xingjinliang committed
69
70
71
        config.output_layer_init_method = scaled_init_method_normal(
            config.init_method_std, config.num_layers
        )
72

73
    # Language model.
74
    language_model = TransformerLanguageModel(
liangjing's avatar
v1  
liangjing committed
75
        config,
76
77
        encoder_attn_mask_type,
        num_tokentypes=num_tokentypes,
78
        add_encoder=add_encoder,
79
80
81
82
        add_decoder=add_decoder,
        decoder_attn_mask_type=decoder_attn_mask_type,
        add_pooler=add_pooler,
        pre_process=pre_process,
xingjinliang's avatar
xingjinliang committed
83
        post_process=post_process,
84
    )
85
86
87
88
89
90
91
92
93
94
95
96
    # key used for checkpoints.
    language_model_key = 'language_model'

    return language_model, language_model_key


class Pooler(MegatronModule):
    """Pooler layer.

    Pool hidden states of a specific token (for example start of the
    sequence) and add a linear transformation followed by a tanh.

xingjinliang's avatar
xingjinliang committed
97
    Args:
98
99
100
101
        hidden_size: hidden size
        init_method: weight initialization method for the linear layer.
            bias is set to zero.
    """
Neel Kant's avatar
Neel Kant committed
102

103
104
    def __init__(self, hidden_size, init_method):
        super(Pooler, self).__init__()
105
        args = get_args()
106
        self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
107
108
        self.sequence_parallel = args.sequence_parallel

109
    def forward(self, hidden_states, sequence_index=0):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
110
        # hidden_states: [s, b, h]
111
        # sequence_index: index of the token to pool.
112
113
114
115

        # gather data along sequence dimensions
        # same pooler is run on all tensor parallel nodes
        if self.sequence_parallel:
116
            hidden_states = tensor_parallel.gather_from_sequence_parallel_region(
xingjinliang's avatar
xingjinliang committed
117
118
                hidden_states, tensor_parallel_output_grad=False
            )
119

Vijay Korthikanti's avatar
Vijay Korthikanti committed
120
        pooled = hidden_states[sequence_index, :, :]
121
122
123
124
125
126
127
128
        pooled = self.dense(pooled)
        pooled = torch.tanh(pooled)
        return pooled


class Embedding(MegatronModule):
    """Language model embeddings.

xingjinliang's avatar
xingjinliang committed
129
    Args:
130
131
132
133
134
135
136
137
138
        hidden_size: hidden size
        vocab_size: vocabulary size
        max_sequence_length: maximum size of sequence. This
                             is used for positional embedding
        embedding_dropout_prob: dropout probability for embeddings
        init_method: weight initialization method
        num_tokentypes: size of the token-type embeddings. 0 value
                        will ignore this embedding
    """
Neel Kant's avatar
Neel Kant committed
139

xingjinliang's avatar
xingjinliang committed
140
141
142
143
144
145
146
147
148
    def __init__(
        self,
        hidden_size,
        vocab_size,
        max_sequence_length,
        embedding_dropout_prob,
        config,
        num_tokentypes=0,
    ):
149
150
151
        super(Embedding, self).__init__()

        self.hidden_size = hidden_size
liangjing's avatar
v1  
liangjing committed
152
        self.init_method = config.init_method
153
154
        self.num_tokentypes = num_tokentypes

155
156
        args = get_args()

157
        # Word embeddings (parallel).
liangjing's avatar
v1  
liangjing committed
158
        self.params_dtype = args.params_dtype
159
        self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
xingjinliang's avatar
xingjinliang committed
160
161
            vocab_size, self.hidden_size, config=config, init_method=config.init_method
        )
162
163
164
        self._word_embeddings_key = 'word_embeddings'

        # Position embedding (serial).
liangjing's avatar
v1  
liangjing committed
165
        self.add_position_embedding = args.position_embedding_type == 'learned_absolute'
Mostofa Patwary's avatar
Mostofa Patwary committed
166
        if self.add_position_embedding:
xingjinliang's avatar
xingjinliang committed
167
            self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size)
Mostofa Patwary's avatar
Mostofa Patwary committed
168
169
170
171
            self._position_embeddings_key = 'position_embeddings'
            # Initialize the position embeddings.
            if args.perform_initialization:
                self.init_method(self.position_embeddings.weight)
172
173
174
175
176
177
178

        # Token type embedding.
        # Add this as an optional field that can be added through
        # method call so we can load a pretrain model without
        # token types and add them as needed.
        self._tokentype_embeddings_key = 'tokentype_embeddings'
        if self.num_tokentypes > 0:
xingjinliang's avatar
xingjinliang committed
179
            self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size)
180
            # Initialize the token-type embeddings.
181
182
            if args.perform_initialization:
                self.init_method(self.tokentype_embeddings.weight)
183
184
185
        else:
            self.tokentype_embeddings = None

liangjing's avatar
v1  
liangjing committed
186
        self.fp32_residual_connection = args.fp32_residual_connection
Vijay Korthikanti's avatar
Vijay Korthikanti committed
187
        self.sequence_parallel = args.sequence_parallel
xingjinliang's avatar
xingjinliang committed
188
        self.clone_scatter_output_in_embedding = args.clone_scatter_output_in_embedding
189
190
191
        # Embeddings dropout
        self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)

192
193
194
    def zero_parameters(self):
        """Zero out all parameters in embedding."""
        self.word_embeddings.weight.data.fill_(0)
Deepak Narayanan's avatar
Deepak Narayanan committed
195
        self.word_embeddings.weight.shared = True
Mostofa Patwary's avatar
Mostofa Patwary committed
196
197
198
        if self.add_position_embedding:
            self.position_embeddings.weight.data.fill_(0)
            self.position_embeddings.weight.shared = True
199
200
        if self.num_tokentypes > 0:
            self.tokentype_embeddings.weight.data.fill_(0)
Deepak Narayanan's avatar
Deepak Narayanan committed
201
            self.tokentype_embeddings.weight.shared = True
202

203
204
205
206
207
208
209
210
    def add_tokentype_embeddings(self, num_tokentypes):
        """Add token-type embedding. This function is provided so we can add
        token-type embeddings in case the pretrained model does not have it.
        This allows us to load the model normally and then add this embedding.
        """
        if self.tokentype_embeddings is not None:
            raise Exception('tokentype embeddings is already initialized')
        if torch.distributed.get_rank() == 0:
xingjinliang's avatar
xingjinliang committed
211
            print('adding embedding for {} tokentypes'.format(num_tokentypes), flush=True)
212
        self.num_tokentypes = num_tokentypes
xingjinliang's avatar
xingjinliang committed
213
        self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size)
214
        # Initialize the token-type embeddings.
215
        args = get_args()
216
217
218
219
220
        self.init_method(self.tokentype_embeddings.weight)

    def forward(self, input_ids, position_ids, tokentype_ids=None):
        # Embeddings.
        words_embeddings = self.word_embeddings(input_ids)
Mostofa Patwary's avatar
Mostofa Patwary committed
221
222
223
224
225
226
        if self.add_position_embedding:
            position_embeddings = self.position_embeddings(position_ids)
            embeddings = words_embeddings + position_embeddings
        else:
            embeddings = words_embeddings

227
228
229
230
231
232
        if tokentype_ids is not None:
            assert self.tokentype_embeddings is not None
            embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
        else:
            assert self.tokentype_embeddings is None

233
        # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
Vijay Korthikanti's avatar
Vijay Korthikanti committed
234
235
        embeddings = embeddings.transpose(0, 1).contiguous()

236
237
        # If the input flag for fp32 residual connection is set, convert for float.
        if self.fp32_residual_connection:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
238
            embeddings = embeddings.float()
239

240
        # Dropout.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
241
        if self.sequence_parallel:
242
            embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings)
xingjinliang's avatar
xingjinliang committed
243
244
245
246
247
            # `scatter_to_sequence_parallel_region` returns a view, which prevents
            # the original tensor from being garbage collected. Clone to facilitate GC.
            # Has a small runtime cost (~0.5%).
            if self.clone_scatter_output_in_embedding:
                embeddings = embeddings.clone()
248
            with tensor_parallel.get_cuda_rng_tracker().fork():
249
250
251
                embeddings = self.embedding_dropout(embeddings)
        else:
            embeddings = self.embedding_dropout(embeddings)
252
253
254

        return embeddings

255
    def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
256
257
258
        """For easy load."""

        state_dict_ = {}
xingjinliang's avatar
xingjinliang committed
259
260
261
        state_dict_[self._word_embeddings_key] = self.word_embeddings.state_dict(
            prefix=prefix, keep_vars=keep_vars
        )
Mostofa Patwary's avatar
Mostofa Patwary committed
262
        if self.add_position_embedding:
xingjinliang's avatar
xingjinliang committed
263
264
265
            state_dict_[self._position_embeddings_key] = self.position_embeddings.state_dict(
                prefix=prefix, keep_vars=keep_vars
            )
266
        if self.num_tokentypes > 0:
xingjinliang's avatar
xingjinliang committed
267
268
269
            state_dict_[self._tokentype_embeddings_key] = self.tokentype_embeddings.state_dict(
                prefix=prefix, keep_vars=keep_vars
            )
270
271
272
273
274
275
276
277
278
279
280
281
282
283

        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        # Word embedding.
        if self._word_embeddings_key in state_dict:
            state_dict_ = state_dict[self._word_embeddings_key]
        else:
            # for backward compatibility.
            state_dict_ = {}
            for key in state_dict.keys():
                if 'word_embeddings' in key:
xingjinliang's avatar
xingjinliang committed
284
                    state_dict_[key.split('word_embeddings.')[1]] = state_dict[key]
285
286
287
        self.word_embeddings.load_state_dict(state_dict_, strict=strict)

        # Position embedding.
Mostofa Patwary's avatar
Mostofa Patwary committed
288
289
290
291
292
293
294
295
        if self.add_position_embedding:
            if self._position_embeddings_key in state_dict:
                state_dict_ = state_dict[self._position_embeddings_key]
            else:
                # for backward compatibility.
                state_dict_ = {}
                for key in state_dict.keys():
                    if 'position_embeddings' in key:
xingjinliang's avatar
xingjinliang committed
296
                        state_dict_[key.split('position_embeddings.')[1]] = state_dict[key]
Mostofa Patwary's avatar
Mostofa Patwary committed
297
            self.position_embeddings.load_state_dict(state_dict_, strict=strict)
298
299

        # Tokentype embedding.
Neel Kant's avatar
Neel Kant committed
300
        if self.num_tokentypes > 0:
301
302
303
304
305
306
307
            state_dict_ = {}
            if self._tokentype_embeddings_key in state_dict:
                state_dict_ = state_dict[self._tokentype_embeddings_key]
            else:
                # for backward compatibility.
                for key in state_dict.keys():
                    if 'tokentype_embeddings' in key:
xingjinliang's avatar
xingjinliang committed
308
                        state_dict_[key.split('tokentype_embeddings.')[1]] = state_dict[key]
309
            if len(state_dict_.keys()) > 0:
xingjinliang's avatar
xingjinliang committed
310
                self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict)
311
            else:
xingjinliang's avatar
xingjinliang committed
312
313
314
315
316
                print(
                    '***WARNING*** expected tokentype embeddings in the '
                    'checkpoint but could not find it',
                    flush=True,
                )
317
318


319
class TransformerLanguageModel(MegatronModule):
320
321
    """Transformer language model.

xingjinliang's avatar
xingjinliang committed
322
    Args:
323
324
325
326
327
328
329
330
        transformer_hparams: transformer hyperparameters
        vocab_size: vocabulary size
        max_sequence_length: maximum size of sequence. This
                             is used for positional embedding
        embedding_dropout_prob: dropout probability for embeddings
        num_tokentypes: size of the token-type embeddings. 0 value
                        will ignore this embedding
    """
Neel Kant's avatar
Neel Kant committed
331

xingjinliang's avatar
xingjinliang committed
332
333
334
335
336
337
338
339
340
341
342
343
    def __init__(
        self,
        config,
        encoder_attn_mask_type,
        num_tokentypes=0,
        add_encoder=True,
        add_decoder=False,
        decoder_attn_mask_type=AttnMaskType.causal,
        add_pooler=False,
        pre_process=True,
        post_process=True,
    ):
Mohammad's avatar
Mohammad committed
344
        args = get_args()
liangjing's avatar
v1  
liangjing committed
345
        # TODO: passing share_embeddings_and_output_weights=False will not work correctly for T5 and embeddings will not be synced. Fix later for T5.
xingjinliang's avatar
xingjinliang committed
346
347
348
349
350
        if args.untie_embeddings_and_output_weights:
            assert not add_decoder
        super(TransformerLanguageModel, self).__init__(
            share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights
        )
351

352
353
        self.pre_process = pre_process
        self.post_process = post_process
liangjing's avatar
v1  
liangjing committed
354
        self.hidden_size = config.hidden_size
355
        self.num_tokentypes = num_tokentypes
liangjing's avatar
v1  
liangjing committed
356
        self.init_method = config.init_method
357
        self.add_encoder = add_encoder
358
        self.encoder_attn_mask_type = encoder_attn_mask_type
359
        self.add_decoder = add_decoder
360
        self.decoder_attn_mask_type = decoder_attn_mask_type
361
        self.add_pooler = add_pooler
362
        self.encoder_hidden_state = None
liangjing's avatar
v1  
liangjing committed
363
        self.add_retriever = args.retro_add_retriever
364
        self.untie_embeddings_and_output_weights = args.untie_embeddings_and_output_weights
365

366
        # Embeddings.
367
        if self.pre_process:
xingjinliang's avatar
xingjinliang committed
368
369
370
371
372
373
374
375
            self.embedding = Embedding(
                self.hidden_size,
                args.padded_vocab_size,
                args.max_position_embeddings,
                args.hidden_dropout,
                config,
                self.num_tokentypes,
            )
376
            self._embedding_key = 'embedding'
377

Mostofa Patwary's avatar
Mostofa Patwary committed
378
        # Rotary positional embeddings
xingjinliang's avatar
xingjinliang committed
379
        self.use_rotary_position_embeddings = args.position_embedding_type == 'rope'
liangjing's avatar
v1  
liangjing committed
380
        if self.use_rotary_position_embeddings:
Mostofa Patwary's avatar
Mostofa Patwary committed
381
            self.seq_length = args.seq_length
xingjinliang's avatar
xingjinliang committed
382
383
384
385
386
            rotary_dim = (
                args.hidden_size // args.num_attention_heads
                if args.kv_channels is None
                else args.kv_channels
            )
Mostofa Patwary's avatar
Mostofa Patwary committed
387
388
389
390

            # partial rotary embeddings, which is better than full rotary
            # Wang and Komatsuzaki et al
            # https://github.com/kingoflolz/mesh-transformer-jax/
liangjing's avatar
v1  
liangjing committed
391
            self.rotary_pos_emb = RotaryEmbedding(
xingjinliang's avatar
xingjinliang committed
392
393
394
                kv_channels=rotary_dim,
                rotary_percent=args.rotary_percent,
                seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor,
395
            )
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
396
397
398
399

        # Encoder (usually set to True, False if part of an encoder-decoder
        # architecture and in encoder-only stage).
        if self.add_encoder:
liangjing's avatar
v1  
liangjing committed
400
401
            self.encoder = ParallelTransformer(
                config,
xingjinliang's avatar
xingjinliang committed
402
403
404
                model_type=(
                    args.model_type if not args.retro_add_retriever else ModelType.retro_decoder
                ),
liangjing's avatar
v1  
liangjing committed
405
406
407
408
                self_attn_mask_type=self.encoder_attn_mask_type,
                pre_process=self.pre_process,
                post_process=self.post_process,
            )
409
410
411
412
413
414
            self._encoder_key = 'encoder'
        else:
            self.encoder = None

        # Decoder (usually set to False, True if part of an encoder-decoder
        # architecture and in decoder-only stage).
Vijay Korthikanti's avatar
Vijay Korthikanti committed
415
416
        if self.add_decoder:
            self.decoder = ParallelTransformer(
liangjing's avatar
v1  
liangjing committed
417
418
                config,
                model_type=args.model_type,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
419
                layer_type=LayerType.decoder,
420
421
                self_attn_mask_type=self.decoder_attn_mask_type,
                pre_process=self.pre_process,
xingjinliang's avatar
xingjinliang committed
422
423
                post_process=self.post_process,
            )
Vijay Korthikanti's avatar
Vijay Korthikanti committed
424
            self._decoder_key = 'decoder'
425
426
        else:
            self.decoder = None
427

428
        if self.post_process:
429
430
431
432
433
            # Pooler.
            if self.add_pooler:
                self.pooler = Pooler(self.hidden_size, self.init_method)
                self._pooler_key = 'pooler'

434
435
436
437
            if self.untie_embeddings_and_output_weights:
                self.output_layer = tensor_parallel.ColumnParallelLinear(
                    args.hidden_size,
                    args.padded_vocab_size,
liangjing's avatar
v1  
liangjing committed
438
439
                    config=config,
                    init_method=self.init_method,
xingjinliang's avatar
xingjinliang committed
440
441
                    bias=False,
                )  # Setting bias to False always to keep it consistent with embedding tying that also does not have a bias.
442
443
                self._output_layer_key = 'output_layer'

444
    def set_input_tensor(self, input_tensor):
xingjinliang's avatar
xingjinliang committed
445
        """See megatron.legacy.model.transformer.set_input_tensor()"""
446
447
448
449
450
451

        # This is usually handled in schedules.py but some inference code still
        # gives us non-lists or None
        if not isinstance(input_tensor, list):
            input_tensor = [input_tensor]

452
        if self.add_encoder and self.add_decoder:
xingjinliang's avatar
xingjinliang committed
453
454
455
            assert (
                len(input_tensor) == 1
            ), 'input_tensor should only be length 1 for stage with both encoder and decoder'
456
457
            self.encoder.set_input_tensor(input_tensor[0])
        elif self.add_encoder:
xingjinliang's avatar
xingjinliang committed
458
459
460
            assert (
                len(input_tensor) == 1
            ), 'input_tensor should only be length 1 for stage with only encoder'
461
462
463
464
465
466
467
468
469
470
471
472
            self.encoder.set_input_tensor(input_tensor[0])
        elif self.add_decoder:
            if len(input_tensor) == 2:
                self.decoder.set_input_tensor(input_tensor[0])
                self.encoder_hidden_state = input_tensor[1]
            elif len(input_tensor) == 1:
                self.decoder.set_input_tensor(None)
                self.encoder_hidden_state = input_tensor[0]
            else:
                raise Exception('input_tensor must have either length 1 or 2')
        else:
            raise Exception('Stage must have at least either encoder or decoder')
473

silencealiang's avatar
add  
silencealiang committed
474
    # @torch.compile(mode="max-autotune-no-cudagraphs")
xingjinliang's avatar
xingjinliang committed
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
    def forward(
        self,
        enc_input_ids,
        enc_position_ids,
        enc_attn_mask,
        dec_input_ids=None,
        dec_position_ids=None,
        dec_attn_mask=None,
        retriever_input_ids=None,
        retriever_position_ids=None,
        retriever_attn_mask=None,
        enc_dec_attn_mask=None,
        tokentype_ids=None,
        inference_params=None,
        pooling_sequence_index=0,
        enc_hidden_states=None,
        output_enc_hidden=False,
    ):
493

494
        # Encoder embedding.
495
        if self.pre_process:
xingjinliang's avatar
xingjinliang committed
496
497
498
            encoder_input = self.embedding(
                enc_input_ids, enc_position_ids, tokentype_ids=tokentype_ids
            )
499
        else:
500
            encoder_input = None
501

liangjing's avatar
v1  
liangjing committed
502
503
        # Retriever embedding.
        if self.add_retriever and self.pre_process:
xingjinliang's avatar
xingjinliang committed
504
505
506
            retriever_input = self.embedding(
                retriever_input_ids, retriever_position_ids, tokentype_ids=tokentype_ids
            )
liangjing's avatar
v1  
liangjing committed
507
508
509
        else:
            retriever_input = None

Mostofa Patwary's avatar
Mostofa Patwary committed
510
511
512
513
        # Rotary positional embeddings
        rotary_pos_emb = None
        if self.use_rotary_position_embeddings:
            if inference_params is not None:
xingjinliang's avatar
xingjinliang committed
514
                rotary_pos_emb = self.rotary_pos_emb(inference_params.max_sequence_length)
Mostofa Patwary's avatar
Mostofa Patwary committed
515
516
517
            else:
                rotary_pos_emb = self.rotary_pos_emb(self.seq_length)

518
        # Run encoder.
519
        if enc_hidden_states is None:
520
            if self.encoder is not None:
liangjing's avatar
v1  
liangjing committed
521
522
523
524
525
526
                encoder_output = self.encoder(
                    encoder_input,
                    enc_attn_mask,
                    retriever_input=retriever_input,
                    retriever_attn_mask=retriever_attn_mask,
                    inference_params=inference_params,
xingjinliang's avatar
xingjinliang committed
527
528
                    rotary_pos_emb=rotary_pos_emb,
                )
529
530
            else:
                encoder_output = self.encoder_hidden_state
531
532
533
        else:
            encoder_output = enc_hidden_states.to(encoder_input.dtype)

534
        if self.post_process:
535
            if self.add_pooler:
xingjinliang's avatar
xingjinliang committed
536
                pooled_output = self.pooler(encoder_output, pooling_sequence_index)
537

Vijay Korthikanti's avatar
Vijay Korthikanti committed
538
539
540
541
        # output_enc_hidden refers to when we just need the encoder's
        # output. For example, it is helpful to compute
        # similarity between two sequences by average pooling
        if not self.add_decoder or output_enc_hidden:
542
            if self.add_pooler and self.post_process:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
543
                return encoder_output, pooled_output
544
            else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
545
546
                return encoder_output

547
548
        # Decoder embedding.
        if self.pre_process:
xingjinliang's avatar
xingjinliang committed
549
            decoder_input = self.embedding(dec_input_ids, dec_position_ids)
550
551
552
553
        else:
            decoder_input = None

        # Run decoder.
554
        decoder_output = self.decoder(
555
            decoder_input,
556
557
558
            dec_attn_mask,
            encoder_output=encoder_output,
            enc_dec_attn_mask=enc_dec_attn_mask,
Mostofa Patwary's avatar
Mostofa Patwary committed
559
            inference_params=inference_params,
xingjinliang's avatar
xingjinliang committed
560
561
            rotary_pos_emb=rotary_pos_emb,
        )
Vijay Korthikanti's avatar
Vijay Korthikanti committed
562

563
        if self.add_pooler and self.post_process:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
564
565
566
            return decoder_output, encoder_output, pooled_output
        else:
            return decoder_output, encoder_output
567

568
    def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
569
570
571
        """For easy load."""

        state_dict_ = {}
572
        if self.pre_process:
xingjinliang's avatar
xingjinliang committed
573
574
575
            state_dict_[self._embedding_key] = self.embedding.state_dict_for_save_checkpoint(
                prefix=prefix, keep_vars=keep_vars
            )
576
        if self.add_encoder:
xingjinliang's avatar
xingjinliang committed
577
578
579
            state_dict_[self._encoder_key] = self.encoder.state_dict_for_save_checkpoint(
                prefix=prefix, keep_vars=keep_vars
            )
580
        if self.post_process:
581
            if self.add_pooler:
xingjinliang's avatar
xingjinliang committed
582
583
584
                state_dict_[self._pooler_key] = self.pooler.state_dict_for_save_checkpoint(
                    prefix=prefix, keep_vars=keep_vars
                )
585
            if self.untie_embeddings_and_output_weights:
xingjinliang's avatar
xingjinliang committed
586
587
588
                state_dict_[self._output_layer_key] = self.output_layer.state_dict(
                    prefix=prefix, keep_vars=keep_vars
                )
MaximumEntropy's avatar
MaximumEntropy committed
589

590
        if self.add_decoder:
xingjinliang's avatar
xingjinliang committed
591
592
593
            state_dict_[self._decoder_key] = self.decoder.state_dict_for_save_checkpoint(
                prefix=prefix, keep_vars=keep_vars
            )
594
595
596
597
598
599
600

        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        # Embedding.
601
        if self.pre_process:
602
603
604
605
606
607
608
609
610
            if self._embedding_key in state_dict:
                state_dict_ = state_dict[self._embedding_key]
            else:
                # for backward compatibility.
                state_dict_ = {}
                for key in state_dict.keys():
                    if '_embeddings' in key:
                        state_dict_[key] = state_dict[key]
            self.embedding.load_state_dict(state_dict_, strict=strict)
611

612
        # Encoder.
613
614
615
616
617
618
        if self.add_encoder:
            if self._encoder_key in state_dict:
                state_dict_ = state_dict[self._encoder_key]
            # For backward compatibility.
            elif 'transformer' in state_dict:
                state_dict_ = state_dict['transformer']
619
            else:
620
621
622
623
624
625
626
627
628
629
                # For backward compatibility.
                state_dict_ = {}
                for key in state_dict.keys():
                    if 'transformer.' in key:
                        state_dict_[key.split('transformer.')[1]] = state_dict[key]

            # For backward compatibility.
            state_dict_self_attention = {}
            for key in state_dict_.keys():
                if '.attention.' in key:
xingjinliang's avatar
xingjinliang committed
630
631
632
                    state_dict_self_attention[key.replace(".attention.", ".self_attention.")] = (
                        state_dict_[key]
                    )
633
634
635
636
637
638
639
                else:
                    state_dict_self_attention[key] = state_dict_[key]
            state_dict_ = state_dict_self_attention

            self.encoder.load_state_dict(state_dict_, strict=strict)

        # Pooler.
640
        if self.post_process:
641
            if self.add_pooler:
xingjinliang's avatar
xingjinliang committed
642
643
                assert 'pooler' in state_dict, 'could not find data for pooler in the checkpoint'
                self.pooler.load_state_dict(state_dict[self._pooler_key], strict=strict)
644
            if self.untie_embeddings_and_output_weights:
xingjinliang's avatar
xingjinliang committed
645
646
647
648
                assert (
                    'output_layer' in state_dict
                ), 'could not find data for output_layer in the checkpoint'
                self.output_layer.load_state_dict(state_dict[self._output_layer_key], strict=strict)
649
        # Decoder.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
650
        if self.add_decoder:
xingjinliang's avatar
xingjinliang committed
651
652
            assert 'decoder' in state_dict, 'could not find data for pooler in the checkpoint'
            self.decoder.load_state_dict(state_dict[self._decoder_key], strict=strict)