Unverified Commit be0058bc authored by Liu-congo's avatar Liu-congo Committed by GitHub
Browse files

[BugFix] replace the input_to_float8 used in dsv2 (#11612)


Signed-off-by: default avatarLiu-congo <1502632128@qq.com>
parent 9e3be1fa
...@@ -92,7 +92,6 @@ from sglang.srt.layers.quantization.fp8_utils import ( ...@@ -92,7 +92,6 @@ from sglang.srt.layers.quantization.fp8_utils import (
block_quant_dequant, block_quant_dequant,
block_quant_to_tensor_quant, block_quant_to_tensor_quant,
channel_quant_to_tensor_quant, channel_quant_to_tensor_quant,
input_to_float8,
normalize_e4m3fn_to_e4m3fnuz, normalize_e4m3fn_to_e4m3fnuz,
quant_weight_ue8m0, quant_weight_ue8m0,
requant_weight_ue8m0_inplace, requant_weight_ue8m0_inplace,
...@@ -1623,15 +1622,15 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1623,15 +1622,15 @@ class DeepseekV2AttentionMLA(nn.Module):
self.w_kc.to(torch.bfloat16) * self.w_scale, self.w_kc.to(torch.bfloat16) * self.w_scale,
) )
elif self.w_kc.dtype == torch.float8_e4m3fn: elif self.w_kc.dtype == torch.float8_e4m3fn:
# TODO fix the per_tensor_quant_mla_fp8 for cublas 12.9 # fix bmm_fp8 error under cublas12.9 caused by bumpallocator, detail in pr#11612
if _is_cublas_ge_129: q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
q_nope_val, q_nope_scale = input_to_float8( q_nope.transpose(0, 1),
q_nope.transpose(0, 1), torch.float8_e4m3fn (
) torch.zeros((1,), dtype=torch.float32, device=q_nope.device)
else: if _is_cublas_ge_129
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8( else zero_allocator.allocate(1)
q_nope.transpose(0, 1), zero_allocator.allocate(1) ),
) )
q_nope_out = bmm_fp8( q_nope_out = bmm_fp8(
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16 q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
) )
...@@ -1772,14 +1771,14 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1772,14 +1771,14 @@ class DeepseekV2AttentionMLA(nn.Module):
attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2) attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
elif self.w_vc.dtype == torch.float8_e4m3fn: elif self.w_vc.dtype == torch.float8_e4m3fn:
if _is_cublas_ge_129: attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
attn_output_val, attn_output_scale = input_to_float8( attn_output.transpose(0, 1),
attn_output.transpose(0, 1), torch.float8_e4m3fn (
) torch.zeros((1,), dtype=torch.float32, device=attn_output.device)
else: if _is_cublas_ge_129
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8( else zero_allocator.allocate(1)
attn_output.transpose(0, 1), zero_allocator.allocate(1) ),
) )
attn_bmm_output = bmm_fp8( attn_bmm_output = bmm_fp8(
attn_output_val, attn_output_val,
self.w_vc, self.w_vc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment