Unverified Commit a7dabfb3 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Fix TF s2s models (#9478)

* Fix Seq2Seq models for serving

* Apply style

* Fix lonfgormer

* Fix mBart/Pegasus/Blenderbot

* Apply style

* Add a main intermediate layer

* Apply style

* Remove import

* Apply tf.function to Longformer

* Fix utils check_copy

* Update S2S template

* Fix BART + Blenderbot

* Fix BlenderbotSmall

* Fix BlenderbotSmall

* Fix BlenderbotSmall

* Fix MBart

* Fix Marian

* Fix Pegasus + template

* Apply style

* Fix common attributes test

* Forgot to fix the LED test

* Apply Patrick's comment on LED Decoder
parent 23e5a36e
...@@ -322,6 +322,7 @@ def input_processing(func, config, input_ids, **kwargs): ...@@ -322,6 +322,7 @@ def input_processing(func, config, input_ids, **kwargs):
""" """
signature = dict(inspect.signature(func).parameters) signature = dict(inspect.signature(func).parameters)
signature.pop("kwargs", None) signature.pop("kwargs", None)
signature.pop("self", None)
parameter_names = list(signature.keys()) parameter_names = list(signature.keys())
output = {} output = {}
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray) allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray)
...@@ -346,6 +347,8 @@ def input_processing(func, config, input_ids, **kwargs): ...@@ -346,6 +347,8 @@ def input_processing(func, config, input_ids, **kwargs):
f"The following keyword arguments are not supported by this model: {list(kwargs['kwargs_call'].keys())}." f"The following keyword arguments are not supported by this model: {list(kwargs['kwargs_call'].keys())}."
) )
kwargs.pop("kwargs_call")
for k, v in kwargs.items(): for k, v in kwargs.items():
if isinstance(v, allowed_types) or v is None: if isinstance(v, allowed_types) or v is None:
output[k] = v output[k] = v
...@@ -356,8 +359,8 @@ def input_processing(func, config, input_ids, **kwargs): ...@@ -356,8 +359,8 @@ def input_processing(func, config, input_ids, **kwargs):
for i, input in enumerate(input_ids): for i, input in enumerate(input_ids):
# EagerTensors don't allow to use the .name property so we check for a real Tensor # EagerTensors don't allow to use the .name property so we check for a real Tensor
if type(input) == tf.Tensor: if type(input) == tf.Tensor:
# Tensor names have always the pattern name:device_id then we check only the # Tensor names have always the pattern `name:id` then we check only the
# name and not the device id # `name` part
tensor_name = input.name.split(":")[0] tensor_name = input.name.split(":")[0]
if tensor_name in parameter_names: if tensor_name in parameter_names:
......
...@@ -411,29 +411,6 @@ class TFBartPretrainedModel(TFPreTrainedModel): ...@@ -411,29 +411,6 @@ class TFBartPretrainedModel(TFPreTrainedModel):
} }
return dummy_inputs return dummy_inputs
def get_input_embeddings(self):
base_model = getattr(self, self.base_model_prefix, self)
return base_model.shared
def set_input_embeddings(self, value):
base_model = getattr(self, self.base_model_prefix, self)
try:
base_model.shared.weight = value
except AttributeError:
self(self.dummy_inputs)
base_model.shared.weight = value
base_model.shared.vocab_size = shape_list(base_model.shared.weight)[0]
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
pass
embed_tokens = TFWrappedEmbeddings(base_model.shared, abs_scope_name=shared_abs_scope_name)
base_model.encoder.set_embed_tokens(embed_tokens)
base_model.decoder.set_embed_tokens(embed_tokens)
@tf.function( @tf.function(
input_signature=[ input_signature=[
{ {
...@@ -605,6 +582,9 @@ class TFBartEncoder(tf.keras.layers.Layer): ...@@ -605,6 +582,9 @@ class TFBartEncoder(tf.keras.layers.Layer):
self.layers = [TFBartEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)] self.layers = [TFBartEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding") self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
def get_embed_tokens(self):
return self.embed_tokens
def set_embed_tokens(self, embed_tokens): def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
...@@ -744,6 +724,9 @@ class TFBartDecoder(tf.keras.layers.Layer): ...@@ -744,6 +724,9 @@ class TFBartDecoder(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(config.dropout) self.dropout = tf.keras.layers.Dropout(config.dropout)
def get_embed_tokens(self):
return self.embed_tokens
def set_embed_tokens(self, embed_tokens): def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
...@@ -871,13 +854,15 @@ class TFBartDecoder(tf.keras.layers.Layer): ...@@ -871,13 +854,15 @@ class TFBartDecoder(tf.keras.layers.Layer):
hidden_states = self.dropout(hidden_states, training=inputs["training"]) hidden_states = self.dropout(hidden_states, training=inputs["training"])
# decoder layers # decoder layers
all_hidden_states = () all_hidden_states = () if inputs["output_hidden_states"] else None
all_self_attns = () all_self_attns = () if inputs["output_attentions"] else None
present_key_values = () present_key_values = () if inputs["use_cache"] else None
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
dropout_probability = random.uniform(0, 1) dropout_probability = random.uniform(0, 1)
if inputs["training"] and (dropout_probability < self.layerdrop): if inputs["training"] and (dropout_probability < self.layerdrop):
...@@ -901,12 +886,12 @@ class TFBartDecoder(tf.keras.layers.Layer): ...@@ -901,12 +886,12 @@ class TFBartDecoder(tf.keras.layers.Layer):
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
else:
all_hidden_states = None
all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None if inputs["output_attentions"]:
all_self_attns = list(all_self_attns)
present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None if inputs["use_cache"]:
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
if not inputs["return_dict"]: if not inputs["return_dict"]:
return hidden_states, present_key_values, all_hidden_states, all_self_attns return hidden_states, present_key_values, all_hidden_states, all_self_attns
...@@ -919,16 +904,14 @@ class TFBartDecoder(tf.keras.layers.Layer): ...@@ -919,16 +904,14 @@ class TFBartDecoder(tf.keras.layers.Layer):
) )
@add_start_docstrings(
"The bare BART Model outputting raw hidden-states without any specific head on top.",
BART_START_DOCSTRING,
)
@keras_serializable @keras_serializable
class TFBartModel(TFBartPretrainedModel): class TFBartMainLayer(tf.keras.layers.Layer):
base_model_prefix = "model" config_class = BartConfig
def __init__(self, config: BartConfig, *inputs, **kwargs): def __init__(self, config: BartConfig, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(**kwargs)
self.config = config
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared") self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared")
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name: with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
...@@ -942,19 +925,20 @@ class TFBartModel(TFBartPretrainedModel): ...@@ -942,19 +925,20 @@ class TFBartModel(TFBartPretrainedModel):
self.encoder = TFBartEncoder(config, embed_tokens, name="encoder") self.encoder = TFBartEncoder(config, embed_tokens, name="encoder")
self.decoder = TFBartDecoder(config, embed_tokens, name="decoder") self.decoder = TFBartDecoder(config, embed_tokens, name="decoder")
def get_encoder(self): def get_input_embeddings(self):
return self.encoder return self.shared
def get_decoder(self): def set_input_embeddings(self, new_embeddings):
return self.decoder self.shared.weight = new_embeddings
self.shared.vocab_size = self.shared.weight.shape[0]
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
pass
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
self.encoder.set_embed_tokens(embed_tokens)
self.decoder.set_embed_tokens(embed_tokens)
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="facebook/bart-large",
output_type=TFSeq2SeqModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -1053,8 +1037,86 @@ class TFBartModel(TFBartPretrainedModel): ...@@ -1053,8 +1037,86 @@ class TFBartModel(TFBartPretrainedModel):
encoder_attentions=inputs["encoder_outputs"].attentions, encoder_attentions=inputs["encoder_outputs"].attentions,
) )
@add_start_docstrings(
"The bare BART Model outputting raw hidden-states without any specific head on top.",
BART_START_DOCSTRING,
)
class TFBartModel(TFBartPretrainedModel):
def __init__(self, config: BartConfig, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.model = TFBartMainLayer(config, name="model")
def get_encoder(self):
return self.model.encoder
def get_decoder(self):
return self.model.decoder
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="facebook/bart-large",
output_type=TFSeq2SeqModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs
):
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"],
decoder_attention_mask=inputs["decoder_attention_mask"],
encoder_outputs=inputs["encoder_outputs"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
def serving_output(self, output): def serving_output(self, output):
pkv = (tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,) pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
...@@ -1083,7 +1145,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel): ...@@ -1083,7 +1145,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel):
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.model = TFBartModel(config, name="model") self.model = TFBartMainLayer(config, name="model")
self.use_cache = config.use_cache self.use_cache = config.use_cache
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency. # final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency.
self.final_logits_bias = self.add_weight( self.final_logits_bias = self.add_weight(
...@@ -1199,7 +1261,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel): ...@@ -1199,7 +1261,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel):
) )
def serving_output(self, output): def serving_output(self, output):
pkv = (tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,) pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
......
...@@ -24,6 +24,7 @@ import tensorflow as tf ...@@ -24,6 +24,7 @@ import tensorflow as tf
from ...activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
from ...file_utils import ( from ...file_utils import (
add_code_sample_docstrings,
add_end_docstrings, add_end_docstrings,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
...@@ -47,7 +48,6 @@ from ...modeling_tf_utils import ( ...@@ -47,7 +48,6 @@ from ...modeling_tf_utils import (
shape_list, shape_list,
) )
from ...utils import logging from ...utils import logging
from ..blenderbot_small import TFBlenderbotSmallForConditionalGeneration, TFBlenderbotSmallModel
from .configuration_blenderbot import BlenderbotConfig from .configuration_blenderbot import BlenderbotConfig
...@@ -416,31 +416,6 @@ class TFBlenderbotPreTrainedModel(TFPreTrainedModel): ...@@ -416,31 +416,6 @@ class TFBlenderbotPreTrainedModel(TFPreTrainedModel):
} }
return dummy_inputs return dummy_inputs
# Copied from transformers.models.bart.modeling_tf_bart.TFBartPretrainedModel.get_input_embeddings
def get_input_embeddings(self):
base_model = getattr(self, self.base_model_prefix, self)
return base_model.shared
# Copied from transformers.models.bart.modeling_tf_bart.TFBartPretrainedModel.set_input_embeddings
def set_input_embeddings(self, value):
base_model = getattr(self, self.base_model_prefix, self)
try:
base_model.shared.weight = value
except AttributeError:
self(self.dummy_inputs)
base_model.shared.weight = value
base_model.shared.vocab_size = shape_list(base_model.shared.weight)[0]
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
pass
embed_tokens = TFWrappedEmbeddings(base_model.shared, abs_scope_name=shared_abs_scope_name)
base_model.encoder.set_embed_tokens(embed_tokens)
base_model.decoder.set_embed_tokens(embed_tokens)
@tf.function( @tf.function(
input_signature=[ input_signature=[
{ {
...@@ -604,6 +579,9 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer): ...@@ -604,6 +579,9 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
self.layers = [TFBlenderbotEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)] self.layers = [TFBlenderbotEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm") self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")
def get_embed_tokens(self):
return self.embed_tokens
def set_embed_tokens(self, embed_tokens): def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
...@@ -744,6 +722,9 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer): ...@@ -744,6 +722,9 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(config.dropout) self.dropout = tf.keras.layers.Dropout(config.dropout)
def get_embed_tokens(self):
return self.embed_tokens
def set_embed_tokens(self, embed_tokens): def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
...@@ -921,16 +902,14 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer): ...@@ -921,16 +902,14 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
) )
@add_start_docstrings(
"The bare BLENDERBOT Model outputting raw hidden-states without any specific head on top.",
BLENDERBOT_START_DOCSTRING,
)
@keras_serializable @keras_serializable
class TFBlenderbotModel(TFBlenderbotPreTrainedModel): class TFBlenderbotMainLayer(tf.keras.layers.Layer):
base_model_prefix = "model" config_class = BlenderbotConfig
def __init__(self, config: BlenderbotConfig, *inputs, **kwargs): def __init__(self, config: BlenderbotConfig, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(**kwargs)
self.config = config
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared") self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared")
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name: with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
...@@ -944,22 +923,20 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel): ...@@ -944,22 +923,20 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
self.encoder = TFBlenderbotEncoder(config, embed_tokens, name="encoder") self.encoder = TFBlenderbotEncoder(config, embed_tokens, name="encoder")
self.decoder = TFBlenderbotDecoder(config, embed_tokens, name="decoder") self.decoder = TFBlenderbotDecoder(config, embed_tokens, name="decoder")
@classmethod def get_input_embeddings(self):
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs): return self.shared
if pretrained_model_name_or_path == "facebook/blenderbot-90M":
warnings.warn(
"The checkpoint `facebook/blenderbot-90M` is deprecated. In the future, please use the identical checkpoint `facebook/small_blenderbot-90M` with `TFBlenderbotSmallModel.from_pretrained('facebook/small_blenderbot-90M')` instead.",
FutureWarning,
)
return TFBlenderbotSmallModel.from_pretrained(pretrained_model_name_or_path)
return super(TFBlenderbotModel, cls).from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
def get_decoder(self): def set_input_embeddings(self, new_embeddings):
return self.decoder self.shared.weight = new_embeddings
self.shared.vocab_size = self.shared.weight.shape[0]
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
pass
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
self.encoder.set_embed_tokens(embed_tokens)
self.decoder.set_embed_tokens(embed_tokens)
@add_start_docstrings_to_model_forward(BLENDERBOT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -977,22 +954,6 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel): ...@@ -977,22 +954,6 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
training=False, training=False,
**kwargs **kwargs
): ):
r"""
Returns:
Example::
>>> from transformers import BlenderbotTokenizer, TFBlenderbotModel
>>> model = TFBlenderbotModel.from_pretrained("facebook/blenderbot-400M-distill")
>>> tokenizer = BlenderbotTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
>>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="tf").input_ids # Batch size 1
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="tf").input_ids # Batch size 1
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
>>> last_hidden_states = outputs.last_hidden_state
"""
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config, config=self.config,
...@@ -1066,9 +1027,100 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel): ...@@ -1066,9 +1027,100 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
encoder_attentions=inputs["encoder_outputs"].attentions, encoder_attentions=inputs["encoder_outputs"].attentions,
) )
@add_start_docstrings(
"The bare BLENDERBOT Model outputting raw hidden-states without any specific head on top.",
BLENDERBOT_START_DOCSTRING,
)
class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
def __init__(self, config: BlenderbotConfig, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.model = TFBlenderbotMainLayer(config, name="model")
def get_encoder(self):
return self.model.encoder
def get_decoder(self):
return self.model.decoder
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
if pretrained_model_name_or_path == "facebook/blenderbot-90M":
from ..blenderbot_small import TFBlenderbotSmallModel
warnings.warn(
"The checkpoint `facebook/blenderbot-90M` is deprecated. In the future, please use the identical checkpoint `facebook/small_blenderbot-90M` with `TFBlenderbotSmallForConditionalGeneration.from_pretrained('facebook/small_blenderbot-90M')` instead.",
FutureWarning,
)
return TFBlenderbotSmallModel.from_pretrained(pretrained_model_name_or_path)
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
@add_start_docstrings_to_model_forward(BLENDERBOT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="facebook/blenderbot-400M-distill",
output_type=TFSeq2SeqModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs
):
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"],
decoder_attention_mask=inputs["decoder_attention_mask"],
encoder_outputs=inputs["encoder_outputs"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
# Copied from transformers.models.bart.modeling_tf_bart.TFBartModel.serving_output # Copied from transformers.models.bart.modeling_tf_bart.TFBartModel.serving_output
def serving_output(self, output): def serving_output(self, output):
pkv = (tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,) pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
...@@ -1097,25 +1149,43 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel): ...@@ -1097,25 +1149,43 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel):
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.model = TFBlenderbotModel(config, name="model") self.model = TFBlenderbotMainLayer(config, name="model")
self.use_cache = config.use_cache self.use_cache = config.use_cache
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency. # final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency.
self.final_logits_bias = self.add_weight( self.final_logits_bias = self.add_weight(
name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
) )
def get_decoder(self):
return self.model.decoder
def get_encoder(self):
return self.model.encoder
def get_output_embeddings(self):
return self.get_input_embeddings()
def set_output_embeddings(self, value):
self.set_input_embeddings(value)
def get_bias(self):
return {"final_logits_bias": self.final_logits_bias}
def set_bias(self, value):
self.final_logits_bias = value["final_logits_bias"]
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
if pretrained_model_name_or_path == "facebook/blenderbot-90M": if pretrained_model_name_or_path == "facebook/blenderbot-90M":
from ..blenderbot_small import TFBlenderbotSmallForConditionalGeneration
warnings.warn( warnings.warn(
"The checkpoint `facebook/blenderbot-90M` is deprecated. In the future, please use the identical checkpoint `facebook/small_blenderbot-90M` with `TFBlenderbotSmallForConditionalGeneration.from_pretrained('facebook/small_blenderbot-90M')` instead.", "The checkpoint `facebook/blenderbot-90M` is deprecated. In the future, please use the identical checkpoint `facebook/small_blenderbot-90M` with `TFBlenderbotSmallForConditionalGeneration.from_pretrained('facebook/small_blenderbot-90M')` instead.",
FutureWarning, FutureWarning,
) )
return TFBlenderbotSmallForConditionalGeneration.from_pretrained(pretrained_model_name_or_path) return TFBlenderbotSmallForConditionalGeneration.from_pretrained(pretrained_model_name_or_path)
return super(TFBlenderbotForConditionalGeneration, cls).from_pretrained( return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
pretrained_model_name_or_path, *model_args, **kwargs
)
@add_start_docstrings_to_model_forward(BLENDERBOT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(BLENDERBOT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
...@@ -1208,7 +1278,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel): ...@@ -1208,7 +1278,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel):
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.serving_output # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.serving_output
def serving_output(self, output): def serving_output(self, output):
pkv = (tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,) pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
...@@ -1283,21 +1353,6 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel): ...@@ -1283,21 +1353,6 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel):
else: else:
return logits return logits
def get_encoder(self):
return self.model.encoder
def get_output_embeddings(self):
return self.get_input_embeddings()
def set_output_embeddings(self, value):
self.set_input_embeddings(value)
def get_bias(self):
return {"final_logits_bias": self.final_logits_bias}
def set_bias(self, value):
self.final_logits_bias = value["final_logits_bias"]
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.compute_loss # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.compute_loss
def compute_loss(self, labels, logits): def compute_loss(self, labels, logits):
"""CrossEntropyLoss that ignores pad tokens""" """CrossEntropyLoss that ignores pad tokens"""
......
...@@ -52,7 +52,7 @@ if TYPE_CHECKING: ...@@ -52,7 +52,7 @@ if TYPE_CHECKING:
) )
if is_tf_available(): if is_tf_available():
from .modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration, TFBlenderbotModel from .modeling_tf_blenderbot_small import TFBlenderbotSmallForConditionalGeneration, TFBlenderbotSmallModel
else: else:
import importlib import importlib
......
...@@ -22,6 +22,7 @@ import tensorflow as tf ...@@ -22,6 +22,7 @@ import tensorflow as tf
from ...activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
from ...file_utils import ( from ...file_utils import (
add_code_sample_docstrings,
add_end_docstrings, add_end_docstrings,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
...@@ -414,31 +415,6 @@ class TFBlenderbotSmallPreTrainedModel(TFPreTrainedModel): ...@@ -414,31 +415,6 @@ class TFBlenderbotSmallPreTrainedModel(TFPreTrainedModel):
} }
return dummy_inputs return dummy_inputs
# Copied from transformers.models.bart.modeling_tf_bart.TFBartPretrainedModel.get_input_embeddings
def get_input_embeddings(self):
base_model = getattr(self, self.base_model_prefix, self)
return base_model.shared
# Copied from transformers.models.bart.modeling_tf_bart.TFBartPretrainedModel.set_input_embeddings
def set_input_embeddings(self, value):
base_model = getattr(self, self.base_model_prefix, self)
try:
base_model.shared.weight = value
except AttributeError:
self(self.dummy_inputs)
base_model.shared.weight = value
base_model.shared.vocab_size = shape_list(base_model.shared.weight)[0]
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
pass
embed_tokens = TFWrappedEmbeddings(base_model.shared, abs_scope_name=shared_abs_scope_name)
base_model.encoder.set_embed_tokens(embed_tokens)
base_model.decoder.set_embed_tokens(embed_tokens)
@tf.function( @tf.function(
input_signature=[ input_signature=[
{ {
...@@ -608,6 +584,9 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer): ...@@ -608,6 +584,9 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
self.layers = [TFBlenderbotSmallEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)] self.layers = [TFBlenderbotSmallEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding") self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
def get_embed_tokens(self):
return self.embed_tokens
def set_embed_tokens(self, embed_tokens): def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
...@@ -748,6 +727,9 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer): ...@@ -748,6 +727,9 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(config.dropout) self.dropout = tf.keras.layers.Dropout(config.dropout)
def get_embed_tokens(self):
return self.embed_tokens
def set_embed_tokens(self, embed_tokens): def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
...@@ -922,16 +904,14 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer): ...@@ -922,16 +904,14 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
) )
@add_start_docstrings(
"The bare BLENDERBOT_SMALL Model outputting raw hidden-states without any specific head on top.",
BLENDERBOT_SMALL_START_DOCSTRING,
)
@keras_serializable @keras_serializable
class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel): class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer):
base_model_prefix = "model" config_class = BlenderbotSmallConfig
def __init__(self, config: BlenderbotSmallConfig, *inputs, **kwargs): def __init__(self, config: BlenderbotSmallConfig, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(**kwargs)
self.config = config
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared") self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared")
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name: with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
...@@ -945,14 +925,20 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel): ...@@ -945,14 +925,20 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
self.encoder = TFBlenderbotSmallEncoder(config, embed_tokens, name="encoder") self.encoder = TFBlenderbotSmallEncoder(config, embed_tokens, name="encoder")
self.decoder = TFBlenderbotSmallDecoder(config, embed_tokens, name="decoder") self.decoder = TFBlenderbotSmallDecoder(config, embed_tokens, name="decoder")
def get_encoder(self): def get_input_embeddings(self):
return self.encoder return self.shared
def get_decoder(self): def set_input_embeddings(self, new_embeddings):
return self.decoder self.shared.weight = new_embeddings
self.shared.vocab_size = self.shared.weight.shape[0]
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
pass
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
self.encoder.set_embed_tokens(embed_tokens)
self.decoder.set_embed_tokens(embed_tokens)
@add_start_docstrings_to_model_forward(BLENDERBOT_SMALL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -970,22 +956,6 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel): ...@@ -970,22 +956,6 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
training=False, training=False,
**kwargs **kwargs
): ):
r"""
Returns:
Example::
>>> from transformers import BlenderbotSmallTokenizer, TFBlenderbotSmallModel
>>> model = TFBlenderbotSmallModel.from_pretrained("facebook/blenderbot_small-90M")
>>> tokenizer = BlenderbotSmallTokenizer.from_pretrained("facebook/blenderbot_small-90M")
>>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="tf").input_ids # Batch size 1
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="tf").input_ids # Batch size 1
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
>>> last_hidden_states = outputs.last_hidden_state
"""
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config, config=self.config,
...@@ -1059,9 +1029,87 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel): ...@@ -1059,9 +1029,87 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
encoder_attentions=inputs["encoder_outputs"].attentions, encoder_attentions=inputs["encoder_outputs"].attentions,
) )
@add_start_docstrings(
"The bare BLENDERBOT_SMALL Model outputting raw hidden-states without any specific head on top.",
BLENDERBOT_SMALL_START_DOCSTRING,
)
class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
def __init__(self, config: BlenderbotSmallConfig, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.model = TFBlenderbotSmallMainLayer(config, name="model")
def get_encoder(self):
return self.model.encoder
def get_decoder(self):
return self.model.decoder
@add_start_docstrings_to_model_forward(BLENDERBOT_SMALL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="facebook/blenderbot_small-90M",
output_type=TFSeq2SeqModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs
):
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"],
decoder_attention_mask=inputs["decoder_attention_mask"],
encoder_outputs=inputs["encoder_outputs"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
# Copied from transformers.models.bart.modeling_tf_bart.TFBartModel.serving_output # Copied from transformers.models.bart.modeling_tf_bart.TFBartModel.serving_output
def serving_output(self, output): def serving_output(self, output):
pkv = (tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,) pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
...@@ -1090,7 +1138,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel ...@@ -1090,7 +1138,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.model = TFBlenderbotSmallModel(config, name="model") self.model = TFBlenderbotSmallMainLayer(config, name="model")
self.use_cache = config.use_cache self.use_cache = config.use_cache
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency. # final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency.
self.final_logits_bias = self.add_weight( self.final_logits_bias = self.add_weight(
...@@ -1206,7 +1254,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel ...@@ -1206,7 +1254,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.serving_output # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.serving_output
def serving_output(self, output): def serving_output(self, output):
pkv = (tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,) pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
......
...@@ -320,6 +320,8 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -320,6 +320,8 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
) )
# make sure that local attention probabilities are set to 0 for indices of global attn # make sure that local attention probabilities are set to 0 for indices of global attn
# When is_global_attn is True, the last dimension is always self.one_sided_attn_window_size * 2 + 1 + 1
# because of the concat Line 713.
attn_probs = tf.where( attn_probs = tf.where(
tf.broadcast_to(is_index_global_attn[:, :, None, None], shape_list(attn_probs)), tf.broadcast_to(is_index_global_attn[:, :, None, None], shape_list(attn_probs)),
tf.zeros(shape_list(attn_probs), dtype=tf.dtypes.float32), tf.zeros(shape_list(attn_probs), dtype=tf.dtypes.float32),
...@@ -882,6 +884,7 @@ class TFLEDEncoderAttention(tf.keras.layers.Layer): ...@@ -882,6 +884,7 @@ class TFLEDEncoderAttention(tf.keras.layers.Layer):
[hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn], [hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn],
training=training, training=training,
) )
attention_output = self.output_dense(self_outputs[0], training=training) attention_output = self.output_dense(self_outputs[0], training=training)
outputs = (attention_output,) + self_outputs[1:] outputs = (attention_output,) + self_outputs[1:]
...@@ -1046,15 +1049,16 @@ class TFLEDEncoderLayer(tf.keras.layers.Layer): ...@@ -1046,15 +1049,16 @@ class TFLEDEncoderLayer(tf.keras.layers.Layer):
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(hidden_states), shape_list(hidden_states),
shape_list(residual), shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}", message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
) )
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn_layer_norm(hidden_states)
residual = hidden_states residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.activation_dropout(hidden_states, training=training) hidden_states = self.activation_dropout(hidden_states, training=training)
...@@ -1182,29 +1186,6 @@ class TFLEDPreTrainedModel(TFPreTrainedModel): ...@@ -1182,29 +1186,6 @@ class TFLEDPreTrainedModel(TFPreTrainedModel):
} }
return dummy_inputs return dummy_inputs
def get_input_embeddings(self):
base_model = getattr(self, self.base_model_prefix, self)
return base_model.shared
def set_input_embeddings(self, value):
base_model = getattr(self, self.base_model_prefix, self)
try:
base_model.shared.weight = value
except AttributeError:
self(self.dummy_inputs)
base_model.shared.weight = value
base_model.shared.vocab_size = shape_list(base_model.shared.weight)[0]
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
pass
embed_tokens = TFWrappedEmbeddings(base_model.shared, abs_scope_name=shared_abs_scope_name)
base_model.encoder.set_embed_tokens(embed_tokens)
base_model.decoder.set_embed_tokens(embed_tokens)
@tf.function( @tf.function(
input_signature=[ input_signature=[
{ {
...@@ -1521,6 +1502,9 @@ class TFLEDEncoder(tf.keras.layers.Layer): ...@@ -1521,6 +1502,9 @@ class TFLEDEncoder(tf.keras.layers.Layer):
self.layers = [TFLEDEncoderLayer(config, i, name=f"layers.{i}") for i in range(config.encoder_layers)] self.layers = [TFLEDEncoderLayer(config, i, name=f"layers.{i}") for i in range(config.encoder_layers)]
self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding") self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
def get_embed_tokens(self):
return self.embed_tokens
def set_embed_tokens(self, embed_tokens): def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
...@@ -1624,20 +1608,17 @@ class TFLEDEncoder(tf.keras.layers.Layer): ...@@ -1624,20 +1608,17 @@ class TFLEDEncoder(tf.keras.layers.Layer):
# check attention mask and invert # check attention mask and invert
if inputs["attention_mask"] is not None: if inputs["attention_mask"] is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(inputs["attention_mask"])[:, 0, 0, :] inputs["attention_mask"] = _expand_mask(inputs["attention_mask"])[:, 0, 0, :]
attention_mask = attention_mask[:, :, None, None] inputs["attention_mask"] = inputs["attention_mask"][:, :, None, None]
else:
attention_mask = None
encoder_states = () if inputs["output_hidden_states"] else None encoder_states = () if inputs["output_hidden_states"] else None
all_attentions = () if inputs["output_attentions"] else None all_attentions = all_global_attentions = () if inputs["output_attentions"] else None
all_global_attentions = () if inputs["output_attentions"] and is_global_attn else None
# encoder layers # encoder layers
for encoder_layer in self.layers: for encoder_layer in self.layers:
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states hidden_states_to_add = self.compute_hidden_states(hidden_states, padding_len)
encoder_states = encoder_states + (hidden_states_to_add,) encoder_states = encoder_states + (hidden_states_to_add,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1) dropout_probability = random.uniform(0, 1)
...@@ -1646,7 +1627,7 @@ class TFLEDEncoder(tf.keras.layers.Layer): ...@@ -1646,7 +1627,7 @@ class TFLEDEncoder(tf.keras.layers.Layer):
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
hidden_states=hidden_states, hidden_states=hidden_states,
attention_mask=attention_mask, attention_mask=inputs["attention_mask"],
is_index_masked=is_index_masked, is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn, is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn, is_global_attn=is_global_attn,
...@@ -1658,14 +1639,12 @@ class TFLEDEncoder(tf.keras.layers.Layer): ...@@ -1658,14 +1639,12 @@ class TFLEDEncoder(tf.keras.layers.Layer):
# bzs x seq_len x num_attn_heads x (num_global_attn + attention_window_len + 1) => bzs x num_attn_heads x seq_len x (num_global_attn + attention_window_len + 1) # bzs x seq_len x num_attn_heads x (num_global_attn + attention_window_len + 1) => bzs x num_attn_heads x seq_len x (num_global_attn + attention_window_len + 1)
all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),) all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),)
if is_global_attn: # bzs x num_attn_heads x num_global_attn x seq_len => bzs x num_attn_heads x seq_len x num_global_attn
# bzs x num_attn_heads x num_global_attn x seq_len => bzs x num_attn_heads x seq_len x num_global_attn all_global_attentions = all_global_attentions + (tf.transpose(layer_outputs[2], (0, 1, 3, 2)),)
all_global_attentions = all_global_attentions + (tf.transpose(layer_outputs[2], (0, 1, 3, 2)),)
# undo padding # undo padding
if padding_len > 0: # unpad `hidden_states` because the calling function is expecting a length == input_ids.size(1)
# unpad `hidden_states` because the calling function is expecting a length == input_ids.size(1) hidden_states = self.compute_hidden_states(hidden_states, padding_len)
hidden_states = hidden_states[:, :-padding_len]
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
...@@ -1679,6 +1658,11 @@ class TFLEDEncoder(tf.keras.layers.Layer): ...@@ -1679,6 +1658,11 @@ class TFLEDEncoder(tf.keras.layers.Layer):
global_attentions=all_global_attentions, global_attentions=all_global_attentions,
) )
@tf.function
def compute_hidden_states(self, hidden_states, padding_len):
return hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
@tf.function
def _pad_to_window_size( def _pad_to_window_size(
self, self,
input_ids, input_ids,
...@@ -1777,19 +1761,14 @@ class TFLEDDecoder(tf.keras.layers.Layer): ...@@ -1777,19 +1761,14 @@ class TFLEDDecoder(tf.keras.layers.Layer):
Args: Args:
input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`): input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it. provide it. Indices can be obtained using :class:`~transformers.LEDTokenizer`. See
Indices can be obtained using :class:`~transformers.LEDTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
for details. for details. `What are input IDs? <../glossary.html#input-ids>`__
`What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**, - 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
encoder_hidden_states (:obj:`tf.Tensor` of shape :obj:`(batch_size, encoder_sequence_length, hidden_size)`, `optional`): encoder_hidden_states (:obj:`tf.Tensor` of shape :obj:`(batch_size, encoder_sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
...@@ -1800,13 +1779,10 @@ class TFLEDDecoder(tf.keras.layers.Layer): ...@@ -1800,13 +1779,10 @@ class TFLEDDecoder(tf.keras.layers.Layer):
- 1 for tokens that are **not masked**, - 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
decoding. decoding. If :obj:`past_key_values` are used, the user can optionally input only the last
If :obj:`past_key_values` are used, the user can optionally input only the last
:obj:`decoder_input_ids` (those that don't have their past key value states given to this model) of :obj:`decoder_input_ids` (those that don't have their past key value states given to this model) of
shape :obj:`(batch_size, 1)` instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size, shape :obj:`(batch_size, 1)` instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size,
sequence_length)`. sequence_length)`.
...@@ -1930,16 +1906,13 @@ class TFLEDDecoder(tf.keras.layers.Layer): ...@@ -1930,16 +1906,13 @@ class TFLEDDecoder(tf.keras.layers.Layer):
) )
@add_start_docstrings(
"The bare LED Model outputting raw hidden-states without any specific head on top.",
LED_START_DOCSTRING,
)
@keras_serializable @keras_serializable
class TFLEDModel(TFLEDPreTrainedModel): class TFLEDMainLayer(tf.keras.layers.Layer):
base_model_prefix = "led" config_class = LEDConfig
def __init__(self, config: LEDConfig, *inputs, **kwargs): def __init__(self, config: LEDConfig, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(**kwargs)
self.config = config
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="led.shared") self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="led.shared")
with tf.compat.v1.variable_scope("led.shared") as shared_abs_scope_name: with tf.compat.v1.variable_scope("led.shared") as shared_abs_scope_name:
...@@ -1953,19 +1926,20 @@ class TFLEDModel(TFLEDPreTrainedModel): ...@@ -1953,19 +1926,20 @@ class TFLEDModel(TFLEDPreTrainedModel):
self.encoder = TFLEDEncoder(config, embed_tokens, name="encoder") self.encoder = TFLEDEncoder(config, embed_tokens, name="encoder")
self.decoder = TFLEDDecoder(config, embed_tokens, name="decoder") self.decoder = TFLEDDecoder(config, embed_tokens, name="decoder")
def get_encoder(self): def get_input_embeddings(self):
return self.encoder return self.shared
def get_decoder(self): def set_input_embeddings(self, new_embeddings):
return self.decoder self.shared.weight = new_embeddings
self.shared.vocab_size = self.shared.weight.shape[0]
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("led.shared") as shared_abs_scope_name:
pass
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
self.encoder.set_embed_tokens(embed_tokens)
self.decoder.set_embed_tokens(embed_tokens)
@add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="allenai/led-base-16384",
output_type=TFLEDSeq2SeqModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -2007,12 +1981,6 @@ class TFLEDModel(TFLEDPreTrainedModel): ...@@ -2007,12 +1981,6 @@ class TFLEDModel(TFLEDPreTrainedModel):
if inputs["decoder_input_ids"] is None and inputs["decoder_inputs_embeds"] is None: if inputs["decoder_input_ids"] is None and inputs["decoder_inputs_embeds"] is None:
inputs["use_cache"] = False inputs["use_cache"] = False
inputs["output_hidden_states"] = (
inputs["output_hidden_states"]
if inputs["output_hidden_states"] is not None
else self.config.output_hidden_states
)
if inputs["encoder_outputs"] is None: if inputs["encoder_outputs"] is None:
inputs["encoder_outputs"] = self.encoder( inputs["encoder_outputs"] = self.encoder(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
...@@ -2063,8 +2031,88 @@ class TFLEDModel(TFLEDPreTrainedModel): ...@@ -2063,8 +2031,88 @@ class TFLEDModel(TFLEDPreTrainedModel):
encoder_global_attentions=inputs["encoder_outputs"].global_attentions, encoder_global_attentions=inputs["encoder_outputs"].global_attentions,
) )
@add_start_docstrings(
"The bare LED Model outputting raw hidden-states without any specific head on top.",
LED_START_DOCSTRING,
)
class TFLEDModel(TFLEDPreTrainedModel):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.led = TFLEDMainLayer(config, name="led")
def get_encoder(self):
return self.led.encoder
def get_decoder(self):
return self.led.decoder
@add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="allenai/led-base-16384",
output_type=TFLEDSeq2SeqModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs: Optional[Union[Tuple, TFLEDEncoderBaseModelOutput]] = None,
global_attention_mask=None,
past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs
):
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
global_attention_mask=global_attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.led(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"],
decoder_attention_mask=inputs["decoder_attention_mask"],
encoder_outputs=inputs["encoder_outputs"],
global_attention_mask=inputs["global_attention_mask"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
def serving_output(self, output): def serving_output(self, output):
pkv = (tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,) pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
...@@ -2095,7 +2143,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel): ...@@ -2095,7 +2143,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.led = TFLEDModel(config, name="led") self.led = TFLEDMainLayer(config, name="led")
self.use_cache = config.use_cache self.use_cache = config.use_cache
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency. # final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency.
self.final_logits_bias = self.add_weight( self.final_logits_bias = self.add_weight(
...@@ -2157,6 +2205,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel): ...@@ -2157,6 +2205,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
>>> probs = tf.nn.softmax(logits[0]) >>> probs = tf.nn.softmax(logits[0])
>>> # probs[5] is associated with the mask token >>> # probs[5] is associated with the mask token
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config, config=self.config,
...@@ -2221,7 +2270,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel): ...@@ -2221,7 +2270,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
) )
def serving_output(self, output): def serving_output(self, output):
pkv = (tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,) pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
......
...@@ -974,6 +974,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -974,6 +974,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
) )
# make sure that local attention probabilities are set to 0 for indices of global attn # make sure that local attention probabilities are set to 0 for indices of global attn
# When is_global_attn is True, the last dimension is always self.one_sided_attn_window_size * 2 + 1 + 1
# because of the concat Line 713.
attn_probs = tf.where( attn_probs = tf.where(
tf.broadcast_to(is_index_global_attn[:, :, None, None], shape_list(attn_probs)), tf.broadcast_to(is_index_global_attn[:, :, None, None], shape_list(attn_probs)),
tf.zeros(shape_list(attn_probs), dtype=tf.dtypes.float32), tf.zeros(shape_list(attn_probs), dtype=tf.dtypes.float32),
......
...@@ -23,6 +23,7 @@ import tensorflow as tf ...@@ -23,6 +23,7 @@ import tensorflow as tf
from ...activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
from ...file_utils import ( from ...file_utils import (
add_code_sample_docstrings,
add_end_docstrings, add_end_docstrings,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
...@@ -444,31 +445,6 @@ class TFMarianPreTrainedModel(TFPreTrainedModel): ...@@ -444,31 +445,6 @@ class TFMarianPreTrainedModel(TFPreTrainedModel):
} }
return dummy_inputs return dummy_inputs
# Copied from transformers.models.bart.modeling_tf_bart.TFBartPretrainedModel.get_input_embeddings
def get_input_embeddings(self):
base_model = getattr(self, self.base_model_prefix, self)
return base_model.shared
# Copied from transformers.models.bart.modeling_tf_bart.TFBartPretrainedModel.set_input_embeddings
def set_input_embeddings(self, value):
base_model = getattr(self, self.base_model_prefix, self)
try:
base_model.shared.weight = value
except AttributeError:
self(self.dummy_inputs)
base_model.shared.weight = value
base_model.shared.vocab_size = shape_list(base_model.shared.weight)[0]
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
pass
embed_tokens = TFWrappedEmbeddings(base_model.shared, abs_scope_name=shared_abs_scope_name)
base_model.encoder.set_embed_tokens(embed_tokens)
base_model.decoder.set_embed_tokens(embed_tokens)
@tf.function( @tf.function(
input_signature=[ input_signature=[
{ {
...@@ -625,6 +601,9 @@ class TFMarianEncoder(tf.keras.layers.Layer): ...@@ -625,6 +601,9 @@ class TFMarianEncoder(tf.keras.layers.Layer):
) )
self.layers = [TFMarianEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)] self.layers = [TFMarianEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
def get_embed_tokens(self):
return self.embed_tokens
def set_embed_tokens(self, embed_tokens): def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
...@@ -761,6 +740,9 @@ class TFMarianDecoder(tf.keras.layers.Layer): ...@@ -761,6 +740,9 @@ class TFMarianDecoder(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(config.dropout) self.dropout = tf.keras.layers.Dropout(config.dropout)
def get_embed_tokens(self):
return self.embed_tokens
def set_embed_tokens(self, embed_tokens): def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
...@@ -935,16 +917,14 @@ class TFMarianDecoder(tf.keras.layers.Layer): ...@@ -935,16 +917,14 @@ class TFMarianDecoder(tf.keras.layers.Layer):
) )
@add_start_docstrings(
"The bare MARIAN Model outputting raw hidden-states without any specific head on top.",
MARIAN_START_DOCSTRING,
)
@keras_serializable @keras_serializable
class TFMarianModel(TFMarianPreTrainedModel): class TFMarianMainLayer(tf.keras.layers.Layer):
base_model_prefix = "model" config_class = MarianConfig
def __init__(self, config: MarianConfig, *inputs, **kwargs): def __init__(self, config: MarianConfig, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(**kwargs)
self.config = config
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared") self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared")
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name: with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
...@@ -958,14 +938,20 @@ class TFMarianModel(TFMarianPreTrainedModel): ...@@ -958,14 +938,20 @@ class TFMarianModel(TFMarianPreTrainedModel):
self.encoder = TFMarianEncoder(config, embed_tokens, name="encoder") self.encoder = TFMarianEncoder(config, embed_tokens, name="encoder")
self.decoder = TFMarianDecoder(config, embed_tokens, name="decoder") self.decoder = TFMarianDecoder(config, embed_tokens, name="decoder")
def get_encoder(self): def get_input_embeddings(self):
return self.encoder return self.shared
def get_decoder(self): def set_input_embeddings(self, new_embeddings):
return self.decoder self.shared.weight = new_embeddings
self.shared.vocab_size = self.shared.weight.shape[0]
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
pass
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
self.encoder.set_embed_tokens(embed_tokens)
self.decoder.set_embed_tokens(embed_tokens)
@add_start_docstrings_to_model_forward(MARIAN_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -983,24 +969,6 @@ class TFMarianModel(TFMarianPreTrainedModel): ...@@ -983,24 +969,6 @@ class TFMarianModel(TFMarianPreTrainedModel):
training=False, training=False,
**kwargs **kwargs
): ):
r"""
Returns:
Example::
>>> from transformers import MarianTokenizer, TFMarianModel
>>> tokenizer = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-de')
>>> model = TFMarianModel.from_pretrained('Helsinki-NLP/opus-mt-en-de')
>>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="tf").input_ids # Batch size 1
>>> decoder_input_ids = tokenizer("<pad> Studien haben gezeigt dass es hilfreich ist einen Hund zu besitzen",
... return_tensors="tf", add_special_tokens=False).input_ids # Batch size 1
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
>>> last_hidden_states = outputs.last_hidden_state
"""
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config, config=self.config,
...@@ -1077,9 +1045,87 @@ class TFMarianModel(TFMarianPreTrainedModel): ...@@ -1077,9 +1045,87 @@ class TFMarianModel(TFMarianPreTrainedModel):
encoder_attentions=inputs["encoder_outputs"].attentions, encoder_attentions=inputs["encoder_outputs"].attentions,
) )
@add_start_docstrings(
"The bare MARIAN Model outputting raw hidden-states without any specific head on top.",
MARIAN_START_DOCSTRING,
)
class TFMarianModel(TFMarianPreTrainedModel):
def __init__(self, config: MarianConfig, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.model = TFMarianMainLayer(config, name="model")
def get_encoder(self):
return self.model.encoder
def get_decoder(self):
return self.model.decoder
@add_start_docstrings_to_model_forward(MARIAN_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="Helsinki-NLP/opus-mt-en-de",
output_type=TFSeq2SeqModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs
):
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"],
decoder_attention_mask=inputs["decoder_attention_mask"],
encoder_outputs=inputs["encoder_outputs"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
# Copied from transformers.models.bart.modeling_tf_bart.TFBartModel.serving_output # Copied from transformers.models.bart.modeling_tf_bart.TFBartModel.serving_output
def serving_output(self, output): def serving_output(self, output):
pkv = (tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,) pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
...@@ -1108,7 +1154,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel): ...@@ -1108,7 +1154,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel):
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.model = TFMarianModel(config, name="model") self.model = TFMarianMainLayer(config, name="model")
self.use_cache = config.use_cache self.use_cache = config.use_cache
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency. # final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency.
self.final_logits_bias = self.add_weight( self.final_logits_bias = self.add_weight(
...@@ -1225,7 +1271,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel): ...@@ -1225,7 +1271,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel):
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.serving_output # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.serving_output
def serving_output(self, output): def serving_output(self, output):
pkv = (tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,) pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
......
...@@ -417,31 +417,6 @@ class TFMBartPreTrainedModel(TFPreTrainedModel): ...@@ -417,31 +417,6 @@ class TFMBartPreTrainedModel(TFPreTrainedModel):
} }
return dummy_inputs return dummy_inputs
# Copied from transformers.models.bart.modeling_tf_bart.TFBartPretrainedModel.get_input_embeddings
def get_input_embeddings(self):
base_model = getattr(self, self.base_model_prefix, self)
return base_model.shared
# Copied from transformers.models.bart.modeling_tf_bart.TFBartPretrainedModel.set_input_embeddings
def set_input_embeddings(self, value):
base_model = getattr(self, self.base_model_prefix, self)
try:
base_model.shared.weight = value
except AttributeError:
self(self.dummy_inputs)
base_model.shared.weight = value
base_model.shared.vocab_size = shape_list(base_model.shared.weight)[0]
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
pass
embed_tokens = TFWrappedEmbeddings(base_model.shared, abs_scope_name=shared_abs_scope_name)
base_model.encoder.set_embed_tokens(embed_tokens)
base_model.decoder.set_embed_tokens(embed_tokens)
@tf.function( @tf.function(
input_signature=[ input_signature=[
{ {
...@@ -615,6 +590,9 @@ class TFMBartEncoder(tf.keras.layers.Layer): ...@@ -615,6 +590,9 @@ class TFMBartEncoder(tf.keras.layers.Layer):
self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding") self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm") self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")
def get_embed_tokens(self):
return self.embed_tokens
def set_embed_tokens(self, embed_tokens): def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
...@@ -757,6 +735,9 @@ class TFMBartDecoder(tf.keras.layers.Layer): ...@@ -757,6 +735,9 @@ class TFMBartDecoder(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(config.dropout) self.dropout = tf.keras.layers.Dropout(config.dropout)
def get_embed_tokens(self):
return self.embed_tokens
def set_embed_tokens(self, embed_tokens): def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
...@@ -934,16 +915,14 @@ class TFMBartDecoder(tf.keras.layers.Layer): ...@@ -934,16 +915,14 @@ class TFMBartDecoder(tf.keras.layers.Layer):
) )
@add_start_docstrings(
"The bare MBART Model outputting raw hidden-states without any specific head on top.",
MBART_START_DOCSTRING,
)
@keras_serializable @keras_serializable
class TFMBartModel(TFMBartPreTrainedModel): class TFMBartMainLayer(tf.keras.layers.Layer):
base_model_prefix = "model" config_class = MBartConfig
def __init__(self, config: MBartConfig, *inputs, **kwargs): def __init__(self, config: MBartConfig, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(**kwargs)
self.config = config
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared") self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared")
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name: with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
...@@ -957,19 +936,20 @@ class TFMBartModel(TFMBartPreTrainedModel): ...@@ -957,19 +936,20 @@ class TFMBartModel(TFMBartPreTrainedModel):
self.encoder = TFMBartEncoder(config, embed_tokens, name="encoder") self.encoder = TFMBartEncoder(config, embed_tokens, name="encoder")
self.decoder = TFMBartDecoder(config, embed_tokens, name="decoder") self.decoder = TFMBartDecoder(config, embed_tokens, name="decoder")
def get_encoder(self): def get_input_embeddings(self):
return self.encoder return self.shared
def get_decoder(self): def set_input_embeddings(self, new_embeddings):
return self.decoder self.shared.weight = new_embeddings
self.shared.vocab_size = self.shared.weight.shape[0]
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
pass
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
self.encoder.set_embed_tokens(embed_tokens)
self.decoder.set_embed_tokens(embed_tokens)
@add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="facebook/mbart-large-cc25",
output_type=TFSeq2SeqModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -1066,9 +1046,87 @@ class TFMBartModel(TFMBartPreTrainedModel): ...@@ -1066,9 +1046,87 @@ class TFMBartModel(TFMBartPreTrainedModel):
encoder_attentions=inputs["encoder_outputs"].attentions, encoder_attentions=inputs["encoder_outputs"].attentions,
) )
@add_start_docstrings(
"The bare MBART Model outputting raw hidden-states without any specific head on top.",
MBART_START_DOCSTRING,
)
class TFMBartModel(TFMBartPreTrainedModel):
def __init__(self, config: MBartConfig, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.model = TFMBartMainLayer(config, name="model")
def get_encoder(self):
return self.model.encoder
def get_decoder(self):
return self.model.decoder
@add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="facebook/mbart-large-cc25",
output_type=TFSeq2SeqModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs
):
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"],
decoder_attention_mask=inputs["decoder_attention_mask"],
encoder_outputs=inputs["encoder_outputs"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
# Copied from transformers.models.bart.modeling_tf_bart.TFBartModel.serving_output # Copied from transformers.models.bart.modeling_tf_bart.TFBartModel.serving_output
def serving_output(self, output): def serving_output(self, output):
pkv = (tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,) pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
...@@ -1097,7 +1155,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel): ...@@ -1097,7 +1155,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel):
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.model = TFMBartModel(config, name="model") self.model = TFMBartMainLayer(config, name="model")
self.use_cache = config.use_cache self.use_cache = config.use_cache
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency. # final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency.
self.final_logits_bias = self.add_weight( self.final_logits_bias = self.add_weight(
...@@ -1212,7 +1270,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel): ...@@ -1212,7 +1270,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel):
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.serving_output # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.serving_output
def serving_output(self, output): def serving_output(self, output):
pkv = (tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,) pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
......
...@@ -23,6 +23,7 @@ import tensorflow as tf ...@@ -23,6 +23,7 @@ import tensorflow as tf
from ...activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
from ...file_utils import ( from ...file_utils import (
add_code_sample_docstrings,
add_end_docstrings, add_end_docstrings,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
...@@ -445,31 +446,6 @@ class TFPegasusPreTrainedModel(TFPreTrainedModel): ...@@ -445,31 +446,6 @@ class TFPegasusPreTrainedModel(TFPreTrainedModel):
} }
return dummy_inputs return dummy_inputs
# Copied from transformers.models.bart.modeling_tf_bart.TFBartPretrainedModel.get_input_embeddings
def get_input_embeddings(self):
base_model = getattr(self, self.base_model_prefix, self)
return base_model.shared
# Copied from transformers.models.bart.modeling_tf_bart.TFBartPretrainedModel.set_input_embeddings
def set_input_embeddings(self, value):
base_model = getattr(self, self.base_model_prefix, self)
try:
base_model.shared.weight = value
except AttributeError:
self(self.dummy_inputs)
base_model.shared.weight = value
base_model.shared.vocab_size = shape_list(base_model.shared.weight)[0]
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
pass
embed_tokens = TFWrappedEmbeddings(base_model.shared, abs_scope_name=shared_abs_scope_name)
base_model.encoder.set_embed_tokens(embed_tokens)
base_model.decoder.set_embed_tokens(embed_tokens)
@tf.function( @tf.function(
input_signature=[ input_signature=[
{ {
...@@ -631,6 +607,9 @@ class TFPegasusEncoder(tf.keras.layers.Layer): ...@@ -631,6 +607,9 @@ class TFPegasusEncoder(tf.keras.layers.Layer):
self.layers = [TFPegasusEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)] self.layers = [TFPegasusEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm") self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")
def get_embed_tokens(self):
return self.embed_tokens
def set_embed_tokens(self, embed_tokens): def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
...@@ -770,6 +749,9 @@ class TFPegasusDecoder(tf.keras.layers.Layer): ...@@ -770,6 +749,9 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(config.dropout) self.dropout = tf.keras.layers.Dropout(config.dropout)
def get_embed_tokens(self):
return self.embed_tokens
def set_embed_tokens(self, embed_tokens): def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
...@@ -946,16 +928,14 @@ class TFPegasusDecoder(tf.keras.layers.Layer): ...@@ -946,16 +928,14 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
) )
@add_start_docstrings(
"The bare PEGASUS Model outputting raw hidden-states without any specific head on top.",
PEGASUS_START_DOCSTRING,
)
@keras_serializable @keras_serializable
class TFPegasusModel(TFPegasusPreTrainedModel): class TFPegasusMainLayer(tf.keras.layers.Layer):
base_model_prefix = "model" config_class = PegasusConfig
def __init__(self, config: PegasusConfig, *inputs, **kwargs): def __init__(self, config: PegasusConfig, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(**kwargs)
self.config = config
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared") self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared")
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name: with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
...@@ -969,14 +949,20 @@ class TFPegasusModel(TFPegasusPreTrainedModel): ...@@ -969,14 +949,20 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
self.encoder = TFPegasusEncoder(config, embed_tokens, name="encoder") self.encoder = TFPegasusEncoder(config, embed_tokens, name="encoder")
self.decoder = TFPegasusDecoder(config, embed_tokens, name="decoder") self.decoder = TFPegasusDecoder(config, embed_tokens, name="decoder")
def get_encoder(self): def get_input_embeddings(self):
return self.encoder return self.shared
def get_decoder(self): def set_input_embeddings(self, new_embeddings):
return self.decoder self.shared.weight = new_embeddings
self.shared.vocab_size = self.shared.weight.shape[0]
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
pass
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
self.encoder.set_embed_tokens(embed_tokens)
self.decoder.set_embed_tokens(embed_tokens)
@add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -994,22 +980,6 @@ class TFPegasusModel(TFPegasusPreTrainedModel): ...@@ -994,22 +980,6 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
training=False, training=False,
**kwargs **kwargs
): ):
r"""
Returns:
Example::
>>> from transformers import PegasusTokenizer, TFPegasusModel
>>> tokenizer = PegasusTokenizer.from_pretrained("google/pegasus-large")
>>> model = TFPegasusModel.from_pretrained("google/pegasus-large")
>>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="tf").input_ids # Batch size 1
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="tf").input_ids # Batch size 1
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
>>> last_hidden_states = outputs.last_hidden_state
"""
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config, config=self.config,
...@@ -1086,9 +1056,87 @@ class TFPegasusModel(TFPegasusPreTrainedModel): ...@@ -1086,9 +1056,87 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
encoder_attentions=inputs["encoder_outputs"].attentions, encoder_attentions=inputs["encoder_outputs"].attentions,
) )
@add_start_docstrings(
"The bare PEGASUS Model outputting raw hidden-states without any specific head on top.",
PEGASUS_START_DOCSTRING,
)
class TFPegasusModel(TFPegasusPreTrainedModel):
def __init__(self, config: PegasusConfig, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.model = TFPegasusMainLayer(config, name="model")
def get_encoder(self):
return self.model.encoder
def get_decoder(self):
return self.model.decoder
@add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/pegasus-large",
output_type=TFSeq2SeqModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs
):
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"],
decoder_attention_mask=inputs["decoder_attention_mask"],
encoder_outputs=inputs["encoder_outputs"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
# Copied from transformers.models.bart.modeling_tf_bart.TFBartModel.serving_output # Copied from transformers.models.bart.modeling_tf_bart.TFBartModel.serving_output
def serving_output(self, output): def serving_output(self, output):
pkv = (tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,) pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
...@@ -1117,7 +1165,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel): ...@@ -1117,7 +1165,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel):
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.model = TFPegasusModel(config, name="model") self.model = TFPegasusMainLayer(config, name="model")
self.use_cache = config.use_cache self.use_cache = config.use_cache
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency. # final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency.
self.final_logits_bias = self.add_weight( self.final_logits_bias = self.add_weight(
...@@ -1234,7 +1282,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel): ...@@ -1234,7 +1282,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel):
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.serving_output # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.serving_output
def serving_output(self, output): def serving_output(self, output):
pkv = (tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,) pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
......
...@@ -1207,7 +1207,7 @@ class TFT5Model(TFT5PreTrainedModel): ...@@ -1207,7 +1207,7 @@ class TFT5Model(TFT5PreTrainedModel):
) )
def serving_output(self, output): def serving_output(self, output):
pkv = (tf.convert_to_tensor(output.past_key_values[1:]) if self.config.use_cache else None,) pkv = tf.convert_to_tensor(output.past_key_values[1:]) if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
...@@ -1437,7 +1437,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1437,7 +1437,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
) )
def serving_output(self, output): def serving_output(self, output):
pkv = (tf.convert_to_tensor(output.past_key_values[1:]) if self.config.use_cache else None,) pkv = tf.convert_to_tensor(output.past_key_values[1:]) if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import tensorflow as tf import tensorflow as tf
from tensorflow.keras import layers
from transformers.modeling_tf_outputs import TFCausalLMOutput from transformers.modeling_tf_outputs import TFCausalLMOutput
...@@ -1915,29 +1916,6 @@ class TF{{cookiecutter.camelcase_modelname}}PreTrainedModel(TFPreTrainedModel): ...@@ -1915,29 +1916,6 @@ class TF{{cookiecutter.camelcase_modelname}}PreTrainedModel(TFPreTrainedModel):
} }
return dummy_inputs return dummy_inputs
def get_input_embeddings(self):
base_model = getattr(self, self.base_model_prefix, self)
return base_model.shared
def set_input_embeddings(self, value):
base_model = getattr(self, self.base_model_prefix, self)
try:
base_model.shared.weight = value
except AttributeError:
self(self.dummy_inputs)
base_model.shared.weight = value
base_model.shared.vocab_size = shape_list(base_model.shared.weight)[0]
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
pass
embed_tokens = TFWrappedEmbeddings(base_model.shared, abs_scope_name=shared_abs_scope_name)
base_model.encoder.set_embed_tokens(embed_tokens)
base_model.decoder.set_embed_tokens(embed_tokens)
@tf.function( @tf.function(
input_signature=[ input_signature=[
{ {
...@@ -1948,6 +1926,7 @@ class TF{{cookiecutter.camelcase_modelname}}PreTrainedModel(TFPreTrainedModel): ...@@ -1948,6 +1926,7 @@ class TF{{cookiecutter.camelcase_modelname}}PreTrainedModel(TFPreTrainedModel):
} }
] ]
) )
# Copied from transformers.models.bart.modeling_tf_bart.TFBartPretrainedModel.serving
def serving(self, inputs): def serving(self, inputs):
output = self.call(inputs) output = self.call(inputs)
...@@ -2080,6 +2059,9 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer): ...@@ -2080,6 +2059,9 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
self.layers = [TF{{cookiecutter.camelcase_modelname}}EncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)] self.layers = [TF{{cookiecutter.camelcase_modelname}}EncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding") self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
def get_embed_tokens(self):
return self.embed_tokens
def set_embed_tokens(self, embed_tokens): def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
...@@ -2148,7 +2130,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer): ...@@ -2148,7 +2130,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["inputs_embeds"] is None: if inputs["inputs_embeds"] is None:
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"]) * self.embed_scale inputs_embeds = self.embed_tokens(inputs["input_ids"]) * self.embed_scale
embed_pos = self.embed_positions(input_shape) embed_pos = self.embed_positions(input_shape)
hidden_states = inputs["inputs_embeds"] + embed_pos hidden_states = inputs["inputs_embeds"] + embed_pos
...@@ -2158,9 +2140,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer): ...@@ -2158,9 +2140,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
# check attention mask and invert # check attention mask and invert
if inputs["attention_mask"] is not None: if inputs["attention_mask"] is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(inputs["attention_mask"]) inputs["attention_mask"] = _expand_mask(inputs["attention_mask"])
else:
attention_mask = None
encoder_states = () if inputs["output_hidden_states"] else None encoder_states = () if inputs["output_hidden_states"] else None
all_attentions = () if inputs["output_attentions"] else None all_attentions = () if inputs["output_attentions"] else None
...@@ -2175,7 +2155,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer): ...@@ -2175,7 +2155,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer
continue continue
hidden_states, attn = encoder_layer(hidden_states, attention_mask) hidden_states, attn = encoder_layer(hidden_states, inputs["attention_mask"])
if inputs["output_attentions"]: if inputs["output_attentions"]:
all_attentions += (attn,) all_attentions += (attn,)
...@@ -2219,9 +2199,12 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer): ...@@ -2219,9 +2199,12 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(config.dropout) self.dropout = tf.keras.layers.Dropout(config.dropout)
def get_embed_tokens(self):
return self.embed_tokens
def set_embed_tokens(self, embed_tokens): def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -2321,20 +2304,13 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer): ...@@ -2321,20 +2304,13 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
positions = self.embed_positions(input_shape, past_key_values_length) positions = self.embed_positions(input_shape, past_key_values_length)
if inputs["inputs_embeds"] is None: if inputs["inputs_embeds"] is None:
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"]) * self.embed_scale inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"])
hidden_states = inputs["inputs_embeds"] hidden_states = inputs["inputs_embeds"]
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] inputs["attention_mask"], combined_attention_mask = self.compute_combined_attns_mask(
if input_shape[-1] > 1: inputs, input_shape, past_key_values_length
combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) )
else:
combined_attention_mask = _expand_mask(
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
)
if inputs["attention_mask"] is not None and input_shape[-1] > 1:
combined_attention_mask = combined_attention_mask + _expand_mask(inputs["attention_mask"], tgt_len=input_shape[-1])
if inputs["encoder_hidden_states"] is not None and inputs["encoder_attention_mask"] is not None: if inputs["encoder_hidden_states"] is not None and inputs["encoder_attention_mask"] is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
...@@ -2344,13 +2320,15 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer): ...@@ -2344,13 +2320,15 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
hidden_states = self.dropout(hidden_states, training=inputs["training"]) hidden_states = self.dropout(hidden_states, training=inputs["training"])
# decoder layers # decoder layers
all_hidden_states = () all_hidden_states = () if inputs["output_hidden_states"] else None
all_self_attns = () all_self_attns = () if inputs["output_attentions"] else None
present_key_values = () present_key_values = () if inputs["use_cache"] else None
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
dropout_probability = random.uniform(0, 1) dropout_probability = random.uniform(0, 1)
if inputs["training"] and (dropout_probability < self.layerdrop): if inputs["training"] and (dropout_probability < self.layerdrop):
...@@ -2374,12 +2352,12 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer): ...@@ -2374,12 +2352,12 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
else:
all_hidden_states = None if inputs["output_attentions"]:
all_self_attns = list(all_self_attns)
all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None
present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None if inputs["use_cache"]:
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
if not inputs["return_dict"]: if not inputs["return_dict"]:
return hidden_states, present_key_values, all_hidden_states, all_self_attns return hidden_states, present_key_values, all_hidden_states, all_self_attns
...@@ -2390,18 +2368,43 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer): ...@@ -2390,18 +2368,43 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attns, attentions=all_self_attns,
) )
@tf.function
def compute_combined_attns_mask(self, inputs, input_shape, past_key_values_length):
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
else:
combined_attention_mask = _expand_mask(
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
)
if inputs["attention_mask"] is None and inputs["input_ids"] is not None and input_shape[-1] > 1:
attention_mask = tf.cast(
tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), inputs["input_ids"].dtype
)
attention_mask = tf.concat(
[
tf.ones((input_shape[0], past_key_values_length), dtype=attention_mask.dtype),
attention_mask,
],
axis=-1,
)
else:
attention_mask = tf.ones((input_shape[0], input_shape[1] + past_key_values_length), dtype=tf.int32)
return attention_mask, combined_attention_mask
@add_start_docstrings(
"The bare {{cookiecutter.uppercase_modelname}} Model outputting raw hidden-states without any specific head on top.",
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING,
)
@keras_serializable @keras_serializable
class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_modelname}}PreTrainedModel): class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
base_model_prefix = "model" config_class = {{cookiecutter.camelcase_modelname}}Config
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, *inputs, **kwargs): def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(**kwargs)
self.config = config
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared") self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared")
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name: with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
...@@ -2414,20 +2417,21 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod ...@@ -2414,20 +2417,21 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
self.encoder = TF{{cookiecutter.camelcase_modelname}}Encoder(config, embed_tokens, name="encoder") self.encoder = TF{{cookiecutter.camelcase_modelname}}Encoder(config, embed_tokens, name="encoder")
self.decoder = TF{{cookiecutter.camelcase_modelname}}Decoder(config, embed_tokens, name="decoder") self.decoder = TF{{cookiecutter.camelcase_modelname}}Decoder(config, embed_tokens, name="decoder")
def get_input_embeddings(self):
return self.shared
def get_encoder(self): def set_input_embeddings(self, new_embeddings):
return self.encoder self.shared.weight = new_embeddings
self.shared.vocab_size = self.shared.weight.shape[0]
def get_decoder(self): # retrieve correct absolute scope for embed token wrapper
return self.decoder with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
pass
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length")) # Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
@add_code_sample_docstrings( embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
tokenizer_class=_TOKENIZER_FOR_DOC, self.encoder.set_embed_tokens(embed_tokens)
checkpoint="{{cookiecutter.checkpoint_identifier}}", self.decoder.set_embed_tokens(embed_tokens)
output_type=TFSeq2SeqModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -2467,12 +2471,6 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod ...@@ -2467,12 +2471,6 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
if inputs["decoder_input_ids"] is None and inputs["decoder_inputs_embeds"] is None: if inputs["decoder_input_ids"] is None and inputs["decoder_inputs_embeds"] is None:
inputs["use_cache"] = False inputs["use_cache"] = False
inputs["output_hidden_states"] = (
inputs["output_hidden_states"]
if inputs["output_hidden_states"] is not None
else self.config.output_hidden_states
)
if inputs["encoder_outputs"] is None: if inputs["encoder_outputs"] is None:
inputs["encoder_outputs"] = self.encoder( inputs["encoder_outputs"] = self.encoder(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
...@@ -2520,10 +2518,88 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod ...@@ -2520,10 +2518,88 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
encoder_hidden_states=inputs["encoder_outputs"].hidden_states, encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions, encoder_attentions=inputs["encoder_outputs"].attentions,
) )
@add_start_docstrings(
"The bare {{cookiecutter.uppercase_modelname}} Model outputting raw hidden-states without any specific head on top.",
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING,
)
class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_modelname}}PreTrainedModel):
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.model = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="model")
def get_encoder(self):
return self.model.encoder
def get_decoder(self):
return self.model.decoder
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="{{cookiecutter.checkpoint_identifier}}",
output_type=TFSeq2SeqModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs
):
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"],
decoder_attention_mask=inputs["decoder_attention_mask"],
encoder_outputs=inputs["encoder_outputs"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
# Copied from transformers.models.bart.modeling_tf_bart.TFBartModel.serving_output # Copied from transformers.models.bart.modeling_tf_bart.TFBartModel.serving_output
def serving_output(self, output): def serving_output(self, output):
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None, pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
...@@ -2552,7 +2628,8 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec ...@@ -2552,7 +2628,8 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.model = TF{{cookiecutter.camelcase_modelname}}Model(config, name="model") self.model = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="model")
self.model._set_save_spec(inputs=self.serving.input_signature)
self.use_cache = config.use_cache self.use_cache = config.use_cache
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency. # final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency.
self.final_logits_bias = self.add_weight( self.final_logits_bias = self.add_weight(
...@@ -2675,7 +2752,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec ...@@ -2675,7 +2752,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.serving_output # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.serving_output
def serving_output(self, output): def serving_output(self, output):
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None, pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
......
...@@ -264,22 +264,8 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -264,22 +264,8 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
models_equal = False models_equal = False
self.assertTrue(models_equal) self.assertTrue(models_equal)
@slow
def test_saved_model_with_hidden_states_output(self):
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
pass
@slow
def test_saved_model_with_attentions_output(self):
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
pass
def test_saved_model_creation(self): def test_saved_model_creation(self):
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR # This test is too long (>30sec) and makes fail the CI
pass
def test_saved_model_creation_extended(self):
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
pass pass
......
...@@ -200,22 +200,8 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -200,22 +200,8 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
name = model.get_bias() name = model.get_bias()
assert name is None assert name is None
@slow
def test_saved_model_with_hidden_states_output(self):
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
pass
@slow
def test_saved_model_with_attentions_output(self):
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
pass
def test_saved_model_creation(self): def test_saved_model_creation(self):
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR # This test is too long (>30sec) and makes fail the CI
pass
def test_saved_model_creation_extended(self):
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
pass pass
def test_resize_token_embeddings(self): def test_resize_token_embeddings(self):
......
...@@ -188,28 +188,19 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -188,28 +188,19 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase):
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config) model = model_class(config)
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer) assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
x = model.get_output_layer_with_bias()
assert x is None
name = model.get_prefix_bias_name()
assert name is None
@slow if model_class in self.all_generative_model_classes:
def test_saved_model_with_hidden_states_output(self): x = model.get_output_embeddings()
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR assert isinstance(x, tf.keras.layers.Layer)
pass name = model.get_bias()
assert isinstance(name, dict)
@slow for k, v in name.items():
def test_saved_model_with_attentions_output(self): assert isinstance(v, tf.Variable)
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR else:
pass x = model.get_output_embeddings()
assert x is None
def test_saved_model_creation(self): name = model.get_bias()
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR assert name is None
pass
def test_saved_model_creation_extended(self):
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
pass
def test_resize_token_embeddings(self): def test_resize_token_embeddings(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...@@ -274,6 +265,10 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -274,6 +265,10 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase):
models_equal = False models_equal = False
self.assertTrue(models_equal) self.assertTrue(models_equal)
def test_saved_model_creation(self):
# This test is too long (>30sec) and makes fail the CI
pass
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""): def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error.""" """If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
......
...@@ -211,36 +211,35 @@ class TFModelTesterMixin: ...@@ -211,36 +211,35 @@ class TFModelTesterMixin:
def test_saved_model_with_hidden_states_output(self): def test_saved_model_with_hidden_states_output(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = True config.output_hidden_states = True
config.output_attentions = False
if hasattr(config, "use_cache"):
config.use_cache = False
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class) class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
# A saved model is always executed in graph mode, since we merged the PR #8777
# the booleans in graph mode are always the ones in the config, then we update
# the use_cache property if it exists in order to have similar booleans with the inputs
if "use_cache" in class_inputs_dict:
config.use_cache = class_inputs_dict.pop("use_cache")
model = model_class(config) model = model_class(config)
num_out = len(model(class_inputs_dict)) num_out = len(model(class_inputs_dict))
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname, saved_model=True)
saved_model_dir = os.path.join(tmpdirname, "saved_model") model = tf.keras.models.load_model(os.path.join(tmpdirname, "saved_model", "1"))
model = tf.keras.models.load_model(saved_model_dir)
outputs = model(class_inputs_dict) outputs = model(class_inputs_dict)
if self.is_encoder_decoder: if self.is_encoder_decoder:
output = outputs["encoder_hidden_states"] if isinstance(outputs, dict) else outputs[-1] output = outputs["encoder_hidden_states"]
else: else:
output = outputs["hidden_states"] if isinstance(outputs, dict) else outputs[-1] output = outputs["hidden_states"]
hidden_states = [t.numpy() for t in output]
self.assertEqual(len(outputs), num_out) self.assertEqual(len(outputs), num_out)
expected_num_layers = getattr( expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
) )
self.assertEqual(len(hidden_states), expected_num_layers)
self.assertEqual(len(output), expected_num_layers)
self.assertListEqual( self.assertListEqual(
list(hidden_states[0].shape[-2:]), list(output[0].shape[-2:]),
[self.model_tester.seq_length, self.model_tester.hidden_size], [self.model_tester.seq_length, self.model_tester.hidden_size],
) )
...@@ -248,36 +247,33 @@ class TFModelTesterMixin: ...@@ -248,36 +247,33 @@ class TFModelTesterMixin:
def test_saved_model_with_attentions_output(self): def test_saved_model_with_attentions_output(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_attentions = True config.output_attentions = True
config.output_hidden_states = False
if hasattr(config, "use_cache"):
config.use_cache = False
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", self.model_tester.seq_length) encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", self.model_tester.seq_length)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class) class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
# A saved model is always executed in graph mode, since we merged the PR #8777
# the booleans in graph mode are always the ones in the config, then we update
# the use_cache property if it exists in order to have similar booleans with the inputs
if "use_cache" in class_inputs_dict:
config.use_cache = class_inputs_dict.pop("use_cache")
model = model_class(config) model = model_class(config)
num_out = len(model(class_inputs_dict)) num_out = len(model(class_inputs_dict))
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
saved_model_dir = os.path.join(tmpdirname, "saved_model") model.save_pretrained(tmpdirname, saved_model=True)
model.save_pretrained(saved_model_dir) model = tf.keras.models.load_model(os.path.join(tmpdirname, "saved_model", "1"))
model = tf.keras.models.load_model(saved_model_dir)
outputs = model(class_inputs_dict) outputs = model(class_inputs_dict)
if self.is_encoder_decoder: if self.is_encoder_decoder:
output = outputs["encoder_attentions"] if isinstance(outputs, dict) else outputs[-1] output = outputs["encoder_attentions"]
else: else:
output = outputs["attentions"] if isinstance(outputs, dict) else outputs[-1] output = outputs["attentions"]
attentions = [t.numpy() for t in output] self.assertEqual(len(output), num_out)
self.assertEqual(len(outputs), num_out) self.assertEqual(len(output), self.model_tester.num_hidden_layers)
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual( self.assertListEqual(
list(attentions[0].shape[-3:]), list(output[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
) )
......
...@@ -352,30 +352,20 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -352,30 +352,20 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
self.assertEqual(model.config.output_hidden_states, True) self.assertEqual(model.config.output_hidden_states, True)
check_encoder_attentions_output(outputs) check_encoder_attentions_output(outputs)
@slow def test_saved_model_creation(self):
def test_saved_model_with_attentions_output(self): # This test is too long (>30sec) and makes fail the CI
# longformer has special attentions which are not
# compatible in graph mode
pass
@slow
def test_saved_model_with_hidden_states_output(self):
# TODO(JPLU, PVP) this test should pass!!! PVP:
# IMO there is a problem with the signature check.
# Test passes for TFLEDModel, but not for TFLEDForConditionalGeneration
# IMO the reason is that the tensor variable name cannot be changed
# from decoder_input_ids -> input_ids, which poses a BIG restrictions
pass
@slow
def test_saved_model_creation_extended(self):
# All the tests about building a saved model
# fails because the Seq2Seq models uses model in a model
# as a layer.
# TODO(JPLU) WARNING: NEED TO BE FIXED ASAP
pass pass
def test_saved_model_creation(self): def test_saved_model_with_attentions_output(self):
# This test don't pass because of the error:
# condition [13,8,4,5], then [13,8,4,5], and else [13,8,4,6] must be broadcastable
# This occurs line 323 in modeling_tf_led.py because the condition line 255
# returns a tensor of shape
# [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 2]
# if is_global_attn is True and a tensor of shape
# [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
# This is due to the tf.concat call line 703 that adds one dimension
# Need to check with PVP how to properly fix this
pass pass
......
...@@ -233,22 +233,8 @@ class TFMarianModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -233,22 +233,8 @@ class TFMarianModelTest(TFModelTesterMixin, unittest.TestCase):
name = model.get_bias() name = model.get_bias()
assert name is None assert name is None
@slow
def test_saved_model_with_hidden_states_output(self):
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
pass
@slow
def test_saved_model_with_attentions_output(self):
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
pass
def test_saved_model_creation(self): def test_saved_model_creation(self):
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR # This test is too long (>30sec) and makes fail the CI
pass
def test_saved_model_creation_extended(self):
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
pass pass
def test_resize_token_embeddings(self): def test_resize_token_embeddings(self):
......
...@@ -204,22 +204,8 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -204,22 +204,8 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase):
name = model.get_bias() name = model.get_bias()
assert name is None assert name is None
@slow
def test_saved_model_with_hidden_states_output(self):
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
pass
@slow
def test_saved_model_with_attentions_output(self):
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
pass
def test_saved_model_creation(self): def test_saved_model_creation(self):
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR # This test is too long (>30sec) and makes fail the CI
pass
def test_saved_model_creation_extended(self):
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
pass pass
def test_resize_token_embeddings(self): def test_resize_token_embeddings(self):
......
...@@ -231,22 +231,8 @@ class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -231,22 +231,8 @@ class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase):
name = model.get_bias() name = model.get_bias()
assert name is None assert name is None
@slow
def test_saved_model_with_hidden_states_output(self):
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
pass
@slow
def test_saved_model_with_attentions_output(self):
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
pass
def test_saved_model_creation(self): def test_saved_model_creation(self):
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR # This test is too long (>30sec) and makes fail the CI
pass
def test_saved_model_creation_extended(self):
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
pass pass
def test_resize_token_embeddings(self): def test_resize_token_embeddings(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment