Unverified Commit 23a75a53 authored by Dahlbomii's avatar Dahlbomii Committed by GitHub
Browse files

Type hints and decorator for TF T5 (#16376)



* Type hints and TF decorator added

* Re-add XLA generation method

* Re-add lines that were deleted by conflicting updates

* Re-add lines that were deleted by conflicting updates

* Re-add lines that were deleted by conflicting updates
Co-authored-by: default avatarmatt <rocketknight1@gmail.com>
parent 2a27c800
......@@ -19,7 +19,7 @@ import copy
import itertools
import math
import warnings
from typing import Tuple
from typing import Optional, Tuple, Union
import numpy as np
import tensorflow as tf
......@@ -34,11 +34,12 @@ from ...modeling_tf_outputs import (
)
from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss,
TFModelInputType,
TFPreTrainedModel,
TFSharedEmbeddings,
TFWrappedEmbeddings,
input_processing,
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...utils import (
......@@ -637,6 +638,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
def _prune_heads(self, heads_to_prune):
raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models
@unpack_inputs
def call(
self,
input_ids=None,
......@@ -654,72 +656,48 @@ class TFT5MainLayer(tf.keras.layers.Layer):
training=False,
**kwargs,
) -> Tuple:
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
encoder_head_mask=encoder_head_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
if input_ids is not None and inputs_embeds is not None:
err_msg_prefix = "decoder_" if self.is_decoder else ""
raise ValueError(
f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
)
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
inputs["input_ids"] = tf.reshape(inputs["input_ids"], (-1, input_shape[-1]))
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
elif input_ids is not None:
input_shape = shape_list(input_ids)
input_ids = tf.reshape(input_ids, (-1, input_shape[-1]))
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
err_msg_prefix = "decoder_" if self.is_decoder else ""
raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
if inputs["inputs_embeds"] is None:
if inputs_embeds is None:
assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"])
inputs_embeds = self.embed_tokens(input_ids)
batch_size, seq_length = input_shape
# required mask seq length can be calculated via length of past
mask_seq_length = (
shape_list(inputs["past_key_values"][0][0])[2] + seq_length
if inputs["past_key_values"] is not None
else seq_length
shape_list(past_key_values[0][0])[2] + seq_length if past_key_values is not None else seq_length
)
if inputs["attention_mask"] is None:
inputs["attention_mask"] = tf.fill((batch_size, mask_seq_length), 1)
if (
self.is_decoder
and inputs["encoder_attention_mask"] is None
and inputs["encoder_hidden_states"] is not None
):
encoder_seq_length = shape_list(inputs["encoder_hidden_states"])[1]
inputs["encoder_attention_mask"] = tf.fill((batch_size, encoder_seq_length), 1)
if attention_mask is None:
attention_mask = tf.fill((batch_size, mask_seq_length), 1)
if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
encoder_seq_length = shape_list(encoder_hidden_states)[1]
encoder_attention_mask = tf.fill((batch_size, encoder_seq_length), 1)
# initialize past_key_values with `None` if past does not exist
if inputs["past_key_values"] is None:
inputs["past_key_values"] = [None] * len(self.block)
if past_key_values is None:
past_key_values = [None] * len(self.block)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=inputs["inputs_embeds"].dtype)
num_dims_attention_mask = len(shape_list(inputs["attention_mask"]))
attention_mask = tf.cast(attention_mask, dtype=inputs_embeds.dtype)
num_dims_attention_mask = len(shape_list(attention_mask))
if num_dims_attention_mask == 3:
extended_attention_mask = inputs["attention_mask"][:, None, :, :]
extended_attention_mask = attention_mask[:, None, :, :]
elif num_dims_attention_mask == 2:
# Provided a padding mask of dimensions [batch_size, mask_seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
......@@ -730,12 +708,12 @@ class TFT5MainLayer(tf.keras.layers.Layer):
tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)),
seq_ids[None, :, None],
)
causal_mask = tf.cast(causal_mask, dtype=inputs["attention_mask"].dtype)
extended_attention_mask = causal_mask[:, None, :, :] * inputs["attention_mask"][:, None, None, :]
if inputs["past_key_values"][0] is not None:
causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype)
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
if past_key_values[0] is not None:
extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :]
else:
extended_attention_mask = inputs["attention_mask"][:, None, None, :]
extended_attention_mask = attention_mask[:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
......@@ -750,18 +728,16 @@ class TFT5MainLayer(tf.keras.layers.Layer):
extended_attention_mask = (1.0 - extended_attention_mask) * -1e9
if self.is_decoder and inputs["encoder_attention_mask"] is not None:
if self.is_decoder and encoder_attention_mask is not None:
# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
inputs["encoder_attention_mask"] = tf.cast(
inputs["encoder_attention_mask"], dtype=extended_attention_mask.dtype
)
num_dims_encoder_attention_mask = len(shape_list(inputs["encoder_attention_mask"]))
encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype)
num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask))
if num_dims_encoder_attention_mask == 3:
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, :, :]
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
if num_dims_encoder_attention_mask == 2:
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, None, :]
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
......@@ -772,33 +748,31 @@ class TFT5MainLayer(tf.keras.layers.Layer):
else:
encoder_extended_attention_mask = None
present_key_value_states = () if inputs["use_cache"] and self.is_decoder else None
all_hidden_states = () if inputs["output_hidden_states"] else None
all_attentions = () if inputs["output_attentions"] else None
all_cross_attentions = () if (inputs["output_attentions"] and self.is_decoder) else None
present_key_value_states = () if use_cache and self.is_decoder else None
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
all_cross_attentions = () if (output_attentions and self.is_decoder) else None
position_bias = None
encoder_decoder_position_bias = None
hidden_states = self.dropout(inputs["inputs_embeds"], training=inputs["training"])
hidden_states = self.dropout(inputs_embeds, training=training)
for idx, (layer_module, past_key_value) in enumerate(zip(self.block, inputs["past_key_values"])):
if inputs["output_hidden_states"]:
for idx, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(
hidden_states,
attention_mask=extended_attention_mask,
position_bias=position_bias,
encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
encoder_layer_head_mask=inputs["encoder_head_mask"][idx]
if inputs["encoder_head_mask"] is not None
else None,
layer_head_mask=head_mask[idx] if head_mask is not None else None,
encoder_layer_head_mask=encoder_head_mask[idx] if encoder_head_mask is not None else None,
past_key_value=past_key_value,
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
training=inputs["training"],
use_cache=use_cache,
output_attentions=output_attentions,
training=training,
)
# layer_outputs is a tuple with:
......@@ -810,33 +784,33 @@ class TFT5MainLayer(tf.keras.layers.Layer):
# (self-attention position bias), (cross-attention position bias), (cross-attention weights),
position_bias = layer_outputs[2]
if self.is_decoder and inputs["encoder_hidden_states"] is not None:
encoder_decoder_position_bias = layer_outputs[4 if inputs["output_attentions"] else 3]
if self.is_decoder and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
# append next layer key value states
if present_key_value_state is not None and inputs["use_cache"] and self.is_decoder:
if present_key_value_state is not None and use_cache and self.is_decoder:
present_key_value_states = present_key_value_states + (present_key_value_state,)
if inputs["output_attentions"]:
if output_attentions:
all_attentions = all_attentions + (layer_outputs[3],)
if self.is_decoder:
all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states, training=inputs["training"])
hidden_states = self.dropout(hidden_states, training=training)
# Add last layer
if inputs["output_hidden_states"]:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not inputs["return_dict"]:
if not return_dict:
outputs = (hidden_states,)
# need to check if is decoder here as well for special cases when using keras compile
if inputs["use_cache"] and self.is_decoder:
if use_cache and self.is_decoder:
outputs = outputs + (present_key_value_states,)
if inputs["output_hidden_states"]:
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if inputs["output_attentions"]:
if output_attentions:
outputs = outputs + (all_attentions,)
if self.is_decoder:
outputs + (all_cross_attentions,)
......@@ -1158,27 +1132,28 @@ class TFT5Model(TFT5PreTrainedModel):
def get_decoder(self):
return self.decoder
@unpack_inputs
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs=None,
past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
input_ids: Optional[TFModelInputType] = None,
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_outputs: Optional[Union[np.ndarray, tf.Tensor]] = None,
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
**kwargs,
):
) -> Union[Tuple, TFSeq2SeqModelOutput]:
r"""
Returns:
......@@ -1204,68 +1179,47 @@ class TFT5Model(TFT5PreTrainedModel):
warnings.warn(_HEAD_MASK_WARNING_MSG, FutureWarning)
decoder_head_mask = head_mask
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
# Encode if needed (training, first prediction pass)
if inputs["encoder_outputs"] is None:
inputs["encoder_outputs"] = self.encoder(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_ids,
attention_mask=attention_mask,
encoder_hidden_states=None,
encoder_attention_mask=None,
inputs_embeds=inputs["inputs_embeds"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs_embeds,
head_mask=head_mask,
past_key_values=None,
use_cache=False,
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
hidden_states = inputs["encoder_outputs"][0]
hidden_states = encoder_outputs[0]
# Decode
decoder_outputs = self.decoder(
inputs["decoder_input_ids"],
attention_mask=inputs["decoder_attention_mask"],
decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=hidden_states,
encoder_attention_mask=inputs["attention_mask"],
inputs_embeds=inputs["decoder_inputs_embeds"],
head_mask=inputs["decoder_head_mask"],
encoder_head_mask=inputs["head_mask"],
past_key_values=inputs["past_key_values"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
encoder_attention_mask=attention_mask,
inputs_embeds=decoder_inputs_embeds,
head_mask=decoder_head_mask,
encoder_head_mask=head_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
past = decoder_outputs[1] if inputs["use_cache"] else None
past = decoder_outputs[1] if use_cache else None
if not inputs["return_dict"]:
if not return_dict:
if past is not None:
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
return decoder_outputs + inputs["encoder_outputs"]
return decoder_outputs + encoder_outputs
return TFSeq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
......@@ -1273,9 +1227,9 @@ class TFT5Model(TFT5PreTrainedModel):
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
def serving_output(self, output):
......@@ -1354,28 +1308,29 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
def get_decoder(self):
return self.decoder
@unpack_inputs
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs=None,
past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
input_ids: Optional[TFModelInputType] = None,
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_outputs: Optional[Union[np.ndarray, tf.Tensor]] = None,
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
**kwargs,
):
) -> Union[Tuple, TFSeq2SeqLMOutput]:
r"""
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the cross entropy classification loss. Indices should be in `[0, ...,
......@@ -1411,65 +1366,39 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
warnings.warn(_HEAD_MASK_WARNING_MSG, FutureWarning)
decoder_head_mask = head_mask
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
# Encode if needed (training, first prediction pass)
if inputs["encoder_outputs"] is None:
inputs["encoder_outputs"] = self.encoder(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
inputs_embeds=inputs["inputs_embeds"],
head_mask=inputs["head_mask"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
hidden_states = inputs["encoder_outputs"][0]
hidden_states = encoder_outputs[0]
if (
inputs["labels"] is not None
and inputs["decoder_input_ids"] is None
and inputs["decoder_inputs_embeds"] is None
):
if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
# get decoder inputs from shifting lm labels to the right
inputs["decoder_input_ids"] = self._shift_right(inputs["labels"])
decoder_input_ids = self._shift_right(labels)
# Decode
decoder_outputs = self.decoder(
inputs["decoder_input_ids"],
attention_mask=inputs["decoder_attention_mask"],
decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=hidden_states,
encoder_attention_mask=inputs["attention_mask"],
inputs_embeds=inputs["decoder_inputs_embeds"],
head_mask=inputs["decoder_head_mask"],
past_key_values=inputs["past_key_values"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
encoder_attention_mask=attention_mask,
inputs_embeds=decoder_inputs_embeds,
head_mask=decoder_head_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
sequence_output = decoder_outputs[0]
......@@ -1483,29 +1412,29 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
logits = tf.cast(logits, tf.float32)
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits)
loss = None if labels is None else self.hf_compute_loss(labels, logits)
past = decoder_outputs[1] if inputs["use_cache"] else None
if not inputs["return_dict"]:
past = decoder_outputs[1] if use_cache else None
if not return_dict:
if past is not None:
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
output = (logits,) + decoder_outputs[1:] + inputs["encoder_outputs"]
output = (logits,) + decoder_outputs[1:] + encoder_outputs
return ((loss,) + output) if loss is not None else output
# If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True
elif isinstance(inputs["encoder_outputs"], tuple):
last_hidden_state = inputs["encoder_outputs"][0]
elif isinstance(encoder_outputs, tuple):
last_hidden_state = encoder_outputs[0]
hidden_states = None
attentions = None
idx = 0
if inputs["output_hidden_states"]:
if output_hidden_states:
idx += 1
hidden_states = inputs["encoder_outputs"][idx]
if inputs["output_attentions"]:
hidden_states = encoder_outputs[idx]
if output_attentions:
idx += 1
attentions = inputs["encoder_outputs"][idx]
attentions = encoder_outputs[idx]
inputs["encoder_outputs"] = TFBaseModelOutput(
encoder_outputs = TFBaseModelOutput(
last_hidden_state=last_hidden_state,
hidden_states=hidden_states,
attentions=attentions,
......@@ -1518,9 +1447,9 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
def serving_output(self, output):
......@@ -1685,20 +1614,21 @@ class TFT5EncoderModel(TFT5PreTrainedModel):
def get_encoder(self):
return self.encoder
@unpack_inputs
@add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFBaseModelOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
input_ids,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
input_ids: Optional[TFModelInputType] = None,
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
**kwargs,
):
) -> Union[Tuple, TFBaseModelOutput]:
r"""
Returns:
......@@ -1715,36 +1645,23 @@ class TFT5EncoderModel(TFT5PreTrainedModel):
>>> ).input_ids # Batch size 1
>>> outputs = model(input_ids)
```"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
encoder_outputs = self.encoder(
input_ids,
attention_mask=inputs["attention_mask"],
attention_mask=attention_mask,
encoder_hidden_states=None,
encoder_attention_mask=None,
inputs_embeds=inputs["inputs_embeds"],
inputs_embeds=inputs_embeds,
head_mask=head_mask,
past_key_values=None,
use_cache=False,
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
if not inputs["return_dict"]:
if not return_dict:
return encoder_outputs
return TFBaseModelOutput(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment