language_model.py 25.2 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4
5
6
7

"""Transformer based language model."""

import torch
import torch.nn.functional as F

Mohammad's avatar
Mohammad committed
8
from megatron import get_args
9
from megatron.core import mpu, tensor_parallel
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
10
11

from .enums import LayerType, AttnMaskType
12
from .module import MegatronModule
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
13
from .retro_transformer import ParallelRetroEncoder, ParallelRetroTransformer
Mostofa Patwary's avatar
Mostofa Patwary committed
14
from .rotary_pos_embedding import apply_rotary_pos_emb, RotaryEmbedding
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
15
16
17
from .transformer import ParallelTransformer
from .utils import get_linear_layer
from .utils import init_method_normal, scaled_init_method_normal
18

19

20
21
22
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
                       bias=None):
    """LM logits using word embedding weights."""
23
    args = get_args()
24
    # Parallel logits.
25
    if args.async_tensor_model_parallel_allreduce or\
Vijay Korthikanti's avatar
Vijay Korthikanti committed
26
            args.sequence_parallel:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
27
        input_parallel = input_
28
        model_parallel = mpu.get_tensor_model_parallel_world_size() > 1
29
        async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
30
            model_parallel and not args.sequence_parallel
31
    else:
32
        input_parallel = tensor_parallel.copy_to_tensor_model_parallel_region(input_)
33
34
        async_grad_allreduce = False

35
    # Matrix multiply.
36
    logits_parallel = tensor_parallel.linear_with_grad_accumulation_and_async_allreduce(
37
38
39
40
41
42
        input=input_parallel,
        weight=word_embeddings_weight,
        bias=bias,
        gradient_accumulation_fusion=args.gradient_accumulation_fusion,
        async_grad_allreduce=async_grad_allreduce,
        sequence_parallel_enabled=args.sequence_parallel)
43
    # Gather if needed.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
44

45
46
    if parallel_output:
        return logits_parallel
Mohammad's avatar
Mohammad committed
47

48
    return tensor_parallel.gather_from_tensor_model_parallel_region(logits_parallel)
Mohammad's avatar
Mohammad committed
49
50


51
def get_language_model(num_tokentypes, add_pooler,
52
                       encoder_attn_mask_type, init_method=None,
53
54
                       scaled_init_method=None, add_encoder=True,
                       add_decoder=False,
55
56
                       decoder_attn_mask_type=AttnMaskType.causal,
                       pre_process=True, post_process=True):
Mohammad's avatar
Mohammad committed
57
    """Build language model and return along with the key to save."""
58
    args = get_args()
Mohammad's avatar
Mohammad committed
59

60
61
62
63
    if init_method is None:
        init_method = init_method_normal(args.init_method_std)

    if scaled_init_method is None:
64
65
        scaled_init_method = scaled_init_method_normal(args.init_method_std,
                                                       args.num_layers)
66

67
    # Language model.
68
69
70
71
72
    language_model = TransformerLanguageModel(
        init_method,
        scaled_init_method,
        encoder_attn_mask_type,
        num_tokentypes=num_tokentypes,
73
        add_encoder=add_encoder,
74
75
76
77
78
79
        add_decoder=add_decoder,
        decoder_attn_mask_type=decoder_attn_mask_type,
        add_pooler=add_pooler,
        pre_process=pre_process,
        post_process=post_process
    )
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    # key used for checkpoints.
    language_model_key = 'language_model'

    return language_model, language_model_key


class Pooler(MegatronModule):
    """Pooler layer.

    Pool hidden states of a specific token (for example start of the
    sequence) and add a linear transformation followed by a tanh.

    Arguments:
        hidden_size: hidden size
        init_method: weight initialization method for the linear layer.
            bias is set to zero.
    """
Neel Kant's avatar
Neel Kant committed
97

98
99
    def __init__(self, hidden_size, init_method):
        super(Pooler, self).__init__()
100
        args = get_args()
101
        self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
102
103
        self.sequence_parallel = args.sequence_parallel

104
105

    def forward(self, hidden_states, sequence_index=0):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
106
        # hidden_states: [s, b, h]
107
        # sequence_index: index of the token to pool.
108
109
110
111

        # gather data along sequence dimensions
        # same pooler is run on all tensor parallel nodes
        if self.sequence_parallel:
112
            hidden_states = tensor_parallel.gather_from_sequence_parallel_region(
113
                hidden_states,
114
                tensor_parallel_output_grad=False)
115

Vijay Korthikanti's avatar
Vijay Korthikanti committed
116
        pooled = hidden_states[sequence_index, :, :]
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
        pooled = self.dense(pooled)
        pooled = torch.tanh(pooled)
        return pooled


class Embedding(MegatronModule):
    """Language model embeddings.

    Arguments:
        hidden_size: hidden size
        vocab_size: vocabulary size
        max_sequence_length: maximum size of sequence. This
                             is used for positional embedding
        embedding_dropout_prob: dropout probability for embeddings
        init_method: weight initialization method
        num_tokentypes: size of the token-type embeddings. 0 value
                        will ignore this embedding
    """
Neel Kant's avatar
Neel Kant committed
135

136
137
138
139
140
141
142
143
144
145
146
147
148
    def __init__(self,
                 hidden_size,
                 vocab_size,
                 max_sequence_length,
                 embedding_dropout_prob,
                 init_method,
                 num_tokentypes=0):
        super(Embedding, self).__init__()

        self.hidden_size = hidden_size
        self.init_method = init_method
        self.num_tokentypes = num_tokentypes

149
150
        args = get_args()

151
        # Word embeddings (parallel).
152
        self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
153
            vocab_size, self.hidden_size,
154
155
156
157
158
            init_method=self.init_method,
            params_dtype=args.params_dtype,
            use_cpu_initialization=args.use_cpu_initialization,
            perform_initialization=args.perform_initialization
        )
159
160
161
        self._word_embeddings_key = 'word_embeddings'

        # Position embedding (serial).
Mostofa Patwary's avatar
Mostofa Patwary committed
162
163
164
165
166
167
168
169
        self.add_position_embedding = args.add_position_embedding
        if self.add_position_embedding:
            self.position_embeddings = torch.nn.Embedding(
                max_sequence_length, self.hidden_size)
            self._position_embeddings_key = 'position_embeddings'
            # Initialize the position embeddings.
            if args.perform_initialization:
                self.init_method(self.position_embeddings.weight)
170
171
172
173
174
175
176
177
178
179

        # Token type embedding.
        # Add this as an optional field that can be added through
        # method call so we can load a pretrain model without
        # token types and add them as needed.
        self._tokentype_embeddings_key = 'tokentype_embeddings'
        if self.num_tokentypes > 0:
            self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes,
                                                           self.hidden_size)
            # Initialize the token-type embeddings.
180
181
            if args.perform_initialization:
                self.init_method(self.tokentype_embeddings.weight)
182
183
184
        else:
            self.tokentype_embeddings = None

185
        self.fp32_residual_connection = args.fp32_residual_connection 
Vijay Korthikanti's avatar
Vijay Korthikanti committed
186
        self.sequence_parallel = args.sequence_parallel
187
188
189
        # Embeddings dropout
        self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)

190
191
192
    def zero_parameters(self):
        """Zero out all parameters in embedding."""
        self.word_embeddings.weight.data.fill_(0)
Deepak Narayanan's avatar
Deepak Narayanan committed
193
        self.word_embeddings.weight.shared = True
Mostofa Patwary's avatar
Mostofa Patwary committed
194
195
196
        if self.add_position_embedding:
            self.position_embeddings.weight.data.fill_(0)
            self.position_embeddings.weight.shared = True
197
198
        if self.num_tokentypes > 0:
            self.tokentype_embeddings.weight.data.fill_(0)
Deepak Narayanan's avatar
Deepak Narayanan committed
199
            self.tokentype_embeddings.weight.shared = True
200

201
202
203
204
205
206
207
208
209
210
211
212
213
214
    def add_tokentype_embeddings(self, num_tokentypes):
        """Add token-type embedding. This function is provided so we can add
        token-type embeddings in case the pretrained model does not have it.
        This allows us to load the model normally and then add this embedding.
        """
        if self.tokentype_embeddings is not None:
            raise Exception('tokentype embeddings is already initialized')
        if torch.distributed.get_rank() == 0:
            print('adding embedding for {} tokentypes'.format(num_tokentypes),
                  flush=True)
        self.num_tokentypes = num_tokentypes
        self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes,
                                                       self.hidden_size)
        # Initialize the token-type embeddings.
215
        args = get_args()
216
217
218
219
220
        self.init_method(self.tokentype_embeddings.weight)

    def forward(self, input_ids, position_ids, tokentype_ids=None):
        # Embeddings.
        words_embeddings = self.word_embeddings(input_ids)
Mostofa Patwary's avatar
Mostofa Patwary committed
221
222
223
224
225
226
        if self.add_position_embedding:
            position_embeddings = self.position_embeddings(position_ids)
            embeddings = words_embeddings + position_embeddings
        else:
            embeddings = words_embeddings

227
228
229
230
231
232
        if tokentype_ids is not None:
            assert self.tokentype_embeddings is not None
            embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
        else:
            assert self.tokentype_embeddings is None

233
        # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
Vijay Korthikanti's avatar
Vijay Korthikanti committed
234
235
        embeddings = embeddings.transpose(0, 1).contiguous()

236
237
        # If the input flag for fp32 residual connection is set, convert for float.
        if self.fp32_residual_connection:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
238
            embeddings = embeddings.float()
239

240
        # Dropout.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
241
        if self.sequence_parallel:
242
243
            embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings)
            with tensor_parallel.get_cuda_rng_tracker().fork():
244
245
246
                embeddings = self.embedding_dropout(embeddings)
        else:
            embeddings = self.embedding_dropout(embeddings)
247
248
249

        return embeddings

250
    def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
251
252
253
254
        """For easy load."""

        state_dict_ = {}
        state_dict_[self._word_embeddings_key] \
255
256
            = self.word_embeddings.state_dict(prefix=prefix,
                                              keep_vars=keep_vars)
Mostofa Patwary's avatar
Mostofa Patwary committed
257
258
259
        if self.add_position_embedding:
            state_dict_[self._position_embeddings_key] \
                = self.position_embeddings.state_dict(prefix=prefix,
260
                                                  keep_vars=keep_vars)
261
262
        if self.num_tokentypes > 0:
            state_dict_[self._tokentype_embeddings_key] \
263
264
                = self.tokentype_embeddings.state_dict(prefix=prefix,
                                                       keep_vars=keep_vars)
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283

        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        # Word embedding.
        if self._word_embeddings_key in state_dict:
            state_dict_ = state_dict[self._word_embeddings_key]
        else:
            # for backward compatibility.
            state_dict_ = {}
            for key in state_dict.keys():
                if 'word_embeddings' in key:
                    state_dict_[key.split('word_embeddings.')[1]] \
                        = state_dict[key]
        self.word_embeddings.load_state_dict(state_dict_, strict=strict)

        # Position embedding.
Mostofa Patwary's avatar
Mostofa Patwary committed
284
285
286
287
288
289
290
291
292
293
294
        if self.add_position_embedding:
            if self._position_embeddings_key in state_dict:
                state_dict_ = state_dict[self._position_embeddings_key]
            else:
                # for backward compatibility.
                state_dict_ = {}
                for key in state_dict.keys():
                    if 'position_embeddings' in key:
                        state_dict_[key.split('position_embeddings.')[1]] \
                            = state_dict[key]
            self.position_embeddings.load_state_dict(state_dict_, strict=strict)
295
296

        # Tokentype embedding.
Neel Kant's avatar
Neel Kant committed
297
        if self.num_tokentypes > 0:
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
            state_dict_ = {}
            if self._tokentype_embeddings_key in state_dict:
                state_dict_ = state_dict[self._tokentype_embeddings_key]
            else:
                # for backward compatibility.
                for key in state_dict.keys():
                    if 'tokentype_embeddings' in key:
                        state_dict_[key.split('tokentype_embeddings.')[1]] \
                            = state_dict[key]
            if len(state_dict_.keys()) > 0:
                self.tokentype_embeddings.load_state_dict(state_dict_,
                                                          strict=strict)
            else:
                print('***WARNING*** expected tokentype embeddings in the '
                      'checkpoint but could not find it', flush=True)


315
class TransformerLanguageModel(MegatronModule):
316
317
318
319
320
321
322
323
324
325
326
    """Transformer language model.

    Arguments:
        transformer_hparams: transformer hyperparameters
        vocab_size: vocabulary size
        max_sequence_length: maximum size of sequence. This
                             is used for positional embedding
        embedding_dropout_prob: dropout probability for embeddings
        num_tokentypes: size of the token-type embeddings. 0 value
                        will ignore this embedding
    """
Neel Kant's avatar
Neel Kant committed
327

328
    def __init__(self,
Mohammad's avatar
Mohammad committed
329
330
                 init_method,
                 output_layer_init_method,
331
                 encoder_attn_mask_type,
332
                 num_tokentypes=0,
333
                 add_encoder=True,
334
                 add_decoder=False,
335
                 decoder_attn_mask_type=AttnMaskType.causal,
336
337
338
339
                 add_pooler=False,
                 pre_process=True,
                 post_process=True):
        super(TransformerLanguageModel, self).__init__()
Mohammad's avatar
Mohammad committed
340
        args = get_args()
341

342
343
        self.pre_process = pre_process
        self.post_process = post_process
Mohammad's avatar
Mohammad committed
344
        self.hidden_size = args.hidden_size
345
        self.num_tokentypes = num_tokentypes
Mohammad's avatar
Mohammad committed
346
        self.init_method = init_method
347
        self.add_encoder = add_encoder
348
        self.encoder_attn_mask_type = encoder_attn_mask_type
349
        self.add_decoder = add_decoder
350
        self.decoder_attn_mask_type = decoder_attn_mask_type
351
        self.add_pooler = add_pooler
352
        self.encoder_hidden_state = None
353

354
        # Embeddings.
355
        if self.pre_process:
356
357
358
359
360
361
362
            self.embedding = Embedding(self.hidden_size,
                                       args.padded_vocab_size,
                                       args.max_position_embeddings,
                                       args.hidden_dropout,
                                       self.init_method,
                                       self.num_tokentypes)
            self._embedding_key = 'embedding'
363

Mostofa Patwary's avatar
Mostofa Patwary committed
364
        # Rotary positional embeddings
Mostofa Patwary's avatar
Mostofa Patwary committed
365
366
        self.use_rotary_position_embeddings = \
            args.use_rotary_position_embeddings
Mostofa Patwary's avatar
Mostofa Patwary committed
367
368
369
370
371
372
373
374
375
376
377
378
379
        if args.use_rotary_position_embeddings:
            self.seq_length = args.seq_length
            rotary_dim = args.hidden_size // args.num_attention_heads \
                if args.kv_channels is None else args.kv_channels

            if args.rotary_percent < 1.0:
                rotary_dim = int(rotary_dim * args.rotary_percent)

            # partial rotary embeddings, which is better than full rotary
            # Wang and Komatsuzaki et al
            # https://github.com/kingoflolz/mesh-transformer-jax/
            self.rotary_pos_emb = RotaryEmbedding(rotary_dim)

Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
380
381
382
        # Retriever (bi-directional transformer with cross attention)
        if args.retro_add_retriever:
            self.retriever = ParallelRetroEncoder(
383
384
                self.init_method,
                output_layer_init_method,
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
385
                self_attn_mask_type=AttnMaskType.padding,
386
                pre_process=self.pre_process,
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
387
                post_process=False,
388
            )
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
            self._retriever_key = 'retriever'
        else:
            self.retriever = None

        # Encoder (usually set to True, False if part of an encoder-decoder
        # architecture and in encoder-only stage).
        if self.add_encoder:
            if args.retro_add_retriever:
                self.encoder = ParallelRetroTransformer(
                    self.init_method,
                    output_layer_init_method,
                    self_attn_mask_type=self.encoder_attn_mask_type,
                    pre_process=self.pre_process,
                    post_process=self.post_process,
                    retriever=self.retriever,
                )
            else:
                self.encoder = ParallelTransformer(
                    self.init_method,
                    output_layer_init_method,
                    self_attn_mask_type=self.encoder_attn_mask_type,
                    pre_process=self.pre_process,
                    post_process=self.post_process,
                )
413
414
415
416
417
418
            self._encoder_key = 'encoder'
        else:
            self.encoder = None

        # Decoder (usually set to False, True if part of an encoder-decoder
        # architecture and in decoder-only stage).
Vijay Korthikanti's avatar
Vijay Korthikanti committed
419
420
421
422
423
        if self.add_decoder:
            self.decoder = ParallelTransformer(
                self.init_method,
                output_layer_init_method,
                layer_type=LayerType.decoder,
424
425
426
                self_attn_mask_type=self.decoder_attn_mask_type,
                pre_process=self.pre_process,
                post_process=self.post_process)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
427
            self._decoder_key = 'decoder'
428
429
        else:
            self.decoder = None
430

431
        if self.post_process:
432
433
434
435
436
            # Pooler.
            if self.add_pooler:
                self.pooler = Pooler(self.hidden_size, self.init_method)
                self._pooler_key = 'pooler'

437
    def set_input_tensor(self, input_tensor):
438
        """ See megatron.model.transformer.set_input_tensor()"""
439
440
441
442
443
444

        # This is usually handled in schedules.py but some inference code still
        # gives us non-lists or None
        if not isinstance(input_tensor, list):
            input_tensor = [input_tensor]

445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
        if self.add_encoder and self.add_decoder:
            assert len(input_tensor) == 1, \
                'input_tensor should only be length 1 for stage with both encoder and decoder'
            self.encoder.set_input_tensor(input_tensor[0])
        elif self.add_encoder:
            assert len(input_tensor) == 1, \
                'input_tensor should only be length 1 for stage with only encoder'
            self.encoder.set_input_tensor(input_tensor[0])
        elif self.add_decoder:
            if len(input_tensor) == 2:
                self.decoder.set_input_tensor(input_tensor[0])
                self.encoder_hidden_state = input_tensor[1]
            elif len(input_tensor) == 1:
                self.decoder.set_input_tensor(None)
                self.encoder_hidden_state = input_tensor[0]
            else:
                raise Exception('input_tensor must have either length 1 or 2')
        else:
            raise Exception('Stage must have at least either encoder or decoder')
464
465
466

    def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
                dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
467
                ret_input_ids=None, ret_position_ids=None, ret_attn_mask=None,
468
                enc_dec_attn_mask=None, tokentype_ids=None,
mshoeybi's avatar
mshoeybi committed
469
                inference_params=None,
470
                pooling_sequence_index=0,
471
                enc_hidden_states=None, output_enc_hidden=False):
472

Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
473
474
475
476
477
478
479
        # Retriever embedding.
        if self.retriever and self.pre_process:
            retriever_input = self.embedding(ret_input_ids, ret_position_ids,
                                             tokentype_ids=tokentype_ids)
        else:
            retriever_input = None

480
        # Encoder embedding.
481
        if self.pre_process:
482
483
            encoder_input = self.embedding(enc_input_ids, enc_position_ids,
                                           tokentype_ids=tokentype_ids)
484
        else:
485
            encoder_input = None
486

Mostofa Patwary's avatar
Mostofa Patwary committed
487
488
489
490
491
492
493
494
495
        # Rotary positional embeddings
        rotary_pos_emb = None
        if self.use_rotary_position_embeddings:
            if inference_params is not None:
                rotary_pos_emb = \
                    self.rotary_pos_emb(inference_params.max_sequence_len)
            else:
                rotary_pos_emb = self.rotary_pos_emb(self.seq_length)

496
        # Run encoder.
497
        if enc_hidden_states is None:
498
            if self.encoder is not None:
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
499
500
501
502
503
504
505
506
507
508
509
                if self.retriever:
                    encoder_output = self.encoder(
                        encoder_input,
                        enc_attn_mask,
                        retriever_output=retriever_input,
                        retriever_attn_mask=ret_attn_mask,
                        inference_params=inference_params)
                else:
                    encoder_output = self.encoder(
                        encoder_input,
                        enc_attn_mask,
Mostofa Patwary's avatar
Mostofa Patwary committed
510
511
                        inference_params=inference_params,
                        rotary_pos_emb=rotary_pos_emb)
512
513
            else:
                encoder_output = self.encoder_hidden_state
514
515
516
        else:
            encoder_output = enc_hidden_states.to(encoder_input.dtype)

517
        if self.post_process:
518
519
520
521
            if self.add_pooler:
                pooled_output = self.pooler(encoder_output,
                                            pooling_sequence_index)

Vijay Korthikanti's avatar
Vijay Korthikanti committed
522
523
524
525
        # output_enc_hidden refers to when we just need the encoder's
        # output. For example, it is helpful to compute
        # similarity between two sequences by average pooling
        if not self.add_decoder or output_enc_hidden:
526
            if self.add_pooler and self.post_process:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
527
                return encoder_output, pooled_output
528
            else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
529
530
                return encoder_output

531
532
533
534
535
536
537
538
        # Decoder embedding.
        if self.pre_process:
            decoder_input = self.embedding(dec_input_ids,
                                           dec_position_ids)
        else:
            decoder_input = None

        # Run decoder.
539
        decoder_output = self.decoder(
540
            decoder_input,
541
542
543
            dec_attn_mask,
            encoder_output=encoder_output,
            enc_dec_attn_mask=enc_dec_attn_mask,
Mostofa Patwary's avatar
Mostofa Patwary committed
544
545
            inference_params=inference_params,
            rotary_pos_emb=rotary_pos_emb)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
546

547
        if self.add_pooler and self.post_process:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
548
549
550
            return decoder_output, encoder_output, pooled_output
        else:
            return decoder_output, encoder_output
551

552
    def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
553
554
555
        """For easy load."""

        state_dict_ = {}
556
        if self.pre_process:
557
            state_dict_[self._embedding_key] \
558
559
                = self.embedding.state_dict_for_save_checkpoint(prefix=prefix,
                                                                keep_vars=keep_vars)
560
561
        if self.add_encoder:
            state_dict_[self._encoder_key] \
562
563
                = self.encoder.state_dict_for_save_checkpoint(prefix=prefix,
                                                              keep_vars=keep_vars)
564
        if self.post_process:
565
566
            if self.add_pooler:
                state_dict_[self._pooler_key] \
567
568
                    = self.pooler.state_dict_for_save_checkpoint(prefix=prefix,
                                                                 keep_vars=keep_vars)
569
570
        if self.add_decoder:
            state_dict_[self._decoder_key] \
571
572
                = self.decoder.state_dict_for_save_checkpoint(prefix=prefix,
                                                              keep_vars=keep_vars)
573
574
575
576
577
578
579

        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        # Embedding.
580
        if self.pre_process:
581
582
583
584
585
586
587
588
589
            if self._embedding_key in state_dict:
                state_dict_ = state_dict[self._embedding_key]
            else:
                # for backward compatibility.
                state_dict_ = {}
                for key in state_dict.keys():
                    if '_embeddings' in key:
                        state_dict_[key] = state_dict[key]
            self.embedding.load_state_dict(state_dict_, strict=strict)
590

591
        # Encoder.
592
593
594
595
596
597
        if self.add_encoder:
            if self._encoder_key in state_dict:
                state_dict_ = state_dict[self._encoder_key]
            # For backward compatibility.
            elif 'transformer' in state_dict:
                state_dict_ = state_dict['transformer']
598
            else:
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
                # For backward compatibility.
                state_dict_ = {}
                for key in state_dict.keys():
                    if 'transformer.' in key:
                        state_dict_[key.split('transformer.')[1]] = state_dict[key]

            # For backward compatibility.
            state_dict_self_attention = {}
            for key in state_dict_.keys():
                if '.attention.' in key:
                    state_dict_self_attention[key.replace(".attention.",
                        ".self_attention.")] = state_dict_[key]
                else:
                    state_dict_self_attention[key] = state_dict_[key]
            state_dict_ = state_dict_self_attention

            self.encoder.load_state_dict(state_dict_, strict=strict)

        # Pooler.
618
        if self.post_process:
619
620
621
622
623
            if self.add_pooler:
                assert 'pooler' in state_dict, \
                    'could not find data for pooler in the checkpoint'
                self.pooler.load_state_dict(state_dict[self._pooler_key],
                                            strict=strict)
624
        # Decoder.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
625
626
        if self.add_decoder:
            assert 'decoder' in state_dict, \
627
                'could not find data for pooler in the checkpoint'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
628
629
            self.decoder.load_state_dict(state_dict[self._decoder_key],
                                         strict=strict)