Unverified Commit ec7da6fc authored by Lucia Fang's avatar Lucia Fang Committed by GitHub
Browse files

[BugFix] llama4 qknorm should be not shared across head (#16311)


Signed-off-by: default avatarLu Fang <fanglu@fb.com>
parent 819d548e
...@@ -155,14 +155,8 @@ class Llama4Attention(nn.Module): ...@@ -155,14 +155,8 @@ class Llama4Attention(nn.Module):
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.n_rep = self.num_heads // self.num_kv_heads self.n_rep = self.num_heads // self.num_kv_heads
self.q_norm = RMSNorm( self.qk_norm = RMSNorm(
hidden_size=self.q_size, hidden_size=self.head_dim,
eps=config.rms_norm_eps,
has_weight=False,
dtype=torch.float32,
) if self.use_qk_norm else None
self.k_norm = RMSNorm(
hidden_size=self.kv_size,
eps=config.rms_norm_eps, eps=config.rms_norm_eps,
has_weight=False, has_weight=False,
dtype=torch.float32, dtype=torch.float32,
...@@ -226,10 +220,11 @@ class Llama4Attention(nn.Module): ...@@ -226,10 +220,11 @@ class Llama4Attention(nn.Module):
if self.rotary_emb is not None: if self.rotary_emb is not None:
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
if self.q_norm is not None: if self.qk_norm is not None:
q = self.q_norm(q.float()).to(q.dtype) q = q.reshape(-1, self.num_heads, self.head_dim)
if self.k_norm is not None: q = self.qk_norm(q.float()).reshape(-1, self.q_size).to(q.dtype)
k = self.k_norm(k.float()).to(k.dtype) k = k.reshape(-1, self.num_kv_heads, self.head_dim)
k = self.qk_norm(k.float()).reshape(-1, self.kv_size).to(k.dtype)
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399) # We are applying temperature tuning (https://arxiv.org/abs/2501.19399)
# to NoPE layers, where the inference-time temperature tuning function # to NoPE layers, where the inference-time temperature tuning function
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment