Commit 1724cee8 authored by thomwolf's avatar thomwolf
Browse files

switch from properties to methods

parent 9b45d0f8
...@@ -281,11 +281,10 @@ class XxxModel(XxxPreTrainedModel): ...@@ -281,11 +281,10 @@ class XxxModel(XxxPreTrainedModel):
self.init_weights() self.init_weights()
@property @property
def input_embeddings(self): def get_input_embeddings(self):
return self.embeddings.word_embeddings return self.embeddings.word_embeddings
@input_embeddings.setter def set_input_embeddings(self, new_embeddings):
def input_embeddings(self, new_embeddings):
self.embeddings.word_embeddings = new_embeddings self.embeddings.word_embeddings = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
...@@ -382,8 +381,7 @@ class XxxForMaskedLM(XxxPreTrainedModel): ...@@ -382,8 +381,7 @@ class XxxForMaskedLM(XxxPreTrainedModel):
self.init_weights() self.init_weights()
@property def get_output_embeddings(self):
def output_embeddings(self):
return self.lm_head return self.lm_head
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
......
...@@ -601,13 +601,11 @@ class BertModel(BertPreTrainedModel): ...@@ -601,13 +601,11 @@ class BertModel(BertPreTrainedModel):
self.init_weights() self.init_weights()
@property def get_input_embeddings(self):
def input_embeddings(self):
return self.embeddings.word_embeddings return self.embeddings.word_embeddings
@input_embeddings.setter def set_input_embeddings(self, value):
def input_embeddings(self, new_embeddings): self.embeddings.word_embeddings = value
self.embeddings.word_embeddings = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """ Prunes heads of the model.
...@@ -753,8 +751,7 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -753,8 +751,7 @@ class BertForPreTraining(BertPreTrainedModel):
self.init_weights() self.init_weights()
@property def get_output_embeddings(self):
def output_embeddings(self):
return self.cls.predictions.decoder return self.cls.predictions.decoder
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
...@@ -829,8 +826,7 @@ class BertForMaskedLM(BertPreTrainedModel): ...@@ -829,8 +826,7 @@ class BertForMaskedLM(BertPreTrainedModel):
self.init_weights() self.init_weights()
@property def get_output_embeddings(self):
def output_embeddings(self):
return self.cls.predictions.decoder return self.cls.predictions.decoder
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
......
...@@ -289,12 +289,10 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -289,12 +289,10 @@ class CTRLModel(CTRLPreTrainedModel):
self.init_weights() self.init_weights()
@property def get_input_embeddings(self):
def input_embeddings(self):
return self.w return self.w
@input_embeddings.setter def set_input_embeddings(self, new_embeddings):
def input_embeddings(self, new_embeddings):
self.w = new_embeddings self.w = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
...@@ -454,8 +452,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): ...@@ -454,8 +452,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
self.init_weights() self.init_weights()
@property def get_output_embeddings(self):
def output_embeddings(self):
return self.lm_head return self.lm_head
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
......
...@@ -421,12 +421,10 @@ class DistilBertModel(DistilBertPreTrainedModel): ...@@ -421,12 +421,10 @@ class DistilBertModel(DistilBertPreTrainedModel):
self.init_weights() self.init_weights()
@property def get_input_embeddings(self):
def input_embeddings(self):
return self.embeddings.word_embeddings return self.embeddings.word_embeddings
@input_embeddings.setter def set_input_embeddings(self, new_embeddings):
def input_embeddings(self, new_embeddings):
self.embeddings.word_embeddings = new_embeddings self.embeddings.word_embeddings = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
...@@ -513,8 +511,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel): ...@@ -513,8 +511,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1) self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
@property def get_output_embeddings(self):
def output_embeddings(self):
return self.vocab_projector return self.vocab_projector
def forward(self, input_ids, attention_mask=None, head_mask=None, masked_lm_labels=None): def forward(self, input_ids, attention_mask=None, head_mask=None, masked_lm_labels=None):
......
...@@ -357,12 +357,10 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -357,12 +357,10 @@ class GPT2Model(GPT2PreTrainedModel):
self.init_weights() self.init_weights()
@property def get_input_embeddings(self):
def input_embeddings(self):
return self.wte return self.wte
@input_embeddings.setter def set_input_embeddings(self, new_embeddings):
def input_embeddings(self, new_embeddings):
self.wte = new_embeddings self.wte = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
...@@ -519,8 +517,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -519,8 +517,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
self.init_weights() self.init_weights()
@property def get_output_embeddings(self):
def output_embeddings(self):
return self.lm_head return self.lm_head
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
...@@ -623,8 +620,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -623,8 +620,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
self.init_weights() self.init_weights()
@property def get_output_embeddings(self):
def output_embeddings(self):
return self.lm_head return self.lm_head
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
......
...@@ -360,12 +360,10 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -360,12 +360,10 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
self.init_weights() self.init_weights()
@property def get_input_embeddings(self):
def input_embeddings(self):
return self.tokens_embed return self.tokens_embed
@input_embeddings.setter def set_input_embeddings(self, new_embeddings):
def input_embeddings(self, new_embeddings):
self.tokens_embed = new_embeddings self.tokens_embed = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
...@@ -494,8 +492,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): ...@@ -494,8 +492,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
self.init_weights() self.init_weights()
@property def get_output_embeddings(self):
def output_embeddings(self):
return self.lm_head return self.lm_head
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
...@@ -584,8 +581,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): ...@@ -584,8 +581,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
self.init_weights() self.init_weights()
@property def get_output_embeddings(self):
def output_embeddings(self):
return self.lm_head return self.lm_head
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
......
...@@ -169,10 +169,11 @@ class RobertaModel(BertModel): ...@@ -169,10 +169,11 @@ class RobertaModel(BertModel):
self.embeddings = RobertaEmbeddings(config) self.embeddings = RobertaEmbeddings(config)
self.init_weights() self.init_weights()
@property def get_input_embeddings(self):
def input_embeddings(self):
return self.embeddings.word_embeddings return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_emebddings = value
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, @add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """,
ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING) ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING)
...@@ -218,8 +219,7 @@ class RobertaForMaskedLM(BertPreTrainedModel): ...@@ -218,8 +219,7 @@ class RobertaForMaskedLM(BertPreTrainedModel):
self.init_weights() self.init_weights()
@property def get_output_embeddings(self):
def output_embeddings(self):
return self.lm_head.decoder return self.lm_head.decoder
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
......
...@@ -639,12 +639,10 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -639,12 +639,10 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
self.init_weights() self.init_weights()
@property def get_input_embeddings(self):
def input_embeddings(self):
return self.word_emb return self.word_emb
@input_embeddings.setter def set_input_embeddings(self, new_embeddings):
def input_embeddings(self, new_embeddings):
self.word_emb = new_embeddings self.word_emb = new_embeddings
def backward_compatible(self): def backward_compatible(self):
......
...@@ -87,21 +87,37 @@ class PreTrainedModel(nn.Module): ...@@ -87,21 +87,37 @@ class PreTrainedModel(nn.Module):
def base_model(self): def base_model(self):
return getattr(self, self.base_model_prefix, self) return getattr(self, self.base_model_prefix, self)
@property def get_input_embeddings(self):
def input_embeddings(self): """ Get model's input embeddings
"""
base_model = getattr(self, self.base_model_prefix, self) base_model = getattr(self, self.base_model_prefix, self)
return base_model.input_embeddings if base_model is not self:
return base_model.get_input_embeddings()
else:
raise NotImplementedError
@property def set_input_embeddings(self, value):
def output_embeddings(self): """ Set model's input embeddings
"""
base_model = getattr(self, self.base_model_prefix, self)
if base_model is not self:
base_model.set_input_embeddings(value)
else:
raise NotImplementedError
def get_output_embeddings(self):
""" Get model's output embeddings
Return None if the model doesn't have output embeddings
"""
return None # Overwrite for models with output embeddings return None # Overwrite for models with output embeddings
def tie_weights(self): def tie_weights(self):
""" Make sure we are sharing the input and output embeddings. """ Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead. Export to TorchScript can't handle parameter sharing so we are cloning them instead.
""" """
if self.output_embeddings is not None: output_embeddings = self.get_output_embeddings()
self._tie_or_clone_weights(self.output_embeddings, self.input_embeddings) if output_embeddings is not None:
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
def _tie_or_clone_weights(self, output_embeddings, input_embeddings): def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
""" Tie or clone module weights depending of weither we are using TorchScript or not """ Tie or clone module weights depending of weither we are using TorchScript or not
...@@ -150,9 +166,10 @@ class PreTrainedModel(nn.Module): ...@@ -150,9 +166,10 @@ class PreTrainedModel(nn.Module):
return model_embeds return model_embeds
def _resize_token_embeddings(self, new_num_tokens): def _resize_token_embeddings(self, new_num_tokens):
old_embeddings = self.input_embeddings old_embeddings = self.get_input_embeddings()
self.input_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens) new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
return self.input_embeddings self.set_input_embeddings(new_embeddings)
return self.get_input_embeddings()
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None): def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
""" Build a resized Embedding Module from a provided token Embedding Module. """ Build a resized Embedding Module from a provided token Embedding Module.
......
...@@ -407,12 +407,10 @@ class XLMModel(XLMPreTrainedModel): ...@@ -407,12 +407,10 @@ class XLMModel(XLMPreTrainedModel):
self.init_weights() self.init_weights()
@property def get_input_embeddings(self):
def input_embeddings(self):
return self.embeddings return self.embeddings
@input_embeddings.setter def set_input_embeddings(self, new_embeddings):
def input_embeddings(self, new_embeddings):
self.embeddings = new_embeddings self.embeddings = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
...@@ -623,8 +621,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): ...@@ -623,8 +621,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
self.init_weights() self.init_weights()
@property def get_output_embeddings(self):
def output_embeddings(self):
return self.pred_layer.proj return self.pred_layer.proj
def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None, def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
......
...@@ -611,12 +611,10 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -611,12 +611,10 @@ class XLNetModel(XLNetPreTrainedModel):
self.init_weights() self.init_weights()
@property def get_input_embeddings(self):
def input_embeddings(self):
return self.word_embedding return self.word_embedding
@input_embeddings.setter def set_input_embeddings(self, new_embeddings):
def input_embeddings(self, new_embeddings):
self.word_embedding = new_embeddings self.word_embedding = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
...@@ -923,8 +921,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -923,8 +921,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
self.init_weights() self.init_weights()
@property def get_output_embeddings(self):
def output_embeddings(self):
return self.lm_loss return self.lm_loss
def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None, def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
......
...@@ -429,6 +429,12 @@ class CommonTestCases: ...@@ -429,6 +429,12 @@ class CommonTestCases:
list(hidden_states[0].shape[-2:]), list(hidden_states[0].shape[-2:]),
[self.model_tester.seq_length, self.model_tester.hidden_size]) [self.model_tester.seq_length, self.model_tester.hidden_size])
def test_debug(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model_embed = model.resize_token_embeddings(config.vocab_size + 10)
def test_resize_tokens_embeddings(self): def test_resize_tokens_embeddings(self):
original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if not self.test_resize_embeddings: if not self.test_resize_embeddings:
...@@ -468,9 +474,9 @@ class CommonTestCases: ...@@ -468,9 +474,9 @@ class CommonTestCases:
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config) model = model_class(config)
self.assertTrue(hasattr(model, 'input_embeddings')) model.get_input_embeddings()
setattr(model, 'input_embeddings', torch.nn.Embedding(10, 10)) model.set_input_embeddings(torch.nn.Embedding(10, 10))
self.assertTrue(hasattr(model, 'output_embeddings')) model.get_output_embeddings()
def test_tie_model_weights(self): def test_tie_model_weights(self):
if not self.test_torchscript: if not self.test_torchscript:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment