Unverified Commit 975003ea authored by théo gigant's avatar théo gigant Committed by GitHub
Browse files

fix a typo in flax T5 attention - attention_mask variable is misnamed (#26663)

* fix a typo in flax t5 attention

* fix the typo in flax longt5 attention
parent e8fdd787
......@@ -545,7 +545,7 @@ class FlaxLongT5Attention(nn.Module):
# During fast autoregressive decoding, we feed one position at a time,
# and cache the keys and values step by step.
if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
key_states, value_states, attention_attention_mask = self._concatenate_to_cache(
key_states, value_states, attention_mask = self._concatenate_to_cache(
key_states, value_states, query_states, attention_mask
)
......
......@@ -405,7 +405,7 @@ class FlaxT5Attention(nn.Module):
# During fast autoregressive decoding, we feed one position at a time,
# and cache the keys and values step by step.
if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
key_states, value_states, attention_attention_mask = self._concatenate_to_cache(
key_states, value_states, attention_mask = self._concatenate_to_cache(
key_states, value_states, query_states, attention_mask
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment