Unverified Commit 0374304a authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Add enable_flashinfer_mxfp4_bf16_moe for higher precision and slower moe backend (#9004)

parent 127d4b0d
......@@ -30,6 +30,7 @@ from sglang.srt.layers.quantization.base_config import (
)
from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import (
direct_register_custom_op,
get_bool_env_var,
......@@ -262,6 +263,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
self.with_bias = False
self.use_flashinfer = get_moe_runner_backend().is_flashinfer_mxfp4()
self.flashinfer_mxfp4_moe_precision = global_server_args_dict[
"flashinfer_mxfp4_moe_precision"
]
self.triton_kernel_moe_forward = None
self.triton_kernel_moe_with_bias_forward = None
......@@ -615,11 +619,29 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
from sglang.srt.layers.moe.topk import TopKOutputChecker
if self.use_flashinfer:
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
x_quant, x_scale = mxfp8_quantize(
x, False, alignment=self.hidden_size
) # to mxfp8
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
# When bf16 mode is enabled, we don't need to quantize the input,
# TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations,
# which can theoretically improve performance
if self.flashinfer_mxfp4_moe_precision == "bf16":
assert x.dtype == torch.bfloat16
x_quant = x
x_scale = None
# May be fused later if this code branch is frequently needed
origin_hidden_states_dim = x_quant.shape[-1]
if self.hidden_size != origin_hidden_states_dim:
x_quant = torch.nn.functional.pad(
x_quant,
(0, self.hidden_size - origin_hidden_states_dim),
mode="constant",
value=0.0,
)
elif self.flashinfer_mxfp4_moe_precision == "default":
x_quant, x_scale = mxfp8_quantize(x, False, alignment=self.hidden_size)
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
else:
raise NotImplementedError
assert x_quant.shape[-1] == self.hidden_size
assert TopKOutputChecker.format_is_bypassed(topk_output)
......
......@@ -87,6 +87,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
"disable_flashinfer_cutlass_moe_fp4_allgather",
"disable_radix_cache",
"enable_dp_lm_head",
"flashinfer_mxfp4_moe_precision",
"enable_flashinfer_allreduce_fusion",
"moe_dense_tp_size",
"ep_dispatch_algorithm",
......
......@@ -190,6 +190,7 @@ class ServerArgs:
"flashinfer_cutlass",
"flashinfer_mxfp4",
] = "auto"
flashinfer_mxfp4_moe_precision: Literal["default", "bf16"] = "default"
enable_flashinfer_allreduce_fusion: bool = False
deepep_mode: Literal["auto", "normal", "low_latency"] = "auto"
ep_num_redundant_experts: int = 0
......@@ -1496,10 +1497,18 @@ class ServerArgs:
"triton_kernel",
"flashinfer_trtllm",
"flashinfer_cutlass",
"flashinfer_mxfp4",
],
default=ServerArgs.moe_runner_backend,
help="Choose the runner backend for MoE.",
)
parser.add_argument(
"--flashinfer-mxfp4-moe-precision",
type=str,
choices=["mxfp4", "bf16"],
default=ServerArgs.flashinfer_mxfp4_moe_precision,
help="Choose the computation precision of flashinfer mxfp4 moe",
)
parser.add_argument(
"--enable-flashinfer-allreduce-fusion",
action="store_true",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment