Unverified Commit 1243ee7d authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Full rework of the TF input/output embeddings and bias resizing (#9193)

* Start rework resizing

* Rework bias/decoder resizing

* Full resizing rework

* Full resizing rework

* Start to update the models with the new approach

* Finish to update the models

* Update all the tests

* Update the template

* Fix tests

* Fix tests

* Test a new approach

* Refactoring

* Refactoring

* Refactoring

* New rework

* Rework BART

* Rework bert+blenderbot

* Rework CTRL

* Rework Distilbert

* Rework DPR

* Rework Electra

* Rework Flaubert

* Rework Funnel

* Rework GPT2

* Rework Longformer

* Rework Lxmert

* Rework marian+mbart

* Rework mobilebert

* Rework mpnet

* Rework openai

* Rework pegasus

* Rework Roberta

* Rework T5

* Rework xlm+xlnet

* Rework template

* Fix TFT5EncoderOnly + DPRs

* Restore previous methods

* Fix Funnel

* Fix CTRL and TransforXL

* Apply style

* Apply Sylvain's comments

* Restore a test in DPR

* Address the comments

* Fix bug

* Apply style

* remove unused import

* Fix test

* Forgot a method

* missing test

* Trigger CI

* naming update

* Rebase

* Trigger CI
parent cf416764
...@@ -530,6 +530,46 @@ def load_tf_weights(model, resolved_archive_file): ...@@ -530,6 +530,46 @@ def load_tf_weights(model, resolved_archive_file):
return missing_layers, unexpected_layers return missing_layers, unexpected_layers
def init_copy_embeddings(old_embeddings, new_num_tokens):
r"""
This function aims to reduce the embeddings in case new_num_tokens < old_num_tokens or to pad with -1 in case
new_num_tokens > old_num_tokens. A mask is also computed in order to know which weight in the embeddings should be
kept or not. Example:
- if new_num_tokens=5 and old_num_tokens=4 and old_embeddings=[w1,w2,w3,w4]
- mask=[True,True,True,True,False] and current_weights=[w1,w2,w3,w4,-1]
- if new_num_tokens=4 and old_num_tokens=5 and old_embeddings=[w1,w2,w3,w4,w5]
- mask=[True,True,True,True] and current_weights=[w1,w2,w3,w4]
"""
old_num_tokens, old_embedding_dim = shape_list(old_embeddings)
size_diff = new_num_tokens - old_num_tokens
# initialize new embeddings
# Copy token embeddings from the previous ones
if tf.math.greater(size_diff, 0):
# if the new size is greater than the old one, we extend the current embeddings with a padding until getting new size
# and we create a mask to properly identify the padded values and be replaced by the values of the newly created
# embeddings
current_weights = tf.pad(
old_embeddings.value(), tf.convert_to_tensor([[0, size_diff], [0, 0]]), constant_values=-1
)
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
mask = tf.fill(tf.convert_to_tensor([num_tokens_to_copy, 1]), True)
mask = tf.pad(mask, tf.convert_to_tensor([[0, size_diff], [0, 0]]), constant_values=False)
else:
# if the new size if lower than the old one, we take the current embeddings until the new size
current_weights = tf.slice(
old_embeddings.value(),
tf.convert_to_tensor([0, 0]),
tf.convert_to_tensor([new_num_tokens, old_embedding_dim]),
)
mask = tf.fill(tf.convert_to_tensor([new_num_tokens, 1]), True)
return mask, current_weights
class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
r""" r"""
Base class for all TF models. Base class for all TF models.
...@@ -615,58 +655,132 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -615,58 +655,132 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
def get_input_embeddings(self) -> tf.keras.layers.Layer: def get_input_embeddings(self) -> tf.keras.layers.Layer:
""" """
Returns the model's input embeddings. Returns the model's input embeddings layer.
Returns: Returns:
:obj:`tf.keras.layers.Layer`: A torch module mapping vocabulary to hidden states. :obj:`tf.Variable`: The embeddings layer mapping vocabulary to hidden states.
""" """
base_model = getattr(self, self.base_model_prefix, self) main_layer = getattr(self, self.base_model_prefix, self)
if base_model is not self: if main_layer is not self:
return base_model.get_input_embeddings() return main_layer.get_input_embeddings()
else: else:
raise NotImplementedError raise NotImplementedError
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
""" """
Set model's input embeddings. Set model's input embeddings
Args: Args:
value (:obj:`tf.keras.layers.Layer`): value (:obj:`tf.Variable`):
A module mapping vocabulary to hidden states. The new weights mapping hidden states to vocabulary.
""" """
base_model = getattr(self, self.base_model_prefix, self) main_layer = getattr(self, self.base_model_prefix)
if base_model is not self:
base_model.set_input_embeddings(value)
else:
raise NotImplementedError
def get_output_embeddings(self) -> tf.keras.layers.Layer: if main_layer is None:
raise NotImplementedError("The model does not implements the base_model_prefix attribute.")
try:
main_layer.set_input_embeddings(value)
except AttributeError:
logger.info("Building the model")
self(self.dummy_inputs)
main_layer.set_input_embeddings(value)
def get_output_embeddings(self) -> Union[None, tf.keras.layers.Layer]:
""" """
Returns the model's output embeddings Returns the model's output embeddings
Returns: Returns:
:obj:`tf.keras.layers.Layer`: A torch module mapping hidden states to vocabulary. :obj:`tf.Variable`: The new weights mapping vocabulary to hidden states.
""" """
if self.get_lm_head() is not None:
lm_head = self.get_lm_head()
return lm_head.get_output_embeddings()
return None # Overwrite for models with output embeddings return None # Overwrite for models with output embeddings
def set_output_embeddings(self, value):
"""
Set model's output embeddings
Args:
value (:obj:`tf.Variable`):
The new weights mapping hidden states to vocabulary.
"""
if self.get_lm_head() is not None:
lm_head = self.get_lm_head()
try:
lm_head.set_output_embeddings(value)
except AttributeError:
logger.info("Building the model")
self(self.dummy_inputs)
lm_head.set_output_embeddings(value)
def get_output_layer_with_bias(self) -> Union[None, tf.keras.layers.Layer]: def get_output_layer_with_bias(self) -> Union[None, tf.keras.layers.Layer]:
""" """
Get the layer that handles a bias attribute in case the model has an LM head with weights tied to the Get the layer that handles a bias attribute in case the model has an LM head with weights tied to the
embeddings. embeddings
Return: Return:
:obj:`tf.keras.layers.Layer`: The layer that handles the bias, None if not an LM model. :obj:`tf.keras.layers.Layer`: The layer that handles the bias, None if not an LM model.
""" """
return None warnings.warn(
"The method get_output_layer_with_bias is deprecated. Please use `get_lm_head` instead.", FutureWarning
)
return self.get_lm_head()
def get_prefix_bias_name(self) -> Union[None, str]: def get_prefix_bias_name(self) -> Union[None, str]:
""" """
Get the concatenated prefix name of the bias from the model name to the parent layer. Get the concatenated prefix name of the bias from the model name to the parent layer
Return: Return:
:obj:`str`: The prefix name of the bias. :obj:`str`: The prefix name of the bias.
""" """
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return None
def get_bias(self) -> Union[None, Dict[str, tf.Variable]]:
"""
Dict of bias attached to an LM head. The key represents the name of the bias attribute.
Return:
:obj:`tf.Variable`: The weights representing the bias, None if not an LM model.
"""
if self.get_lm_head() is not None:
lm_head = self.get_lm_head()
try:
return lm_head.get_bias()
except AttributeError:
self(self.dummy_inputs)
return lm_head.get_bias()
return None
def set_bias(self, value):
"""
Set all the bias in the LM head.
Args:
value (:obj:`Dict[tf.Variable]`):
All the new bias attached to an LM head.
"""
if self.get_lm_head() is not None:
lm_head = self.get_lm_head()
try:
lm_head.set_bias(value)
except AttributeError:
self(self.dummy_inputs)
lm_head.set_bias(value)
def get_lm_head(self) -> tf.keras.layers.Layer:
"""
The LM Head layer. This method must be overwritten by all the models that have a lm head.
Return:
:obj:`tf.keras.layers.Layer`: The LM head layer if the model has one, None if not.
"""
return None return None
def resize_token_embeddings(self, new_num_tokens=None) -> tf.Variable: def resize_token_embeddings(self, new_num_tokens=None) -> tf.Variable:
...@@ -685,143 +799,179 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -685,143 +799,179 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
Return: Return:
:obj:`tf.Variable`: Pointer to the input tokens Embeddings Module of the model. :obj:`tf.Variable`: Pointer to the input tokens Embeddings Module of the model.
""" """
model_embeds = self._resize_token_embeddings(new_num_tokens) if new_num_tokens is None or new_num_tokens == self.config.vocab_size:
if new_num_tokens is None: return self._get_word_embedding_weight(self.get_input_embeddings())
return model_embeds
return model_embeds model_embeds = self._resize_token_embeddings(new_num_tokens)
def _resize_token_embeddings(self, new_num_tokens):
# get_input_embeddings and set_input_embeddings need to be implemented in base layer.
base_model = getattr(self, self.base_model_prefix, self)
old_embeddings = base_model.get_input_embeddings()
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
base_model.set_input_embeddings(new_embeddings)
# Update base model and current model config # Update base model and current model config
self.config.vocab_size = new_num_tokens self.config.vocab_size = new_num_tokens
base_model.vocab_size = new_num_tokens
return base_model.get_input_embeddings() return model_embeds
def _get_word_embeddings(self, embeddings): def _get_word_embedding_weight(self, embedding_layer):
if hasattr(embeddings, "word_embeddings"): if hasattr(embedding_layer, "word_embeddings"):
# TFBertEmbeddings, TFAlbertEmbeddings, TFElectraEmbeddings return embedding_layer.word_embeddings
return embeddings.word_embeddings elif hasattr(embedding_layer, "weight"):
elif hasattr(embeddings, "weight"): return embedding_layer.weight
# TFSharedEmbeddings elif hasattr(embedding_layer, "decoder"):
return embeddings.weight return embedding_layer.decoder
else: else:
# Here we build the word embeddings weights if not exists. # Here we build the word embeddings weights if not exists.
# And then we retry to get the attribute once built. # And then we retry to get the attribute once built.
embeddings.build([]) self(self.dummy_inputs)
if hasattr(embeddings, "word_embeddings"): if hasattr(embedding_layer, "word_embeddings"):
# TFBertEmbeddings, TFAlbertEmbeddings, TFElectraEmbeddings return embedding_layer.word_embeddings
return embeddings.word_embeddings elif hasattr(embedding_layer, "weight"):
elif hasattr(embeddings, "weight"): return embedding_layer.weight
# TFSharedEmbeddings elif hasattr(embedding_layer, "decoder"):
return embeddings.weight return embedding_layer.decoder
else: else:
raise ValueError("word embedding is not defined.") return None
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None) -> tf.Variable: def _resize_token_embeddings(self, new_num_tokens):
old_embeddings = self._get_word_embedding_weight(self.get_input_embeddings())
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
# if word embeddings are not tied, make sure that lm head bias is resized as well
if self.get_bias() is not None:
old_lm_head_bias = self.get_bias()
new_lm_head_bias = self._get_resized_lm_head_bias(old_lm_head_bias, new_num_tokens)
self.set_bias(new_lm_head_bias)
# if word embeddings are not tied, make sure that lm head decoder is resized as well
if self.get_output_embeddings() is not None:
old_lm_head_decoder = self._get_word_embedding_weight(self.get_output_embeddings())
new_lm_head_decoder = self._get_resized_lm_head_decoder(old_lm_head_decoder, new_num_tokens)
self.set_output_embeddings(new_lm_head_decoder)
self.set_input_embeddings(new_embeddings)
return self.get_input_embeddings()
def _get_resized_lm_head_bias(self, old_lm_head_bias, new_num_tokens):
""" """
Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly Build a resized bias from the old ones. Increasing the size will add newly initialized vectors at the end.
initialized vectors at the end. Reducing the size will remove vectors from the end Reducing the size will remove vectors from the end
Args: Args:
old_embeddings (:obj:`tf.Variable`): old_lm_head_bias (:obj:`tf.Variable`):
Old embeddings to be resized. Old lm head bias to be resized.
new_num_tokens (:obj:`int`, `optional`): new_num_tokens (:obj:`int`, `optional`):
New number of tokens in the embedding matrix. New number of tokens in the linear matrix.
Increasing the size will add newly initialized vectors at the end. Reducing the size will remove Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
vectors from the end. If not provided or :obj:`None`, just returns a pointer to the input tokens vectors from the end. If not provided or :obj:`None`, just returns None
:obj:`tf.Variable`` module of the model without doing anything.
Return: Return:
:obj:`tf.Variable`: Pointer to the resized Embedding Module or the old Embedding Module if :obj:`tf.Variable`: Pointer to the resized bias.
:obj:`new_num_tokens` is :obj:`None`
""" """
word_embeddings = self._get_word_embeddings(old_embeddings) new_lm_head_bias = {}
bias_layer = self.get_output_layer_with_bias()
if new_num_tokens is None:
return word_embeddings
old_num_tokens, old_embedding_dim = word_embeddings.shape for attr, weight in old_lm_head_bias.items():
first_dim, old_num_tokens = (None, shape_list(weight)[0]) if tf.rank(weight) == 1 else shape_list(weight)
size_diff = new_num_tokens - old_num_tokens
final_shape = [new_num_tokens] if first_dim is None else [first_dim, new_num_tokens]
if old_num_tokens == new_num_tokens: # initialize new bias
return word_embeddings if tf.math.greater(size_diff, 0):
padding_shape = [[0, size_diff]] if first_dim is None else [[0, 0], [0, size_diff]]
# initialize new embeddings current_bias = tf.pad(weight.value(), tf.convert_to_tensor(padding_shape), constant_values=-1)
# todo: initializer range is not always passed in config. num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
init_range = getattr(self.config, "initializer_range", 0.02) mask_shape = [num_tokens_to_copy] if first_dim is None else [1, num_tokens_to_copy]
name = ( bias_mask = tf.fill(tf.convert_to_tensor(mask_shape), True)
self.name bias_mask = tf.pad(bias_mask, tf.convert_to_tensor(padding_shape), constant_values=False)
+ "/" else:
+ self.base_model_prefix slice_from = [0] if first_dim is None else [0, 0]
+ "/" current_bias = tf.slice(
+ old_embeddings.name weight.value(), tf.convert_to_tensor(slice_from), tf.convert_to_tensor(final_shape)
+ "/"
+ word_embeddings.name.split(":")[0]
) )
new_embeddings = self.add_weight( bias_mask = tf.fill(tf.convert_to_tensor(final_shape), True)
name=name,
shape=[new_num_tokens, old_embedding_dim], new_bias = self.add_weight(
initializer=get_initializer(init_range), shape=final_shape,
dtype=tf.float32, initializer="zeros",
trainable=True,
name=weight.name.split(":")[0],
) )
init_weights = tf.make_ndarray(tf.make_tensor_proto(new_embeddings.value())) init_bias = tf.where(bias_mask, current_bias, new_bias.value())
# Copy token embeddings from the previous weights new_bias.assign(init_bias)
num_tokens_to_copy = min(old_num_tokens, new_num_tokens) new_lm_head_bias[attr] = new_bias
init_weights[:num_tokens_to_copy] = word_embeddings.value()[:num_tokens_to_copy, :]
new_embeddings.assign(init_weights)
if bias_layer is not None: return new_lm_head_bias
if not hasattr(bias_layer, "bias"):
bias_layer.build([])
# Second check in order to be sure the attribute has been properly created def _get_resized_lm_head_decoder(self, old_lm_head_decoder, new_num_tokens):
if not hasattr(bias_layer, "bias"): """
raise ValueError("bias is not defined.") Build a resized decoder from the old ones. Increasing the size will add newly initialized vectors at the end.
Reducing the size will remove vectors from the end
# initialize bias Args:
init_bias = np.zeros((new_num_tokens,)) old_lm_head_decoder (:obj:`tf.Variable`):
init_bias[:num_tokens_to_copy] = bias_layer.bias.value()[ Old lm head decoder to be resized.
:num_tokens_to_copy new_num_tokens (:obj:`int`, `optional`):
] # tf.make_ndarray(tf.make_tensor_proto(bias_layer.bias.value()))[:num_tokens_to_copy] New number of tokens in the linear matrix.
bias_layer.bias = self.add_weight( Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
shape=(new_num_tokens,), vectors from the end. If not provided or :obj:`None`, just returns None
Return:
:obj:`tf.Variable`: Pointer to the resized decoder or None if the output embeddings are differents of the
input ones.
"""
new_lm_head_decoder = old_lm_head_decoder
is_input_output_equals = tf.reduce_any(
self._get_word_embedding_weight(self.get_input_embeddings()) == old_lm_head_decoder
)
if old_lm_head_decoder is not None and not is_input_output_equals:
old_embedding_dim = shape_list(old_lm_head_decoder)[1]
decoder_mask, current_decoder = init_copy_embeddings(old_lm_head_decoder, new_num_tokens)
new_lm_head_decoder = self.add_weight(
shape=(new_num_tokens, old_embedding_dim),
initializer="zeros", initializer="zeros",
trainable=True, trainable=True,
name=self.get_prefix_bias_name() + "/bias", name=old_lm_head_decoder.name.split(":")[0],
) )
init_decoder = tf.where(decoder_mask, current_decoder, new_lm_head_decoder.value())
bias_layer.bias.assign(init_bias) new_lm_head_decoder.assign(init_decoder)
output_embeddings = self.get_output_embeddings() return new_lm_head_decoder
if output_embeddings is not None: def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None) -> tf.Variable:
if self.get_input_embeddings() != output_embeddings: """
if not hasattr(output_embeddings, "decoder"): Build a resized Embedding weights from a provided token Embedding weights. Increasing the size will add newly
output_embeddings.build([]) initialized vectors at the end. Reducing the size will remove vectors from the end
# Second check in order to be sure the attribute has been properly created Args:
if not hasattr(output_embeddings, "decoder"): old_embeddings (:obj:`tf.Variable`):
raise ValueError("decoder is not defined.") Old embeddings to be resized.
new_num_tokens (:obj:`int`, `optional`):
New number of tokens in the embedding matrix.
# initialize decoder Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
init_weights = np.zeros((new_num_tokens, old_embedding_dim)) vectors from the end. If not provided or :obj:`None`, just returns a pointer to the input tokens
init_weights[:num_tokens_to_copy] = output_embeddings.decoder.value()[:num_tokens_to_copy, :] :obj:`tf.Variable`` module of the model without doing anything.
output_embeddings.decoder = self.add_weight( Return:
shape=(new_num_tokens, old_embedding_dim), :obj:`tf.Variable`: Pointer to the resized Embedding Module or the old Embedding Module if
initializer="zeros", :obj:`new_num_tokens` is :obj:`None`
trainable=True, """
name=self.get_prefix_bias_name() + "/decoder/weight", old_embedding_dim = shape_list(old_embeddings)[1]
init_range = getattr(self.config, "initializer_range", 0.02)
embeddings_mask, current_embeddings = init_copy_embeddings(old_embeddings, new_num_tokens)
new_embeddings = self.add_weight(
name=old_embeddings.name.split(":")[0],
shape=[new_num_tokens, old_embedding_dim],
initializer=get_initializer(init_range),
dtype=tf.float32,
) )
output_embeddings.decoder.assign(init_weights) init_embeddings = tf.where(embeddings_mask, current_embeddings, new_embeddings.value())
new_embeddings.assign(init_embeddings)
return new_embeddings return new_embeddings
......
...@@ -470,6 +470,21 @@ class TFAlbertMLMHead(tf.keras.layers.Layer): ...@@ -470,6 +470,21 @@ class TFAlbertMLMHead(tf.keras.layers.Layer):
super().build(input_shape) super().build(input_shape)
def get_output_embeddings(self):
return self.decoder
def set_output_embeddings(self, value):
self.decoder.word_embeddings = value
self.decoder.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias, "decoder_bias": self.decoder_bias}
def set_bias(self, value):
self.bias = value["bias"]
self.decoder_bias = value["decoder_bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.activation(hidden_states) hidden_states = self.activation(hidden_states)
...@@ -505,10 +520,7 @@ class TFAlbertMainLayer(tf.keras.layers.Layer): ...@@ -505,10 +520,7 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0] self.embeddings.vocab_size = shape_list(value)[0]
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" """
...@@ -835,34 +847,8 @@ class TFAlbertForPreTraining(TFAlbertPreTrainedModel): ...@@ -835,34 +847,8 @@ class TFAlbertForPreTraining(TFAlbertPreTrainedModel):
self.predictions = TFAlbertMLMHead(config, self.albert.embeddings, name="predictions") self.predictions = TFAlbertMLMHead(config, self.albert.embeddings, name="predictions")
self.sop_classifier = TFAlbertSOPHead(config, name="sop_classifier") self.sop_classifier = TFAlbertSOPHead(config, name="sop_classifier")
def get_output_embeddings(self): def get_lm_head(self):
return self.albert.embeddings return self.predictions
def resize_token_embeddings(self, new_num_tokens):
super().resize_token_embeddings(new_num_tokens=new_num_tokens)
# ALBERT is a special case where there are two bias to update
# even though self.bias is not used anywhere and is here
# just to make the loading weights from a PT model happy
if new_num_tokens is not None:
num_tokens_to_copy = min(self.predictions.bias.shape[0], new_num_tokens)
self.predictions.vocab_size = num_tokens_to_copy
init_bias = tf.zeros((new_num_tokens,))
init_bias[:num_tokens_to_copy] = self.predictions.bias.value()[:num_tokens_to_copy]
name = self.name + "/" + self.predictions.name + "/bias"
self.predictions.bias = self.add_weight(
shape=(new_num_tokens,), initializer="zeros", trainable=True, name=name
)
self.predictions.bias.assign(init_bias)
init_decoder_bias = tf.zeros((new_num_tokens,))
init_decoder_bias[:num_tokens_to_copy] = self.predictions.decoder_bias.value()[:num_tokens_to_copy]
name = self.name + "/" + self.predictions.name + "/decoder_bias"
self.predictions.decoder_bias = self.add_weight(
shape=(new_num_tokens,), initializer="zeros", trainable=True, name=name
)
self.predictions.decoder_bias.assign(init_decoder_bias)
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFAlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFAlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
...@@ -980,34 +966,8 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss) ...@@ -980,34 +966,8 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss)
self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert") self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert")
self.predictions = TFAlbertMLMHead(config, self.albert.embeddings, name="predictions") self.predictions = TFAlbertMLMHead(config, self.albert.embeddings, name="predictions")
def get_output_embeddings(self): def get_lm_head(self):
return self.albert.embeddings return self.predictions
def resize_token_embeddings(self, new_num_tokens):
super().resize_token_embeddings(new_num_tokens=new_num_tokens)
# ALBERT is a special case where there are two bias to update
# even though self.bias is not used anywhere and is here
# just to make the loading weights from a PT model happy
if new_num_tokens is not None:
num_tokens_to_copy = min(self.predictions.bias.shape[0], new_num_tokens)
self.predictions.vocab_size = num_tokens_to_copy
init_bias = tf.zeros((new_num_tokens,))
init_bias[:num_tokens_to_copy] = self.predictions.bias.value()[:num_tokens_to_copy]
name = self.name + "/" + self.predictions.name + "/bias"
self.predictions.bias = self.add_weight(
shape=(new_num_tokens,), initializer="zeros", trainable=True, name=name
)
self.predictions.bias.assign(init_bias)
init_decoder_bias = tf.zeros((new_num_tokens,))
init_decoder_bias[:num_tokens_to_copy] = self.predictions.decoder_bias.value()[:num_tokens_to_copy]
name = self.name + "/" + self.predictions.name + "/decoder_bias"
self.predictions.decoder_bias = self.add_weight(
shape=(new_num_tokens,), initializer="zeros", trainable=True, name=name
)
self.predictions.decoder_bias.assign(init_decoder_bias)
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
......
...@@ -481,6 +481,29 @@ class TFBartPretrainedModel(TFPreTrainedModel): ...@@ -481,6 +481,29 @@ class TFBartPretrainedModel(TFPreTrainedModel):
} }
return dummy_inputs return dummy_inputs
def get_input_embeddings(self):
base_model = getattr(self, self.base_model_prefix, self)
return base_model.shared
def set_input_embeddings(self, value):
base_model = getattr(self, self.base_model_prefix, self)
try:
base_model.shared.weight = value
except AttributeError:
self(self.dummy_inputs)
base_model.shared.weight = value
base_model.shared.vocab_size = shape_list(base_model.shared.weight)[0]
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
pass
embed_tokens = TFWrappedEmbeddings(base_model.shared, abs_scope_name=shared_abs_scope_name)
base_model.encoder.set_embed_tokens(embed_tokens)
base_model.decoder.set_embed_tokens(embed_tokens)
@tf.function( @tf.function(
input_signature=[ input_signature=[
{ {
...@@ -634,6 +657,9 @@ class TFBartEncoder(tf.keras.layers.Layer): ...@@ -634,6 +657,9 @@ class TFBartEncoder(tf.keras.layers.Layer):
else None else None
) )
def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -791,6 +817,9 @@ class TFBartDecoder(tf.keras.layers.Layer): ...@@ -791,6 +817,9 @@ class TFBartDecoder(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(config.dropout) self.dropout = tf.keras.layers.Dropout(config.dropout)
self.do_blenderbot_90_layernorm = config.do_blenderbot_90_layernorm self.do_blenderbot_90_layernorm = config.do_blenderbot_90_layernorm
def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -1009,6 +1038,9 @@ class TFBartModel(TFBartPretrainedModel): ...@@ -1009,6 +1038,9 @@ class TFBartModel(TFBartPretrainedModel):
self.encoder = TFBartEncoder(config, embed_tokens, name="encoder") self.encoder = TFBartEncoder(config, embed_tokens, name="encoder")
self.decoder = TFBartDecoder(config, embed_tokens, name="decoder") self.decoder = TFBartDecoder(config, embed_tokens, name="decoder")
def get_encoder(self):
return self.encoder
def get_decoder(self): def get_decoder(self):
return self.decoder return self.decoder
...@@ -1134,15 +1166,6 @@ class TFBartModel(TFBartPretrainedModel): ...@@ -1134,15 +1166,6 @@ class TFBartModel(TFBartPretrainedModel):
encoder_attentions=enc_attns, encoder_attentions=enc_attns,
) )
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, value):
self.shared = value
def get_output_embeddings(self):
return self.shared
@add_start_docstrings( @add_start_docstrings(
"The BART Model with a language modeling head. Can be used for summarization.", "The BART Model with a language modeling head. Can be used for summarization.",
...@@ -1166,22 +1189,20 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel): ...@@ -1166,22 +1189,20 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel):
def get_decoder(self): def get_decoder(self):
return self.model.decoder return self.model.decoder
def resize_token_embeddings(self, new_num_tokens): def get_encoder(self):
super().resize_token_embeddings(new_num_tokens=new_num_tokens) return self.model.encoder
# BART is a special case where the bias has two dimensions def get_output_embeddings(self):
# and not named just `bias` return self.get_input_embeddings()
if new_num_tokens is not None:
num_tokens_to_copy = min(self.final_logits_bias.shape[0], new_num_tokens) def set_output_embeddings(self, value):
init_bias = tf.zeros((new_num_tokens,)) self.set_input_embeddings(value)
init_bias[:num_tokens_to_copy] = self.final_logits_bias.value()[:num_tokens_to_copy]
self.final_logits_bias = self.add_weight( def get_bias(self):
shape=(1, new_num_tokens), return {"final_logits_bias": self.final_logits_bias}
initializer="zeros",
trainable=False, def set_bias(self, value):
name="final_logits_bias", self.final_logits_bias = value["final_logits_bias"]
)
self.final_logits_bias.assign(init_bias)
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
...@@ -1356,12 +1377,6 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel): ...@@ -1356,12 +1377,6 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel):
else: else:
return logits return logits
def get_output_embeddings(self):
return self.model.shared
def get_encoder(self):
return self.model.encoder
def compute_loss(self, labels, logits): def compute_loss(self, labels, logits):
"""CrossEntropyLoss that ignores pad tokens""" """CrossEntropyLoss that ignores pad tokens"""
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
""" TF 2.0 BERT model. """ """ TF 2.0 BERT model. """
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
...@@ -526,6 +527,20 @@ class TFBertLMPredictionHead(tf.keras.layers.Layer): ...@@ -526,6 +527,20 @@ class TFBertLMPredictionHead(tf.keras.layers.Layer):
super().build(input_shape) super().build(input_shape)
def get_output_embeddings(self):
return self.input_embeddings
def set_output_embeddings(self, value):
self.input_embeddings.word_embeddings = value
self.input_embeddings.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias}
def set_bias(self, value):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.transform(hidden_states) hidden_states = self.transform(hidden_states)
hidden_states = self.input_embeddings(hidden_states, mode="linear") hidden_states = self.input_embeddings(hidden_states, mode="linear")
...@@ -582,7 +597,7 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -582,7 +597,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0] self.embeddings.vocab_size = shape_list(value)[0]
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" """
...@@ -918,13 +933,11 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss): ...@@ -918,13 +933,11 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
self.nsp = TFBertNSPHead(config, name="nsp___cls") self.nsp = TFBertNSPHead(config, name="nsp___cls")
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls") self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")
def get_output_embeddings(self): def get_lm_head(self):
return self.bert.embeddings
def get_output_layer_with_bias(self):
return self.mlm.predictions return self.mlm.predictions
def get_prefix_bias_name(self): def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
...@@ -1044,13 +1057,11 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss): ...@@ -1044,13 +1057,11 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert") self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls") self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")
def get_output_embeddings(self): def get_lm_head(self):
return self.bert.embeddings
def get_output_layer_with_bias(self):
return self.mlm.predictions return self.mlm.predictions
def get_prefix_bias_name(self): def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
...@@ -1149,13 +1160,11 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1149,13 +1160,11 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert") self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls") self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")
def get_output_embeddings(self): def get_lm_head(self):
return self.bert.embeddings
def get_output_layer_with_bias(self):
return self.mlm.predictions return self.mlm.predictions
def get_prefix_bias_name(self): def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
@add_code_sample_docstrings( @add_code_sample_docstrings(
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
# limitations under the License. # limitations under the License.
""" TF 2.0 CTRL model.""" """ TF 2.0 CTRL model."""
import warnings
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -242,10 +244,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -242,10 +244,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.w.weight = value self.w.weight = value
self.w.vocab_size = value.shape[0] self.w.vocab_size = shape_list(value)[0]
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" """
...@@ -618,6 +617,20 @@ class TFCTRLLMHead(tf.keras.layers.Layer): ...@@ -618,6 +617,20 @@ class TFCTRLLMHead(tf.keras.layers.Layer):
self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias") self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape) super().build(input_shape)
def get_output_embeddings(self):
return self.input_embeddings
def set_output_embeddings(self, value):
self.input_embeddings.weight = value
self.input_embeddings.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias}
def set_bias(self, value):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.input_embeddings(hidden_states, mode="linear") hidden_states = self.input_embeddings(hidden_states, mode="linear")
hidden_states = hidden_states + self.bias hidden_states = hidden_states + self.bias
...@@ -638,13 +651,11 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -638,13 +651,11 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
self.lm_head = TFCTRLLMHead(config, self.transformer.w, name="lm_head") self.lm_head = TFCTRLLMHead(config, self.transformer.w, name="lm_head")
def get_output_embeddings(self): def get_lm_head(self):
return self.lm_head.input_embeddings
def get_output_layer_with_bias(self):
return self.lm_head return self.lm_head
def get_prefix_bias_name(self): def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.lm_head.name return self.name + "/" + self.lm_head.name
def prepare_inputs_for_generation(self, inputs, past, **kwargs): def prepare_inputs_for_generation(self, inputs, past, **kwargs):
......
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
TF 2.0 DistilBERT model TF 2.0 DistilBERT model
""" """
import warnings
import tensorflow as tf import tensorflow as tf
from ...activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
...@@ -39,7 +41,6 @@ from ...modeling_tf_utils import ( ...@@ -39,7 +41,6 @@ from ...modeling_tf_utils import (
TFPreTrainedModel, TFPreTrainedModel,
TFQuestionAnsweringLoss, TFQuestionAnsweringLoss,
TFSequenceClassificationLoss, TFSequenceClassificationLoss,
TFSharedEmbeddings,
TFTokenClassificationLoss, TFTokenClassificationLoss,
get_initializer, get_initializer,
input_processing, input_processing,
...@@ -72,9 +73,6 @@ class TFEmbeddings(tf.keras.layers.Layer): ...@@ -72,9 +73,6 @@ class TFEmbeddings(tf.keras.layers.Layer):
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.dim = config.dim self.dim = config.dim
self.initializer_range = config.initializer_range self.initializer_range = config.initializer_range
self.word_embeddings = TFSharedEmbeddings(
config.vocab_size, config.dim, initializer_range=config.initializer_range, name="word_embeddings"
) # padding_idx=0)
self.position_embeddings = tf.keras.layers.Embedding( self.position_embeddings = tf.keras.layers.Embedding(
config.max_position_embeddings, config.max_position_embeddings,
config.dim, config.dim,
...@@ -648,6 +646,20 @@ class TFDistilBertLMHead(tf.keras.layers.Layer): ...@@ -648,6 +646,20 @@ class TFDistilBertLMHead(tf.keras.layers.Layer):
self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias") self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape) super().build(input_shape)
def get_output_embeddings(self):
return self.input_embeddings
def set_output_embeddings(self, value):
self.input_embeddings.word_embeddings = value
self.input_embeddings.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias}
def set_bias(self, value):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.input_embeddings(hidden_states, mode="linear") hidden_states = self.input_embeddings(hidden_states, mode="linear")
hidden_states = hidden_states + self.bias hidden_states = hidden_states + self.bias
...@@ -671,13 +683,11 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModel ...@@ -671,13 +683,11 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModel
self.vocab_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="vocab_layer_norm") self.vocab_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="vocab_layer_norm")
self.vocab_projector = TFDistilBertLMHead(config, self.distilbert.embeddings, name="vocab_projector") self.vocab_projector = TFDistilBertLMHead(config, self.distilbert.embeddings, name="vocab_projector")
def get_output_embeddings(self): def get_lm_head(self):
return self.vocab_projector.input_embeddings
def get_output_layer_with_bias(self):
return self.vocab_projector return self.vocab_projector
def get_prefix_bias_name(self): def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.vocab_projector.name return self.name + "/" + self.vocab_projector.name
@add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
......
...@@ -577,6 +577,10 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder): ...@@ -577,6 +577,10 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder):
self.ctx_encoder = TFDPREncoderLayer(config, name="ctx_encoder") self.ctx_encoder = TFDPREncoderLayer(config, name="ctx_encoder")
def get_input_embeddings(self): def get_input_embeddings(self):
try:
return self.ctx_encoder.bert_model.get_input_embeddings()
except AttributeError:
self(self.dummy_inputs)
return self.ctx_encoder.bert_model.get_input_embeddings() return self.ctx_encoder.bert_model.get_input_embeddings()
@add_start_docstrings_to_model_forward(TF_DPR_ENCODERS_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(TF_DPR_ENCODERS_INPUTS_DOCSTRING)
...@@ -671,6 +675,10 @@ class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder): ...@@ -671,6 +675,10 @@ class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder):
self.question_encoder = TFDPREncoderLayer(config, name="question_encoder") self.question_encoder = TFDPREncoderLayer(config, name="question_encoder")
def get_input_embeddings(self): def get_input_embeddings(self):
try:
return self.question_encoder.bert_model.get_input_embeddings()
except AttributeError:
self(self.dummy_inputs)
return self.question_encoder.bert_model.get_input_embeddings() return self.question_encoder.bert_model.get_input_embeddings()
@add_start_docstrings_to_model_forward(TF_DPR_ENCODERS_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(TF_DPR_ENCODERS_INPUTS_DOCSTRING)
...@@ -764,6 +772,10 @@ class TFDPRReader(TFDPRPretrainedReader): ...@@ -764,6 +772,10 @@ class TFDPRReader(TFDPRPretrainedReader):
self.span_predictor = TFDPRSpanPredictorLayer(config, name="span_predictor") self.span_predictor = TFDPRSpanPredictorLayer(config, name="span_predictor")
def get_input_embeddings(self): def get_input_embeddings(self):
try:
return self.span_predictor.encoder.bert_model.get_input_embeddings()
except AttributeError:
self(self.dummy_inputs)
return self.span_predictor.encoder.bert_model.get_input_embeddings() return self.span_predictor.encoder.bert_model.get_input_embeddings()
@add_start_docstrings_to_model_forward(TF_DPR_READER_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(TF_DPR_READER_INPUTS_DOCSTRING)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
""" TF Electra model. """ """ TF Electra model. """
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
...@@ -511,10 +512,7 @@ class TFElectraMainLayer(tf.keras.layers.Layer): ...@@ -511,10 +512,7 @@ class TFElectraMainLayer(tf.keras.layers.Layer):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0] self.embeddings.vocab_size = shape_list(value)[0]
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" """
...@@ -912,6 +910,20 @@ class TFElectraMaskedLMHead(tf.keras.layers.Layer): ...@@ -912,6 +910,20 @@ class TFElectraMaskedLMHead(tf.keras.layers.Layer):
super().build(input_shape) super().build(input_shape)
def get_output_embeddings(self):
return self.input_embeddings
def set_output_embeddings(self, value):
self.input_embeddings.word_embeddings = value
self.input_embeddings.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias}
def set_bias(self, value):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states, training=False): def call(self, hidden_states, training=False):
hidden_states = self.input_embeddings(hidden_states, mode="linear") hidden_states = self.input_embeddings(hidden_states, mode="linear")
hidden_states = hidden_states + self.bias hidden_states = hidden_states + self.bias
...@@ -943,13 +955,11 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos ...@@ -943,13 +955,11 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos
self.generator_lm_head = TFElectraMaskedLMHead(config, self.electra.embeddings, name="generator_lm_head") self.generator_lm_head = TFElectraMaskedLMHead(config, self.electra.embeddings, name="generator_lm_head")
def get_output_embeddings(self): def get_lm_head(self):
return self.electra.embeddings
def get_output_layer_with_bias(self):
return self.generator_lm_head return self.generator_lm_head
def get_prefix_bias_name(self): def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.generator_lm_head.name return self.name + "/" + self.generator_lm_head.name
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import itertools import itertools
import random import random
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
...@@ -478,6 +479,10 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer): ...@@ -478,6 +479,10 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embeddings return self.embeddings
def set_input_embeddings(self, value):
self.embeddings.weight = value
self.embeddings.vocab_size = shape_list(value)[0]
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -721,6 +726,20 @@ class TFFlaubertPredLayer(tf.keras.layers.Layer): ...@@ -721,6 +726,20 @@ class TFFlaubertPredLayer(tf.keras.layers.Layer):
super().build(input_shape) super().build(input_shape)
def get_output_embeddings(self):
return self.input_embeddings
def set_output_embeddings(self, value):
self.input_embeddings.weight = value
self.input_embeddings.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias}
def set_bias(self, value):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.input_embeddings(hidden_states, mode="linear") hidden_states = self.input_embeddings(hidden_states, mode="linear")
hidden_states = hidden_states + self.bias hidden_states = hidden_states + self.bias
...@@ -767,13 +786,11 @@ class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel): ...@@ -767,13 +786,11 @@ class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel):
self.transformer = TFFlaubertMainLayer(config, name="transformer") self.transformer = TFFlaubertMainLayer(config, name="transformer")
self.pred_layer = TFFlaubertPredLayer(config, self.transformer.embeddings, name="pred_layer_._proj") self.pred_layer = TFFlaubertPredLayer(config, self.transformer.embeddings, name="pred_layer_._proj")
def get_output_embeddings(self): def get_lm_head(self):
return self.pred_layer.input_embeddings
def get_output_layer_with_bias(self):
return self.pred_layer return self.pred_layer
def get_prefix_bias_name(self): def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.pred_layer.name return self.name + "/" + self.pred_layer.name
def prepare_inputs_for_generation(self, inputs, **kwargs): def prepare_inputs_for_generation(self, inputs, **kwargs):
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
""" TF 2.0 Funnel model. """ """ TF 2.0 Funnel model. """
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
...@@ -787,7 +788,7 @@ class TFFunnelBaseLayer(tf.keras.layers.Layer): ...@@ -787,7 +788,7 @@ class TFFunnelBaseLayer(tf.keras.layers.Layer):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0] self.embeddings.vocab_size = shape_list(value)[0]
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models
...@@ -873,7 +874,7 @@ class TFFunnelMainLayer(tf.keras.layers.Layer): ...@@ -873,7 +874,7 @@ class TFFunnelMainLayer(tf.keras.layers.Layer):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0] self.embeddings.vocab_size = shape_list(value)[0]
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models
...@@ -992,6 +993,20 @@ class TFFunnelMaskedLMHead(tf.keras.layers.Layer): ...@@ -992,6 +993,20 @@ class TFFunnelMaskedLMHead(tf.keras.layers.Layer):
self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias") self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape) super().build(input_shape)
def get_output_embeddings(self):
return self.input_embeddings
def set_output_embeddings(self, value):
self.input_embeddings.word_embeddings = value
self.input_embeddings.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias}
def set_bias(self, value):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states, training=False): def call(self, hidden_states, training=False):
hidden_states = self.input_embeddings(hidden_states, mode="linear") hidden_states = self.input_embeddings(hidden_states, mode="linear")
hidden_states = hidden_states + self.bias hidden_states = hidden_states + self.bias
...@@ -1349,13 +1364,11 @@ class TFFunnelForMaskedLM(TFFunnelPreTrainedModel, TFMaskedLanguageModelingLoss) ...@@ -1349,13 +1364,11 @@ class TFFunnelForMaskedLM(TFFunnelPreTrainedModel, TFMaskedLanguageModelingLoss)
self.funnel = TFFunnelMainLayer(config, name="funnel") self.funnel = TFFunnelMainLayer(config, name="funnel")
self.lm_head = TFFunnelMaskedLMHead(config, self.funnel.embeddings, name="lm_head") self.lm_head = TFFunnelMaskedLMHead(config, self.funnel.embeddings, name="lm_head")
def get_output_embeddings(self): def get_lm_head(self):
return self.funnel.embeddings
def get_output_layer_with_bias(self):
return self.lm_head return self.lm_head
def get_prefix_bias_name(self): def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.lm_head.name return self.name + "/" + self.lm_head.name
@add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
......
...@@ -243,7 +243,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -243,7 +243,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.wte.weight = value self.wte.weight = value
self.wte.vocab_size = self.wte.weight.shape[0] self.wte.vocab_size = shape_list(value)[0]
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" """
...@@ -653,7 +653,10 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -653,7 +653,10 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
self.transformer = TFGPT2MainLayer(config, name="transformer") self.transformer = TFGPT2MainLayer(config, name="transformer")
def get_output_embeddings(self): def get_output_embeddings(self):
return self.transformer.wte return self.get_input_embeddings()
def set_output_embeddings(self, value):
self.set_input_embeddings(value)
def prepare_inputs_for_generation(self, inputs, past, **kwargs): def prepare_inputs_for_generation(self, inputs, past, **kwargs):
# only last token for inputs_ids if past is defined in kwargs # only last token for inputs_ids if past is defined in kwargs
...@@ -771,9 +774,6 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel): ...@@ -771,9 +774,6 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
config, initializer_range=config.initializer_range, name="multiple_choice_head" config, initializer_range=config.initializer_range, name="multiple_choice_head"
) )
def get_output_embeddings(self):
return self.transformer.wte
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFGPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFGPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
...@@ -945,9 +945,6 @@ class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassific ...@@ -945,9 +945,6 @@ class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassific
) )
self.transformer = TFGPT2MainLayer(config, name="transformer") self.transformer = TFGPT2MainLayer(config, name="transformer")
def get_output_embeddings(self):
return self.transformer.wte
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
......
...@@ -1182,6 +1182,44 @@ class TFLEDPreTrainedModel(TFPreTrainedModel): ...@@ -1182,6 +1182,44 @@ class TFLEDPreTrainedModel(TFPreTrainedModel):
} }
return dummy_inputs return dummy_inputs
def get_input_embeddings(self):
base_model = getattr(self, self.base_model_prefix, self)
return base_model.shared
def set_input_embeddings(self, value):
base_model = getattr(self, self.base_model_prefix, self)
try:
base_model.shared.weight = value
except AttributeError:
self(self.dummy_inputs)
base_model.shared.weight = value
base_model.shared.vocab_size = shape_list(base_model.shared.weight)[0]
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
pass
embed_tokens = TFWrappedEmbeddings(base_model.shared, abs_scope_name=shared_abs_scope_name)
base_model.encoder.set_embed_tokens(embed_tokens)
base_model.decoder.set_embed_tokens(embed_tokens)
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"decoder_input_ids": tf.TensorSpec((None, None), tf.int32, name="decoder_input_ids"),
"decoder_attention_mask": tf.TensorSpec((None, None), tf.int32, name="decoder_attention_mask"),
}
]
)
def serving(self, inputs):
output = self.call(inputs)
return self.serving_output(output)
@dataclass @dataclass
# Copied from transformers.models.longformer.modeling_tf_longformer.TFLongformerBaseModelOutput with TFLongformer->TFLEDEncoder # Copied from transformers.models.longformer.modeling_tf_longformer.TFLongformerBaseModelOutput with TFLongformer->TFLEDEncoder
...@@ -1483,6 +1521,9 @@ class TFLEDEncoder(tf.keras.layers.Layer): ...@@ -1483,6 +1521,9 @@ class TFLEDEncoder(tf.keras.layers.Layer):
self.layers = [TFLEDEncoderLayer(config, i, name=f"layers.{i}") for i in range(config.encoder_layers)] self.layers = [TFLEDEncoderLayer(config, i, name=f"layers.{i}") for i in range(config.encoder_layers)]
self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding") self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -1714,6 +1755,9 @@ class TFLEDDecoder(tf.keras.layers.Layer): ...@@ -1714,6 +1755,9 @@ class TFLEDDecoder(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(config.dropout) self.dropout = tf.keras.layers.Dropout(config.dropout)
def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -1921,6 +1965,9 @@ class TFLEDModel(TFLEDPreTrainedModel): ...@@ -1921,6 +1965,9 @@ class TFLEDModel(TFLEDPreTrainedModel):
self.encoder = TFLEDEncoder(config, embed_tokens, name="encoder") self.encoder = TFLEDEncoder(config, embed_tokens, name="encoder")
self.decoder = TFLEDDecoder(config, embed_tokens, name="decoder") self.decoder = TFLEDDecoder(config, embed_tokens, name="decoder")
def get_encoder(self):
return self.encoder
def get_decoder(self): def get_decoder(self):
return self.decoder return self.decoder
...@@ -2047,15 +2094,6 @@ class TFLEDModel(TFLEDPreTrainedModel): ...@@ -2047,15 +2094,6 @@ class TFLEDModel(TFLEDPreTrainedModel):
encoder_global_attentions=enc_g_attns, encoder_global_attentions=enc_g_attns,
) )
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, value):
self.shared = value
def get_output_embeddings(self):
return self.shared
@add_start_docstrings( @add_start_docstrings(
"The LED Model with a language modeling head. Can be used for summarization.", "The LED Model with a language modeling head. Can be used for summarization.",
...@@ -2079,22 +2117,20 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel): ...@@ -2079,22 +2117,20 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
def get_decoder(self): def get_decoder(self):
return self.led.decoder return self.led.decoder
def resize_token_embeddings(self, new_num_tokens): def get_encoder(self):
super().resize_token_embeddings(new_num_tokens=new_num_tokens) return self.led.encoder
# LED is a special case where the bias has two dimensions def get_bias(self):
# and not named just `bias` return {"final_logits_bias": self.final_logits_bias}
if new_num_tokens is not None:
num_tokens_to_copy = min(shape_list(self.final_logits_bias), new_num_tokens) def set_bias(self, value):
init_bias = tf.zeros((new_num_tokens,)) self.final_logits_bias = value["final_logits_bias"]
init_bias[:num_tokens_to_copy] = self.final_logits_bias.value()[:num_tokens_to_copy]
self.final_logits_bias = self.add_weight( def get_output_embeddings(self):
shape=(1, new_num_tokens), return self.get_input_embeddings()
initializer="zeros",
trainable=False, def set_output_embeddings(self, value):
name="final_logits_bias", self.set_input_embeddings(value)
)
self.final_logits_bias.assign(init_bias)
@add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFLEDSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFLEDSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
...@@ -2266,12 +2302,6 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel): ...@@ -2266,12 +2302,6 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
) )
return (past[0], reordered_past) return (past[0], reordered_past)
def get_output_embeddings(self):
return self.led.shared
def get_encoder(self):
return self.led.encoder
def compute_loss(self, labels, logits): def compute_loss(self, labels, logits):
"""CrossEntropyLoss that ignores pad tokens""" """CrossEntropyLoss that ignores pad tokens"""
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
"""Tensorflow Longformer model. """ """Tensorflow Longformer model. """
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
...@@ -437,6 +438,20 @@ class TFLongformerLMHead(tf.keras.layers.Layer): ...@@ -437,6 +438,20 @@ class TFLongformerLMHead(tf.keras.layers.Layer):
super().build(input_shape) super().build(input_shape)
def get_output_embeddings(self):
return self.decoder
def set_output_embeddings(self, value):
self.decoder.word_embeddings = value
self.decoder.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias}
def set_bias(self, value):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.act(hidden_states) hidden_states = self.act(hidden_states)
...@@ -1602,7 +1617,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer): ...@@ -1602,7 +1617,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0] self.embeddings.vocab_size = shape_list(value)[0]
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" """
...@@ -2040,13 +2055,11 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel ...@@ -2040,13 +2055,11 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
self.longformer = TFLongformerMainLayer(config, add_pooling_layer=False, name="longformer") self.longformer = TFLongformerMainLayer(config, add_pooling_layer=False, name="longformer")
self.lm_head = TFLongformerLMHead(config, self.longformer.embeddings, name="lm_head") self.lm_head = TFLongformerLMHead(config, self.longformer.embeddings, name="lm_head")
def get_output_embeddings(self): def get_lm_head(self):
return self.lm_head.decoder
def get_output_layer_with_bias(self):
return self.lm_head return self.lm_head
def get_prefix_bias_name(self): def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.lm_head.name return self.name + "/" + self.lm_head.name
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
# limitations under the License. # limitations under the License.
""" TF 2.0 LXMERT model. """ """ TF 2.0 LXMERT model. """
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
...@@ -706,10 +707,7 @@ class TFLxmertMainLayer(tf.keras.layers.Layer): ...@@ -706,10 +707,7 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0] self.embeddings.vocab_size = shape_list(value)[0]
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
raise NotImplementedError raise NotImplementedError
...@@ -1103,6 +1101,20 @@ class TFLxmertLMPredictionHead(tf.keras.layers.Layer): ...@@ -1103,6 +1101,20 @@ class TFLxmertLMPredictionHead(tf.keras.layers.Layer):
self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias") self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape) super().build(input_shape)
def get_output_embeddings(self):
return self.input_embeddings
def set_output_embeddings(self, value):
self.input_embeddings.word_embeddings = value
self.input_embeddings.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias}
def set_bias(self, value):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.transform(hidden_states) hidden_states = self.transform(hidden_states)
hidden_states = self.input_embeddings(hidden_states, mode="linear") hidden_states = self.input_embeddings(hidden_states, mode="linear")
...@@ -1292,13 +1304,11 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel): ...@@ -1292,13 +1304,11 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
**({"obj_labels": obj_labels} if self.config.task_obj_predict else {}), **({"obj_labels": obj_labels} if self.config.task_obj_predict else {}),
} }
def get_output_embeddings(self): def get_lm_head(self):
return self.lxmert.embeddings
def get_output_layer_with_bias(self):
return self.cls.predictions return self.cls.predictions
def get_prefix_bias_name(self): def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.cls.name + "/" + self.cls.predictions.name return self.name + "/" + self.cls.name + "/" + self.cls.predictions.name
@add_start_docstrings_to_model_forward(LXMERT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(LXMERT_INPUTS_DOCSTRING)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
""" TF 2.0 MobileBERT model. """ """ TF 2.0 MobileBERT model. """
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
...@@ -665,6 +666,20 @@ class TFMobileBertLMPredictionHead(tf.keras.layers.Layer): ...@@ -665,6 +666,20 @@ class TFMobileBertLMPredictionHead(tf.keras.layers.Layer):
) )
super().build(input_shape) super().build(input_shape)
def get_output_embeddings(self):
return self
def set_output_embeddings(self, value):
self.decoder = value
self.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias}
def set_bias(self, value):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.transform(hidden_states) hidden_states = self.transform(hidden_states)
hidden_states = tf.matmul(hidden_states, tf.concat([tf.transpose(self.decoder), self.dense], axis=0)) hidden_states = tf.matmul(hidden_states, tf.concat([tf.transpose(self.decoder), self.dense], axis=0))
...@@ -704,10 +719,7 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer): ...@@ -704,10 +719,7 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0] self.embeddings.vocab_size = shape_list(value)[0]
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" """
...@@ -1039,13 +1051,11 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel): ...@@ -1039,13 +1051,11 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel):
self.predictions = TFMobileBertMLMHead(config, name="predictions___cls") self.predictions = TFMobileBertMLMHead(config, name="predictions___cls")
self.seq_relationship = TFMobileBertOnlyNSPHead(2, name="seq_relationship___cls") self.seq_relationship = TFMobileBertOnlyNSPHead(2, name="seq_relationship___cls")
def get_output_embeddings(self): def get_lm_head(self):
return self.predictions.predictions
def get_output_layer_with_bias(self):
return self.predictions.predictions return self.predictions.predictions
def get_prefix_bias_name(self): def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.predictions.name + "/" + self.predictions.predictions.name return self.name + "/" + self.predictions.name + "/" + self.predictions.predictions.name
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
...@@ -1149,13 +1159,11 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel ...@@ -1149,13 +1159,11 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
self.mobilebert = TFMobileBertMainLayer(config, add_pooling_layer=False, name="mobilebert") self.mobilebert = TFMobileBertMainLayer(config, add_pooling_layer=False, name="mobilebert")
self.mlm = TFMobileBertMLMHead(config, name="mlm___cls") self.mlm = TFMobileBertMLMHead(config, name="mlm___cls")
def get_output_embeddings(self): def get_lm_head(self):
return self.mlm.predictions
def get_output_layer_with_bias(self):
return self.mlm.predictions return self.mlm.predictions
def get_prefix_bias_name(self): def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import math import math
import warnings
import tensorflow as tf import tensorflow as tf
...@@ -541,7 +542,7 @@ class TFMPNetMainLayer(tf.keras.layers.Layer): ...@@ -541,7 +542,7 @@ class TFMPNetMainLayer(tf.keras.layers.Layer):
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.set_input_embeddings # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.set_input_embeddings
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0] self.embeddings.vocab_size = shape_list(value)[0]
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer._prune_heads # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer._prune_heads
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
...@@ -840,6 +841,20 @@ class TFMPNetLMHead(tf.keras.layers.Layer): ...@@ -840,6 +841,20 @@ class TFMPNetLMHead(tf.keras.layers.Layer):
super().build(input_shape) super().build(input_shape)
def get_output_embeddings(self):
return self.decoder
def set_output_embeddings(self, value):
self.decoder.word_embeddings = value
self.decoder.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias}
def set_bias(self, value):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, features): def call(self, features):
x = self.dense(features) x = self.dense(features)
x = self.act(x) x = self.act(x)
...@@ -862,13 +877,11 @@ class TFMPNetForMaskedLM(TFMPNetPreTrainedModel, TFMaskedLanguageModelingLoss): ...@@ -862,13 +877,11 @@ class TFMPNetForMaskedLM(TFMPNetPreTrainedModel, TFMaskedLanguageModelingLoss):
self.mpnet = TFMPNetMainLayer(config, name="mpnet") self.mpnet = TFMPNetMainLayer(config, name="mpnet")
self.lm_head = TFMPNetLMHead(config, self.mpnet.embeddings, name="lm_head") self.lm_head = TFMPNetLMHead(config, self.mpnet.embeddings, name="lm_head")
def get_output_embeddings(self): def get_lm_head(self):
return self.mpnet.embeddings
def get_output_layer_with_bias(self):
return self.lm_head return self.lm_head
def get_prefix_bias_name(self): def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.lm_head.name return self.name + "/" + self.lm_head.name
@add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
......
...@@ -219,7 +219,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer): ...@@ -219,7 +219,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.tokens_embed.weight = value self.tokens_embed.weight = value
self.tokens_embed.vocab_size = value.shape[0] self.tokens_embed.vocab_size = shape_list(value)[0]
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" """
...@@ -577,7 +577,10 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelin ...@@ -577,7 +577,10 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelin
self.transformer = TFOpenAIGPTMainLayer(config, name="transformer") self.transformer = TFOpenAIGPTMainLayer(config, name="transformer")
def get_output_embeddings(self): def get_output_embeddings(self):
return self.transformer.tokens_embed return self.get_input_embeddings()
def set_output_embeddings(self, value):
self.set_input_embeddings(value)
@add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
...@@ -682,9 +685,6 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel): ...@@ -682,9 +685,6 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
config, initializer_range=config.initializer_range, name="multiple_choice_head" config, initializer_range=config.initializer_range, name="multiple_choice_head"
) )
def get_output_embeddings(self):
return self.transformer.tokens_embed
@add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFOpenAIGPTDoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFOpenAIGPTDoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
...@@ -841,9 +841,6 @@ class TFOpenAIGPTForSequenceClassification(TFOpenAIGPTPreTrainedModel, TFSequenc ...@@ -841,9 +841,6 @@ class TFOpenAIGPTForSequenceClassification(TFOpenAIGPTPreTrainedModel, TFSequenc
) )
self.transformer = TFOpenAIGPTMainLayer(config, name="transformer") self.transformer = TFOpenAIGPTMainLayer(config, name="transformer")
def get_output_embeddings(self):
return self.transformer.tokens_embed
@add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
# limitations under the License. # limitations under the License.
""" TF 2.0 RoBERTa model. """ """ TF 2.0 RoBERTa model. """
import warnings
import tensorflow as tf import tensorflow as tf
from ...activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
...@@ -502,7 +504,7 @@ class TFRobertaMainLayer(tf.keras.layers.Layer): ...@@ -502,7 +504,7 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.set_input_embeddings # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.set_input_embeddings
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0] self.embeddings.vocab_size = shape_list(value)[0]
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer._prune_heads # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer._prune_heads
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
...@@ -827,6 +829,20 @@ class TFRobertaLMHead(tf.keras.layers.Layer): ...@@ -827,6 +829,20 @@ class TFRobertaLMHead(tf.keras.layers.Layer):
super().build(input_shape) super().build(input_shape)
def get_output_embeddings(self):
return self.decoder
def set_output_embeddings(self, value):
self.decoder.word_embeddings = value
self.decoder.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias}
def set_bias(self, value):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.act(hidden_states) hidden_states = self.act(hidden_states)
...@@ -849,13 +865,11 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos ...@@ -849,13 +865,11 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos
self.roberta = TFRobertaMainLayer(config, add_pooling_layer=False, name="roberta") self.roberta = TFRobertaMainLayer(config, add_pooling_layer=False, name="roberta")
self.lm_head = TFRobertaLMHead(config, self.roberta.embeddings, name="lm_head") self.lm_head = TFRobertaLMHead(config, self.roberta.embeddings, name="lm_head")
def get_output_embeddings(self): def get_lm_head(self):
return self.lm_head.decoder
def get_output_layer_with_bias(self):
return self.lm_head return self.lm_head
def get_prefix_bias_name(self): def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.lm_head.name return self.name + "/" + self.lm_head.name
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
......
...@@ -573,15 +573,6 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -573,15 +573,6 @@ class TFT5MainLayer(tf.keras.layers.Layer):
self.final_layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="final_layer_norm") self.final_layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="final_layer_norm")
self.dropout = tf.keras.layers.Dropout(config.dropout_rate) self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
def get_input_embeddings(self):
return self.embed_tokens
def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models
...@@ -839,6 +830,26 @@ class TFT5PreTrainedModel(TFPreTrainedModel): ...@@ -839,6 +830,26 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
return self.serving_output(output) return self.serving_output(output)
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, value):
try:
self.shared.weight = value
except AttributeError:
self(self.dummy_inputs)
self.shared.weight = value
self.shared.vocab_size = shape_list(value)[0]
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
pass
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
self.encoder.embed_tokens = embed_tokens
if hasattr(self, "decoder"):
self.decoder.embed_tokens = embed_tokens
def _shift_right(self, input_ids): def _shift_right(self, input_ids):
decoder_start_token_id = self.config.decoder_start_token_id decoder_start_token_id = self.config.decoder_start_token_id
pad_token_id = self.config.pad_token_id pad_token_id = self.config.pad_token_id
...@@ -1050,20 +1061,6 @@ class TFT5Model(TFT5PreTrainedModel): ...@@ -1050,20 +1061,6 @@ class TFT5Model(TFT5PreTrainedModel):
decoder_config.is_decoder = True decoder_config.is_decoder = True
self.decoder = TFT5MainLayer(decoder_config, embed_tokens, name="decoder") self.decoder = TFT5MainLayer(decoder_config, embed_tokens, name="decoder")
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, new_embeddings):
self.shared.weight = new_embeddings
self.shared.vocab_size = self.shared.weight.shape[0]
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
pass
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
self.encoder.set_embed_tokens(embed_tokens)
self.decoder.set_embed_tokens(embed_tokens)
def get_encoder(self): def get_encoder(self):
return self.encoder return self.encoder
...@@ -1222,24 +1219,23 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1222,24 +1219,23 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
if not config.tie_word_embeddings: if not config.tie_word_embeddings:
self.lm_head = tf.keras.layers.Dense(config.vocab_size, use_bias=False, name="lm_head") self.lm_head = tf.keras.layers.Dense(config.vocab_size, use_bias=False, name="lm_head")
def get_input_embeddings(self):
return self.shared
def get_output_embeddings(self): def get_output_embeddings(self):
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
return self.shared return self.get_input_embeddings()
else: else:
return self.lm_head # in a dense layer the kernel has a shape (last_dim, units), for us (dim, num_tokens)
# value has a shape (num_tokens, dim) then needs to be transposed
return tf.transpose(self.lm_head.kernel)
def set_input_embeddings(self, new_embeddings): def set_output_embeddings(self, value):
self.shared.weight = new_embeddings if self.config.tie_word_embeddings:
# retrieve correct absolute scope for embed token wrapper self.set_input_embeddings(value)
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name: else:
pass self.lm_head = tf.keras.layers.Dense(shape_list(value)[0], use_bias=False, name="lm_head")
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope. # in a dense layer the kernel has a shape (last_dim, units), for us (dim, num_tokens)
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name) # value has a shape (num_tokens, dim) then needs to be transposed
self.encoder.set_embed_tokens(embed_tokens) transposed_value = tf.transpose(value)
self.decoder.set_embed_tokens(embed_tokens) self.lm_head.kernel = transposed_value
def get_encoder(self): def get_encoder(self):
return self.encoder return self.encoder
...@@ -1358,9 +1354,9 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1358,9 +1354,9 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
# T5v1.1 does not tie output word embeddings and thus does not require downscaling # T5v1.1 does not tie output word embeddings and thus does not require downscaling
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
sequence_output = sequence_output * (self.model_dim ** -0.5) sequence_output = sequence_output * (self.model_dim ** -0.5)
logits = self.get_output_embeddings()(sequence_output, mode="linear") logits = self.shared(sequence_output, mode="linear")
else: else:
logits = self.get_output_embeddings()(sequence_output) logits = self.lm_head(sequence_output)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
...@@ -1488,19 +1484,6 @@ class TFT5EncoderModel(TFT5PreTrainedModel): ...@@ -1488,19 +1484,6 @@ class TFT5EncoderModel(TFT5PreTrainedModel):
encoder_config.use_cache = False encoder_config.use_cache = False
self.encoder = TFT5MainLayer(encoder_config, embed_tokens, name="encoder") self.encoder = TFT5MainLayer(encoder_config, embed_tokens, name="encoder")
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, new_embeddings):
self.shared.weight = new_embeddings
self.shared.vocab_size = self.shared.weight.shape[0]
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
pass
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
self.encoder.set_embed_tokens(embed_tokens)
def get_encoder(self): def get_encoder(self):
return self.encoder return self.encoder
......
...@@ -468,9 +468,6 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): ...@@ -468,9 +468,6 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
raise NotImplementedError raise NotImplementedError
def _resize_token_embeddings(self, new_num_tokens):
return self.word_emb
def backward_compatible(self): def backward_compatible(self):
self.sample_softmax = -1 self.sample_softmax = -1
...@@ -909,25 +906,6 @@ class TFTransfoXLModel(TFTransfoXLPreTrainedModel): ...@@ -909,25 +906,6 @@ class TFTransfoXLModel(TFTransfoXLPreTrainedModel):
) )
class TFTransfoXLMHead(tf.keras.layers.Layer):
def __init__(self, config, input_embeddings, **kwargs):
super().__init__(**kwargs)
self.vocab_size = config.vocab_size
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.input_embeddings = input_embeddings
def build(self, input_shape):
self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape)
def call(self, hidden_states):
hidden_states = self.input_embeddings(hidden_states, mode="linear")
hidden_states = hidden_states + self.bias
return hidden_states
@add_start_docstrings( @add_start_docstrings(
""" """
The Transformer-XL Model with a language modeling head on top (adaptive softmax with weights tied to the adaptive The Transformer-XL Model with a language modeling head on top (adaptive softmax with weights tied to the adaptive
...@@ -948,6 +926,9 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel): ...@@ -948,6 +926,9 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val, name="crit" config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val, name="crit"
) )
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError()
def get_output_embeddings(self): def get_output_embeddings(self):
"""Double-check if you are using adaptive softmax.""" """Double-check if you are using adaptive softmax."""
if len(self.crit.out_layers) > 0: if len(self.crit.out_layers) > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment