"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "164c794eb356917237512a9755f26f3caf0a6255"
Commit 877ef2c6 authored by Rémi Louf's avatar Rémi Louf
Browse files

override `from_pretrained` in Bert2Rnd

In the seq2seq model we need to both load pretrained weights in the
encoder and initialize the decoder randomly. Because the
`from_pretrained` method defined in the base class relies on module
names to assign weights, it would also initialize the decoder with
pretrained weights. To avoid this we override the method to only
initialize the encoder with pretrained weights.
parent 851ef592
......@@ -1455,6 +1455,37 @@ class Bert2Rnd(BertPreTrainedModel):
self.init_weights()
@classmethod
def from_pretrained(cls, pretrained_model_or_path, *model_args, **model_kwargs):
""" Load the pretrained weights in the encoder.
Since the decoder needs to be initialized with random weights, and the encoder with
pretrained weights we need to override the `from_pretrained` method of the base `PreTrainedModel`
class.
"""
pretrained_encoder = BertModel.from_pretrained(pretrained_model_or_path, *model_args, **model_kwargs)
config = cls._load_config(pretrained_model_or_path, *model_args, **model_kwargs)
model = cls(config)
model.encoder = pretrained_encoder
return model
def _load_config(self, pretrained_model_name_or_path, *args, **kwargs):
config = kwargs.pop('config', None)
if config is None:
cache_dir = kwargs.pop('cache_dir', None)
force_download = kwargs.pop('force_download', False)
config, _ = self.config_class.from_pretrained(
pretrained_model_name_or_path,
*args,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
**kwargs
)
return config
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
encoder_outputs = self.encoder(input_ids,
attention_mask=attention_mask,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment