Unverified Commit 87bd9189 authored by Yongye Zhu's avatar Yongye Zhu Committed by GitHub
Browse files

[MoE Refactor] Mxfp4 oracle rebased (#37128)


Signed-off-by: default avatarYongye Zhu <zyy1102000@gmail.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent c7f98b4d
...@@ -88,8 +88,8 @@ To be used with a particular `FusedMoEPrepareAndFinalizeModular` subclass, MoE k ...@@ -88,8 +88,8 @@ To be used with a particular `FusedMoEPrepareAndFinalizeModular` subclass, MoE k
| flashinfer | standard | nvfp4,</br>fp8 | T | <sup>5</sup> | N | Y | [`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] | | flashinfer | standard | nvfp4,</br>fp8 | T | <sup>5</sup> | N | Y | [`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] |
| gpt oss triton | standard | N/A | N/A | <sup>5</sup> | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],</br>[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] | | gpt oss triton | standard | N/A | N/A | <sup>5</sup> | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],</br>[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] |
| marlin | standard,</br>batched | <sup>3</sup> / N/A | <sup>3</sup> / N/A | silu,</br>swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],</br>[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],</br>[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] | | marlin | standard,</br>batched | <sup>3</sup> / N/A | <sup>3</sup> / N/A | silu,</br>swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],</br>[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],</br>[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] |
| trtllm | standard | mxfp4,</br>nvfp4 | G(16),G(32) | <sup>5</sup> | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] | | trtllm | standard | mxfp4,</br>nvfp4 | G(16),G(32) | <sup>5</sup> | N | Y | [`TrtLlmMxfp4ExpertsMonolithic`][vllm.model_executor.layers.fused_moe.experts.trtllm_mxfp4_moe.TrtLlmMxfp4ExpertsMonolithic],</br>[`TrtLlmMxfp4ExpertsModular`][vllm.model_executor.layers.fused_moe.experts.trtllm_mxfp4_moe.TrtLlmMxfp4ExpertsModular],</br>[`TrtLlmNvFp4ExpertsMonolithic`][vllm.model_executor.layers.fused_moe.experts.trtllm_nvfp4_moe.TrtLlmNvFp4ExpertsMonolithic],</br>[`TrtLlmNvfp4ExpertsModular`][vllm.model_executor.layers.fused_moe.experts.trtllm_nvfp4_moe.TrtLlmNvFp4ExpertsModular] |
| rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_experts] | | rocm aiter moe | standard | mxfp4,</br>fp8 | G(32),G(128),A,T | silu, gelu,</br>swigluoai | Y | N | `rocm_aiter_fused_experts`,</br>`AiterExperts` |
| cpu_fused_moe | standard | N/A | N/A | silu | N | N | [`CPUFusedMOE`][vllm.model_executor.layers.fused_moe.cpu_fused_moe.CPUFusedMOE] | | cpu_fused_moe | standard | N/A | N/A | silu | N | N | [`CPUFusedMOE`][vllm.model_executor.layers.fused_moe.cpu_fused_moe.CPUFusedMOE] |
| naive batched<sup>4</sup> | batched | int8,</br>fp8 | G,A,T | silu, gelu | <sup>6</sup> | Y | [`NaiveBatchedExperts`][vllm.model_executor.layers.fused_moe.fused_batched_moe.NaiveBatchedExperts] | | naive batched<sup>4</sup> | batched | int8,</br>fp8 | G,A,T | silu, gelu | <sup>6</sup> | Y | [`NaiveBatchedExperts`][vllm.model_executor.layers.fused_moe.fused_batched_moe.NaiveBatchedExperts] |
......
...@@ -84,7 +84,10 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn): ...@@ -84,7 +84,10 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
# TODO: remove this after finishing migration from envs to model kwargs # TODO: remove this after finishing migration from envs to model kwargs
if model_name == "openai/gpt-oss-20b": if model_name == "openai/gpt-oss-20b":
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "1") from .common import is_blackwell
if is_blackwell():
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "1")
# Disable, compile cache to make sure custom passes run. # Disable, compile cache to make sure custom passes run.
# Otherwise, we can't verify fusion happened through the logs. # Otherwise, we can't verify fusion happened through the logs.
......
...@@ -6,6 +6,7 @@ import pytest ...@@ -6,6 +6,7 @@ import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from vllm.platforms import current_platform
from vllm.utils.import_utils import has_triton_kernels from vllm.utils.import_utils import has_triton_kernels
if not has_triton_kernels(): if not has_triton_kernels():
...@@ -14,6 +15,7 @@ if not has_triton_kernels(): ...@@ -14,6 +15,7 @@ if not has_triton_kernels():
allow_module_level=True, allow_module_level=True,
) )
import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
import triton_kernels.swiglu import triton_kernels.swiglu
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
from triton_kernels.numerics import InFlexData from triton_kernels.numerics import InFlexData
...@@ -303,6 +305,12 @@ def test_equiv(num_token, a_dtype, w_dtype, tp, workspace_init): ...@@ -303,6 +305,12 @@ def test_equiv(num_token, a_dtype, w_dtype, tp, workspace_init):
pc2, pc2,
) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=8) ) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=8)
if current_platform.is_device_capability_family(100):
constraints = {
"is_persistent": True,
}
opt_flags.update_opt_flags_constraints(constraints)
if a_dtype == "bf16" and w_dtype == "mx4": if a_dtype == "bf16" and w_dtype == "mx4":
quant_config = mxfp4_w4a16_moe_quant_config( quant_config = mxfp4_w4a16_moe_quant_config(
w1_scale=pc1, w1_scale=pc1,
......
...@@ -82,7 +82,7 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase): ...@@ -82,7 +82,7 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
model_case.model_id, model_case.model_id,
tensor_parallel_size=model_case.tp, tensor_parallel_size=model_case.tp,
load_format="dummy", load_format="dummy",
cudagraph_capture_sizes=[16], compilation_config={"cudagraph_capture_sizes": [16]},
) as llm: ) as llm:
# Disabled as check_model is broken: https://github.com/vllm-project/vllm/pull/18465#issuecomment-3329880562 # Disabled as check_model is broken: https://github.com/vllm-project/vllm/pull/18465#issuecomment-3329880562
# def check_model(model): # def check_model(model):
......
...@@ -17,89 +17,6 @@ from unittest.mock import MagicMock, patch ...@@ -17,89 +17,6 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
import torch import torch
from vllm.model_executor.layers.quantization.mxfp4 import (
Mxfp4Backend,
Mxfp4MoEMethod,
)
def _make_mock_moe_config(ep_size: int = 1) -> MagicMock:
"""Create a mock FusedMoEConfig with the given EP size."""
parallel_config = MagicMock()
parallel_config.ep_size = ep_size
moe_config = MagicMock()
moe_config.ep_size = ep_size
moe_config.is_lora_enabled = False
moe_config.moe_parallel_config = parallel_config
return moe_config
class TestMxfp4TritonIsMonolithic:
"""Verify that is_monolithic is always True for the TRITON backend,
regardless of EP size, since triton_kernel_moe_forward now handles
expert_map remapping internally."""
@pytest.mark.parametrize(
"backend,ep_size,expected_monolithic",
[
# TRITON is always monolithic (handles EP via expert_map remapping)
(Mxfp4Backend.TRITON, 1, True),
(Mxfp4Backend.TRITON, 2, True),
(Mxfp4Backend.TRITON, 4, True),
# SM100 backends are always monolithic
(Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM, 1, True),
(Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM, 2, True),
(Mxfp4Backend.SM100_FI_MXFP4_BF16, 1, True),
(Mxfp4Backend.SM100_FI_MXFP4_BF16, 2, True),
# MARLIN is never monolithic
(Mxfp4Backend.MARLIN, 1, False),
(Mxfp4Backend.MARLIN, 2, False),
],
ids=[
"triton-no-ep",
"triton-ep2",
"triton-ep4",
"sm100-trtllm-no-ep",
"sm100-trtllm-ep2",
"sm100-bf16-no-ep",
"sm100-bf16-ep2",
"marlin-no-ep",
"marlin-ep2",
],
)
@patch(
"vllm.model_executor.layers.quantization.mxfp4.get_mxfp4_backend",
)
@patch(
"vllm.model_executor.layers.quantization.mxfp4.get_current_vllm_config",
)
def test_is_monolithic(
self,
mock_get_config,
mock_get_backend,
backend,
ep_size,
expected_monolithic,
):
"""is_monolithic should be True for TRITON regardless of EP size."""
mock_get_backend.return_value = backend
mock_compilation_config = MagicMock()
mock_compilation_config.max_cudagraph_capture_size = 1024
mock_vllm_config = MagicMock()
mock_vllm_config.compilation_config = mock_compilation_config
mock_get_config.return_value = mock_vllm_config
moe_config = _make_mock_moe_config(ep_size=ep_size)
method = Mxfp4MoEMethod(moe_config)
assert method.is_monolithic == expected_monolithic, (
f"Expected is_monolithic={expected_monolithic} for "
f"backend={backend.name}, ep_size={ep_size}, "
f"but got {method.is_monolithic}."
)
class TestTritonMoeForwardExpertMap: class TestTritonMoeForwardExpertMap:
"""Test that triton_kernel_moe_forward applies expert_map remapping """Test that triton_kernel_moe_forward applies expert_map remapping
......
...@@ -9,80 +9,236 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -9,80 +9,236 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEParallelConfig, FusedMoEParallelConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
RoutingMethodType,
) )
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP, TopKWeightAndReduceNoOP,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, QuantKey,
kMxfp4Static,
kMxfp8Dynamic,
) )
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer
class TrtLlmGenExperts(mk.FusedMoEExpertsModular): class TrtLlmMxfp4ExpertsBase:
"""TensorRT-LLM-based fused MoE expert implementation.""" """
MXFP4 TRTLLM-Gen MoE kernels. Shared base for modular and monolithic.
"""
def __init__( def __init__(
self, self,
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
max_capture_size,
): ):
super().__init__(moe_config, quant_config) # NOTE: FusedMoEExperts.__init__ is called by the concrete subclass
self.device = torch.accelerator.current_device_index() # (Monolithic/Modular) via MRO, not here, to avoid mypy issues with
self.num_experts = moe_config.num_local_experts # multiple inheritance. This matches the NvFP4 expert pattern.
self.moe_config = moe_config
self.quant_config = quant_config
self.routing_method_type = moe_config.routing_method
self.topk = moe_config.experts_per_token
self.intermediate_size_per_partition = (
moe_config.intermediate_size_per_partition
)
self.hidden_dim = moe_config.hidden_dim
self.local_num_experts = moe_config.num_local_experts
self.ep_rank = moe_config.moe_parallel_config.ep_rank
# MXFP4-specific TRTLLM parameters
device = torch.accelerator.current_device_index()
self.gemm1_alpha = torch.tensor( self.gemm1_alpha = torch.tensor(
[1.702] * self.num_experts, dtype=torch.float32, device=self.device [1.702] * self.local_num_experts,
dtype=torch.float32,
device=device,
) )
self.gemm1_beta = torch.tensor( self.gemm1_beta = torch.tensor(
[1.0] * self.num_experts, dtype=torch.float32, device=self.device [1.0] * self.local_num_experts,
dtype=torch.float32,
device=device,
) )
self.gemm1_clamp_limit = torch.tensor( self.gemm1_clamp_limit = torch.tensor(
[7.0] * self.num_experts, dtype=torch.float32, device=self.device [7.0] * self.local_num_experts,
dtype=torch.float32,
device=device,
) )
self.max_capture_size = max_capture_size
@staticmethod from vllm.config import get_current_vllm_config
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard self.max_capture_size = (
get_current_vllm_config().compilation_config.max_cudagraph_capture_size
)
# P1-5 fix: use public quant_dtype property instead of private _a1
self.use_mxfp8_input = quant_config.quant_dtype == "mxfp8"
@staticmethod @staticmethod
def _supports_current_device() -> bool: def _supports_current_device() -> bool:
raise NotImplementedError( p = current_platform
"TrtLlmGenExperts is not yet used by an Oracle. " return p.is_cuda() and p.is_device_capability_family(100) and has_flashinfer()
"This method should not be called."
)
@staticmethod @staticmethod
def _supports_no_act_and_mul() -> bool: def _supports_no_act_and_mul() -> bool:
raise NotImplementedError( return False
"TrtLlmGenExperts is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod @staticmethod
def _supports_quant_scheme( def _supports_quant_scheme(
weight_key: QuantKey | None, weight_key: QuantKey | None,
activation_key: QuantKey | None, activation_key: QuantKey | None,
) -> bool: ) -> bool:
raise NotImplementedError( SUPPORTED_W_A = [
"TrtLlmGenExperts is not yet used by an Oracle. " (kMxfp4Static, None),
"This method should not be called." (kMxfp4Static, kMxfp8Dynamic),
) ]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod @staticmethod
def _supports_activation(activation: MoEActivation) -> bool: def _supports_activation(activation: MoEActivation) -> bool:
raise NotImplementedError( return activation == MoEActivation.SWIGLUOAI
"TrtLlmGenExperts is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod @staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: def activation_format() -> mk.FusedMoEActivationFormat:
raise NotImplementedError( return mk.FusedMoEActivationFormat.Standard
"TrtLlmGenExperts is not yet used by an Oracle. "
"This method should not be called." def supports_chunking(self) -> bool:
return False
def supports_expert_map(self) -> bool:
return False
@property
def expects_unquantized_inputs(self) -> bool:
# Expert handles MXFP8 quantization internally if needed
return True
class TrtLlmMxfp4ExpertsMonolithic(
TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsMonolithic
):
"""
Monolithic version of the MXFP4 TRTLLM kernel (router + experts).
Wraps flashinfer.trtllm_fp4_block_scale_moe().
"""
@staticmethod
def _supports_parallel_config(
moe_parallel_config: FusedMoEParallelConfig,
) -> bool:
return (
not moe_parallel_config.use_all2all_kernels
and not moe_parallel_config.enable_eplb
and moe_parallel_config.dp_size <= 1
) )
@staticmethod
def _supports_routing_method(
routing_method: RoutingMethodType,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
return routing_method in [
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
]
@staticmethod
def _supports_router_logits_dtype(
router_logits_dtype: torch.dtype | None,
routing_method: RoutingMethodType,
) -> bool:
# Kernel converts to bfloat16 internally
return True
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
apply_router_weight_on_input: bool,
# grouped topk + fused topk bias parameters
num_expert_group: int | None = None,
e_score_correction_bias: torch.Tensor | None = None,
routed_scaling_factor: float | None = None,
topk_group: int | None = None,
) -> torch.Tensor:
from flashinfer import trtllm_fp4_block_scale_moe
# Handle input quantization
if self.use_mxfp8_input:
from flashinfer import mxfp8_quantize
x_quant, x_scale = mxfp8_quantize(
hidden_states,
is_sf_swizzled_layout=False,
alignment=256,
)
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(
*hidden_states.shape[:-1], -1
)
else:
assert hidden_states.dtype == torch.bfloat16
x_quant = hidden_states
x_scale = None
output = torch.empty_like(hidden_states)
return trtllm_fp4_block_scale_moe(
routing_logits=router_logits.to(torch.bfloat16),
routing_bias=None,
hidden_states=x_quant,
hidden_states_scale=x_scale,
gemm1_weights=w1,
gemm1_weights_scale=self.w1_scale,
gemm1_bias=self.w1_bias,
gemm1_alpha=self.gemm1_alpha,
gemm1_beta=self.gemm1_beta,
gemm1_clamp_limit=self.gemm1_clamp_limit,
gemm2_weights=w2,
gemm2_weights_scale=self.w2_scale,
gemm2_bias=self.w2_bias,
output1_scale_scalar=None,
output1_scale_gate_scalar=None,
output2_scale_scalar=None,
num_experts=global_num_experts,
top_k=self.topk,
n_group=None,
topk_group=None,
intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.ep_rank * self.local_num_experts,
local_num_experts=self.local_num_experts,
routed_scaling_factor=None,
routing_method_type=self.routing_method_type,
do_finalize=True,
tune_max_num_tokens=max(self.max_capture_size, 1),
output=output,
)[0]
class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModular):
"""
Modular version of the MXFP4 TRTLLM kernel (just the experts).
Wraps flashinfer.trtllm_fp4_block_scale_routed_moe().
Moved from trtllm_moe.py.
"""
@property
def expects_unquantized_inputs(self) -> bool:
return True
@staticmethod
def _supports_parallel_config(
moe_parallel_config: FusedMoEParallelConfig,
) -> bool:
return True
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
return True return True
...@@ -129,10 +285,22 @@ class TrtLlmGenExperts(mk.FusedMoEExpertsModular): ...@@ -129,10 +285,22 @@ class TrtLlmGenExperts(mk.FusedMoEExpertsModular):
intermediate_size = w2.size(1) intermediate_size = w2.size(1)
local_expert_offset = self.moe_config.ep_rank * local_num_experts local_expert_offset = self.moe_config.ep_rank * local_num_experts
x_quant = hidden_states # Handle input quantization
x_scale = a1q_scale if self.use_mxfp8_input:
if x_scale is not None: from flashinfer import mxfp8_quantize
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x_quant.shape[:-1], -1)
x_quant, x_scale = mxfp8_quantize(
hidden_states,
is_sf_swizzled_layout=False,
alignment=256,
)
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(
*hidden_states.shape[:-1], -1
)
else:
assert hidden_states.dtype == torch.bfloat16
x_quant = hidden_states
x_scale = None
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to( packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
torch.bfloat16 torch.bfloat16
...@@ -165,7 +333,7 @@ class TrtLlmGenExperts(mk.FusedMoEExpertsModular): ...@@ -165,7 +333,7 @@ class TrtLlmGenExperts(mk.FusedMoEExpertsModular):
"local_expert_offset": local_expert_offset, "local_expert_offset": local_expert_offset,
"local_num_experts": local_num_experts, "local_num_experts": local_num_experts,
"routed_scaling_factor": None, "routed_scaling_factor": None,
"routing_method_type": 1, "routing_method_type": self.routing_method_type,
"do_finalize": True, "do_finalize": True,
"output": output, "output": output,
"tune_max_num_tokens": max(self.max_capture_size, 1), "tune_max_num_tokens": max(self.max_capture_size, 1),
......
...@@ -40,6 +40,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -40,6 +40,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8Static128BlockSym, kFp8Static128BlockSym,
kFp8StaticChannelSym, kFp8StaticChannelSym,
kFp8StaticTensorSym, kFp8StaticTensorSym,
kMxfp4Static,
kNvfp4Static, kNvfp4Static,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -574,12 +575,13 @@ class MarlinExpertsBase(mk.FusedMoEExpertsModular): ...@@ -574,12 +575,13 @@ class MarlinExpertsBase(mk.FusedMoEExpertsModular):
weight_key: QuantKey | None, weight_key: QuantKey | None,
activation_key: QuantKey | None, activation_key: QuantKey | None,
) -> bool: ) -> bool:
# TODO(rob): add int4, mxfp4, int8 as integrations # TODO(rob): add int4, int8 as integrations
# are migrated to use the oracle one-by-one. # are migrated to use the oracle one-by-one.
SUPPORTED_W = [ SUPPORTED_W = [
kFp8Static128BlockSym, kFp8Static128BlockSym,
kFp8StaticChannelSym, kFp8StaticChannelSym,
kFp8StaticTensorSym, kFp8StaticTensorSym,
kMxfp4Static,
kNvfp4Static, kNvfp4Static,
] ]
return weight_key in SUPPORTED_W return weight_key in SUPPORTED_W
......
...@@ -11,8 +11,10 @@ from vllm.logger import init_logger ...@@ -11,8 +11,10 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, FUSED_MOE_UNQUANTIZED_CONFIG,
FusedMoEConfig,
FusedMoEParallelConfig, FusedMoEParallelConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
RoutingMethodType,
) )
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP, TopKWeightAndReduceNoOP,
...@@ -20,6 +22,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( ...@@ -20,6 +22,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, QuantKey,
kMxfp4Static,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
...@@ -537,43 +540,43 @@ def make_routing_data( ...@@ -537,43 +540,43 @@ def make_routing_data(
class BaseOAITritonExperts(mk.FusedMoEExpertsModular): class BaseOAITritonExperts(mk.FusedMoEExpertsModular):
@property
def expects_unquantized_inputs(self) -> bool:
return True
@staticmethod @staticmethod
def _supports_current_device() -> bool: def _supports_current_device() -> bool:
raise NotImplementedError( p = current_platform
"OAITritonExperts is not yet used by an Oracle. " if not p.is_cuda_alike():
"This method should not be called." return False
) cap = p.get_device_capability()
if cap is None:
return False
# (9,0) <= cap < (11,0) covers CUDA SM90 (Hopper), SM100+ (Blackwell)
# and ROCm gfx942/gfx950 (which map to 9.4/9.5).
return (9, 0) <= (cap.major, cap.minor) < (11, 0)
@staticmethod @staticmethod
def _supports_no_act_and_mul() -> bool: def _supports_no_act_and_mul() -> bool:
raise NotImplementedError( return False
"OAITritonExperts is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod @staticmethod
def _supports_quant_scheme( def _supports_quant_scheme(
weight_key: QuantKey | None, weight_key: QuantKey | None,
activation_key: QuantKey | None, activation_key: QuantKey | None,
) -> bool: ) -> bool:
raise NotImplementedError( SUPPORTED_W_A = [
"OAITritonExperts is not yet used by an Oracle. " (kMxfp4Static, None),
"This method should not be called." ]
) return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod @staticmethod
def _supports_activation(activation: MoEActivation) -> bool: def _supports_activation(activation: MoEActivation) -> bool:
raise NotImplementedError( raise NotImplementedError
"OAITritonExperts is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod @staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
raise NotImplementedError( return True
"OAITritonExperts is not yet used by an Oracle. "
"This method should not be called."
)
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
return True return True
...@@ -630,6 +633,10 @@ class BaseOAITritonExperts(mk.FusedMoEExpertsModular): ...@@ -630,6 +633,10 @@ class BaseOAITritonExperts(mk.FusedMoEExpertsModular):
class OAITritonExperts(BaseOAITritonExperts): class OAITritonExperts(BaseOAITritonExperts):
"""OAI Triton-based fused MoE expert implementation.""" """OAI Triton-based fused MoE expert implementation."""
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
return activation == MoEActivation.SWIGLUOAI
@staticmethod @staticmethod
def activation_format() -> mk.FusedMoEActivationFormat: def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard return mk.FusedMoEActivationFormat.Standard
...@@ -714,6 +721,15 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts): ...@@ -714,6 +721,15 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
One use case for it is to inject LoRA modules on the activation and moe_sum. One use case for it is to inject LoRA modules on the activation and moe_sum.
""" """
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
return activation in [
MoEActivation.SILU,
MoEActivation.GELU,
MoEActivation.SWIGLUOAI,
MoEActivation.SWIGLUSTEP,
]
@staticmethod @staticmethod
def activation_format() -> mk.FusedMoEActivationFormat: def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard return mk.FusedMoEActivationFormat.Standard
...@@ -839,3 +855,118 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts): ...@@ -839,3 +855,118 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
) )
self.moe_sum(intermediate_cache3.view(-1, topk, K), output) self.moe_sum(intermediate_cache3.view(-1, topk, K), output)
class OAITritonMxfp4ExpertsMonolithic(mk.FusedMoEExpertsMonolithic):
"""Monolithic Triton MXFP4 expert. Wraps triton_kernel_moe_forward()."""
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
):
super().__init__(moe_config, quant_config)
self.topk = moe_config.experts_per_token
self.renormalize = moe_config.routing_method in (
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
)
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
@staticmethod
def _supports_current_device() -> bool:
p = current_platform
if not p.is_cuda_alike():
return False
cap = p.get_device_capability()
if cap is None:
return False
# (9,0) <= cap < (11,0) covers CUDA SM90 (Hopper), SM100+ (Blackwell)
# and ROCm gfx942/gfx950 (which map to 9.4/9.5).
return (9, 0) <= (cap.major, cap.minor) < (11, 0)
@staticmethod
def _supports_no_act_and_mul() -> bool:
return False
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
SUPPORTED_W_A = [
(kMxfp4Static, None),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
return activation == MoEActivation.SWIGLUOAI
@staticmethod
def _supports_parallel_config(
moe_parallel_config: FusedMoEParallelConfig,
) -> bool:
return (
not moe_parallel_config.use_all2all_kernels
and not moe_parallel_config.enable_eplb
and moe_parallel_config.dp_size <= 1
)
@staticmethod
def _supports_routing_method(
routing_method: RoutingMethodType,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
return routing_method in [
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
]
@staticmethod
def _supports_router_logits_dtype(
router_logits_dtype: torch.dtype | None,
routing_method: RoutingMethodType,
) -> bool:
return True
def supports_expert_map(self) -> bool:
return True
@property
def expects_unquantized_inputs(self) -> bool:
return True
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
apply_router_weight_on_input: bool,
# grouped topk + fused topk bias parameters
num_expert_group: int | None = None,
e_score_correction_bias: torch.Tensor | None = None,
routed_scaling_factor: float | None = None,
topk_group: int | None = None,
) -> torch.Tensor:
return triton_kernel_moe_forward(
hidden_states=hidden_states,
w1=w1,
w2=w2,
gating_output=router_logits,
topk=self.topk,
renormalize=self.renormalize,
global_num_experts=global_num_experts,
expert_map=expert_map,
quant_config=self.quant_config,
apply_router_weight_on_input=apply_router_weight_on_input,
)
...@@ -52,7 +52,6 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -52,7 +52,6 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -218,7 +217,6 @@ def maybe_roundup_hidden_size( ...@@ -218,7 +217,6 @@ def maybe_roundup_hidden_size(
moe_parallel_config: FusedMoEParallelConfig, moe_parallel_config: FusedMoEParallelConfig,
is_lora_enabled: bool, is_lora_enabled: bool,
model_type: str | None, model_type: str | None,
is_mxfp4_quant: bool,
) -> int: ) -> int:
""" """
Given layer hidden size and MoE configurations, round up hidden_size Given layer hidden size and MoE configurations, round up hidden_size
...@@ -232,7 +230,6 @@ def maybe_roundup_hidden_size( ...@@ -232,7 +230,6 @@ def maybe_roundup_hidden_size(
is used in the case of mxfp4 quantization in selecting the is used in the case of mxfp4 quantization in selecting the
MxFP4Backend. MxFP4Backend.
model_type: for checking if gpt-oss model_type: for checking if gpt-oss
is_mxfp4_quant: whether the layer is quantized with mxfp4
Return: Return:
Rounded up hidden_size if rounding up is required based on the configs. Rounded up hidden_size if rounding up is required based on the configs.
...@@ -246,28 +243,6 @@ def maybe_roundup_hidden_size( ...@@ -246,28 +243,6 @@ def maybe_roundup_hidden_size(
hidden_size, act_dtype, moe_parallel_config hidden_size, act_dtype, moe_parallel_config
) )
# we are padding globally so EP buffer allocation works
if model_type == "gpt_oss" and is_mxfp4_quant:
from vllm.model_executor.layers.quantization.mxfp4 import (
Mxfp4Backend,
get_mxfp4_backend,
)
current_mxfp4_backend = get_mxfp4_backend(is_lora_enabled)
if (
current_mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
):
hidden_size = round_up(hidden_size, 128)
elif (
current_platform.is_rocm()
or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
or current_mxfp4_backend == Mxfp4Backend.MARLIN
):
hidden_size = round_up(hidden_size, 256)
return hidden_size return hidden_size
...@@ -540,9 +515,6 @@ class FusedMoE(CustomOp): ...@@ -540,9 +515,6 @@ class FusedMoE(CustomOp):
moe_parallel_config=self.moe_parallel_config, moe_parallel_config=self.moe_parallel_config,
is_lora_enabled=vllm_config.lora_config is not None, is_lora_enabled=vllm_config.lora_config is not None,
model_type=self.model_type, model_type=self.model_type,
is_mxfp4_quant=(
quant_config is not None and quant_config.is_mxfp4_quant(prefix, self)
),
) )
self.hidden_size = hidden_size self.hidden_size = hidden_size
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
from typing import Union
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
FusedMoEConfig,
)
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
mxfp4_mxfp8_moe_quant_config,
mxfp4_w4a16_moe_quant_config,
ocp_mx_moe_quant_config,
)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
_swizzle_mxfp4,
get_padding_alignment,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kMxfp4Static,
kMxfp8Dynamic,
)
from vllm.platforms import current_platform
from vllm.utils.import_utils import has_triton_kernels
from vllm.utils.math_utils import round_up
logger = init_logger(__name__)
if has_triton_kernels():
try:
from triton_kernels.matmul_ogs import PrecisionConfig
except (ImportError, AttributeError) as e:
logger.error(
"Failed to import Triton kernels. Please make sure your triton "
"version is compatible. Error: %s",
e,
)
class Mxfp4MoeBackend(Enum):
NONE = "None"
# FlashInfer TRTLLM backends
FLASHINFER_TRTLLM_MXFP4_MXFP8 = "FLASHINFER_TRTLLM_MXFP4_MXFP8"
FLASHINFER_TRTLLM_MXFP4_BF16 = "FLASHINFER_TRTLLM_MXFP4_BF16"
# FlashInfer CUTLASS backends
FLASHINFER_CUTLASS_MXFP4_MXFP8 = "FLASHINFER_CUTLASS_MXFP4_MXFP8"
FLASHINFER_CUTLASS_MXFP4_BF16 = "FLASHINFER_CUTLASS_MXFP4_BF16"
# Marlin
BATCHED_MARLIN = "BATCHED_MARLIN"
MARLIN = "MARLIN"
# ROCm AITER (CK)
CK = "CK"
# Triton
TRITON = "TRITON"
TRITON_UNFUSED = "TRITON_UNFUSED"
# XPU
XPU = "XPU"
# Backends that share the same TRTLLM weight format
TRTLLM_BACKENDS = (
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
)
TRITON_BACKENDS = (
Mxfp4MoeBackend.TRITON,
Mxfp4MoeBackend.TRITON_UNFUSED,
)
def backend_to_kernel_cls(
backend: Mxfp4MoeBackend,
) -> list[type[mk.FusedMoEExperts]]:
if backend in (
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
):
from vllm.model_executor.layers.fused_moe.experts.trtllm_mxfp4_moe import (
TrtLlmMxfp4ExpertsModular,
TrtLlmMxfp4ExpertsMonolithic,
)
# NOTE: prefer Monolithic > Modular, so return Monolithic first.
return [TrtLlmMxfp4ExpertsMonolithic, TrtLlmMxfp4ExpertsModular]
elif backend in (
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
):
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
return [FlashInferExperts]
elif backend == Mxfp4MoeBackend.TRITON:
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
OAITritonExperts,
OAITritonMxfp4ExpertsMonolithic,
)
# NOTE: prefer Monolithic > Modular, so return Monolithic first.
return [OAITritonMxfp4ExpertsMonolithic, OAITritonExperts]
elif backend == Mxfp4MoeBackend.TRITON_UNFUSED:
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
UnfusedOAITritonExperts,
)
return [UnfusedOAITritonExperts]
elif backend == Mxfp4MoeBackend.MARLIN:
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts,
)
return [MarlinExperts]
elif backend == Mxfp4MoeBackend.BATCHED_MARLIN:
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
BatchedMarlinExperts,
)
return [BatchedMarlinExperts]
elif backend == Mxfp4MoeBackend.CK:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
AiterExperts,
)
return [AiterExperts]
elif backend == Mxfp4MoeBackend.XPU:
raise NotImplementedError("XPU backend uses XpuMxfp4MoEMethod directly.")
else:
raise ValueError(f"Unknown MXFP4 MoE backend: {backend.value}")
def map_mxfp4_backend(runner_backend: str) -> Mxfp4MoeBackend:
"""Map user's moe_backend string to Mxfp4MoeBackend."""
mapping: dict[str, Mxfp4MoeBackend] = {
"flashinfer_trtllm": Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
"flashinfer_trtllm_afp8": Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
"flashinfer_cutlass": Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
"flashinfer_cutlass_afp8": Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
"triton": Mxfp4MoeBackend.TRITON,
"marlin": Mxfp4MoeBackend.MARLIN,
"ck": Mxfp4MoeBackend.CK,
}
if backend := mapping.get(runner_backend):
return backend
raise ValueError(
f"moe_backend='{runner_backend}' is not supported for MXFP4 MoE. "
f"Expected one of {list(mapping.keys())}."
)
def _get_priority_backends() -> list[Mxfp4MoeBackend]:
"""
Get available backends in priority order based on platform and config.
Only includes BF16 backends. MXFP8 backends are selected via env vars.
"""
_AVAILABLE_BACKENDS = [
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
Mxfp4MoeBackend.CK,
Mxfp4MoeBackend.TRITON,
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
Mxfp4MoeBackend.TRITON_UNFUSED,
Mxfp4MoeBackend.MARLIN,
Mxfp4MoeBackend.BATCHED_MARLIN,
]
return _AVAILABLE_BACKENDS
def _backend_activation_key(backend: Mxfp4MoeBackend) -> QuantKey | None:
"""Map backend to its activation key (MXFP8 or None for BF16)."""
if backend in (
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
):
return kMxfp8Dynamic
return None
def select_mxfp4_moe_backend(
config: FusedMoEConfig,
) -> tuple[Mxfp4MoeBackend, type[mk.FusedMoEExperts] | None]:
"""
Select the primary MXFP4 MoE backend.
Note: Shape-specific fallbacks may still occur at runtime.
"""
triton_kernels_supported = has_triton_kernels() and (
9,
0,
) <= current_platform.get_device_capability() < (11, 0)
# LoRA: separate experts backend path
if config.is_lora_enabled:
if not current_platform.is_cuda():
raise NotImplementedError("Mxfp4 LoRA only supported on CUDA Platform.")
if envs.VLLM_MXFP4_USE_MARLIN is False and triton_kernels_supported:
logger.info_once("Using Triton backend for mxfp4 lora")
return Mxfp4MoeBackend.TRITON_UNFUSED, backend_to_kernel_cls(
Mxfp4MoeBackend.TRITON_UNFUSED
)[0]
logger.info_once("Using Marlin backend for mxfp4 lora")
return Mxfp4MoeBackend.MARLIN, backend_to_kernel_cls(Mxfp4MoeBackend.MARLIN)[0]
activation_format = (
mk.FusedMoEActivationFormat.BatchedExperts
if config.moe_parallel_config.use_batched_activation_format
else mk.FusedMoEActivationFormat.Standard
)
def _make_log_backend(backend: Mxfp4MoeBackend):
return f"Using '{backend.value}' Mxfp4 MoE backend."
def _make_log_unsupported(backend: Mxfp4MoeBackend, reason: str | None) -> str:
if reason:
return (
f"Mxfp4 MoE backend '{backend.value}' does not support the "
f"deployment configuration since {reason}."
)
return (
f"Mxfp4 MoE backend '{backend.value}' does not support the "
"deployment configuration."
)
def _return_or_raise(
backend: Mxfp4MoeBackend,
config: FusedMoEConfig,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
activation_format: mk.FusedMoEActivationFormat,
) -> tuple[Mxfp4MoeBackend, type[mk.FusedMoEExperts]]:
reason: str | None = None
for k_cls in backend_to_kernel_cls(backend):
supported, reason = k_cls.is_supported_config(
k_cls, config, weight_key, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
raise ValueError(_make_log_unsupported(backend, reason))
runner_backend = config.moe_backend
if runner_backend != "auto":
requested_backend = map_mxfp4_backend(runner_backend)
if (
activation_format == mk.FusedMoEActivationFormat.BatchedExperts
and requested_backend == Mxfp4MoeBackend.MARLIN
):
requested_backend = Mxfp4MoeBackend.BATCHED_MARLIN
return _return_or_raise(
requested_backend,
config,
kMxfp4Static,
_backend_activation_key(requested_backend),
activation_format,
)
# Select kernels in order of backend.
AVAILABLE_BACKENDS = _get_priority_backends()
# Handle explicit FlashInfer MXFP4 BF16 configuration.
if envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"):
if not envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16:
AVAILABLE_BACKENDS.remove(Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16)
AVAILABLE_BACKENDS.remove(Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16)
else:
if current_platform.is_device_capability(90):
return _return_or_raise(
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
config,
kMxfp4Static,
None,
activation_format,
)
if current_platform.is_device_capability_family(100):
return _return_or_raise(
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
config,
kMxfp4Static,
None,
activation_format,
)
raise ValueError(
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16=1 is set but the "
"current device capability is not supported. "
"Only SM90 (CUTLASS) and SM100+ (TRTLLM) are supported."
)
# Handle explicit FlashInfer MXFP4 MXFP8 TRTLLM configuration.
if (
envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8")
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
):
return _return_or_raise(
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
config,
kMxfp4Static,
kMxfp8Dynamic,
activation_format,
)
# Handle explicit FlashInfer MXFP4 MXFP8 CUTLASS configuration.
if (
envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS")
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS
):
return _return_or_raise(
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
config,
kMxfp4Static,
kMxfp8Dynamic,
activation_format,
)
# Handle explicit Marlin MXFP4 configuration.
if envs.is_set("VLLM_MXFP4_USE_MARLIN") and envs.VLLM_MXFP4_USE_MARLIN:
return _return_or_raise(
Mxfp4MoeBackend.MARLIN,
config,
kMxfp4Static,
None,
activation_format,
)
for backend in AVAILABLE_BACKENDS:
activation_key = _backend_activation_key(backend)
for k_cls in backend_to_kernel_cls(backend):
supported, reason = k_cls.is_supported_config(
k_cls, config, kMxfp4Static, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
else:
logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
if current_platform.is_xpu():
backend = Mxfp4MoeBackend.XPU
logger.info_once(_make_log_backend(backend))
return backend, None
if current_platform.is_cuda() or current_platform.is_rocm():
raise NotImplementedError(
"No MXFP4 MoE backend supports the deployment configuration."
)
return Mxfp4MoeBackend.NONE, None
def mxfp4_round_up_hidden_size_and_intermediate_size(
backend: Mxfp4MoeBackend, hidden_size: int, intermediate_size: int
) -> tuple[int, int]:
"""Round up hidden_size and intermediate_size based on backend requirements."""
if backend in (Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN):
intermediate_size = round_up(intermediate_size, 128)
if current_platform.is_xpu():
hidden_size = round_up(hidden_size, 128)
else:
hidden_size = round_up(hidden_size, 256)
elif backend in TRTLLM_BACKENDS:
intermediate_size = round_up(intermediate_size, 256)
hidden_size = round_up(hidden_size, 256)
elif backend in (
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
):
intermediate_size = round_up(intermediate_size, 128)
hidden_size = round_up(hidden_size, 128)
elif current_platform.is_rocm():
pad_align = get_padding_alignment()
intermediate_size = round_up(intermediate_size, pad_align)
hidden_size = round_up(hidden_size, pad_align)
else:
intermediate_size = round_up(intermediate_size, 64)
return hidden_size, intermediate_size
def convert_to_mxfp4_moe_kernel_format(
mxfp4_backend: Mxfp4MoeBackend,
layer: torch.nn.Module,
w13_weight: torch.Tensor,
w2_weight: torch.Tensor,
w13_weight_scale: torch.Tensor,
w2_weight_scale: torch.Tensor,
w13_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
_cache_permute_indices: dict[torch.Size, torch.Tensor] | None = None,
) -> tuple[
torch.Tensor,
torch.Tensor,
Union[torch.Tensor, "PrecisionConfig"],
Union[torch.Tensor, "PrecisionConfig"],
torch.Tensor | None,
torch.Tensor | None,
]:
"""Convert loaded weights into backend-specific kernel format."""
num_experts = w13_weight.shape[0]
intermediate_size = w13_weight.shape[1] // 2
hidden_size = w13_weight.shape[2] * 2
sf_block_size = 32 # mxfp4 block size
if mxfp4_backend in (Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN):
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
prepare_moe_mxfp4_layer_for_marlin,
)
return prepare_moe_mxfp4_layer_for_marlin(
layer,
w13_weight,
w2_weight,
w13_weight_scale,
w2_weight_scale,
w13_bias,
w2_bias,
)
elif mxfp4_backend in TRTLLM_BACKENDS:
assert _cache_permute_indices is not None
from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache
# gemm1_alpha/beta/clamp_limit are created by the expert class
# (TrtLlmMxfp4ExpertsBase), not on the layer.
w13_weight = w13_weight.data
w2_weight = w2_weight.data
w13_weight_scale = w13_weight_scale.data
w2_weight_scale = w2_weight_scale.data
assert w13_bias is not None and w2_bias is not None
w13_bias = w13_bias.data.to(torch.float32)
w2_bias = w2_bias.data.to(torch.float32)
# Swap w1 and w3 as the definition of swiglu is different in trtllm-gen
def swap_every_two_rows(x, axis=-1):
shape = x.shape
if axis < 0:
axis = len(shape) + axis
new_shape = list(shape)
new_shape[axis] = shape[axis] // 2
new_shape.insert(axis + 1, 2)
x = x.reshape(*new_shape)
x = x.flip(axis + 1)
new_shape = list(shape)
return x.reshape(*new_shape)
w13_weight_scale = swap_every_two_rows(w13_weight_scale, -2)
w13_weight = swap_every_two_rows(w13_weight, -2)
w13_bias = swap_every_two_rows(w13_bias, -1)
# Shuffle weights and scaling factors for transposed mma output
gemm1_weights_shuffled = []
gemm1_scales_shuffled = []
gemm2_weights_shuffled = []
gemm2_scales_shuffled = []
gemm1_bias_shuffled = []
gemm2_bias_shuffled = []
epilogue_tile_m = 128
for i in range(num_experts):
# w13 weight
permute_indices = get_w2_permute_indices_with_cache(
_cache_permute_indices,
w13_weight[i].view(torch.uint8),
epilogue_tile_m,
)
gemm1_weights_shuffled.append(
w13_weight[i]
.view(torch.uint8)[permute_indices.to(w13_weight.device)]
.contiguous()
)
# w13 scale
permute_sf_indices = get_w2_permute_indices_with_cache(
_cache_permute_indices,
w13_weight_scale[i].view(torch.uint8),
epilogue_tile_m,
num_elts_per_sf=16,
)
gemm1_scales_shuffled.append(
nvfp4_block_scale_interleave(
w13_weight_scale[i]
.view(torch.uint8)[permute_sf_indices.to(w13_weight_scale.device)]
.contiguous()
)
)
# w13 bias
permute_bias_indices = get_w2_permute_indices_with_cache(
_cache_permute_indices,
w13_bias[i].clone().reshape(-1, 1),
epilogue_tile_m,
)
gemm1_bias_shuffled.append(
w13_bias[i]
.clone()
.reshape(-1, 1)[permute_bias_indices.to(w13_bias.device)]
.contiguous()
)
# w2 weight
permute_indices = get_w2_permute_indices_with_cache(
_cache_permute_indices,
w2_weight[i].view(torch.uint8),
epilogue_tile_m,
)
gemm2_weights_shuffled.append(
w2_weight[i]
.view(torch.uint8)[permute_indices.to(w2_weight.device)]
.contiguous()
)
# w2 scale
permute_sf_indices = get_w2_permute_indices_with_cache(
_cache_permute_indices,
w2_weight_scale[i].view(torch.uint8),
epilogue_tile_m,
num_elts_per_sf=16,
)
gemm2_scales_shuffled.append(
nvfp4_block_scale_interleave(
w2_weight_scale[i]
.view(torch.uint8)[permute_sf_indices.to(w2_weight_scale.device)]
.contiguous()
)
)
# w2 bias
permute_indices = get_w2_permute_indices_with_cache(
_cache_permute_indices,
w2_bias[i].clone().reshape(-1, 1),
epilogue_tile_m,
)
gemm2_bias_shuffled.append(
w2_bias[i]
.clone()
.reshape(-1, 1)[permute_indices.to(w2_bias.device)]
.contiguous()
)
w13_weight = torch.stack(gemm1_weights_shuffled)
w13_weight_scale = (
torch.stack(gemm1_scales_shuffled)
.reshape(num_experts, 2 * intermediate_size, hidden_size // sf_block_size)
.view(torch.float8_e4m3fn)
)
w2_weight = torch.stack(gemm2_weights_shuffled)
w2_weight_scale = (
torch.stack(gemm2_scales_shuffled)
.reshape(num_experts, hidden_size, intermediate_size // sf_block_size)
.view(torch.float8_e4m3fn)
)
w13_bias = torch.stack(gemm1_bias_shuffled).reshape(num_experts, -1)
w2_bias = torch.stack(gemm2_bias_shuffled).reshape(num_experts, -1)
return (
w13_weight,
w2_weight,
w13_weight_scale,
w2_weight_scale,
w13_bias,
w2_bias,
)
elif mxfp4_backend in (
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
):
# De-interleave and swap for w13 weight, bias, and scales
w13_w = w13_weight.data
gate_w, up_w = w13_w[:, ::2, :], w13_w[:, 1::2, :]
deinterleaved_w13_w = torch.cat([gate_w, up_w], dim=1)
w1_w, w3_w = torch.chunk(deinterleaved_w13_w, 2, dim=1)
w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1)
assert w13_bias is not None and w2_bias is not None
w13_b = w13_bias.data.to(torch.float32)
gate_b, up_b = w13_b[:, ::2], w13_b[:, 1::2]
deinterleaved_w13_b = torch.cat([gate_b, up_b], dim=1)
b1, b3 = torch.chunk(deinterleaved_w13_b, 2, dim=-1)
w13_bias_swapped = torch.cat([b3, b1], dim=-1).to(torch.bfloat16)
w13_s = w13_weight_scale.data
gate_s, up_s = w13_s[:, ::2, :], w13_s[:, 1::2, :]
deinterleaved_w13_s = torch.cat([gate_s, up_s], dim=1)
s1, s3 = torch.chunk(deinterleaved_w13_s, 2, dim=1)
w13_scale_swapped = torch.cat([s3, s1], dim=1)
if mxfp4_backend == Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8:
from flashinfer import block_scale_interleave
orig_shape = w13_scale_swapped.shape
w13_scale_interleaved = block_scale_interleave(
w13_scale_swapped.view(torch.uint8)
).reshape(orig_shape)
w2_s = w2_weight_scale.data
orig_shape = w2_s.shape
w2_scale_interleaved = block_scale_interleave(
w2_s.view(torch.uint8)
).reshape(orig_shape)
return (
w13_weight_swapped,
w2_weight,
w13_scale_interleaved,
w2_scale_interleaved,
w13_bias_swapped,
w2_bias,
)
else:
assert mxfp4_backend == Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16
def _interleave_mxfp4_cutlass_sm90(w):
w_shape = w.shape
w_interleaved = w.reshape(w_shape[0], w_shape[1], (w_shape[2] // 4), 4)
w_interleaved = w_interleaved.permute(0, 2, 1, 3)
w_interleaved = w_interleaved.reshape(
w_shape[0], w_shape[2] // 4, w_shape[1] * 4
)
return w_interleaved
w31_scales = w13_scale_swapped.to(torch.uint8)
w31_scales_interleaved = _interleave_mxfp4_cutlass_sm90(w31_scales)
w2_scale = w2_weight_scale.data.to(torch.uint8)
w2_scale_interleaved = _interleave_mxfp4_cutlass_sm90(w2_scale)
return (
w13_weight_swapped,
w2_weight,
w31_scales_interleaved,
w2_scale_interleaved,
w13_bias_swapped,
w2_bias,
)
elif mxfp4_backend == Mxfp4MoeBackend.CK:
from vllm._aiter_ops import rocm_aiter_ops
if w13_bias is not None:
w13_bias = w13_bias.data.to(torch.float32)
if w2_bias is not None:
w2_bias = w2_bias.data.to(torch.float32)
e, n, k = w13_weight.shape
# De-interleave w13 rows: gate/up pairs -> contiguous gate, up blocks
w13_weight.view(torch.uint8).copy_(
w13_weight.data.view(torch.uint8)
.view(e, n // 2, 2, k)
.permute(0, 2, 1, 3)
.contiguous()
.view(e, n, k)
)
w13_weight_scale.data = (
w13_weight_scale.data.view(e, n // 2, 2, -1)
.permute(0, 2, 1, 3)
.contiguous()
.view(e, n, -1)
)
# View as native FP4 dtype for AITER shuffle
w13_weight.data = w13_weight.data.view(torch.float4_e2m1fn_x2)
w2_weight.data = w2_weight.data.view(torch.float4_e2m1fn_x2)
# Shuffle weights and scales for AITER CK kernel layout
w13_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(w13_weight, 16, True)
shuffled_w13_scale = rocm_aiter_ops.shuffle_scale_a16w4(
w13_weight_scale.view(-1, w13_weight_scale.shape[-1]),
num_experts,
True,
)
w2_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(w2_weight, 16, False)
shuffled_w2_scale = rocm_aiter_ops.shuffle_scale_a16w4(
w2_weight_scale.view(-1, w2_weight_scale.shape[-1]),
num_experts,
False,
)
# Permute bias to match de-interleaved weight layout
if w13_bias is not None:
w13_bias = (
w13_bias.data.view(-1, n // 2, 2)
.permute(0, 2, 1)
.contiguous()
.view(-1, n)
)
return (
w13_weight,
w2_weight,
shuffled_w13_scale,
shuffled_w2_scale,
w13_bias,
w2_bias,
)
elif mxfp4_backend in TRITON_BACKENDS:
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
assert w13_bias is not None and w2_bias is not None
w13_bias = w13_bias.to(torch.float32)
w2_bias = w2_bias.to(torch.float32)
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
w13_weight,
w13_weight_scale,
)
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
w2_weight,
w2_weight_scale,
)
w13_precision_config = PrecisionConfig(
weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)
)
w2_precision_config = PrecisionConfig(
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
)
del layer.w13_weight
del layer.w2_weight
return (
w13_weight,
w2_weight,
w13_precision_config,
w2_precision_config,
w13_bias,
w2_bias,
)
else:
raise ValueError(
f"Unsupported mxfp4_backend: {mxfp4_backend}: "
f"should be one of: {list(Mxfp4MoeBackend)}."
)
def make_mxfp4_moe_quant_config(
mxfp4_backend: Mxfp4MoeBackend,
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
) -> FusedMoEQuantConfig | None:
"""Create a FusedMoEQuantConfig for the given MXFP4 backend."""
if mxfp4_backend in (
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
):
return mxfp4_mxfp8_moe_quant_config(
w1_bias=w1_bias,
w2_bias=w2_bias,
w1_scale=w1_scale,
w2_scale=w2_scale,
)
elif mxfp4_backend in (
Mxfp4MoeBackend.MARLIN,
Mxfp4MoeBackend.BATCHED_MARLIN,
Mxfp4MoeBackend.TRITON,
Mxfp4MoeBackend.TRITON_UNFUSED,
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
Mxfp4MoeBackend.CK,
):
return mxfp4_w4a16_moe_quant_config(
w1_bias=w1_bias,
w2_bias=w2_bias,
w1_scale=w1_scale,
w2_scale=w2_scale,
)
else:
return ocp_mx_moe_quant_config(
quant_dtype="mxfp4",
w1_bias=w1_bias,
w2_bias=w2_bias,
w1_scale=w1_scale,
w2_scale=w2_scale,
)
def make_mxfp4_moe_kernel(
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
experts_cls: type[mk.FusedMoEExperts],
mxfp4_backend: Mxfp4MoeBackend,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
shared_experts: torch.nn.Module | None = None,
) -> mk.FusedMoEKernel:
"""Create a FusedMoEKernel for the given MXFP4 backend."""
is_monolithic = issubclass(experts_cls, mk.FusedMoEExpertsMonolithic)
# Create Prepare/Finalize.
prepare_finalize = maybe_make_prepare_finalize(
moe=moe_config,
quant_config=moe_quant_config,
routing_tables=routing_tables,
allow_new_interface=True,
use_monolithic=is_monolithic,
)
assert prepare_finalize is not None
logger.info_once("Using %s", prepare_finalize.__class__.__name__, scope="local")
# Create Experts.
if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
max_num_tokens = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens is not None
experts = experts_cls(
moe_config=moe_config,
quant_config=moe_quant_config,
max_num_tokens=max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
)
else:
experts = experts_cls(
moe_config=moe_config,
quant_config=moe_quant_config,
)
kernel = mk.FusedMoEKernel(
prepare_finalize,
experts,
shared_experts=(
shared_experts
if moe_config.moe_parallel_config.use_deepep_ll_kernels
else None
),
moe_parallel_config=moe_config.moe_parallel_config,
inplace=(
not moe_config.disable_inplace and mxfp4_backend not in TRTLLM_BACKENDS
),
)
return kernel
...@@ -14,7 +14,6 @@ from vllm.model_executor.layers.fused_moe.all2all_utils import ( ...@@ -14,7 +14,6 @@ from vllm.model_executor.layers.fused_moe.all2all_utils import (
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
mxfp4_w4a16_moe_quant_config,
nvfp4_moe_quant_config, nvfp4_moe_quant_config,
nvfp4_w4a16_moe_quant_config, nvfp4_w4a16_moe_quant_config,
) )
...@@ -347,16 +346,6 @@ def convert_to_nvfp4_moe_kernel_format( ...@@ -347,16 +346,6 @@ def convert_to_nvfp4_moe_kernel_format(
) )
def make_mxfp4_moe_quant_config(
w13_scale: torch.Tensor,
w2_scale: torch.Tensor,
) -> FusedMoEQuantConfig:
return mxfp4_w4a16_moe_quant_config(
w1_scale=w13_scale,
w2_scale=w2_scale,
)
def make_nvfp4_moe_quant_config( def make_nvfp4_moe_quant_config(
backend: NvFp4MoeBackend, backend: NvFp4MoeBackend,
w13_scale: torch.Tensor, w13_scale: torch.Tensor,
......
...@@ -24,6 +24,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -24,6 +24,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8Static128BlockSym, kFp8Static128BlockSym,
kFp8StaticChannelSym, kFp8StaticChannelSym,
kFp8StaticTensorSym, kFp8StaticTensorSym,
kMxfp4Static,
) )
...@@ -201,6 +202,8 @@ def rocm_aiter_fused_experts( ...@@ -201,6 +202,8 @@ def rocm_aiter_fused_experts(
activation_method = ActivationMethod.SILU activation_method = ActivationMethod.SILU
elif activation == MoEActivation.GELU: elif activation == MoEActivation.GELU:
activation_method = ActivationMethod.GELU activation_method = ActivationMethod.GELU
elif activation == MoEActivation.SWIGLUOAI:
activation_method = rocm_aiter_ops.get_aiter_activation_type("swiglu")
else: else:
raise ValueError(f"Unsupported activation: {activation}") raise ValueError(f"Unsupported activation: {activation}")
...@@ -247,8 +250,8 @@ def rocm_aiter_fused_experts( ...@@ -247,8 +250,8 @@ def rocm_aiter_fused_experts(
else: else:
quant_method = QuantMethod.NO.value quant_method = QuantMethod.NO.value
# quark moe for mxfp4 w_dtype mxfp4 a_dtype # mxfp4: both w4a4 (quark) and w4a16 (oracle CK) use BLOCK_1X32
if quant_config.use_mxfp4_w4a4: if quant_config.use_mxfp4_w4a4 or quant_config.use_mxfp4_w4a16:
quant_method = QuantMethod.BLOCK_1X32.value quant_method = QuantMethod.BLOCK_1X32.value
# w8a8 block-scaled # w8a8 block-scaled
if quant_config.block_shape is not None and quant_config.use_fp8_w8a8: if quant_config.block_shape is not None and quant_config.use_fp8_w8a8:
...@@ -289,6 +292,8 @@ def rocm_aiter_fused_experts( ...@@ -289,6 +292,8 @@ def rocm_aiter_fused_experts(
doweight_stage1=apply_router_weight_on_input, doweight_stage1=apply_router_weight_on_input,
num_local_tokens=num_local_tokens, num_local_tokens=num_local_tokens,
output_dtype=output_dtype, output_dtype=output_dtype,
bias1=quant_config.w1_bias if quant_config.use_mxfp4_w4a16 else None,
bias2=quant_config.w2_bias if quant_config.use_mxfp4_w4a16 else None,
) )
...@@ -319,21 +324,23 @@ class AiterExperts(mk.FusedMoEExpertsModular): ...@@ -319,21 +324,23 @@ class AiterExperts(mk.FusedMoEExpertsModular):
weight_key: QuantKey | None, weight_key: QuantKey | None,
activation_key: QuantKey | None, activation_key: QuantKey | None,
) -> bool: ) -> bool:
# TODO(rob): AITER also supports MXFP4, which is not
# yet supported via an Oracle. Once it is, we will add
# MXFP4 to this list.
SUPPORTED_W_A = [ SUPPORTED_W_A = [
(None, None), (None, None),
(kFp8Static128BlockSym, kFp8Dynamic128Sym), (kFp8Static128BlockSym, kFp8Dynamic128Sym),
(kFp8StaticTensorSym, kFp8StaticTensorSym), (kFp8StaticTensorSym, kFp8StaticTensorSym),
(kFp8StaticTensorSym, kFp8DynamicTensorSym), (kFp8StaticTensorSym, kFp8DynamicTensorSym),
(kFp8StaticChannelSym, kFp8DynamicTokenSym), (kFp8StaticChannelSym, kFp8DynamicTokenSym),
(kMxfp4Static, None),
] ]
return (weight_key, activation_key) in SUPPORTED_W_A return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod @staticmethod
def _supports_activation(activation: MoEActivation) -> bool: def _supports_activation(activation: MoEActivation) -> bool:
return activation in [MoEActivation.SILU, MoEActivation.GELU] return activation in [
MoEActivation.SILU,
MoEActivation.GELU,
MoEActivation.SWIGLUOAI,
]
@staticmethod @staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
......
...@@ -45,11 +45,14 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( ...@@ -45,11 +45,14 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
make_fp8_moe_quant_config, make_fp8_moe_quant_config,
select_fp8_moe_backend, select_fp8_moe_backend,
) )
from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import (
Mxfp4MoeBackend,
make_mxfp4_moe_kernel,
make_mxfp4_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
NvFp4MoeBackend,
convert_to_nvfp4_moe_kernel_format, convert_to_nvfp4_moe_kernel_format,
is_global_sf_supported_for_nvfp4_backend, is_global_sf_supported_for_nvfp4_backend,
make_mxfp4_moe_quant_config,
make_nvfp4_moe_kernel, make_nvfp4_moe_kernel,
make_nvfp4_moe_quant_config, make_nvfp4_moe_quant_config,
select_nvfp4_moe_backend, select_nvfp4_moe_backend,
...@@ -235,7 +238,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -235,7 +238,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
def __init__(self, moe): def __init__(self, moe):
super().__init__(moe) super().__init__(moe)
self.group_size = 32 self.group_size = 32
self.mxfp4_backend = NvFp4MoeBackend.MARLIN self.mxfp4_backend = Mxfp4MoeBackend.MARLIN
self.experts_cls = MarlinExperts self.experts_cls = MarlinExperts
def create_weights( def create_weights(
...@@ -310,7 +313,9 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -310,7 +313,9 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
self, layer: torch.nn.Module self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None: ) -> FusedMoEQuantConfig | None:
return make_mxfp4_moe_quant_config( return make_mxfp4_moe_quant_config(
w13_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale mxfp4_backend=self.mxfp4_backend,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
) )
def process_weights_after_loading(self, layer: FusedMoE) -> None: def process_weights_after_loading(self, layer: FusedMoE) -> None:
...@@ -334,10 +339,11 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -334,10 +339,11 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config is not None: if self.moe_quant_config is not None:
self.moe_kernel = make_nvfp4_moe_kernel( self.moe_kernel = make_mxfp4_moe_kernel(
moe_quant_config=self.moe_quant_config, moe_quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
experts_cls=self.experts_cls, experts_cls=self.experts_cls,
mxfp4_backend=self.mxfp4_backend,
shared_experts=layer.shared_experts, shared_experts=layer.shared_experts,
routing_tables=layer._maybe_init_expert_routing_tables(), routing_tables=layer._maybe_init_expert_routing_tables(),
) )
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
import torch import torch
from torch.nn.parameter import Parameter
from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
...@@ -17,173 +13,31 @@ from vllm.model_executor.layers.fused_moe import ( ...@@ -17,173 +13,31 @@ from vllm.model_executor.layers.fused_moe import (
MoEActivation, MoEActivation,
) )
from vllm.model_executor.layers.fused_moe import modular_kernel as mk from vllm.model_executor.layers.fused_moe import modular_kernel as mk
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, FusedMoEQuantConfig,
mxfp4_mxfp8_moe_quant_config,
mxfp4_w4a16_moe_quant_config,
ocp_mx_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
BatchedMarlinExperts,
MarlinExperts,
) )
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import (
OAITritonExperts, TRITON_BACKENDS,
UnfusedOAITritonExperts, Mxfp4MoeBackend,
convert_to_mxfp4_moe_kernel_format,
make_mxfp4_moe_kernel,
make_mxfp4_moe_quant_config,
mxfp4_round_up_hidden_size_and_intermediate_size,
select_mxfp4_moe_backend,
) )
from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
get_marlin_input_dtype,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
prepare_moe_fp4_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
CK_MXFP4_MOE_DIM_ALIGNMENT,
_can_support_mxfp4,
_swizzle_mxfp4,
get_padding_alignment,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer
from vllm.utils.import_utils import has_triton_kernels
from vllm.utils.math_utils import round_up
logger = init_logger(__name__) logger = init_logger(__name__)
# enum for mxfp4 backend
class Mxfp4Backend(Enum):
NONE = 0
# FlashInfer Backend
SM100_FI_MXFP4_MXFP8_TRTLLM = 1
SM100_FI_MXFP4_MXFP8_CUTLASS = 2
SM100_FI_MXFP4_BF16 = 3
SM90_FI_MXFP4_BF16 = 4
# Marlin Backend
MARLIN = 5
# Triton Backend
TRITON = 6
CK = 7
def get_mxfp4_backend_with_lora() -> Mxfp4Backend:
"""
Not all MXFP4 backends support LoRA. Select backends that are known to
have LoRA support.
"""
if not current_platform.is_cuda():
return Mxfp4Backend.NONE
# If FlashInfer is not available, try either Marlin or Triton
triton_kernels_supported = (
has_triton_kernels()
# NOTE: triton_kernels are only confirmed to work on SM90 and SM100
# SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317
# SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498
and (9, 0) <= current_platform.get_device_capability() < (11, 0)
)
if envs.VLLM_MXFP4_USE_MARLIN is False and triton_kernels_supported:
logger.info_once("[get_mxfp4_backend_with_lora] Using Triton backend")
return Mxfp4Backend.TRITON
logger.info_once("[get_mxfp4_backend_with_lora] Using Marlin backend")
return Mxfp4Backend.MARLIN
def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
# Backend Selection
if with_lora_support:
return get_mxfp4_backend_with_lora()
if current_platform.is_cuda():
if (
current_platform.is_device_capability(90)
and has_flashinfer()
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
):
logger.info_once("Using FlashInfer MXFP4 BF16 backend for SM90")
return Mxfp4Backend.SM90_FI_MXFP4_BF16
elif (
current_platform.is_device_capability_family(100)
and has_flashinfer()
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS
):
logger.info_once("Using FlashInfer MXFP4 MXFP8 CUTLASS backend for SM100")
return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
elif (
current_platform.is_device_capability_family(100)
and has_flashinfer()
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
):
logger.info_once(
"Using FlashInfer MXFP4 MXFP8 TRTLLM backend for SM100", scope="local"
)
return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
elif current_platform.is_device_capability_family(100) and has_flashinfer():
logger.info_once(
"Using FlashInfer MXFP4 BF16 backend for SM100, "
"For faster performance on SM100, consider setting "
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, though this may impact "
"accuracy."
)
return Mxfp4Backend.SM100_FI_MXFP4_BF16
elif (
current_platform.is_device_capability_family(100)
or current_platform.is_device_capability(90)
) and not has_flashinfer():
logger.warning_once(
"MXFP4 MoE is enabled on Hopper/Blackwell but FlashInfer "
"is not available. This may result in degraded performance. "
"Please `pip install vllm[flashinfer]` for best results."
)
# If FlashInfer is not available, try either Marlin or Triton
triton_kernels_supported = (
has_triton_kernels()
# NOTE: triton_kernels are only confirmed to work on SM90 and SM100
# SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317
# SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498
and (9, 0) <= current_platform.get_device_capability() < (11, 0)
)
if envs.VLLM_MXFP4_USE_MARLIN or not triton_kernels_supported:
logger.info_once("Using Marlin backend")
return Mxfp4Backend.MARLIN
else:
logger.info_once("Using Triton backend")
return Mxfp4Backend.TRITON
elif current_platform.is_xpu():
logger.info_once("Using xpu backend on XPU")
return Mxfp4Backend.MARLIN
elif current_platform.is_rocm():
from vllm.platforms.rocm import on_gfx950
if rocm_aiter_ops.is_enabled() and on_gfx950():
logger.info_once("Using CK MXFP4 MoE backend (Aiter ROCm)")
return Mxfp4Backend.CK
elif has_triton_kernels():
logger.info_once("Using Triton backend")
return Mxfp4Backend.TRITON
return Mxfp4Backend.NONE
class Mxfp4Config(QuantizationConfig): class Mxfp4Config(QuantizationConfig):
def __init__(self, ignored_layers: list[str] | None = None): def __init__(self, ignored_layers: list[str] | None = None):
super().__init__() super().__init__()
...@@ -219,9 +73,6 @@ class Mxfp4Config(QuantizationConfig): ...@@ -219,9 +73,6 @@ class Mxfp4Config(QuantizationConfig):
fused_mapping=self.packed_modules_mapping, fused_mapping=self.packed_modules_mapping,
): ):
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
# TODO: Add support for MXFP4 Linear Method.
# MXFP4 LinearMethod is available in AMD-Quark, refer to that implementation
# if you are interested in enabling MXFP4 here.
logger.debug_once( logger.debug_once(
"MXFP4 linear layer is not implemented - falling back to " "MXFP4 linear layer is not implemented - falling back to "
"UnquantizedLinearMethod.", "UnquantizedLinearMethod.",
...@@ -232,10 +83,8 @@ class Mxfp4Config(QuantizationConfig): ...@@ -232,10 +83,8 @@ class Mxfp4Config(QuantizationConfig):
if current_platform.is_xpu(): if current_platform.is_xpu():
return XpuMxfp4MoEMethod(layer.moe_config) return XpuMxfp4MoEMethod(layer.moe_config)
else: else:
quant_method = Mxfp4MoEMethod(layer.moe_config) return Mxfp4MoEMethod(layer.moe_config)
return quant_method
elif isinstance(layer, Attention): elif isinstance(layer, Attention):
# TODO: Add support for MXFP4 Attention.
logger.debug_once( logger.debug_once(
"MXFP4 attention layer is not implemented. " "MXFP4 attention layer is not implemented. "
"Skipping quantization for this layer.", "Skipping quantization for this layer.",
...@@ -254,51 +103,36 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -254,51 +103,36 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def __init__(self, moe: FusedMoEConfig): def __init__(self, moe: FusedMoEConfig):
super().__init__(moe) super().__init__(moe)
self.weight_dtype = "mxfp4" self.weight_dtype = "mxfp4"
self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled) self.mxfp4_backend, self.experts_cls = select_mxfp4_moe_backend(moe)
self.max_capture_size = ( self.max_capture_size = (
get_current_vllm_config().compilation_config.max_cudagraph_capture_size get_current_vllm_config().compilation_config.max_cudagraph_capture_size
) )
# CK's pre-compiled MXFP4 MoE GEMM kernel instances have dimension
# alignment requirements. Fall back to Triton when not met.
if (
self.mxfp4_backend == Mxfp4Backend.CK
and moe.intermediate_size_per_partition % CK_MXFP4_MOE_DIM_ALIGNMENT != 0
):
if has_triton_kernels():
logger.warning_once(
"CK MXFP4 MoE GEMM does not support "
"intermediate_size_per_partition=%d (not a multiple of "
"%d). Falling back to Triton backend.",
moe.intermediate_size_per_partition,
CK_MXFP4_MOE_DIM_ALIGNMENT,
)
self.mxfp4_backend = Mxfp4Backend.TRITON
else:
raise ValueError(
f"CK MXFP4 MoE GEMM does not support "
f"intermediate_size_per_partition="
f"{moe.intermediate_size_per_partition} (not a multiple "
f"of {CK_MXFP4_MOE_DIM_ALIGNMENT}) and no Triton "
f"fallback is available. Use a compatible "
f"tensor_parallel_size."
)
assert self.mxfp4_backend != Mxfp4Backend.NONE, (
f"get_mxfp4_backend(with_lora_support={moe.is_lora_enabled}) found"
"no compatible MXFP4 MoE backend (FlashInfer/Marlin/Triton)."
"Please check your environment and try again."
)
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {} self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
# Initialized in process_weights_after_loading for CUTLASS/SM90 backends
self.moe_kernel: mk.FusedMoEKernel | None = None self.moe_kernel: mk.FusedMoEKernel | None = None
# Round up dims once based on backend. This mutates the shared
# FusedMoEConfig in-place so that create_weights() and all
# downstream code see the padded dimensions. This must happen
# before create_weights() is called.
self.moe.hidden_dim, self.moe.intermediate_size_per_partition = (
mxfp4_round_up_hidden_size_and_intermediate_size(
self.mxfp4_backend,
self.moe.hidden_dim,
self.moe.intermediate_size_per_partition,
)
)
# Used for triton kernel precision configs
self.w13_precision_config = None
self.w2_precision_config = None
@property @property
def skip_forward_padding(self) -> bool: def skip_forward_padding(self) -> bool:
# SM100_FI_MXFP4_MXFP8_TRTLLM supports padding with mxfp8 quant # SM100_FI_MXFP4_MXFP8_TRTLLM supports padding with mxfp8 quant
# so can skip the padding in the forward before applying the moe method # so can skip the padding in the forward before applying the moe method
return self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM return self.mxfp4_backend == Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8
def create_weights( def create_weights(
self, self,
...@@ -312,77 +146,14 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -312,77 +146,14 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self.num_experts = num_experts self.num_experts = num_experts
weight_dtype = torch.uint8 weight_dtype = torch.uint8
scale_dtype = torch.uint8 scale_dtype = torch.uint8
# FIXME (zyongye): ship after torch and safetensors support mxfp4
# is_torch_mxfp4_available = (
# hasattr(torch, "float4_e2m1fn_x2") and
# hasattr(torch, "float8_e8m0fnu"))
# if is_torch_mxfp4_available:
# weight_dtype = torch.float4_e2m1fn_x2
# scale_dtype = torch.float8_e8m0fnu
mxfp4_block = 32 mxfp4_block = 32
intermediate_size_per_partition_after_pad = intermediate_size_per_partition # Use pre-rounded sizes from config
if self.mxfp4_backend == Mxfp4Backend.MARLIN: self.intermediate_size = intermediate_size_per_partition_after_pad = (
# The moe marlin kernel requires that for each linear self.moe.intermediate_size_per_partition
# n % 256 == 0 and k % 128 == 0.
# In gate_up_proj:
# n = 2 * intermediate_size_per_partition_after_pad
# k = hidden_size
# In down_proj
# n = hidden_size
# k = intermediate_size_per_partition_after_pad
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 128
)
if current_platform.is_xpu():
hidden_size = round_up(hidden_size, 128)
else:
hidden_size = round_up(hidden_size, 256)
layer.params_dtype = params_dtype
layer.num_experts = num_experts
layer.hidden_size = hidden_size
layer.intermediate_size_per_partition = (
intermediate_size_per_partition_after_pad
)
elif (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
):
# pad the intermediate size to be a multiple of 2 * mxfp4_block
# for to hold non-uniform sharded tensor as well as swizzling
# other padding to increase performance
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 256
)
hidden_size = round_up(hidden_size, 256)
elif (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
):
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 128
)
hidden_size = round_up(hidden_size, 128)
elif current_platform.is_rocm():
pad_align = get_padding_alignment()
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, pad_align
)
hidden_size = round_up(hidden_size, pad_align)
else:
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 64
)
self.intermediate_size = intermediate_size_per_partition_after_pad
self.hidden_size = hidden_size
self.hidden_pad = extra_weight_attrs.get("hidden_pad", 0)
self.intermediate_pad = (
intermediate_size_per_partition_after_pad - intermediate_size_per_partition
) )
self.hidden_size = hidden_size = self.moe.hidden_dim
# Fused gate_up_proj (column parallel) # Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter( w13_weight = torch.nn.Parameter(
torch.zeros( torch.zeros(
...@@ -408,17 +179,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -408,17 +179,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w13_weight_scale", w13_weight_scale)
set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w13_weight_scale, extra_weight_attrs)
w13_bias = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition_after_pad,
dtype=torch.bfloat16,
),
requires_grad=False,
)
layer.register_parameter("w13_bias", w13_bias)
set_weight_attrs(w13_bias, extra_weight_attrs)
# down_proj (row parallel) # down_proj (row parallel)
w2_weight = torch.nn.Parameter( w2_weight = torch.nn.Parameter(
torch.zeros( torch.zeros(
...@@ -444,604 +204,170 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -444,604 +204,170 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer.register_parameter("w2_weight_scale", w2_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale)
set_weight_attrs(w2_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs)
w2_bias = torch.nn.Parameter( if self.moe.has_bias:
torch.zeros( w13_bias = torch.nn.Parameter(
num_experts, torch.zeros(
hidden_size, num_experts,
dtype=torch.bfloat16, 2 * intermediate_size_per_partition_after_pad,
), dtype=torch.bfloat16,
requires_grad=False,
)
layer.register_parameter("w2_bias", w2_bias)
set_weight_attrs(w2_bias, extra_weight_attrs)
def process_weights_after_loading(self, layer):
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
prepare_moe_fp4_layer_for_marlin(
layer, input_dtype=get_marlin_input_dtype()
)
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
assert self.moe_quant_config is not None
prepare_finalize = maybe_make_prepare_finalize(
moe=self.moe,
quant_config=self.moe_quant_config,
routing_tables=layer._maybe_init_expert_routing_tables(),
allow_new_interface=True,
)
assert prepare_finalize is not None
self.moe_kernel = mk.FusedMoEKernel(
prepare_finalize,
MarlinExperts(
self.moe,
self.moe_quant_config,
), ),
inplace=not self.moe.disable_inplace,
shared_experts=None,
)
elif (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
):
from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache
layer.gemm1_alpha = Parameter(
torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False, requires_grad=False,
) )
layer.gemm1_beta = Parameter( layer.register_parameter("w13_bias", w13_bias)
torch.tensor([1.0] * self.num_experts, dtype=torch.float32).cuda(), set_weight_attrs(w13_bias, extra_weight_attrs)
requires_grad=False,
)
layer.gemm1_clamp_limit = Parameter(
torch.tensor([7.0] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False,
)
sf_block_size = 32 # mxfp4 block size
assert (
layer.w13_weight.dim() == 3
and layer.w13_weight.shape[0] == self.num_experts
and layer.w13_weight.shape[1] == self.intermediate_size * 2
and layer.w13_weight.shape[2] == self.hidden_size // 2
)
assert (
layer.w13_weight_scale.dim() == 3
and layer.w13_weight_scale.shape[0] == self.num_experts
and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2
and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size
)
assert (
layer.w2_weight.dim() == 3
and layer.w2_weight.shape[0] == self.num_experts
and layer.w2_weight.shape[1] == self.hidden_size
and layer.w2_weight.shape[2] == self.intermediate_size // 2
)
assert (
layer.w2_weight_scale.dim() == 3
and layer.w2_weight_scale.shape[1] == self.hidden_size
and layer.w2_weight_scale.shape[2]
== self.intermediate_size // sf_block_size
)
assert (
layer.w13_bias.dim() == 2
and layer.w13_bias.shape[0] == self.num_experts
and layer.w13_bias.shape[1] == self.intermediate_size * 2
)
assert (
layer.w2_bias.dim() == 2
and layer.w2_bias.shape[0] == self.num_experts
and layer.w2_bias.shape[1] == self.hidden_size
)
w13_weight_scale = layer.w13_weight_scale.data
w2_weight_scale = layer.w2_weight_scale.data
w13_weight = layer.w13_weight.data
w2_weight = layer.w2_weight.data
w13_bias = layer.w13_bias.data.to(torch.float32)
w2_bias = layer.w2_bias.data.to(torch.float32)
# Swap w1 and w3 as the definition of
# swiglu is different in the trtllm-gen
def swap_every_two_rows(x, axis=-1):
shape = x.shape
if axis < 0:
axis = len(shape) + axis
# Create a new shape with pairs swapped along specified axis
new_shape = list(shape)
new_shape[axis] = shape[axis] // 2
new_shape.insert(axis + 1, 2)
# Reshape to expose pairs, swap them, and reshape back w2_bias = torch.nn.Parameter(
x = x.reshape(*new_shape) torch.zeros(
x = x.flip(axis + 1) num_experts,
new_shape = list(shape) hidden_size,
return x.reshape(*new_shape) dtype=torch.bfloat16,
),
w13_weight_scale = swap_every_two_rows(w13_weight_scale, -2)
w13_weight = swap_every_two_rows(w13_weight, -2)
w13_bias = swap_every_two_rows(w13_bias, -1)
# Do not interleave as the checkpoint is already interleaved
# Shuffle weights and scaling factors for transposed mma output
gemm1_weights_mxfp4_shuffled = []
gemm1_scales_mxfp4_shuffled = []
gemm2_weights_mxfp4_shuffled = []
gemm2_scales_mxfp4_shuffled = []
gemm1_bias_shuffled = []
gemm2_bias_shuffled = []
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
for i in range(self.num_experts):
# w13 weight shuffling
permute_indices = get_w2_permute_indices_with_cache(
self._cache_permute_indices,
w13_weight[i].view(torch.uint8),
epilogue_tile_m,
)
gemm1_weights_mxfp4_shuffled.append(
w13_weight[i]
.view(torch.uint8)[permute_indices.to(w13_weight.device)]
.contiguous()
)
# w13 scale shuffling
permute_sf_indices = get_w2_permute_indices_with_cache(
self._cache_permute_indices,
w13_weight_scale[i].view(torch.uint8),
epilogue_tile_m,
num_elts_per_sf=16,
)
gemm1_scales_mxfp4_shuffled.append(
nvfp4_block_scale_interleave(
w13_weight_scale[i]
.view(torch.uint8)[
permute_sf_indices.to(w13_weight_scale.device)
]
.contiguous()
)
)
# w13 bias shuffling
permute_bias_indices = get_w2_permute_indices_with_cache(
self._cache_permute_indices,
w13_bias[i].clone().reshape(-1, 1),
epilogue_tile_m,
)
gemm1_bias_shuffled.append(
w13_bias[i]
.clone()
.reshape(-1, 1)[permute_bias_indices.to(w13_bias.device)]
.contiguous()
)
# w2 weight shuffling
permute_indices = get_w2_permute_indices_with_cache(
self._cache_permute_indices,
w2_weight[i].view(torch.uint8),
epilogue_tile_m,
)
gemm2_weights_mxfp4_shuffled.append(
w2_weight[i]
.view(torch.uint8)[permute_indices.to(w2_weight.device)]
.contiguous()
)
# w2 scale shuffling
permute_sf_indices = get_w2_permute_indices_with_cache(
self._cache_permute_indices,
w2_weight_scale[i].view(torch.uint8),
epilogue_tile_m,
num_elts_per_sf=16,
)
gemm2_scales_mxfp4_shuffled.append(
nvfp4_block_scale_interleave(
w2_weight_scale[i]
.view(torch.uint8)[
permute_sf_indices.to(w2_weight_scale.device)
]
.contiguous()
)
)
# w2 bias shuffling
permute_indices = get_w2_permute_indices_with_cache(
self._cache_permute_indices,
w2_bias[i].clone().reshape(-1, 1),
epilogue_tile_m,
)
gemm2_bias_shuffled.append(
w2_bias[i]
.clone()
.reshape(-1, 1)[permute_indices.to(w2_bias.device)]
.contiguous()
)
w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled)
w13_weight_scale = (
torch.stack(gemm1_scales_mxfp4_shuffled)
.reshape(
self.num_experts,
2 * self.intermediate_size,
self.hidden_size // sf_block_size,
)
.view(torch.float8_e4m3fn)
)
w2_weight = torch.stack(gemm2_weights_mxfp4_shuffled)
w2_weight_scale = (
torch.stack(gemm2_scales_mxfp4_shuffled)
.reshape(
self.num_experts,
self.hidden_size,
self.intermediate_size // sf_block_size,
)
.view(torch.float8_e4m3fn)
)
layer.w13_weight = Parameter(w13_weight, requires_grad=False)
layer.w13_weight_scale = Parameter(w13_weight_scale, requires_grad=False)
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
layer.w2_weight_scale = Parameter(w2_weight_scale, requires_grad=False)
layer.w13_bias = Parameter(
torch.stack(gemm1_bias_shuffled).reshape(self.num_experts, -1),
requires_grad=False,
)
layer.w2_bias = Parameter(
torch.stack(gemm2_bias_shuffled).reshape(self.num_experts, -1),
requires_grad=False, requires_grad=False,
) )
elif ( layer.register_parameter("w2_bias", w2_bias)
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS set_weight_attrs(w2_bias, extra_weight_attrs)
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
):
sf_block_size = 32 # mxfp4 block size
# Common shape assertions def _setup_kernel(
assert ( self,
layer.w13_weight.dim() == 3 layer: FusedMoE,
and layer.w13_weight.shape[0] == self.num_experts w13: torch.Tensor,
and layer.w13_weight.shape[1] == self.intermediate_size * 2 w2: torch.Tensor,
and layer.w13_weight.shape[2] == self.hidden_size // 2 w13_scale: torch.Tensor,
) w2_scale: torch.Tensor,
assert ( w13_bias: torch.Tensor | None = None,
layer.w13_weight_scale.dim() == 3 w2_bias: torch.Tensor | None = None,
and layer.w13_weight_scale.shape[0] == self.num_experts ) -> None:
and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2 num_experts = self.num_experts
and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size intermediate_size = self.intermediate_size
) hidden_size = self.hidden_size
assert ( sf_block_size = 32
layer.w2_weight.dim() == 3
and layer.w2_weight.shape[0] == self.num_experts # Shape assertions
and layer.w2_weight.shape[1] == self.hidden_size assert (
and layer.w2_weight.shape[2] == self.intermediate_size // 2 w13.dim() == 3
) and w13.shape[0] == num_experts
assert ( and w13.shape[1] == intermediate_size * 2
layer.w2_weight_scale.dim() == 3 and w13.shape[2] == hidden_size // 2
and layer.w2_weight_scale.shape[1] == self.hidden_size )
and layer.w2_weight_scale.shape[2] assert (
== self.intermediate_size // sf_block_size w13_scale.dim() == 3
) and w13_scale.shape[0] == num_experts
and w13_scale.shape[1] == intermediate_size * 2
and w13_scale.shape[2] == hidden_size // sf_block_size
)
assert (
w2.dim() == 3
and w2.shape[0] == num_experts
and w2.shape[1] == hidden_size
and w2.shape[2] == intermediate_size // 2
)
assert (
w2_scale.dim() == 3
and w2_scale.shape[1] == hidden_size
and w2_scale.shape[2] == intermediate_size // sf_block_size
)
if w13_bias is not None:
assert ( assert (
layer.w13_bias.dim() == 2 w13_bias.dim() == 2
and layer.w13_bias.shape[0] == self.num_experts and w13_bias.shape[0] == num_experts
and layer.w13_bias.shape[1] == self.intermediate_size * 2 and w13_bias.shape[1] == intermediate_size * 2
) )
if w2_bias is not None:
assert ( assert (
layer.w2_bias.dim() == 2 w2_bias.dim() == 2
and layer.w2_bias.shape[0] == self.num_experts and w2_bias.shape[0] == num_experts
and layer.w2_bias.shape[1] == self.hidden_size and w2_bias.shape[1] == hidden_size
) )
# De-interleave and swap for w13 weight, bias, and scales # Convert weights to kernel format
w13_w = layer.w13_weight.data w13, w2, w13_scale, w2_scale, w13_bias, w2_bias = (
gate_w, up_w = w13_w[:, ::2, :], w13_w[:, 1::2, :] convert_to_mxfp4_moe_kernel_format(
deinterleaved_w13_w = torch.cat([gate_w, up_w], dim=1) mxfp4_backend=self.mxfp4_backend,
w1_w, w3_w = torch.chunk(deinterleaved_w13_w, 2, dim=1) layer=layer,
w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1) w13_weight=w13,
w2_weight=w2,
w13_b = layer.w13_bias.data.to(torch.float32) w13_weight_scale=w13_scale,
gate_b, up_b = w13_b[:, ::2], w13_b[:, 1::2] w2_weight_scale=w2_scale,
deinterleaved_w13_b = torch.cat([gate_b, up_b], dim=1) w13_bias=w13_bias,
b1, b3 = torch.chunk(deinterleaved_w13_b, 2, dim=-1) w2_bias=w2_bias,
w13_bias_swapped = torch.cat([b3, b1], dim=-1).to(torch.bfloat16) _cache_permute_indices=self._cache_permute_indices,
w13_s = layer.w13_weight_scale.data
gate_s, up_s = w13_s[:, ::2, :], w13_s[:, 1::2, :]
deinterleaved_w13_s = torch.cat([gate_s, up_s], dim=1)
s1, s3 = torch.chunk(deinterleaved_w13_s, 2, dim=1)
w13_scale_swapped = torch.cat([s3, s1], dim=1)
if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS:
from flashinfer import block_scale_interleave
orig_shape = w13_scale_swapped.shape
w13_scale_interleaved = block_scale_interleave(
w13_scale_swapped.view(torch.uint8)
).reshape(orig_shape)
w2_s = layer.w2_weight_scale.data
orig_shape = w2_s.shape
w2_scale_interleaved = block_scale_interleave(
w2_s.view(torch.uint8)
).reshape(orig_shape)
layer.w13_weight = Parameter(w13_weight_swapped, requires_grad=False)
layer.w13_weight_scale = Parameter(
w13_scale_interleaved, requires_grad=False
)
layer.w13_bias = Parameter(w13_bias_swapped, requires_grad=False)
layer.w2_weight_scale = Parameter(
w2_scale_interleaved, requires_grad=False
)
elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16:
def _interleave_mxfp4_cutlass_sm90(w):
w_shape = w.shape
w_interleaved = w.reshape(
w_shape[0], w_shape[1], (w_shape[2] // 4), 4
)
w_interleaved = w_interleaved.permute(0, 2, 1, 3)
w_interleaved = w_interleaved.reshape(
w_shape[0], w_shape[2] // 4, w_shape[1] * 4
)
return w_interleaved
w31_scales = w13_scale_swapped.to(torch.uint8).view(torch.uint8)
w31_scales_interleaved = _interleave_mxfp4_cutlass_sm90(w31_scales)
w2_weight_scale = layer.w2_weight_scale.data
w2_scales = w2_weight_scale.to(torch.uint8).view(torch.uint8)
w2_scales_interleaved = _interleave_mxfp4_cutlass_sm90(w2_scales)
layer.w13_weight = torch.nn.Parameter(
torch.cat([w3_w, w1_w], dim=1), requires_grad=False
)
layer.w13_bias = torch.nn.Parameter(
w13_bias_swapped, requires_grad=False
)
layer.w13_weight_scale = torch.nn.Parameter(
w31_scales_interleaved, requires_grad=False
)
layer.w2_weight_scale = torch.nn.Parameter(
w2_scales_interleaved, requires_grad=False
)
# theses two kernels go through the `flashinfer_cutlass_fused_moe` path
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
) )
)
self.moe_quant_config = self.get_fused_moe_quant_config(layer) # For TRITON backends, weights are wrapped tensors from triton_kernels
assert self.moe_quant_config is not None # that don't support .detach(). Manually assign parameters.
prepare_finalize = maybe_make_prepare_finalize( if self.mxfp4_backend not in TRITON_BACKENDS:
moe=self.moe, replace_parameter(layer, "w13_weight", w13)
quant_config=self.moe_quant_config, replace_parameter(layer, "w2_weight", w2)
replace_parameter(layer, "w13_weight_scale", w13_scale)
replace_parameter(layer, "w2_weight_scale", w2_scale)
else:
layer.w13_weight = w13
layer.w2_weight = w2
self.w13_precision_config = w13_scale
self.w2_precision_config = w2_scale
if w13_bias is not None and w2_bias is not None:
replace_parameter(layer, "w13_bias", w13_bias)
replace_parameter(layer, "w2_bias", w2_bias)
# Build quant config
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
# Build kernel (modular or monolithic)
if self.moe_quant_config is not None and self.experts_cls is not None:
self.moe_kernel = make_mxfp4_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
mxfp4_backend=self.mxfp4_backend,
experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(), routing_tables=layer._maybe_init_expert_routing_tables(),
allow_new_interface=True, shared_experts=layer.shared_experts,
)
assert prepare_finalize is not None
self.moe_kernel = mk.FusedMoEKernel(
prepare_finalize,
FlashInferExperts(
moe_config=self.moe,
quant_config=self.moe_quant_config,
),
shared_experts=None,
)
elif self.mxfp4_backend == Mxfp4Backend.CK:
if layer.w13_bias is not None:
layer.w13_bias.data = layer.w13_bias.data.to(torch.float32)
if layer.w2_bias.data is not None:
layer.w2_bias.data = layer.w2_bias.data.to(torch.float32)
e, n, k = layer.w13_weight.shape
layer.w13_weight.view(torch.uint8).copy_(
layer.w13_weight.data.view(torch.uint8)
.view(e, n // 2, 2, k)
.permute(0, 2, 1, 3)
.contiguous()
.view(e, n, k)
)
layer.w13_weight_scale.data = (
layer.w13_weight_scale.data.view(e, n // 2, 2, -1)
.permute(0, 2, 1, 3)
.contiguous()
.view(e, n, -1)
)
layer.w13_weight.data = layer.w13_weight.data.view(torch.float4_e2m1fn_x2)
layer.w2_weight.data = layer.w2_weight.data.view(torch.float4_e2m1fn_x2)
layer.w13_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(
layer.w13_weight, 16, True
)
shuffled_w13_scale = rocm_aiter_ops.shuffle_scale_a16w4(
layer.w13_weight_scale.view(-1, layer.w13_weight_scale.shape[-1]),
self.num_experts,
True,
)
layer.w2_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(
layer.w2_weight, 16, False
)
shuffled_w2_scale = rocm_aiter_ops.shuffle_scale_a16w4(
layer.w2_weight_scale.view(-1, layer.w2_weight_scale.shape[-1]),
self.num_experts,
False,
) )
layer.w13_bias.data = ( def process_weights_after_loading(self, layer):
layer.w13_bias.data.view(-1, n // 2, 2) w13 = layer.w13_weight
.permute(0, 2, 1) w2 = layer.w2_weight
.contiguous() w13_scale = layer.w13_weight_scale
.view(-1, n) w2_scale = layer.w2_weight_scale
) w13_bias = getattr(layer, "w13_bias", None)
w2_bias = getattr(layer, "w2_bias", None)
layer.w13_weight_scale = torch.nn.Parameter(
shuffled_w13_scale, requires_grad=False
)
layer.w2_weight_scale = torch.nn.Parameter(
shuffled_w2_scale, requires_grad=False
)
# replace_parameter(layer, "w13_bias", w13_bias)
# replace_parameter(layer, "w13_weight_scale", w13_weight_scale)
# replace_parameter(layer, "w2_weight_scale", w2_weight_scale)
# replace_parameter(layer, "w13_weight", w13_weight)
# replace_parameter(layer, "w2_weight", w2_weight)
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
w13_bias = layer.w13_bias.to(torch.float32)
w2_bias = layer.w2_bias.to(torch.float32)
layer.w13_bias = Parameter(w13_bias, requires_grad=False)
layer.w2_bias = Parameter(w2_bias, requires_grad=False)
# Ideally we'd use FusedMoEModularKernel.prepare_finalize object
# (stored in self.fused_experts) to determine if the MoE has a
# batched activation format. As self.fused_experts is not
# initialized at this point, we resort to checking the MoE config
# directly.
is_batched_moe = (
self.moe.use_deepep_ll_kernels or self.moe.use_nixl_ep_kernels
)
if is_batched_moe:
num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
else:
num_warps = 8
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
layer.w13_weight, layer.w13_weight_scale, num_warps
)
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
layer.w2_weight, layer.w2_weight_scale, num_warps
)
self.w13_precision_config = PrecisionConfig( if self.mxfp4_backend == Mxfp4MoeBackend.NONE:
weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex) return
)
self.w2_precision_config = PrecisionConfig(
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
)
self.w13_weight = w13_weight
self.w2_weight = w2_weight
del layer.w13_weight
del layer.w2_weight
layer.w13_weight = w13_weight
layer.w2_weight = w2_weight
else: self._setup_kernel(layer, w13, w2, w13_scale, w2_scale, w13_bias, w2_bias)
raise ValueError(
f"Unsupported mxfp4_backend: {self.mxfp4_backend}: "
f"should be one of: {list(Mxfp4Backend)}."
)
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
self, layer: torch.nn.Module self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None: ) -> FusedMoEQuantConfig | None:
if self.mxfp4_backend == Mxfp4Backend.MARLIN: w1_scale = layer.w13_weight_scale
return mxfp4_w4a16_moe_quant_config( w2_scale = layer.w2_weight_scale
w1_bias=layer.w13_bias, w1_bias = getattr(layer, "w13_bias", None)
w2_bias=layer.w2_bias, w2_bias = getattr(layer, "w2_bias", None)
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale, if self.mxfp4_backend in TRITON_BACKENDS:
) assert self.w13_precision_config is not None
elif self.mxfp4_backend == Mxfp4Backend.TRITON: assert self.w2_precision_config is not None
w1_scale = self.w13_precision_config w1_scale = self.w13_precision_config
w2_scale = self.w2_precision_config w2_scale = self.w2_precision_config
return mxfp4_w4a16_moe_quant_config(
w1_bias=layer.w13_bias, return make_mxfp4_moe_quant_config(
w2_bias=layer.w2_bias, mxfp4_backend=self.mxfp4_backend,
w1_scale=w1_scale, w1_scale=w1_scale,
w2_scale=w2_scale, w2_scale=w2_scale,
) w1_bias=w1_bias,
elif self.mxfp4_backend in [ w2_bias=w2_bias,
Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM, )
Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS,
]:
return mxfp4_mxfp8_moe_quant_config(
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
)
elif self.mxfp4_backend in [
Mxfp4Backend.SM100_FI_MXFP4_BF16,
Mxfp4Backend.SM90_FI_MXFP4_BF16,
Mxfp4Backend.CK,
]:
return mxfp4_w4a16_moe_quant_config(
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
)
else:
w1_scale = layer.w13_weight_scale
w2_scale = layer.w2_weight_scale
return ocp_mx_moe_quant_config(
quant_dtype="mxfp4",
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
w1_scale=w1_scale,
w2_scale=w2_scale,
)
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular, prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module, layer: torch.nn.Module,
) -> mk.FusedMoEExpertsModular: ) -> mk.FusedMoEExpertsModular:
if ( raise ValueError(
prepare_finalize.activation_format f"{self.__class__.__name__} uses the new modular kernel "
== mk.FusedMoEActivationFormat.BatchedExperts "initialization logic. This function should not be called."
):
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None
assert self.moe_quant_config is not None
return BatchedMarlinExperts(
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
moe_config=self.moe,
)
else:
raise NotImplementedError(
f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for "
"EP batched experts format"
)
else:
assert self.moe_quant_config is not None
if (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
):
# B200 code-path
kwargs = {
# TODO(bnell): part of quant_config
"max_capture_size": self.max_capture_size,
}
return TrtLlmGenExperts(self.moe, self.moe_quant_config, **kwargs)
elif self.mxfp4_backend == Mxfp4Backend.MARLIN:
return MarlinExperts(self.moe, self.moe_quant_config)
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
if self.moe.is_lora_enabled:
return UnfusedOAITritonExperts(self.moe, self.moe_quant_config)
return OAITritonExperts(self.moe, self.moe_quant_config)
else:
raise NotImplementedError(
f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for EP"
)
@property
def is_monolithic(self) -> bool:
if self.moe.is_lora_enabled:
return False
return (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
or self.mxfp4_backend == Mxfp4Backend.TRITON
or self.mxfp4_backend == Mxfp4Backend.CK
) )
def apply( def apply(
...@@ -1053,30 +379,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -1053,30 +379,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic assert not self.is_monolithic
if layer.enable_eplb:
raise NotImplementedError("EPLB is not supported for mxfp4")
assert _can_support_mxfp4(
layer.use_grouped_topk,
layer.topk_group,
layer.num_expert_group,
layer.expert_map,
layer.custom_routing_function,
layer.e_score_correction_bias,
layer.apply_router_weight_on_input,
layer.scoring_func,
layer.activation,
layer.eplb_state.expert_load_view,
layer.eplb_state.logical_to_physical_map,
layer.eplb_state.logical_replica_count,
), "MXFP4 are not supported with this configuration."
assert (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
or self.mxfp4_backend == Mxfp4Backend.MARLIN
)
assert self.moe_kernel is not None assert self.moe_kernel is not None
return self.moe_kernel.apply( return self.moe_kernel.apply(
hidden_states=x, hidden_states=x,
...@@ -1098,126 +400,17 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -1098,126 +400,17 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.is_monolithic assert self.is_monolithic
assert self.moe_kernel is not None
if layer.enable_eplb: return self.moe_kernel.apply_monolithic(
raise NotImplementedError("EPLB is not supported for mxfp4") hidden_states=x,
w1=layer.w13_weight,
assert _can_support_mxfp4( w2=layer.w2_weight,
layer.use_grouped_topk, router_logits=router_logits,
layer.topk_group, activation=layer.activation,
layer.num_expert_group, global_num_experts=layer.global_num_experts,
layer.expert_map, expert_map=layer.expert_map,
layer.custom_routing_function, apply_router_weight_on_input=layer.apply_router_weight_on_input,
layer.e_score_correction_bias, )
layer.apply_router_weight_on_input,
layer.scoring_func,
layer.activation,
layer.eplb_state.expert_load_view,
layer.eplb_state.logical_to_physical_map,
layer.eplb_state.logical_replica_count,
), "MXFP4 are not supported with this configuration."
# Apply routing simulation strategy if specified.
# This applies to all monolithic backends (SM100_FI and TRITON).
routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY
if routing_strategy == "uniform_random":
router_logits = torch.rand_like(router_logits)
if (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
):
from flashinfer import trtllm_fp4_block_scale_moe
if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16:
assert x.dtype == torch.bfloat16
x_quant = x
x_scale = None
elif self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM:
from flashinfer import mxfp8_quantize
# x_quant is padded in hidden dimension with alignment=256
x_quant, x_scale = mxfp8_quantize(
x,
is_sf_swizzled_layout=False,
alignment=256,
)
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x.shape[:-1], -1)
# output with original unpadded hidden size
output = torch.empty_like(x)
trtllm_gen_output = trtllm_fp4_block_scale_moe(
routing_logits=router_logits.to(torch.bfloat16),
routing_bias=None,
hidden_states=x_quant,
hidden_states_scale=x_scale,
gemm1_weights=layer.w13_weight, # uint8 (e2m1 x 2)
gemm1_weights_scale=layer.w13_weight_scale, # uint8 (e4m3 x 2)
gemm1_bias=layer.w13_bias, # fp32 per expert per channel
gemm1_alpha=layer.gemm1_alpha, # fp32 per expert
gemm1_beta=layer.gemm1_beta, # fp32 per expert
gemm1_clamp_limit=layer.gemm1_clamp_limit, # fp32 per expert
gemm2_weights=layer.w2_weight, # uint8 (e2m1 x 2)
gemm2_weights_scale=layer.w2_weight_scale, # ue8m0
gemm2_bias=layer.w2_bias, # fp32 per expert per channel
output1_scale_scalar=None,
output1_scale_gate_scalar=None,
output2_scale_scalar=None,
num_experts=layer.global_num_experts,
top_k=layer.top_k,
n_group=None,
topk_group=None,
intermediate_size=self.intermediate_size, # padded to multiple of 256
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=self.num_experts,
routed_scaling_factor=None,
routing_method_type=1 if layer.renormalize else 0,
do_finalize=True,
tune_max_num_tokens=max(self.max_capture_size, 1),
output=output,
)[0]
return trtllm_gen_output
elif self.mxfp4_backend == Mxfp4Backend.CK:
topk_weights, topk_ids = rocm_aiter_ops.fused_topk(
x, router_logits, layer.top_k, True
)
output = rocm_aiter_ops.fused_moe(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation_method=rocm_aiter_ops.get_aiter_activation_type("swiglu"),
quant_method=rocm_aiter_ops.get_aiter_quant_type("per_1x32"),
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
doweight_stage1=False,
hidden_pad=self.hidden_pad // 128 * 128,
intermediate_pad=self.intermediate_pad // 64 * 64 * 2,
bias1=layer.w13_bias,
bias2=layer.w2_bias,
)
return output
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501
triton_kernel_moe_forward,
)
return triton_kernel_moe_forward(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
gating_output=router_logits,
topk=layer.top_k,
renormalize=layer.renormalize,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
else:
raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
class XpuMxfp4MoEMethod(Mxfp4MoEMethod): class XpuMxfp4MoEMethod(Mxfp4MoEMethod):
......
...@@ -25,9 +25,9 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -25,9 +25,9 @@ from vllm.model_executor.layers.fused_moe.config import (
ocp_mx_moe_quant_config, ocp_mx_moe_quant_config,
) )
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.quantization.mxfp4 import ( from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import (
Mxfp4Backend, Mxfp4MoeBackend,
get_mxfp4_backend, select_mxfp4_moe_backend,
) )
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_fp8_moe_layer_for_marlin, prepare_fp8_moe_layer_for_marlin,
...@@ -699,9 +699,9 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ...@@ -699,9 +699,9 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
f"Please check that the combination is supported in OCP_MX_Scheme." f"Please check that the combination is supported in OCP_MX_Scheme."
) )
self.mxfp4_backend: Mxfp4Backend | None = None self.mxfp4_backend: Mxfp4MoeBackend | None = None
if self.ocp_mx_scheme == "w_mxfp4": if self.ocp_mx_scheme == "w_mxfp4":
self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled) self.mxfp4_backend, _ = select_mxfp4_moe_backend(moe)
if self.input_quant is not None: if self.input_quant is not None:
self.static_input_scales = not self.input_quant.get("is_dynamic") self.static_input_scales = not self.input_quant.get("is_dynamic")
......
...@@ -389,9 +389,9 @@ def prepare_moe_fp4_layer_for_marlin( ...@@ -389,9 +389,9 @@ def prepare_moe_fp4_layer_for_marlin(
group_size = 16 if is_nvfp4 else 32 group_size = 16 if is_nvfp4 else 32
e = layer.num_experts e = layer.moe_config.num_experts
k = layer.hidden_size k = layer.moe_config.hidden_dim
n = layer.intermediate_size_per_partition n = layer.moe_config.intermediate_size_per_partition
# WORKSPACE # WORKSPACE
device = layer.w13_weight.device device = layer.w13_weight.device
...@@ -500,6 +500,120 @@ def prepare_moe_fp4_layer_for_marlin( ...@@ -500,6 +500,120 @@ def prepare_moe_fp4_layer_for_marlin(
setattr(layer, name, bias) setattr(layer, name, bias)
def prepare_moe_mxfp4_layer_for_marlin(
layer: torch.nn.Module,
w13: torch.Tensor,
w2: torch.Tensor,
w13_scale: torch.Tensor,
w2_scale: torch.Tensor,
w13_bias: torch.Tensor | None,
w2_bias: torch.Tensor | None,
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor | None,
torch.Tensor | None,
]:
"""Pure-function version of prepare_moe_fp4_layer_for_marlin for MXFP4.
Takes weight tensors as inputs and returns transformed tensors.
Does NOT modify the layer in-place.
"""
input_dtype = get_marlin_input_dtype()
if (
input_dtype is not None
and input_dtype.itemsize == 1
and input_dtype != torch.float8_e4m3fn
):
raise RuntimeError("MXFP4 weight + INT8 activation is not supported.")
group_size = 32 # MXFP4 block size
# Derive dimensions from actual weight shapes to handle rounded/padded
# sizes correctly (e.g., Mxfp4MoEMethod rounds up hidden_dim).
# w13 shape: (E, 2*N, K//2)
e = w13.shape[0]
n = w13.shape[1] // 2 # intermediate_size_per_partition
k = w13.shape[2] * 2 # hidden_size
device = w13.device
param_dtype = layer.params_dtype
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
perm = torch.empty(0, dtype=torch.int, device=device)
# WEIGHT: Repack weights to marlin format
def repack_weight(weight: torch.Tensor, name: str) -> torch.Tensor:
tensor_list = []
if "w13" in name:
size_n, size_k = n * 2, k
else:
size_n, size_k = k, n
assert weight.shape == (e, size_n, size_k // 2)
for i in range(e):
qweight = weight[i].view(torch.int32).T.contiguous()
marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=qweight,
perm=perm,
size_k=size_k,
size_n=size_n,
num_bits=4,
is_a_8bit=is_a_8bit,
)
tensor_list.append(marlin_qweight)
return torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
w13 = repack_weight(w13, "w13")
w2 = repack_weight(w2, "w2")
# WEIGHT SCALES: Permute scales
def permute_scales(scales: torch.Tensor, name: str) -> torch.Tensor:
scales = scales.view(torch.float8_e8m0fnu)
scales = scales.to(param_dtype)
tensor_list = []
if "w13" in name:
size_n, size_k = n * 2, k
else:
size_n, size_k = k, n
for i in range(e):
scale = scales[i].T
marlin_scales = marlin_permute_scales(
s=scale,
size_k=size_k,
size_n=size_n,
group_size=group_size,
is_a_8bit=is_a_8bit,
)
marlin_scales = mxfp4_marlin_process_scales(
marlin_scales, input_dtype=input_dtype
)
tensor_list.append(marlin_scales)
return torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
w13_scale = permute_scales(w13_scale, "w13")
w2_scale = permute_scales(w2_scale, "w2")
# BIAS: Permute bias
def permute_bias(bias: torch.Tensor | None) -> torch.Tensor | None:
if bias is None:
return None
bias = bias.to(param_dtype)
tensor_list = []
for i in range(e):
tensor_list.append(marlin_permute_bias(bias[i]))
return torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
w13_bias = permute_bias(w13_bias)
w2_bias = permute_bias(w2_bias)
return w13, w2, w13_scale, w2_scale, w13_bias, w2_bias
def rand_marlin_weight_nvfp4_like(weight, group_size, input_dtype=None): def rand_marlin_weight_nvfp4_like(weight, group_size, input_dtype=None):
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1 is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import Any from typing import Any
import torch import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import triton from vllm.triton_utils import triton
from vllm.utils.import_utils import has_triton_kernels from vllm.utils.import_utils import has_triton_kernels
...@@ -22,7 +20,7 @@ logger = init_logger(__name__) ...@@ -22,7 +20,7 @@ logger = init_logger(__name__)
CK_MXFP4_MOE_DIM_ALIGNMENT = 256 CK_MXFP4_MOE_DIM_ALIGNMENT = 256
def _swizzle_mxfp4(quant_tensor, scale, num_warps): def _swizzle_mxfp4(quant_tensor, scale, num_warps=8):
"""weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel""" """weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel"""
assert has_triton_kernels() assert has_triton_kernels()
import triton_kernels.matmul_ogs_details.opt_flags as opt_flags import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
...@@ -87,35 +85,6 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps): ...@@ -87,35 +85,6 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps):
return quant_tensor, InFlexData(), scale return quant_tensor, InFlexData(), scale
def _can_support_mxfp4(
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
scoring_func: str = "softmax",
activation: MoEActivation = MoEActivation.SWIGLUOAI,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
):
return not (
use_grouped_topk
or topk_group
or num_expert_group
or custom_routing_function
or e_score_correction_bias
or apply_router_weight_on_input
or scoring_func != "softmax"
or activation != MoEActivation.SWIGLUOAI
or expert_load_view
or logical_to_physical_map
or logical_replica_count
)
def get_padding_alignment(): def get_padding_alignment():
return ( return (
256 256
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment