Unverified Commit a3b9c17b authored by Shu Wang's avatar Shu Wang Committed by GitHub
Browse files

Support Tensorrt-LLM MoE fp4 for low-latency (#21331)


Signed-off-by: default avatarShu Wang <shuw@nvidia.com>
Signed-off-by: default avatarPo-Han Huang <pohanh@nvidia.com>
Signed-off-by: default avatarShu Wang. <shuw@nvidia.com>
Signed-off-by: default avatarXIn Li <xinli@nvidia.com>
Co-authored-by: default avatarXIn Li <xinli@nvidia.com>
parent d57dc236
...@@ -129,6 +129,7 @@ if TYPE_CHECKING: ...@@ -129,6 +129,7 @@ if TYPE_CHECKING:
VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False
VLLM_USE_FLASHINFER_MOE_FP8: bool = False VLLM_USE_FLASHINFER_MOE_FP8: bool = False
VLLM_USE_FLASHINFER_MOE_FP4: bool = False VLLM_USE_FLASHINFER_MOE_FP4: bool = False
VLLM_FLASHINFER_MOE_BACKEND: str = "throughput"
VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_XGRAMMAR_CACHE_MB: int = 0
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
...@@ -982,6 +983,20 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -982,6 +983,20 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ALL2ALL_BACKEND": "VLLM_ALL2ALL_BACKEND":
lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"), lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"),
# Flashinfer MoE backend for vLLM's fused Mixture-of-Experts support. Both
# require compute capability 10.0 or above.
# Available options:
# - "throughput": [default]
# Uses CUTLASS kernels optimized for high-throughput batch inference.
# - "latency":
# Uses TensorRT-LLM kernels optimized for low-latency inference.
# To set this backend, define the environment variable:
# export VLLM_FLASHINFER_MOE_BACKEND=latency.
# If not set, defaults to "throughput".
"VLLM_FLASHINFER_MOE_BACKEND": lambda: os.getenv(
"VLLM_FLASHINFER_MOE_BACKEND", "throughput"
),
# Control the maximum number of tokens per expert supported by the # Control the maximum number of tokens per expert supported by the
# NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for # NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for
# the blockscale tensor of activations NVFP4 Quantization. # the blockscale tensor of activations NVFP4 Quantization.
......
...@@ -192,7 +192,8 @@ class FusedMoEParallelConfig: ...@@ -192,7 +192,8 @@ class FusedMoEParallelConfig:
@property @property
def use_flashinfer_cutlass_kernels(self): def use_flashinfer_cutlass_kernels(self):
return (envs.VLLM_USE_FLASHINFER_MOE_FP4 return (envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutlass_fused_moe()) and has_flashinfer_cutlass_fused_moe()
and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput")
@staticmethod @staticmethod
def make(tp_size_: int, dp_size_: int, def make(tp_size_: int, dp_size_: int,
......
...@@ -105,7 +105,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): ...@@ -105,7 +105,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
detect_nvfp4_moe_support) detect_nvfp4_moe_support)
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__) _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
self.allow_flashinfer_cutlass = _nvfp4.allow_flashinfer_cutlass self.allow_flashinfer = _nvfp4.allow_flashinfer
self.use_marlin = _nvfp4.use_marlin self.use_marlin = _nvfp4.use_marlin
self.group_size = 16 self.group_size = 16
self.fused_experts = None # type: ignore[assignment] self.fused_experts = None # type: ignore[assignment]
...@@ -212,7 +212,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): ...@@ -212,7 +212,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
requires_grad=False) requires_grad=False)
# reorder GEMM1 weights and block scales for FlashInfer CUTLASS kernel. # reorder GEMM1 weights and block scales for FlashInfer CUTLASS kernel.
if self.allow_flashinfer_cutlass: if self.allow_flashinfer:
w, s = reorder_w1w3_to_w3w1(layer.w13_weight.data, w, s = reorder_w1w3_to_w3w1(layer.w13_weight.data,
layer.w13_weight_scale.data, layer.w13_weight_scale.data,
dim=-2) dim=-2)
...@@ -266,7 +266,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): ...@@ -266,7 +266,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
(layer.w2_input_global_scale), requires_grad=False) (layer.w2_input_global_scale), requires_grad=False)
def maybe_swap_experts_impl(self, moe_parallel_config): def maybe_swap_experts_impl(self, moe_parallel_config):
if not self.allow_flashinfer_cutlass: if not self.allow_flashinfer:
return return
self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel( self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel(
moe_parallel_config) moe_parallel_config)
...@@ -277,8 +277,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): ...@@ -277,8 +277,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501 from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501
select_nvfp4_gemm_impl) select_nvfp4_gemm_impl)
return select_nvfp4_gemm_impl(self.allow_flashinfer_cutlass, moe, return select_nvfp4_gemm_impl(self.allow_flashinfer, moe, logger)
logger)
def apply( def apply(
self, self,
......
...@@ -126,7 +126,7 @@ def flashinfer_fp4_cutlass_moe_forward( ...@@ -126,7 +126,7 @@ def flashinfer_fp4_cutlass_moe_forward(
def select_nvfp4_gemm_impl( def select_nvfp4_gemm_impl(
allow_flashinfer_cutlass: bool, allow_flashinfer: bool,
moe, # FusedMoEConfig moe, # FusedMoEConfig
logger): logger):
"""Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers""" """Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers"""
...@@ -137,8 +137,14 @@ def select_nvfp4_gemm_impl( ...@@ -137,8 +137,14 @@ def select_nvfp4_gemm_impl(
all2all_manager = get_ep_group().device_communicator.all2all_manager all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None assert all2all_manager is not None
if allow_flashinfer_cutlass: if allow_flashinfer:
logger.debug_once("Using FlashInferExperts") flashinfer_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
if flashinfer_backend != "throughput":
raise ValueError(
f"Only throughput backend is supported for FlashInferExperts, "
f"but got {flashinfer_backend}.")
logger.debug_once(
"Initializing FlashInferExperts with throughput backend.")
return FlashInferExperts( return FlashInferExperts(
use_nvfp4_w4a4=True, use_nvfp4_w4a4=True,
use_dp=moe.moe_parallel_config.dp_size > 1, use_dp=moe.moe_parallel_config.dp_size > 1,
......
...@@ -21,7 +21,7 @@ class NvFp4Support: ...@@ -21,7 +21,7 @@ class NvFp4Support:
"""Result container for NV-FP4 capability probing.""" """Result container for NV-FP4 capability probing."""
cutlass_supported: bool cutlass_supported: bool
allow_flashinfer_cutlass: bool allow_flashinfer: bool
use_marlin: bool use_marlin: bool
...@@ -54,6 +54,6 @@ def detect_nvfp4_moe_support(class_name: str = "") -> NvFp4Support: ...@@ -54,6 +54,6 @@ def detect_nvfp4_moe_support(class_name: str = "") -> NvFp4Support:
return NvFp4Support( return NvFp4Support(
cutlass_supported=cutlass_supported, cutlass_supported=cutlass_supported,
allow_flashinfer_cutlass=allow_flashinfer, allow_flashinfer=allow_flashinfer,
use_marlin=use_marlin, use_marlin=use_marlin,
) )
...@@ -86,6 +86,8 @@ flashinfer_cutlass_fused_moe = _lazy_import_wrapper("flashinfer.fused_moe", ...@@ -86,6 +86,8 @@ flashinfer_cutlass_fused_moe = _lazy_import_wrapper("flashinfer.fused_moe",
fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize") fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
nvfp4_block_scale_interleave = _lazy_import_wrapper( nvfp4_block_scale_interleave = _lazy_import_wrapper(
"flashinfer", "nvfp4_block_scale_interleave") "flashinfer", "nvfp4_block_scale_interleave")
trtllm_fp4_block_scale_moe = _lazy_import_wrapper(
"flashinfer", "trtllm_fp4_block_scale_moe")
# Special case for autotune since it returns a context manager # Special case for autotune since it returns a context manager
autotune = _lazy_import_wrapper( autotune = _lazy_import_wrapper(
...@@ -112,6 +114,7 @@ def has_flashinfer_cutlass_fused_moe() -> bool: ...@@ -112,6 +114,7 @@ def has_flashinfer_cutlass_fused_moe() -> bool:
("flashinfer.fused_moe", "cutlass_fused_moe"), ("flashinfer.fused_moe", "cutlass_fused_moe"),
("flashinfer", "fp4_quantize"), ("flashinfer", "fp4_quantize"),
("flashinfer", "nvfp4_block_scale_interleave"), ("flashinfer", "nvfp4_block_scale_interleave"),
("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"),
] ]
for module_name, attr_name in required_functions: for module_name, attr_name in required_functions:
...@@ -188,6 +191,7 @@ __all__ = [ ...@@ -188,6 +191,7 @@ __all__ = [
"flashinfer_cutlass_fused_moe", "flashinfer_cutlass_fused_moe",
"fp4_quantize", "fp4_quantize",
"nvfp4_block_scale_interleave", "nvfp4_block_scale_interleave",
"trtllm_fp4_block_scale_moe",
"autotune", "autotune",
"has_flashinfer_moe", "has_flashinfer_moe",
"has_flashinfer_cutlass_fused_moe", "has_flashinfer_cutlass_fused_moe",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment