Unverified Commit 3b61d289 authored by Aashiq Muhamed's avatar Aashiq Muhamed Committed by GitHub
Browse files

Include decoder_attention_mask in T5 model inputs (#22835)

parent 91d6a593
...@@ -1807,6 +1807,7 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel): ...@@ -1807,6 +1807,7 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel):
attention_mask=None, attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
decoder_attention_mask=None,
cross_attn_head_mask=None, cross_attn_head_mask=None,
use_cache=None, use_cache=None,
encoder_outputs=None, encoder_outputs=None,
...@@ -1823,6 +1824,7 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel): ...@@ -1823,6 +1824,7 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel):
"attention_mask": attention_mask, "attention_mask": attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
"decoder_head_mask": decoder_head_mask, "decoder_head_mask": decoder_head_mask,
"decoder_attention_mask": decoder_attention_mask,
"cross_attn_head_mask": cross_attn_head_mask, "cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache, "use_cache": use_cache,
} }
......
...@@ -1774,6 +1774,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1774,6 +1774,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
attention_mask=None, attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
decoder_attention_mask=None,
cross_attn_head_mask=None, cross_attn_head_mask=None,
use_cache=None, use_cache=None,
encoder_outputs=None, encoder_outputs=None,
...@@ -1790,6 +1791,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1790,6 +1791,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
"attention_mask": attention_mask, "attention_mask": attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
"decoder_head_mask": decoder_head_mask, "decoder_head_mask": decoder_head_mask,
"decoder_attention_mask": decoder_attention_mask,
"cross_attn_head_mask": cross_attn_head_mask, "cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache, "use_cache": use_cache,
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment