Unverified Commit 443f67e8 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[PyTorch] Refactor Resize Token Embeddings (#8880)

* fix resize tokens

* correct mobile_bert

* move embedding fix into modeling_utils.py

* refactor

* fix lm head resize

* refactor

* break lines to make sylvain happy

* add news tests

* fix typo

* improve test

* skip bart-like for now

* check if base_model = get(...) is necessary

* clean files

* improve test

* fix tests

* revert style templates

* Update templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py
parent e52f9c0a
...@@ -605,14 +605,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -605,14 +605,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
Return: Return:
:obj:`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model. :obj:`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
""" """
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed model_embeds = self._resize_token_embeddings(new_num_tokens)
model_embeds = base_model._resize_token_embeddings(new_num_tokens)
if new_num_tokens is None: if new_num_tokens is None:
return model_embeds return model_embeds
# Update base model and current model config # Update base model and current model config
self.config.vocab_size = new_num_tokens self.config.vocab_size = new_num_tokens
base_model.vocab_size = new_num_tokens self.vocab_size = new_num_tokens
# Tie weights again if needed # Tie weights again if needed
self.tie_weights() self.tie_weights()
...@@ -623,6 +622,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -623,6 +622,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
old_embeddings = self.get_input_embeddings() old_embeddings = self.get_input_embeddings()
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens) new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
self.set_input_embeddings(new_embeddings) self.set_input_embeddings(new_embeddings)
# if word embeddings are not tied, make sure that lm head is resized as well
if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
old_lm_head = self.get_output_embeddings()
new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens)
self.set_output_embeddings(new_lm_head)
return self.get_input_embeddings() return self.get_input_embeddings()
def _get_resized_embeddings( def _get_resized_embeddings(
...@@ -653,9 +659,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -653,9 +659,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
if old_num_tokens == new_num_tokens: if old_num_tokens == new_num_tokens:
return old_embeddings return old_embeddings
if not isinstance(old_embeddings, nn.Embedding):
raise TypeError(
f"Old embeddings are of type {type(old_embeddings)}, which is not an instance of {nn.Embedding}."
f"You should either use a different resize function or make sure that `old_embeddings` are an instance of {nn.Embedding}."
)
# Build new embeddings # Build new embeddings
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim) new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim).to(self.device)
new_embeddings.to(old_embeddings.weight.device)
# initialize all new embeddings (in particular added tokens) # initialize all new embeddings (in particular added tokens)
self._init_weights(new_embeddings) self._init_weights(new_embeddings)
...@@ -666,6 +677,68 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -666,6 +677,68 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
return new_embeddings return new_embeddings
def _get_resized_lm_head(
self, old_lm_head: torch.nn.Linear, new_num_tokens: Optional[int] = None, transposed: Optional[bool] = False
) -> torch.nn.Linear:
"""
Build a resized Linear Module from a provided old Linear Module. Increasing the size will add newly initialized
vectors at the end. Reducing the size will remove vectors from the end
Args:
old_lm_head (:obj:`torch.nn.Linear`):
Old lm head liner layer to be resized.
new_num_tokens (:obj:`int`, `optional`):
New number of tokens in the linear matrix.
Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
vectors from the end. If not provided or :obj:`None`, just returns a pointer to the input tokens
:obj:`torch.nn.Linear`` module of the model without doing anything.
transposed (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether ``old_lm_head`` is transposed or not. If True ``old_lm_head.size()`` is ``lm_head_dim,
vocab_size`` else ``vocab_size, lm_head_dim``.
Return:
:obj:`torch.nn.Linear`: Pointer to the resized Linear Module or the old Linear Module if
:obj:`new_num_tokens` is :obj:`None`
"""
if new_num_tokens is None:
return old_lm_head
old_num_tokens, old_lm_head_dim = (
old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size()
)
if old_num_tokens == new_num_tokens:
return old_lm_head
if not isinstance(old_lm_head, nn.Linear):
raise TypeError(
f"Old language model head is of type {type(old_lm_head)}, which is not an instance of {nn.Linear}."
f"You should either use a different resize function or make sure that `old_embeddings` are an instance of {nn.Linear}."
)
# Build new lm head
new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim)
has_new_lm_head_bias = old_lm_head.bias is not None
new_lm_head = nn.Linear(*new_lm_head_shape, bias=has_new_lm_head_bias).to(self.device)
# initialize new lm head (in particular added tokens)
self._init_weights(new_lm_head)
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
# Copy old lm head weights to new lm head
if not transposed:
new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :]
else:
new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[:, :num_tokens_to_copy]
# Copy bias weights to new lm head
if has_new_lm_head_bias:
new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy]
return new_lm_head
def init_weights(self): def init_weights(self):
""" """
Initializes and prunes weights if needed. Initializes and prunes weights if needed.
......
...@@ -632,12 +632,6 @@ class AlbertModel(AlbertPreTrainedModel): ...@@ -632,12 +632,6 @@ class AlbertModel(AlbertPreTrainedModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value self.embeddings.word_embeddings = value
def _resize_token_embeddings(self, new_num_tokens):
old_embeddings = self.embeddings.word_embeddings
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
self.embeddings.word_embeddings = new_embeddings
return self.embeddings.word_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" """
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} ALBERT has Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} ALBERT has
...@@ -748,6 +742,9 @@ class AlbertForPreTraining(AlbertPreTrainedModel): ...@@ -748,6 +742,9 @@ class AlbertForPreTraining(AlbertPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.predictions.decoder return self.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.predictions.decoder = new_embeddings
def get_input_embeddings(self): def get_input_embeddings(self):
return self.albert.embeddings.word_embeddings return self.albert.embeddings.word_embeddings
...@@ -889,6 +886,9 @@ class AlbertForMaskedLM(AlbertPreTrainedModel): ...@@ -889,6 +886,9 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.predictions.decoder return self.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.predictions.decoder = new_embeddings
def get_input_embeddings(self): def get_input_embeddings(self):
return self.albert.embeddings.word_embeddings return self.albert.embeddings.word_embeddings
......
...@@ -905,6 +905,9 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -905,6 +905,9 @@ class BertForPreTraining(BertPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.cls.predictions.decoder return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
...@@ -1010,6 +1013,9 @@ class BertLMHeadModel(BertPreTrainedModel): ...@@ -1010,6 +1013,9 @@ class BertLMHeadModel(BertPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.cls.predictions.decoder return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
...@@ -1131,6 +1137,9 @@ class BertForMaskedLM(BertPreTrainedModel): ...@@ -1131,6 +1137,9 @@ class BertForMaskedLM(BertPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.cls.predictions.decoder return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
......
...@@ -422,6 +422,9 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel): ...@@ -422,6 +422,9 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head.decoder return self.lm_head.decoder
def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
@add_start_docstrings_to_model_forward(BERT_GENERATION_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(BERT_GENERATION_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
......
...@@ -496,6 +496,9 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): ...@@ -496,6 +496,9 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids, past=None, use_cache=None, **kwargs): def prepare_inputs_for_generation(self, input_ids, past=None, use_cache=None, **kwargs):
# only last token for inputs_ids if past is defined in kwargs # only last token for inputs_ids if past is defined in kwargs
if past: if past:
......
...@@ -508,6 +508,9 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel): ...@@ -508,6 +508,9 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.vocab_projector return self.vocab_projector
def set_output_embeddings(self, new_embeddings):
self.vocab_projector = new_embeddings
@add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices")) @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
......
...@@ -1003,6 +1003,9 @@ class ElectraForMaskedLM(ElectraPreTrainedModel): ...@@ -1003,6 +1003,9 @@ class ElectraForMaskedLM(ElectraPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.generator_lm_head return self.generator_lm_head
def set_output_embeddings(self, word_embeddings):
self.generator_lm_head = word_embeddings
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
......
...@@ -194,6 +194,9 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -194,6 +194,9 @@ class EncoderDecoderModel(PreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.decoder.get_output_embeddings() return self.decoder.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
return self.decoder.set_output_embeddings(new_embeddings)
@classmethod @classmethod
def from_encoder_decoder_pretrained( def from_encoder_decoder_pretrained(
cls, cls,
......
...@@ -1167,6 +1167,9 @@ class FunnelForMaskedLM(FunnelPreTrainedModel): ...@@ -1167,6 +1167,9 @@ class FunnelForMaskedLM(FunnelPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
@add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
......
...@@ -816,6 +816,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -816,6 +816,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None) token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs # only last token for inputs_ids if past is defined in kwargs
...@@ -945,6 +948,9 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -945,6 +948,9 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None) token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs # only last token for inputs_ids if past is defined in kwargs
......
...@@ -781,6 +781,9 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel): ...@@ -781,6 +781,9 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.cls.predictions.decoder return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
@add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) @add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
......
...@@ -1632,6 +1632,9 @@ class LongformerForMaskedLM(LongformerPreTrainedModel): ...@@ -1632,6 +1632,9 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head.decoder return self.lm_head.decoder
def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=LongformerMaskedLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=LongformerMaskedLMOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
......
...@@ -641,7 +641,7 @@ class MobileBertLMPredictionHead(nn.Module): ...@@ -641,7 +641,7 @@ class MobileBertLMPredictionHead(nn.Module):
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states = self.transform(hidden_states) hidden_states = self.transform(hidden_states)
hidden_states = hidden_states.matmul(torch.cat([self.decoder.weight.t(), self.dense.weight], dim=0)) hidden_states = hidden_states.matmul(torch.cat([self.decoder.weight.t(), self.dense.weight], dim=0))
hidden_states += self.bias hidden_states += self.decoder.bias
return hidden_states return hidden_states
...@@ -949,26 +949,16 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel): ...@@ -949,26 +949,16 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.cls.predictions.decoder return self.cls.predictions.decoder
def tie_weights(self): def set_output_embeddings(self, new_embeddigs):
""" self.cls.predictions.decoder = new_embeddigs
Tie the weights between the input embeddings and the output embeddings. If the `torchscript` flag is set in the
configuration, can't handle parameter sharing so we are cloning the weights instead.
"""
output_embeddings = self.get_output_embeddings()
input_embeddings = self.get_input_embeddings()
resized_dense = nn.Linear( def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> torch.nn.Embedding:
input_embeddings.num_embeddings, self.config.hidden_size - self.config.embedding_size, bias=False # resize dense output embedings at first
self.cls.predictions.dense = self._get_resized_lm_head(
self.cls.predictions.dense, new_num_tokens=new_num_tokens, transposed=True
) )
kept_data = self.cls.predictions.dense.weight.data[
..., : min(self.cls.predictions.dense.weight.data.shape[1], resized_dense.weight.data.shape[1])
]
resized_dense.weight.data[..., : self.cls.predictions.dense.weight.data.shape[1]] = kept_data
self.cls.predictions.dense = resized_dense
self.cls.predictions.dense.to(self.device)
if output_embeddings is not None and self.config.tie_word_embeddings: return super().resize_token_embeddings(new_num_tokens=new_num_tokens)
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=MobileBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=MobileBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
...@@ -1067,26 +1057,15 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel): ...@@ -1067,26 +1057,15 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.cls.predictions.decoder return self.cls.predictions.decoder
def tie_weights(self): def set_output_embeddings(self, new_embeddigs):
""" self.cls.predictions.decoder = new_embeddigs
Tie the weights between the input embeddings and the output embeddings. If the `torchscript` flag is set in the
configuration, can't handle parameter sharing so we are cloning the weights instead.
"""
output_embeddings = self.get_output_embeddings()
input_embeddings = self.get_input_embeddings()
resized_dense = nn.Linear( def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> torch.nn.Embedding:
input_embeddings.num_embeddings, self.config.hidden_size - self.config.embedding_size, bias=False # resize dense output embedings at first
self.cls.predictions.dense = self._get_resized_lm_head(
self.cls.predictions.dense, new_num_tokens=new_num_tokens, transposed=True
) )
kept_data = self.cls.predictions.dense.weight.data[ return super().resize_token_embeddings(new_num_tokens=new_num_tokens)
..., : min(self.cls.predictions.dense.weight.data.shape[1], resized_dense.weight.data.shape[1])
]
resized_dense.weight.data[..., : self.cls.predictions.dense.weight.data.shape[1]] = kept_data
self.cls.predictions.dense = resized_dense
self.cls.predictions.dense.to(self.device)
if output_embeddings is not None and self.config.tie_word_embeddings:
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
......
...@@ -542,6 +542,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): ...@@ -542,6 +542,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
@add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
...@@ -628,6 +631,9 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): ...@@ -628,6 +631,9 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
@add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=OpenAIGPTDoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=OpenAIGPTDoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
......
...@@ -1703,6 +1703,9 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel): ...@@ -1703,6 +1703,9 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def get_input_embeddings(self): def get_input_embeddings(self):
return self.prophetnet.word_embeddings return self.prophetnet.word_embeddings
...@@ -1901,6 +1904,9 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel): ...@@ -1901,6 +1904,9 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
@add_start_docstrings_to_model_forward(PROPHETNET_STANDALONE_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(PROPHETNET_STANDALONE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=ProphetNetDecoderLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=ProphetNetDecoderLMOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
......
...@@ -1459,6 +1459,9 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1459,6 +1459,9 @@ class RagTokenForGeneration(RagPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.rag.generator.get_output_embeddings() return self.rag.generator.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
return self.rag.generator.set_output_embeddings(new_embeddings)
def shift_tokens_right(self, input_ids, start_token_id=None): def shift_tokens_right(self, input_ids, start_token_id=None):
"""Shift input ids one token to the right, and pad with start_token_id""" """Shift input ids one token to the right, and pad with start_token_id"""
if start_token_id is None: if start_token_id is None:
......
...@@ -2197,6 +2197,9 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel): ...@@ -2197,6 +2197,9 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head.decoder return self.lm_head.decoder
def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
@add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
...@@ -2309,6 +2312,9 @@ class ReformerForMaskedLM(ReformerPreTrainedModel): ...@@ -2309,6 +2312,9 @@ class ReformerForMaskedLM(ReformerPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head.decoder return self.lm_head.decoder
def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
@add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
......
...@@ -752,6 +752,9 @@ class RobertaForCausalLM(RobertaPreTrainedModel): ...@@ -752,6 +752,9 @@ class RobertaForCausalLM(RobertaPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head.decoder return self.lm_head.decoder
def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
...@@ -873,6 +876,9 @@ class RobertaForMaskedLM(RobertaPreTrainedModel): ...@@ -873,6 +876,9 @@ class RobertaForMaskedLM(RobertaPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head.decoder return self.lm_head.decoder
def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
......
...@@ -655,6 +655,9 @@ class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel): ...@@ -655,6 +655,9 @@ class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.cls.predictions.decoder return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
@add_start_docstrings_to_model_forward(SQUEEZEBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) @add_start_docstrings_to_model_forward(SQUEEZEBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
......
...@@ -1363,6 +1363,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1363,6 +1363,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
self.encoder.set_input_embeddings(new_embeddings) self.encoder.set_input_embeddings(new_embeddings)
self.decoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings)
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head return self.lm_head
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment