Unverified Commit 51d9c569 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Fix embeddings resizing in TF models (#8657)

* Resize the biases in same time than the embeddings

* Trigger CI

* Biases are not reset anymore

* Remove get_output_embeddings + better LM model detection in generation utils

* Apply style

* First test on BERT

* Update docstring + new name

* Apply the new resizing logic to all the models

* fix tests

* Apply style

* Update the template

* Fix naming

* Fix naming

* Apply style

* Apply style

* Remove unused import

* Revert get_output_embeddings

* Trigger CI

* Update num parameters

* Restore get_output_embeddings in TFPretrainedModel and add comments

* Style

* Add decoder resizing

* Style

* Fix tests

* Separate bias and decoder resize

* Fix tests

* Fix tests

* Apply style

* Add bias resizing in MPNet

* Trigger CI

* Apply style
parent 3552d0e0
...@@ -614,13 +614,32 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -614,13 +614,32 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
def get_output_embeddings(self) -> tf.keras.layers.Layer: def get_output_embeddings(self) -> tf.keras.layers.Layer:
""" """
Returns the model's output embeddings. Returns the model's output embeddings
Returns: Returns:
:obj:`tf.keras.layers.Layer`: A torch module mapping hidden states to vocabulary. :obj:`tf.keras.layers.Layer`: A torch module mapping hidden states to vocabulary.
""" """
return None # Overwrite for models with output embeddings return None # Overwrite for models with output embeddings
def get_output_layer_with_bias(self) -> Union[None, tf.keras.layers.Layer]:
"""
Get the layer that handles a bias attribute in case the model has an LM head with weights tied to the
embeddings.
Return:
:obj:`tf.keras.layers.Layer`: The layer that handles the bias, None if not an LM model.
"""
return None
def get_prefix_bias_name(self) -> Union[None, str]:
"""
Get the concatenated prefix name of the bias from the model name to the parent layer.
Return:
:obj:`str`: The prefix name of the bias.
"""
return None
def resize_token_embeddings(self, new_num_tokens=None) -> tf.Variable: def resize_token_embeddings(self, new_num_tokens=None) -> tf.Variable:
""" """
Resizes input token embeddings matrix of the model if :obj:`new_num_tokens != config.vocab_size`. Resizes input token embeddings matrix of the model if :obj:`new_num_tokens != config.vocab_size`.
...@@ -662,7 +681,17 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -662,7 +681,17 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
# TFSharedEmbeddings # TFSharedEmbeddings
return embeddings.weight return embeddings.weight
else: else:
raise ValueError("word embedding is not defined.") # Here we build the word embeddings weights if not exists.
# And then we retry to get the attribute once built.
embeddings.build([])
if hasattr(embeddings, "word_embeddings"):
# TFBertEmbeddings, TFAlbertEmbeddings, TFElectraEmbeddings
return embeddings.word_embeddings
elif hasattr(embeddings, "weight"):
# TFSharedEmbeddings
return embeddings.weight
else:
raise ValueError("word embedding is not defined.")
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None) -> tf.Variable: def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None) -> tf.Variable:
""" """
...@@ -684,28 +713,87 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -684,28 +713,87 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
:obj:`new_num_tokens` is :obj:`None` :obj:`new_num_tokens` is :obj:`None`
""" """
word_embeddings = self._get_word_embeddings(old_embeddings) word_embeddings = self._get_word_embeddings(old_embeddings)
bias_layer = self.get_output_layer_with_bias()
if new_num_tokens is None: if new_num_tokens is None:
return word_embeddings return word_embeddings
old_num_tokens, old_embedding_dim = word_embeddings.shape old_num_tokens, old_embedding_dim = word_embeddings.shape
if old_num_tokens == new_num_tokens: if old_num_tokens == new_num_tokens:
return word_embeddings return word_embeddings
# initialize new embeddings # initialize new embeddings
# todo: initializer range is not always passed in config. # todo: initializer range is not always passed in config.
init_range = getattr(self.config, "initializer_range", 0.02) init_range = getattr(self.config, "initializer_range", 0.02)
name = (
self.name
+ "/"
+ self.base_model_prefix
+ "/"
+ old_embeddings.name
+ "/"
+ word_embeddings.name.split(":")[0]
)
new_embeddings = self.add_weight( new_embeddings = self.add_weight(
"weight", name=name,
shape=[new_num_tokens, old_embedding_dim], shape=[new_num_tokens, old_embedding_dim],
initializer=get_initializer(init_range), initializer=get_initializer(init_range),
dtype=tf.float32, dtype=tf.float32,
) )
init_weights = new_embeddings.numpy() init_weights = tf.make_ndarray(tf.make_tensor_proto(new_embeddings.value()))
# Copy token embeddings from the previous weights # Copy token embeddings from the previous weights
num_tokens_to_copy = min(old_num_tokens, new_num_tokens) num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
init_weights[:num_tokens_to_copy] = word_embeddings[:num_tokens_to_copy, :] init_weights[:num_tokens_to_copy] = word_embeddings.value()[:num_tokens_to_copy, :]
new_embeddings.assign(init_weights) new_embeddings.assign(init_weights)
if bias_layer is not None:
if not hasattr(bias_layer, "bias"):
bias_layer.build([])
# Second check in order to be sure the attribute has been properly created
if not hasattr(bias_layer, "bias"):
raise ValueError("bias is not defined.")
# initialize bias
init_bias = np.zeros((new_num_tokens,))
init_bias[:num_tokens_to_copy] = bias_layer.bias.value()[
:num_tokens_to_copy
] # tf.make_ndarray(tf.make_tensor_proto(bias_layer.bias.value()))[:num_tokens_to_copy]
bias_layer.bias = self.add_weight(
shape=(new_num_tokens,),
initializer="zeros",
trainable=True,
name=self.get_prefix_bias_name() + "/bias",
)
bias_layer.bias.assign(init_bias)
output_embeddings = self.get_output_embeddings()
if output_embeddings is not None:
if self.get_input_embeddings() != output_embeddings:
if not hasattr(output_embeddings, "decoder"):
output_embeddings.build([])
# Second check in order to be sure the attribute has been properly created
if not hasattr(output_embeddings, "decoder"):
raise ValueError("decoder is not defined.")
# initialize decoder
init_weights = np.zeros((new_num_tokens, old_embedding_dim))
init_weights[:num_tokens_to_copy] = output_embeddings.decoder.value()[:num_tokens_to_copy, :]
output_embeddings.decoder = self.add_weight(
shape=(new_num_tokens, old_embedding_dim),
initializer="zeros",
trainable=True,
name=self.get_prefix_bias_name() + "/decoder/weight",
)
output_embeddings.decoder.assign(init_weights)
return new_embeddings return new_embeddings
def prune_heads(self, heads_to_prune): def prune_heads(self, heads_to_prune):
......
...@@ -467,6 +467,7 @@ class TFAlbertMLMHead(tf.keras.layers.Layer): ...@@ -467,6 +467,7 @@ class TFAlbertMLMHead(tf.keras.layers.Layer):
self.decoder_bias = self.add_weight( self.decoder_bias = self.add_weight(
shape=(self.vocab_size,), initializer="zeros", trainable=True, name="decoder/bias" shape=(self.vocab_size,), initializer="zeros", trainable=True, name="decoder/bias"
) )
super().build(input_shape) super().build(input_shape)
def call(self, hidden_states): def call(self, hidden_states):
...@@ -825,6 +826,32 @@ class TFAlbertForPreTraining(TFAlbertPreTrainedModel): ...@@ -825,6 +826,32 @@ class TFAlbertForPreTraining(TFAlbertPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.albert.embeddings return self.albert.embeddings
def resize_token_embeddings(self, new_num_tokens):
super().resize_token_embeddings(new_num_tokens=new_num_tokens)
# ALBERT is a special case where there are two bias to update
# even though self.bias is not used anywhere and is here
# just to make the loading weights from a PT model happy
if new_num_tokens is not None:
num_tokens_to_copy = min(self.predictions.bias.shape[0], new_num_tokens)
self.predictions.vocab_size = num_tokens_to_copy
init_bias = tf.zeros((new_num_tokens,))
init_bias[:num_tokens_to_copy] = self.predictions.bias.value()[:num_tokens_to_copy]
name = self.name + "/" + self.predictions.name + "/bias"
self.predictions.bias = self.add_weight(
shape=(new_num_tokens,), initializer="zeros", trainable=True, name=name
)
self.predictions.bias.assign(init_bias)
init_decoder_bias = tf.zeros((new_num_tokens,))
init_decoder_bias[:num_tokens_to_copy] = self.predictions.decoder_bias.value()[:num_tokens_to_copy]
name = self.name + "/" + self.predictions.name + "/decoder_bias"
self.predictions.decoder_bias = self.add_weight(
shape=(new_num_tokens,), initializer="zeros", trainable=True, name=name
)
self.predictions.decoder_bias.assign(init_decoder_bias)
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFAlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFAlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
...@@ -933,6 +960,32 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss) ...@@ -933,6 +960,32 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss)
def get_output_embeddings(self): def get_output_embeddings(self):
return self.albert.embeddings return self.albert.embeddings
def resize_token_embeddings(self, new_num_tokens):
super().resize_token_embeddings(new_num_tokens=new_num_tokens)
# ALBERT is a special case where there are two bias to update
# even though self.bias is not used anywhere and is here
# just to make the loading weights from a PT model happy
if new_num_tokens is not None:
num_tokens_to_copy = min(self.predictions.bias.shape[0], new_num_tokens)
self.predictions.vocab_size = num_tokens_to_copy
init_bias = tf.zeros((new_num_tokens,))
init_bias[:num_tokens_to_copy] = self.predictions.bias.value()[:num_tokens_to_copy]
name = self.name + "/" + self.predictions.name + "/bias"
self.predictions.bias = self.add_weight(
shape=(new_num_tokens,), initializer="zeros", trainable=True, name=name
)
self.predictions.bias.assign(init_bias)
init_decoder_bias = tf.zeros((new_num_tokens,))
init_decoder_bias[:num_tokens_to_copy] = self.predictions.decoder_bias.value()[:num_tokens_to_copy]
name = self.name + "/" + self.predictions.name + "/decoder_bias"
self.predictions.decoder_bias = self.add_weight(
shape=(new_num_tokens,), initializer="zeros", trainable=True, name=name
)
self.predictions.decoder_bias.assign(init_decoder_bias)
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
......
...@@ -1049,6 +1049,24 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel): ...@@ -1049,6 +1049,24 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel):
name="/final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False name="/final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
) )
def resize_token_embeddings(self, new_num_tokens):
super().resize_token_embeddings(new_num_tokens=new_num_tokens)
# BART is a special case where the bias has two dimensions
# and not named just `bias`
if new_num_tokens is not None:
num_tokens_to_copy = min(self.final_logits_bias.shape[0], new_num_tokens)
init_bias = tf.zeros((new_num_tokens,))
init_bias[:num_tokens_to_copy] = self.final_logits_bias.value()[:num_tokens_to_copy]
name = self.name + "/final_logits_bias"
self.final_logits_bias = self.add_weight(
shape=(1, new_num_tokens),
initializer="zeros",
trainable=False,
name=name,
)
self.final_logits_bias.assign(init_bias)
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
......
...@@ -893,6 +893,12 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss): ...@@ -893,6 +893,12 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.bert.embeddings return self.bert.embeddings
def get_output_layer_with_bias(self):
return self.mlm.predictions
def get_prefix_bias_name(self):
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
...@@ -1002,6 +1008,12 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss): ...@@ -1002,6 +1008,12 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.bert.embeddings return self.bert.embeddings
def get_output_layer_with_bias(self):
return self.mlm.predictions
def get_prefix_bias_name(self):
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
...@@ -1095,6 +1107,12 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1095,6 +1107,12 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.bert.embeddings return self.bert.embeddings
def get_output_layer_with_bias(self):
return self.mlm.predictions
def get_prefix_bias_name(self):
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="bert-base-cased", checkpoint="bert-base-cased",
......
...@@ -629,6 +629,12 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -629,6 +629,12 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head.input_embeddings return self.lm_head.input_embeddings
def get_output_layer_with_bias(self):
return self.lm_head
def get_prefix_bias_name(self):
return self.name + "/" + self.lm_head.name
def prepare_inputs_for_generation(self, inputs, past, **kwargs): def prepare_inputs_for_generation(self, inputs, past, **kwargs):
# only last token for inputs_ids if past is defined in kwargs # only last token for inputs_ids if past is defined in kwargs
if past: if past:
......
...@@ -655,6 +655,12 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModel ...@@ -655,6 +655,12 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModel
def get_output_embeddings(self): def get_output_embeddings(self):
return self.vocab_projector.input_embeddings return self.vocab_projector.input_embeddings
def get_output_layer_with_bias(self):
return self.vocab_projector
def get_prefix_bias_name(self):
return self.name + "/" + self.vocab_projector.name
@add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
......
...@@ -882,7 +882,7 @@ class TFElectraMaskedLMHead(tf.keras.layers.Layer): ...@@ -882,7 +882,7 @@ class TFElectraMaskedLMHead(tf.keras.layers.Layer):
super().build(input_shape) super().build(input_shape)
def call(self, hidden_states): def call(self, hidden_states, training=False):
hidden_states = self.input_embeddings(hidden_states, mode="linear") hidden_states = self.input_embeddings(hidden_states, mode="linear")
hidden_states = hidden_states + self.bias hidden_states = hidden_states + self.bias
...@@ -914,8 +914,14 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos ...@@ -914,8 +914,14 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos
self.generator_lm_head = TFElectraMaskedLMHead(config, self.electra.embeddings, name="generator_lm_head") self.generator_lm_head = TFElectraMaskedLMHead(config, self.electra.embeddings, name="generator_lm_head")
def get_output_embeddings(self): def get_output_embeddings(self):
return self.electra.embeddings
def get_output_layer_with_bias(self):
return self.generator_lm_head return self.generator_lm_head
def get_prefix_bias_name(self):
return self.name + "/" + self.generator_lm_head.name
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
......
...@@ -766,6 +766,12 @@ class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel): ...@@ -766,6 +766,12 @@ class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.pred_layer.input_embeddings return self.pred_layer.input_embeddings
def get_output_layer_with_bias(self):
return self.pred_layer
def get_prefix_bias_name(self):
return self.name + "/" + self.pred_layer.name
def prepare_inputs_for_generation(self, inputs, **kwargs): def prepare_inputs_for_generation(self, inputs, **kwargs):
mask_token_id = self.config.mask_token_id mask_token_id = self.config.mask_token_id
lang_id = self.config.lang_id lang_id = self.config.lang_id
......
...@@ -1320,6 +1320,15 @@ class TFFunnelForMaskedLM(TFFunnelPreTrainedModel, TFMaskedLanguageModelingLoss) ...@@ -1320,6 +1320,15 @@ class TFFunnelForMaskedLM(TFFunnelPreTrainedModel, TFMaskedLanguageModelingLoss)
self.funnel = TFFunnelMainLayer(config, name="funnel") self.funnel = TFFunnelMainLayer(config, name="funnel")
self.lm_head = TFFunnelMaskedLMHead(config, self.funnel.embeddings, name="lm_head") self.lm_head = TFFunnelMaskedLMHead(config, self.funnel.embeddings, name="lm_head")
def get_output_embeddings(self):
return self.funnel.embeddings
def get_output_layer_with_bias(self):
return self.lm_head
def get_prefix_bias_name(self):
return self.name + "/" + self.lm_head.name
@add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
......
...@@ -2009,6 +2009,12 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel ...@@ -2009,6 +2009,12 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head.decoder return self.lm_head.decoder
def get_output_layer_with_bias(self):
return self.lm_head
def get_prefix_bias_name(self):
return self.name + "/" + self.lm_head.name
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
......
...@@ -1257,6 +1257,15 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel): ...@@ -1257,6 +1257,15 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
**({"obj_labels": obj_labels} if self.config.task_obj_predict else {}), **({"obj_labels": obj_labels} if self.config.task_obj_predict else {}),
} }
def get_output_embeddings(self):
return self.lxmert.embeddings
def get_output_layer_with_bias(self):
return self.cls.predictions
def get_prefix_bias_name(self):
return self.name + "/" + self.cls.name + "/" + self.cls.predictions.name
@add_start_docstrings_to_model_forward(LXMERT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(LXMERT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFLxmertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFLxmertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
......
...@@ -702,6 +702,10 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer): ...@@ -702,6 +702,10 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embeddings return self.embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0]
def _resize_token_embeddings(self, new_num_tokens): def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError raise NotImplementedError
...@@ -1024,7 +1028,13 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel): ...@@ -1024,7 +1028,13 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel):
self.seq_relationship = TFMobileBertOnlyNSPHead(2, name="seq_relationship___cls") self.seq_relationship = TFMobileBertOnlyNSPHead(2, name="seq_relationship___cls")
def get_output_embeddings(self): def get_output_embeddings(self):
return self.mobilebert.embeddings return self.predictions.predictions
def get_output_layer_with_bias(self):
return self.predictions.predictions
def get_prefix_bias_name(self):
return self.name + "/" + self.predictions.name + "/" + self.predictions.predictions.name
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFMobileBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFMobileBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
...@@ -1117,7 +1127,13 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel ...@@ -1117,7 +1127,13 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
self.mlm = TFMobileBertMLMHead(config, name="mlm___cls") self.mlm = TFMobileBertMLMHead(config, name="mlm___cls")
def get_output_embeddings(self): def get_output_embeddings(self):
return self.mobilebert.embeddings return self.mlm.predictions
def get_output_layer_with_bias(self):
return self.mlm.predictions
def get_prefix_bias_name(self):
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
......
...@@ -830,6 +830,12 @@ class TFMPNetForMaskedLM(TFMPNetPreTrainedModel, TFMaskedLanguageModelingLoss): ...@@ -830,6 +830,12 @@ class TFMPNetForMaskedLM(TFMPNetPreTrainedModel, TFMaskedLanguageModelingLoss):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.mpnet.embeddings return self.mpnet.embeddings
def get_output_layer_with_bias(self):
return self.lm_head
def get_prefix_bias_name(self):
return self.name + "/" + self.lm_head.name
@add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
......
...@@ -810,6 +810,12 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos ...@@ -810,6 +810,12 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head.decoder return self.lm_head.decoder
def get_output_layer_with_bias(self):
return self.lm_head
def get_prefix_bias_name(self):
return self.name + "/" + self.lm_head.name
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
......
...@@ -803,6 +803,12 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel): ...@@ -803,6 +803,12 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.pred_layer.input_embeddings return self.pred_layer.input_embeddings
def get_output_layer_with_bias(self):
return self.pred_layer
def get_prefix_bias_name(self):
return self.name + "/" + self.pred_layer.name
def prepare_inputs_for_generation(self, inputs, **kwargs): def prepare_inputs_for_generation(self, inputs, **kwargs):
mask_token_id = self.config.mask_token_id mask_token_id = self.config.mask_token_id
lang_id = self.config.lang_id lang_id = self.config.lang_id
......
...@@ -1221,6 +1221,12 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1221,6 +1221,12 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_loss.input_embeddings return self.lm_loss.input_embeddings
def get_output_layer_with_bias(self):
return self.lm_loss
def get_prefix_bias_name(self):
return self.name + "/" + self.lm_loss.name
def prepare_inputs_for_generation(self, inputs, past, use_mems=None, **kwargs): def prepare_inputs_for_generation(self, inputs, past, use_mems=None, **kwargs):
# Add dummy token at the end (no attention on this one) # Add dummy token at the end (no attention on this one)
......
...@@ -772,10 +772,16 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca ...@@ -772,10 +772,16 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
self.{{cookiecutter.lowercase_modelname}} = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="{{cookiecutter.lowercase_modelname}}") self.{{cookiecutter.lowercase_modelname}} = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="{{cookiecutter.lowercase_modelname}}")
self.mlm = TF{{cookiecutter.camelcase_modelname}}MLMHead(config, self.{{cookiecutter.lowercase_modelname}}.embeddings, name="mlm___cls") self.mlm = TF{{cookiecutter.camelcase_modelname}}MLMHead(config, self.{{cookiecutter.lowercase_modelname}}.embeddings, name="mlm___cls")
def get_output_embeddings(self): def get_output_embeddings(self):
return self.{{cookiecutter.lowercase_modelname}}.embeddings return self.{{cookiecutter.lowercase_modelname}}.embeddings
def get_output_layer_with_bias(self):
return self.mlm.predictions
def get_prefix_bias_name(self):
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
......
...@@ -272,6 +272,17 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -272,6 +272,17 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_albert_for_question_answering(*config_and_inputs) self.model_tester.create_and_check_albert_for_question_answering(*config_and_inputs)
def test_model_common_attributes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
x = model.get_output_layer_with_bias()
assert x is None
name = model.get_prefix_bias_name()
assert name is None
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
...@@ -126,6 +126,17 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -126,6 +126,17 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
# Should be uncommented during patrick TF refactor # Should be uncommented during patrick TF refactor
pass pass
def test_model_common_attributes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
x = model.get_output_layer_with_bias()
assert x is None
name = model.get_prefix_bias_name()
assert name is None
@require_tf @require_tf
class TFBartHeadTests(unittest.TestCase): class TFBartHeadTests(unittest.TestCase):
......
...@@ -331,6 +331,25 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -331,6 +331,25 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
model = TFBertModel.from_pretrained("jplu/tiny-tf-bert-random") model = TFBertModel.from_pretrained("jplu/tiny-tf-bert-random")
self.assertIsNotNone(model) self.assertIsNotNone(model)
def test_model_common_attributes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
list_lm_models = [TFBertForMaskedLM, TFBertForPreTraining, TFBertLMHeadModel]
for model_class in self.all_model_classes:
model = model_class(config)
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
if model_class in list_lm_models:
x = model.get_output_layer_with_bias()
assert isinstance(x, tf.keras.layers.Layer)
name = model.get_prefix_bias_name()
assert isinstance(name, str)
else:
x = model.get_output_layer_with_bias()
assert x is None
name = model.get_prefix_bias_name()
assert x is None
def test_custom_load_tf_weights(self): def test_custom_load_tf_weights(self):
model, output_loading_info = TFBertForTokenClassification.from_pretrained( model, output_loading_info = TFBertForTokenClassification.from_pretrained(
"jplu/tiny-tf-bert-random", output_loading_info=True "jplu/tiny-tf-bert-random", output_loading_info=True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment