language_model.py 23.6 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4
5
6
7

"""Transformer based language model."""

import torch
import torch.nn.functional as F

Mohammad's avatar
Mohammad committed
8
from megatron import get_args
9
from megatron.core import mpu, tensor_parallel
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
10
11

from .enums import LayerType, AttnMaskType
12
from .module import MegatronModule
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
13
14
15
16
from .retro_transformer import ParallelRetroEncoder, ParallelRetroTransformer
from .transformer import ParallelTransformer
from .utils import get_linear_layer
from .utils import init_method_normal, scaled_init_method_normal
17

18

19
20
21
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
                       bias=None):
    """LM logits using word embedding weights."""
22
    args = get_args()
23
    # Parallel logits.
24
    if args.async_tensor_model_parallel_allreduce or\
Vijay Korthikanti's avatar
Vijay Korthikanti committed
25
            args.sequence_parallel:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
26
        input_parallel = input_
27
        model_parallel = mpu.get_tensor_model_parallel_world_size() > 1
28
        async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
29
            model_parallel and not args.sequence_parallel
30
    else:
31
        input_parallel = tensor_parallel.copy_to_tensor_model_parallel_region(input_)
32
33
        async_grad_allreduce = False

34
    # Matrix multiply.
35
    logits_parallel = tensor_parallel.linear_with_grad_accumulation_and_async_allreduce(
36
37
38
39
40
41
        input=input_parallel,
        weight=word_embeddings_weight,
        bias=bias,
        gradient_accumulation_fusion=args.gradient_accumulation_fusion,
        async_grad_allreduce=async_grad_allreduce,
        sequence_parallel_enabled=args.sequence_parallel)
42
    # Gather if needed.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
43

44
45
    if parallel_output:
        return logits_parallel
Mohammad's avatar
Mohammad committed
46

47
    return tensor_parallel.gather_from_tensor_model_parallel_region(logits_parallel)
Mohammad's avatar
Mohammad committed
48
49


50
def get_language_model(num_tokentypes, add_pooler,
51
                       encoder_attn_mask_type, init_method=None,
52
53
                       scaled_init_method=None, add_encoder=True,
                       add_decoder=False,
54
55
                       decoder_attn_mask_type=AttnMaskType.causal,
                       pre_process=True, post_process=True):
Mohammad's avatar
Mohammad committed
56
    """Build language model and return along with the key to save."""
57
    args = get_args()
Mohammad's avatar
Mohammad committed
58

59
60
61
62
    if init_method is None:
        init_method = init_method_normal(args.init_method_std)

    if scaled_init_method is None:
63
64
        scaled_init_method = scaled_init_method_normal(args.init_method_std,
                                                       args.num_layers)
65

66
    # Language model.
67
68
69
70
71
    language_model = TransformerLanguageModel(
        init_method,
        scaled_init_method,
        encoder_attn_mask_type,
        num_tokentypes=num_tokentypes,
72
        add_encoder=add_encoder,
73
74
75
76
77
78
        add_decoder=add_decoder,
        decoder_attn_mask_type=decoder_attn_mask_type,
        add_pooler=add_pooler,
        pre_process=pre_process,
        post_process=post_process
    )
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
    # key used for checkpoints.
    language_model_key = 'language_model'

    return language_model, language_model_key


class Pooler(MegatronModule):
    """Pooler layer.

    Pool hidden states of a specific token (for example start of the
    sequence) and add a linear transformation followed by a tanh.

    Arguments:
        hidden_size: hidden size
        init_method: weight initialization method for the linear layer.
            bias is set to zero.
    """
Neel Kant's avatar
Neel Kant committed
96

97
98
    def __init__(self, hidden_size, init_method):
        super(Pooler, self).__init__()
99
        args = get_args()
100
        self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
101
102
        self.sequence_parallel = args.sequence_parallel

103
104

    def forward(self, hidden_states, sequence_index=0):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
105
        # hidden_states: [s, b, h]
106
        # sequence_index: index of the token to pool.
107
108
109
110

        # gather data along sequence dimensions
        # same pooler is run on all tensor parallel nodes
        if self.sequence_parallel:
111
            hidden_states = tensor_parallel.gather_from_sequence_parallel_region(
112
                hidden_states,
113
                tensor_parallel_output_grad=False)
114

Vijay Korthikanti's avatar
Vijay Korthikanti committed
115
        pooled = hidden_states[sequence_index, :, :]
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
        pooled = self.dense(pooled)
        pooled = torch.tanh(pooled)
        return pooled


class Embedding(MegatronModule):
    """Language model embeddings.

    Arguments:
        hidden_size: hidden size
        vocab_size: vocabulary size
        max_sequence_length: maximum size of sequence. This
                             is used for positional embedding
        embedding_dropout_prob: dropout probability for embeddings
        init_method: weight initialization method
        num_tokentypes: size of the token-type embeddings. 0 value
                        will ignore this embedding
    """
Neel Kant's avatar
Neel Kant committed
134

135
136
137
138
139
140
141
142
143
144
145
146
147
    def __init__(self,
                 hidden_size,
                 vocab_size,
                 max_sequence_length,
                 embedding_dropout_prob,
                 init_method,
                 num_tokentypes=0):
        super(Embedding, self).__init__()

        self.hidden_size = hidden_size
        self.init_method = init_method
        self.num_tokentypes = num_tokentypes

148
149
        args = get_args()

150
        # Word embeddings (parallel).
151
        self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
152
            vocab_size, self.hidden_size,
153
154
155
156
157
            init_method=self.init_method,
            params_dtype=args.params_dtype,
            use_cpu_initialization=args.use_cpu_initialization,
            perform_initialization=args.perform_initialization
        )
158
159
160
161
162
163
164
        self._word_embeddings_key = 'word_embeddings'

        # Position embedding (serial).
        self.position_embeddings = torch.nn.Embedding(
            max_sequence_length, self.hidden_size)
        self._position_embeddings_key = 'position_embeddings'
        # Initialize the position embeddings.
165
166
        if args.perform_initialization:
            self.init_method(self.position_embeddings.weight)
167
168
169
170
171
172
173
174
175
176

        # Token type embedding.
        # Add this as an optional field that can be added through
        # method call so we can load a pretrain model without
        # token types and add them as needed.
        self._tokentype_embeddings_key = 'tokentype_embeddings'
        if self.num_tokentypes > 0:
            self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes,
                                                           self.hidden_size)
            # Initialize the token-type embeddings.
177
178
            if args.perform_initialization:
                self.init_method(self.tokentype_embeddings.weight)
179
180
181
        else:
            self.tokentype_embeddings = None

182
        self.fp32_residual_connection = args.fp32_residual_connection 
Vijay Korthikanti's avatar
Vijay Korthikanti committed
183
        self.sequence_parallel = args.sequence_parallel
184
185
186
        # Embeddings dropout
        self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)

187
188
189
    def zero_parameters(self):
        """Zero out all parameters in embedding."""
        self.word_embeddings.weight.data.fill_(0)
Deepak Narayanan's avatar
Deepak Narayanan committed
190
        self.word_embeddings.weight.shared = True
191
        self.position_embeddings.weight.data.fill_(0)
Deepak Narayanan's avatar
Deepak Narayanan committed
192
        self.position_embeddings.weight.shared = True
193
194
        if self.num_tokentypes > 0:
            self.tokentype_embeddings.weight.data.fill_(0)
Deepak Narayanan's avatar
Deepak Narayanan committed
195
            self.tokentype_embeddings.weight.shared = True
196

197
198
199
200
201
202
203
204
205
206
207
208
209
210
    def add_tokentype_embeddings(self, num_tokentypes):
        """Add token-type embedding. This function is provided so we can add
        token-type embeddings in case the pretrained model does not have it.
        This allows us to load the model normally and then add this embedding.
        """
        if self.tokentype_embeddings is not None:
            raise Exception('tokentype embeddings is already initialized')
        if torch.distributed.get_rank() == 0:
            print('adding embedding for {} tokentypes'.format(num_tokentypes),
                  flush=True)
        self.num_tokentypes = num_tokentypes
        self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes,
                                                       self.hidden_size)
        # Initialize the token-type embeddings.
211
        args = get_args()
212
213
214
215
216
217
218
219
220
221
222
223
224
        self.init_method(self.tokentype_embeddings.weight)

    def forward(self, input_ids, position_ids, tokentype_ids=None):
        # Embeddings.
        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        embeddings = words_embeddings + position_embeddings
        if tokentype_ids is not None:
            assert self.tokentype_embeddings is not None
            embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
        else:
            assert self.tokentype_embeddings is None

225
        # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
Vijay Korthikanti's avatar
Vijay Korthikanti committed
226
227
        embeddings = embeddings.transpose(0, 1).contiguous()

228
229
        # If the input flag for fp32 residual connection is set, convert for float.
        if self.fp32_residual_connection:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
230
            embeddings = embeddings.float()
231

232
        # Dropout.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
233
        if self.sequence_parallel:
234
235
            embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings)
            with tensor_parallel.get_cuda_rng_tracker().fork():
236
237
238
                embeddings = self.embedding_dropout(embeddings)
        else:
            embeddings = self.embedding_dropout(embeddings)
239
240
241

        return embeddings

242
    def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
243
244
245
246
        """For easy load."""

        state_dict_ = {}
        state_dict_[self._word_embeddings_key] \
247
248
            = self.word_embeddings.state_dict(prefix=prefix,
                                              keep_vars=keep_vars)
249
        state_dict_[self._position_embeddings_key] \
250
251
            = self.position_embeddings.state_dict(prefix=prefix,
                                                  keep_vars=keep_vars)
252
253
        if self.num_tokentypes > 0:
            state_dict_[self._tokentype_embeddings_key] \
254
255
                = self.tokentype_embeddings.state_dict(prefix=prefix,
                                                       keep_vars=keep_vars)
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286

        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        # Word embedding.
        if self._word_embeddings_key in state_dict:
            state_dict_ = state_dict[self._word_embeddings_key]
        else:
            # for backward compatibility.
            state_dict_ = {}
            for key in state_dict.keys():
                if 'word_embeddings' in key:
                    state_dict_[key.split('word_embeddings.')[1]] \
                        = state_dict[key]
        self.word_embeddings.load_state_dict(state_dict_, strict=strict)

        # Position embedding.
        if self._position_embeddings_key in state_dict:
            state_dict_ = state_dict[self._position_embeddings_key]
        else:
            # for backward compatibility.
            state_dict_ = {}
            for key in state_dict.keys():
                if 'position_embeddings' in key:
                    state_dict_[key.split('position_embeddings.')[1]] \
                        = state_dict[key]
        self.position_embeddings.load_state_dict(state_dict_, strict=strict)

        # Tokentype embedding.
Neel Kant's avatar
Neel Kant committed
287
        if self.num_tokentypes > 0:
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
            state_dict_ = {}
            if self._tokentype_embeddings_key in state_dict:
                state_dict_ = state_dict[self._tokentype_embeddings_key]
            else:
                # for backward compatibility.
                for key in state_dict.keys():
                    if 'tokentype_embeddings' in key:
                        state_dict_[key.split('tokentype_embeddings.')[1]] \
                            = state_dict[key]
            if len(state_dict_.keys()) > 0:
                self.tokentype_embeddings.load_state_dict(state_dict_,
                                                          strict=strict)
            else:
                print('***WARNING*** expected tokentype embeddings in the '
                      'checkpoint but could not find it', flush=True)


305
class TransformerLanguageModel(MegatronModule):
306
307
308
309
310
311
312
313
314
315
316
    """Transformer language model.

    Arguments:
        transformer_hparams: transformer hyperparameters
        vocab_size: vocabulary size
        max_sequence_length: maximum size of sequence. This
                             is used for positional embedding
        embedding_dropout_prob: dropout probability for embeddings
        num_tokentypes: size of the token-type embeddings. 0 value
                        will ignore this embedding
    """
Neel Kant's avatar
Neel Kant committed
317

318
    def __init__(self,
Mohammad's avatar
Mohammad committed
319
320
                 init_method,
                 output_layer_init_method,
321
                 encoder_attn_mask_type,
322
                 num_tokentypes=0,
323
                 add_encoder=True,
324
                 add_decoder=False,
325
                 decoder_attn_mask_type=AttnMaskType.causal,
326
327
328
329
                 add_pooler=False,
                 pre_process=True,
                 post_process=True):
        super(TransformerLanguageModel, self).__init__()
Mohammad's avatar
Mohammad committed
330
        args = get_args()
331

332
333
        self.pre_process = pre_process
        self.post_process = post_process
Mohammad's avatar
Mohammad committed
334
        self.hidden_size = args.hidden_size
335
        self.num_tokentypes = num_tokentypes
Mohammad's avatar
Mohammad committed
336
        self.init_method = init_method
337
        self.add_encoder = add_encoder
338
        self.encoder_attn_mask_type = encoder_attn_mask_type
339
        self.add_decoder = add_decoder
340
        self.decoder_attn_mask_type = decoder_attn_mask_type
341
        self.add_pooler = add_pooler
342
        self.encoder_hidden_state = None
343

344
        # Embeddings.
345
        if self.pre_process:
346
347
348
349
350
351
352
            self.embedding = Embedding(self.hidden_size,
                                       args.padded_vocab_size,
                                       args.max_position_embeddings,
                                       args.hidden_dropout,
                                       self.init_method,
                                       self.num_tokentypes)
            self._embedding_key = 'embedding'
353

Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
354
355
356
        # Retriever (bi-directional transformer with cross attention)
        if args.retro_add_retriever:
            self.retriever = ParallelRetroEncoder(
357
358
                self.init_method,
                output_layer_init_method,
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
359
                self_attn_mask_type=AttnMaskType.padding,
360
                pre_process=self.pre_process,
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
361
                post_process=False,
362
            )
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
            self._retriever_key = 'retriever'
        else:
            self.retriever = None

        # Encoder (usually set to True, False if part of an encoder-decoder
        # architecture and in encoder-only stage).
        if self.add_encoder:
            if args.retro_add_retriever:
                self.encoder = ParallelRetroTransformer(
                    self.init_method,
                    output_layer_init_method,
                    self_attn_mask_type=self.encoder_attn_mask_type,
                    pre_process=self.pre_process,
                    post_process=self.post_process,
                    retriever=self.retriever,
                )
            else:
                self.encoder = ParallelTransformer(
                    self.init_method,
                    output_layer_init_method,
                    self_attn_mask_type=self.encoder_attn_mask_type,
                    pre_process=self.pre_process,
                    post_process=self.post_process,
                )
387
388
389
390
391
392
            self._encoder_key = 'encoder'
        else:
            self.encoder = None

        # Decoder (usually set to False, True if part of an encoder-decoder
        # architecture and in decoder-only stage).
Vijay Korthikanti's avatar
Vijay Korthikanti committed
393
394
395
396
397
        if self.add_decoder:
            self.decoder = ParallelTransformer(
                self.init_method,
                output_layer_init_method,
                layer_type=LayerType.decoder,
398
399
400
                self_attn_mask_type=self.decoder_attn_mask_type,
                pre_process=self.pre_process,
                post_process=self.post_process)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
401
            self._decoder_key = 'decoder'
402
403
        else:
            self.decoder = None
404

405
        if self.post_process:
406
407
408
409
410
            # Pooler.
            if self.add_pooler:
                self.pooler = Pooler(self.hidden_size, self.init_method)
                self._pooler_key = 'pooler'

411
    def set_input_tensor(self, input_tensor):
412
        """ See megatron.model.transformer.set_input_tensor()"""
413
414
415
416
417
418

        # This is usually handled in schedules.py but some inference code still
        # gives us non-lists or None
        if not isinstance(input_tensor, list):
            input_tensor = [input_tensor]

419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
        if self.add_encoder and self.add_decoder:
            assert len(input_tensor) == 1, \
                'input_tensor should only be length 1 for stage with both encoder and decoder'
            self.encoder.set_input_tensor(input_tensor[0])
        elif self.add_encoder:
            assert len(input_tensor) == 1, \
                'input_tensor should only be length 1 for stage with only encoder'
            self.encoder.set_input_tensor(input_tensor[0])
        elif self.add_decoder:
            if len(input_tensor) == 2:
                self.decoder.set_input_tensor(input_tensor[0])
                self.encoder_hidden_state = input_tensor[1]
            elif len(input_tensor) == 1:
                self.decoder.set_input_tensor(None)
                self.encoder_hidden_state = input_tensor[0]
            else:
                raise Exception('input_tensor must have either length 1 or 2')
        else:
            raise Exception('Stage must have at least either encoder or decoder')
438
439
440

    def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
                dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
441
                ret_input_ids=None, ret_position_ids=None, ret_attn_mask=None,
442
                enc_dec_attn_mask=None, tokentype_ids=None,
mshoeybi's avatar
mshoeybi committed
443
                inference_params=None,
444
                pooling_sequence_index=0,
445
                enc_hidden_states=None, output_enc_hidden=False):
446

Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
447
448
449
450
451
452
453
        # Retriever embedding.
        if self.retriever and self.pre_process:
            retriever_input = self.embedding(ret_input_ids, ret_position_ids,
                                             tokentype_ids=tokentype_ids)
        else:
            retriever_input = None

454
        # Encoder embedding.
455
        if self.pre_process:
456
457
            encoder_input = self.embedding(enc_input_ids, enc_position_ids,
                                           tokentype_ids=tokentype_ids)
458
        else:
459
            encoder_input = None
460

461
        # Run encoder.
462
        if enc_hidden_states is None:
463
            if self.encoder is not None:
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
464
465
466
467
468
469
470
471
472
473
474
475
                if self.retriever:
                    encoder_output = self.encoder(
                        encoder_input,
                        enc_attn_mask,
                        retriever_output=retriever_input,
                        retriever_attn_mask=ret_attn_mask,
                        inference_params=inference_params)
                else:
                    encoder_output = self.encoder(
                        encoder_input,
                        enc_attn_mask,
                        inference_params=inference_params)
476
477
            else:
                encoder_output = self.encoder_hidden_state
478
479
480
        else:
            encoder_output = enc_hidden_states.to(encoder_input.dtype)

481
        if self.post_process:
482
483
484
485
            if self.add_pooler:
                pooled_output = self.pooler(encoder_output,
                                            pooling_sequence_index)

Vijay Korthikanti's avatar
Vijay Korthikanti committed
486
487
488
489
        # output_enc_hidden refers to when we just need the encoder's
        # output. For example, it is helpful to compute
        # similarity between two sequences by average pooling
        if not self.add_decoder or output_enc_hidden:
490
            if self.add_pooler and self.post_process:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
491
                return encoder_output, pooled_output
492
            else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
493
494
                return encoder_output

495
496
497
498
499
500
501
502
        # Decoder embedding.
        if self.pre_process:
            decoder_input = self.embedding(dec_input_ids,
                                           dec_position_ids)
        else:
            decoder_input = None

        # Run decoder.
503
        decoder_output = self.decoder(
504
            decoder_input,
505
506
507
            dec_attn_mask,
            encoder_output=encoder_output,
            enc_dec_attn_mask=enc_dec_attn_mask,
mshoeybi's avatar
mshoeybi committed
508
            inference_params=inference_params)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
509

510
        if self.add_pooler and self.post_process:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
511
512
513
            return decoder_output, encoder_output, pooled_output
        else:
            return decoder_output, encoder_output
514

515
    def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
516
517
518
        """For easy load."""

        state_dict_ = {}
519
        if self.pre_process:
520
            state_dict_[self._embedding_key] \
521
522
                = self.embedding.state_dict_for_save_checkpoint(prefix=prefix,
                                                                keep_vars=keep_vars)
523
524
        if self.add_encoder:
            state_dict_[self._encoder_key] \
525
526
                = self.encoder.state_dict_for_save_checkpoint(prefix=prefix,
                                                              keep_vars=keep_vars)
527
        if self.post_process:
528
529
            if self.add_pooler:
                state_dict_[self._pooler_key] \
530
531
                    = self.pooler.state_dict_for_save_checkpoint(prefix=prefix,
                                                                 keep_vars=keep_vars)
532
533
        if self.add_decoder:
            state_dict_[self._decoder_key] \
534
535
                = self.decoder.state_dict_for_save_checkpoint(prefix=prefix,
                                                              keep_vars=keep_vars)
536
537
538
539
540
541
542

        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        # Embedding.
543
        if self.pre_process:
544
545
546
547
548
549
550
551
552
            if self._embedding_key in state_dict:
                state_dict_ = state_dict[self._embedding_key]
            else:
                # for backward compatibility.
                state_dict_ = {}
                for key in state_dict.keys():
                    if '_embeddings' in key:
                        state_dict_[key] = state_dict[key]
            self.embedding.load_state_dict(state_dict_, strict=strict)
553

554
        # Encoder.
555
556
557
558
559
560
        if self.add_encoder:
            if self._encoder_key in state_dict:
                state_dict_ = state_dict[self._encoder_key]
            # For backward compatibility.
            elif 'transformer' in state_dict:
                state_dict_ = state_dict['transformer']
561
            else:
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
                # For backward compatibility.
                state_dict_ = {}
                for key in state_dict.keys():
                    if 'transformer.' in key:
                        state_dict_[key.split('transformer.')[1]] = state_dict[key]

            # For backward compatibility.
            state_dict_self_attention = {}
            for key in state_dict_.keys():
                if '.attention.' in key:
                    state_dict_self_attention[key.replace(".attention.",
                        ".self_attention.")] = state_dict_[key]
                else:
                    state_dict_self_attention[key] = state_dict_[key]
            state_dict_ = state_dict_self_attention

            self.encoder.load_state_dict(state_dict_, strict=strict)

        # Pooler.
581
        if self.post_process:
582
583
584
585
586
            if self.add_pooler:
                assert 'pooler' in state_dict, \
                    'could not find data for pooler in the checkpoint'
                self.pooler.load_state_dict(state_dict[self._pooler_key],
                                            strict=strict)
587
        # Decoder.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
588
589
        if self.add_decoder:
            assert 'decoder' in state_dict, \
590
                'could not find data for pooler in the checkpoint'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
591
592
            self.decoder.load_state_dict(state_dict[self._decoder_key],
                                         strict=strict)