Commit 9057fc2f authored by Julien Denize's avatar Julien Denize Committed by Kevin H. Luu
Browse files

[BUGFIX] llama_4_scaling wrongly passed to DeepseekAttention (#29908)


Signed-off-by: default avatarjuliendenize <julien.denize@mistral.ai>
(cherry picked from commit 5e5646e2)
parent a05b5805
......@@ -1135,6 +1135,8 @@ class DeepseekV2DecoderLayer(nn.Module):
dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim)
)
self.use_mha = use_mha
if use_mha:
attn_cls = DeepseekAttention
elif model_config.use_mla:
......@@ -1196,11 +1198,14 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
llama_4_scaling=llama_4_scaling,
)
attn_kwargs = {
"positions": positions,
"hidden_states": hidden_states,
}
if not self.use_mha:
attn_kwargs["llama_4_scaling"] = llama_4_scaling
hidden_states = self.self_attn(**attn_kwargs)
if (
not isinstance(self.self_attn, DeepseekAttention)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment