Unverified Commit 0d1e27a0 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

Better optimization log for gpt-oss model (#8953)

parent 774b47f3
...@@ -24,6 +24,7 @@ from sglang.srt.utils import ( ...@@ -24,6 +24,7 @@ from sglang.srt.utils import (
is_cuda, is_cuda,
is_flashinfer_available, is_flashinfer_available,
is_hip, is_hip,
log_info_on_rank0,
next_power_of_2, next_power_of_2,
round_up, round_up,
set_weight_attrs, set_weight_attrs,
...@@ -34,7 +35,6 @@ has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None ...@@ -34,7 +35,6 @@ has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
if is_flashinfer_available(): if is_flashinfer_available():
# from flashinfer.fused_moe import cutlass_fused_moe
from flashinfer import ( from flashinfer import (
mxfp8_quantize, mxfp8_quantize,
shuffle_matrix_a, shuffle_matrix_a,
...@@ -63,7 +63,7 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps): ...@@ -63,7 +63,7 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps):
scale_layout, scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout( scale_layout, scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout(
mx_axis=1, num_warps=num_warps mx_axis=1, num_warps=num_warps
) )
if is_cuda() and torch.cuda.get_device_capability()[0] == 10: if _is_sm100_supported:
constraints = { constraints = {
"is_persistent": True, "is_persistent": True,
"epilogue_subtile": 1, "epilogue_subtile": 1,
...@@ -331,8 +331,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -331,8 +331,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def process_weights_after_loading(self, layer): def process_weights_after_loading(self, layer):
if self.use_flashinfer: if self.use_flashinfer:
logger.info( log_info_on_rank0(
"Shuffling MoE weights for FlashInfer, it might take a while..." logger,
"Shuffling MoE weights for FlashInfer MXFP4 moe kernel, it might take a while...",
) )
layer.gemm1_alpha = Parameter( layer.gemm1_alpha = Parameter(
torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(), torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
......
...@@ -488,8 +488,14 @@ class ServerArgs: ...@@ -488,8 +488,14 @@ class ServerArgs:
if is_sm100_supported() and is_mxfp4_quant_format: if is_sm100_supported() and is_mxfp4_quant_format:
self.enable_flashinfer_mxfp4_moe = True self.enable_flashinfer_mxfp4_moe = True
self.enable_triton_kernel_moe = False self.enable_triton_kernel_moe = False
logger.info(
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
)
else: else:
self.enable_triton_kernel_moe = True self.enable_triton_kernel_moe = True
logger.info(
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."
)
self.disable_hybrid_swa_memory = True self.disable_hybrid_swa_memory = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment