Unverified Commit b2a189dd authored by strgrb's avatar strgrb Committed by GitHub
Browse files

use sglang_per_token_group_quant_fp8 from sgl-kernel instead of trion kernel (#5473)


Co-authored-by: default avatarZhang Kaihong <zhangkaihong.zkh@alibaba-inc.com>
parent f28d8299
...@@ -275,6 +275,8 @@ def sglang_per_token_group_quant_fp8( ...@@ -275,6 +275,8 @@ def sglang_per_token_group_quant_fp8(
x: torch.Tensor, x: torch.Tensor,
group_size: int, group_size: int,
eps: float = 1e-10, eps: float = 1e-10,
column_major_scales: bool = False,
scale_tma_aligned: bool = False,
): ):
assert ( assert (
x.shape[-1] % group_size == 0 x.shape[-1] % group_size == 0
...@@ -282,11 +284,28 @@ def sglang_per_token_group_quant_fp8( ...@@ -282,11 +284,28 @@ def sglang_per_token_group_quant_fp8(
assert x.is_contiguous(), "`x` is not contiguous" assert x.is_contiguous(), "`x` is not contiguous"
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type) x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
x_s = torch.empty( if column_major_scales:
x.shape[:-1] + (x.shape[-1] // group_size,), if scale_tma_aligned:
device=x.device, # aligned to 4 * sizeof(float)
dtype=torch.float32, aligned_size = (x.shape[-2] + 3) // 4 * 4
) x_s = torch.empty(
x.shape[:-2] + (x.shape[-1] // group_size, aligned_size),
device=x.device,
dtype=torch.float32,
).permute(-1, -2)[: x.shape[-2], :]
else:
x_s = torch.empty(
(x.shape[-1] // group_size,) + x.shape[:-1],
device=x.device,
dtype=torch.float32,
).permute(-1, -2)
else:
x_s = torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size,),
device=x.device,
dtype=torch.float32,
)
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max) sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
return x_q, x_s return x_q, x_s
......
...@@ -141,7 +141,7 @@ def apply_w8a8_block_fp8_linear( ...@@ -141,7 +141,7 @@ def apply_w8a8_block_fp8_linear(
gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output) gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
else: else:
if _enable_jit_deepgemm: if _enable_jit_deepgemm:
q_input, x_scale = per_token_group_quant_fp8( q_input, x_scale = sglang_per_token_group_quant_fp8(
input_2d, input_2d,
block_size[1], block_size[1],
column_major_scales=True, column_major_scales=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment