Commit f8e98d67 authored by Rémi Louf's avatar Rémi Louf
Browse files

load pretrained embeddings in Bert decoder

In Rothe et al.'s "Leveraging Pre-trained Checkpoints for Sequence
Generation Tasks", Bert2Bert is initialized with pre-trained weights for
the encoder, and only pre-trained embeddings for the decoder. The
current version of the code completely randomizes the weights of the
decoder.

We write a custom function to initiliaze the weights of the decoder; we
first initialize the decoder with the weights and then randomize
everything but the embeddings.
parent 1e68c286
...@@ -1348,15 +1348,14 @@ class Bert2Rnd(BertPreTrainedModel): ...@@ -1348,15 +1348,14 @@ class Bert2Rnd(BertPreTrainedModel):
self.encoder = BertModel(config) self.encoder = BertModel(config)
self.decoder = BertDecoderModel(config) self.decoder = BertDecoderModel(config)
self.init_weights()
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_or_path, *model_args, **model_kwargs): def from_pretrained(cls, pretrained_model_or_path, *model_args, **model_kwargs):
""" Load the pretrained weights in the encoder. """ Load the pretrained weights in the encoder.
Since the decoder needs to be initialized with random weights, and the encoder with The encoder of `Bert2Rand` is initialized with pretrained weights; the
pretrained weights we need to override the `from_pretrained` method of the base `PreTrainedModel` weights of the decoder are initialized at random except the embeddings
class. which are initialized with the pretrained embeddings. We thus need to override
the base class' `from_pretrained` method.
""" """
# Load the configuration # Load the configuration
...@@ -1374,10 +1373,26 @@ class Bert2Rnd(BertPreTrainedModel): ...@@ -1374,10 +1373,26 @@ class Bert2Rnd(BertPreTrainedModel):
) )
model = cls(config) model = cls(config)
# The encoder is loaded with pretrained weights # We load the encoder with pretrained weights
pretrained_encoder = BertModel.from_pretrained(pretrained_model_or_path, *model_args, **model_kwargs) pretrained_encoder = BertModel.from_pretrained(pretrained_model_or_path, *model_args, **model_kwargs)
model.encoder = pretrained_encoder model.encoder = pretrained_encoder
# We load the decoder with pretrained weights and then randomize all weights but embeddings-related one.
def randomize_decoder_weights(module):
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=config.initializer_range)
elif isinstance(module, BertLayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
pretrained_decoder = BertDecoderModel.from_pretrained(pretrained_model_or_path, *model_args, **model_kwargs)
pretrained_decoder.apply(randomize_decoder_weights)
model.decoder = pretrained_decoder
return model return model
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
...@@ -1386,11 +1401,9 @@ class Bert2Rnd(BertPreTrainedModel): ...@@ -1386,11 +1401,9 @@ class Bert2Rnd(BertPreTrainedModel):
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask) head_mask=head_mask)
encoder_output = encoder_outputs[0]
decoder_outputs = self.decoder(input_ids, decoder_outputs = self.decoder(input_ids,
encoder_output, encoder_outputs[0],
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask) head_mask=head_mask)
return decoder_outputs[0] return decoder_outputs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment