Unverified Commit 47824c14 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

[Perf] Auto enable best flashinfer mxfp4 kernel in b200 (#8898)

parent c36a6693
......@@ -206,13 +206,13 @@ class FusedMoE(torch.nn.Module):
assert self.quant_method is not None
self.quant_config = quant_config
self.use_enable_flashinfer_mxfp4_moe = global_server_args_dict.get(
"enable_flashinfer_mxfp4_moe", False
)
if (
self.quant_config is not None
and self.quant_config.get_name() == "mxfp4"
and (
get_bool_env_var("SGLANG_USE_FLASHINFER_MXFP4_MOE")
or get_bool_env_var("SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE")
)
and self.use_enable_flashinfer_mxfp4_moe
):
hidden_size = round_up(hidden_size, 256)
self.hidden_size = hidden_size
......
......@@ -3,22 +3,20 @@
from __future__ import annotations
import importlib
import importlib.util
import logging
from typing import TYPE_CHECKING, Callable, List, Optional
from typing import TYPE_CHECKING, List, Optional
import torch
from torch.nn.parameter import Parameter
# from vllm.model_executor.layers.fused_moe import (
# FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
# FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize)
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import (
direct_register_custom_op,
get_bool_env_var,
......@@ -32,11 +30,6 @@ from sglang.srt.utils import (
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
# Environment variables for FlashInfer MXFP4 MoE backend
USE_FLASHINFER_MXFP4_MOE = get_bool_env_var("SGLANG_USE_FLASHINFER_MXFP4_MOE", "false")
USE_FLASHINFER_MXFP4_BF16_MOE = get_bool_env_var(
"SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE", "false"
)
if is_flashinfer_available():
# from flashinfer.fused_moe import cutlass_fused_moe
......@@ -193,7 +186,12 @@ class Mxfp4Config(QuantizationConfig):
):
return UnquantizedLinearMethod()
elif isinstance(layer, FusedMoE):
return Mxfp4MoEMethod(use_triton_kernels=True, with_bias=True)
use_flashinfer = global_server_args_dict.get(
"enable_flashinfer_mxfp4_moe", False
)
return Mxfp4MoEMethod(
use_triton_kernels=True, with_bias=True, use_flashinfer=use_flashinfer
)
else:
raise NotImplementedError("Mxfp4 attention layer is not implemented")
return None
......@@ -204,11 +202,18 @@ class Mxfp4Config(QuantizationConfig):
class Mxfp4MoEMethod(FusedMoEMethodBase):
def __init__(self, use_triton_kernels: bool = True, with_bias: bool = True):
def __init__(
self,
use_triton_kernels: bool = True,
with_bias: bool = True,
use_flashinfer: bool = False,
):
super().__init__()
self.topk_indices_dtype = None
self.use_triton_kernels = use_triton_kernels
self.with_bias = with_bias
self.use_flashinfer = use_flashinfer
self.triton_kernel_moe_forward = None
self.triton_kernel_moe_with_bias_forward = None
if torch.cuda.is_available() and has_triton_kernels:
......@@ -239,7 +244,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
# pad the intermediate size to be a multiple of 2 * mxfp4_block
# for to hold non-uniform sharded tensor as well as swizzling
if USE_FLASHINFER_MXFP4_MOE or USE_FLASHINFER_MXFP4_BF16_MOE:
if self.use_flashinfer:
intermediate_size_per_partition_after_pad = round_up(intermediate_size, 256)
hidden_size = round_up(hidden_size, 256)
elif is_hip():
......@@ -319,7 +324,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
set_weight_attrs(w2_weight_bias, extra_weight_attrs)
def process_weights_after_loading(self, layer):
if USE_FLASHINFER_MXFP4_MOE or USE_FLASHINFER_MXFP4_BF16_MOE:
if self.use_flashinfer:
logger.info(
"Shuffling MoE weights for FlashInfer, it might take a while..."
)
......@@ -544,20 +549,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
) -> torch.Tensor:
if USE_FLASHINFER_MXFP4_MOE or USE_FLASHINFER_MXFP4_BF16_MOE:
# When USE_FLASHINFER_MXFP4_BF16_MOE is enabled, we don't need to quantize the input,
# TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations,
# which can theoretically improve performance
if USE_FLASHINFER_MXFP4_BF16_MOE:
assert x.dtype == torch.bfloat16
x_quant = x
x_scale = None
else:
x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
if self.use_flashinfer:
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
topk_weights, topk_ids, router_logits = topk_output
top_k = topk_weights.shape[-1]
top_k, router_logits = topk_output
trtllm_gen_output = trtllm_fp4_block_scale_moe(
router_logits.to(torch.bfloat16),
......
......@@ -107,6 +107,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
"num_reserved_decode_tokens",
"weight_loader_disable_mmap",
"enable_triton_kernel_moe",
"enable_flashinfer_mxfp4_moe",
"enable_multimodal",
"enable_symm_mem",
"quantization",
......
......@@ -102,11 +102,15 @@ class GptOssSparseMoeBlock(nn.Module):
f"the number of experts {config.num_local_experts}."
)
self.topk = TopK(
top_k=config.num_experts_per_tok,
renormalize=True,
)
if global_server_args_dict["enable_flashinfer_mxfp4_moe"]:
self.topk = None
else:
self.topk = TopK(
top_k=config.num_experts_per_tok,
renormalize=True,
)
self.top_k = config.num_experts_per_tok
experts_type = get_moe_impl_class()
extra_kwargs = {}
if experts_type.__name__ == "FusedMoE":
......@@ -176,7 +180,7 @@ class GptOssSparseMoeBlock(nn.Module):
if self.topk is not None:
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
else:
kwargs["router_logits"] = router_logits
kwargs["topk_output"] = (self.top_k, router_logits)
final_hidden_states = self.experts(**kwargs)
if self.tp_size > 1:
......
......@@ -248,6 +248,7 @@ class ServerArgs:
disable_fast_image_processor: bool = False
enable_return_hidden_states: bool = False
enable_triton_kernel_moe: bool = False
enable_flashinfer_mxfp4_moe: bool = False
# Debug tensor dumps
debug_tensor_dump_output_folder: Optional[str] = None
......@@ -476,18 +477,10 @@ class ServerArgs:
or self.attention_backend == "triton"
)
# Check if FlashInfer MXFP4 MoE is enabled
from sglang.srt.utils import get_bool_env_var
USE_FLASHINFER_MXFP4_MOE = get_bool_env_var(
"SGLANG_USE_FLASHINFER_MXFP4_MOE", "false"
)
USE_FLASHINFER_MXFP4_BF16_MOE = get_bool_env_var(
"SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE", "false"
)
# Only enable Triton kernel MoE if FlashInfer is not enabled
if not (USE_FLASHINFER_MXFP4_MOE or USE_FLASHINFER_MXFP4_BF16_MOE):
if is_sm100_supported():
self.enable_flashinfer_mxfp4_moe = True
self.enable_triton_kernel_moe = False
else:
self.enable_triton_kernel_moe = True
self.disable_hybrid_swa_memory = True
......@@ -1846,6 +1839,11 @@ class ServerArgs:
action="store_true",
help="Use triton moe grouped gemm kernel.",
)
parser.add_argument(
"--enable-flashinfer-mxfp4-moe",
action="store_true",
help="Enable FlashInfer MXFP4 MoE backend for modelopt_fp4 quant on Blackwell.",
)
# Debug tensor dumps
parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment