Unverified Commit 6af3ce77 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[Flax LLaMA] Fix attn dropout (#28059)

parent 7e876dca
......@@ -289,6 +289,10 @@ class FlaxLlamaAttention(nn.Module):
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
attention_mask = combine_masks(attention_mask, causal_mask)
dropout_rng = None
if not deterministic and self.config.attention_dropout > 0.0:
dropout_rng = self.make_rng("dropout")
# During fast autoregressive decoding, we feed one position at a time,
# and cache the keys and values step by step.
if self.has_variable("cache", "cached_key") or init_cache:
......@@ -307,6 +311,8 @@ class FlaxLlamaAttention(nn.Module):
query,
key,
bias=attention_bias,
dropout_rng=dropout_rng,
dropout_rate=self.config.attention_dropout,
deterministic=deterministic,
dtype=attention_dtype,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment