"tests/models/vscode:/vscode.git/clone" did not exist on "12d51db243a00726a548a43cc333390ebae731e3"
Unverified Commit df3f4d2a authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Fix T5 and BART for TF (#9063)

* Fix T5 for graphe compilation+execution

* Fix BART

* Fix import

* Fix naming

* fix attribute name

* Oops

* fix import

* fix tests

* fix tests

* Update test

* Add mising import

* Address Patrick's comments

* Style

* Address Patrick's comment
parent a9c8bff7
...@@ -91,8 +91,6 @@ TensorFlow loss functions ...@@ -91,8 +91,6 @@ TensorFlow loss functions
TensorFlow Helper Functions TensorFlow Helper Functions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: transformers.modeling_tf_utils.cast_bool_to_primitive
.. autofunction:: transformers.modeling_tf_utils.get_initializer .. autofunction:: transformers.modeling_tf_utils.get_initializer
.. autofunction:: transformers.modeling_tf_utils.keras_serializable .. autofunction:: transformers.modeling_tf_utils.keras_serializable
......
...@@ -51,6 +51,8 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="") ...@@ -51,6 +51,8 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="")
) # '_._' is replaced by a level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList) ) # '_._' is replaced by a level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
tf_name = re.sub(r"//+", "/", tf_name) # Remove empty levels at the end tf_name = re.sub(r"//+", "/", tf_name) # Remove empty levels at the end
tf_name = tf_name.split("/") # Convert from TF2.0 '/' separators to PyTorch '.' separators tf_name = tf_name.split("/") # Convert from TF2.0 '/' separators to PyTorch '.' separators
# Some weights have a single name withtout "/" such as final_logits_bias in BART
if len(tf_name) > 1:
tf_name = tf_name[1:] # Remove level zero tf_name = tf_name[1:] # Remove level zero
# When should we transpose the weights # When should we transpose the weights
......
...@@ -354,7 +354,7 @@ def input_processing(func, config, input_ids, **kwargs): ...@@ -354,7 +354,7 @@ def input_processing(func, config, input_ids, **kwargs):
if isinstance(v, allowed_types) or v is None: if isinstance(v, allowed_types) or v is None:
output[k] = v output[k] = v
else: else:
raise ValueError(f"Data of type {type(v)} is not allowed only tf.Tensor is accepted for {k}.") raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
if isinstance(input_ids, (tuple, list)): if isinstance(input_ids, (tuple, list)):
for i, input in enumerate(input_ids): for i, input in enumerate(input_ids):
...@@ -372,7 +372,7 @@ def input_processing(func, config, input_ids, **kwargs): ...@@ -372,7 +372,7 @@ def input_processing(func, config, input_ids, **kwargs):
output[parameter_names[i]] = input output[parameter_names[i]] = input
else: else:
raise ValueError( raise ValueError(
f"Data of type {type(input)} is not allowed only tf.Tensor is accepted for {parameter_names[i]}." f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for {parameter_names[i]}."
) )
elif isinstance(input_ids, (dict, BatchEncoding)): elif isinstance(input_ids, (dict, BatchEncoding)):
if "inputs" in input_ids: if "inputs" in input_ids:
...@@ -399,13 +399,13 @@ def input_processing(func, config, input_ids, **kwargs): ...@@ -399,13 +399,13 @@ def input_processing(func, config, input_ids, **kwargs):
) )
continue continue
else: else:
raise ValueError(f"Data of type {type(v)} is not allowed only tf.Tensor is accepted for {k}.") raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
else: else:
if isinstance(input_ids, tf.Tensor) or input_ids is None: if isinstance(input_ids, tf.Tensor) or input_ids is None:
output[parameter_names[0]] = input_ids output[parameter_names[0]] = input_ids
else: else:
raise ValueError( raise ValueError(
f"Data of type {type(input_ids)} is not allowed only tf.Tensor is accepted for {parameter_names[0]}." f"Data of type {type(input_ids)} is not allowed only {allowed_types} is accepted for {parameter_names[0]}."
) )
for name in parameter_names: for name in parameter_names:
...@@ -1366,31 +1366,6 @@ def get_initializer(initializer_range: float = 0.02) -> tf.initializers.Truncate ...@@ -1366,31 +1366,6 @@ def get_initializer(initializer_range: float = 0.02) -> tf.initializers.Truncate
return tf.keras.initializers.TruncatedNormal(stddev=initializer_range) return tf.keras.initializers.TruncatedNormal(stddev=initializer_range)
def cast_bool_to_primitive(bool_variable: Union[tf.Tensor, bool], default_tensor_to_true=False) -> bool:
"""
Function arguments can be inserted as boolean tensor and bool variables to cope with Keras serialization we need to
cast the bool arguments (like :obj:`output_attentions` for instance) to correct boolean if it is a tensor.
Args:
bool_variable (:obj:`Union[tf.Tensor, bool]`):
The variable to convert to a boolean.
default_tensor_to_true (:obj:`bool`, `optional`, defaults to `False`):
The default value to use in case the tensor has no numpy attribute.
Returns:
:obj:`bool`: The converted value.
"""
# if bool variable is tensor and has numpy value
if tf.is_tensor(bool_variable):
if hasattr(bool_variable, "numpy"):
return bool(bool_variable.numpy())
elif default_tensor_to_true:
return True
# else variable is bool
return bool_variable
class TFWrappedEmbeddings: class TFWrappedEmbeddings:
""" """
this class wraps a the TFSharedEmbeddingTokens layer into a python 'no-keras-layer' class to avoid problem with this class wraps a the TFSharedEmbeddingTokens layer into a python 'no-keras-layer' class to avoid problem with
......
...@@ -41,7 +41,6 @@ from ...modeling_tf_utils import ( ...@@ -41,7 +41,6 @@ from ...modeling_tf_utils import (
TFPreTrainedModel, TFPreTrainedModel,
TFSharedEmbeddings, TFSharedEmbeddings,
TFWrappedEmbeddings, TFWrappedEmbeddings,
cast_bool_to_primitive,
input_processing, input_processing,
keras_serializable, keras_serializable,
shape_list, shape_list,
...@@ -258,9 +257,11 @@ class TFEncoderLayer(tf.keras.layers.Layer): ...@@ -258,9 +257,11 @@ class TFEncoderLayer(tf.keras.layers.Layer):
if self.normalize_before: if self.normalize_before:
x = self.self_attn_layer_norm(x) x = self.self_attn_layer_norm(x)
x, self_attn_weights = self.self_attn(query=x, key=x, key_padding_mask=encoder_padding_mask) x, self_attn_weights = self.self_attn(query=x, key=x, key_padding_mask=encoder_padding_mask)
assert shape_list(x) == shape_list( tf.debugging.assert_equal(
residual shape_list(x),
), f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(x)}" shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(x)}",
)
x = self.dropout(x, training=training) x = self.dropout(x, training=training)
x = residual + x x = residual + x
if not self.normalize_before: if not self.normalize_before:
...@@ -295,9 +296,6 @@ class TFBartEncoder(tf.keras.layers.Layer): ...@@ -295,9 +296,6 @@ class TFBartEncoder(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(config.dropout) self.dropout = tf.keras.layers.Dropout(config.dropout)
self.layerdrop = config.encoder_layerdrop self.layerdrop = config.encoder_layerdrop
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.max_source_positions = config.max_position_embeddings self.max_source_positions = config.max_position_embeddings
...@@ -328,7 +326,6 @@ class TFBartEncoder(tf.keras.layers.Layer): ...@@ -328,7 +326,6 @@ class TFBartEncoder(tf.keras.layers.Layer):
if config.add_final_layer_norm if config.add_final_layer_norm
else None else None
) )
self.return_dict = config.return_dict
def call( def call(
self, self,
...@@ -355,10 +352,6 @@ class TFBartEncoder(tf.keras.layers.Layer): ...@@ -355,10 +352,6 @@ class TFBartEncoder(tf.keras.layers.Layer):
- **all_attentions** (List[tf.Tensor]): Attention weights for each layer. - **all_attentions** (List[tf.Tensor]): Attention weights for each layer.
During training might not be of length n_layers because of layer dropout. During training might not be of length n_layers because of layer dropout.
""" """
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
# check attention mask and invert # check attention mask and invert
if attention_mask is not None: if attention_mask is not None:
assert ( assert (
...@@ -546,9 +539,6 @@ class TFBartDecoder(tf.keras.layers.Layer): ...@@ -546,9 +539,6 @@ class TFBartDecoder(tf.keras.layers.Layer):
) )
self.dropout = tf.keras.layers.Dropout(config.dropout) self.dropout = tf.keras.layers.Dropout(config.dropout)
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.use_cache = config.use_cache
self.do_blenderbot_90_layernorm = config.do_blenderbot_90_layernorm self.do_blenderbot_90_layernorm = config.do_blenderbot_90_layernorm
def call( def call(
...@@ -565,14 +555,7 @@ class TFBartDecoder(tf.keras.layers.Layer): ...@@ -565,14 +555,7 @@ class TFBartDecoder(tf.keras.layers.Layer):
return_dict=None, return_dict=None,
training=False, training=False,
): ):
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
use_cache = use_cache if use_cache is not None else self.use_cache
return_dict = return_dict if return_dict is not None else self.config.return_dict
if use_cache:
assert not training, "Training + use cache are incompatible"
# check attention mask and invert # check attention mask and invert
use_cache = cast_bool_to_primitive(use_cache)
if encoder_padding_mask is not None: if encoder_padding_mask is not None:
encoder_padding_mask = invert_mask(encoder_padding_mask) encoder_padding_mask = invert_mask(encoder_padding_mask)
...@@ -1046,7 +1029,7 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel): ...@@ -1046,7 +1029,7 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel):
self.use_cache = config.use_cache self.use_cache = config.use_cache
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency. # final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency.
self.final_logits_bias = self.add_weight( self.final_logits_bias = self.add_weight(
name="/final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
) )
def resize_token_embeddings(self, new_num_tokens): def resize_token_embeddings(self, new_num_tokens):
......
...@@ -32,12 +32,16 @@ from ...file_utils import ( ...@@ -32,12 +32,16 @@ from ...file_utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
replace_return_docstrings, replace_return_docstrings,
) )
from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput, TFSeq2SeqModelOutput from ...modeling_tf_outputs import (
TFBaseModelOutput,
TFBaseModelOutputWithPast,
TFSeq2SeqLMOutput,
TFSeq2SeqModelOutput,
)
from ...modeling_tf_utils import ( from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss, TFCausalLanguageModelingLoss,
TFPreTrainedModel, TFPreTrainedModel,
TFSharedEmbeddings, TFSharedEmbeddings,
cast_bool_to_primitive,
input_processing, input_processing,
keras_serializable, keras_serializable,
shape_list, shape_list,
...@@ -311,7 +315,7 @@ class TFT5Attention(tf.keras.layers.Layer): ...@@ -311,7 +315,7 @@ class TFT5Attention(tf.keras.layers.Layer):
) )
# to cope with keras serialization # to cope with keras serialization
if self.is_decoder and cast_bool_to_primitive(use_cache, self.use_cache) is True: if self.is_decoder and use_cache:
present_key_value_state = (key_states, value_states) present_key_value_state = (key_states, value_states)
else: else:
present_key_value_state = None present_key_value_state = None
...@@ -594,6 +598,7 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -594,6 +598,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None,
training=False, training=False,
**kwargs, **kwargs,
) -> Tuple: ) -> Tuple:
...@@ -610,6 +615,7 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -610,6 +615,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
...@@ -713,10 +719,9 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -713,10 +719,9 @@ class TFT5MainLayer(tf.keras.layers.Layer):
assert inputs["head_mask"] is None, "Head mask not supported" assert inputs["head_mask"] is None, "Head mask not supported"
inputs["head_mask"] = [None] * self.num_hidden_layers inputs["head_mask"] = [None] * self.num_hidden_layers
present_key_value_states = () if inputs["use_cache"] and self.is_decoder else None
present_key_value_states = () all_hidden_states = () if inputs["output_hidden_states"] else None
all_hidden_states = () all_attentions = () if inputs["output_attentions"] else None
all_attentions = ()
position_bias = None position_bias = None
encoder_decoder_position_bias = None encoder_decoder_position_bias = None
...@@ -725,7 +730,6 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -725,7 +730,6 @@ class TFT5MainLayer(tf.keras.layers.Layer):
for i, (layer_module, past_key_value) in enumerate(zip(self.block, inputs["past_key_values"])): for i, (layer_module, past_key_value) in enumerate(zip(self.block, inputs["past_key_values"])):
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module( layer_outputs = layer_module(
hidden_states, hidden_states,
attention_mask=extended_attention_mask, attention_mask=extended_attention_mask,
...@@ -739,6 +743,7 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -739,6 +743,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
training=inputs["training"], training=inputs["training"],
) )
# layer_outputs is a tuple with: # layer_outputs is a tuple with:
# hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
hidden_states, present_key_value_state = layer_outputs[:2] hidden_states, present_key_value_state = layer_outputs[:2]
...@@ -747,9 +752,12 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -747,9 +752,12 @@ class TFT5MainLayer(tf.keras.layers.Layer):
# layer_outputs = hidden-states, past_key_values, (self-attention weights), # layer_outputs = hidden-states, past_key_values, (self-attention weights),
# (self-attention position bias), (cross-attention position bias), (cross-attention weights), # (self-attention position bias), (cross-attention position bias), (cross-attention weights),
position_bias = layer_outputs[2] position_bias = layer_outputs[2]
if self.is_decoder and inputs["encoder_hidden_states"] is not None: if self.is_decoder and inputs["encoder_hidden_states"] is not None:
encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] encoder_decoder_position_bias = layer_outputs[4 if inputs["output_attentions"] else 3]
# append next layer key value states # append next layer key value states
if present_key_value_state is not None and inputs["use_cache"] and self.is_decoder:
present_key_value_states = present_key_value_states + (present_key_value_state,) present_key_value_states = present_key_value_states + (present_key_value_state,)
if inputs["output_attentions"]: if inputs["output_attentions"]:
...@@ -762,9 +770,10 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -762,9 +770,10 @@ class TFT5MainLayer(tf.keras.layers.Layer):
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if not inputs["return_dict"]:
outputs = (hidden_states,) outputs = (hidden_states,)
# need to check if is decoder here as well for special cases when using keras compile # need to check if is decoder here as well for special cases when using keras compile
if cast_bool_to_primitive(inputs["use_cache"], self.use_cache) is True and self.is_decoder: if inputs["use_cache"] and self.is_decoder:
outputs = outputs + (present_key_value_states,) outputs = outputs + (present_key_value_states,)
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
...@@ -772,6 +781,20 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -772,6 +781,20 @@ class TFT5MainLayer(tf.keras.layers.Layer):
outputs = outputs + (all_attentions,) outputs = outputs + (all_attentions,)
return outputs # last-layer hidden state, (all hidden states), (all attentions) return outputs # last-layer hidden state, (all hidden states), (all attentions)
if self.is_decoder:
return TFBaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=present_key_value_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
else:
return TFBaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
#################################################### ####################################################
# TFT5PreTrainedModel is a sub-class of tf.keras.Model # TFT5PreTrainedModel is a sub-class of tf.keras.Model
...@@ -1102,6 +1125,7 @@ class TFT5Model(TFT5PreTrainedModel): ...@@ -1102,6 +1125,7 @@ class TFT5Model(TFT5PreTrainedModel):
use_cache=False, use_cache=False,
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
...@@ -1119,38 +1143,25 @@ class TFT5Model(TFT5PreTrainedModel): ...@@ -1119,38 +1143,25 @@ class TFT5Model(TFT5PreTrainedModel):
use_cache=inputs["use_cache"], use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
past = ( past = (inputs["encoder_outputs"], decoder_outputs[1]) if inputs["use_cache"] else None
(inputs["encoder_outputs"], decoder_outputs[1])
if cast_bool_to_primitive(inputs["use_cache"], self.config.use_cache)
else None
)
if not inputs["return_dict"]: if not inputs["return_dict"]:
if past is not None: if past is not None:
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:] decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
return decoder_outputs + inputs["encoder_outputs"] return decoder_outputs + inputs["encoder_outputs"]
# This is long and annoying but if we introduce return_dict at the TFT5MainLayer level (like in PyTorch)
# TF refuses to compile anymore.
if not cast_bool_to_primitive(inputs["use_cache"], self.config.use_cache):
decoder_outputs = decoder_outputs[:1] + (None,) + decoder_outputs[1:]
if not cast_bool_to_primitive(inputs["output_hidden_states"], self.config.output_hidden_states):
inputs["encoder_outputs"] = inputs["encoder_outputs"][:1] + (None,) + inputs["encoder_outputs"][1:]
decoder_outputs = decoder_outputs[:2] + (None,) + decoder_outputs[2:]
if not cast_bool_to_primitive(inputs["output_attentions"], self.config.output_attentions):
inputs["encoder_outputs"] = inputs["encoder_outputs"] + (None,)
decoder_outputs = decoder_outputs + (None,)
return TFSeq2SeqModelOutput( return TFSeq2SeqModelOutput(
last_hidden_state=decoder_outputs[0], last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=past, past_key_values=past,
decoder_hidden_states=decoder_outputs[2], decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs[3], decoder_attentions=decoder_outputs.attentions,
encoder_last_hidden_state=inputs["encoder_outputs"][0], encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"][1], encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
encoder_attentions=inputs["encoder_outputs"][2], encoder_attentions=inputs["encoder_outputs"].attentions,
) )
...@@ -1280,6 +1291,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1280,6 +1291,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
...@@ -1313,6 +1325,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1313,6 +1325,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
use_cache=inputs["use_cache"], use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
...@@ -1327,37 +1340,41 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1327,37 +1340,41 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
past = ( past = (inputs["encoder_outputs"], decoder_outputs[1]) if inputs["use_cache"] else None
(inputs["encoder_outputs"], decoder_outputs[1])
if cast_bool_to_primitive(inputs["use_cache"], self.config.use_cache)
else None
)
if not inputs["return_dict"]: if not inputs["return_dict"]:
if past is not None: if past is not None:
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:] decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
output = (logits,) + decoder_outputs[1:] + inputs["encoder_outputs"] output = (logits,) + decoder_outputs[1:] + inputs["encoder_outputs"]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
# This is long and annoying but if we introduce return_dict at the TFT5MainLayer level (like in PyTorch) # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True
# TF refuses to compile anymore. elif isinstance(inputs["encoder_outputs"], tuple):
if not cast_bool_to_primitive(inputs["use_cache"], self.config.use_cache): last_hidden_state = inputs["encoder_outputs"][0]
decoder_outputs = decoder_outputs[:1] + (None,) + decoder_outputs[1:] hidden_states = None
if not cast_bool_to_primitive(inputs["output_hidden_states"], self.config.output_hidden_states): attentions = None
inputs["encoder_outputs"] = inputs["encoder_outputs"][:1] + (None,) + inputs["encoder_outputs"][1:] idx = 0
decoder_outputs = decoder_outputs[:2] + (None,) + decoder_outputs[2:] if inputs["output_hidden_states"]:
if not cast_bool_to_primitive(inputs["output_attentions"], self.config.output_attentions): idx += 1
inputs["encoder_outputs"] = inputs["encoder_outputs"] + (None,) hidden_states = inputs["encoder_outputs"][idx]
decoder_outputs = decoder_outputs + (None,) if inputs["output_attentions"]:
idx += 1
attentions = inputs["encoder_outputs"][idx]
inputs["encoder_outputs"] = TFBaseModelOutput(
last_hidden_state=last_hidden_state,
hidden_states=hidden_states,
attentions=attentions,
)
return TFSeq2SeqLMOutput( return TFSeq2SeqLMOutput(
loss=loss, loss=loss,
logits=logits, logits=logits,
past_key_values=past, past_key_values=past,
decoder_hidden_states=decoder_outputs[2], decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs[3], decoder_attentions=decoder_outputs.attentions,
encoder_last_hidden_state=inputs["encoder_outputs"][0], encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"][1], encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
encoder_attentions=inputs["encoder_outputs"][2], encoder_attentions=inputs["encoder_outputs"].attentions,
) )
def prepare_inputs_for_generation(self, inputs, past, attention_mask, use_cache, **kwargs): def prepare_inputs_for_generation(self, inputs, past, attention_mask, use_cache, **kwargs):
...@@ -1498,19 +1515,15 @@ class TFT5EncoderModel(TFT5PreTrainedModel): ...@@ -1498,19 +1515,15 @@ class TFT5EncoderModel(TFT5PreTrainedModel):
use_cache=False, use_cache=False,
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
if not inputs["return_dict"]: if not inputs["return_dict"]:
return encoder_outputs return encoder_outputs
if not cast_bool_to_primitive(inputs["output_hidden_states"], self.config.output_hidden_states):
encoder_outputs = encoder_outputs[:1] + (None,) + encoder_outputs[1:]
if not cast_bool_to_primitive(inputs["output_attentions"], self.config.output_attentions):
encoder_outputs = encoder_outputs + (None,)
return TFBaseModelOutput( return TFBaseModelOutput(
last_hidden_state=encoder_outputs[0], last_hidden_state=encoder_outputs.last_hidden_state,
hidden_states=encoder_outputs[1], hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs[2], attentions=encoder_outputs.attentions,
) )
...@@ -118,14 +118,6 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -118,14 +118,6 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
# inputs_embeds not supported # inputs_embeds not supported
pass pass
def test_saved_model_with_hidden_states_output(self):
# Should be uncommented during patrick TF refactor
pass
def test_saved_model_with_attentions_output(self):
# Should be uncommented during patrick TF refactor
pass
def test_model_common_attributes(self): def test_model_common_attributes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
...@@ -171,6 +171,11 @@ class TFModelTesterMixin: ...@@ -171,6 +171,11 @@ class TFModelTesterMixin:
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class) class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
# A saved model is always executed in graph mode, since we merged the PR #8777
# the booleans in graph mode are always the ones in the config, then we update
# the use_cache property if it exists in order to have similar booleans with the inputs
if "use_cache" in class_inputs_dict:
config.use_cache = class_inputs_dict.pop("use_cache")
model = model_class(config) model = model_class(config)
num_out = len(model(class_inputs_dict)) num_out = len(model(class_inputs_dict))
model._saved_model_inputs_spec = None model._saved_model_inputs_spec = None
...@@ -207,6 +212,11 @@ class TFModelTesterMixin: ...@@ -207,6 +212,11 @@ class TFModelTesterMixin:
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class) class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
# A saved model is always executed in graph mode, since we merged the PR #8777
# the booleans in graph mode are always the ones in the config, then we update
# the use_cache property if it exists in order to have similar booleans with the inputs
if "use_cache" in class_inputs_dict:
config.use_cache = class_inputs_dict.pop("use_cache")
model = model_class(config) model = model_class(config)
num_out = len(model(class_inputs_dict)) num_out = len(model(class_inputs_dict))
model._saved_model_inputs_spec = None model._saved_model_inputs_spec = None
...@@ -249,10 +259,11 @@ class TFModelTesterMixin: ...@@ -249,10 +259,11 @@ class TFModelTesterMixin:
if "T5" in main_layer_class.__name__: if "T5" in main_layer_class.__name__:
# Take the same values than in TFT5ModelTester for this shared layer # Take the same values than in TFT5ModelTester for this shared layer
shared = TFSharedEmbeddings(99, 32, name="shared") shared = TFSharedEmbeddings(99, 32, name="shared")
config.use_cache = False config.use_cache = inputs_dict.pop("use_cache", None)
main_layer = main_layer_class(config, embed_tokens=shared) main_layer = main_layer_class(config, embed_tokens=shared)
else: else:
main_layer = main_layer_class(config) main_layer = main_layer_class(config)
symbolic_inputs = { symbolic_inputs = {
name: tf.keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items() name: tf.keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items()
} }
...@@ -321,10 +332,13 @@ class TFModelTesterMixin: ...@@ -321,10 +332,13 @@ class TFModelTesterMixin:
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
pt_model.eval() pt_model.eval()
pt_inputs_dict = dict( pt_inputs_dict = {}
(name, torch.from_numpy(key.numpy()).to(torch.long)) for name, key in self._prepare_for_class(inputs_dict, model_class).items():
for name, key in self._prepare_for_class(inputs_dict, model_class).items() if type(key) == bool:
) pt_inputs_dict[name] = key
else:
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
# need to rename encoder-decoder "inputs" for PyTorch # need to rename encoder-decoder "inputs" for PyTorch
if "inputs" in pt_inputs_dict and self.is_encoder_decoder: if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs") pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
...@@ -358,10 +372,13 @@ class TFModelTesterMixin: ...@@ -358,10 +372,13 @@ class TFModelTesterMixin:
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
pt_model.eval() pt_model.eval()
pt_inputs_dict = dict( pt_inputs_dict = {}
(name, torch.from_numpy(key.numpy()).to(torch.long)) for name, key in self._prepare_for_class(inputs_dict, model_class).items():
for name, key in self._prepare_for_class(inputs_dict, model_class).items() if type(key) == bool:
) key = np.array(key, dtype=bool)
pt_inputs_dict[name] = torch.from_numpy(key).to(torch.long)
else:
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
# need to rename encoder-decoder "inputs" for PyTorch # need to rename encoder-decoder "inputs" for PyTorch
if "inputs" in pt_inputs_dict and self.is_encoder_decoder: if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs") pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
...@@ -574,7 +591,23 @@ class TFModelTesterMixin: ...@@ -574,7 +591,23 @@ class TFModelTesterMixin:
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
) )
hidden_states = outputs[-1] if model.config.is_encoder_decoder:
encoder_hidden_states = outputs.encoder_hidden_states
decoder_hidden_states = outputs.decoder_hidden_states
self.assertEqual(config.output_attentions, False)
self.assertEqual(len(encoder_hidden_states), expected_num_layers)
self.assertListEqual(
list(encoder_hidden_states[0].shape[-2:]),
[self.model_tester.seq_length, self.model_tester.hidden_size],
)
self.assertEqual(len(decoder_hidden_states), expected_num_layers)
self.assertListEqual(
list(decoder_hidden_states[0].shape[-2:]),
[self.model_tester.seq_length, self.model_tester.hidden_size],
)
else:
hidden_states = outputs.hidden_states
self.assertEqual(config.output_attentions, False) self.assertEqual(config.output_attentions, False)
self.assertEqual(len(hidden_states), expected_num_layers) self.assertEqual(len(hidden_states), expected_num_layers)
self.assertListEqual( self.assertListEqual(
...@@ -796,7 +829,7 @@ class TFModelTesterMixin: ...@@ -796,7 +829,7 @@ class TFModelTesterMixin:
def test_lm_head_model_random_beam_search_generate(self): def test_lm_head_model_random_beam_search_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"] input_ids = inputs_dict["input_ids"]
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
model = model_class(config) model = model_class(config)
......
...@@ -133,8 +133,6 @@ class TFT5ModelTester: ...@@ -133,8 +133,6 @@ class TFT5ModelTester:
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
output, past_key_values = outputs
# create hypothetical next token and extent to next_input_ids # create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
...@@ -142,7 +140,7 @@ class TFT5ModelTester: ...@@ -142,7 +140,7 @@ class TFT5ModelTester:
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
output_from_no_past = model(next_input_ids)[0] output_from_no_past = model(next_input_ids)[0]
output_from_past = model(next_tokens, past_key_values=past_key_values)[0] output_from_past = model(next_tokens, past_key_values=outputs.past_key_values)[0]
# select random slice # select random slice
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1])) random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
...@@ -164,7 +162,7 @@ class TFT5ModelTester: ...@@ -164,7 +162,7 @@ class TFT5ModelTester:
attn_mask = tf.concat([attn_mask_begin, attn_mask_end], axis=1) attn_mask = tf.concat([attn_mask_begin, attn_mask_end], axis=1)
# first forward pass # first forward pass
_, past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True) outputs = model(input_ids, attention_mask=attn_mask, use_cache=True)
# create hypothetical next token and extent to next_input_ids # create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
...@@ -187,7 +185,7 @@ class TFT5ModelTester: ...@@ -187,7 +185,7 @@ class TFT5ModelTester:
# get two different outputs # get two different outputs
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)[0] output_from_no_past = model(next_input_ids, attention_mask=attn_mask)[0]
output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask)[0] output_from_past = model(next_tokens, past_key_values=outputs.past_key_values, attention_mask=attn_mask)[0]
# select random slice # select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).numpy().item() random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).numpy().item()
...@@ -208,8 +206,6 @@ class TFT5ModelTester: ...@@ -208,8 +206,6 @@ class TFT5ModelTester:
# first forward pass # first forward pass
outputs = model(input_ids, use_cache=True) outputs = model(input_ids, use_cache=True)
output, past_key_values = outputs
# create hypothetical next token and extent to next_input_ids # create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
...@@ -217,7 +213,7 @@ class TFT5ModelTester: ...@@ -217,7 +213,7 @@ class TFT5ModelTester:
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
output_from_no_past = model(next_input_ids)[0] output_from_no_past = model(next_input_ids)[0]
output_from_past = model(next_tokens, past_key_values=past_key_values)[0] output_from_past = model(next_tokens, past_key_values=outputs.past_key_values)[0]
self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1]) self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
...@@ -236,7 +232,7 @@ class TFT5ModelTester: ...@@ -236,7 +232,7 @@ class TFT5ModelTester:
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": input_ids, "decoder_input_ids": input_ids,
"decoder_attention_mask": input_mask, "decoder_attention_mask": input_mask,
"use_cache": tf.convert_to_tensor([False]), "use_cache": False,
} }
return config, inputs_dict return config, inputs_dict
...@@ -298,14 +294,6 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -298,14 +294,6 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
model = TFT5Model.from_pretrained("t5-small") model = TFT5Model.from_pretrained("t5-small")
self.assertIsNotNone(model) self.assertIsNotNone(model)
@slow
def test_saved_model_with_attentions_output(self):
pass
@slow
def test_saved_model_with_hidden_states_output(self):
pass
class TFT5EncoderOnlyModelTester: class TFT5EncoderOnlyModelTester:
def __init__( def __init__(
...@@ -411,6 +399,7 @@ class TFT5EncoderOnlyModelTester: ...@@ -411,6 +399,7 @@ class TFT5EncoderOnlyModelTester:
class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase): class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
is_encoder_decoder = False
all_model_classes = (TFT5EncoderModel,) if is_tf_available() else () all_model_classes = (TFT5EncoderModel,) if is_tf_available() else ()
def setUp(self): def setUp(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment