Unverified Commit 5e5c30d9 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Tiny let DeepGEMM scale checks cover more cases (#7182)


Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
parent 9f00ec44
......@@ -11,6 +11,7 @@ from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
ENABLE_JIT_DEEPGEMM,
)
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_bool_env_var
logger = logging.getLogger(__name__)
......@@ -18,6 +19,8 @@ if ENABLE_JIT_DEEPGEMM:
import deep_gemm
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
_SANITY_CHECK = get_bool_env_var("SGLANG_DEEPGEMM_SANITY_CHECK")
# TODO maybe rename these functions
def grouped_gemm_nt_f8f8bf16_masked(
......@@ -31,6 +34,9 @@ def grouped_gemm_nt_f8f8bf16_masked(
_, n, _ = rhs[0].shape
kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED
_sanity_check_input(lhs)
_sanity_check_input(rhs)
with compile_utils.deep_gemm_execution_hook(
expected_m, n, k, num_groups, kernel_type
):
......@@ -53,6 +59,9 @@ def grouped_gemm_nt_f8f8bf16_contig(
num_groups, n, _ = rhs[0].shape
kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG
_sanity_check_input(lhs)
_sanity_check_input(rhs)
with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(lhs, rhs, out, m_indices)
......@@ -67,6 +76,9 @@ def gemm_nt_f8f8bf16(
num_groups = 1
kernel_type = compile_utils.DeepGemmKernelType.GEMM_NT_F8F8BF16
_sanity_check_input(lhs)
_sanity_check_input(rhs)
with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
deep_gemm.fp8_gemm_nt(
lhs,
......@@ -90,3 +102,18 @@ def configure_deep_gemm_num_sms(num_sms):
yield
finally:
deep_gemm.set_num_sms(original_num_sms)
def _sanity_check_input(x_fp8: Tuple[torch.Tensor, torch.Tensor]):
if not _SANITY_CHECK:
return
x, x_scale = x_fp8
if x_scale.dtype == torch.int:
return
from sglang.srt.layers.quantization.fp8_utils import ceil_to_ue8m0
x_scale_ceil = ceil_to_ue8m0(x_scale)
assert torch.all(x_scale == x_scale_ceil), f"{x_scale=} {x_scale_ceil=}"
......@@ -248,11 +248,6 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
# NOTE(alcanderian): Useless when scale is packed to int32
# if get_bool_env_var("SGLANG_W8A8_DEEPGEMM_SANITY_CHECK_UE8M0"):
# _check_ue8m0("x_scale", x_scale)
# _check_ue8m0("weight_scale", ws)
output = w8a8_block_fp8_matmul_deepgemm(
q_input, weight, x_scale, weight_scale, block_size, output_dtype=output_dtype
)
......@@ -261,11 +256,6 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
return output.to(dtype=output_dtype).view(*output_shape)
def _check_ue8m0(name, x):
x_ceil = ceil_to_ue8m0(x)
assert torch.all(x == x_ceil), f"{name=} {x=} {x_ceil=}"
def aiter_w8a8_block_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment