Unverified Commit 50f1b6d6 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Remove copy after bmm (#7441)

parent 5962e70d
......@@ -1084,13 +1084,16 @@ class DeepseekV2AttentionMLA(nn.Module):
masked_m,
expected_m,
)
attn_bmm_output = attn_bmm_output[:, :expected_m, :]
attn_bmm_output = (
attn_bmm_output[:, :expected_m, :].transpose(0, 1).flatten(1, 2)
)
elif _is_hip:
# TODO(haishaw): add bmm_fp8 to ROCm
attn_bmm_output = torch.bmm(
attn_output.to(torch.bfloat16).transpose(0, 1),
self.w_vc.to(torch.bfloat16) * self.w_scale,
)
attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
elif self.w_vc.dtype == torch.float8_e4m3fn:
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
attn_output.transpose(0, 1),
......@@ -1103,10 +1106,21 @@ class DeepseekV2AttentionMLA(nn.Module):
self.w_scale,
torch.bfloat16,
)
attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
else:
attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
output, _ = self.o_proj(attn_output)
attn_bmm_output = torch.empty(
(attn_output.shape[0], self.num_local_heads * self.v_head_dim),
dtype=attn_output.dtype,
device=attn_output.device,
)
torch.bmm(
attn_output.transpose(0, 1),
self.w_vc,
out=attn_bmm_output.view(
-1, self.num_local_heads, self.v_head_dim
).transpose(0, 1),
)
output, _ = self.o_proj(attn_bmm_output)
return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment