Unverified Commit 47824c14 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

[Perf] Auto enable best flashinfer mxfp4 kernel in b200 (#8898)

parent c36a6693
...@@ -206,13 +206,13 @@ class FusedMoE(torch.nn.Module): ...@@ -206,13 +206,13 @@ class FusedMoE(torch.nn.Module):
assert self.quant_method is not None assert self.quant_method is not None
self.quant_config = quant_config self.quant_config = quant_config
self.use_enable_flashinfer_mxfp4_moe = global_server_args_dict.get(
"enable_flashinfer_mxfp4_moe", False
)
if ( if (
self.quant_config is not None self.quant_config is not None
and self.quant_config.get_name() == "mxfp4" and self.quant_config.get_name() == "mxfp4"
and ( and self.use_enable_flashinfer_mxfp4_moe
get_bool_env_var("SGLANG_USE_FLASHINFER_MXFP4_MOE")
or get_bool_env_var("SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE")
)
): ):
hidden_size = round_up(hidden_size, 256) hidden_size = round_up(hidden_size, 256)
self.hidden_size = hidden_size self.hidden_size = hidden_size
......
...@@ -3,22 +3,20 @@ ...@@ -3,22 +3,20 @@
from __future__ import annotations from __future__ import annotations
import importlib import importlib.util
import logging import logging
from typing import TYPE_CHECKING, Callable, List, Optional from typing import TYPE_CHECKING, List, Optional
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
# from vllm.model_executor.layers.fused_moe import (
# FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
# FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize)
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase, FusedMoEMethodBase,
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.utils import is_layer_skipped from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import ( from sglang.srt.utils import (
direct_register_custom_op, direct_register_custom_op,
get_bool_env_var, get_bool_env_var,
...@@ -32,11 +30,6 @@ from sglang.srt.utils import ( ...@@ -32,11 +30,6 @@ from sglang.srt.utils import (
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
# Environment variables for FlashInfer MXFP4 MoE backend
USE_FLASHINFER_MXFP4_MOE = get_bool_env_var("SGLANG_USE_FLASHINFER_MXFP4_MOE", "false")
USE_FLASHINFER_MXFP4_BF16_MOE = get_bool_env_var(
"SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE", "false"
)
if is_flashinfer_available(): if is_flashinfer_available():
# from flashinfer.fused_moe import cutlass_fused_moe # from flashinfer.fused_moe import cutlass_fused_moe
...@@ -193,7 +186,12 @@ class Mxfp4Config(QuantizationConfig): ...@@ -193,7 +186,12 @@ class Mxfp4Config(QuantizationConfig):
): ):
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
return Mxfp4MoEMethod(use_triton_kernels=True, with_bias=True) use_flashinfer = global_server_args_dict.get(
"enable_flashinfer_mxfp4_moe", False
)
return Mxfp4MoEMethod(
use_triton_kernels=True, with_bias=True, use_flashinfer=use_flashinfer
)
else: else:
raise NotImplementedError("Mxfp4 attention layer is not implemented") raise NotImplementedError("Mxfp4 attention layer is not implemented")
return None return None
...@@ -204,11 +202,18 @@ class Mxfp4Config(QuantizationConfig): ...@@ -204,11 +202,18 @@ class Mxfp4Config(QuantizationConfig):
class Mxfp4MoEMethod(FusedMoEMethodBase): class Mxfp4MoEMethod(FusedMoEMethodBase):
def __init__(self, use_triton_kernels: bool = True, with_bias: bool = True): def __init__(
self,
use_triton_kernels: bool = True,
with_bias: bool = True,
use_flashinfer: bool = False,
):
super().__init__() super().__init__()
self.topk_indices_dtype = None self.topk_indices_dtype = None
self.use_triton_kernels = use_triton_kernels self.use_triton_kernels = use_triton_kernels
self.with_bias = with_bias self.with_bias = with_bias
self.use_flashinfer = use_flashinfer
self.triton_kernel_moe_forward = None self.triton_kernel_moe_forward = None
self.triton_kernel_moe_with_bias_forward = None self.triton_kernel_moe_with_bias_forward = None
if torch.cuda.is_available() and has_triton_kernels: if torch.cuda.is_available() and has_triton_kernels:
...@@ -239,7 +244,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -239,7 +244,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
# pad the intermediate size to be a multiple of 2 * mxfp4_block # pad the intermediate size to be a multiple of 2 * mxfp4_block
# for to hold non-uniform sharded tensor as well as swizzling # for to hold non-uniform sharded tensor as well as swizzling
if USE_FLASHINFER_MXFP4_MOE or USE_FLASHINFER_MXFP4_BF16_MOE: if self.use_flashinfer:
intermediate_size_per_partition_after_pad = round_up(intermediate_size, 256) intermediate_size_per_partition_after_pad = round_up(intermediate_size, 256)
hidden_size = round_up(hidden_size, 256) hidden_size = round_up(hidden_size, 256)
elif is_hip(): elif is_hip():
...@@ -319,7 +324,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -319,7 +324,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
set_weight_attrs(w2_weight_bias, extra_weight_attrs) set_weight_attrs(w2_weight_bias, extra_weight_attrs)
def process_weights_after_loading(self, layer): def process_weights_after_loading(self, layer):
if USE_FLASHINFER_MXFP4_MOE or USE_FLASHINFER_MXFP4_BF16_MOE: if self.use_flashinfer:
logger.info( logger.info(
"Shuffling MoE weights for FlashInfer, it might take a while..." "Shuffling MoE weights for FlashInfer, it might take a while..."
) )
...@@ -544,20 +549,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -544,20 +549,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
activation_alpha: Optional[float] = None, activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None, swiglu_limit: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if USE_FLASHINFER_MXFP4_MOE or USE_FLASHINFER_MXFP4_BF16_MOE: if self.use_flashinfer:
# When USE_FLASHINFER_MXFP4_BF16_MOE is enabled, we don't need to quantize the input, # Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
# TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations,
# which can theoretically improve performance
if USE_FLASHINFER_MXFP4_BF16_MOE:
assert x.dtype == torch.bfloat16
x_quant = x
x_scale = None
else:
x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8 x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1) x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
topk_weights, topk_ids, router_logits = topk_output top_k, router_logits = topk_output
top_k = topk_weights.shape[-1]
trtllm_gen_output = trtllm_fp4_block_scale_moe( trtllm_gen_output = trtllm_fp4_block_scale_moe(
router_logits.to(torch.bfloat16), router_logits.to(torch.bfloat16),
......
...@@ -107,6 +107,7 @@ GLOBAL_SERVER_ARGS_KEYS = [ ...@@ -107,6 +107,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
"num_reserved_decode_tokens", "num_reserved_decode_tokens",
"weight_loader_disable_mmap", "weight_loader_disable_mmap",
"enable_triton_kernel_moe", "enable_triton_kernel_moe",
"enable_flashinfer_mxfp4_moe",
"enable_multimodal", "enable_multimodal",
"enable_symm_mem", "enable_symm_mem",
"quantization", "quantization",
......
...@@ -102,11 +102,15 @@ class GptOssSparseMoeBlock(nn.Module): ...@@ -102,11 +102,15 @@ class GptOssSparseMoeBlock(nn.Module):
f"the number of experts {config.num_local_experts}." f"the number of experts {config.num_local_experts}."
) )
if global_server_args_dict["enable_flashinfer_mxfp4_moe"]:
self.topk = None
else:
self.topk = TopK( self.topk = TopK(
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
renormalize=True, renormalize=True,
) )
self.top_k = config.num_experts_per_tok
experts_type = get_moe_impl_class() experts_type = get_moe_impl_class()
extra_kwargs = {} extra_kwargs = {}
if experts_type.__name__ == "FusedMoE": if experts_type.__name__ == "FusedMoE":
...@@ -176,7 +180,7 @@ class GptOssSparseMoeBlock(nn.Module): ...@@ -176,7 +180,7 @@ class GptOssSparseMoeBlock(nn.Module):
if self.topk is not None: if self.topk is not None:
kwargs["topk_output"] = self.topk(hidden_states, router_logits) kwargs["topk_output"] = self.topk(hidden_states, router_logits)
else: else:
kwargs["router_logits"] = router_logits kwargs["topk_output"] = (self.top_k, router_logits)
final_hidden_states = self.experts(**kwargs) final_hidden_states = self.experts(**kwargs)
if self.tp_size > 1: if self.tp_size > 1:
......
...@@ -248,6 +248,7 @@ class ServerArgs: ...@@ -248,6 +248,7 @@ class ServerArgs:
disable_fast_image_processor: bool = False disable_fast_image_processor: bool = False
enable_return_hidden_states: bool = False enable_return_hidden_states: bool = False
enable_triton_kernel_moe: bool = False enable_triton_kernel_moe: bool = False
enable_flashinfer_mxfp4_moe: bool = False
# Debug tensor dumps # Debug tensor dumps
debug_tensor_dump_output_folder: Optional[str] = None debug_tensor_dump_output_folder: Optional[str] = None
...@@ -476,18 +477,10 @@ class ServerArgs: ...@@ -476,18 +477,10 @@ class ServerArgs:
or self.attention_backend == "triton" or self.attention_backend == "triton"
) )
# Check if FlashInfer MXFP4 MoE is enabled if is_sm100_supported():
from sglang.srt.utils import get_bool_env_var self.enable_flashinfer_mxfp4_moe = True
self.enable_triton_kernel_moe = False
USE_FLASHINFER_MXFP4_MOE = get_bool_env_var( else:
"SGLANG_USE_FLASHINFER_MXFP4_MOE", "false"
)
USE_FLASHINFER_MXFP4_BF16_MOE = get_bool_env_var(
"SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE", "false"
)
# Only enable Triton kernel MoE if FlashInfer is not enabled
if not (USE_FLASHINFER_MXFP4_MOE or USE_FLASHINFER_MXFP4_BF16_MOE):
self.enable_triton_kernel_moe = True self.enable_triton_kernel_moe = True
self.disable_hybrid_swa_memory = True self.disable_hybrid_swa_memory = True
...@@ -1846,6 +1839,11 @@ class ServerArgs: ...@@ -1846,6 +1839,11 @@ class ServerArgs:
action="store_true", action="store_true",
help="Use triton moe grouped gemm kernel.", help="Use triton moe grouped gemm kernel.",
) )
parser.add_argument(
"--enable-flashinfer-mxfp4-moe",
action="store_true",
help="Enable FlashInfer MXFP4 MoE backend for modelopt_fp4 quant on Blackwell.",
)
# Debug tensor dumps # Debug tensor dumps
parser.add_argument( parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment