"docs/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "a0939977a3b3c34c925c565c3fd3dcbe5d09e23c"
Unverified Commit 50f1b6d6 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Remove copy after bmm (#7441)

parent 5962e70d
...@@ -1084,13 +1084,16 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1084,13 +1084,16 @@ class DeepseekV2AttentionMLA(nn.Module):
masked_m, masked_m,
expected_m, expected_m,
) )
attn_bmm_output = attn_bmm_output[:, :expected_m, :] attn_bmm_output = (
attn_bmm_output[:, :expected_m, :].transpose(0, 1).flatten(1, 2)
)
elif _is_hip: elif _is_hip:
# TODO(haishaw): add bmm_fp8 to ROCm # TODO(haishaw): add bmm_fp8 to ROCm
attn_bmm_output = torch.bmm( attn_bmm_output = torch.bmm(
attn_output.to(torch.bfloat16).transpose(0, 1), attn_output.to(torch.bfloat16).transpose(0, 1),
self.w_vc.to(torch.bfloat16) * self.w_scale, self.w_vc.to(torch.bfloat16) * self.w_scale,
) )
attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
elif self.w_vc.dtype == torch.float8_e4m3fn: elif self.w_vc.dtype == torch.float8_e4m3fn:
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8( attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
attn_output.transpose(0, 1), attn_output.transpose(0, 1),
...@@ -1103,10 +1106,21 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1103,10 +1106,21 @@ class DeepseekV2AttentionMLA(nn.Module):
self.w_scale, self.w_scale,
torch.bfloat16, torch.bfloat16,
) )
attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
else: else:
attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc) attn_bmm_output = torch.empty(
attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2) (attn_output.shape[0], self.num_local_heads * self.v_head_dim),
output, _ = self.o_proj(attn_output) dtype=attn_output.dtype,
device=attn_output.device,
)
torch.bmm(
attn_output.transpose(0, 1),
self.w_vc,
out=attn_bmm_output.view(
-1, self.num_local_heads, self.v_head_dim
).transpose(0, 1),
)
output, _ = self.o_proj(attn_bmm_output)
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment