Unverified Commit bec3e484 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support new DeepGEMM format in per token group quant (part 2: srt) (#7155)

parent 8ab7d93c
...@@ -49,7 +49,7 @@ runtime_common = [ ...@@ -49,7 +49,7 @@ runtime_common = [
srt = [ srt = [
"sglang[runtime_common]", "sglang[runtime_common]",
"sgl-kernel==0.1.7", "sgl-kernel==0.1.8.post1",
"flashinfer_python==0.2.6.post1", "flashinfer_python==0.2.6.post1",
"torch==2.7.1", "torch==2.7.1",
"torchaudio==2.7.1", "torchaudio==2.7.1",
......
...@@ -605,7 +605,7 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -605,7 +605,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if _is_cuda: if _is_cuda:
assert_pkg_version( assert_pkg_version(
"sgl-kernel", "sgl-kernel",
"0.1.7", "0.1.8.post1",
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`", "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
) )
......
...@@ -280,6 +280,7 @@ def sglang_per_token_group_quant_fp8( ...@@ -280,6 +280,7 @@ def sglang_per_token_group_quant_fp8(
eps: float = 1e-10, eps: float = 1e-10,
column_major_scales: bool = False, column_major_scales: bool = False,
scale_tma_aligned: bool = False, scale_tma_aligned: bool = False,
scale_ue8m0: bool = False,
): ):
assert ( assert (
x.shape[-1] % group_size == 0 x.shape[-1] % group_size == 0
...@@ -287,8 +288,20 @@ def sglang_per_token_group_quant_fp8( ...@@ -287,8 +288,20 @@ def sglang_per_token_group_quant_fp8(
assert x.is_contiguous(), "`x` is not contiguous" assert x.is_contiguous(), "`x` is not contiguous"
x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype) x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
if column_major_scales: if scale_ue8m0:
assert column_major_scales and scale_tma_aligned
x_q_mn, x_q_k = x.shape
x_s_mn, x_s_k = x_q_mn, x_q_k // 128
aligned_mn = align(x_s_mn, 4)
aligned_k = align(x_s_k, 4)
x_s = torch.empty(
(aligned_k // 4, aligned_mn),
device=x.device,
dtype=torch.int,
).permute(-1, -2)[:x_s_mn, :]
elif column_major_scales:
if scale_tma_aligned: if scale_tma_aligned:
# TODO extract "align" function
# aligned to 4 * sizeof(float) # aligned to 4 * sizeof(float)
aligned_size = (x.shape[-2] + 3) // 4 * 4 aligned_size = (x.shape[-2] + 3) // 4 * 4
x_s = torch.empty( x_s = torch.empty(
...@@ -309,7 +322,9 @@ def sglang_per_token_group_quant_fp8( ...@@ -309,7 +322,9 @@ def sglang_per_token_group_quant_fp8(
dtype=torch.float32, dtype=torch.float32,
) )
if x.shape[0] > 0: if x.shape[0] > 0:
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max) sgl_per_token_group_quant_fp8(
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
)
return x_q, x_s return x_q, x_s
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment