Unverified Commit b188a89a authored by YanbingJiang's avatar YanbingJiang Committed by GitHub
Browse files

Fix CI xeon test with triton 3.3.1 (#8086)

parent 497efe74
......@@ -29,6 +29,7 @@ from sglang.srt.utils import (
direct_register_custom_op,
get_device_core_count,
get_device_name,
is_cpu,
is_cuda,
is_hip,
log_info_on_rank0,
......@@ -37,6 +38,7 @@ from sglang.srt.utils import (
_is_hip = is_hip()
_is_cuda = is_cuda()
_is_cpu = is_cpu()
if _is_cuda:
from sgl_kernel import (
......@@ -1168,7 +1170,7 @@ def scaled_fp8_quant(
return output, scale
@triton.autotune(
fp8_autotune = triton.autotune(
configs=[
triton.Config({"BLOCK_M": block_m}, num_warps=num_warps)
for block_m in [16, 32, 64, 128]
......@@ -1176,6 +1178,8 @@ def scaled_fp8_quant(
],
key=["K", "BLOCK_K", "M_ALIGNMENT"],
)
@triton.jit
def _per_token_group_quant_fp8_hopper_moe_mn_major(
a, # (M, K):(K, 1)
......@@ -1221,6 +1225,12 @@ def _per_token_group_quant_fp8_hopper_moe_mn_major(
tl.store(sfa_ptrs, inp_amax / 448.0, mask=coord_m < m)
if not _is_cpu:
_per_token_group_quant_fp8_hopper_moe_mn_major = fp8_autotune(
_per_token_group_quant_fp8_hopper_moe_mn_major
)
def per_token_group_quant_fp8_hopper_moe_mn_major(
A: torch.Tensor,
expert_offsets: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment