Unverified Commit 0f1dfa1e authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Tiny add sanity checks for DeepGEMM inputs (#7157)

parent e3ec6bf4
......@@ -239,6 +239,11 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
column_major_scales=True,
scale_tma_aligned=True,
)
if get_bool_env_var("SGLANG_W8A8_DEEPGEMM_SANITY_CHECK_UE8M0"):
_check_ue8m0("x_scale", x_scale)
_check_ue8m0("weight_scale", weight_scale)
output = w8a8_block_fp8_matmul_deepgemm(
q_input, weight, x_scale, weight_scale, block_size, output_dtype=output_dtype
)
......@@ -247,6 +252,11 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
return output.to(dtype=output_dtype).view(*output_shape)
def _check_ue8m0(name, x):
x_ceil = ceil_to_ue8m0(x)
assert torch.all(x == x_ceil), f"{name=} {x=} {x_ceil=}"
def aiter_w8a8_block_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,
......@@ -380,6 +390,11 @@ def block_quant_dequant(
return (x_q_block.to(torch.float32) * x_scale_repeat).to(dtype)
# COPIED FROM DeepGEMM
def ceil_to_ue8m0(x: torch.Tensor):
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
def channel_quant_to_tensor_quant(
x_q_channel: torch.Tensor,
x_s: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment