"tools/vscode:/vscode.git/clone" did not exist on "85aeae712ab3ec1193c6c503ae06352892dcd9df"
Commit 9b45d0f8 authored by thomwolf's avatar thomwolf
Browse files

Add common properties input_embeddings and output_embeddings

parent 8a628355
...@@ -280,12 +280,14 @@ class XxxModel(XxxPreTrainedModel): ...@@ -280,12 +280,14 @@ class XxxModel(XxxPreTrainedModel):
self.init_weights() self.init_weights()
def _resize_token_embeddings(self, new_num_tokens): @property
old_embeddings = self.embeddings.word_embeddings def input_embeddings(self):
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
self.embeddings.word_embeddings = new_embeddings
return self.embeddings.word_embeddings return self.embeddings.word_embeddings
@input_embeddings.setter
def input_embeddings(self, new_embeddings):
self.embeddings.word_embeddings = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """ Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
...@@ -376,17 +378,13 @@ class XxxForMaskedLM(XxxPreTrainedModel): ...@@ -376,17 +378,13 @@ class XxxForMaskedLM(XxxPreTrainedModel):
super(XxxForMaskedLM, self).__init__(config) super(XxxForMaskedLM, self).__init__(config)
self.transformer = XxxModel(config) self.transformer = XxxModel(config)
self.cls = XxxOnlyMLMHead(config) self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
self.init_weights() self.init_weights()
self.tie_weights()
def tie_weights(self): @property
""" Make sure we are sharing the input and output embeddings. def output_embeddings(self):
Export to TorchScript can't handle parameter sharing so we are cloning them instead. return self.lm_head
"""
self._tie_or_clone_weights(self.cls.predictions.decoder,
self.transformer.embeddings.word_embeddings)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
masked_lm_labels=None): masked_lm_labels=None):
......
...@@ -601,12 +601,14 @@ class BertModel(BertPreTrainedModel): ...@@ -601,12 +601,14 @@ class BertModel(BertPreTrainedModel):
self.init_weights() self.init_weights()
def _resize_token_embeddings(self, new_num_tokens): @property
old_embeddings = self.embeddings.word_embeddings def input_embeddings(self):
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
self.embeddings.word_embeddings = new_embeddings
return self.embeddings.word_embeddings return self.embeddings.word_embeddings
@input_embeddings.setter
def input_embeddings(self, new_embeddings):
self.embeddings.word_embeddings = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """ Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
...@@ -750,14 +752,10 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -750,14 +752,10 @@ class BertForPreTraining(BertPreTrainedModel):
self.cls = BertPreTrainingHeads(config) self.cls = BertPreTrainingHeads(config)
self.init_weights() self.init_weights()
self.tie_weights()
def tie_weights(self): @property
""" Make sure we are sharing the input and output embeddings. def output_embeddings(self):
Export to TorchScript can't handle parameter sharing so we are cloning them instead. return self.cls.predictions.decoder
"""
self._tie_or_clone_weights(self.cls.predictions.decoder,
self.bert.embeddings.word_embeddings)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
masked_lm_labels=None, next_sentence_label=None): masked_lm_labels=None, next_sentence_label=None):
...@@ -830,14 +828,10 @@ class BertForMaskedLM(BertPreTrainedModel): ...@@ -830,14 +828,10 @@ class BertForMaskedLM(BertPreTrainedModel):
self.cls = BertOnlyMLMHead(config) self.cls = BertOnlyMLMHead(config)
self.init_weights() self.init_weights()
self.tie_weights()
def tie_weights(self): @property
""" Make sure we are sharing the input and output embeddings. def output_embeddings(self):
Export to TorchScript can't handle parameter sharing so we are cloning them instead. return self.cls.predictions.decoder
"""
self._tie_or_clone_weights(self.cls.predictions.decoder,
self.bert.embeddings.word_embeddings)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
masked_lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None, lm_labels=None, ): masked_lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None, lm_labels=None, ):
......
...@@ -289,10 +289,14 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -289,10 +289,14 @@ class CTRLModel(CTRLPreTrainedModel):
self.init_weights() self.init_weights()
def _resize_token_embeddings(self, new_num_tokens): @property
self.w = self._get_resized_embeddings(self.w, new_num_tokens) def input_embeddings(self):
return self.w return self.w
@input_embeddings.setter
def input_embeddings(self, new_embeddings):
self.w = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """ Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
...@@ -449,13 +453,10 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): ...@@ -449,13 +453,10 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=True) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=True)
self.init_weights() self.init_weights()
self.tie_weights()
def tie_weights(self): @property
""" Make sure we are sharing the input and output embeddings. def output_embeddings(self):
Export to TorchScript can't handle parameter sharing so we are cloning them instead. return self.lm_head
"""
self._tie_or_clone_weights(self.lm_head, self.transformer.w)
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
labels=None): labels=None):
......
...@@ -334,9 +334,6 @@ class DistilBertPreTrainedModel(PreTrainedModel): ...@@ -334,9 +334,6 @@ class DistilBertPreTrainedModel(PreTrainedModel):
load_tf_weights = None load_tf_weights = None
base_model_prefix = "distilbert" base_model_prefix = "distilbert"
def __init__(self, *inputs, **kwargs):
super(DistilBertPreTrainedModel, self).__init__(*inputs, **kwargs)
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights. """ Initialize the weights.
""" """
...@@ -424,12 +421,14 @@ class DistilBertModel(DistilBertPreTrainedModel): ...@@ -424,12 +421,14 @@ class DistilBertModel(DistilBertPreTrainedModel):
self.init_weights() self.init_weights()
def _resize_token_embeddings(self, new_num_tokens): @property
old_embeddings = self.embeddings.word_embeddings def input_embeddings(self):
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
self.embeddings.word_embeddings = new_embeddings
return self.embeddings.word_embeddings return self.embeddings.word_embeddings
@input_embeddings.setter
def input_embeddings(self, new_embeddings):
self.embeddings.word_embeddings = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """ Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
...@@ -511,16 +510,12 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel): ...@@ -511,16 +510,12 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
self.vocab_projector = nn.Linear(config.dim, config.vocab_size) self.vocab_projector = nn.Linear(config.dim, config.vocab_size)
self.init_weights() self.init_weights()
self.tie_weights()
self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1) self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
def tie_weights(self): @property
""" Make sure we are sharing the input and output embeddings. def output_embeddings(self):
Export to TorchScript can't handle parameter sharing so we are cloning them instead. return self.vocab_projector
"""
self._tie_or_clone_weights(self.vocab_projector,
self.distilbert.embeddings.word_embeddings)
def forward(self, input_ids, attention_mask=None, head_mask=None, masked_lm_labels=None): def forward(self, input_ids, attention_mask=None, head_mask=None, masked_lm_labels=None):
dlbrt_output = self.distilbert(input_ids=input_ids, dlbrt_output = self.distilbert(input_ids=input_ids,
......
...@@ -357,10 +357,14 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -357,10 +357,14 @@ class GPT2Model(GPT2PreTrainedModel):
self.init_weights() self.init_weights()
def _resize_token_embeddings(self, new_num_tokens): @property
self.wte = self._get_resized_embeddings(self.wte, new_num_tokens) def input_embeddings(self):
return self.wte return self.wte
@input_embeddings.setter
def input_embeddings(self, new_embeddings):
self.wte = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """ Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
...@@ -514,14 +518,10 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -514,14 +518,10 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.init_weights() self.init_weights()
self.tie_weights()
def tie_weights(self): @property
""" Make sure we are sharing the input and output embeddings. def output_embeddings(self):
Export to TorchScript can't handle parameter sharing so we are cloning them instead. return self.lm_head
"""
self._tie_or_clone_weights(self.lm_head,
self.transformer.wte)
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
labels=None): labels=None):
...@@ -622,14 +622,10 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -622,14 +622,10 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
self.multiple_choice_head = SequenceSummary(config) self.multiple_choice_head = SequenceSummary(config)
self.init_weights() self.init_weights()
self.tie_weights()
def tie_weights(self): @property
""" Make sure we are sharing the input and output embeddings. def output_embeddings(self):
Export to TorchScript can't handle parameter sharing so we are cloning them instead. return self.lm_head
"""
self._tie_or_clone_weights(self.lm_head,
self.transformer.wte)
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
mc_token_ids=None, lm_labels=None, mc_labels=None): mc_token_ids=None, lm_labels=None, mc_labels=None):
......
...@@ -360,10 +360,14 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -360,10 +360,14 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
self.init_weights() self.init_weights()
def _resize_token_embeddings(self, new_num_tokens): @property
self.tokens_embed = self._get_resized_embeddings(self.tokens_embed, new_num_tokens) def input_embeddings(self):
return self.tokens_embed return self.tokens_embed
@input_embeddings.setter
def input_embeddings(self, new_embeddings):
self.tokens_embed = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """ Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
...@@ -489,14 +493,10 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): ...@@ -489,14 +493,10 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.init_weights() self.init_weights()
self.tie_weights()
def tie_weights(self): @property
""" Make sure we are sharing the input and output embeddings. def output_embeddings(self):
Export to TorchScript can't handle parameter sharing so we are cloning them instead. return self.lm_head
"""
self._tie_or_clone_weights(self.lm_head,
self.transformer.tokens_embed)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
labels=None): labels=None):
...@@ -583,14 +583,10 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): ...@@ -583,14 +583,10 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
self.multiple_choice_head = SequenceSummary(config) self.multiple_choice_head = SequenceSummary(config)
self.init_weights() self.init_weights()
self.tie_weights()
def tie_weights(self): @property
""" Make sure we are sharing the input and output embeddings. def output_embeddings(self):
Export to TorchScript can't handle parameter sharing so we are cloning them instead. return self.lm_head
"""
self._tie_or_clone_weights(self.lm_head,
self.transformer.tokens_embed)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
mc_token_ids=None, lm_labels=None, mc_labels=None): mc_token_ids=None, lm_labels=None, mc_labels=None):
......
...@@ -169,6 +169,10 @@ class RobertaModel(BertModel): ...@@ -169,6 +169,10 @@ class RobertaModel(BertModel):
self.embeddings = RobertaEmbeddings(config) self.embeddings = RobertaEmbeddings(config)
self.init_weights() self.init_weights()
@property
def input_embeddings(self):
return self.embeddings.word_embeddings
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, @add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """,
ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING) ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING)
...@@ -213,13 +217,10 @@ class RobertaForMaskedLM(BertPreTrainedModel): ...@@ -213,13 +217,10 @@ class RobertaForMaskedLM(BertPreTrainedModel):
self.lm_head = RobertaLMHead(config) self.lm_head = RobertaLMHead(config)
self.init_weights() self.init_weights()
self.tie_weights()
def tie_weights(self): @property
""" Make sure we are sharing the input and output embeddings. def output_embeddings(self):
Export to TorchScript can't handle parameter sharing so we are cloning them instead. return self.lm_head.decoder
"""
self._tie_or_clone_weights(self.lm_head.decoder, self.roberta.embeddings.word_embeddings)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
masked_lm_labels=None): masked_lm_labels=None):
......
...@@ -639,9 +639,14 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -639,9 +639,14 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
self.init_weights() self.init_weights()
def _resize_token_embeddings(self, new_num_tokens): @property
def input_embeddings(self):
return self.word_emb return self.word_emb
@input_embeddings.setter
def input_embeddings(self, new_embeddings):
self.word_emb = new_embeddings
def backward_compatible(self): def backward_compatible(self):
self.sample_softmax = -1 self.sample_softmax = -1
...@@ -826,7 +831,6 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): ...@@ -826,7 +831,6 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
self.crit = ProjectedAdaptiveLogSoftmax(config.n_token, config.d_embed, config.d_model, self.crit = ProjectedAdaptiveLogSoftmax(config.n_token, config.d_embed, config.d_model,
config.cutoffs, div_val=config.div_val) config.cutoffs, div_val=config.div_val)
self.init_weights() self.init_weights()
self.tie_weights()
def tie_weights(self): def tie_weights(self):
""" """
......
...@@ -83,55 +83,43 @@ class PreTrainedModel(nn.Module): ...@@ -83,55 +83,43 @@ class PreTrainedModel(nn.Module):
# Save config in model # Save config in model
self.config = config self.config = config
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None): @property
""" Build a resized Embedding Module from a provided token Embedding Module. def base_model(self):
Increasing the size will add newly initialized vectors at the end return getattr(self, self.base_model_prefix, self)
Reducing the size will remove vectors from the end
@property
Args: def input_embeddings(self):
new_num_tokens: (`optional`) int base_model = getattr(self, self.base_model_prefix, self)
New number of tokens in the embedding matrix. return base_model.input_embeddings
Increasing the size will add newly initialized vectors at the end
Reducing the size will remove vectors from the end @property
If not provided or None: return the provided token Embedding Module. def output_embeddings(self):
Return: ``torch.nn.Embeddings`` return None # Overwrite for models with output embeddings
Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
def tie_weights(self):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
""" """
if new_num_tokens is None: if self.output_embeddings is not None:
return old_embeddings self._tie_or_clone_weights(self.output_embeddings, self.input_embeddings)
old_num_tokens, old_embedding_dim = old_embeddings.weight.size() def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
if old_num_tokens == new_num_tokens:
return old_embeddings
# Build new embeddings
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
new_embeddings.to(old_embeddings.weight.device)
# initialize all new embeddings (in particular added tokens)
self._init_weights(new_embeddings)
# Copy word embeddings from the previous weights
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
return new_embeddings
def _tie_or_clone_weights(self, first_module, second_module):
""" Tie or clone module weights depending of weither we are using TorchScript or not """ Tie or clone module weights depending of weither we are using TorchScript or not
""" """
if self.config.torchscript: if self.config.torchscript:
first_module.weight = nn.Parameter(second_module.weight.clone()) output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone())
else: else:
first_module.weight = second_module.weight output_embeddings.weight = input_embeddings.weight
if hasattr(first_module, 'bias') and first_module.bias is not None: if hasattr(output_embeddings, 'bias') and output_embeddings.bias is not None:
first_module.bias.data = torch.nn.functional.pad( output_embeddings.bias.data = torch.nn.functional.pad(
first_module.bias.data, output_embeddings.bias.data,
(0, first_module.weight.shape[0] - first_module.bias.shape[0]), (0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0]),
'constant', 'constant',
0 0
) )
if hasattr(output_embeddings, 'out_features') and hasattr(input_embeddings, 'num_embeddings'):
output_embeddings.out_features = input_embeddings.num_embeddings
def resize_token_embeddings(self, new_num_tokens=None): def resize_token_embeddings(self, new_num_tokens=None):
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size. """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
...@@ -161,6 +149,45 @@ class PreTrainedModel(nn.Module): ...@@ -161,6 +149,45 @@ class PreTrainedModel(nn.Module):
return model_embeds return model_embeds
def _resize_token_embeddings(self, new_num_tokens):
old_embeddings = self.input_embeddings
self.input_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
return self.input_embeddings
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
""" Build a resized Embedding Module from a provided token Embedding Module.
Increasing the size will add newly initialized vectors at the end
Reducing the size will remove vectors from the end
Args:
new_num_tokens: (`optional`) int
New number of tokens in the embedding matrix.
Increasing the size will add newly initialized vectors at the end
Reducing the size will remove vectors from the end
If not provided or None: return the provided token Embedding Module.
Return: ``torch.nn.Embeddings``
Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
"""
if new_num_tokens is None:
return old_embeddings
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
if old_num_tokens == new_num_tokens:
return old_embeddings
# Build new embeddings
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
new_embeddings.to(old_embeddings.weight.device)
# initialize all new embeddings (in particular added tokens)
self._init_weights(new_embeddings)
# Copy word embeddings from the previous weights
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
return new_embeddings
def init_weights(self): def init_weights(self):
""" Initialize and prunes weights if needed. """ """ Initialize and prunes weights if needed. """
# Initialize weights # Initialize weights
...@@ -170,6 +197,9 @@ class PreTrainedModel(nn.Module): ...@@ -170,6 +197,9 @@ class PreTrainedModel(nn.Module):
if self.config.pruned_heads: if self.config.pruned_heads:
self.prune_heads(self.config.pruned_heads) self.prune_heads(self.config.pruned_heads)
# Tie weights if needed
self.tie_weights()
def prune_heads(self, heads_to_prune): def prune_heads(self, heads_to_prune):
""" Prunes heads of the base model. """ Prunes heads of the base model.
...@@ -178,14 +208,12 @@ class PreTrainedModel(nn.Module): ...@@ -178,14 +208,12 @@ class PreTrainedModel(nn.Module):
heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`). heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2. E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
""" """
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
# save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads) union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON
base_model._prune_heads(heads_to_prune) self.base_model._prune_heads(heads_to_prune)
def save_pretrained(self, save_directory): def save_pretrained(self, save_directory):
""" Save a model and its configuration file to a directory, so that it """ Save a model and its configuration file to a directory, so that it
......
...@@ -407,10 +407,14 @@ class XLMModel(XLMPreTrainedModel): ...@@ -407,10 +407,14 @@ class XLMModel(XLMPreTrainedModel):
self.init_weights() self.init_weights()
def _resize_token_embeddings(self, new_num_tokens): @property
self.embeddings = self._get_resized_embeddings(self.embeddings, new_num_tokens) def input_embeddings(self):
return self.embeddings return self.embeddings
@input_embeddings.setter
def input_embeddings(self, new_embeddings):
self.embeddings = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """ Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
...@@ -618,12 +622,10 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): ...@@ -618,12 +622,10 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
self.pred_layer = XLMPredLayer(config) self.pred_layer = XLMPredLayer(config)
self.init_weights() self.init_weights()
self.tie_weights()
def tie_weights(self): @property
""" Make sure we are sharing the embeddings def output_embeddings(self):
""" return self.pred_layer.proj
self._tie_or_clone_weights(self.pred_layer.proj, self.transformer.embeddings)
def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None, def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
lengths=None, cache=None, head_mask=None, labels=None): lengths=None, cache=None, head_mask=None, labels=None):
......
...@@ -611,10 +611,14 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -611,10 +611,14 @@ class XLNetModel(XLNetPreTrainedModel):
self.init_weights() self.init_weights()
def _resize_token_embeddings(self, new_num_tokens): @property
self.word_embedding = self._get_resized_embeddings(self.word_embedding, new_num_tokens) def input_embeddings(self):
return self.word_embedding return self.word_embedding
@input_embeddings.setter
def input_embeddings(self, new_embeddings):
self.word_embedding = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
raise NotImplementedError raise NotImplementedError
...@@ -918,12 +922,10 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -918,12 +922,10 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
self.lm_loss = nn.Linear(config.d_model, config.n_token, bias=True) self.lm_loss = nn.Linear(config.d_model, config.n_token, bias=True)
self.init_weights() self.init_weights()
self.tie_weights()
def tie_weights(self): @property
""" Make sure we are sharing the embeddings def output_embeddings(self):
""" return self.lm_loss
self._tie_or_clone_weights(self.lm_loss, self.transformer.word_embedding)
def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None, def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None, labels=None): token_type_ids=None, input_mask=None, head_mask=None, labels=None):
......
...@@ -463,6 +463,15 @@ class CommonTestCases: ...@@ -463,6 +463,15 @@ class CommonTestCases:
self.assertTrue(models_equal) self.assertTrue(models_equal)
def test_model_common_attributes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
self.assertTrue(hasattr(model, 'input_embeddings'))
setattr(model, 'input_embeddings', torch.nn.Embedding(10, 10))
self.assertTrue(hasattr(model, 'output_embeddings'))
def test_tie_model_weights(self): def test_tie_model_weights(self):
if not self.test_torchscript: if not self.test_torchscript:
return return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment