Unverified Commit be0058bc authored by Liu-congo's avatar Liu-congo Committed by GitHub
Browse files

[BugFix] replace the input_to_float8 used in dsv2 (#11612)


Signed-off-by: default avatarLiu-congo <1502632128@qq.com>
parent 9e3be1fa
......@@ -92,7 +92,6 @@ from sglang.srt.layers.quantization.fp8_utils import (
block_quant_dequant,
block_quant_to_tensor_quant,
channel_quant_to_tensor_quant,
input_to_float8,
normalize_e4m3fn_to_e4m3fnuz,
quant_weight_ue8m0,
requant_weight_ue8m0_inplace,
......@@ -1623,15 +1622,15 @@ class DeepseekV2AttentionMLA(nn.Module):
self.w_kc.to(torch.bfloat16) * self.w_scale,
)
elif self.w_kc.dtype == torch.float8_e4m3fn:
# TODO fix the per_tensor_quant_mla_fp8 for cublas 12.9
if _is_cublas_ge_129:
q_nope_val, q_nope_scale = input_to_float8(
q_nope.transpose(0, 1), torch.float8_e4m3fn
)
else:
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
q_nope.transpose(0, 1), zero_allocator.allocate(1)
)
# fix bmm_fp8 error under cublas12.9 caused by bumpallocator, detail in pr#11612
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
q_nope.transpose(0, 1),
(
torch.zeros((1,), dtype=torch.float32, device=q_nope.device)
if _is_cublas_ge_129
else zero_allocator.allocate(1)
),
)
q_nope_out = bmm_fp8(
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
)
......@@ -1772,14 +1771,14 @@ class DeepseekV2AttentionMLA(nn.Module):
attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
elif self.w_vc.dtype == torch.float8_e4m3fn:
if _is_cublas_ge_129:
attn_output_val, attn_output_scale = input_to_float8(
attn_output.transpose(0, 1), torch.float8_e4m3fn
)
else:
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
attn_output.transpose(0, 1), zero_allocator.allocate(1)
)
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
attn_output.transpose(0, 1),
(
torch.zeros((1,), dtype=torch.float32, device=attn_output.device)
if _is_cublas_ge_129
else zero_allocator.allocate(1)
),
)
attn_bmm_output = bmm_fp8(
attn_output_val,
self.w_vc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment