Unverified Commit d9b90a07 authored by yzong-rh's avatar yzong-rh Committed by GitHub
Browse files

[MoE Refactor] Migrate Unquantized to Full Oracle Flow (#36286)


Signed-off-by: default avatarYifan Zong <yzong@redhat.com>
Signed-off-by: default avatarRobert Shaw <robshaw@redhat.com>
Signed-off-by: default avataryzong-rh <yzong@redhat.com>
Co-authored-by: default avatarRobert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: default avatarRobert Shaw <robshaw@redhat.com>
parent 598190aa
...@@ -1664,7 +1664,7 @@ def test_unquantized_bf16_flashinfer_trtllm_backend( ...@@ -1664,7 +1664,7 @@ def test_unquantized_bf16_flashinfer_trtllm_backend(
intermediate_size_per_partition=n, intermediate_size_per_partition=n,
num_local_experts=e, num_local_experts=e,
num_logical_experts=e, num_logical_experts=e,
activation="silu", activation=MoEActivation.SILU,
device="cuda", device="cuda",
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(), moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
in_dtype=dtype, in_dtype=dtype,
...@@ -1695,13 +1695,25 @@ def test_unquantized_bf16_flashinfer_trtllm_backend( ...@@ -1695,13 +1695,25 @@ def test_unquantized_bf16_flashinfer_trtllm_backend(
layer.topk_group = 1 layer.topk_group = 1
layer.intermediate_size_per_partition = n layer.intermediate_size_per_partition = n
layer.ep_rank = 0 layer.ep_rank = 0
layer.activation = "silu" layer.activation = MoEActivation.SILU
layer.e_score_correction_bias = None layer.e_score_correction_bias = None
layer.routing_method_type = RoutingMethodType.Renormalize layer.routing_method_type = RoutingMethodType.Renormalize
layer.expert_map = None
layer.apply_router_weight_on_input = False
layer.routed_scaling_factor = None
layer.shared_experts = None
layer._maybe_init_expert_routing_tables = lambda: None
quant_method.process_weights_after_loading(layer) quant_method.process_weights_after_loading(layer)
trtllm_output = quant_method.forward_monolithic_cuda( assert quant_method.moe_kernel is not None, (
"moe_kernel should be set after process_weights_after_loading"
)
assert quant_method.supports_internal_mk, (
"supports_internal_mk should be True after setup"
)
trtllm_output = quant_method.apply_monolithic(
layer=layer, layer=layer,
x=a, x=a,
router_logits=router_logits, router_logits=router_logits,
......
...@@ -24,7 +24,7 @@ from vllm.platforms import current_platform ...@@ -24,7 +24,7 @@ from vllm.platforms import current_platform
], ],
) )
@patch( @patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.has_flashinfer", "vllm.utils.flashinfer.has_flashinfer",
return_value=False, return_value=False,
) )
@patch( @patch(
...@@ -54,13 +54,29 @@ def test_select_default_backend_by_platform( ...@@ -54,13 +54,29 @@ def test_select_default_backend_by_platform(
# Set only the specified platform to True # Set only the specified platform to True
getattr(mock_platform, platform_method).return_value = True getattr(mock_platform, platform_method).return_value = True
with (
patch.object(current_platform, "is_cuda", return_value=False),
patch.object(current_platform, "is_rocm", return_value=False),
patch.object(current_platform, "is_cpu", return_value=False),
patch.object(current_platform, "is_xpu", return_value=False),
patch.object(current_platform, "is_tpu", return_value=False),
patch.object(current_platform, "is_out_of_tree", return_value=False),
patch.object(current_platform, platform_method, return_value=True),
):
moe_config = make_dummy_moe_config() moe_config = make_dummy_moe_config()
selected_backend = select_unquantized_moe_backend( selected_backend, expert_cls = select_unquantized_moe_backend(
moe_config=moe_config, moe_config=moe_config
use_dp=False,
) )
assert selected_backend == expected_backend assert selected_backend == expected_backend
if expected_backend in [
UnquantizedMoeBackend.CPU,
UnquantizedMoeBackend.OOT,
UnquantizedMoeBackend.TPU,
]:
assert expert_cls is None
else:
assert expert_cls is not None
@patch( @patch(
...@@ -87,88 +103,90 @@ def test_select_rocm_aiter_backend(mock_aiter_enabled, mock_has_flashinfer): ...@@ -87,88 +103,90 @@ def test_select_rocm_aiter_backend(mock_aiter_enabled, mock_has_flashinfer):
mock_platform.is_out_of_tree.return_value = False mock_platform.is_out_of_tree.return_value = False
moe_config = make_dummy_moe_config() moe_config = make_dummy_moe_config()
selected_backend = select_unquantized_moe_backend( selected_backend, expert_cls = select_unquantized_moe_backend(
moe_config=moe_config, moe_config=moe_config,
use_dp=False,
) )
assert selected_backend == UnquantizedMoeBackend.AITER assert selected_backend == UnquantizedMoeBackend.AITER
assert expert_cls is not None
@patch( @patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.has_flashinfer", "vllm.model_executor.layers.fused_moe.experts.trtllm_bf16_moe.TrtLlmBf16Experts.is_supported_config",
return_value=True,
)
@patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.is_supported_config_trtllm_bf16",
return_value=(True, None), return_value=(True, None),
) )
@pytest.mark.skipif( @pytest.mark.skipif(
not current_platform.is_cuda(), reason="Only supported on NVIDIA platforms." not current_platform.is_cuda(), reason="Only supported on NVIDIA platforms."
) )
def test_select_cuda_flashinfer_trtllm_backend( def test_select_cuda_flashinfer_trtllm_backend(mock_is_supported_trtllm, monkeypatch):
mock_has_flashinfer, mock_is_supported_trtllm, monkeypatch
):
"""Test CUDA backend selection when FlashInfer TRTLLM is available and enabled.""" """Test CUDA backend selection when FlashInfer TRTLLM is available and enabled."""
with patch( with (
"vllm.model_executor.layers.fused_moe.oracle.unquantized.current_platform" patch.object(current_platform, "is_cuda", return_value=True),
) as mock_platform: patch.object(current_platform, "is_rocm", return_value=False),
# Set as CUDA platform patch.object(current_platform, "is_cpu", return_value=False),
mock_platform.is_cuda.return_value = True patch.object(current_platform, "is_xpu", return_value=False),
mock_platform.is_rocm.return_value = False patch.object(current_platform, "is_tpu", return_value=False),
mock_platform.is_cpu.return_value = False patch.object(current_platform, "is_out_of_tree", return_value=False),
mock_platform.is_xpu.return_value = False patch.object(current_platform, "has_device_capability", return_value=True),
mock_platform.is_tpu.return_value = False ):
mock_platform.is_out_of_tree.return_value = False
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP16", "1") monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP16", "1")
moe_config = make_dummy_moe_config() moe_config = make_dummy_moe_config()
# TRTLLM requires EP and does not support DP
moe_config.moe_parallel_config.use_ep = True
moe_config.moe_parallel_config.use_dp = False
selected_backend = select_unquantized_moe_backend( selected_backend, experts_cls = select_unquantized_moe_backend(
moe_config=moe_config, moe_config=moe_config
use_dp=False,
) )
assert selected_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM assert selected_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM
assert experts_cls is not None
@patch( @patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.has_flashinfer", "vllm.utils.flashinfer.has_flashinfer",
return_value=True, return_value=True,
) )
@patch( @patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.is_supported_config_trtllm_bf16", "vllm.model_executor.layers.fused_moe.experts.trtllm_bf16_moe.TrtLlmBf16Experts.is_supported_config",
return_value=(False, None), return_value=(False, None),
) )
@patch(
"vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts.is_supported_config",
return_value=(True, None),
)
@pytest.mark.skipif( @pytest.mark.skipif(
not current_platform.is_cuda(), reason="Only supported on NVIDIA platforms." not current_platform.is_cuda(), reason="Only supported on NVIDIA platforms."
) )
def test_select_cuda_flashinfer_cutlass_backend( def test_select_cuda_flashinfer_cutlass_backend(
mock_has_flashinfer, mock_is_supported_trtllm, monkeypatch mock_has_flashinfer,
mock_is_supported_trtllm,
mock_is_supported_cutlass,
monkeypatch,
): ):
"""Test CUDA backend selection when FlashInfer TRTLLM is not available """Test CUDA backend selection when FlashInfer TRTLLM is not available
and FlashInfer CUTLASS is available.""" and FlashInfer CUTLASS is available."""
with patch( with (
"vllm.model_executor.layers.fused_moe.oracle.unquantized.current_platform" patch.object(current_platform, "is_cuda", return_value=True),
) as mock_platform: patch.object(current_platform, "is_rocm", return_value=False),
# Set as CUDA platform with Hopper capability patch.object(current_platform, "is_cpu", return_value=False),
mock_platform.is_cuda.return_value = True patch.object(current_platform, "is_xpu", return_value=False),
mock_platform.is_rocm.return_value = False patch.object(current_platform, "is_tpu", return_value=False),
mock_platform.is_cpu.return_value = False patch.object(current_platform, "is_out_of_tree", return_value=False),
mock_platform.is_xpu.return_value = False patch.object(current_platform, "has_device_capability", return_value=True),
mock_platform.is_tpu.return_value = False ):
mock_platform.is_out_of_tree.return_value = False
mock_platform.has_device_capability.return_value = True # SM90+
# Enable FlashInfer via env var # Enable FlashInfer via env var
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP16", "1") monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP16", "1")
moe_config = make_dummy_moe_config() moe_config = make_dummy_moe_config()
# CUTLASS requires EP and does not support DP
moe_config.moe_parallel_config.use_ep = True
moe_config.moe_parallel_config.use_dp = False
selected_backend = select_unquantized_moe_backend( selected_backend, experts_cls = select_unquantized_moe_backend(
moe_config=moe_config, moe_config=moe_config
use_dp=False, # CUTLASS doesn't support DP
) )
assert selected_backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS assert selected_backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS
assert experts_cls is not None
...@@ -210,6 +210,13 @@ def test_gptoss_eager(monkeypatch: pytest.MonkeyPatch): ...@@ -210,6 +210,13 @@ def test_gptoss_eager(monkeypatch: pytest.MonkeyPatch):
## Qwen3 Next ## ## Qwen3 Next ##
@pytest.mark.skip(
reason=(
"FLASHINFER TRTLLM MoE has a bug with all negative router logits "
"for models with RENORMALIZE. This will be re-enabled once the "
"issue is fixed in flashinfer."
)
)
def test_qwen3_next_bf16_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): def test_qwen3_next_bf16_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
can_initialize( can_initialize(
"Qwen/Qwen3-Next-80B-A3B-Instruct", "Qwen/Qwen3-Next-80B-A3B-Instruct",
......
...@@ -49,6 +49,9 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): ...@@ -49,6 +49,9 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
assert not self.base_layer.use_ep, ( assert not self.base_layer.use_ep, (
"EP support for Fused MoE LoRA is not implemented yet." "EP support for Fused MoE LoRA is not implemented yet."
) )
assert not self.base_layer.quant_method.is_monolithic, (
"Monolithic kernels are not supported for Fused MoE LoRA."
)
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
self.device = _get_lora_device(base_layer) self.device = _get_lora_device(base_layer)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_trtllm_fused_moe
class TrtLlmBf16Experts(mk.FusedMoEExpertsMonolithic):
"""
BF16 unquantized TRTLLM-Gen MoE kernels. Supports monolithic interface.
"""
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
):
super().__init__(moe_config, quant_config)
self.routing_method_type = moe_config.routing_method
self.topk = moe_config.experts_per_token
self.intermediate_size_per_partition = (
moe_config.intermediate_size_per_partition
)
self.hidden_dim = moe_config.hidden_dim
self.local_num_experts = moe_config.num_local_experts
self.ep_rank = moe_config.moe_parallel_config.ep_rank
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
@staticmethod
def _supports_current_device() -> bool:
"""Supports only Blackwell-family GPUs."""
p = current_platform
return (
p.is_cuda()
and p.is_device_capability_family(100)
and has_flashinfer_trtllm_fused_moe()
)
@staticmethod
def _supports_no_act_and_mul() -> bool:
"""BF16 kernels do not support non-gated MoE"""
return False
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
"""Supports only unquantized inputs."""
return weight_key is None and activation_key is None
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
return activation in [MoEActivation.SILU]
@staticmethod
def _supports_routing_method(
routing_method: RoutingMethodType,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
return routing_method in [
RoutingMethodType.Default,
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Llama4,
# NOTE: TRTLLM Kernel has issue with Qwen3.5 router.
# Re-enable once the issue is resolved.
# https://github.com/vllm-project/vllm/issues/37591
# RoutingMethodType.Renormalize,
# RoutingMethodType.RenormalizeNaive
]
@staticmethod
def _supports_parallel_config(
moe_parallel_config: FusedMoEParallelConfig,
) -> bool:
"""Monolithic kernel so only use with naive DP/EP and TP."""
return (
not moe_parallel_config.use_all2all_kernels
or moe_parallel_config.use_ag_rs_all2all_kernels
) and not moe_parallel_config.enable_eplb
@staticmethod
def _supports_router_logits_dtype(
router_logits_dtype: torch.dtype | None,
routing_method: RoutingMethodType,
) -> bool:
return True
def supports_chunking(self) -> bool:
return False
def supports_expert_map(self) -> bool:
return False
@property
def expects_unquantized_inputs(self) -> bool:
return True
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
apply_router_weight_on_input: bool,
num_expert_group: int | None = None,
e_score_correction_bias: torch.Tensor | None = None,
routed_scaling_factor: float | None = None,
topk_group: int | None = None,
) -> torch.Tensor:
import flashinfer
return flashinfer.fused_moe.trtllm_bf16_moe(
routing_logits=router_logits,
routing_bias=e_score_correction_bias,
hidden_states=hidden_states,
gemm1_weights=w1,
gemm2_weights=w2,
num_experts=global_num_experts,
top_k=self.topk,
n_group=num_expert_group,
topk_group=topk_group,
intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.ep_rank * self.local_num_experts,
local_num_experts=self.local_num_experts,
routing_method_type=self.routing_method_type,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
RoutingMethodType,
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
#
# Methods used by the oracle for kernel selection.
#
def _supports_current_device() -> bool:
"""Supports only Blackwell-family GPUs."""
p = current_platform
return p.is_cuda() and p.is_device_capability_family(100)
def _supports_no_act_and_mul() -> bool:
"""BF16 kernels do not support non-gated MoE"""
return False
def _supports_activation(activation: MoEActivation) -> bool:
return activation in [MoEActivation.SILU]
def _supports_routing_method_bf16(
routing_method: RoutingMethodType,
) -> bool:
return routing_method in [
RoutingMethodType.Default,
RoutingMethodType.Renormalize,
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Llama4,
RoutingMethodType.RenormalizeNaive,
]
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
"""Supports TRTLLM Kernel does not support EPLB."""
return not moe_parallel_config.enable_eplb
def is_supported_config_trtllm_bf16(
moe_config: FusedMoEConfig,
activation_format: mk.FusedMoEActivationFormat,
) -> tuple[bool, str | None]:
"""
This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config
for BF16 unquantized kernels.
"""
def _make_reason(reason: str) -> str:
return f"kernel does not support {reason}"
if not _supports_current_device():
return False, _make_reason(f"current device {current_platform.device_name}")
elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()):
return False, _make_reason("no act_and_mul MLP layer")
elif not _supports_activation(moe_config.activation):
return False, _make_reason(f"{moe_config.activation} activation")
elif not _supports_parallel_config(moe_config.moe_parallel_config):
return False, _make_reason(f"parallel config {moe_config.moe_parallel_config}")
elif not _supports_routing_method_bf16(moe_config.routing_method):
return False, _make_reason(f"routing method {moe_config.routing_method}")
elif activation_format != mk.FusedMoEActivationFormat.Standard:
return False, _make_reason(f"activation format {activation_format}")
return True, None
def flashinfer_fused_moe_bf16(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor | None,
hidden_states: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm2_weights: torch.Tensor,
num_experts: int,
top_k: int,
n_group: int | None,
topk_group: int | None,
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
routing_method_type: int,
tune_max_num_tokens: int = 8192,
) -> torch.Tensor:
from vllm.utils.flashinfer import flashinfer_trtllm_bf16_moe
return flashinfer_trtllm_bf16_moe(
routing_logits=routing_logits,
routing_bias=routing_bias,
hidden_states=hidden_states,
gemm1_weights=gemm1_weights,
gemm2_weights=gemm2_weights,
num_experts=num_experts,
top_k=top_k,
n_group=n_group,
topk_group=topk_group,
intermediate_size=intermediate_size,
local_expert_offset=local_expert_offset,
local_num_experts=local_num_experts,
routing_method_type=routing_method_type,
tune_max_num_tokens=tune_max_num_tokens,
)
def flashinfer_fused_moe_bf16_fake(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor | None,
hidden_states: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm2_weights: torch.Tensor,
num_experts: int,
top_k: int,
n_group: int | None,
topk_group: int | None,
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
routing_method_type: int = RoutingMethodType.Renormalize,
tune_max_num_tokens: int = 8192,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
direct_register_custom_op(
op_name="flashinfer_fused_moe_bf16",
op_func=flashinfer_fused_moe_bf16,
fake_impl=flashinfer_fused_moe_bf16_fake,
tags=(torch.Tag.needs_fixed_stride_order,),
)
...@@ -1967,6 +1967,10 @@ class TritonExperts(mk.FusedMoEExpertsModular): ...@@ -1967,6 +1967,10 @@ class TritonExperts(mk.FusedMoEExpertsModular):
or moe_parallel_config.use_fi_nvl_one_sided_kernels or moe_parallel_config.use_fi_nvl_one_sided_kernels
) )
@staticmethod
def _supports_batch_invariance():
return True
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
return True return True
......
...@@ -9,6 +9,7 @@ from typing import final ...@@ -9,6 +9,7 @@ from typing import final
import torch import torch
import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import ( from vllm.model_executor.layers.fused_moe.activation import (
MoEActivation, MoEActivation,
...@@ -563,6 +564,8 @@ class FusedMoEExperts(ABC): ...@@ -563,6 +564,8 @@ class FusedMoEExperts(ABC):
) )
elif activation_format != cls.activation_format(): elif activation_format != cls.activation_format():
return False, _make_reason(f"{activation_format.value} activation format") return False, _make_reason(f"{activation_format.value} activation format")
elif envs.VLLM_BATCH_INVARIANT and not cls._supports_batch_invariance():
return False, _make_reason("batch invariance")
return True, None return True, None
@staticmethod @staticmethod
...@@ -645,6 +648,15 @@ class FusedMoEExperts(ABC): ...@@ -645,6 +648,15 @@ class FusedMoEExperts(ABC):
""" """
return True return True
@staticmethod
def _supports_batch_invariance() -> bool:
"""
Whether the kernel supports batch invariance, i.e. the output does not
depend on the order of the tokens in the input batch. This is useful
for determining if the kernel can used with VLLM_BATCH_INVARIANT=1.
"""
return False
# #
# Various helpers for accessing quantization parameters from the # Various helpers for accessing quantization parameters from the
# quant_config. # quant_config.
......
...@@ -11,21 +11,20 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk ...@@ -11,21 +11,20 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.config.kernel import MoEBackend from vllm.config.kernel import MoEBackend
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
) )
from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import (
is_supported_config_trtllm_bf16,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoDPEPModular,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
convert_moe_weights_to_flashinfer_trtllm_block_layout,
get_flashinfer_moe_backend,
swap_w13_to_w31, swap_w13_to_w31,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer, has_flashinfer_cutlass_fused_moe
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -35,21 +34,96 @@ class UnquantizedMoeBackend(Enum): ...@@ -35,21 +34,96 @@ class UnquantizedMoeBackend(Enum):
FLASHINFER_CUTLASS = "FlashInfer CUTLASS" FLASHINFER_CUTLASS = "FlashInfer CUTLASS"
AITER = "ROCm AITER" AITER = "ROCm AITER"
TRITON = "TRITON" TRITON = "TRITON"
BATCHED_TRITON = "BATCHED_TRITON"
CPU = "CPU" CPU = "CPU"
XPU = "XPU" XPU = "XPU"
TPU = "TPU" TPU = "TPU"
OOT = "OOT" OOT = "OOT"
# NOTE(zyongye): Unsupported backend means backend def _get_priority_backends(moe_config: FusedMoEConfig) -> list[UnquantizedMoeBackend]:
# that is not conform with Modular kernel format. """
# We will directly call the kernel for those backend Get available backends in priority order based on platform and config.
UNSUPPORTED_BACKEND = [
UnquantizedMoeBackend.FLASHINFER_TRTLLM, This function can be extended to become more complex as needed.
UnquantizedMoeBackend.CPU, """
UnquantizedMoeBackend.TPU,
UnquantizedMoeBackend.OOT, def _move_to_back(
] backends: list[UnquantizedMoeBackend],
backend: UnquantizedMoeBackend,
) -> None:
backends.append(backends.pop(backends.index(backend)))
if current_platform.is_rocm():
_AVAILABLE_BACKENDS = [
UnquantizedMoeBackend.AITER,
UnquantizedMoeBackend.TRITON,
UnquantizedMoeBackend.BATCHED_TRITON,
]
elif current_platform.is_cuda():
_AVAILABLE_BACKENDS = [
UnquantizedMoeBackend.FLASHINFER_TRTLLM,
UnquantizedMoeBackend.FLASHINFER_CUTLASS,
UnquantizedMoeBackend.TRITON,
UnquantizedMoeBackend.BATCHED_TRITON,
]
# HACK: Qwen3.5 has crash with FLASHINFER_CUTLASS BF16 if DEP.
# Updating the oracle querying logic is out of the scope of this
# PR. Need to fix the kernel or update structure in follow up.
if moe_config.moe_parallel_config.dp_size > 1:
_move_to_back(_AVAILABLE_BACKENDS, UnquantizedMoeBackend.FLASHINFER_CUTLASS)
elif current_platform.is_xpu():
_AVAILABLE_BACKENDS = [UnquantizedMoeBackend.XPU]
elif current_platform.is_cpu():
_AVAILABLE_BACKENDS = [UnquantizedMoeBackend.CPU]
return _AVAILABLE_BACKENDS
def backend_to_kernel_cls(
backend: UnquantizedMoeBackend,
) -> type[mk.FusedMoEExperts]:
if backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM:
from vllm.model_executor.layers.fused_moe.experts.trtllm_bf16_moe import (
TrtLlmBf16Experts,
)
return TrtLlmBf16Experts
elif backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
return FlashInferExperts
elif backend == UnquantizedMoeBackend.AITER:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
AiterExperts,
)
return AiterExperts
elif backend == UnquantizedMoeBackend.TRITON:
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
return TritonExperts
elif backend == UnquantizedMoeBackend.BATCHED_TRITON:
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts,
)
return BatchedTritonExperts
elif backend == UnquantizedMoeBackend.XPU:
from vllm.model_executor.layers.fused_moe.xpu_fused_moe import XPUExperts
return XPUExperts
else:
raise ValueError(f"Unknown unquantized MoE backend: {backend.value}")
def map_unquantized_backend(runner_backend: MoEBackend) -> UnquantizedMoeBackend: def map_unquantized_backend(runner_backend: MoEBackend) -> UnquantizedMoeBackend:
...@@ -70,194 +144,224 @@ def map_unquantized_backend(runner_backend: MoEBackend) -> UnquantizedMoeBackend ...@@ -70,194 +144,224 @@ def map_unquantized_backend(runner_backend: MoEBackend) -> UnquantizedMoeBackend
def select_unquantized_moe_backend( def select_unquantized_moe_backend(
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
use_dp: bool, ) -> tuple[UnquantizedMoeBackend, type[mk.FusedMoEExperts] | None]:
) -> UnquantizedMoeBackend:
""" """
Select the primary Unquantized MoE backend Select the primary Unquantized MoE backend.
Note: Shape-specific fallbacks may still occur at runtime. Note: Shape-specific fallbacks may still occur at runtime.
""" """
def _make_log_backend(backend: UnquantizedMoeBackend): if current_platform.is_cpu():
return f"Using {backend.value} backend for Unquantized MoE" # TODO: migrate to MK structure.
return UnquantizedMoeBackend.CPU, None
if current_platform.is_tpu():
return UnquantizedMoeBackend.TPU, None
if current_platform.is_out_of_tree():
return UnquantizedMoeBackend.OOT, None
# NOTE: the kernels are selected in the following order.
AVAILABLE_BACKENDS = _get_priority_backends(moe_config)
# NOTE(rob): We need to peak into the P/F selection to determine
# if we are using the batched or standard expert format, which
# if not ideal. Once we unify TP + DP/EP, we can select P/F first.
activation_format = ( activation_format = (
mk.FusedMoEActivationFormat.BatchedExperts mk.FusedMoEActivationFormat.BatchedExperts
if moe_config.moe_parallel_config.use_batched_activation_format if moe_config.moe_parallel_config.use_batched_activation_format
else mk.FusedMoEActivationFormat.Standard else mk.FusedMoEActivationFormat.Standard
) )
# Check if FlashInfer TRTLLM BF16 MoE is supported def _make_log_backend(backend: UnquantizedMoeBackend) -> str:
trtllm_supported, _ = is_supported_config_trtllm_bf16( available_strs = [b.value for b in AVAILABLE_BACKENDS]
moe_config=moe_config, return (
activation_format=activation_format, f"Using {backend.value} Unquantized MoE backend out "
) f"of potential backends: {available_strs}."
flashinfer_trtllm_available = has_flashinfer() and trtllm_supported )
# FlashInfer CUTLASS MoE is only supported on Hopper and later GPUS
flashinfer_cutlass_available = ( def _make_log_unsupported(
has_flashinfer_cutlass_fused_moe() backend: UnquantizedMoeBackend, reason: str | None
and (not use_dp) ) -> str:
and current_platform.has_device_capability(90) if reason:
) return (
flashinfer_trtllm_moe_enabled = ( f"Unquantized MoE backend {backend.value} does not support the "
flashinfer_trtllm_available f"deployment configuration since {reason}."
and envs.VLLM_USE_FLASHINFER_MOE_FP16 )
and envs.VLLM_FLASHINFER_MOE_BACKEND == "latency" return (
) f"Unquantized MoE backend '{backend.value}' does not support the "
flashinfer_cutlass_moe_enabled = ( "deployment configuration."
flashinfer_cutlass_available and envs.VLLM_USE_FLASHINFER_MOE_FP16 )
)
rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() def _return_or_raise(
backend: UnquantizedMoeBackend,
config: FusedMoEConfig,
activation_format: mk.FusedMoEActivationFormat,
) -> tuple[UnquantizedMoeBackend, type[mk.FusedMoEExperts] | None]:
k_cls = backend_to_kernel_cls(backend)
supported, reason = k_cls.is_supported_config(
k_cls, config, None, None, activation_format
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
raise ValueError(_make_log_unsupported(backend, reason))
# Handle explicit moe_backend from user.
runner_backend = moe_config.moe_backend runner_backend = moe_config.moe_backend
if runner_backend != "auto": if runner_backend != "auto":
requested_backend = map_unquantized_backend(runner_backend) requested_backend = map_unquantized_backend(runner_backend)
if requested_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM: if (
if not flashinfer_trtllm_available: activation_format == mk.FusedMoEActivationFormat.BatchedExperts
and requested_backend == UnquantizedMoeBackend.TRITON
):
requested_backend = UnquantizedMoeBackend.BATCHED_TRITON
return _return_or_raise(requested_backend, moe_config, activation_format)
# Handle explicit FlashInfer FP16 configuration.
if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP16"):
if not envs.VLLM_USE_FLASHINFER_MOE_FP16:
if UnquantizedMoeBackend.FLASHINFER_TRTLLM in AVAILABLE_BACKENDS:
AVAILABLE_BACKENDS.remove(UnquantizedMoeBackend.FLASHINFER_TRTLLM)
if UnquantizedMoeBackend.FLASHINFER_CUTLASS in AVAILABLE_BACKENDS:
AVAILABLE_BACKENDS.remove(UnquantizedMoeBackend.FLASHINFER_CUTLASS)
elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"):
# If user is explicit about backend, validate it.
fi_backend = get_flashinfer_moe_backend()
if fi_backend == FlashinferMoeBackend.CUTLASS:
backend = UnquantizedMoeBackend.FLASHINFER_CUTLASS
elif fi_backend == FlashinferMoeBackend.TENSORRT_LLM:
backend = UnquantizedMoeBackend.FLASHINFER_TRTLLM
else:
raise ValueError( raise ValueError(
"FlashInfer TRTLLM MoE backend is not available for this " f"FlashInfer MOE backend {fi_backend} "
"configuration." "does not support unquantized MoE."
) )
elif requested_backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS: k_cls = backend_to_kernel_cls(backend)
if not flashinfer_cutlass_available: return _return_or_raise(backend, moe_config, activation_format)
raise ValueError( else:
"FlashInfer CUTLASS MoE backend is not available for this " # If the user is not explicit about the backend, try both.
"configuration." for backend in [
UnquantizedMoeBackend.FLASHINFER_TRTLLM,
UnquantizedMoeBackend.FLASHINFER_CUTLASS,
]:
k_cls = backend_to_kernel_cls(backend)
supported, reason = k_cls.is_supported_config(
k_cls, moe_config, None, None, activation_format
) )
elif requested_backend == UnquantizedMoeBackend.AITER and not ( if supported:
current_platform.is_rocm() and rocm_aiter_moe_enabled logger.info_once(_make_log_backend(backend), scope="local")
): return backend, k_cls
raise ValueError( else:
"ROCm AITer MoE backend is not available for this configuration." logger.debug_once(
_make_log_unsupported(backend, reason), scope="local"
)
raise NotImplementedError(
"Found VLLM_USE_FLASHINFER_MOE_FP16=1, but no "
"FlashInfer unquantized MoE backend supports the configuration."
) )
logger.info_once(_make_log_backend(requested_backend), scope="local")
return requested_backend
if current_platform.is_rocm(): # Handle explicit AITER FP8 configuration.
if rocm_aiter_moe_enabled: if envs.is_set("VLLM_ROCM_USE_AITER") or envs.is_set("VLLM_ROCM_USE_AITER_MOE"):
backend = UnquantizedMoeBackend.AITER if not envs.VLLM_ROCM_USE_AITER or not envs.VLLM_ROCM_USE_AITER_MOE:
if UnquantizedMoeBackend.AITER in AVAILABLE_BACKENDS:
AVAILABLE_BACKENDS.remove(UnquantizedMoeBackend.AITER)
else: else:
backend = UnquantizedMoeBackend.TRITON backend = UnquantizedMoeBackend.AITER
if current_platform.is_cuda(): return _return_or_raise(backend, moe_config, activation_format)
if flashinfer_trtllm_moe_enabled:
backend = UnquantizedMoeBackend.FLASHINFER_TRTLLM
elif flashinfer_cutlass_moe_enabled:
backend = UnquantizedMoeBackend.FLASHINFER_CUTLASS
if trtllm_supported:
logger.info_once(
"FlashInfer TRTLLM MoE is available but not enabled, "
"consider setting VLLM_FLASHINFER_MOE_BACKEND=latency "
"to enable it for better performance.",
scope="local",
)
else:
if not envs.VLLM_USE_FLASHINFER_MOE_FP16 and trtllm_supported:
logger.info_once(
"FlashInfer TRTLLM MoE is available but not enabled, "
"consider setting VLLM_USE_FLASHINFER_MOE_FP16=1 "
"and VLLM_FLASHINFER_MOE_BACKEND=latency "
"to enable it for better performance.",
scope="local",
)
elif not use_dp and flashinfer_cutlass_available:
logger.info_once(
"FlashInfer CUTLASS MoE is available"
" but not enabled, consider setting"
" VLLM_USE_FLASHINFER_MOE_FP16=1 to enable it.",
scope="local",
)
elif use_dp:
logger.info_once(
"FlashInfer CUTLASS MoE is currently not available for DP.",
scope="local",
)
backend = UnquantizedMoeBackend.TRITON
if current_platform.is_xpu():
backend = UnquantizedMoeBackend.XPU
if current_platform.is_cpu():
backend = UnquantizedMoeBackend.CPU
if current_platform.is_tpu():
backend = UnquantizedMoeBackend.TPU
if current_platform.is_out_of_tree():
backend = UnquantizedMoeBackend.OOT
logger.info_once(_make_log_backend(backend), scope="local") for backend in AVAILABLE_BACKENDS:
return backend k_cls = backend_to_kernel_cls(backend)
supported, reason = k_cls.is_supported_config(
k_cls, moe_config, None, None, activation_format
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
raise NotImplementedError(
"No Unquantized MoE backend supports the deployment configuration."
)
def convert_to_unquantized_kernel_format( def convert_to_unquantized_kernel_format(
unquantized_backend: UnquantizedMoeBackend, unquantized_backend: UnquantizedMoeBackend,
layer: Module, layer: Module,
w13_weight: torch.Tensor | None = None, w13_weight: torch.Tensor,
w2_weight: torch.Tensor | None = None, w2_weight: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
if unquantized_backend == UnquantizedMoeBackend.AITER: if unquantized_backend == UnquantizedMoeBackend.AITER:
w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights( w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights(w13_weight, w2_weight)
layer.w13_weight.data, layer.w2_weight.data
)
elif unquantized_backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS: elif unquantized_backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS:
if layer.moe_config.is_act_and_mul:
# Swap halves to arrange as [w3; w1] (kernel expectation)
# Non-gated MoE: w13 is a single projection, no need to swap.
w13_weight = swap_w13_to_w31(w13_weight)
elif unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM:
# Swap halves to arrange as [w3; w1] (kernel expectation) # Swap halves to arrange as [w3; w1] (kernel expectation)
w13_weight = swap_w13_to_w31(layer.w13_weight.data) w13_weight = swap_w13_to_w31(w13_weight)
_cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
w13_weight, w2_weight = convert_moe_weights_to_flashinfer_trtllm_block_layout(
_cache_permute_indices,
w13_weight,
w2_weight,
)
return w13_weight, w2_weight return w13_weight.contiguous(), w2_weight.contiguous()
def make_unquantized_moe_kernel( def make_unquantized_moe_kernel(
backend: UnquantizedMoeBackend,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
) -> mk.FusedMoEKernel | None: backend: UnquantizedMoeBackend,
if backend in UNSUPPORTED_BACKEND: experts_cls: type[mk.FusedMoEExperts],
return None routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
shared_experts: torch.nn.Module | None = None,
) -> mk.FusedMoEKernel:
# Create Prepare/Finalize
is_monolithic = issubclass(experts_cls, mk.FusedMoEExpertsMonolithic)
prepare_finalize = maybe_make_prepare_finalize(
moe=moe_config,
quant_config=quant_config,
routing_tables=routing_tables,
allow_new_interface=True,
use_monolithic=is_monolithic,
)
assert prepare_finalize is not None
if backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS: logger.info_once("Using %s", prepare_finalize.__class__.__name__, scope="local")
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
kernel = mk.FusedMoEKernel( # Create Experts
MoEPrepareAndFinalizeNoDPEPModular(), if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
FlashInferExperts( max_num_tokens = prepare_finalize.max_num_tokens_per_rank()
moe_config=moe_config, assert max_num_tokens is not None
quant_config=quant_config, experts = experts_cls(
), moe_config=moe_config,
inplace=False, quant_config=quant_config,
max_num_tokens=max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
) )
else:
elif backend == UnquantizedMoeBackend.AITER: experts = experts_cls(
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( moe_config=moe_config,
AiterExperts, quant_config=quant_config,
) )
kernel = mk.FusedMoEKernel( kernel = mk.FusedMoEKernel(
MoEPrepareAndFinalizeNoDPEPModular(), prepare_finalize,
AiterExperts( experts,
moe_config=moe_config, shared_experts=(
quant_config=quant_config, shared_experts
), if moe_config.moe_parallel_config.use_deepep_ll_kernels
inplace=not moe_config.disable_inplace, else None
) ),
elif backend == UnquantizedMoeBackend.TRITON: moe_parallel_config=moe_config.moe_parallel_config,
from vllm.model_executor.layers.fused_moe import TritonExperts inplace=(not moe_config.disable_inplace and not is_monolithic),
)
kernel = mk.FusedMoEKernel(
MoEPrepareAndFinalizeNoDPEPModular(),
TritonExperts(
moe_config=moe_config,
quant_config=quant_config,
),
inplace=not moe_config.disable_inplace,
)
elif backend == UnquantizedMoeBackend.XPU:
from vllm.model_executor.layers.fused_moe import XPUExperts
kernel = mk.FusedMoEKernel(
MoEPrepareAndFinalizeNoDPEPModular(),
XPUExperts(
moe_config=moe_config,
quant_config=quant_config,
),
inplace=not moe_config.disable_inplace,
)
return kernel return kernel
...@@ -6,11 +6,8 @@ from collections.abc import Callable ...@@ -6,11 +6,8 @@ from collections.abc import Callable
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import Module from torch.nn import Module
from torch.nn.parameter import Parameter
import vllm.envs as envs import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
...@@ -23,7 +20,6 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( ...@@ -23,7 +20,6 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase, FusedMoEMethodBase,
) )
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat,
FusedMoEExpertsModular, FusedMoEExpertsModular,
FusedMoEPrepareAndFinalizeModular, FusedMoEPrepareAndFinalizeModular,
) )
...@@ -33,20 +29,10 @@ from vllm.model_executor.layers.fused_moe.oracle.unquantized import ( ...@@ -33,20 +29,10 @@ from vllm.model_executor.layers.fused_moe.oracle.unquantized import (
make_unquantized_moe_kernel, make_unquantized_moe_kernel,
select_unquantized_moe_backend, select_unquantized_moe_backend,
) )
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
convert_moe_weights_to_flashinfer_trtllm_block_layout,
)
from vllm.model_executor.utils import replace_parameter, set_weight_attrs from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum from vllm.platforms.interface import CpuArchEnum
if current_platform.is_cuda_alike() or current_platform.is_xpu():
from .fused_batched_moe import BatchedTritonExperts
from .fused_moe import TritonExperts
else:
TritonExperts = None # type: ignore
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -59,45 +45,16 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -59,45 +45,16 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def __init__(self, moe: FusedMoEConfig): def __init__(self, moe: FusedMoEConfig):
super().__init__(moe) super().__init__(moe)
self.unquantized_backend = select_unquantized_moe_backend( self.unquantized_backend, self.experts_cls = select_unquantized_moe_backend(
moe_config=self.moe, moe_config=self.moe,
use_dp=self.moe.moe_parallel_config.dp_size > 1,
)
# AITER only supports gated activations (silu/gelu), so disable it
# for non-gated MoE (is_act_and_mul=False)
self.rocm_aiter_moe_enabled = (
rocm_aiter_ops.is_fused_moe_enabled() and moe.is_act_and_mul
) )
self.kernel: mk.FusedMoEKernel | None = None
self._is_monolithic = (
current_platform.is_cpu()
or self.unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM
)
if self.is_monolithic:
self.apply_monolithic: Callable = self._select_monolithic()
def _select_monolithic(self) -> Callable:
"""Select the monolithic implementation based on platform."""
if current_platform.is_cpu():
return self.forward_monolithic_cpu
else:
return self.forward_monolithic_cuda
def forward_native(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward_cuda(layer, x, topk_weights, topk_ids, shared_experts_input)
@property @property
def is_monolithic(self) -> bool: def is_monolithic(self) -> bool:
return self._is_monolithic # Escape hatch for CPU, which stays on the old monolithic path.
if self.unquantized_backend == UnquantizedMoeBackend.CPU:
return True
return super().is_monolithic
@property @property
def supports_eplb(self) -> bool: def supports_eplb(self) -> bool:
...@@ -106,43 +63,22 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -106,43 +63,22 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> FusedMoEPrepareAndFinalizeModular | None: ):
return super().maybe_make_prepare_finalize(routing_tables) raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic for all but the CPU backend. CPU backend is monolithic. "
"So this function should not be called."
)
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: FusedMoEPrepareAndFinalizeModular, prepare_finalize: FusedMoEPrepareAndFinalizeModular,
layer: torch.nn.Module, layer: torch.nn.Module,
) -> FusedMoEExpertsModular: ) -> FusedMoEExpertsModular:
assert self.moe_quant_config is not None raise ValueError(
if ( f"{self.__class__.__name__} uses the new modular kernel initialization "
prepare_finalize.activation_format "logic. This function should not be called."
== FusedMoEActivationFormat.BatchedExperts )
):
logger.debug("BatchedTritonExperts %s", self.moe)
return BatchedTritonExperts(
moe_config=self.moe,
quant_config=self.moe_quant_config,
max_num_tokens=self.moe.max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
)
elif (
self.unquantized_backend == UnquantizedMoeBackend.AITER
and rocm_aiter_ops.is_fused_moe_enabled()
):
from .rocm_aiter_fused_moe import AiterExperts
logger.debug("AiterExperts %s", self.moe)
return AiterExperts(
moe_config=self.moe,
quant_config=self.moe_quant_config,
)
else:
logger.debug("TritonExperts %s", self.moe)
return TritonExperts(
moe_config=self.moe,
quant_config=self.moe_quant_config,
)
def create_weights( def create_weights(
self, self,
...@@ -227,14 +163,17 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -227,14 +163,17 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
replace_parameter(layer, "w13_weight", w13) replace_parameter(layer, "w13_weight", w13)
replace_parameter(layer, "w2_weight", w2) replace_parameter(layer, "w2_weight", w2)
# Setup Modular Kernel for TP Case # Setup moe kernel.
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
assert self.moe_quant_config is not None assert self.moe_quant_config is not None
assert self.experts_cls is not None
self.kernel = make_unquantized_moe_kernel( self.moe_kernel = make_unquantized_moe_kernel(
backend=self.unquantized_backend,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
backend=self.unquantized_backend,
experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
) )
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
...@@ -244,22 +183,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -244,22 +183,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data) layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data)
layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data) layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
if self.unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM: if self.unquantized_backend in [
_cache_permute_indices: dict[torch.Size, torch.Tensor] = {} UnquantizedMoeBackend.TPU,
# Swap halves to arrange as [w3; w1] (kernel expectation) UnquantizedMoeBackend.OOT,
w1_w, w3_w = torch.chunk(layer.w13_weight.data, 2, dim=1) ]:
w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1) # OOT handles internally.
layer.w13_weight.data = w13_weight_swapped.contiguous() return
w13_weights_shuffled, w2_weights_shuffled = (
convert_moe_weights_to_flashinfer_trtllm_block_layout(
_cache_permute_indices,
layer.w13_weight.data,
layer.w2_weight.data,
)
)
layer.w13_weight = Parameter(w13_weights_shuffled, requires_grad=False)
layer.w2_weight = Parameter(w2_weights_shuffled, requires_grad=False)
elif self.unquantized_backend == UnquantizedMoeBackend.CPU: elif self.unquantized_backend == UnquantizedMoeBackend.CPU:
# CPU stays on the old path — no oracle, no moe_kernel.
from vllm.model_executor.layers.fused_moe import cpu_fused_moe from vllm.model_executor.layers.fused_moe import cpu_fused_moe
if current_platform.get_cpu_architecture() == CpuArchEnum.X86: if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
...@@ -290,13 +222,22 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -290,13 +222,22 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer) self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
else: else:
self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer) self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
elif current_platform.is_cuda_alike() or current_platform.is_xpu(): else:
self._setup_kernel( self._setup_kernel(
layer=layer, layer=layer,
w13=layer.w13_weight, w13=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
) )
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
if self.moe.has_bias:
return biased_moe_quant_config(
layer.w13_bias,
layer.w2_bias,
)
else:
return FUSED_MOE_UNQUANTIZED_CONFIG
def apply( def apply(
self, self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
...@@ -313,16 +254,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -313,16 +254,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
shared_experts_input=shared_experts_input, shared_experts_input=shared_experts_input,
) )
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig: def forward_native(
if self.moe.has_bias:
return biased_moe_quant_config(
layer.w13_bias,
layer.w2_bias,
)
else:
return FUSED_MOE_UNQUANTIZED_CONFIG
def forward_cuda(
self, self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor, x: torch.Tensor,
...@@ -330,9 +262,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -330,9 +262,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.kernel is not None assert self.moe_kernel is not None
return self.moe_kernel.apply(
return self.kernel.apply(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
...@@ -345,53 +276,58 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -345,53 +276,58 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
shared_experts_input=shared_experts_input, shared_experts_input=shared_experts_input,
) )
def forward_monolithic_cuda( def forward_cuda(
self, self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: F401 return self.forward_native(
layer, x, topk_weights, topk_ids, shared_experts_input
assert self.unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM
return torch.ops.vllm.flashinfer_fused_moe_bf16(
routing_logits=router_logits,
routing_bias=layer.e_score_correction_bias,
hidden_states=x,
gemm1_weights=layer.w13_weight,
gemm2_weights=layer.w2_weight,
num_experts=layer.global_num_experts,
top_k=layer.top_k,
n_group=layer.num_expert_group,
topk_group=layer.topk_group,
intermediate_size=layer.intermediate_size_per_partition,
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
routing_method_type=layer.routing_method_type,
) )
def forward_monolithic_cpu( def apply_monolithic(
self, self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.cpu_fused_moe( assert self.is_monolithic
layer, if self.unquantized_backend == UnquantizedMoeBackend.CPU:
x, assert self.moe_kernel is None
layer.use_grouped_topk, return self.cpu_fused_moe(
layer.top_k, layer,
router_logits, x,
layer.renormalize, layer.use_grouped_topk,
layer.topk_group, layer.top_k,
layer.num_expert_group, router_logits,
layer.global_num_experts, layer.renormalize,
layer.expert_map, layer.topk_group,
layer.custom_routing_function, layer.num_expert_group,
layer.scoring_func, layer.global_num_experts,
layer.routed_scaling_factor, layer.expert_map,
layer.e_score_correction_bias, layer.custom_routing_function,
layer.apply_router_weight_on_input, layer.scoring_func,
layer.activation, layer.routed_scaling_factor,
) layer.e_score_correction_bias,
layer.apply_router_weight_on_input,
layer.activation,
)
else:
assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic(
x,
layer.w13_weight,
layer.w2_weight,
router_logits,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
e_score_correction_bias=layer.e_score_correction_bias,
routed_scaling_factor=layer.routed_scaling_factor,
)
...@@ -202,6 +202,7 @@ def has_flashinfer_trtllm_fused_moe() -> bool: ...@@ -202,6 +202,7 @@ def has_flashinfer_trtllm_fused_moe() -> bool:
("flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe"), ("flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe"),
("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"), ("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"),
("flashinfer.fused_moe", "trtllm_mxint4_block_scale_moe"), ("flashinfer.fused_moe", "trtllm_mxint4_block_scale_moe"),
("flashinfer.fused_moe", "trtllm_bf16_moe"),
] ]
for module_name, attr_name in required_functions: for module_name, attr_name in required_functions:
mod = _get_submodule(module_name) mod = _get_submodule(module_name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment