Unverified Commit 0374304a authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Add enable_flashinfer_mxfp4_bf16_moe for higher precision and slower moe backend (#9004)

parent 127d4b0d
...@@ -30,6 +30,7 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -30,6 +30,7 @@ from sglang.srt.layers.quantization.base_config import (
) )
from sglang.srt.layers.quantization.utils import is_layer_skipped from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import ( from sglang.srt.utils import (
direct_register_custom_op, direct_register_custom_op,
get_bool_env_var, get_bool_env_var,
...@@ -262,6 +263,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -262,6 +263,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel() self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
self.with_bias = False self.with_bias = False
self.use_flashinfer = get_moe_runner_backend().is_flashinfer_mxfp4() self.use_flashinfer = get_moe_runner_backend().is_flashinfer_mxfp4()
self.flashinfer_mxfp4_moe_precision = global_server_args_dict[
"flashinfer_mxfp4_moe_precision"
]
self.triton_kernel_moe_forward = None self.triton_kernel_moe_forward = None
self.triton_kernel_moe_with_bias_forward = None self.triton_kernel_moe_with_bias_forward = None
...@@ -615,11 +619,29 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -615,11 +619,29 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
from sglang.srt.layers.moe.topk import TopKOutputChecker from sglang.srt.layers.moe.topk import TopKOutputChecker
if self.use_flashinfer: if self.use_flashinfer:
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance # When bf16 mode is enabled, we don't need to quantize the input,
x_quant, x_scale = mxfp8_quantize( # TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations,
x, False, alignment=self.hidden_size # which can theoretically improve performance
) # to mxfp8 if self.flashinfer_mxfp4_moe_precision == "bf16":
assert x.dtype == torch.bfloat16
x_quant = x
x_scale = None
# May be fused later if this code branch is frequently needed
origin_hidden_states_dim = x_quant.shape[-1]
if self.hidden_size != origin_hidden_states_dim:
x_quant = torch.nn.functional.pad(
x_quant,
(0, self.hidden_size - origin_hidden_states_dim),
mode="constant",
value=0.0,
)
elif self.flashinfer_mxfp4_moe_precision == "default":
x_quant, x_scale = mxfp8_quantize(x, False, alignment=self.hidden_size)
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1) x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
else:
raise NotImplementedError
assert x_quant.shape[-1] == self.hidden_size assert x_quant.shape[-1] == self.hidden_size
assert TopKOutputChecker.format_is_bypassed(topk_output) assert TopKOutputChecker.format_is_bypassed(topk_output)
......
...@@ -87,6 +87,7 @@ GLOBAL_SERVER_ARGS_KEYS = [ ...@@ -87,6 +87,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
"disable_flashinfer_cutlass_moe_fp4_allgather", "disable_flashinfer_cutlass_moe_fp4_allgather",
"disable_radix_cache", "disable_radix_cache",
"enable_dp_lm_head", "enable_dp_lm_head",
"flashinfer_mxfp4_moe_precision",
"enable_flashinfer_allreduce_fusion", "enable_flashinfer_allreduce_fusion",
"moe_dense_tp_size", "moe_dense_tp_size",
"ep_dispatch_algorithm", "ep_dispatch_algorithm",
......
...@@ -190,6 +190,7 @@ class ServerArgs: ...@@ -190,6 +190,7 @@ class ServerArgs:
"flashinfer_cutlass", "flashinfer_cutlass",
"flashinfer_mxfp4", "flashinfer_mxfp4",
] = "auto" ] = "auto"
flashinfer_mxfp4_moe_precision: Literal["default", "bf16"] = "default"
enable_flashinfer_allreduce_fusion: bool = False enable_flashinfer_allreduce_fusion: bool = False
deepep_mode: Literal["auto", "normal", "low_latency"] = "auto" deepep_mode: Literal["auto", "normal", "low_latency"] = "auto"
ep_num_redundant_experts: int = 0 ep_num_redundant_experts: int = 0
...@@ -1496,10 +1497,18 @@ class ServerArgs: ...@@ -1496,10 +1497,18 @@ class ServerArgs:
"triton_kernel", "triton_kernel",
"flashinfer_trtllm", "flashinfer_trtllm",
"flashinfer_cutlass", "flashinfer_cutlass",
"flashinfer_mxfp4",
], ],
default=ServerArgs.moe_runner_backend, default=ServerArgs.moe_runner_backend,
help="Choose the runner backend for MoE.", help="Choose the runner backend for MoE.",
) )
parser.add_argument(
"--flashinfer-mxfp4-moe-precision",
type=str,
choices=["mxfp4", "bf16"],
default=ServerArgs.flashinfer_mxfp4_moe_precision,
help="Choose the computation precision of flashinfer mxfp4 moe",
)
parser.add_argument( parser.add_argument(
"--enable-flashinfer-allreduce-fusion", "--enable-flashinfer-allreduce-fusion",
action="store_true", action="store_true",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment