Unverified Commit 6af3ce77 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[Flax LLaMA] Fix attn dropout (#28059)

parent 7e876dca
...@@ -289,6 +289,10 @@ class FlaxLlamaAttention(nn.Module): ...@@ -289,6 +289,10 @@ class FlaxLlamaAttention(nn.Module):
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
attention_mask = combine_masks(attention_mask, causal_mask) attention_mask = combine_masks(attention_mask, causal_mask)
dropout_rng = None
if not deterministic and self.config.attention_dropout > 0.0:
dropout_rng = self.make_rng("dropout")
# During fast autoregressive decoding, we feed one position at a time, # During fast autoregressive decoding, we feed one position at a time,
# and cache the keys and values step by step. # and cache the keys and values step by step.
if self.has_variable("cache", "cached_key") or init_cache: if self.has_variable("cache", "cached_key") or init_cache:
...@@ -307,6 +311,8 @@ class FlaxLlamaAttention(nn.Module): ...@@ -307,6 +311,8 @@ class FlaxLlamaAttention(nn.Module):
query, query,
key, key,
bias=attention_bias, bias=attention_bias,
dropout_rng=dropout_rng,
dropout_rate=self.config.attention_dropout,
deterministic=deterministic, deterministic=deterministic,
dtype=attention_dtype, dtype=attention_dtype,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment