Unverified Commit 975003ea authored by théo gigant's avatar théo gigant Committed by GitHub
Browse files

fix a typo in flax T5 attention - attention_mask variable is misnamed (#26663)

* fix a typo in flax t5 attention

* fix the typo in flax longt5 attention
parent e8fdd787
...@@ -545,7 +545,7 @@ class FlaxLongT5Attention(nn.Module): ...@@ -545,7 +545,7 @@ class FlaxLongT5Attention(nn.Module):
# During fast autoregressive decoding, we feed one position at a time, # During fast autoregressive decoding, we feed one position at a time,
# and cache the keys and values step by step. # and cache the keys and values step by step.
if self.causal and (self.has_variable("cache", "cached_key") or init_cache): if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
key_states, value_states, attention_attention_mask = self._concatenate_to_cache( key_states, value_states, attention_mask = self._concatenate_to_cache(
key_states, value_states, query_states, attention_mask key_states, value_states, query_states, attention_mask
) )
......
...@@ -405,7 +405,7 @@ class FlaxT5Attention(nn.Module): ...@@ -405,7 +405,7 @@ class FlaxT5Attention(nn.Module):
# During fast autoregressive decoding, we feed one position at a time, # During fast autoregressive decoding, we feed one position at a time,
# and cache the keys and values step by step. # and cache the keys and values step by step.
if self.causal and (self.has_variable("cache", "cached_key") or init_cache): if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
key_states, value_states, attention_attention_mask = self._concatenate_to_cache( key_states, value_states, attention_mask = self._concatenate_to_cache(
key_states, value_states, query_states, attention_mask key_states, value_states, query_states, attention_mask
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment