Unverified Commit c8f27121 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #1721 from huggingface/common_attributes

Add common getter and setter for input_embeddings & output_embeddings
parents 1d4d0702 b340a910
...@@ -280,12 +280,13 @@ class XxxModel(XxxPreTrainedModel): ...@@ -280,12 +280,13 @@ class XxxModel(XxxPreTrainedModel):
self.init_weights() self.init_weights()
def _resize_token_embeddings(self, new_num_tokens): @property
old_embeddings = self.embeddings.word_embeddings def get_input_embeddings(self):
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
self.embeddings.word_embeddings = new_embeddings
return self.embeddings.word_embeddings return self.embeddings.word_embeddings
def set_input_embeddings(self, new_embeddings):
self.embeddings.word_embeddings = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """ Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
...@@ -376,17 +377,12 @@ class XxxForMaskedLM(XxxPreTrainedModel): ...@@ -376,17 +377,12 @@ class XxxForMaskedLM(XxxPreTrainedModel):
super(XxxForMaskedLM, self).__init__(config) super(XxxForMaskedLM, self).__init__(config)
self.transformer = XxxModel(config) self.transformer = XxxModel(config)
self.cls = XxxOnlyMLMHead(config) self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
self.init_weights() self.init_weights()
self.tie_weights()
def tie_weights(self): def get_output_embeddings(self):
""" Make sure we are sharing the input and output embeddings. return self.lm_head
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
self._tie_or_clone_weights(self.cls.predictions.decoder,
self.transformer.embeddings.word_embeddings)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
masked_lm_labels=None): masked_lm_labels=None):
......
...@@ -601,12 +601,12 @@ class BertModel(BertPreTrainedModel): ...@@ -601,12 +601,12 @@ class BertModel(BertPreTrainedModel):
self.init_weights() self.init_weights()
def _resize_token_embeddings(self, new_num_tokens): def get_input_embeddings(self):
old_embeddings = self.embeddings.word_embeddings
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
self.embeddings.word_embeddings = new_embeddings
return self.embeddings.word_embeddings return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """ Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
...@@ -750,14 +750,9 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -750,14 +750,9 @@ class BertForPreTraining(BertPreTrainedModel):
self.cls = BertPreTrainingHeads(config) self.cls = BertPreTrainingHeads(config)
self.init_weights() self.init_weights()
self.tie_weights()
def tie_weights(self): def get_output_embeddings(self):
""" Make sure we are sharing the input and output embeddings. return self.cls.predictions.decoder
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
self._tie_or_clone_weights(self.cls.predictions.decoder,
self.bert.embeddings.word_embeddings)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
masked_lm_labels=None, next_sentence_label=None): masked_lm_labels=None, next_sentence_label=None):
...@@ -830,14 +825,9 @@ class BertForMaskedLM(BertPreTrainedModel): ...@@ -830,14 +825,9 @@ class BertForMaskedLM(BertPreTrainedModel):
self.cls = BertOnlyMLMHead(config) self.cls = BertOnlyMLMHead(config)
self.init_weights() self.init_weights()
self.tie_weights()
def tie_weights(self): def get_output_embeddings(self):
""" Make sure we are sharing the input and output embeddings. return self.cls.predictions.decoder
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
self._tie_or_clone_weights(self.cls.predictions.decoder,
self.bert.embeddings.word_embeddings)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
masked_lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None, lm_labels=None, ): masked_lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None, lm_labels=None, ):
......
...@@ -289,10 +289,12 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -289,10 +289,12 @@ class CTRLModel(CTRLPreTrainedModel):
self.init_weights() self.init_weights()
def _resize_token_embeddings(self, new_num_tokens): def get_input_embeddings(self):
self.w = self._get_resized_embeddings(self.w, new_num_tokens)
return self.w return self.w
def set_input_embeddings(self, new_embeddings):
self.w = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """ Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
...@@ -449,13 +451,9 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): ...@@ -449,13 +451,9 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=True) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=True)
self.init_weights() self.init_weights()
self.tie_weights()
def tie_weights(self): def get_output_embeddings(self):
""" Make sure we are sharing the input and output embeddings. return self.lm_head
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
self._tie_or_clone_weights(self.lm_head, self.transformer.w)
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
labels=None): labels=None):
......
...@@ -334,9 +334,6 @@ class DistilBertPreTrainedModel(PreTrainedModel): ...@@ -334,9 +334,6 @@ class DistilBertPreTrainedModel(PreTrainedModel):
load_tf_weights = None load_tf_weights = None
base_model_prefix = "distilbert" base_model_prefix = "distilbert"
def __init__(self, *inputs, **kwargs):
super(DistilBertPreTrainedModel, self).__init__(*inputs, **kwargs)
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights. """ Initialize the weights.
""" """
...@@ -424,12 +421,12 @@ class DistilBertModel(DistilBertPreTrainedModel): ...@@ -424,12 +421,12 @@ class DistilBertModel(DistilBertPreTrainedModel):
self.init_weights() self.init_weights()
def _resize_token_embeddings(self, new_num_tokens): def get_input_embeddings(self):
old_embeddings = self.embeddings.word_embeddings
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
self.embeddings.word_embeddings = new_embeddings
return self.embeddings.word_embeddings return self.embeddings.word_embeddings
def set_input_embeddings(self, new_embeddings):
self.embeddings.word_embeddings = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """ Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
...@@ -511,16 +508,11 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel): ...@@ -511,16 +508,11 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
self.vocab_projector = nn.Linear(config.dim, config.vocab_size) self.vocab_projector = nn.Linear(config.dim, config.vocab_size)
self.init_weights() self.init_weights()
self.tie_weights()
self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1) self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
def tie_weights(self): def get_output_embeddings(self):
""" Make sure we are sharing the input and output embeddings. return self.vocab_projector
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
self._tie_or_clone_weights(self.vocab_projector,
self.distilbert.embeddings.word_embeddings)
def forward(self, input_ids, attention_mask=None, head_mask=None, masked_lm_labels=None): def forward(self, input_ids, attention_mask=None, head_mask=None, masked_lm_labels=None):
dlbrt_output = self.distilbert(input_ids=input_ids, dlbrt_output = self.distilbert(input_ids=input_ids,
......
...@@ -357,10 +357,12 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -357,10 +357,12 @@ class GPT2Model(GPT2PreTrainedModel):
self.init_weights() self.init_weights()
def _resize_token_embeddings(self, new_num_tokens): def get_input_embeddings(self):
self.wte = self._get_resized_embeddings(self.wte, new_num_tokens)
return self.wte return self.wte
def set_input_embeddings(self, new_embeddings):
self.wte = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """ Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
...@@ -514,14 +516,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -514,14 +516,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.init_weights() self.init_weights()
self.tie_weights()
def tie_weights(self): def get_output_embeddings(self):
""" Make sure we are sharing the input and output embeddings. return self.lm_head
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
self._tie_or_clone_weights(self.lm_head,
self.transformer.wte)
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
labels=None): labels=None):
...@@ -622,14 +619,9 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -622,14 +619,9 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
self.multiple_choice_head = SequenceSummary(config) self.multiple_choice_head = SequenceSummary(config)
self.init_weights() self.init_weights()
self.tie_weights()
def tie_weights(self): def get_output_embeddings(self):
""" Make sure we are sharing the input and output embeddings. return self.lm_head
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
self._tie_or_clone_weights(self.lm_head,
self.transformer.wte)
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
mc_token_ids=None, lm_labels=None, mc_labels=None): mc_token_ids=None, lm_labels=None, mc_labels=None):
......
...@@ -360,10 +360,12 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -360,10 +360,12 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
self.init_weights() self.init_weights()
def _resize_token_embeddings(self, new_num_tokens): def get_input_embeddings(self):
self.tokens_embed = self._get_resized_embeddings(self.tokens_embed, new_num_tokens)
return self.tokens_embed return self.tokens_embed
def set_input_embeddings(self, new_embeddings):
self.tokens_embed = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """ Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
...@@ -489,14 +491,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): ...@@ -489,14 +491,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.init_weights() self.init_weights()
self.tie_weights()
def tie_weights(self): def get_output_embeddings(self):
""" Make sure we are sharing the input and output embeddings. return self.lm_head
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
self._tie_or_clone_weights(self.lm_head,
self.transformer.tokens_embed)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
labels=None): labels=None):
...@@ -583,14 +580,9 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): ...@@ -583,14 +580,9 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
self.multiple_choice_head = SequenceSummary(config) self.multiple_choice_head = SequenceSummary(config)
self.init_weights() self.init_weights()
self.tie_weights()
def tie_weights(self): def get_output_embeddings(self):
""" Make sure we are sharing the input and output embeddings. return self.lm_head
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
self._tie_or_clone_weights(self.lm_head,
self.transformer.tokens_embed)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
mc_token_ids=None, lm_labels=None, mc_labels=None): mc_token_ids=None, lm_labels=None, mc_labels=None):
......
...@@ -169,6 +169,11 @@ class RobertaModel(BertModel): ...@@ -169,6 +169,11 @@ class RobertaModel(BertModel):
self.embeddings = RobertaEmbeddings(config) self.embeddings = RobertaEmbeddings(config)
self.init_weights() self.init_weights()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, @add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """,
ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING) ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING)
...@@ -213,13 +218,9 @@ class RobertaForMaskedLM(BertPreTrainedModel): ...@@ -213,13 +218,9 @@ class RobertaForMaskedLM(BertPreTrainedModel):
self.lm_head = RobertaLMHead(config) self.lm_head = RobertaLMHead(config)
self.init_weights() self.init_weights()
self.tie_weights()
def tie_weights(self): def get_output_embeddings(self):
""" Make sure we are sharing the input and output embeddings. return self.lm_head.decoder
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
self._tie_or_clone_weights(self.lm_head.decoder, self.roberta.embeddings.word_embeddings)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
masked_lm_labels=None): masked_lm_labels=None):
......
...@@ -639,9 +639,12 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -639,9 +639,12 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
self.init_weights() self.init_weights()
def _resize_token_embeddings(self, new_num_tokens): def get_input_embeddings(self):
return self.word_emb return self.word_emb
def set_input_embeddings(self, new_embeddings):
self.word_emb = new_embeddings
def backward_compatible(self): def backward_compatible(self):
self.sample_softmax = -1 self.sample_softmax = -1
...@@ -826,7 +829,6 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): ...@@ -826,7 +829,6 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
self.crit = ProjectedAdaptiveLogSoftmax(config.n_token, config.d_embed, config.d_model, self.crit = ProjectedAdaptiveLogSoftmax(config.n_token, config.d_embed, config.d_model,
config.cutoffs, div_val=config.div_val) config.cutoffs, div_val=config.div_val)
self.init_weights() self.init_weights()
self.tie_weights()
def tie_weights(self): def tie_weights(self):
""" """
......
...@@ -83,55 +83,59 @@ class PreTrainedModel(nn.Module): ...@@ -83,55 +83,59 @@ class PreTrainedModel(nn.Module):
# Save config in model # Save config in model
self.config = config self.config = config
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None): @property
""" Build a resized Embedding Module from a provided token Embedding Module. def base_model(self):
Increasing the size will add newly initialized vectors at the end return getattr(self, self.base_model_prefix, self)
Reducing the size will remove vectors from the end
Args: def get_input_embeddings(self):
new_num_tokens: (`optional`) int """ Get model's input embeddings
New number of tokens in the embedding matrix.
Increasing the size will add newly initialized vectors at the end
Reducing the size will remove vectors from the end
If not provided or None: return the provided token Embedding Module.
Return: ``torch.nn.Embeddings``
Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
""" """
if new_num_tokens is None: base_model = getattr(self, self.base_model_prefix, self)
return old_embeddings if base_model is not self:
return base_model.get_input_embeddings()
old_num_tokens, old_embedding_dim = old_embeddings.weight.size() else:
if old_num_tokens == new_num_tokens: raise NotImplementedError
return old_embeddings
# Build new embeddings
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
new_embeddings.to(old_embeddings.weight.device)
# initialize all new embeddings (in particular added tokens) def set_input_embeddings(self, value):
self._init_weights(new_embeddings) """ Set model's input embeddings
"""
base_model = getattr(self, self.base_model_prefix, self)
if base_model is not self:
base_model.set_input_embeddings(value)
else:
raise NotImplementedError
# Copy word embeddings from the previous weights def get_output_embeddings(self):
num_tokens_to_copy = min(old_num_tokens, new_num_tokens) """ Get model's output embeddings
new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :] Return None if the model doesn't have output embeddings
"""
return None # Overwrite for models with output embeddings
return new_embeddings def tie_weights(self):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
output_embeddings = self.get_output_embeddings()
if output_embeddings is not None:
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
def _tie_or_clone_weights(self, first_module, second_module): def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
""" Tie or clone module weights depending of weither we are using TorchScript or not """ Tie or clone module weights depending of weither we are using TorchScript or not
""" """
if self.config.torchscript: if self.config.torchscript:
first_module.weight = nn.Parameter(second_module.weight.clone()) output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone())
else: else:
first_module.weight = second_module.weight output_embeddings.weight = input_embeddings.weight
if hasattr(first_module, 'bias') and first_module.bias is not None: if hasattr(output_embeddings, 'bias') and output_embeddings.bias is not None:
first_module.bias.data = torch.nn.functional.pad( output_embeddings.bias.data = torch.nn.functional.pad(
first_module.bias.data, output_embeddings.bias.data,
(0, first_module.weight.shape[0] - first_module.bias.shape[0]), (0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0]),
'constant', 'constant',
0 0
) )
if hasattr(output_embeddings, 'out_features') and hasattr(input_embeddings, 'num_embeddings'):
output_embeddings.out_features = input_embeddings.num_embeddings
def resize_token_embeddings(self, new_num_tokens=None): def resize_token_embeddings(self, new_num_tokens=None):
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size. """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
...@@ -161,6 +165,46 @@ class PreTrainedModel(nn.Module): ...@@ -161,6 +165,46 @@ class PreTrainedModel(nn.Module):
return model_embeds return model_embeds
def _resize_token_embeddings(self, new_num_tokens):
old_embeddings = self.get_input_embeddings()
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
self.set_input_embeddings(new_embeddings)
return self.get_input_embeddings()
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
""" Build a resized Embedding Module from a provided token Embedding Module.
Increasing the size will add newly initialized vectors at the end
Reducing the size will remove vectors from the end
Args:
new_num_tokens: (`optional`) int
New number of tokens in the embedding matrix.
Increasing the size will add newly initialized vectors at the end
Reducing the size will remove vectors from the end
If not provided or None: return the provided token Embedding Module.
Return: ``torch.nn.Embeddings``
Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
"""
if new_num_tokens is None:
return old_embeddings
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
if old_num_tokens == new_num_tokens:
return old_embeddings
# Build new embeddings
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
new_embeddings.to(old_embeddings.weight.device)
# initialize all new embeddings (in particular added tokens)
self._init_weights(new_embeddings)
# Copy word embeddings from the previous weights
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
return new_embeddings
def init_weights(self): def init_weights(self):
""" Initialize and prunes weights if needed. """ """ Initialize and prunes weights if needed. """
# Initialize weights # Initialize weights
...@@ -170,6 +214,9 @@ class PreTrainedModel(nn.Module): ...@@ -170,6 +214,9 @@ class PreTrainedModel(nn.Module):
if self.config.pruned_heads: if self.config.pruned_heads:
self.prune_heads(self.config.pruned_heads) self.prune_heads(self.config.pruned_heads)
# Tie weights if needed
self.tie_weights()
def prune_heads(self, heads_to_prune): def prune_heads(self, heads_to_prune):
""" Prunes heads of the base model. """ Prunes heads of the base model.
...@@ -178,14 +225,12 @@ class PreTrainedModel(nn.Module): ...@@ -178,14 +225,12 @@ class PreTrainedModel(nn.Module):
heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`). heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2. E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
""" """
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
# save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads) union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON
base_model._prune_heads(heads_to_prune) self.base_model._prune_heads(heads_to_prune)
def save_pretrained(self, save_directory): def save_pretrained(self, save_directory):
""" Save a model and its configuration file to a directory, so that it """ Save a model and its configuration file to a directory, so that it
......
...@@ -407,10 +407,12 @@ class XLMModel(XLMPreTrainedModel): ...@@ -407,10 +407,12 @@ class XLMModel(XLMPreTrainedModel):
self.init_weights() self.init_weights()
def _resize_token_embeddings(self, new_num_tokens): def get_input_embeddings(self):
self.embeddings = self._get_resized_embeddings(self.embeddings, new_num_tokens)
return self.embeddings return self.embeddings
def set_input_embeddings(self, new_embeddings):
self.embeddings = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """ Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
...@@ -618,12 +620,9 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): ...@@ -618,12 +620,9 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
self.pred_layer = XLMPredLayer(config) self.pred_layer = XLMPredLayer(config)
self.init_weights() self.init_weights()
self.tie_weights()
def tie_weights(self): def get_output_embeddings(self):
""" Make sure we are sharing the embeddings return self.pred_layer.proj
"""
self._tie_or_clone_weights(self.pred_layer.proj, self.transformer.embeddings)
def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None, def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
lengths=None, cache=None, head_mask=None, labels=None): lengths=None, cache=None, head_mask=None, labels=None):
......
...@@ -611,10 +611,12 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -611,10 +611,12 @@ class XLNetModel(XLNetPreTrainedModel):
self.init_weights() self.init_weights()
def _resize_token_embeddings(self, new_num_tokens): def get_input_embeddings(self):
self.word_embedding = self._get_resized_embeddings(self.word_embedding, new_num_tokens)
return self.word_embedding return self.word_embedding
def set_input_embeddings(self, new_embeddings):
self.word_embedding = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
raise NotImplementedError raise NotImplementedError
...@@ -918,12 +920,9 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -918,12 +920,9 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
self.lm_loss = nn.Linear(config.d_model, config.n_token, bias=True) self.lm_loss = nn.Linear(config.d_model, config.n_token, bias=True)
self.init_weights() self.init_weights()
self.tie_weights()
def tie_weights(self): def get_output_embeddings(self):
""" Make sure we are sharing the embeddings return self.lm_loss
"""
self._tie_or_clone_weights(self.lm_loss, self.transformer.word_embedding)
def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None, def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None, labels=None): token_type_ids=None, input_mask=None, head_mask=None, labels=None):
......
...@@ -38,6 +38,7 @@ else: ...@@ -38,6 +38,7 @@ else:
class AutoModelTest(unittest.TestCase): class AutoModelTest(unittest.TestCase):
@pytest.mark.slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
...@@ -52,6 +53,7 @@ class AutoModelTest(unittest.TestCase): ...@@ -52,6 +53,7 @@ class AutoModelTest(unittest.TestCase):
for value in loading_info.values(): for value in loading_info.values():
self.assertEqual(len(value), 0) self.assertEqual(len(value), 0)
@pytest.mark.slow
def test_lmhead_model_from_pretrained(self): def test_lmhead_model_from_pretrained(self):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
...@@ -64,6 +66,7 @@ class AutoModelTest(unittest.TestCase): ...@@ -64,6 +66,7 @@ class AutoModelTest(unittest.TestCase):
self.assertIsNotNone(model) self.assertIsNotNone(model)
self.assertIsInstance(model, BertForMaskedLM) self.assertIsInstance(model, BertForMaskedLM)
@pytest.mark.slow
def test_sequence_classification_model_from_pretrained(self): def test_sequence_classification_model_from_pretrained(self):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
...@@ -76,6 +79,7 @@ class AutoModelTest(unittest.TestCase): ...@@ -76,6 +79,7 @@ class AutoModelTest(unittest.TestCase):
self.assertIsNotNone(model) self.assertIsNotNone(model)
self.assertIsInstance(model, BertForSequenceClassification) self.assertIsInstance(model, BertForSequenceClassification)
@pytest.mark.slow
def test_question_answering_model_from_pretrained(self): def test_question_answering_model_from_pretrained(self):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
......
...@@ -463,6 +463,15 @@ class CommonTestCases: ...@@ -463,6 +463,15 @@ class CommonTestCases:
self.assertTrue(models_equal) self.assertTrue(models_equal)
def test_model_common_attributes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.get_input_embeddings()
model.set_input_embeddings(torch.nn.Embedding(10, 10))
model.get_output_embeddings()
def test_tie_model_weights(self): def test_tie_model_weights(self):
if not self.test_torchscript: if not self.test_torchscript:
return return
...@@ -477,11 +486,11 @@ class CommonTestCases: ...@@ -477,11 +486,11 @@ class CommonTestCases:
return equal return equal
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
if not hasattr(model_class, 'tie_weights'):
continue
config.torchscript = True config.torchscript = True
model_not_tied = model_class(config) model_not_tied = model_class(config)
if model_not_tied.get_output_embeddings() is None:
continue
params_not_tied = list(model_not_tied.parameters()) params_not_tied = list(model_not_tied.parameters())
config_tied = copy.deepcopy(config) config_tied = copy.deepcopy(config)
...@@ -688,6 +697,7 @@ class CommonTestCases: ...@@ -688,6 +697,7 @@ class CommonTestCases:
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
self.create_and_check_presents(*config_and_inputs) self.create_and_check_presents(*config_and_inputs)
@pytest.mark.slow
def run_slow_tests(self): def run_slow_tests(self):
self.create_and_check_model_from_pretrained() self.create_and_check_model_from_pretrained()
...@@ -761,6 +771,7 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None): ...@@ -761,6 +771,7 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None):
class ModelUtilsTest(unittest.TestCase): class ModelUtilsTest(unittest.TestCase):
@pytest.mark.slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
......
...@@ -27,6 +27,7 @@ else: ...@@ -27,6 +27,7 @@ else:
class EncoderDecoderModelTest(unittest.TestCase): class EncoderDecoderModelTest(unittest.TestCase):
@pytest.mark.slow
def test_model2model_from_pretrained(self): def test_model2model_from_pretrained(self):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
......
...@@ -26,6 +26,7 @@ from transformers import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2_PRETRAINED_CON ...@@ -26,6 +26,7 @@ from transformers import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2_PRETRAINED_CON
class AutoTokenizerTest(unittest.TestCase): class AutoTokenizerTest(unittest.TestCase):
@pytest.mark.slow
def test_tokenizer_from_pretrained(self): def test_tokenizer_from_pretrained(self):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
for model_name in list(BERT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())[:1]: for model_name in list(BERT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())[:1]:
......
...@@ -16,6 +16,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -16,6 +16,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os import os
import unittest import unittest
import pytest
from io import open from io import open
from transformers.tokenization_bert import (BasicTokenizer, from transformers.tokenization_bert import (BasicTokenizer,
...@@ -125,6 +126,7 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -125,6 +126,7 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
self.assertFalse(_is_punctuation(u"A")) self.assertFalse(_is_punctuation(u"A"))
self.assertFalse(_is_punctuation(u" ")) self.assertFalse(_is_punctuation(u" "))
@pytest.mark.slow
def test_sequence_builders(self): def test_sequence_builders(self):
tokenizer = self.tokenizer_class.from_pretrained("bert-base-uncased") tokenizer = self.tokenizer_class.from_pretrained("bert-base-uncased")
......
...@@ -16,6 +16,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -16,6 +16,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os import os
import unittest import unittest
import pytest
from io import open from io import open
from transformers.tokenization_distilbert import (DistilBertTokenizer) from transformers.tokenization_distilbert import (DistilBertTokenizer)
...@@ -30,6 +31,7 @@ class DistilBertTokenizationTest(BertTokenizationTest): ...@@ -30,6 +31,7 @@ class DistilBertTokenizationTest(BertTokenizationTest):
def get_tokenizer(self, **kwargs): def get_tokenizer(self, **kwargs):
return DistilBertTokenizer.from_pretrained(self.tmpdirname, **kwargs) return DistilBertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
@pytest.mark.slow
def test_sequence_builders(self): def test_sequence_builders(self):
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
......
...@@ -17,6 +17,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -17,6 +17,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os import os
import json import json
import unittest import unittest
import pytest
from io import open from io import open
from transformers.tokenization_roberta import RobertaTokenizer, VOCAB_FILES_NAMES from transformers.tokenization_roberta import RobertaTokenizer, VOCAB_FILES_NAMES
...@@ -78,6 +79,7 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -78,6 +79,7 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2] [0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]
) )
@pytest.mark.slow
def test_sequence_builders(self): def test_sequence_builders(self):
tokenizer = RobertaTokenizer.from_pretrained("roberta-base") tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
......
...@@ -18,11 +18,13 @@ from __future__ import print_function ...@@ -18,11 +18,13 @@ from __future__ import print_function
import unittest import unittest
import six import six
import pytest
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from transformers.tokenization_gpt2 import GPT2Tokenizer from transformers.tokenization_gpt2 import GPT2Tokenizer
class TokenizerUtilsTest(unittest.TestCase): class TokenizerUtilsTest(unittest.TestCase):
@pytest.mark.slow
def check_tokenizer_from_pretrained(self, tokenizer_class): def check_tokenizer_from_pretrained(self, tokenizer_class):
s3_models = list(tokenizer_class.max_model_input_sizes.keys()) s3_models = list(tokenizer_class.max_model_input_sizes.keys())
for model_name in s3_models[:1]: for model_name in s3_models[:1]:
......
...@@ -17,6 +17,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -17,6 +17,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os import os
import unittest import unittest
import json import json
import pytest
from transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES from transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES
...@@ -66,6 +67,7 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -66,6 +67,7 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester):
self.assertListEqual( self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
@pytest.mark.slow
def test_sequence_builders(self): def test_sequence_builders(self):
tokenizer = XLMTokenizer.from_pretrained("xlm-mlm-en-2048") tokenizer = XLMTokenizer.from_pretrained("xlm-mlm-en-2048")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment