Unverified Commit 162f3ccb authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Fix layernorm input shape (#1066)


Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
parent 65e89bae
......@@ -445,11 +445,12 @@ class DeepseekV2AttentionMLA(nn.Module):
q_nope_out = q_input[..., : self.kv_lora_rank]
torch.bmm(q_nope.transpose(0, 1), self.w_kc, out=q_nope_out.transpose(0, 1))
k_input = self.kv_a_proj_with_mqa(hidden_states)[0].unsqueeze(1)
k_pe = k_input[..., self.kv_lora_rank :]
v_input = k_input[..., : self.kv_lora_rank]
v_input = self.kv_a_layernorm(v_input.contiguous())
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
v_input = latent_cache[..., : self.kv_lora_rank]
v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
k_input = latent_cache.unsqueeze(1)
k_input[..., : self.kv_lora_rank] = v_input
k_pe = k_input[..., self.kv_lora_rank :]
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
q_input[..., self.kv_lora_rank :] = q_pe
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment