Unverified Commit de527e1c authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[UX] Add `--moe-backend` arg for explicit kernel selection (#33807)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
Co-authored-by: default avatarRobert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
parent 1976356e
...@@ -2,7 +2,4 @@ model_name: "RedHatAI/Qwen3-30B-A3B-FP8-block" ...@@ -2,7 +2,4 @@ model_name: "RedHatAI/Qwen3-30B-A3B-FP8-block"
accuracy_threshold: 0.85 accuracy_threshold: 0.85
num_questions: 1319 num_questions: 1319
num_fewshot: 5 num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2" server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --moe-backend=flashinfer_cutlass"
env:
VLLM_USE_FLASHINFER_MOE_FP8: "1"
VLLM_FLASHINFER_MOE_BACKEND: "throughput"
...@@ -2,7 +2,6 @@ model_name: "RedHatAI/Qwen3-30B-A3B-FP8-block" ...@@ -2,7 +2,6 @@ model_name: "RedHatAI/Qwen3-30B-A3B-FP8-block"
accuracy_threshold: 0.85 accuracy_threshold: 0.85
num_questions: 1319 num_questions: 1319
num_fewshot: 5 num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2" server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --moe-backend=triton"
env: env:
VLLM_USE_FLASHINFER_MOE_FP8: "0"
VLLM_USE_DEEP_GEMM: "0" VLLM_USE_DEEP_GEMM: "0"
...@@ -2,7 +2,4 @@ model_name: "RedHatAI/Qwen3-30B-A3B-NVFP4" ...@@ -2,7 +2,4 @@ model_name: "RedHatAI/Qwen3-30B-A3B-NVFP4"
accuracy_threshold: 0.88 accuracy_threshold: 0.88
num_questions: 1319 num_questions: 1319
num_fewshot: 5 num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2" server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --moe-backend=flashinfer_cutlass"
env:
VLLM_USE_FLASHINFER_MOE_FP4: "1"
VLLM_FLASHINFER_MOE_BACKEND: "throughput"
...@@ -2,7 +2,4 @@ model_name: "RedHatAI/Qwen3-30B-A3B-NVFP4" ...@@ -2,7 +2,4 @@ model_name: "RedHatAI/Qwen3-30B-A3B-NVFP4"
accuracy_threshold: 0.88 accuracy_threshold: 0.88
num_questions: 1319 num_questions: 1319
num_fewshot: 5 num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2" server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --moe-backend=flashinfer_trtllm"
env:
VLLM_USE_FLASHINFER_MOE_FP4: "1"
VLLM_FLASHINFER_MOE_BACKEND: "latency"
...@@ -2,6 +2,4 @@ model_name: "RedHatAI/Qwen3-30B-A3B-NVFP4" ...@@ -2,6 +2,4 @@ model_name: "RedHatAI/Qwen3-30B-A3B-NVFP4"
accuracy_threshold: 0.88 accuracy_threshold: 0.88
num_questions: 1319 num_questions: 1319
num_fewshot: 5 num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2" server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --moe-backend=cutlass"
env:
VLLM_USE_FLASHINFER_MOE_FP4: "0"
...@@ -2,7 +2,4 @@ model_name: "nvidia/Qwen3-30B-A3B-NVFP4" ...@@ -2,7 +2,4 @@ model_name: "nvidia/Qwen3-30B-A3B-NVFP4"
accuracy_threshold: 0.88 accuracy_threshold: 0.88
num_questions: 1319 num_questions: 1319
num_fewshot: 5 num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2" server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --moe-backend=flashinfer_cutlass"
env:
VLLM_USE_FLASHINFER_MOE_FP4: "1"
VLLM_FLASHINFER_MOE_BACKEND: "throughput"
...@@ -2,7 +2,4 @@ model_name: "nvidia/Qwen3-30B-A3B-NVFP4" ...@@ -2,7 +2,4 @@ model_name: "nvidia/Qwen3-30B-A3B-NVFP4"
accuracy_threshold: 0.88 accuracy_threshold: 0.88
num_questions: 1319 num_questions: 1319
num_fewshot: 5 num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2" server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --moe-backend=flashinfer_trtllm"
env:
VLLM_USE_FLASHINFER_MOE_FP4: "1"
VLLM_FLASHINFER_MOE_BACKEND: "latency"
...@@ -2,6 +2,4 @@ model_name: "nvidia/Qwen3-30B-A3B-NVFP4" ...@@ -2,6 +2,4 @@ model_name: "nvidia/Qwen3-30B-A3B-NVFP4"
accuracy_threshold: 0.88 accuracy_threshold: 0.88
num_questions: 1319 num_questions: 1319
num_fewshot: 5 num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2" server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --moe-backend=cutlass"
env:
VLLM_USE_FLASHINFER_MOE_FP4: "0"
...@@ -85,34 +85,34 @@ def can_initialize( ...@@ -85,34 +85,34 @@ def can_initialize(
) )
) )
def test_llama4_fp8_tensor_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch): def test_llama4_fp8_tensor_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1")
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput")
can_initialize( can_initialize(
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", hf_overrides=HF_OVERRIDE_MM "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
hf_overrides=HF_OVERRIDE_MM,
extra_args=["--moe-backend=flashinfer_cutlass"],
) )
def test_llama4_fp8_tensor_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): def test_llama4_fp8_tensor_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1")
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency")
can_initialize( can_initialize(
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", hf_overrides=HF_OVERRIDE_MM "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
hf_overrides=HF_OVERRIDE_MM,
extra_args=["--moe-backend=flashinfer_trtllm"],
) )
def test_llama4_nvfp4_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch): def test_llama4_nvfp4_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1")
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput")
can_initialize( can_initialize(
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", hf_overrides=HF_OVERRIDE_MM "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
hf_overrides=HF_OVERRIDE_MM,
extra_args=["--moe-backend=flashinfer_cutlass"],
) )
def test_llama4_nvfp4_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): def test_llama4_nvfp4_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1")
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency")
can_initialize( can_initialize(
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", hf_overrides=HF_OVERRIDE_MM "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
hf_overrides=HF_OVERRIDE_MM,
extra_args=["--moe-backend=flashinfer_trtllm"],
) )
...@@ -120,8 +120,11 @@ def test_llama4_nvfp4_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): ...@@ -120,8 +120,11 @@ def test_llama4_nvfp4_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
def test_deepseek_fp8_block_moe_deep_gemm(monkeypatch: pytest.MonkeyPatch): def test_deepseek_fp8_block_moe_deep_gemm(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "1") can_initialize(
can_initialize("deepseek-ai/DeepSeek-V3.1", hf_overrides=HF_OVERRIDE_TEXT) "deepseek-ai/DeepSeek-V3.1",
hf_overrides=HF_OVERRIDE_TEXT,
extra_args=["--moe-backend=deep_gemm"],
)
@pytest.mark.skip( @pytest.mark.skip(
...@@ -131,27 +134,35 @@ def test_deepseek_fp8_block_moe_deep_gemm(monkeypatch: pytest.MonkeyPatch): ...@@ -131,27 +134,35 @@ def test_deepseek_fp8_block_moe_deep_gemm(monkeypatch: pytest.MonkeyPatch):
) )
) )
def test_deepseek_fp8_block_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch): def test_deepseek_fp8_block_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1") can_initialize(
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput") "deepseek-ai/DeepSeek-V3.1",
can_initialize("deepseek-ai/DeepSeek-V3.1", hf_overrides=HF_OVERRIDE_TEXT) hf_overrides=HF_OVERRIDE_TEXT,
extra_args=["--moe-backend=flashinfer_cutlass"],
)
def test_deepseek_fp8_block_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): def test_deepseek_fp8_block_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1") can_initialize(
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency") "deepseek-ai/DeepSeek-V3.1",
can_initialize("deepseek-ai/DeepSeek-V3.1", hf_overrides=HF_OVERRIDE_TEXT) hf_overrides=HF_OVERRIDE_TEXT,
extra_args=["--moe-backend=flashinfer_trtllm"],
)
def test_deepseek_nvfp4_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch): def test_deepseek_nvfp4_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1") can_initialize(
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput") "nvidia/DeepSeek-R1-0528-FP4-v2",
can_initialize("nvidia/DeepSeek-R1-0528-FP4-v2", hf_overrides=HF_OVERRIDE_TEXT) hf_overrides=HF_OVERRIDE_TEXT,
extra_args=["--moe-backend=flashinfer_cutlass"],
)
def test_deepseek_nvfp4_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): def test_deepseek_nvfp4_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1") can_initialize(
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency") "nvidia/DeepSeek-R1-0528-FP4-v2",
can_initialize("nvidia/DeepSeek-R1-0528-FP4-v2", hf_overrides=HF_OVERRIDE_TEXT) hf_overrides=HF_OVERRIDE_TEXT,
extra_args=["--moe-backend=flashinfer_trtllm"],
)
## GPT-OSS ## ## GPT-OSS ##
...@@ -184,5 +195,8 @@ def test_gptoss_eager(monkeypatch: pytest.MonkeyPatch): ...@@ -184,5 +195,8 @@ def test_gptoss_eager(monkeypatch: pytest.MonkeyPatch):
def test_qwen3_next_bf16_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): def test_qwen3_next_bf16_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP16", "1") can_initialize(
can_initialize("Qwen/Qwen3-Next-80B-A3B-Instruct", hf_overrides=HF_OVERRIDE_TEXT) "Qwen/Qwen3-Next-80B-A3B-Instruct",
hf_overrides=HF_OVERRIDE_TEXT,
extra_args=["--moe-backend=flashinfer_trtllm"],
)
...@@ -2,13 +2,25 @@ ...@@ -2,13 +2,25 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable from collections.abc import Callable
from typing import Any from typing import Any, Literal
from pydantic import Field, field_validator from pydantic import Field, field_validator
from vllm.config.utils import config from vllm.config.utils import config
from vllm.utils.hashing import safe_hash from vllm.utils.hashing import safe_hash
MoEBackend = Literal[
"auto",
"triton",
"deep_gemm",
"cutlass",
"flashinfer_trtllm",
"flashinfer_cutlass",
"flashinfer_cutedsl",
"marlin",
"aiter",
]
@config @config
class KernelConfig: class KernelConfig:
...@@ -17,6 +29,26 @@ class KernelConfig: ...@@ -17,6 +29,26 @@ class KernelConfig:
enable_flashinfer_autotune: bool = Field(default=None) enable_flashinfer_autotune: bool = Field(default=None)
"""If True, run FlashInfer autotuning during kernel warmup.""" """If True, run FlashInfer autotuning during kernel warmup."""
moe_backend: MoEBackend = "auto"
"""Backend for MoE expert computation kernels. Available options:
- "auto": Automatically select the best backend based on model and hardware\n
- "triton": Use Triton-based fused MoE kernels\n
- "deep_gemm": Use DeepGEMM kernels (FP8 block-quantized only)\n
- "cutlass": Use vLLM CUTLASS kernels\n
- "flashinfer_trtllm": Use FlashInfer with TRTLLM-GEN kernels\n
- "flashinfer_cutlass": Use FlashInfer with CUTLASS kernels\n
- "flashinfer_cutedsl": Use FlashInfer with CuteDSL kernels (FP4 only)\n
- "marlin": Use Marlin kernels (weight-only quantization)\n
- "aiter": Use AMD AITer kernels (ROCm only)"""
@field_validator("moe_backend", mode="before")
@classmethod
def _normalize_moe_backend(cls, value: Any) -> Any:
if isinstance(value, str):
return value.lower().replace("-", "_")
return value
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
WARNING: Whenever a new field is added to this config, WARNING: Whenever a new field is added to this config,
......
...@@ -70,6 +70,7 @@ from vllm.config.cache import ( ...@@ -70,6 +70,7 @@ from vllm.config.cache import (
PrefixCachingHashAlgo, PrefixCachingHashAlgo,
) )
from vllm.config.device import Device from vllm.config.device import Device
from vllm.config.kernel import MoEBackend
from vllm.config.lora import MaxLoRARanks from vllm.config.lora import MaxLoRARanks
from vllm.config.model import ( from vllm.config.model import (
ConvertOption, ConvertOption,
...@@ -416,6 +417,7 @@ class EngineArgs: ...@@ -416,6 +417,7 @@ class EngineArgs:
data_parallel_external_lb: bool = False data_parallel_external_lb: bool = False
data_parallel_backend: DataParallelBackend = ParallelConfig.data_parallel_backend data_parallel_backend: DataParallelBackend = ParallelConfig.data_parallel_backend
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
moe_backend: MoEBackend = KernelConfig.moe_backend
all2all_backend: All2AllBackend = ParallelConfig.all2all_backend all2all_backend: All2AllBackend = ParallelConfig.all2all_backend
enable_dbo: bool = ParallelConfig.enable_dbo enable_dbo: bool = ParallelConfig.enable_dbo
ubatch_size: int = ParallelConfig.ubatch_size ubatch_size: int = ParallelConfig.ubatch_size
...@@ -1227,6 +1229,9 @@ class EngineArgs: ...@@ -1227,6 +1229,9 @@ class EngineArgs:
"--enable-flashinfer-autotune", "--enable-flashinfer-autotune",
**kernel_kwargs["enable_flashinfer_autotune"], **kernel_kwargs["enable_flashinfer_autotune"],
) )
moe_backend_kwargs = kernel_kwargs["moe_backend"]
moe_backend_kwargs["type"] = lambda s: s.lower().replace("-", "_")
kernel_group.add_argument("--moe-backend", **moe_backend_kwargs)
# vLLM arguments # vLLM arguments
vllm_kwargs = get_kwargs(VllmConfig) vllm_kwargs = get_kwargs(VllmConfig)
...@@ -1817,6 +1822,8 @@ class EngineArgs: ...@@ -1817,6 +1822,8 @@ class EngineArgs:
"are mutually exclusive" "are mutually exclusive"
) )
kernel_config.enable_flashinfer_autotune = self.enable_flashinfer_autotune kernel_config.enable_flashinfer_autotune = self.enable_flashinfer_autotune
if self.moe_backend != "auto":
kernel_config.moe_backend = self.moe_backend
load_config = self.create_load_config() load_config = self.create_load_config()
......
...@@ -1066,7 +1066,6 @@ class FusedMoEParallelConfig: ...@@ -1066,7 +1066,6 @@ class FusedMoEParallelConfig:
- Comment: There are 2 engine instances and the experts are split - Comment: There are 2 engine instances and the experts are split
between the 4 devices. between the 4 devices.
""" """
use_ep = ( use_ep = (
dp_size_ * pcp_size_ * tp_size_ > 1 dp_size_ * pcp_size_ * tp_size_ > 1
and vllm_parallel_config.enable_expert_parallel and vllm_parallel_config.enable_expert_parallel
...@@ -1155,6 +1154,7 @@ class FusedMoEConfig: ...@@ -1155,6 +1154,7 @@ class FusedMoEConfig:
# Defaults to in_dtype if not specified. # Defaults to in_dtype if not specified.
router_logits_dtype: torch.dtype | None = None router_logits_dtype: torch.dtype | None = None
moe_backend: str = "auto"
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE
has_bias: bool = False has_bias: bool = False
is_act_and_mul: bool = True is_act_and_mul: bool = True
......
...@@ -198,7 +198,6 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -198,7 +198,6 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
x = x[0].permute(2, 0, 1) x = x[0].permute(2, 0, 1)
num_experts, max_tokens, hidden_dim_by_2 = x.shape num_experts, max_tokens, hidden_dim_by_2 = x.shape
hidden_dim = hidden_dim_by_2 * 2 hidden_dim = hidden_dim_by_2 * 2
assert envs.VLLM_FLASHINFER_MOE_BACKEND == "masked_gemm"
logger.info_once( logger.info_once(
"Quantization is fused with DeepEP nvfp4 dispatch for " "Quantization is fused with DeepEP nvfp4 dispatch for "
"FlashInfer CUTEDSL as VLLM_DEEPEPLL_NVFP4_DISPATCH==1" "FlashInfer CUTEDSL as VLLM_DEEPEPLL_NVFP4_DISPATCH==1"
......
...@@ -550,6 +550,7 @@ class FusedMoE(CustomOp): ...@@ -550,6 +550,7 @@ class FusedMoE(CustomOp):
num_logical_experts=self.logical_num_experts, num_logical_experts=self.logical_num_experts,
moe_parallel_config=self.moe_parallel_config, moe_parallel_config=self.moe_parallel_config,
in_dtype=moe_in_dtype, in_dtype=moe_in_dtype,
moe_backend=vllm_config.kernel_config.moe_backend,
router_logits_dtype=router_logits_dtype, router_logits_dtype=router_logits_dtype,
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
has_bias=has_bias, has_bias=has_bias,
......
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.config.kernel import MoEBackend
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.all2all_utils import ( from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize, maybe_make_prepare_finalize,
...@@ -180,6 +181,25 @@ def backend_to_kernel_cls( ...@@ -180,6 +181,25 @@ def backend_to_kernel_cls(
raise ValueError(f"Unknown FP8 MoE backend: {backend.value}") raise ValueError(f"Unknown FP8 MoE backend: {backend.value}")
def map_fp8_backend(runner_backend: MoEBackend) -> Fp8MoeBackend:
"""Map user's MoEBackend to Fp8MoeBackend."""
mapping = {
"triton": Fp8MoeBackend.TRITON,
"deep_gemm": Fp8MoeBackend.DEEPGEMM,
"cutlass": Fp8MoeBackend.VLLM_CUTLASS,
"flashinfer_trtllm": Fp8MoeBackend.FLASHINFER_TRTLLM,
"flashinfer_cutlass": Fp8MoeBackend.FLASHINFER_CUTLASS,
"marlin": Fp8MoeBackend.MARLIN,
"aiter": Fp8MoeBackend.AITER,
}
if backend := mapping.get(runner_backend):
return backend
raise ValueError(
f"moe_backend='{runner_backend}' is not supported for FP8 MoE. "
f"Expected one of {list(mapping.keys())}."
)
def select_fp8_moe_backend( def select_fp8_moe_backend(
config: FusedMoEConfig, config: FusedMoEConfig,
weight_key: QuantKey | None, weight_key: QuantKey | None,
...@@ -242,6 +262,45 @@ def select_fp8_moe_backend( ...@@ -242,6 +262,45 @@ def select_fp8_moe_backend(
return backend, k_cls return backend, k_cls
raise ValueError(_make_log_unsupported(backend, reason)) raise ValueError(_make_log_unsupported(backend, reason))
# Handle explicit moe_backend from user.
runner_backend = config.moe_backend
if runner_backend != "auto":
requested_backend = map_fp8_backend(runner_backend)
# For batched activation format, use batched variants if available.
if activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
if requested_backend == Fp8MoeBackend.DEEPGEMM:
requested_backend = Fp8MoeBackend.BATCHED_DEEPGEMM
elif requested_backend == Fp8MoeBackend.TRITON:
requested_backend = Fp8MoeBackend.BATCHED_TRITON
elif requested_backend == Fp8MoeBackend.VLLM_CUTLASS:
requested_backend = Fp8MoeBackend.BATCHED_VLLM_CUTLASS
if (
requested_backend
in [
Fp8MoeBackend.VLLM_CUTLASS,
Fp8MoeBackend.BATCHED_VLLM_CUTLASS,
]
and not allow_vllm_cutlass
):
raise ValueError(
"vLLM CUTLASS FP8 MoE backend is disabled for this configuration."
)
# Handle FLASHINFER_TRTLLM specially (no kernel class).
if requested_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
supported, reason = is_supported_config_trtllm_fp8(
config, weight_key, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(requested_backend))
return requested_backend, None
raise ValueError(_make_log_unsupported(requested_backend, reason))
return _return_or_raise(
requested_backend, config, weight_key, activation_key, activation_format
)
# Handle explicit FlashInfer FP8 configuration. # Handle explicit FlashInfer FP8 configuration.
if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP8"): if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP8"):
if not envs.VLLM_USE_FLASHINFER_MOE_FP8: if not envs.VLLM_USE_FLASHINFER_MOE_FP8:
......
...@@ -6,6 +6,7 @@ import torch ...@@ -6,6 +6,7 @@ import torch
import vllm.envs as envs import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.config.kernel import MoEBackend
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.all2all_utils import ( from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize, maybe_make_prepare_finalize,
...@@ -103,6 +104,23 @@ def backend_to_kernel_cls( ...@@ -103,6 +104,23 @@ def backend_to_kernel_cls(
raise ValueError(f"Unknown NvFP4 MoE backend: {backend.value}") raise ValueError(f"Unknown NvFP4 MoE backend: {backend.value}")
def map_nvfp4_backend(runner_backend: MoEBackend) -> NvFp4MoeBackend:
"""Map user's MoEBackend to NvFp4MoeBackend."""
mapping = {
"cutlass": NvFp4MoeBackend.VLLM_CUTLASS,
"flashinfer_trtllm": NvFp4MoeBackend.FLASHINFER_TRTLLM,
"flashinfer_cutlass": NvFp4MoeBackend.FLASHINFER_CUTLASS,
"flashinfer_cutedsl": NvFp4MoeBackend.FLASHINFER_CUTEDSL,
"marlin": NvFp4MoeBackend.MARLIN,
}
if backend := mapping.get(runner_backend):
return backend
raise ValueError(
f"moe_backend='{runner_backend}' is not supported for NvFP4 MoE. "
f"Expected one of {list(mapping.keys())}."
)
def select_nvfp4_moe_backend( def select_nvfp4_moe_backend(
config: FusedMoEConfig, config: FusedMoEConfig,
weight_key: QuantKey | None, weight_key: QuantKey | None,
...@@ -170,6 +188,23 @@ def select_nvfp4_moe_backend( ...@@ -170,6 +188,23 @@ def select_nvfp4_moe_backend(
return backend, k_cls return backend, k_cls
raise ValueError(_make_log_unsupported(backend, reason)) raise ValueError(_make_log_unsupported(backend, reason))
# Handle explicit moe_backend from user.
runner_backend = config.moe_backend
if runner_backend != "auto":
requested_backend = map_nvfp4_backend(runner_backend)
if requested_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
supported, reason = is_supported_config_trtllm(
config, weight_key, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(requested_backend))
return requested_backend, None
raise ValueError(_make_log_unsupported(requested_backend, reason))
return _return_or_raise(
requested_backend, config, weight_key, activation_key, activation_format
)
if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP4"): if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP4"):
if not envs.VLLM_USE_FLASHINFER_MOE_FP4: if not envs.VLLM_USE_FLASHINFER_MOE_FP4:
# If the user rejects FlashInfer remove those backends. # If the user rejects FlashInfer remove those backends.
......
...@@ -9,6 +9,7 @@ from torch.nn import Module ...@@ -9,6 +9,7 @@ from torch.nn import Module
import vllm.envs as envs import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.config.kernel import MoEBackend
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
...@@ -51,6 +52,22 @@ UNSUPPORTED_BACKEND = [ ...@@ -51,6 +52,22 @@ UNSUPPORTED_BACKEND = [
] ]
def map_unquantized_backend(runner_backend: MoEBackend) -> UnquantizedMoeBackend:
"""Map user's MoEBackend to UnquantizedMoeBackend."""
mapping = {
"triton": UnquantizedMoeBackend.TRITON,
"flashinfer_trtllm": UnquantizedMoeBackend.FLASHINFER_TRTLLM,
"flashinfer_cutlass": UnquantizedMoeBackend.FLASHINFER_CUTLASS,
"aiter": UnquantizedMoeBackend.AITER,
}
if backend := mapping.get(runner_backend):
return backend
raise ValueError(
f"moe_backend='{runner_backend}' is not supported for unquantized MoE. "
f"Expected one of {list(mapping.keys())}."
)
def select_unquantized_moe_backend( def select_unquantized_moe_backend(
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
use_ep: bool, use_ep: bool,
...@@ -64,8 +81,6 @@ def select_unquantized_moe_backend( ...@@ -64,8 +81,6 @@ def select_unquantized_moe_backend(
def _make_log_backend(backend: UnquantizedMoeBackend): def _make_log_backend(backend: UnquantizedMoeBackend):
return f"Using {backend.value} backend for Unquantized MoE" return f"Using {backend.value} backend for Unquantized MoE"
rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
activation_format = ( activation_format = (
mk.FusedMoEActivationFormat.BatchedExperts mk.FusedMoEActivationFormat.BatchedExperts
if moe_config.moe_parallel_config.use_batched_activation_format if moe_config.moe_parallel_config.use_batched_activation_format
...@@ -77,20 +92,49 @@ def select_unquantized_moe_backend( ...@@ -77,20 +92,49 @@ def select_unquantized_moe_backend(
moe_config=moe_config, moe_config=moe_config,
activation_format=activation_format, activation_format=activation_format,
) )
flashinfer_trtllm_moe_enabled = ( flashinfer_trtllm_available = has_flashinfer() and trtllm_supported
has_flashinfer()
and envs.VLLM_USE_FLASHINFER_MOE_FP16
and trtllm_supported
and envs.VLLM_FLASHINFER_MOE_BACKEND == "latency"
)
# FlashInfer CUTLASS MoE is only supported on Hopper and later GPUS # FlashInfer CUTLASS MoE is only supported on Hopper and later GPUS
flashinfer_cutlass_moe_enabled = ( flashinfer_cutlass_available = (
has_flashinfer_cutlass_fused_moe() has_flashinfer_cutlass_fused_moe()
and envs.VLLM_USE_FLASHINFER_MOE_FP16
and use_ep and use_ep
and (not use_dp) and (not use_dp)
and current_platform.has_device_capability(90) and current_platform.has_device_capability(90)
) )
flashinfer_trtllm_moe_enabled = (
flashinfer_trtllm_available
and envs.VLLM_USE_FLASHINFER_MOE_FP16
and envs.VLLM_FLASHINFER_MOE_BACKEND == "latency"
)
flashinfer_cutlass_moe_enabled = (
flashinfer_cutlass_available and envs.VLLM_USE_FLASHINFER_MOE_FP16
)
rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
# Handle explicit moe_backend from user.
runner_backend = moe_config.moe_backend
if runner_backend != "auto":
requested_backend = map_unquantized_backend(runner_backend)
if requested_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM:
if not flashinfer_trtllm_available:
raise ValueError(
"FlashInfer TRTLLM MoE backend is not available for this "
"configuration."
)
elif requested_backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS:
if not flashinfer_cutlass_available:
raise ValueError(
"FlashInfer CUTLASS MoE backend is not available for this "
"configuration."
)
elif requested_backend == UnquantizedMoeBackend.AITER and not (
current_platform.is_rocm() and rocm_aiter_moe_enabled
):
raise ValueError(
"ROCm AITer MoE backend is not available for this configuration."
)
logger.info_once(_make_log_backend(requested_backend), scope="local")
return requested_backend
if current_platform.is_rocm(): if current_platform.is_rocm():
if rocm_aiter_moe_enabled: if rocm_aiter_moe_enabled:
backend = UnquantizedMoeBackend.AITER backend = UnquantizedMoeBackend.AITER
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment