Unverified Commit 0d1e27a0 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

Better optimization log for gpt-oss model (#8953)

parent 774b47f3
......@@ -24,6 +24,7 @@ from sglang.srt.utils import (
is_cuda,
is_flashinfer_available,
is_hip,
log_info_on_rank0,
next_power_of_2,
round_up,
set_weight_attrs,
......@@ -34,7 +35,6 @@ has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
if is_flashinfer_available():
# from flashinfer.fused_moe import cutlass_fused_moe
from flashinfer import (
mxfp8_quantize,
shuffle_matrix_a,
......@@ -63,7 +63,7 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps):
scale_layout, scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout(
mx_axis=1, num_warps=num_warps
)
if is_cuda() and torch.cuda.get_device_capability()[0] == 10:
if _is_sm100_supported:
constraints = {
"is_persistent": True,
"epilogue_subtile": 1,
......@@ -331,8 +331,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def process_weights_after_loading(self, layer):
if self.use_flashinfer:
logger.info(
"Shuffling MoE weights for FlashInfer, it might take a while..."
log_info_on_rank0(
logger,
"Shuffling MoE weights for FlashInfer MXFP4 moe kernel, it might take a while...",
)
layer.gemm1_alpha = Parameter(
torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
......
......@@ -488,8 +488,14 @@ class ServerArgs:
if is_sm100_supported() and is_mxfp4_quant_format:
self.enable_flashinfer_mxfp4_moe = True
self.enable_triton_kernel_moe = False
logger.info(
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
)
else:
self.enable_triton_kernel_moe = True
logger.info(
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."
)
self.disable_hybrid_swa_memory = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment