Unverified Commit b2a189dd authored by strgrb's avatar strgrb Committed by GitHub
Browse files

use sglang_per_token_group_quant_fp8 from sgl-kernel instead of trion kernel (#5473)


Co-authored-by: default avatarZhang Kaihong <zhangkaihong.zkh@alibaba-inc.com>
parent f28d8299
......@@ -275,6 +275,8 @@ def sglang_per_token_group_quant_fp8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
column_major_scales: bool = False,
scale_tma_aligned: bool = False,
):
assert (
x.shape[-1] % group_size == 0
......@@ -282,11 +284,28 @@ def sglang_per_token_group_quant_fp8(
assert x.is_contiguous(), "`x` is not contiguous"
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
x_s = torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size,),
device=x.device,
dtype=torch.float32,
)
if column_major_scales:
if scale_tma_aligned:
# aligned to 4 * sizeof(float)
aligned_size = (x.shape[-2] + 3) // 4 * 4
x_s = torch.empty(
x.shape[:-2] + (x.shape[-1] // group_size, aligned_size),
device=x.device,
dtype=torch.float32,
).permute(-1, -2)[: x.shape[-2], :]
else:
x_s = torch.empty(
(x.shape[-1] // group_size,) + x.shape[:-1],
device=x.device,
dtype=torch.float32,
).permute(-1, -2)
else:
x_s = torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size,),
device=x.device,
dtype=torch.float32,
)
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
return x_q, x_s
......
......@@ -141,7 +141,7 @@ def apply_w8a8_block_fp8_linear(
gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
else:
if _enable_jit_deepgemm:
q_input, x_scale = per_token_group_quant_fp8(
q_input, x_scale = sglang_per_token_group_quant_fp8(
input_2d,
block_size[1],
column_major_scales=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment