Unverified Commit 131e2584 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix TF T5/LED missing cross attn in retrun values (#15511)



* add cross attn to outputs

* add cross attn to outputs for TFLED

* add undo padding

* remove unused import

* fix style
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 6775b211
......@@ -29,7 +29,7 @@ from ...file_utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_tf_outputs import TFBaseModelOutputWithPast
from ...modeling_tf_outputs import TFBaseModelOutputWithPastAndCrossAttentions
# Public API
from ...modeling_tf_utils import (
......@@ -1220,7 +1220,7 @@ class TFLEDDecoderLayer(tf.keras.layers.Layer):
encoder_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None,
training=False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
"""
Args:
hidden_states (`tf.Tensor`): input to the layer of shape *(seq_len, batch, embed_dim)*
......@@ -1254,12 +1254,13 @@ class TFLEDDecoderLayer(tf.keras.layers.Layer):
# Cross-Attention Block
cross_attn_present_key_value = None
cross_attn_weights = None
if encoder_hidden_states is not None:
residual = hidden_states
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
hidden_states, _, cross_attn_present_key_value = self.encoder_attn(
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
......@@ -1285,6 +1286,7 @@ class TFLEDDecoderLayer(tf.keras.layers.Layer):
return (
hidden_states,
self_attn_weights,
cross_attn_weights,
present_key_value,
)
......@@ -1808,6 +1810,14 @@ class TFLEDEncoder(tf.keras.layers.Layer):
# unpad `hidden_states` because the calling function is expecting a length == input_ids.size(1)
hidden_states = self.compute_hidden_states(hidden_states, padding_len)
# undo padding
if inputs["output_attentions"]:
all_attentions = (
tuple([state[:, :, :-padding_len, :] for state in all_attentions])
if padding_len > 0
else all_attentions
)
if inputs["output_hidden_states"]:
encoder_states = encoder_states + (hidden_states,)
......@@ -2038,6 +2048,7 @@ class TFLEDDecoder(tf.keras.layers.Layer):
# decoder layers
all_hidden_states = ()
all_self_attns = ()
all_cross_attentions = ()
present_key_values = ()
# check if head_mask has a correct number of layers specified if desired
......@@ -2059,7 +2070,7 @@ class TFLEDDecoder(tf.keras.layers.Layer):
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None
hidden_states, layer_self_attn, present_key_value = decoder_layer(
hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
hidden_states,
attention_mask=combined_attention_mask,
encoder_hidden_states=inputs["encoder_hidden_states"],
......@@ -2076,24 +2087,31 @@ class TFLEDDecoder(tf.keras.layers.Layer):
if inputs["output_attentions"]:
all_self_attns += (layer_self_attn,)
all_cross_attentions += (layer_cross_attn,)
if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,)
else:
all_hidden_states = None
all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None
all_self_attns = all_self_attns if inputs["output_attentions"] else None
all_cross_attentions = all_cross_attentions if inputs["output_attentions"] else None
present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None
if not inputs["return_dict"]:
return hidden_states, present_key_values, all_hidden_states, all_self_attns
return tuple(
v
for v in [hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attentions]
if v is not None
)
else:
return TFBaseModelOutputWithPast(
return TFBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=present_key_values,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)
......@@ -2223,6 +2241,7 @@ class TFLEDMainLayer(tf.keras.layers.Layer):
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions,
......@@ -2475,6 +2494,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
past_key_values=outputs.past_key_values, # index 1 of d outputs
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
encoder_attentions=outputs.encoder_attentions, # 2 of e out
......
......@@ -33,7 +33,7 @@ from ...file_utils import (
)
from ...modeling_tf_outputs import (
TFBaseModelOutput,
TFBaseModelOutputWithPast,
TFBaseModelOutputWithPastAndCrossAttentions,
TFSeq2SeqLMOutput,
TFSeq2SeqModelOutput,
)
......@@ -771,6 +771,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
present_key_value_states = () if inputs["use_cache"] and self.is_decoder else None
all_hidden_states = () if inputs["output_hidden_states"] else None
all_attentions = () if inputs["output_attentions"] else None
all_cross_attentions = () if (inputs["output_attentions"] and self.is_decoder) else None
position_bias = None
encoder_decoder_position_bias = None
......@@ -814,6 +815,8 @@ class TFT5MainLayer(tf.keras.layers.Layer):
if inputs["output_attentions"]:
all_attentions = all_attentions + (layer_outputs[3],)
if self.is_decoder:
all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states, training=inputs["training"])
......@@ -831,14 +834,17 @@ class TFT5MainLayer(tf.keras.layers.Layer):
outputs = outputs + (all_hidden_states,)
if inputs["output_attentions"]:
outputs = outputs + (all_attentions,)
return outputs # last-layer hidden state, (all hidden states), (all attentions)
if self.is_decoder:
outputs + (all_cross_attentions,)
return outputs # last-layer hidden state, (past_key_values), (all hidden states), (all attentions), (all_cross_attentions)
if self.is_decoder:
return TFBaseModelOutputWithPast(
return TFBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=present_key_value_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
cross_attentions=all_cross_attentions,
)
else:
return TFBaseModelOutput(
......@@ -1264,6 +1270,7 @@ class TFT5Model(TFT5PreTrainedModel):
past_key_values=past,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions,
......@@ -1508,6 +1515,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
past_key_values=past,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions,
......
......@@ -322,7 +322,7 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
self.assertEqual(len(global_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, seq_length],
[self.model_tester.num_attention_heads, seq_length, seq_length],
)
self.assertListEqual(
list(global_attentions[0].shape[-3:]),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment