"src/vscode:/vscode.git/clone" did not exist on "5249a2666e51c4381156faa0f6a4b4d079e0c2a7"
Unverified Commit 162f3ccb authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Fix layernorm input shape (#1066)


Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
parent 65e89bae
...@@ -445,11 +445,12 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -445,11 +445,12 @@ class DeepseekV2AttentionMLA(nn.Module):
q_nope_out = q_input[..., : self.kv_lora_rank] q_nope_out = q_input[..., : self.kv_lora_rank]
torch.bmm(q_nope.transpose(0, 1), self.w_kc, out=q_nope_out.transpose(0, 1)) torch.bmm(q_nope.transpose(0, 1), self.w_kc, out=q_nope_out.transpose(0, 1))
k_input = self.kv_a_proj_with_mqa(hidden_states)[0].unsqueeze(1) latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
k_pe = k_input[..., self.kv_lora_rank :] v_input = latent_cache[..., : self.kv_lora_rank]
v_input = k_input[..., : self.kv_lora_rank] v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
v_input = self.kv_a_layernorm(v_input.contiguous()) k_input = latent_cache.unsqueeze(1)
k_input[..., : self.kv_lora_rank] = v_input k_input[..., : self.kv_lora_rank] = v_input
k_pe = k_input[..., self.kv_lora_rank :]
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
q_input[..., self.kv_lora_rank :] = q_pe q_input[..., self.kv_lora_rank :] = q_pe
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment