Unverified Commit a163c9ca authored by ZhuBaohe's avatar ZhuBaohe Committed by GitHub
Browse files

[T5] Fix Cross Attention position bias (#4499)

* fix

* fix1
parent 1d690289
......@@ -745,7 +745,7 @@ class T5Stack(T5PreTrainedModel):
# layer_outputs = hidden-states, key-value-states (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
position_bias = layer_outputs[3 if self.output_attentions else 2]
if self.is_decoder and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[4 if self.output_attentions else 3]
encoder_decoder_position_bias = layer_outputs[5 if self.output_attentions else 3]
# append next layer key value states
present_key_value_states = present_key_value_states + (present_key_value_state,)
......
......@@ -682,7 +682,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
# layer_outputs = hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
position_bias = layer_outputs[3 if self.output_attentions else 2]
if self.is_decoder and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[4 if self.output_attentions else 3]
encoder_decoder_position_bias = layer_outputs[5 if self.output_attentions else 3]
# append next layer key value states
present_key_value_states = present_key_value_states + (present_key_value_state,)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment