Unverified Commit 1243ee7d authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Full rework of the TF input/output embeddings and bias resizing (#9193)

* Start rework resizing

* Rework bias/decoder resizing

* Full resizing rework

* Full resizing rework

* Start to update the models with the new approach

* Finish to update the models

* Update all the tests

* Update the template

* Fix tests

* Fix tests

* Test a new approach

* Refactoring

* Refactoring

* Refactoring

* New rework

* Rework BART

* Rework bert+blenderbot

* Rework CTRL

* Rework Distilbert

* Rework DPR

* Rework Electra

* Rework Flaubert

* Rework Funnel

* Rework GPT2

* Rework Longformer

* Rework Lxmert

* Rework marian+mbart

* Rework mobilebert

* Rework mpnet

* Rework openai

* Rework pegasus

* Rework Roberta

* Rework T5

* Rework xlm+xlnet

* Rework template

* Fix TFT5EncoderOnly + DPRs

* Restore previous methods

* Fix Funnel

* Fix CTRL and TransforXL

* Apply style

* Apply Sylvain's comments

* Restore a test in DPR

* Address the comments

* Fix bug

* Apply style

* remove unused import

* Fix test

* Forgot a method

* missing test

* Trigger CI

* naming update

* Rebase

* Trigger CI
parent cf416764
This diff is collapsed.
......@@ -470,6 +470,21 @@ class TFAlbertMLMHead(tf.keras.layers.Layer):
super().build(input_shape)
def get_output_embeddings(self):
return self.decoder
def set_output_embeddings(self, value):
self.decoder.word_embeddings = value
self.decoder.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias, "decoder_bias": self.decoder_bias}
def set_bias(self, value):
self.bias = value["bias"]
self.decoder_bias = value["decoder_bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.activation(hidden_states)
......@@ -505,10 +520,7 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0]
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
self.embeddings.vocab_size = shape_list(value)[0]
def _prune_heads(self, heads_to_prune):
"""
......@@ -835,34 +847,8 @@ class TFAlbertForPreTraining(TFAlbertPreTrainedModel):
self.predictions = TFAlbertMLMHead(config, self.albert.embeddings, name="predictions")
self.sop_classifier = TFAlbertSOPHead(config, name="sop_classifier")
def get_output_embeddings(self):
return self.albert.embeddings
def resize_token_embeddings(self, new_num_tokens):
super().resize_token_embeddings(new_num_tokens=new_num_tokens)
# ALBERT is a special case where there are two bias to update
# even though self.bias is not used anywhere and is here
# just to make the loading weights from a PT model happy
if new_num_tokens is not None:
num_tokens_to_copy = min(self.predictions.bias.shape[0], new_num_tokens)
self.predictions.vocab_size = num_tokens_to_copy
init_bias = tf.zeros((new_num_tokens,))
init_bias[:num_tokens_to_copy] = self.predictions.bias.value()[:num_tokens_to_copy]
name = self.name + "/" + self.predictions.name + "/bias"
self.predictions.bias = self.add_weight(
shape=(new_num_tokens,), initializer="zeros", trainable=True, name=name
)
self.predictions.bias.assign(init_bias)
init_decoder_bias = tf.zeros((new_num_tokens,))
init_decoder_bias[:num_tokens_to_copy] = self.predictions.decoder_bias.value()[:num_tokens_to_copy]
name = self.name + "/" + self.predictions.name + "/decoder_bias"
self.predictions.decoder_bias = self.add_weight(
shape=(new_num_tokens,), initializer="zeros", trainable=True, name=name
)
self.predictions.decoder_bias.assign(init_decoder_bias)
def get_lm_head(self):
return self.predictions
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFAlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
......@@ -980,34 +966,8 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss)
self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert")
self.predictions = TFAlbertMLMHead(config, self.albert.embeddings, name="predictions")
def get_output_embeddings(self):
return self.albert.embeddings
def resize_token_embeddings(self, new_num_tokens):
super().resize_token_embeddings(new_num_tokens=new_num_tokens)
# ALBERT is a special case where there are two bias to update
# even though self.bias is not used anywhere and is here
# just to make the loading weights from a PT model happy
if new_num_tokens is not None:
num_tokens_to_copy = min(self.predictions.bias.shape[0], new_num_tokens)
self.predictions.vocab_size = num_tokens_to_copy
init_bias = tf.zeros((new_num_tokens,))
init_bias[:num_tokens_to_copy] = self.predictions.bias.value()[:num_tokens_to_copy]
name = self.name + "/" + self.predictions.name + "/bias"
self.predictions.bias = self.add_weight(
shape=(new_num_tokens,), initializer="zeros", trainable=True, name=name
)
self.predictions.bias.assign(init_bias)
init_decoder_bias = tf.zeros((new_num_tokens,))
init_decoder_bias[:num_tokens_to_copy] = self.predictions.decoder_bias.value()[:num_tokens_to_copy]
name = self.name + "/" + self.predictions.name + "/decoder_bias"
self.predictions.decoder_bias = self.add_weight(
shape=(new_num_tokens,), initializer="zeros", trainable=True, name=name
)
self.predictions.decoder_bias.assign(init_decoder_bias)
def get_lm_head(self):
return self.predictions
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......
......@@ -481,6 +481,29 @@ class TFBartPretrainedModel(TFPreTrainedModel):
}
return dummy_inputs
def get_input_embeddings(self):
base_model = getattr(self, self.base_model_prefix, self)
return base_model.shared
def set_input_embeddings(self, value):
base_model = getattr(self, self.base_model_prefix, self)
try:
base_model.shared.weight = value
except AttributeError:
self(self.dummy_inputs)
base_model.shared.weight = value
base_model.shared.vocab_size = shape_list(base_model.shared.weight)[0]
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
pass
embed_tokens = TFWrappedEmbeddings(base_model.shared, abs_scope_name=shared_abs_scope_name)
base_model.encoder.set_embed_tokens(embed_tokens)
base_model.decoder.set_embed_tokens(embed_tokens)
@tf.function(
input_signature=[
{
......@@ -634,6 +657,9 @@ class TFBartEncoder(tf.keras.layers.Layer):
else None
)
def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens
def call(
self,
input_ids=None,
......@@ -791,6 +817,9 @@ class TFBartDecoder(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(config.dropout)
self.do_blenderbot_90_layernorm = config.do_blenderbot_90_layernorm
def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens
def call(
self,
input_ids=None,
......@@ -1009,6 +1038,9 @@ class TFBartModel(TFBartPretrainedModel):
self.encoder = TFBartEncoder(config, embed_tokens, name="encoder")
self.decoder = TFBartDecoder(config, embed_tokens, name="decoder")
def get_encoder(self):
return self.encoder
def get_decoder(self):
return self.decoder
......@@ -1134,15 +1166,6 @@ class TFBartModel(TFBartPretrainedModel):
encoder_attentions=enc_attns,
)
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, value):
self.shared = value
def get_output_embeddings(self):
return self.shared
@add_start_docstrings(
"The BART Model with a language modeling head. Can be used for summarization.",
......@@ -1166,22 +1189,20 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel):
def get_decoder(self):
return self.model.decoder
def resize_token_embeddings(self, new_num_tokens):
super().resize_token_embeddings(new_num_tokens=new_num_tokens)
# BART is a special case where the bias has two dimensions
# and not named just `bias`
if new_num_tokens is not None:
num_tokens_to_copy = min(self.final_logits_bias.shape[0], new_num_tokens)
init_bias = tf.zeros((new_num_tokens,))
init_bias[:num_tokens_to_copy] = self.final_logits_bias.value()[:num_tokens_to_copy]
self.final_logits_bias = self.add_weight(
shape=(1, new_num_tokens),
initializer="zeros",
trainable=False,
name="final_logits_bias",
)
self.final_logits_bias.assign(init_bias)
def get_encoder(self):
return self.model.encoder
def get_output_embeddings(self):
return self.get_input_embeddings()
def set_output_embeddings(self, value):
self.set_input_embeddings(value)
def get_bias(self):
return {"final_logits_bias": self.final_logits_bias}
def set_bias(self, value):
self.final_logits_bias = value["final_logits_bias"]
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
......@@ -1356,12 +1377,6 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel):
else:
return logits
def get_output_embeddings(self):
return self.model.shared
def get_encoder(self):
return self.model.encoder
def compute_loss(self, labels, logits):
"""CrossEntropyLoss that ignores pad tokens"""
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
......
......@@ -15,6 +15,7 @@
# limitations under the License.
""" TF 2.0 BERT model. """
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple
......@@ -526,6 +527,20 @@ class TFBertLMPredictionHead(tf.keras.layers.Layer):
super().build(input_shape)
def get_output_embeddings(self):
return self.input_embeddings
def set_output_embeddings(self, value):
self.input_embeddings.word_embeddings = value
self.input_embeddings.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias}
def set_bias(self, value):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.input_embeddings(hidden_states, mode="linear")
......@@ -582,7 +597,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0]
self.embeddings.vocab_size = shape_list(value)[0]
def _prune_heads(self, heads_to_prune):
"""
......@@ -918,13 +933,11 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
self.nsp = TFBertNSPHead(config, name="nsp___cls")
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")
def get_output_embeddings(self):
return self.bert.embeddings
def get_output_layer_with_bias(self):
def get_lm_head(self):
return self.mlm.predictions
def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
......@@ -1044,13 +1057,11 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")
def get_output_embeddings(self):
return self.bert.embeddings
def get_output_layer_with_bias(self):
def get_lm_head(self):
return self.mlm.predictions
def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
......@@ -1149,13 +1160,11 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")
def get_output_embeddings(self):
return self.bert.embeddings
def get_output_layer_with_bias(self):
def get_lm_head(self):
return self.mlm.predictions
def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
@add_code_sample_docstrings(
......
......@@ -15,6 +15,8 @@
# limitations under the License.
""" TF 2.0 CTRL model."""
import warnings
import numpy as np
import tensorflow as tf
......@@ -242,10 +244,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
def set_input_embeddings(self, value):
self.w.weight = value
self.w.vocab_size = value.shape[0]
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
self.w.vocab_size = shape_list(value)[0]
def _prune_heads(self, heads_to_prune):
"""
......@@ -618,6 +617,20 @@ class TFCTRLLMHead(tf.keras.layers.Layer):
self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape)
def get_output_embeddings(self):
return self.input_embeddings
def set_output_embeddings(self, value):
self.input_embeddings.weight = value
self.input_embeddings.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias}
def set_bias(self, value):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states):
hidden_states = self.input_embeddings(hidden_states, mode="linear")
hidden_states = hidden_states + self.bias
......@@ -638,13 +651,11 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
self.lm_head = TFCTRLLMHead(config, self.transformer.w, name="lm_head")
def get_output_embeddings(self):
return self.lm_head.input_embeddings
def get_output_layer_with_bias(self):
def get_lm_head(self):
return self.lm_head
def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.lm_head.name
def prepare_inputs_for_generation(self, inputs, past, **kwargs):
......
......@@ -16,6 +16,8 @@
TF 2.0 DistilBERT model
"""
import warnings
import tensorflow as tf
from ...activations_tf import get_tf_activation
......@@ -39,7 +41,6 @@ from ...modeling_tf_utils import (
TFPreTrainedModel,
TFQuestionAnsweringLoss,
TFSequenceClassificationLoss,
TFSharedEmbeddings,
TFTokenClassificationLoss,
get_initializer,
input_processing,
......@@ -72,9 +73,6 @@ class TFEmbeddings(tf.keras.layers.Layer):
self.vocab_size = config.vocab_size
self.dim = config.dim
self.initializer_range = config.initializer_range
self.word_embeddings = TFSharedEmbeddings(
config.vocab_size, config.dim, initializer_range=config.initializer_range, name="word_embeddings"
) # padding_idx=0)
self.position_embeddings = tf.keras.layers.Embedding(
config.max_position_embeddings,
config.dim,
......@@ -648,6 +646,20 @@ class TFDistilBertLMHead(tf.keras.layers.Layer):
self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape)
def get_output_embeddings(self):
return self.input_embeddings
def set_output_embeddings(self, value):
self.input_embeddings.word_embeddings = value
self.input_embeddings.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias}
def set_bias(self, value):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states):
hidden_states = self.input_embeddings(hidden_states, mode="linear")
hidden_states = hidden_states + self.bias
......@@ -671,13 +683,11 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModel
self.vocab_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="vocab_layer_norm")
self.vocab_projector = TFDistilBertLMHead(config, self.distilbert.embeddings, name="vocab_projector")
def get_output_embeddings(self):
return self.vocab_projector.input_embeddings
def get_output_layer_with_bias(self):
def get_lm_head(self):
return self.vocab_projector
def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.vocab_projector.name
@add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
......
......@@ -577,7 +577,11 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder):
self.ctx_encoder = TFDPREncoderLayer(config, name="ctx_encoder")
def get_input_embeddings(self):
return self.ctx_encoder.bert_model.get_input_embeddings()
try:
return self.ctx_encoder.bert_model.get_input_embeddings()
except AttributeError:
self(self.dummy_inputs)
return self.ctx_encoder.bert_model.get_input_embeddings()
@add_start_docstrings_to_model_forward(TF_DPR_ENCODERS_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFDPRContextEncoderOutput, config_class=_CONFIG_FOR_DOC)
......@@ -671,7 +675,11 @@ class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder):
self.question_encoder = TFDPREncoderLayer(config, name="question_encoder")
def get_input_embeddings(self):
return self.question_encoder.bert_model.get_input_embeddings()
try:
return self.question_encoder.bert_model.get_input_embeddings()
except AttributeError:
self(self.dummy_inputs)
return self.question_encoder.bert_model.get_input_embeddings()
@add_start_docstrings_to_model_forward(TF_DPR_ENCODERS_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFDPRQuestionEncoderOutput, config_class=_CONFIG_FOR_DOC)
......@@ -764,7 +772,11 @@ class TFDPRReader(TFDPRPretrainedReader):
self.span_predictor = TFDPRSpanPredictorLayer(config, name="span_predictor")
def get_input_embeddings(self):
return self.span_predictor.encoder.bert_model.get_input_embeddings()
try:
return self.span_predictor.encoder.bert_model.get_input_embeddings()
except AttributeError:
self(self.dummy_inputs)
return self.span_predictor.encoder.bert_model.get_input_embeddings()
@add_start_docstrings_to_model_forward(TF_DPR_READER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFDPRReaderOutput, config_class=_CONFIG_FOR_DOC)
......
......@@ -14,6 +14,7 @@
# limitations under the License.
""" TF Electra model. """
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple
......@@ -511,10 +512,7 @@ class TFElectraMainLayer(tf.keras.layers.Layer):
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0]
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
self.embeddings.vocab_size = shape_list(value)[0]
def _prune_heads(self, heads_to_prune):
"""
......@@ -912,6 +910,20 @@ class TFElectraMaskedLMHead(tf.keras.layers.Layer):
super().build(input_shape)
def get_output_embeddings(self):
return self.input_embeddings
def set_output_embeddings(self, value):
self.input_embeddings.word_embeddings = value
self.input_embeddings.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias}
def set_bias(self, value):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states, training=False):
hidden_states = self.input_embeddings(hidden_states, mode="linear")
hidden_states = hidden_states + self.bias
......@@ -943,13 +955,11 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos
self.generator_lm_head = TFElectraMaskedLMHead(config, self.electra.embeddings, name="generator_lm_head")
def get_output_embeddings(self):
return self.electra.embeddings
def get_output_layer_with_bias(self):
def get_lm_head(self):
return self.generator_lm_head
def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.generator_lm_head.name
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
......
......@@ -18,6 +18,7 @@
import itertools
import random
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple
......@@ -478,6 +479,10 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
def get_input_embeddings(self):
return self.embeddings
def set_input_embeddings(self, value):
self.embeddings.weight = value
self.embeddings.vocab_size = shape_list(value)[0]
def call(
self,
input_ids=None,
......@@ -721,6 +726,20 @@ class TFFlaubertPredLayer(tf.keras.layers.Layer):
super().build(input_shape)
def get_output_embeddings(self):
return self.input_embeddings
def set_output_embeddings(self, value):
self.input_embeddings.weight = value
self.input_embeddings.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias}
def set_bias(self, value):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states):
hidden_states = self.input_embeddings(hidden_states, mode="linear")
hidden_states = hidden_states + self.bias
......@@ -767,13 +786,11 @@ class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel):
self.transformer = TFFlaubertMainLayer(config, name="transformer")
self.pred_layer = TFFlaubertPredLayer(config, self.transformer.embeddings, name="pred_layer_._proj")
def get_output_embeddings(self):
return self.pred_layer.input_embeddings
def get_output_layer_with_bias(self):
def get_lm_head(self):
return self.pred_layer
def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.pred_layer.name
def prepare_inputs_for_generation(self, inputs, **kwargs):
......
......@@ -14,6 +14,7 @@
# limitations under the License.
""" TF 2.0 Funnel model. """
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple
......@@ -787,7 +788,7 @@ class TFFunnelBaseLayer(tf.keras.layers.Layer):
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0]
self.embeddings.vocab_size = shape_list(value)[0]
def _prune_heads(self, heads_to_prune):
raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models
......@@ -873,7 +874,7 @@ class TFFunnelMainLayer(tf.keras.layers.Layer):
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0]
self.embeddings.vocab_size = shape_list(value)[0]
def _prune_heads(self, heads_to_prune):
raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models
......@@ -992,6 +993,20 @@ class TFFunnelMaskedLMHead(tf.keras.layers.Layer):
self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape)
def get_output_embeddings(self):
return self.input_embeddings
def set_output_embeddings(self, value):
self.input_embeddings.word_embeddings = value
self.input_embeddings.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias}
def set_bias(self, value):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states, training=False):
hidden_states = self.input_embeddings(hidden_states, mode="linear")
hidden_states = hidden_states + self.bias
......@@ -1349,13 +1364,11 @@ class TFFunnelForMaskedLM(TFFunnelPreTrainedModel, TFMaskedLanguageModelingLoss)
self.funnel = TFFunnelMainLayer(config, name="funnel")
self.lm_head = TFFunnelMaskedLMHead(config, self.funnel.embeddings, name="lm_head")
def get_output_embeddings(self):
return self.funnel.embeddings
def get_output_layer_with_bias(self):
def get_lm_head(self):
return self.lm_head
def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.lm_head.name
@add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
......
......@@ -243,7 +243,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
def set_input_embeddings(self, value):
self.wte.weight = value
self.wte.vocab_size = self.wte.weight.shape[0]
self.wte.vocab_size = shape_list(value)[0]
def _prune_heads(self, heads_to_prune):
"""
......@@ -653,7 +653,10 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
self.transformer = TFGPT2MainLayer(config, name="transformer")
def get_output_embeddings(self):
return self.transformer.wte
return self.get_input_embeddings()
def set_output_embeddings(self, value):
self.set_input_embeddings(value)
def prepare_inputs_for_generation(self, inputs, past, **kwargs):
# only last token for inputs_ids if past is defined in kwargs
......@@ -771,9 +774,6 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
config, initializer_range=config.initializer_range, name="multiple_choice_head"
)
def get_output_embeddings(self):
return self.transformer.wte
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFGPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
def call(
......@@ -945,9 +945,6 @@ class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassific
)
self.transformer = TFGPT2MainLayer(config, name="transformer")
def get_output_embeddings(self):
return self.transformer.wte
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
......
......@@ -1182,6 +1182,44 @@ class TFLEDPreTrainedModel(TFPreTrainedModel):
}
return dummy_inputs
def get_input_embeddings(self):
base_model = getattr(self, self.base_model_prefix, self)
return base_model.shared
def set_input_embeddings(self, value):
base_model = getattr(self, self.base_model_prefix, self)
try:
base_model.shared.weight = value
except AttributeError:
self(self.dummy_inputs)
base_model.shared.weight = value
base_model.shared.vocab_size = shape_list(base_model.shared.weight)[0]
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
pass
embed_tokens = TFWrappedEmbeddings(base_model.shared, abs_scope_name=shared_abs_scope_name)
base_model.encoder.set_embed_tokens(embed_tokens)
base_model.decoder.set_embed_tokens(embed_tokens)
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"decoder_input_ids": tf.TensorSpec((None, None), tf.int32, name="decoder_input_ids"),
"decoder_attention_mask": tf.TensorSpec((None, None), tf.int32, name="decoder_attention_mask"),
}
]
)
def serving(self, inputs):
output = self.call(inputs)
return self.serving_output(output)
@dataclass
# Copied from transformers.models.longformer.modeling_tf_longformer.TFLongformerBaseModelOutput with TFLongformer->TFLEDEncoder
......@@ -1483,6 +1521,9 @@ class TFLEDEncoder(tf.keras.layers.Layer):
self.layers = [TFLEDEncoderLayer(config, i, name=f"layers.{i}") for i in range(config.encoder_layers)]
self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens
def call(
self,
input_ids=None,
......@@ -1714,6 +1755,9 @@ class TFLEDDecoder(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(config.dropout)
def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens
def call(
self,
input_ids=None,
......@@ -1921,6 +1965,9 @@ class TFLEDModel(TFLEDPreTrainedModel):
self.encoder = TFLEDEncoder(config, embed_tokens, name="encoder")
self.decoder = TFLEDDecoder(config, embed_tokens, name="decoder")
def get_encoder(self):
return self.encoder
def get_decoder(self):
return self.decoder
......@@ -2047,15 +2094,6 @@ class TFLEDModel(TFLEDPreTrainedModel):
encoder_global_attentions=enc_g_attns,
)
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, value):
self.shared = value
def get_output_embeddings(self):
return self.shared
@add_start_docstrings(
"The LED Model with a language modeling head. Can be used for summarization.",
......@@ -2079,22 +2117,20 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
def get_decoder(self):
return self.led.decoder
def resize_token_embeddings(self, new_num_tokens):
super().resize_token_embeddings(new_num_tokens=new_num_tokens)
# LED is a special case where the bias has two dimensions
# and not named just `bias`
if new_num_tokens is not None:
num_tokens_to_copy = min(shape_list(self.final_logits_bias), new_num_tokens)
init_bias = tf.zeros((new_num_tokens,))
init_bias[:num_tokens_to_copy] = self.final_logits_bias.value()[:num_tokens_to_copy]
self.final_logits_bias = self.add_weight(
shape=(1, new_num_tokens),
initializer="zeros",
trainable=False,
name="final_logits_bias",
)
self.final_logits_bias.assign(init_bias)
def get_encoder(self):
return self.led.encoder
def get_bias(self):
return {"final_logits_bias": self.final_logits_bias}
def set_bias(self, value):
self.final_logits_bias = value["final_logits_bias"]
def get_output_embeddings(self):
return self.get_input_embeddings()
def set_output_embeddings(self, value):
self.set_input_embeddings(value)
@add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFLEDSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
......@@ -2266,12 +2302,6 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
)
return (past[0], reordered_past)
def get_output_embeddings(self):
return self.led.shared
def get_encoder(self):
return self.led.encoder
def compute_loss(self, labels, logits):
"""CrossEntropyLoss that ignores pad tokens"""
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
......
......@@ -14,6 +14,7 @@
# limitations under the License.
"""Tensorflow Longformer model. """
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple
......@@ -437,6 +438,20 @@ class TFLongformerLMHead(tf.keras.layers.Layer):
super().build(input_shape)
def get_output_embeddings(self):
return self.decoder
def set_output_embeddings(self, value):
self.decoder.word_embeddings = value
self.decoder.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias}
def set_bias(self, value):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.act(hidden_states)
......@@ -1602,7 +1617,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0]
self.embeddings.vocab_size = shape_list(value)[0]
def _prune_heads(self, heads_to_prune):
"""
......@@ -2040,13 +2055,11 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
self.longformer = TFLongformerMainLayer(config, add_pooling_layer=False, name="longformer")
self.lm_head = TFLongformerLMHead(config, self.longformer.embeddings, name="lm_head")
def get_output_embeddings(self):
return self.lm_head.decoder
def get_output_layer_with_bias(self):
def get_lm_head(self):
return self.lm_head
def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.lm_head.name
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
......
......@@ -16,6 +16,7 @@
# limitations under the License.
""" TF 2.0 LXMERT model. """
import warnings
from dataclasses import dataclass
from typing import Dict, Optional, Tuple
......@@ -706,10 +707,7 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0]
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
self.embeddings.vocab_size = shape_list(value)[0]
def _prune_heads(self, heads_to_prune):
raise NotImplementedError
......@@ -1103,6 +1101,20 @@ class TFLxmertLMPredictionHead(tf.keras.layers.Layer):
self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape)
def get_output_embeddings(self):
return self.input_embeddings
def set_output_embeddings(self, value):
self.input_embeddings.word_embeddings = value
self.input_embeddings.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias}
def set_bias(self, value):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.input_embeddings(hidden_states, mode="linear")
......@@ -1292,13 +1304,11 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
**({"obj_labels": obj_labels} if self.config.task_obj_predict else {}),
}
def get_output_embeddings(self):
return self.lxmert.embeddings
def get_output_layer_with_bias(self):
def get_lm_head(self):
return self.cls.predictions
def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.cls.name + "/" + self.cls.predictions.name
@add_start_docstrings_to_model_forward(LXMERT_INPUTS_DOCSTRING)
......
......@@ -15,6 +15,7 @@
# limitations under the License.
""" TF 2.0 MobileBERT model. """
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple
......@@ -665,6 +666,20 @@ class TFMobileBertLMPredictionHead(tf.keras.layers.Layer):
)
super().build(input_shape)
def get_output_embeddings(self):
return self
def set_output_embeddings(self, value):
self.decoder = value
self.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias}
def set_bias(self, value):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = tf.matmul(hidden_states, tf.concat([tf.transpose(self.decoder), self.dense], axis=0))
......@@ -704,10 +719,7 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0]
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
self.embeddings.vocab_size = shape_list(value)[0]
def _prune_heads(self, heads_to_prune):
"""
......@@ -1039,13 +1051,11 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel):
self.predictions = TFMobileBertMLMHead(config, name="predictions___cls")
self.seq_relationship = TFMobileBertOnlyNSPHead(2, name="seq_relationship___cls")
def get_output_embeddings(self):
return self.predictions.predictions
def get_output_layer_with_bias(self):
def get_lm_head(self):
return self.predictions.predictions
def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.predictions.name + "/" + self.predictions.predictions.name
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
......@@ -1149,13 +1159,11 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
self.mobilebert = TFMobileBertMainLayer(config, add_pooling_layer=False, name="mobilebert")
self.mlm = TFMobileBertMLMHead(config, name="mlm___cls")
def get_output_embeddings(self):
return self.mlm.predictions
def get_output_layer_with_bias(self):
def get_lm_head(self):
return self.mlm.predictions
def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
......
......@@ -17,6 +17,7 @@
import math
import warnings
import tensorflow as tf
......@@ -541,7 +542,7 @@ class TFMPNetMainLayer(tf.keras.layers.Layer):
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.set_input_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0]
self.embeddings.vocab_size = shape_list(value)[0]
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer._prune_heads
def _prune_heads(self, heads_to_prune):
......@@ -840,6 +841,20 @@ class TFMPNetLMHead(tf.keras.layers.Layer):
super().build(input_shape)
def get_output_embeddings(self):
return self.decoder
def set_output_embeddings(self, value):
self.decoder.word_embeddings = value
self.decoder.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias}
def set_bias(self, value):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, features):
x = self.dense(features)
x = self.act(x)
......@@ -862,13 +877,11 @@ class TFMPNetForMaskedLM(TFMPNetPreTrainedModel, TFMaskedLanguageModelingLoss):
self.mpnet = TFMPNetMainLayer(config, name="mpnet")
self.lm_head = TFMPNetLMHead(config, self.mpnet.embeddings, name="lm_head")
def get_output_embeddings(self):
return self.mpnet.embeddings
def get_output_layer_with_bias(self):
def get_lm_head(self):
return self.lm_head
def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.lm_head.name
@add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
......
......@@ -219,7 +219,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
def set_input_embeddings(self, value):
self.tokens_embed.weight = value
self.tokens_embed.vocab_size = value.shape[0]
self.tokens_embed.vocab_size = shape_list(value)[0]
def _prune_heads(self, heads_to_prune):
"""
......@@ -577,7 +577,10 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelin
self.transformer = TFOpenAIGPTMainLayer(config, name="transformer")
def get_output_embeddings(self):
return self.transformer.tokens_embed
return self.get_input_embeddings()
def set_output_embeddings(self, value):
self.set_input_embeddings(value)
@add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
......@@ -682,9 +685,6 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
config, initializer_range=config.initializer_range, name="multiple_choice_head"
)
def get_output_embeddings(self):
return self.transformer.tokens_embed
@add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFOpenAIGPTDoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
def call(
......@@ -841,9 +841,6 @@ class TFOpenAIGPTForSequenceClassification(TFOpenAIGPTPreTrainedModel, TFSequenc
)
self.transformer = TFOpenAIGPTMainLayer(config, name="transformer")
def get_output_embeddings(self):
return self.transformer.tokens_embed
@add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
......
......@@ -15,6 +15,8 @@
# limitations under the License.
""" TF 2.0 RoBERTa model. """
import warnings
import tensorflow as tf
from ...activations_tf import get_tf_activation
......@@ -502,7 +504,7 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.set_input_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0]
self.embeddings.vocab_size = shape_list(value)[0]
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer._prune_heads
def _prune_heads(self, heads_to_prune):
......@@ -827,6 +829,20 @@ class TFRobertaLMHead(tf.keras.layers.Layer):
super().build(input_shape)
def get_output_embeddings(self):
return self.decoder
def set_output_embeddings(self, value):
self.decoder.word_embeddings = value
self.decoder.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias}
def set_bias(self, value):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.act(hidden_states)
......@@ -849,13 +865,11 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos
self.roberta = TFRobertaMainLayer(config, add_pooling_layer=False, name="roberta")
self.lm_head = TFRobertaLMHead(config, self.roberta.embeddings, name="lm_head")
def get_output_embeddings(self):
return self.lm_head.decoder
def get_output_layer_with_bias(self):
def get_lm_head(self):
return self.lm_head
def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.lm_head.name
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
......
......@@ -573,15 +573,6 @@ class TFT5MainLayer(tf.keras.layers.Layer):
self.final_layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="final_layer_norm")
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
def get_input_embeddings(self):
return self.embed_tokens
def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models
def _prune_heads(self, heads_to_prune):
raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models
......@@ -839,6 +830,26 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
return self.serving_output(output)
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, value):
try:
self.shared.weight = value
except AttributeError:
self(self.dummy_inputs)
self.shared.weight = value
self.shared.vocab_size = shape_list(value)[0]
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
pass
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
self.encoder.embed_tokens = embed_tokens
if hasattr(self, "decoder"):
self.decoder.embed_tokens = embed_tokens
def _shift_right(self, input_ids):
decoder_start_token_id = self.config.decoder_start_token_id
pad_token_id = self.config.pad_token_id
......@@ -1050,20 +1061,6 @@ class TFT5Model(TFT5PreTrainedModel):
decoder_config.is_decoder = True
self.decoder = TFT5MainLayer(decoder_config, embed_tokens, name="decoder")
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, new_embeddings):
self.shared.weight = new_embeddings
self.shared.vocab_size = self.shared.weight.shape[0]
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
pass
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
self.encoder.set_embed_tokens(embed_tokens)
self.decoder.set_embed_tokens(embed_tokens)
def get_encoder(self):
return self.encoder
......@@ -1222,24 +1219,23 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
if not config.tie_word_embeddings:
self.lm_head = tf.keras.layers.Dense(config.vocab_size, use_bias=False, name="lm_head")
def get_input_embeddings(self):
return self.shared
def get_output_embeddings(self):
if self.config.tie_word_embeddings:
return self.shared
return self.get_input_embeddings()
else:
return self.lm_head
# in a dense layer the kernel has a shape (last_dim, units), for us (dim, num_tokens)
# value has a shape (num_tokens, dim) then needs to be transposed
return tf.transpose(self.lm_head.kernel)
def set_input_embeddings(self, new_embeddings):
self.shared.weight = new_embeddings
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
pass
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
self.encoder.set_embed_tokens(embed_tokens)
self.decoder.set_embed_tokens(embed_tokens)
def set_output_embeddings(self, value):
if self.config.tie_word_embeddings:
self.set_input_embeddings(value)
else:
self.lm_head = tf.keras.layers.Dense(shape_list(value)[0], use_bias=False, name="lm_head")
# in a dense layer the kernel has a shape (last_dim, units), for us (dim, num_tokens)
# value has a shape (num_tokens, dim) then needs to be transposed
transposed_value = tf.transpose(value)
self.lm_head.kernel = transposed_value
def get_encoder(self):
return self.encoder
......@@ -1358,9 +1354,9 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
# T5v1.1 does not tie output word embeddings and thus does not require downscaling
if self.config.tie_word_embeddings:
sequence_output = sequence_output * (self.model_dim ** -0.5)
logits = self.get_output_embeddings()(sequence_output, mode="linear")
logits = self.shared(sequence_output, mode="linear")
else:
logits = self.get_output_embeddings()(sequence_output)
logits = self.lm_head(sequence_output)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
......@@ -1488,19 +1484,6 @@ class TFT5EncoderModel(TFT5PreTrainedModel):
encoder_config.use_cache = False
self.encoder = TFT5MainLayer(encoder_config, embed_tokens, name="encoder")
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, new_embeddings):
self.shared.weight = new_embeddings
self.shared.vocab_size = self.shared.weight.shape[0]
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
pass
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
self.encoder.set_embed_tokens(embed_tokens)
def get_encoder(self):
return self.encoder
......
......@@ -468,9 +468,6 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
def set_input_embeddings(self, value):
raise NotImplementedError
def _resize_token_embeddings(self, new_num_tokens):
return self.word_emb
def backward_compatible(self):
self.sample_softmax = -1
......@@ -909,25 +906,6 @@ class TFTransfoXLModel(TFTransfoXLPreTrainedModel):
)
class TFTransfoXLMHead(tf.keras.layers.Layer):
def __init__(self, config, input_embeddings, **kwargs):
super().__init__(**kwargs)
self.vocab_size = config.vocab_size
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.input_embeddings = input_embeddings
def build(self, input_shape):
self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape)
def call(self, hidden_states):
hidden_states = self.input_embeddings(hidden_states, mode="linear")
hidden_states = hidden_states + self.bias
return hidden_states
@add_start_docstrings(
"""
The Transformer-XL Model with a language modeling head on top (adaptive softmax with weights tied to the adaptive
......@@ -948,6 +926,9 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val, name="crit"
)
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError()
def get_output_embeddings(self):
"""Double-check if you are using adaptive softmax."""
if len(self.crit.out_layers) > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment