Unverified Commit b188a89a authored by YanbingJiang's avatar YanbingJiang Committed by GitHub
Browse files

Fix CI xeon test with triton 3.3.1 (#8086)

parent 497efe74
...@@ -29,6 +29,7 @@ from sglang.srt.utils import ( ...@@ -29,6 +29,7 @@ from sglang.srt.utils import (
direct_register_custom_op, direct_register_custom_op,
get_device_core_count, get_device_core_count,
get_device_name, get_device_name,
is_cpu,
is_cuda, is_cuda,
is_hip, is_hip,
log_info_on_rank0, log_info_on_rank0,
...@@ -37,6 +38,7 @@ from sglang.srt.utils import ( ...@@ -37,6 +38,7 @@ from sglang.srt.utils import (
_is_hip = is_hip() _is_hip = is_hip()
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_cpu = is_cpu()
if _is_cuda: if _is_cuda:
from sgl_kernel import ( from sgl_kernel import (
...@@ -1168,7 +1170,7 @@ def scaled_fp8_quant( ...@@ -1168,7 +1170,7 @@ def scaled_fp8_quant(
return output, scale return output, scale
@triton.autotune( fp8_autotune = triton.autotune(
configs=[ configs=[
triton.Config({"BLOCK_M": block_m}, num_warps=num_warps) triton.Config({"BLOCK_M": block_m}, num_warps=num_warps)
for block_m in [16, 32, 64, 128] for block_m in [16, 32, 64, 128]
...@@ -1176,6 +1178,8 @@ def scaled_fp8_quant( ...@@ -1176,6 +1178,8 @@ def scaled_fp8_quant(
], ],
key=["K", "BLOCK_K", "M_ALIGNMENT"], key=["K", "BLOCK_K", "M_ALIGNMENT"],
) )
@triton.jit @triton.jit
def _per_token_group_quant_fp8_hopper_moe_mn_major( def _per_token_group_quant_fp8_hopper_moe_mn_major(
a, # (M, K):(K, 1) a, # (M, K):(K, 1)
...@@ -1221,6 +1225,12 @@ def _per_token_group_quant_fp8_hopper_moe_mn_major( ...@@ -1221,6 +1225,12 @@ def _per_token_group_quant_fp8_hopper_moe_mn_major(
tl.store(sfa_ptrs, inp_amax / 448.0, mask=coord_m < m) tl.store(sfa_ptrs, inp_amax / 448.0, mask=coord_m < m)
if not _is_cpu:
_per_token_group_quant_fp8_hopper_moe_mn_major = fp8_autotune(
_per_token_group_quant_fp8_hopper_moe_mn_major
)
def per_token_group_quant_fp8_hopper_moe_mn_major( def per_token_group_quant_fp8_hopper_moe_mn_major(
A: torch.Tensor, A: torch.Tensor,
expert_offsets: torch.Tensor, expert_offsets: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment