Unverified Commit 23a75a53 authored by Dahlbomii's avatar Dahlbomii Committed by GitHub
Browse files

Type hints and decorator for TF T5 (#16376)



* Type hints and TF decorator added

* Re-add XLA generation method

* Re-add lines that were deleted by conflicting updates

* Re-add lines that were deleted by conflicting updates

* Re-add lines that were deleted by conflicting updates
Co-authored-by: default avatarmatt <rocketknight1@gmail.com>
parent 2a27c800
...@@ -19,7 +19,7 @@ import copy ...@@ -19,7 +19,7 @@ import copy
import itertools import itertools
import math import math
import warnings import warnings
from typing import Tuple from typing import Optional, Tuple, Union
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -34,11 +34,12 @@ from ...modeling_tf_outputs import ( ...@@ -34,11 +34,12 @@ from ...modeling_tf_outputs import (
) )
from ...modeling_tf_utils import ( from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss, TFCausalLanguageModelingLoss,
TFModelInputType,
TFPreTrainedModel, TFPreTrainedModel,
TFSharedEmbeddings, TFSharedEmbeddings,
TFWrappedEmbeddings, TFWrappedEmbeddings,
input_processing,
keras_serializable, keras_serializable,
unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list
from ...utils import ( from ...utils import (
...@@ -637,6 +638,7 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -637,6 +638,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models
@unpack_inputs
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -654,72 +656,48 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -654,72 +656,48 @@ class TFT5MainLayer(tf.keras.layers.Layer):
training=False, training=False,
**kwargs, **kwargs,
) -> Tuple: ) -> Tuple:
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
encoder_head_mask=encoder_head_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if input_ids is not None and inputs_embeds is not None:
err_msg_prefix = "decoder_" if self.is_decoder else "" err_msg_prefix = "decoder_" if self.is_decoder else ""
raise ValueError( raise ValueError(
f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
) )
elif inputs["input_ids"] is not None: elif input_ids is not None:
input_shape = shape_list(inputs["input_ids"]) input_shape = shape_list(input_ids)
inputs["input_ids"] = tf.reshape(inputs["input_ids"], (-1, input_shape[-1])) input_ids = tf.reshape(input_ids, (-1, input_shape[-1]))
elif inputs["inputs_embeds"] is not None: elif inputs_embeds is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1] input_shape = shape_list(inputs_embeds)[:-1]
else: else:
err_msg_prefix = "decoder_" if self.is_decoder else "" err_msg_prefix = "decoder_" if self.is_decoder else ""
raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
if inputs["inputs_embeds"] is None: if inputs_embeds is None:
assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"]) inputs_embeds = self.embed_tokens(input_ids)
batch_size, seq_length = input_shape batch_size, seq_length = input_shape
# required mask seq length can be calculated via length of past # required mask seq length can be calculated via length of past
mask_seq_length = ( mask_seq_length = (
shape_list(inputs["past_key_values"][0][0])[2] + seq_length shape_list(past_key_values[0][0])[2] + seq_length if past_key_values is not None else seq_length
if inputs["past_key_values"] is not None
else seq_length
) )
if inputs["attention_mask"] is None: if attention_mask is None:
inputs["attention_mask"] = tf.fill((batch_size, mask_seq_length), 1) attention_mask = tf.fill((batch_size, mask_seq_length), 1)
if ( if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
self.is_decoder encoder_seq_length = shape_list(encoder_hidden_states)[1]
and inputs["encoder_attention_mask"] is None encoder_attention_mask = tf.fill((batch_size, encoder_seq_length), 1)
and inputs["encoder_hidden_states"] is not None
):
encoder_seq_length = shape_list(inputs["encoder_hidden_states"])[1]
inputs["encoder_attention_mask"] = tf.fill((batch_size, encoder_seq_length), 1)
# initialize past_key_values with `None` if past does not exist # initialize past_key_values with `None` if past does not exist
if inputs["past_key_values"] is None: if past_key_values is None:
inputs["past_key_values"] = [None] * len(self.block) past_key_values = [None] * len(self.block)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads. # ourselves in which case we just need to make it broadcastable to all heads.
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=inputs["inputs_embeds"].dtype) attention_mask = tf.cast(attention_mask, dtype=inputs_embeds.dtype)
num_dims_attention_mask = len(shape_list(inputs["attention_mask"])) num_dims_attention_mask = len(shape_list(attention_mask))
if num_dims_attention_mask == 3: if num_dims_attention_mask == 3:
extended_attention_mask = inputs["attention_mask"][:, None, :, :] extended_attention_mask = attention_mask[:, None, :, :]
elif num_dims_attention_mask == 2: elif num_dims_attention_mask == 2:
# Provided a padding mask of dimensions [batch_size, mask_seq_length] # Provided a padding mask of dimensions [batch_size, mask_seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask # - if the model is a decoder, apply a causal mask in addition to the padding mask
...@@ -730,12 +708,12 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -730,12 +708,12 @@ class TFT5MainLayer(tf.keras.layers.Layer):
tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)),
seq_ids[None, :, None], seq_ids[None, :, None],
) )
causal_mask = tf.cast(causal_mask, dtype=inputs["attention_mask"].dtype) causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype)
extended_attention_mask = causal_mask[:, None, :, :] * inputs["attention_mask"][:, None, None, :] extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
if inputs["past_key_values"][0] is not None: if past_key_values[0] is not None:
extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :]
else: else:
extended_attention_mask = inputs["attention_mask"][:, None, None, :] extended_attention_mask = attention_mask[:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for # masked positions, this operation will create a tensor which is 0.0 for
...@@ -750,18 +728,16 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -750,18 +728,16 @@ class TFT5MainLayer(tf.keras.layers.Layer):
extended_attention_mask = (1.0 - extended_attention_mask) * -1e9 extended_attention_mask = (1.0 - extended_attention_mask) * -1e9
if self.is_decoder and inputs["encoder_attention_mask"] is not None: if self.is_decoder and encoder_attention_mask is not None:
# If a 2D ou 3D attention mask is provided for the cross-attention # If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
inputs["encoder_attention_mask"] = tf.cast( encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype)
inputs["encoder_attention_mask"], dtype=extended_attention_mask.dtype num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask))
)
num_dims_encoder_attention_mask = len(shape_list(inputs["encoder_attention_mask"]))
if num_dims_encoder_attention_mask == 3: if num_dims_encoder_attention_mask == 3:
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, :, :] encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
if num_dims_encoder_attention_mask == 2: if num_dims_encoder_attention_mask == 2:
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, None, :] encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
...@@ -772,33 +748,31 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -772,33 +748,31 @@ class TFT5MainLayer(tf.keras.layers.Layer):
else: else:
encoder_extended_attention_mask = None encoder_extended_attention_mask = None
present_key_value_states = () if inputs["use_cache"] and self.is_decoder else None present_key_value_states = () if use_cache and self.is_decoder else None
all_hidden_states = () if inputs["output_hidden_states"] else None all_hidden_states = () if output_hidden_states else None
all_attentions = () if inputs["output_attentions"] else None all_attentions = () if output_attentions else None
all_cross_attentions = () if (inputs["output_attentions"] and self.is_decoder) else None all_cross_attentions = () if (output_attentions and self.is_decoder) else None
position_bias = None position_bias = None
encoder_decoder_position_bias = None encoder_decoder_position_bias = None
hidden_states = self.dropout(inputs["inputs_embeds"], training=inputs["training"]) hidden_states = self.dropout(inputs_embeds, training=training)
for idx, (layer_module, past_key_value) in enumerate(zip(self.block, inputs["past_key_values"])): for idx, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
if inputs["output_hidden_states"]: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module( layer_outputs = layer_module(
hidden_states, hidden_states,
attention_mask=extended_attention_mask, attention_mask=extended_attention_mask,
position_bias=position_bias, position_bias=position_bias,
encoder_hidden_states=inputs["encoder_hidden_states"], encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask, encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias, encoder_decoder_position_bias=encoder_decoder_position_bias,
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, layer_head_mask=head_mask[idx] if head_mask is not None else None,
encoder_layer_head_mask=inputs["encoder_head_mask"][idx] encoder_layer_head_mask=encoder_head_mask[idx] if encoder_head_mask is not None else None,
if inputs["encoder_head_mask"] is not None
else None,
past_key_value=past_key_value, past_key_value=past_key_value,
use_cache=inputs["use_cache"], use_cache=use_cache,
output_attentions=inputs["output_attentions"], output_attentions=output_attentions,
training=inputs["training"], training=training,
) )
# layer_outputs is a tuple with: # layer_outputs is a tuple with:
...@@ -810,33 +784,33 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -810,33 +784,33 @@ class TFT5MainLayer(tf.keras.layers.Layer):
# (self-attention position bias), (cross-attention position bias), (cross-attention weights), # (self-attention position bias), (cross-attention position bias), (cross-attention weights),
position_bias = layer_outputs[2] position_bias = layer_outputs[2]
if self.is_decoder and inputs["encoder_hidden_states"] is not None: if self.is_decoder and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[4 if inputs["output_attentions"] else 3] encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
# append next layer key value states # append next layer key value states
if present_key_value_state is not None and inputs["use_cache"] and self.is_decoder: if present_key_value_state is not None and use_cache and self.is_decoder:
present_key_value_states = present_key_value_states + (present_key_value_state,) present_key_value_states = present_key_value_states + (present_key_value_state,)
if inputs["output_attentions"]: if output_attentions:
all_attentions = all_attentions + (layer_outputs[3],) all_attentions = all_attentions + (layer_outputs[3],)
if self.is_decoder: if self.is_decoder:
all_cross_attentions = all_cross_attentions + (layer_outputs[5],) all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states, training=inputs["training"]) hidden_states = self.dropout(hidden_states, training=training)
# Add last layer # Add last layer
if inputs["output_hidden_states"]: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if not inputs["return_dict"]: if not return_dict:
outputs = (hidden_states,) outputs = (hidden_states,)
# need to check if is decoder here as well for special cases when using keras compile # need to check if is decoder here as well for special cases when using keras compile
if inputs["use_cache"] and self.is_decoder: if use_cache and self.is_decoder:
outputs = outputs + (present_key_value_states,) outputs = outputs + (present_key_value_states,)
if inputs["output_hidden_states"]: if output_hidden_states:
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
if inputs["output_attentions"]: if output_attentions:
outputs = outputs + (all_attentions,) outputs = outputs + (all_attentions,)
if self.is_decoder: if self.is_decoder:
outputs + (all_cross_attentions,) outputs + (all_cross_attentions,)
...@@ -1158,27 +1132,28 @@ class TFT5Model(TFT5PreTrainedModel): ...@@ -1158,27 +1132,28 @@ class TFT5Model(TFT5PreTrainedModel):
def get_decoder(self): def get_decoder(self):
return self.decoder return self.decoder
@unpack_inputs
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_input_ids=None, decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_head_mask=None, decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_outputs=None, encoder_outputs: Optional[Union[np.ndarray, tf.Tensor]] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
training=False, training: Optional[bool] = False,
**kwargs, **kwargs,
): ) -> Union[Tuple, TFSeq2SeqModelOutput]:
r""" r"""
Returns: Returns:
...@@ -1204,68 +1179,47 @@ class TFT5Model(TFT5PreTrainedModel): ...@@ -1204,68 +1179,47 @@ class TFT5Model(TFT5PreTrainedModel):
warnings.warn(_HEAD_MASK_WARNING_MSG, FutureWarning) warnings.warn(_HEAD_MASK_WARNING_MSG, FutureWarning)
decoder_head_mask = head_mask decoder_head_mask = head_mask
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
# Encode if needed (training, first prediction pass) # Encode if needed (training, first prediction pass)
if inputs["encoder_outputs"] is None: if encoder_outputs is None:
inputs["encoder_outputs"] = self.encoder( encoder_outputs = self.encoder(
inputs["input_ids"], input_ids,
attention_mask=inputs["attention_mask"], attention_mask=attention_mask,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs_embeds,
head_mask=inputs["head_mask"], head_mask=head_mask,
past_key_values=None, past_key_values=None,
use_cache=False, use_cache=False,
output_attentions=inputs["output_attentions"], output_attentions=output_attentions,
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=output_hidden_states,
return_dict=inputs["return_dict"], return_dict=return_dict,
training=inputs["training"], training=training,
) )
hidden_states = inputs["encoder_outputs"][0] hidden_states = encoder_outputs[0]
# Decode # Decode
decoder_outputs = self.decoder( decoder_outputs = self.decoder(
inputs["decoder_input_ids"], decoder_input_ids,
attention_mask=inputs["decoder_attention_mask"], attention_mask=decoder_attention_mask,
encoder_hidden_states=hidden_states, encoder_hidden_states=hidden_states,
encoder_attention_mask=inputs["attention_mask"], encoder_attention_mask=attention_mask,
inputs_embeds=inputs["decoder_inputs_embeds"], inputs_embeds=decoder_inputs_embeds,
head_mask=inputs["decoder_head_mask"], head_mask=decoder_head_mask,
encoder_head_mask=inputs["head_mask"], encoder_head_mask=head_mask,
past_key_values=inputs["past_key_values"], past_key_values=past_key_values,
use_cache=inputs["use_cache"], use_cache=use_cache,
output_attentions=inputs["output_attentions"], output_attentions=output_attentions,
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=output_hidden_states,
return_dict=inputs["return_dict"], return_dict=return_dict,
training=inputs["training"], training=training,
) )
past = decoder_outputs[1] if inputs["use_cache"] else None past = decoder_outputs[1] if use_cache else None
if not inputs["return_dict"]: if not return_dict:
if past is not None: if past is not None:
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:] decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
return decoder_outputs + inputs["encoder_outputs"] return decoder_outputs + encoder_outputs
return TFSeq2SeqModelOutput( return TFSeq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state, last_hidden_state=decoder_outputs.last_hidden_state,
...@@ -1273,9 +1227,9 @@ class TFT5Model(TFT5PreTrainedModel): ...@@ -1273,9 +1227,9 @@ class TFT5Model(TFT5PreTrainedModel):
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions, cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"].hidden_states, encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions, encoder_attentions=encoder_outputs.attentions,
) )
def serving_output(self, output): def serving_output(self, output):
...@@ -1354,28 +1308,29 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1354,28 +1308,29 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
def get_decoder(self): def get_decoder(self):
return self.decoder return self.decoder
@unpack_inputs
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_input_ids=None, decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_head_mask=None, decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_outputs=None, encoder_outputs: Optional[Union[np.ndarray, tf.Tensor]] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
labels=None, labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
training=False, training: Optional[bool] = False,
**kwargs, **kwargs,
): ) -> Union[Tuple, TFSeq2SeqLMOutput]:
r""" r"""
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., Labels for computing the cross entropy classification loss. Indices should be in `[0, ...,
...@@ -1411,65 +1366,39 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1411,65 +1366,39 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
warnings.warn(_HEAD_MASK_WARNING_MSG, FutureWarning) warnings.warn(_HEAD_MASK_WARNING_MSG, FutureWarning)
decoder_head_mask = head_mask decoder_head_mask = head_mask
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
# Encode if needed (training, first prediction pass) # Encode if needed (training, first prediction pass)
if inputs["encoder_outputs"] is None: if encoder_outputs is None:
inputs["encoder_outputs"] = self.encoder( encoder_outputs = self.encoder(
inputs["input_ids"], input_ids,
attention_mask=inputs["attention_mask"], attention_mask=attention_mask,
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs_embeds,
head_mask=inputs["head_mask"], head_mask=head_mask,
output_attentions=inputs["output_attentions"], output_attentions=output_attentions,
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=output_hidden_states,
return_dict=inputs["return_dict"], return_dict=return_dict,
training=inputs["training"], training=training,
) )
hidden_states = inputs["encoder_outputs"][0] hidden_states = encoder_outputs[0]
if ( if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
inputs["labels"] is not None
and inputs["decoder_input_ids"] is None
and inputs["decoder_inputs_embeds"] is None
):
# get decoder inputs from shifting lm labels to the right # get decoder inputs from shifting lm labels to the right
inputs["decoder_input_ids"] = self._shift_right(inputs["labels"]) decoder_input_ids = self._shift_right(labels)
# Decode # Decode
decoder_outputs = self.decoder( decoder_outputs = self.decoder(
inputs["decoder_input_ids"], decoder_input_ids,
attention_mask=inputs["decoder_attention_mask"], attention_mask=decoder_attention_mask,
encoder_hidden_states=hidden_states, encoder_hidden_states=hidden_states,
encoder_attention_mask=inputs["attention_mask"], encoder_attention_mask=attention_mask,
inputs_embeds=inputs["decoder_inputs_embeds"], inputs_embeds=decoder_inputs_embeds,
head_mask=inputs["decoder_head_mask"], head_mask=decoder_head_mask,
past_key_values=inputs["past_key_values"], past_key_values=past_key_values,
use_cache=inputs["use_cache"], use_cache=use_cache,
output_attentions=inputs["output_attentions"], output_attentions=output_attentions,
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=output_hidden_states,
return_dict=inputs["return_dict"], return_dict=return_dict,
training=inputs["training"], training=training,
) )
sequence_output = decoder_outputs[0] sequence_output = decoder_outputs[0]
...@@ -1483,29 +1412,29 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1483,29 +1412,29 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
logits = tf.cast(logits, tf.float32) logits = tf.cast(logits, tf.float32)
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits) loss = None if labels is None else self.hf_compute_loss(labels, logits)
past = decoder_outputs[1] if inputs["use_cache"] else None past = decoder_outputs[1] if use_cache else None
if not inputs["return_dict"]: if not return_dict:
if past is not None: if past is not None:
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:] decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
output = (logits,) + decoder_outputs[1:] + inputs["encoder_outputs"] output = (logits,) + decoder_outputs[1:] + encoder_outputs
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
# If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True
elif isinstance(inputs["encoder_outputs"], tuple): elif isinstance(encoder_outputs, tuple):
last_hidden_state = inputs["encoder_outputs"][0] last_hidden_state = encoder_outputs[0]
hidden_states = None hidden_states = None
attentions = None attentions = None
idx = 0 idx = 0
if inputs["output_hidden_states"]: if output_hidden_states:
idx += 1 idx += 1
hidden_states = inputs["encoder_outputs"][idx] hidden_states = encoder_outputs[idx]
if inputs["output_attentions"]: if output_attentions:
idx += 1 idx += 1
attentions = inputs["encoder_outputs"][idx] attentions = encoder_outputs[idx]
inputs["encoder_outputs"] = TFBaseModelOutput( encoder_outputs = TFBaseModelOutput(
last_hidden_state=last_hidden_state, last_hidden_state=last_hidden_state,
hidden_states=hidden_states, hidden_states=hidden_states,
attentions=attentions, attentions=attentions,
...@@ -1518,9 +1447,9 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1518,9 +1447,9 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions, cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"].hidden_states, encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions, encoder_attentions=encoder_outputs.attentions,
) )
def serving_output(self, output): def serving_output(self, output):
...@@ -1685,20 +1614,21 @@ class TFT5EncoderModel(TFT5PreTrainedModel): ...@@ -1685,20 +1614,21 @@ class TFT5EncoderModel(TFT5PreTrainedModel):
def get_encoder(self): def get_encoder(self):
return self.encoder return self.encoder
@unpack_inputs
@add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFBaseModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFBaseModelOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
self, self,
input_ids, input_ids: Optional[TFModelInputType] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
training=False, training: Optional[bool] = False,
**kwargs, **kwargs,
): ) -> Union[Tuple, TFBaseModelOutput]:
r""" r"""
Returns: Returns:
...@@ -1715,36 +1645,23 @@ class TFT5EncoderModel(TFT5PreTrainedModel): ...@@ -1715,36 +1645,23 @@ class TFT5EncoderModel(TFT5PreTrainedModel):
>>> ).input_ids # Batch size 1 >>> ).input_ids # Batch size 1
>>> outputs = model(input_ids) >>> outputs = model(input_ids)
```""" ```"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
input_ids, input_ids,
attention_mask=inputs["attention_mask"], attention_mask=attention_mask,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs_embeds,
head_mask=head_mask, head_mask=head_mask,
past_key_values=None, past_key_values=None,
use_cache=False, use_cache=False,
output_attentions=inputs["output_attentions"], output_attentions=output_attentions,
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=output_hidden_states,
return_dict=inputs["return_dict"], return_dict=return_dict,
training=inputs["training"], training=training,
) )
if not inputs["return_dict"]: if not return_dict:
return encoder_outputs return encoder_outputs
return TFBaseModelOutput( return TFBaseModelOutput(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment