"vscode:/vscode.git/clone" did not exist on "933790c209fe73398a542d9fb31a138103ee3ccb"
Unverified Commit 4f0f844b authored by Po-Han Huang (NVIDIA)'s avatar Po-Han Huang (NVIDIA) Committed by GitHub
Browse files

Fix cuda illegal mem access with Llama4 TP8 + rms_norm custom op (#22701)


Signed-off-by: default avatarPo-Han Huang <pohanh@nvidia.com>
parent c5830381
...@@ -224,10 +224,14 @@ class Llama4Attention(nn.Module): ...@@ -224,10 +224,14 @@ class Llama4Attention(nn.Module):
if self.rotary_emb is not None: if self.rotary_emb is not None:
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
if self.qk_norm is not None: if self.qk_norm is not None:
q = q.reshape(-1, self.num_heads, self.head_dim) # Normalization is applied on the head_dim dimension. The rest of
# the dimensions are collapsed into a single dimension to support
# custom rms_norm cuda kernel.
q = q.reshape(-1, self.head_dim)
q = self.qk_norm(q.float()).reshape(-1, self.q_size).to(q.dtype) q = self.qk_norm(q.float()).reshape(-1, self.q_size).to(q.dtype)
k = k.reshape(-1, self.num_kv_heads, self.head_dim) k = k.reshape(-1, self.head_dim)
k = self.qk_norm(k.float()).reshape(-1, self.kv_size).to(k.dtype) k = self.qk_norm(k.float()).reshape(-1, self.kv_size).to(k.dtype)
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399) # We are applying temperature tuning (https://arxiv.org/abs/2501.19399)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment