Unverified Commit 131e2584 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix TF T5/LED missing cross attn in retrun values (#15511)



* add cross attn to outputs

* add cross attn to outputs for TFLED

* add undo padding

* remove unused import

* fix style
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 6775b211
...@@ -29,7 +29,7 @@ from ...file_utils import ( ...@@ -29,7 +29,7 @@ from ...file_utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
replace_return_docstrings, replace_return_docstrings,
) )
from ...modeling_tf_outputs import TFBaseModelOutputWithPast from ...modeling_tf_outputs import TFBaseModelOutputWithPastAndCrossAttentions
# Public API # Public API
from ...modeling_tf_utils import ( from ...modeling_tf_utils import (
...@@ -1220,7 +1220,7 @@ class TFLEDDecoderLayer(tf.keras.layers.Layer): ...@@ -1220,7 +1220,7 @@ class TFLEDDecoderLayer(tf.keras.layers.Layer):
encoder_layer_head_mask: Optional[tf.Tensor] = None, encoder_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None, past_key_value: Optional[Tuple[tf.Tensor]] = None,
training=False, training=False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
""" """
Args: Args:
hidden_states (`tf.Tensor`): input to the layer of shape *(seq_len, batch, embed_dim)* hidden_states (`tf.Tensor`): input to the layer of shape *(seq_len, batch, embed_dim)*
...@@ -1254,12 +1254,13 @@ class TFLEDDecoderLayer(tf.keras.layers.Layer): ...@@ -1254,12 +1254,13 @@ class TFLEDDecoderLayer(tf.keras.layers.Layer):
# Cross-Attention Block # Cross-Attention Block
cross_attn_present_key_value = None cross_attn_present_key_value = None
cross_attn_weights = None
if encoder_hidden_states is not None: if encoder_hidden_states is not None:
residual = hidden_states residual = hidden_states
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
hidden_states, _, cross_attn_present_key_value = self.encoder_attn( hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
...@@ -1285,6 +1286,7 @@ class TFLEDDecoderLayer(tf.keras.layers.Layer): ...@@ -1285,6 +1286,7 @@ class TFLEDDecoderLayer(tf.keras.layers.Layer):
return ( return (
hidden_states, hidden_states,
self_attn_weights, self_attn_weights,
cross_attn_weights,
present_key_value, present_key_value,
) )
...@@ -1808,6 +1810,14 @@ class TFLEDEncoder(tf.keras.layers.Layer): ...@@ -1808,6 +1810,14 @@ class TFLEDEncoder(tf.keras.layers.Layer):
# unpad `hidden_states` because the calling function is expecting a length == input_ids.size(1) # unpad `hidden_states` because the calling function is expecting a length == input_ids.size(1)
hidden_states = self.compute_hidden_states(hidden_states, padding_len) hidden_states = self.compute_hidden_states(hidden_states, padding_len)
# undo padding
if inputs["output_attentions"]:
all_attentions = (
tuple([state[:, :, :-padding_len, :] for state in all_attentions])
if padding_len > 0
else all_attentions
)
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
...@@ -2038,6 +2048,7 @@ class TFLEDDecoder(tf.keras.layers.Layer): ...@@ -2038,6 +2048,7 @@ class TFLEDDecoder(tf.keras.layers.Layer):
# decoder layers # decoder layers
all_hidden_states = () all_hidden_states = ()
all_self_attns = () all_self_attns = ()
all_cross_attentions = ()
present_key_values = () present_key_values = ()
# check if head_mask has a correct number of layers specified if desired # check if head_mask has a correct number of layers specified if desired
...@@ -2059,7 +2070,7 @@ class TFLEDDecoder(tf.keras.layers.Layer): ...@@ -2059,7 +2070,7 @@ class TFLEDDecoder(tf.keras.layers.Layer):
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None
hidden_states, layer_self_attn, present_key_value = decoder_layer( hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
hidden_states, hidden_states,
attention_mask=combined_attention_mask, attention_mask=combined_attention_mask,
encoder_hidden_states=inputs["encoder_hidden_states"], encoder_hidden_states=inputs["encoder_hidden_states"],
...@@ -2076,24 +2087,31 @@ class TFLEDDecoder(tf.keras.layers.Layer): ...@@ -2076,24 +2087,31 @@ class TFLEDDecoder(tf.keras.layers.Layer):
if inputs["output_attentions"]: if inputs["output_attentions"]:
all_self_attns += (layer_self_attn,) all_self_attns += (layer_self_attn,)
all_cross_attentions += (layer_cross_attn,)
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
else: else:
all_hidden_states = None all_hidden_states = None
all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None all_self_attns = all_self_attns if inputs["output_attentions"] else None
all_cross_attentions = all_cross_attentions if inputs["output_attentions"] else None
present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None
if not inputs["return_dict"]: if not inputs["return_dict"]:
return hidden_states, present_key_values, all_hidden_states, all_self_attns return tuple(
v
for v in [hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attentions]
if v is not None
)
else: else:
return TFBaseModelOutputWithPast( return TFBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=present_key_values, past_key_values=present_key_values,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attns, attentions=all_self_attns,
cross_attentions=all_cross_attentions,
) )
...@@ -2223,6 +2241,7 @@ class TFLEDMainLayer(tf.keras.layers.Layer): ...@@ -2223,6 +2241,7 @@ class TFLEDMainLayer(tf.keras.layers.Layer):
past_key_values=decoder_outputs.past_key_values, past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"].hidden_states, encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions, encoder_attentions=inputs["encoder_outputs"].attentions,
...@@ -2475,6 +2494,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel): ...@@ -2475,6 +2494,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
past_key_values=outputs.past_key_values, # index 1 of d outputs past_key_values=outputs.past_key_values, # index 1 of d outputs
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
encoder_attentions=outputs.encoder_attentions, # 2 of e out encoder_attentions=outputs.encoder_attentions, # 2 of e out
......
...@@ -33,7 +33,7 @@ from ...file_utils import ( ...@@ -33,7 +33,7 @@ from ...file_utils import (
) )
from ...modeling_tf_outputs import ( from ...modeling_tf_outputs import (
TFBaseModelOutput, TFBaseModelOutput,
TFBaseModelOutputWithPast, TFBaseModelOutputWithPastAndCrossAttentions,
TFSeq2SeqLMOutput, TFSeq2SeqLMOutput,
TFSeq2SeqModelOutput, TFSeq2SeqModelOutput,
) )
...@@ -771,6 +771,7 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -771,6 +771,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
present_key_value_states = () if inputs["use_cache"] and self.is_decoder else None present_key_value_states = () if inputs["use_cache"] and self.is_decoder else None
all_hidden_states = () if inputs["output_hidden_states"] else None all_hidden_states = () if inputs["output_hidden_states"] else None
all_attentions = () if inputs["output_attentions"] else None all_attentions = () if inputs["output_attentions"] else None
all_cross_attentions = () if (inputs["output_attentions"] and self.is_decoder) else None
position_bias = None position_bias = None
encoder_decoder_position_bias = None encoder_decoder_position_bias = None
...@@ -814,6 +815,8 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -814,6 +815,8 @@ class TFT5MainLayer(tf.keras.layers.Layer):
if inputs["output_attentions"]: if inputs["output_attentions"]:
all_attentions = all_attentions + (layer_outputs[3],) all_attentions = all_attentions + (layer_outputs[3],)
if self.is_decoder:
all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states, training=inputs["training"]) hidden_states = self.dropout(hidden_states, training=inputs["training"])
...@@ -831,14 +834,17 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -831,14 +834,17 @@ class TFT5MainLayer(tf.keras.layers.Layer):
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
if inputs["output_attentions"]: if inputs["output_attentions"]:
outputs = outputs + (all_attentions,) outputs = outputs + (all_attentions,)
return outputs # last-layer hidden state, (all hidden states), (all attentions) if self.is_decoder:
outputs + (all_cross_attentions,)
return outputs # last-layer hidden state, (past_key_values), (all hidden states), (all attentions), (all_cross_attentions)
if self.is_decoder: if self.is_decoder:
return TFBaseModelOutputWithPast( return TFBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=present_key_value_states, past_key_values=present_key_value_states,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_attentions, attentions=all_attentions,
cross_attentions=all_cross_attentions,
) )
else: else:
return TFBaseModelOutput( return TFBaseModelOutput(
...@@ -1264,6 +1270,7 @@ class TFT5Model(TFT5PreTrainedModel): ...@@ -1264,6 +1270,7 @@ class TFT5Model(TFT5PreTrainedModel):
past_key_values=past, past_key_values=past,
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"].hidden_states, encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions, encoder_attentions=inputs["encoder_outputs"].attentions,
...@@ -1508,6 +1515,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1508,6 +1515,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
past_key_values=past, past_key_values=past,
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"].hidden_states, encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions, encoder_attentions=inputs["encoder_outputs"].attentions,
......
...@@ -322,7 +322,7 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -322,7 +322,7 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
self.assertEqual(len(global_attentions), self.model_tester.num_hidden_layers) self.assertEqual(len(global_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual( self.assertListEqual(
list(attentions[0].shape[-3:]), list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, seq_length], [self.model_tester.num_attention_heads, seq_length, seq_length],
) )
self.assertListEqual( self.assertListEqual(
list(global_attentions[0].shape[-3:]), list(global_attentions[0].shape[-3:]),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment