Commit 078de197 authored by lizhigong's avatar lizhigong
Browse files

Merge branch 'v0.5.4_rzc' into 'v0.5.4_dev'

添加w_scale=1的判断,消除两处elementwise数乘操作

See merge request OpenDAS/sglang!39
parents 59259b56 8da7ca78
...@@ -1618,11 +1618,18 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1618,11 +1618,18 @@ class DeepseekV2AttentionMLA(nn.Module):
torch.bfloat16, torch.bfloat16,
q_nope_out, q_nope_out,
) )
else: else: # TODO: 手写融合算子
q_nope_out = torch.bmm( _q_nope_safe = q_nope.to(torch.bfloat16).transpose(0, 1)
q_nope.to(torch.bfloat16).transpose(0, 1), _w_kc_safe = self.w_kc.to(torch.bfloat16)
self.w_kc.to(torch.bfloat16) * self.w_scale, if abs(self.w_scale - 1) < 1e-6:
q_nope_out = torch.bmm(_q_nope_safe, _w_kc_safe)
else:
q_nope_out = torch.bmm(_q_nope_safe, _w_kc_safe * self.w_scale,
) )
# q_nope_out = torch.bmm(
# q_nope.to(torch.bfloat16).transpose(0, 1),
# self.w_kc.to(torch.bfloat16) * self.w_scale,
# )
elif self.w_kc.dtype == torch.float8_e4m3fn: elif self.w_kc.dtype == torch.float8_e4m3fn:
# fix bmm_fp8 error under cublas12.9 caused by bumpallocator, detail in pr#11612 # fix bmm_fp8 error under cublas12.9 caused by bumpallocator, detail in pr#11612
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8( q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
...@@ -1763,11 +1770,18 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1763,11 +1770,18 @@ class DeepseekV2AttentionMLA(nn.Module):
torch.bfloat16, torch.bfloat16,
attn_bmm_output, attn_bmm_output,
) )
else: else: # TODO: 手写融合算子
attn_bmm_output = torch.bmm( _attn_output_safe = attn_output.to(torch.bfloat16).transpose(0, 1)
attn_output.to(torch.bfloat16).transpose(0, 1), _w_vc_safe = self.w_vc.to(torch.bfloat16)
self.w_vc.to(torch.bfloat16) * self.w_scale, if abs(self.w_scale - 1) < 1e-6:
attn_bmm_output = torch.bmm(_attn_output_safe, _w_vc_safe)
else:
attn_bmm_output = torch.bmm(_attn_output_safe, _w_vc_safe * self.w_scale,
) )
# attn_bmm_output = torch.bmm(
# attn_output.to(torch.bfloat16).transpose(0, 1),
# self.w_vc.to(torch.bfloat16) * self.w_scale,
# )
if self.o_proj.weight.dtype == torch.uint8: if self.o_proj.weight.dtype == torch.uint8:
attn_bmm_output = attn_bmm_output.transpose(0, 1) attn_bmm_output = attn_bmm_output.transpose(0, 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment