"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "85a1269e19af022e04bc2aad82572cd5a9e8cdd9"
Unverified Commit 1243ee7d authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Full rework of the TF input/output embeddings and bias resizing (#9193)

* Start rework resizing

* Rework bias/decoder resizing

* Full resizing rework

* Full resizing rework

* Start to update the models with the new approach

* Finish to update the models

* Update all the tests

* Update the template

* Fix tests

* Fix tests

* Test a new approach

* Refactoring

* Refactoring

* Refactoring

* New rework

* Rework BART

* Rework bert+blenderbot

* Rework CTRL

* Rework Distilbert

* Rework DPR

* Rework Electra

* Rework Flaubert

* Rework Funnel

* Rework GPT2

* Rework Longformer

* Rework Lxmert

* Rework marian+mbart

* Rework mobilebert

* Rework mpnet

* Rework openai

* Rework pegasus

* Rework Roberta

* Rework T5

* Rework xlm+xlnet

* Rework template

* Fix TFT5EncoderOnly + DPRs

* Restore previous methods

* Fix Funnel

* Fix CTRL and TransforXL

* Apply style

* Apply Sylvain's comments

* Restore a test in DPR

* Address the comments

* Fix bug

* Apply style

* remove unused import

* Fix test

* Forgot a method

* missing test

* Trigger CI

* naming update

* Rebase

* Trigger CI
parent cf416764
This diff is collapsed.
...@@ -470,6 +470,21 @@ class TFAlbertMLMHead(tf.keras.layers.Layer): ...@@ -470,6 +470,21 @@ class TFAlbertMLMHead(tf.keras.layers.Layer):
super().build(input_shape) super().build(input_shape)
def get_output_embeddings(self):
return self.decoder
def set_output_embeddings(self, value):
self.decoder.word_embeddings = value
self.decoder.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias, "decoder_bias": self.decoder_bias}
def set_bias(self, value):
self.bias = value["bias"]
self.decoder_bias = value["decoder_bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.activation(hidden_states) hidden_states = self.activation(hidden_states)
...@@ -505,10 +520,7 @@ class TFAlbertMainLayer(tf.keras.layers.Layer): ...@@ -505,10 +520,7 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0] self.embeddings.vocab_size = shape_list(value)[0]
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" """
...@@ -835,34 +847,8 @@ class TFAlbertForPreTraining(TFAlbertPreTrainedModel): ...@@ -835,34 +847,8 @@ class TFAlbertForPreTraining(TFAlbertPreTrainedModel):
self.predictions = TFAlbertMLMHead(config, self.albert.embeddings, name="predictions") self.predictions = TFAlbertMLMHead(config, self.albert.embeddings, name="predictions")
self.sop_classifier = TFAlbertSOPHead(config, name="sop_classifier") self.sop_classifier = TFAlbertSOPHead(config, name="sop_classifier")
def get_output_embeddings(self): def get_lm_head(self):
return self.albert.embeddings return self.predictions
def resize_token_embeddings(self, new_num_tokens):
super().resize_token_embeddings(new_num_tokens=new_num_tokens)
# ALBERT is a special case where there are two bias to update
# even though self.bias is not used anywhere and is here
# just to make the loading weights from a PT model happy
if new_num_tokens is not None:
num_tokens_to_copy = min(self.predictions.bias.shape[0], new_num_tokens)
self.predictions.vocab_size = num_tokens_to_copy
init_bias = tf.zeros((new_num_tokens,))
init_bias[:num_tokens_to_copy] = self.predictions.bias.value()[:num_tokens_to_copy]
name = self.name + "/" + self.predictions.name + "/bias"
self.predictions.bias = self.add_weight(
shape=(new_num_tokens,), initializer="zeros", trainable=True, name=name
)
self.predictions.bias.assign(init_bias)
init_decoder_bias = tf.zeros((new_num_tokens,))
init_decoder_bias[:num_tokens_to_copy] = self.predictions.decoder_bias.value()[:num_tokens_to_copy]
name = self.name + "/" + self.predictions.name + "/decoder_bias"
self.predictions.decoder_bias = self.add_weight(
shape=(new_num_tokens,), initializer="zeros", trainable=True, name=name
)
self.predictions.decoder_bias.assign(init_decoder_bias)
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFAlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFAlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
...@@ -980,34 +966,8 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss) ...@@ -980,34 +966,8 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss)
self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert") self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert")
self.predictions = TFAlbertMLMHead(config, self.albert.embeddings, name="predictions") self.predictions = TFAlbertMLMHead(config, self.albert.embeddings, name="predictions")
def get_output_embeddings(self): def get_lm_head(self):
return self.albert.embeddings return self.predictions
def resize_token_embeddings(self, new_num_tokens):
super().resize_token_embeddings(new_num_tokens=new_num_tokens)
# ALBERT is a special case where there are two bias to update
# even though self.bias is not used anywhere and is here
# just to make the loading weights from a PT model happy
if new_num_tokens is not None:
num_tokens_to_copy = min(self.predictions.bias.shape[0], new_num_tokens)
self.predictions.vocab_size = num_tokens_to_copy
init_bias = tf.zeros((new_num_tokens,))
init_bias[:num_tokens_to_copy] = self.predictions.bias.value()[:num_tokens_to_copy]
name = self.name + "/" + self.predictions.name + "/bias"
self.predictions.bias = self.add_weight(
shape=(new_num_tokens,), initializer="zeros", trainable=True, name=name
)
self.predictions.bias.assign(init_bias)
init_decoder_bias = tf.zeros((new_num_tokens,))
init_decoder_bias[:num_tokens_to_copy] = self.predictions.decoder_bias.value()[:num_tokens_to_copy]
name = self.name + "/" + self.predictions.name + "/decoder_bias"
self.predictions.decoder_bias = self.add_weight(
shape=(new_num_tokens,), initializer="zeros", trainable=True, name=name
)
self.predictions.decoder_bias.assign(init_decoder_bias)
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
......
...@@ -481,6 +481,29 @@ class TFBartPretrainedModel(TFPreTrainedModel): ...@@ -481,6 +481,29 @@ class TFBartPretrainedModel(TFPreTrainedModel):
} }
return dummy_inputs return dummy_inputs
def get_input_embeddings(self):
base_model = getattr(self, self.base_model_prefix, self)
return base_model.shared
def set_input_embeddings(self, value):
base_model = getattr(self, self.base_model_prefix, self)
try:
base_model.shared.weight = value
except AttributeError:
self(self.dummy_inputs)
base_model.shared.weight = value
base_model.shared.vocab_size = shape_list(base_model.shared.weight)[0]
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
pass
embed_tokens = TFWrappedEmbeddings(base_model.shared, abs_scope_name=shared_abs_scope_name)
base_model.encoder.set_embed_tokens(embed_tokens)
base_model.decoder.set_embed_tokens(embed_tokens)
@tf.function( @tf.function(
input_signature=[ input_signature=[
{ {
...@@ -634,6 +657,9 @@ class TFBartEncoder(tf.keras.layers.Layer): ...@@ -634,6 +657,9 @@ class TFBartEncoder(tf.keras.layers.Layer):
else None else None
) )
def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -791,6 +817,9 @@ class TFBartDecoder(tf.keras.layers.Layer): ...@@ -791,6 +817,9 @@ class TFBartDecoder(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(config.dropout) self.dropout = tf.keras.layers.Dropout(config.dropout)
self.do_blenderbot_90_layernorm = config.do_blenderbot_90_layernorm self.do_blenderbot_90_layernorm = config.do_blenderbot_90_layernorm
def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -1009,6 +1038,9 @@ class TFBartModel(TFBartPretrainedModel): ...@@ -1009,6 +1038,9 @@ class TFBartModel(TFBartPretrainedModel):
self.encoder = TFBartEncoder(config, embed_tokens, name="encoder") self.encoder = TFBartEncoder(config, embed_tokens, name="encoder")
self.decoder = TFBartDecoder(config, embed_tokens, name="decoder") self.decoder = TFBartDecoder(config, embed_tokens, name="decoder")
def get_encoder(self):
return self.encoder
def get_decoder(self): def get_decoder(self):
return self.decoder return self.decoder
...@@ -1134,15 +1166,6 @@ class TFBartModel(TFBartPretrainedModel): ...@@ -1134,15 +1166,6 @@ class TFBartModel(TFBartPretrainedModel):
encoder_attentions=enc_attns, encoder_attentions=enc_attns,
) )
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, value):
self.shared = value
def get_output_embeddings(self):
return self.shared
@add_start_docstrings( @add_start_docstrings(
"The BART Model with a language modeling head. Can be used for summarization.", "The BART Model with a language modeling head. Can be used for summarization.",
...@@ -1166,22 +1189,20 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel): ...@@ -1166,22 +1189,20 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel):
def get_decoder(self): def get_decoder(self):
return self.model.decoder return self.model.decoder
def resize_token_embeddings(self, new_num_tokens): def get_encoder(self):
super().resize_token_embeddings(new_num_tokens=new_num_tokens) return self.model.encoder
# BART is a special case where the bias has two dimensions def get_output_embeddings(self):
# and not named just `bias` return self.get_input_embeddings()
if new_num_tokens is not None:
num_tokens_to_copy = min(self.final_logits_bias.shape[0], new_num_tokens) def set_output_embeddings(self, value):
init_bias = tf.zeros((new_num_tokens,)) self.set_input_embeddings(value)
init_bias[:num_tokens_to_copy] = self.final_logits_bias.value()[:num_tokens_to_copy]
self.final_logits_bias = self.add_weight( def get_bias(self):
shape=(1, new_num_tokens), return {"final_logits_bias": self.final_logits_bias}
initializer="zeros",
trainable=False, def set_bias(self, value):
name="final_logits_bias", self.final_logits_bias = value["final_logits_bias"]
)
self.final_logits_bias.assign(init_bias)
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
...@@ -1356,12 +1377,6 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel): ...@@ -1356,12 +1377,6 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel):
else: else:
return logits return logits
def get_output_embeddings(self):
return self.model.shared
def get_encoder(self):
return self.model.encoder
def compute_loss(self, labels, logits): def compute_loss(self, labels, logits):
"""CrossEntropyLoss that ignores pad tokens""" """CrossEntropyLoss that ignores pad tokens"""
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
""" TF 2.0 BERT model. """ """ TF 2.0 BERT model. """
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
...@@ -526,6 +527,20 @@ class TFBertLMPredictionHead(tf.keras.layers.Layer): ...@@ -526,6 +527,20 @@ class TFBertLMPredictionHead(tf.keras.layers.Layer):
super().build(input_shape) super().build(input_shape)
def get_output_embeddings(self):
return self.input_embeddings
def set_output_embeddings(self, value):
self.input_embeddings.word_embeddings = value
self.input_embeddings.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias}
def set_bias(self, value):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.transform(hidden_states) hidden_states = self.transform(hidden_states)
hidden_states = self.input_embeddings(hidden_states, mode="linear") hidden_states = self.input_embeddings(hidden_states, mode="linear")
...@@ -582,7 +597,7 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -582,7 +597,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0] self.embeddings.vocab_size = shape_list(value)[0]
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" """
...@@ -918,13 +933,11 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss): ...@@ -918,13 +933,11 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
self.nsp = TFBertNSPHead(config, name="nsp___cls") self.nsp = TFBertNSPHead(config, name="nsp___cls")
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls") self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")
def get_output_embeddings(self): def get_lm_head(self):
return self.bert.embeddings
def get_output_layer_with_bias(self):
return self.mlm.predictions return self.mlm.predictions
def get_prefix_bias_name(self): def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
...@@ -1044,13 +1057,11 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss): ...@@ -1044,13 +1057,11 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert") self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls") self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")
def get_output_embeddings(self): def get_lm_head(self):
return self.bert.embeddings
def get_output_layer_with_bias(self):
return self.mlm.predictions return self.mlm.predictions
def get_prefix_bias_name(self): def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
...@@ -1149,13 +1160,11 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1149,13 +1160,11 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert") self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls") self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")
def get_output_embeddings(self): def get_lm_head(self):
return self.bert.embeddings
def get_output_layer_with_bias(self):
return self.mlm.predictions return self.mlm.predictions
def get_prefix_bias_name(self): def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
@add_code_sample_docstrings( @add_code_sample_docstrings(
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
# limitations under the License. # limitations under the License.
""" TF 2.0 CTRL model.""" """ TF 2.0 CTRL model."""
import warnings
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -242,10 +244,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -242,10 +244,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.w.weight = value self.w.weight = value
self.w.vocab_size = value.shape[0] self.w.vocab_size = shape_list(value)[0]
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" """
...@@ -618,6 +617,20 @@ class TFCTRLLMHead(tf.keras.layers.Layer): ...@@ -618,6 +617,20 @@ class TFCTRLLMHead(tf.keras.layers.Layer):
self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias") self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape) super().build(input_shape)
def get_output_embeddings(self):
return self.input_embeddings
def set_output_embeddings(self, value):
self.input_embeddings.weight = value
self.input_embeddings.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias}
def set_bias(self, value):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.input_embeddings(hidden_states, mode="linear") hidden_states = self.input_embeddings(hidden_states, mode="linear")
hidden_states = hidden_states + self.bias hidden_states = hidden_states + self.bias
...@@ -638,13 +651,11 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -638,13 +651,11 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
self.lm_head = TFCTRLLMHead(config, self.transformer.w, name="lm_head") self.lm_head = TFCTRLLMHead(config, self.transformer.w, name="lm_head")
def get_output_embeddings(self): def get_lm_head(self):
return self.lm_head.input_embeddings
def get_output_layer_with_bias(self):
return self.lm_head return self.lm_head
def get_prefix_bias_name(self): def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.lm_head.name return self.name + "/" + self.lm_head.name
def prepare_inputs_for_generation(self, inputs, past, **kwargs): def prepare_inputs_for_generation(self, inputs, past, **kwargs):
......
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
TF 2.0 DistilBERT model TF 2.0 DistilBERT model
""" """
import warnings
import tensorflow as tf import tensorflow as tf
from ...activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
...@@ -39,7 +41,6 @@ from ...modeling_tf_utils import ( ...@@ -39,7 +41,6 @@ from ...modeling_tf_utils import (
TFPreTrainedModel, TFPreTrainedModel,
TFQuestionAnsweringLoss, TFQuestionAnsweringLoss,
TFSequenceClassificationLoss, TFSequenceClassificationLoss,
TFSharedEmbeddings,
TFTokenClassificationLoss, TFTokenClassificationLoss,
get_initializer, get_initializer,
input_processing, input_processing,
...@@ -72,9 +73,6 @@ class TFEmbeddings(tf.keras.layers.Layer): ...@@ -72,9 +73,6 @@ class TFEmbeddings(tf.keras.layers.Layer):
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.dim = config.dim self.dim = config.dim
self.initializer_range = config.initializer_range self.initializer_range = config.initializer_range
self.word_embeddings = TFSharedEmbeddings(
config.vocab_size, config.dim, initializer_range=config.initializer_range, name="word_embeddings"
) # padding_idx=0)
self.position_embeddings = tf.keras.layers.Embedding( self.position_embeddings = tf.keras.layers.Embedding(
config.max_position_embeddings, config.max_position_embeddings,
config.dim, config.dim,
...@@ -648,6 +646,20 @@ class TFDistilBertLMHead(tf.keras.layers.Layer): ...@@ -648,6 +646,20 @@ class TFDistilBertLMHead(tf.keras.layers.Layer):
self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias") self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape) super().build(input_shape)
def get_output_embeddings(self):
return self.input_embeddings
def set_output_embeddings(self, value):
self.input_embeddings.word_embeddings = value
self.input_embeddings.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias}
def set_bias(self, value):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.input_embeddings(hidden_states, mode="linear") hidden_states = self.input_embeddings(hidden_states, mode="linear")
hidden_states = hidden_states + self.bias hidden_states = hidden_states + self.bias
...@@ -671,13 +683,11 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModel ...@@ -671,13 +683,11 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModel
self.vocab_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="vocab_layer_norm") self.vocab_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="vocab_layer_norm")
self.vocab_projector = TFDistilBertLMHead(config, self.distilbert.embeddings, name="vocab_projector") self.vocab_projector = TFDistilBertLMHead(config, self.distilbert.embeddings, name="vocab_projector")
def get_output_embeddings(self): def get_lm_head(self):
return self.vocab_projector.input_embeddings
def get_output_layer_with_bias(self):
return self.vocab_projector return self.vocab_projector
def get_prefix_bias_name(self): def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.vocab_projector.name return self.name + "/" + self.vocab_projector.name
@add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
......
...@@ -577,7 +577,11 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder): ...@@ -577,7 +577,11 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder):
self.ctx_encoder = TFDPREncoderLayer(config, name="ctx_encoder") self.ctx_encoder = TFDPREncoderLayer(config, name="ctx_encoder")
def get_input_embeddings(self): def get_input_embeddings(self):
return self.ctx_encoder.bert_model.get_input_embeddings() try:
return self.ctx_encoder.bert_model.get_input_embeddings()
except AttributeError:
self(self.dummy_inputs)
return self.ctx_encoder.bert_model.get_input_embeddings()
@add_start_docstrings_to_model_forward(TF_DPR_ENCODERS_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(TF_DPR_ENCODERS_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFDPRContextEncoderOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFDPRContextEncoderOutput, config_class=_CONFIG_FOR_DOC)
...@@ -671,7 +675,11 @@ class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder): ...@@ -671,7 +675,11 @@ class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder):
self.question_encoder = TFDPREncoderLayer(config, name="question_encoder") self.question_encoder = TFDPREncoderLayer(config, name="question_encoder")
def get_input_embeddings(self): def get_input_embeddings(self):
return self.question_encoder.bert_model.get_input_embeddings() try:
return self.question_encoder.bert_model.get_input_embeddings()
except AttributeError:
self(self.dummy_inputs)
return self.question_encoder.bert_model.get_input_embeddings()
@add_start_docstrings_to_model_forward(TF_DPR_ENCODERS_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(TF_DPR_ENCODERS_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFDPRQuestionEncoderOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFDPRQuestionEncoderOutput, config_class=_CONFIG_FOR_DOC)
...@@ -764,7 +772,11 @@ class TFDPRReader(TFDPRPretrainedReader): ...@@ -764,7 +772,11 @@ class TFDPRReader(TFDPRPretrainedReader):
self.span_predictor = TFDPRSpanPredictorLayer(config, name="span_predictor") self.span_predictor = TFDPRSpanPredictorLayer(config, name="span_predictor")
def get_input_embeddings(self): def get_input_embeddings(self):
return self.span_predictor.encoder.bert_model.get_input_embeddings() try:
return self.span_predictor.encoder.bert_model.get_input_embeddings()
except AttributeError:
self(self.dummy_inputs)
return self.span_predictor.encoder.bert_model.get_input_embeddings()
@add_start_docstrings_to_model_forward(TF_DPR_READER_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(TF_DPR_READER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFDPRReaderOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFDPRReaderOutput, config_class=_CONFIG_FOR_DOC)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
""" TF Electra model. """ """ TF Electra model. """
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
...@@ -511,10 +512,7 @@ class TFElectraMainLayer(tf.keras.layers.Layer): ...@@ -511,10 +512,7 @@ class TFElectraMainLayer(tf.keras.layers.Layer):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0] self.embeddings.vocab_size = shape_list(value)[0]
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" """
...@@ -912,6 +910,20 @@ class TFElectraMaskedLMHead(tf.keras.layers.Layer): ...@@ -912,6 +910,20 @@ class TFElectraMaskedLMHead(tf.keras.layers.Layer):
super().build(input_shape) super().build(input_shape)
def get_output_embeddings(self):
return self.input_embeddings
def set_output_embeddings(self, value):
self.input_embeddings.word_embeddings = value
self.input_embeddings.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias}
def set_bias(self, value):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states, training=False): def call(self, hidden_states, training=False):
hidden_states = self.input_embeddings(hidden_states, mode="linear") hidden_states = self.input_embeddings(hidden_states, mode="linear")
hidden_states = hidden_states + self.bias hidden_states = hidden_states + self.bias
...@@ -943,13 +955,11 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos ...@@ -943,13 +955,11 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos
self.generator_lm_head = TFElectraMaskedLMHead(config, self.electra.embeddings, name="generator_lm_head") self.generator_lm_head = TFElectraMaskedLMHead(config, self.electra.embeddings, name="generator_lm_head")
def get_output_embeddings(self): def get_lm_head(self):
return self.electra.embeddings
def get_output_layer_with_bias(self):
return self.generator_lm_head return self.generator_lm_head
def get_prefix_bias_name(self): def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.generator_lm_head.name return self.name + "/" + self.generator_lm_head.name
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import itertools import itertools
import random import random
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
...@@ -478,6 +479,10 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer): ...@@ -478,6 +479,10 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embeddings return self.embeddings
def set_input_embeddings(self, value):
self.embeddings.weight = value
self.embeddings.vocab_size = shape_list(value)[0]
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -721,6 +726,20 @@ class TFFlaubertPredLayer(tf.keras.layers.Layer): ...@@ -721,6 +726,20 @@ class TFFlaubertPredLayer(tf.keras.layers.Layer):
super().build(input_shape) super().build(input_shape)
def get_output_embeddings(self):
return self.input_embeddings
def set_output_embeddings(self, value):
self.input_embeddings.weight = value
self.input_embeddings.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias}
def set_bias(self, value):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.input_embeddings(hidden_states, mode="linear") hidden_states = self.input_embeddings(hidden_states, mode="linear")
hidden_states = hidden_states + self.bias hidden_states = hidden_states + self.bias
...@@ -767,13 +786,11 @@ class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel): ...@@ -767,13 +786,11 @@ class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel):
self.transformer = TFFlaubertMainLayer(config, name="transformer") self.transformer = TFFlaubertMainLayer(config, name="transformer")
self.pred_layer = TFFlaubertPredLayer(config, self.transformer.embeddings, name="pred_layer_._proj") self.pred_layer = TFFlaubertPredLayer(config, self.transformer.embeddings, name="pred_layer_._proj")
def get_output_embeddings(self): def get_lm_head(self):
return self.pred_layer.input_embeddings
def get_output_layer_with_bias(self):
return self.pred_layer return self.pred_layer
def get_prefix_bias_name(self): def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.pred_layer.name return self.name + "/" + self.pred_layer.name
def prepare_inputs_for_generation(self, inputs, **kwargs): def prepare_inputs_for_generation(self, inputs, **kwargs):
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
""" TF 2.0 Funnel model. """ """ TF 2.0 Funnel model. """
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
...@@ -787,7 +788,7 @@ class TFFunnelBaseLayer(tf.keras.layers.Layer): ...@@ -787,7 +788,7 @@ class TFFunnelBaseLayer(tf.keras.layers.Layer):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0] self.embeddings.vocab_size = shape_list(value)[0]
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models
...@@ -873,7 +874,7 @@ class TFFunnelMainLayer(tf.keras.layers.Layer): ...@@ -873,7 +874,7 @@ class TFFunnelMainLayer(tf.keras.layers.Layer):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0] self.embeddings.vocab_size = shape_list(value)[0]
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models
...@@ -992,6 +993,20 @@ class TFFunnelMaskedLMHead(tf.keras.layers.Layer): ...@@ -992,6 +993,20 @@ class TFFunnelMaskedLMHead(tf.keras.layers.Layer):
self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias") self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape) super().build(input_shape)
def get_output_embeddings(self):
return self.input_embeddings
def set_output_embeddings(self, value):
self.input_embeddings.word_embeddings = value
self.input_embeddings.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias}
def set_bias(self, value):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states, training=False): def call(self, hidden_states, training=False):
hidden_states = self.input_embeddings(hidden_states, mode="linear") hidden_states = self.input_embeddings(hidden_states, mode="linear")
hidden_states = hidden_states + self.bias hidden_states = hidden_states + self.bias
...@@ -1349,13 +1364,11 @@ class TFFunnelForMaskedLM(TFFunnelPreTrainedModel, TFMaskedLanguageModelingLoss) ...@@ -1349,13 +1364,11 @@ class TFFunnelForMaskedLM(TFFunnelPreTrainedModel, TFMaskedLanguageModelingLoss)
self.funnel = TFFunnelMainLayer(config, name="funnel") self.funnel = TFFunnelMainLayer(config, name="funnel")
self.lm_head = TFFunnelMaskedLMHead(config, self.funnel.embeddings, name="lm_head") self.lm_head = TFFunnelMaskedLMHead(config, self.funnel.embeddings, name="lm_head")
def get_output_embeddings(self): def get_lm_head(self):
return self.funnel.embeddings
def get_output_layer_with_bias(self):
return self.lm_head return self.lm_head
def get_prefix_bias_name(self): def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.lm_head.name return self.name + "/" + self.lm_head.name
@add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
......
...@@ -243,7 +243,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -243,7 +243,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.wte.weight = value self.wte.weight = value
self.wte.vocab_size = self.wte.weight.shape[0] self.wte.vocab_size = shape_list(value)[0]
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" """
...@@ -653,7 +653,10 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -653,7 +653,10 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
self.transformer = TFGPT2MainLayer(config, name="transformer") self.transformer = TFGPT2MainLayer(config, name="transformer")
def get_output_embeddings(self): def get_output_embeddings(self):
return self.transformer.wte return self.get_input_embeddings()
def set_output_embeddings(self, value):
self.set_input_embeddings(value)
def prepare_inputs_for_generation(self, inputs, past, **kwargs): def prepare_inputs_for_generation(self, inputs, past, **kwargs):
# only last token for inputs_ids if past is defined in kwargs # only last token for inputs_ids if past is defined in kwargs
...@@ -771,9 +774,6 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel): ...@@ -771,9 +774,6 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
config, initializer_range=config.initializer_range, name="multiple_choice_head" config, initializer_range=config.initializer_range, name="multiple_choice_head"
) )
def get_output_embeddings(self):
return self.transformer.wte
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFGPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFGPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
...@@ -945,9 +945,6 @@ class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassific ...@@ -945,9 +945,6 @@ class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassific
) )
self.transformer = TFGPT2MainLayer(config, name="transformer") self.transformer = TFGPT2MainLayer(config, name="transformer")
def get_output_embeddings(self):
return self.transformer.wte
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
......
...@@ -1182,6 +1182,44 @@ class TFLEDPreTrainedModel(TFPreTrainedModel): ...@@ -1182,6 +1182,44 @@ class TFLEDPreTrainedModel(TFPreTrainedModel):
} }
return dummy_inputs return dummy_inputs
def get_input_embeddings(self):
base_model = getattr(self, self.base_model_prefix, self)
return base_model.shared
def set_input_embeddings(self, value):
base_model = getattr(self, self.base_model_prefix, self)
try:
base_model.shared.weight = value
except AttributeError:
self(self.dummy_inputs)
base_model.shared.weight = value
base_model.shared.vocab_size = shape_list(base_model.shared.weight)[0]
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
pass
embed_tokens = TFWrappedEmbeddings(base_model.shared, abs_scope_name=shared_abs_scope_name)
base_model.encoder.set_embed_tokens(embed_tokens)
base_model.decoder.set_embed_tokens(embed_tokens)
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"decoder_input_ids": tf.TensorSpec((None, None), tf.int32, name="decoder_input_ids"),
"decoder_attention_mask": tf.TensorSpec((None, None), tf.int32, name="decoder_attention_mask"),
}
]
)
def serving(self, inputs):
output = self.call(inputs)
return self.serving_output(output)
@dataclass @dataclass
# Copied from transformers.models.longformer.modeling_tf_longformer.TFLongformerBaseModelOutput with TFLongformer->TFLEDEncoder # Copied from transformers.models.longformer.modeling_tf_longformer.TFLongformerBaseModelOutput with TFLongformer->TFLEDEncoder
...@@ -1483,6 +1521,9 @@ class TFLEDEncoder(tf.keras.layers.Layer): ...@@ -1483,6 +1521,9 @@ class TFLEDEncoder(tf.keras.layers.Layer):
self.layers = [TFLEDEncoderLayer(config, i, name=f"layers.{i}") for i in range(config.encoder_layers)] self.layers = [TFLEDEncoderLayer(config, i, name=f"layers.{i}") for i in range(config.encoder_layers)]
self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding") self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -1714,6 +1755,9 @@ class TFLEDDecoder(tf.keras.layers.Layer): ...@@ -1714,6 +1755,9 @@ class TFLEDDecoder(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(config.dropout) self.dropout = tf.keras.layers.Dropout(config.dropout)
def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -1921,6 +1965,9 @@ class TFLEDModel(TFLEDPreTrainedModel): ...@@ -1921,6 +1965,9 @@ class TFLEDModel(TFLEDPreTrainedModel):
self.encoder = TFLEDEncoder(config, embed_tokens, name="encoder") self.encoder = TFLEDEncoder(config, embed_tokens, name="encoder")
self.decoder = TFLEDDecoder(config, embed_tokens, name="decoder") self.decoder = TFLEDDecoder(config, embed_tokens, name="decoder")
def get_encoder(self):
return self.encoder
def get_decoder(self): def get_decoder(self):
return self.decoder return self.decoder
...@@ -2047,15 +2094,6 @@ class TFLEDModel(TFLEDPreTrainedModel): ...@@ -2047,15 +2094,6 @@ class TFLEDModel(TFLEDPreTrainedModel):
encoder_global_attentions=enc_g_attns, encoder_global_attentions=enc_g_attns,
) )
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, value):
self.shared = value
def get_output_embeddings(self):
return self.shared
@add_start_docstrings( @add_start_docstrings(
"The LED Model with a language modeling head. Can be used for summarization.", "The LED Model with a language modeling head. Can be used for summarization.",
...@@ -2079,22 +2117,20 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel): ...@@ -2079,22 +2117,20 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
def get_decoder(self): def get_decoder(self):
return self.led.decoder return self.led.decoder
def resize_token_embeddings(self, new_num_tokens): def get_encoder(self):
super().resize_token_embeddings(new_num_tokens=new_num_tokens) return self.led.encoder
# LED is a special case where the bias has two dimensions def get_bias(self):
# and not named just `bias` return {"final_logits_bias": self.final_logits_bias}
if new_num_tokens is not None:
num_tokens_to_copy = min(shape_list(self.final_logits_bias), new_num_tokens) def set_bias(self, value):
init_bias = tf.zeros((new_num_tokens,)) self.final_logits_bias = value["final_logits_bias"]
init_bias[:num_tokens_to_copy] = self.final_logits_bias.value()[:num_tokens_to_copy]
self.final_logits_bias = self.add_weight( def get_output_embeddings(self):
shape=(1, new_num_tokens), return self.get_input_embeddings()
initializer="zeros",
trainable=False, def set_output_embeddings(self, value):
name="final_logits_bias", self.set_input_embeddings(value)
)
self.final_logits_bias.assign(init_bias)
@add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFLEDSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFLEDSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
...@@ -2266,12 +2302,6 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel): ...@@ -2266,12 +2302,6 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
) )
return (past[0], reordered_past) return (past[0], reordered_past)
def get_output_embeddings(self):
return self.led.shared
def get_encoder(self):
return self.led.encoder
def compute_loss(self, labels, logits): def compute_loss(self, labels, logits):
"""CrossEntropyLoss that ignores pad tokens""" """CrossEntropyLoss that ignores pad tokens"""
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
"""Tensorflow Longformer model. """ """Tensorflow Longformer model. """
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
...@@ -437,6 +438,20 @@ class TFLongformerLMHead(tf.keras.layers.Layer): ...@@ -437,6 +438,20 @@ class TFLongformerLMHead(tf.keras.layers.Layer):
super().build(input_shape) super().build(input_shape)
def get_output_embeddings(self):
return self.decoder
def set_output_embeddings(self, value):
self.decoder.word_embeddings = value
self.decoder.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias}
def set_bias(self, value):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.act(hidden_states) hidden_states = self.act(hidden_states)
...@@ -1602,7 +1617,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer): ...@@ -1602,7 +1617,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0] self.embeddings.vocab_size = shape_list(value)[0]
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" """
...@@ -2040,13 +2055,11 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel ...@@ -2040,13 +2055,11 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
self.longformer = TFLongformerMainLayer(config, add_pooling_layer=False, name="longformer") self.longformer = TFLongformerMainLayer(config, add_pooling_layer=False, name="longformer")
self.lm_head = TFLongformerLMHead(config, self.longformer.embeddings, name="lm_head") self.lm_head = TFLongformerLMHead(config, self.longformer.embeddings, name="lm_head")
def get_output_embeddings(self): def get_lm_head(self):
return self.lm_head.decoder
def get_output_layer_with_bias(self):
return self.lm_head return self.lm_head
def get_prefix_bias_name(self): def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.lm_head.name return self.name + "/" + self.lm_head.name
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
# limitations under the License. # limitations under the License.
""" TF 2.0 LXMERT model. """ """ TF 2.0 LXMERT model. """
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
...@@ -706,10 +707,7 @@ class TFLxmertMainLayer(tf.keras.layers.Layer): ...@@ -706,10 +707,7 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0] self.embeddings.vocab_size = shape_list(value)[0]
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
raise NotImplementedError raise NotImplementedError
...@@ -1103,6 +1101,20 @@ class TFLxmertLMPredictionHead(tf.keras.layers.Layer): ...@@ -1103,6 +1101,20 @@ class TFLxmertLMPredictionHead(tf.keras.layers.Layer):
self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias") self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape) super().build(input_shape)
def get_output_embeddings(self):
return self.input_embeddings
def set_output_embeddings(self, value):
self.input_embeddings.word_embeddings = value
self.input_embeddings.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias}
def set_bias(self, value):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.transform(hidden_states) hidden_states = self.transform(hidden_states)
hidden_states = self.input_embeddings(hidden_states, mode="linear") hidden_states = self.input_embeddings(hidden_states, mode="linear")
...@@ -1292,13 +1304,11 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel): ...@@ -1292,13 +1304,11 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
**({"obj_labels": obj_labels} if self.config.task_obj_predict else {}), **({"obj_labels": obj_labels} if self.config.task_obj_predict else {}),
} }
def get_output_embeddings(self): def get_lm_head(self):
return self.lxmert.embeddings
def get_output_layer_with_bias(self):
return self.cls.predictions return self.cls.predictions
def get_prefix_bias_name(self): def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.cls.name + "/" + self.cls.predictions.name return self.name + "/" + self.cls.name + "/" + self.cls.predictions.name
@add_start_docstrings_to_model_forward(LXMERT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(LXMERT_INPUTS_DOCSTRING)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
""" TF 2.0 MobileBERT model. """ """ TF 2.0 MobileBERT model. """
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
...@@ -665,6 +666,20 @@ class TFMobileBertLMPredictionHead(tf.keras.layers.Layer): ...@@ -665,6 +666,20 @@ class TFMobileBertLMPredictionHead(tf.keras.layers.Layer):
) )
super().build(input_shape) super().build(input_shape)
def get_output_embeddings(self):
return self
def set_output_embeddings(self, value):
self.decoder = value
self.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias}
def set_bias(self, value):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.transform(hidden_states) hidden_states = self.transform(hidden_states)
hidden_states = tf.matmul(hidden_states, tf.concat([tf.transpose(self.decoder), self.dense], axis=0)) hidden_states = tf.matmul(hidden_states, tf.concat([tf.transpose(self.decoder), self.dense], axis=0))
...@@ -704,10 +719,7 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer): ...@@ -704,10 +719,7 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0] self.embeddings.vocab_size = shape_list(value)[0]
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" """
...@@ -1039,13 +1051,11 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel): ...@@ -1039,13 +1051,11 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel):
self.predictions = TFMobileBertMLMHead(config, name="predictions___cls") self.predictions = TFMobileBertMLMHead(config, name="predictions___cls")
self.seq_relationship = TFMobileBertOnlyNSPHead(2, name="seq_relationship___cls") self.seq_relationship = TFMobileBertOnlyNSPHead(2, name="seq_relationship___cls")
def get_output_embeddings(self): def get_lm_head(self):
return self.predictions.predictions
def get_output_layer_with_bias(self):
return self.predictions.predictions return self.predictions.predictions
def get_prefix_bias_name(self): def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.predictions.name + "/" + self.predictions.predictions.name return self.name + "/" + self.predictions.name + "/" + self.predictions.predictions.name
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
...@@ -1149,13 +1159,11 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel ...@@ -1149,13 +1159,11 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
self.mobilebert = TFMobileBertMainLayer(config, add_pooling_layer=False, name="mobilebert") self.mobilebert = TFMobileBertMainLayer(config, add_pooling_layer=False, name="mobilebert")
self.mlm = TFMobileBertMLMHead(config, name="mlm___cls") self.mlm = TFMobileBertMLMHead(config, name="mlm___cls")
def get_output_embeddings(self): def get_lm_head(self):
return self.mlm.predictions
def get_output_layer_with_bias(self):
return self.mlm.predictions return self.mlm.predictions
def get_prefix_bias_name(self): def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import math import math
import warnings
import tensorflow as tf import tensorflow as tf
...@@ -541,7 +542,7 @@ class TFMPNetMainLayer(tf.keras.layers.Layer): ...@@ -541,7 +542,7 @@ class TFMPNetMainLayer(tf.keras.layers.Layer):
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.set_input_embeddings # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.set_input_embeddings
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0] self.embeddings.vocab_size = shape_list(value)[0]
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer._prune_heads # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer._prune_heads
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
...@@ -840,6 +841,20 @@ class TFMPNetLMHead(tf.keras.layers.Layer): ...@@ -840,6 +841,20 @@ class TFMPNetLMHead(tf.keras.layers.Layer):
super().build(input_shape) super().build(input_shape)
def get_output_embeddings(self):
return self.decoder
def set_output_embeddings(self, value):
self.decoder.word_embeddings = value
self.decoder.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias}
def set_bias(self, value):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, features): def call(self, features):
x = self.dense(features) x = self.dense(features)
x = self.act(x) x = self.act(x)
...@@ -862,13 +877,11 @@ class TFMPNetForMaskedLM(TFMPNetPreTrainedModel, TFMaskedLanguageModelingLoss): ...@@ -862,13 +877,11 @@ class TFMPNetForMaskedLM(TFMPNetPreTrainedModel, TFMaskedLanguageModelingLoss):
self.mpnet = TFMPNetMainLayer(config, name="mpnet") self.mpnet = TFMPNetMainLayer(config, name="mpnet")
self.lm_head = TFMPNetLMHead(config, self.mpnet.embeddings, name="lm_head") self.lm_head = TFMPNetLMHead(config, self.mpnet.embeddings, name="lm_head")
def get_output_embeddings(self): def get_lm_head(self):
return self.mpnet.embeddings
def get_output_layer_with_bias(self):
return self.lm_head return self.lm_head
def get_prefix_bias_name(self): def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.lm_head.name return self.name + "/" + self.lm_head.name
@add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
......
...@@ -219,7 +219,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer): ...@@ -219,7 +219,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.tokens_embed.weight = value self.tokens_embed.weight = value
self.tokens_embed.vocab_size = value.shape[0] self.tokens_embed.vocab_size = shape_list(value)[0]
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" """
...@@ -577,7 +577,10 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelin ...@@ -577,7 +577,10 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelin
self.transformer = TFOpenAIGPTMainLayer(config, name="transformer") self.transformer = TFOpenAIGPTMainLayer(config, name="transformer")
def get_output_embeddings(self): def get_output_embeddings(self):
return self.transformer.tokens_embed return self.get_input_embeddings()
def set_output_embeddings(self, value):
self.set_input_embeddings(value)
@add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
...@@ -682,9 +685,6 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel): ...@@ -682,9 +685,6 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
config, initializer_range=config.initializer_range, name="multiple_choice_head" config, initializer_range=config.initializer_range, name="multiple_choice_head"
) )
def get_output_embeddings(self):
return self.transformer.tokens_embed
@add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFOpenAIGPTDoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFOpenAIGPTDoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
...@@ -841,9 +841,6 @@ class TFOpenAIGPTForSequenceClassification(TFOpenAIGPTPreTrainedModel, TFSequenc ...@@ -841,9 +841,6 @@ class TFOpenAIGPTForSequenceClassification(TFOpenAIGPTPreTrainedModel, TFSequenc
) )
self.transformer = TFOpenAIGPTMainLayer(config, name="transformer") self.transformer = TFOpenAIGPTMainLayer(config, name="transformer")
def get_output_embeddings(self):
return self.transformer.tokens_embed
@add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
# limitations under the License. # limitations under the License.
""" TF 2.0 RoBERTa model. """ """ TF 2.0 RoBERTa model. """
import warnings
import tensorflow as tf import tensorflow as tf
from ...activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
...@@ -502,7 +504,7 @@ class TFRobertaMainLayer(tf.keras.layers.Layer): ...@@ -502,7 +504,7 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.set_input_embeddings # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.set_input_embeddings
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0] self.embeddings.vocab_size = shape_list(value)[0]
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer._prune_heads # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer._prune_heads
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
...@@ -827,6 +829,20 @@ class TFRobertaLMHead(tf.keras.layers.Layer): ...@@ -827,6 +829,20 @@ class TFRobertaLMHead(tf.keras.layers.Layer):
super().build(input_shape) super().build(input_shape)
def get_output_embeddings(self):
return self.decoder
def set_output_embeddings(self, value):
self.decoder.word_embeddings = value
self.decoder.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias}
def set_bias(self, value):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.act(hidden_states) hidden_states = self.act(hidden_states)
...@@ -849,13 +865,11 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos ...@@ -849,13 +865,11 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos
self.roberta = TFRobertaMainLayer(config, add_pooling_layer=False, name="roberta") self.roberta = TFRobertaMainLayer(config, add_pooling_layer=False, name="roberta")
self.lm_head = TFRobertaLMHead(config, self.roberta.embeddings, name="lm_head") self.lm_head = TFRobertaLMHead(config, self.roberta.embeddings, name="lm_head")
def get_output_embeddings(self): def get_lm_head(self):
return self.lm_head.decoder
def get_output_layer_with_bias(self):
return self.lm_head return self.lm_head
def get_prefix_bias_name(self): def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.lm_head.name return self.name + "/" + self.lm_head.name
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
......
...@@ -573,15 +573,6 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -573,15 +573,6 @@ class TFT5MainLayer(tf.keras.layers.Layer):
self.final_layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="final_layer_norm") self.final_layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="final_layer_norm")
self.dropout = tf.keras.layers.Dropout(config.dropout_rate) self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
def get_input_embeddings(self):
return self.embed_tokens
def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models
...@@ -839,6 +830,26 @@ class TFT5PreTrainedModel(TFPreTrainedModel): ...@@ -839,6 +830,26 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
return self.serving_output(output) return self.serving_output(output)
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, value):
try:
self.shared.weight = value
except AttributeError:
self(self.dummy_inputs)
self.shared.weight = value
self.shared.vocab_size = shape_list(value)[0]
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
pass
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
self.encoder.embed_tokens = embed_tokens
if hasattr(self, "decoder"):
self.decoder.embed_tokens = embed_tokens
def _shift_right(self, input_ids): def _shift_right(self, input_ids):
decoder_start_token_id = self.config.decoder_start_token_id decoder_start_token_id = self.config.decoder_start_token_id
pad_token_id = self.config.pad_token_id pad_token_id = self.config.pad_token_id
...@@ -1050,20 +1061,6 @@ class TFT5Model(TFT5PreTrainedModel): ...@@ -1050,20 +1061,6 @@ class TFT5Model(TFT5PreTrainedModel):
decoder_config.is_decoder = True decoder_config.is_decoder = True
self.decoder = TFT5MainLayer(decoder_config, embed_tokens, name="decoder") self.decoder = TFT5MainLayer(decoder_config, embed_tokens, name="decoder")
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, new_embeddings):
self.shared.weight = new_embeddings
self.shared.vocab_size = self.shared.weight.shape[0]
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
pass
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
self.encoder.set_embed_tokens(embed_tokens)
self.decoder.set_embed_tokens(embed_tokens)
def get_encoder(self): def get_encoder(self):
return self.encoder return self.encoder
...@@ -1222,24 +1219,23 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1222,24 +1219,23 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
if not config.tie_word_embeddings: if not config.tie_word_embeddings:
self.lm_head = tf.keras.layers.Dense(config.vocab_size, use_bias=False, name="lm_head") self.lm_head = tf.keras.layers.Dense(config.vocab_size, use_bias=False, name="lm_head")
def get_input_embeddings(self):
return self.shared
def get_output_embeddings(self): def get_output_embeddings(self):
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
return self.shared return self.get_input_embeddings()
else: else:
return self.lm_head # in a dense layer the kernel has a shape (last_dim, units), for us (dim, num_tokens)
# value has a shape (num_tokens, dim) then needs to be transposed
return tf.transpose(self.lm_head.kernel)
def set_input_embeddings(self, new_embeddings): def set_output_embeddings(self, value):
self.shared.weight = new_embeddings if self.config.tie_word_embeddings:
# retrieve correct absolute scope for embed token wrapper self.set_input_embeddings(value)
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name: else:
pass self.lm_head = tf.keras.layers.Dense(shape_list(value)[0], use_bias=False, name="lm_head")
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope. # in a dense layer the kernel has a shape (last_dim, units), for us (dim, num_tokens)
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name) # value has a shape (num_tokens, dim) then needs to be transposed
self.encoder.set_embed_tokens(embed_tokens) transposed_value = tf.transpose(value)
self.decoder.set_embed_tokens(embed_tokens) self.lm_head.kernel = transposed_value
def get_encoder(self): def get_encoder(self):
return self.encoder return self.encoder
...@@ -1358,9 +1354,9 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1358,9 +1354,9 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
# T5v1.1 does not tie output word embeddings and thus does not require downscaling # T5v1.1 does not tie output word embeddings and thus does not require downscaling
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
sequence_output = sequence_output * (self.model_dim ** -0.5) sequence_output = sequence_output * (self.model_dim ** -0.5)
logits = self.get_output_embeddings()(sequence_output, mode="linear") logits = self.shared(sequence_output, mode="linear")
else: else:
logits = self.get_output_embeddings()(sequence_output) logits = self.lm_head(sequence_output)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
...@@ -1488,19 +1484,6 @@ class TFT5EncoderModel(TFT5PreTrainedModel): ...@@ -1488,19 +1484,6 @@ class TFT5EncoderModel(TFT5PreTrainedModel):
encoder_config.use_cache = False encoder_config.use_cache = False
self.encoder = TFT5MainLayer(encoder_config, embed_tokens, name="encoder") self.encoder = TFT5MainLayer(encoder_config, embed_tokens, name="encoder")
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, new_embeddings):
self.shared.weight = new_embeddings
self.shared.vocab_size = self.shared.weight.shape[0]
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
pass
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
self.encoder.set_embed_tokens(embed_tokens)
def get_encoder(self): def get_encoder(self):
return self.encoder return self.encoder
......
...@@ -468,9 +468,6 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): ...@@ -468,9 +468,6 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
raise NotImplementedError raise NotImplementedError
def _resize_token_embeddings(self, new_num_tokens):
return self.word_emb
def backward_compatible(self): def backward_compatible(self):
self.sample_softmax = -1 self.sample_softmax = -1
...@@ -909,25 +906,6 @@ class TFTransfoXLModel(TFTransfoXLPreTrainedModel): ...@@ -909,25 +906,6 @@ class TFTransfoXLModel(TFTransfoXLPreTrainedModel):
) )
class TFTransfoXLMHead(tf.keras.layers.Layer):
def __init__(self, config, input_embeddings, **kwargs):
super().__init__(**kwargs)
self.vocab_size = config.vocab_size
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.input_embeddings = input_embeddings
def build(self, input_shape):
self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape)
def call(self, hidden_states):
hidden_states = self.input_embeddings(hidden_states, mode="linear")
hidden_states = hidden_states + self.bias
return hidden_states
@add_start_docstrings( @add_start_docstrings(
""" """
The Transformer-XL Model with a language modeling head on top (adaptive softmax with weights tied to the adaptive The Transformer-XL Model with a language modeling head on top (adaptive softmax with weights tied to the adaptive
...@@ -948,6 +926,9 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel): ...@@ -948,6 +926,9 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val, name="crit" config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val, name="crit"
) )
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError()
def get_output_embeddings(self): def get_output_embeddings(self):
"""Double-check if you are using adaptive softmax.""" """Double-check if you are using adaptive softmax."""
if len(self.crit.out_layers) > 0: if len(self.crit.out_layers) > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment