Unverified Commit 87bd9189 authored by Yongye Zhu's avatar Yongye Zhu Committed by GitHub
Browse files

[MoE Refactor] Mxfp4 oracle rebased (#37128)


Signed-off-by: default avatarYongye Zhu <zyy1102000@gmail.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent c7f98b4d
......@@ -88,8 +88,8 @@ To be used with a particular `FusedMoEPrepareAndFinalizeModular` subclass, MoE k
| flashinfer | standard | nvfp4,</br>fp8 | T | <sup>5</sup> | N | Y | [`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] |
| gpt oss triton | standard | N/A | N/A | <sup>5</sup> | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],</br>[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] |
| marlin | standard,</br>batched | <sup>3</sup> / N/A | <sup>3</sup> / N/A | silu,</br>swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],</br>[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],</br>[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] |
| trtllm | standard | mxfp4,</br>nvfp4 | G(16),G(32) | <sup>5</sup> | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] |
| rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_experts] |
| trtllm | standard | mxfp4,</br>nvfp4 | G(16),G(32) | <sup>5</sup> | N | Y | [`TrtLlmMxfp4ExpertsMonolithic`][vllm.model_executor.layers.fused_moe.experts.trtllm_mxfp4_moe.TrtLlmMxfp4ExpertsMonolithic],</br>[`TrtLlmMxfp4ExpertsModular`][vllm.model_executor.layers.fused_moe.experts.trtllm_mxfp4_moe.TrtLlmMxfp4ExpertsModular],</br>[`TrtLlmNvFp4ExpertsMonolithic`][vllm.model_executor.layers.fused_moe.experts.trtllm_nvfp4_moe.TrtLlmNvFp4ExpertsMonolithic],</br>[`TrtLlmNvfp4ExpertsModular`][vllm.model_executor.layers.fused_moe.experts.trtllm_nvfp4_moe.TrtLlmNvFp4ExpertsModular] |
| rocm aiter moe | standard | mxfp4,</br>fp8 | G(32),G(128),A,T | silu, gelu,</br>swigluoai | Y | N | `rocm_aiter_fused_experts`,</br>`AiterExperts` |
| cpu_fused_moe | standard | N/A | N/A | silu | N | N | [`CPUFusedMOE`][vllm.model_executor.layers.fused_moe.cpu_fused_moe.CPUFusedMOE] |
| naive batched<sup>4</sup> | batched | int8,</br>fp8 | G,A,T | silu, gelu | <sup>6</sup> | Y | [`NaiveBatchedExperts`][vllm.model_executor.layers.fused_moe.fused_batched_moe.NaiveBatchedExperts] |
......
......@@ -84,7 +84,10 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
# TODO: remove this after finishing migration from envs to model kwargs
if model_name == "openai/gpt-oss-20b":
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "1")
from .common import is_blackwell
if is_blackwell():
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "1")
# Disable, compile cache to make sure custom passes run.
# Otherwise, we can't verify fusion happened through the logs.
......
......@@ -6,6 +6,7 @@ import pytest
import torch
import torch.nn.functional as F
from vllm.platforms import current_platform
from vllm.utils.import_utils import has_triton_kernels
if not has_triton_kernels():
......@@ -14,6 +15,7 @@ if not has_triton_kernels():
allow_module_level=True,
)
import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
import triton_kernels.swiglu
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
from triton_kernels.numerics import InFlexData
......@@ -303,6 +305,12 @@ def test_equiv(num_token, a_dtype, w_dtype, tp, workspace_init):
pc2,
) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=8)
if current_platform.is_device_capability_family(100):
constraints = {
"is_persistent": True,
}
opt_flags.update_opt_flags_constraints(constraints)
if a_dtype == "bf16" and w_dtype == "mx4":
quant_config = mxfp4_w4a16_moe_quant_config(
w1_scale=pc1,
......
......@@ -82,7 +82,7 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
model_case.model_id,
tensor_parallel_size=model_case.tp,
load_format="dummy",
cudagraph_capture_sizes=[16],
compilation_config={"cudagraph_capture_sizes": [16]},
) as llm:
# Disabled as check_model is broken: https://github.com/vllm-project/vllm/pull/18465#issuecomment-3329880562
# def check_model(model):
......
......@@ -17,89 +17,6 @@ from unittest.mock import MagicMock, patch
import pytest
import torch
from vllm.model_executor.layers.quantization.mxfp4 import (
Mxfp4Backend,
Mxfp4MoEMethod,
)
def _make_mock_moe_config(ep_size: int = 1) -> MagicMock:
"""Create a mock FusedMoEConfig with the given EP size."""
parallel_config = MagicMock()
parallel_config.ep_size = ep_size
moe_config = MagicMock()
moe_config.ep_size = ep_size
moe_config.is_lora_enabled = False
moe_config.moe_parallel_config = parallel_config
return moe_config
class TestMxfp4TritonIsMonolithic:
"""Verify that is_monolithic is always True for the TRITON backend,
regardless of EP size, since triton_kernel_moe_forward now handles
expert_map remapping internally."""
@pytest.mark.parametrize(
"backend,ep_size,expected_monolithic",
[
# TRITON is always monolithic (handles EP via expert_map remapping)
(Mxfp4Backend.TRITON, 1, True),
(Mxfp4Backend.TRITON, 2, True),
(Mxfp4Backend.TRITON, 4, True),
# SM100 backends are always monolithic
(Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM, 1, True),
(Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM, 2, True),
(Mxfp4Backend.SM100_FI_MXFP4_BF16, 1, True),
(Mxfp4Backend.SM100_FI_MXFP4_BF16, 2, True),
# MARLIN is never monolithic
(Mxfp4Backend.MARLIN, 1, False),
(Mxfp4Backend.MARLIN, 2, False),
],
ids=[
"triton-no-ep",
"triton-ep2",
"triton-ep4",
"sm100-trtllm-no-ep",
"sm100-trtllm-ep2",
"sm100-bf16-no-ep",
"sm100-bf16-ep2",
"marlin-no-ep",
"marlin-ep2",
],
)
@patch(
"vllm.model_executor.layers.quantization.mxfp4.get_mxfp4_backend",
)
@patch(
"vllm.model_executor.layers.quantization.mxfp4.get_current_vllm_config",
)
def test_is_monolithic(
self,
mock_get_config,
mock_get_backend,
backend,
ep_size,
expected_monolithic,
):
"""is_monolithic should be True for TRITON regardless of EP size."""
mock_get_backend.return_value = backend
mock_compilation_config = MagicMock()
mock_compilation_config.max_cudagraph_capture_size = 1024
mock_vllm_config = MagicMock()
mock_vllm_config.compilation_config = mock_compilation_config
mock_get_config.return_value = mock_vllm_config
moe_config = _make_mock_moe_config(ep_size=ep_size)
method = Mxfp4MoEMethod(moe_config)
assert method.is_monolithic == expected_monolithic, (
f"Expected is_monolithic={expected_monolithic} for "
f"backend={backend.name}, ep_size={ep_size}, "
f"but got {method.is_monolithic}."
)
class TestTritonMoeForwardExpertMap:
"""Test that triton_kernel_moe_forward applies expert_map remapping
......
......@@ -9,80 +9,236 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kMxfp4Static,
kMxfp8Dynamic,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer
class TrtLlmGenExperts(mk.FusedMoEExpertsModular):
"""TensorRT-LLM-based fused MoE expert implementation."""
class TrtLlmMxfp4ExpertsBase:
"""
MXFP4 TRTLLM-Gen MoE kernels. Shared base for modular and monolithic.
"""
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
max_capture_size,
):
super().__init__(moe_config, quant_config)
self.device = torch.accelerator.current_device_index()
self.num_experts = moe_config.num_local_experts
# NOTE: FusedMoEExperts.__init__ is called by the concrete subclass
# (Monolithic/Modular) via MRO, not here, to avoid mypy issues with
# multiple inheritance. This matches the NvFP4 expert pattern.
self.moe_config = moe_config
self.quant_config = quant_config
self.routing_method_type = moe_config.routing_method
self.topk = moe_config.experts_per_token
self.intermediate_size_per_partition = (
moe_config.intermediate_size_per_partition
)
self.hidden_dim = moe_config.hidden_dim
self.local_num_experts = moe_config.num_local_experts
self.ep_rank = moe_config.moe_parallel_config.ep_rank
# MXFP4-specific TRTLLM parameters
device = torch.accelerator.current_device_index()
self.gemm1_alpha = torch.tensor(
[1.702] * self.num_experts, dtype=torch.float32, device=self.device
[1.702] * self.local_num_experts,
dtype=torch.float32,
device=device,
)
self.gemm1_beta = torch.tensor(
[1.0] * self.num_experts, dtype=torch.float32, device=self.device
[1.0] * self.local_num_experts,
dtype=torch.float32,
device=device,
)
self.gemm1_clamp_limit = torch.tensor(
[7.0] * self.num_experts, dtype=torch.float32, device=self.device
[7.0] * self.local_num_experts,
dtype=torch.float32,
device=device,
)
self.max_capture_size = max_capture_size
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
from vllm.config import get_current_vllm_config
self.max_capture_size = (
get_current_vllm_config().compilation_config.max_cudagraph_capture_size
)
# P1-5 fix: use public quant_dtype property instead of private _a1
self.use_mxfp8_input = quant_config.quant_dtype == "mxfp8"
@staticmethod
def _supports_current_device() -> bool:
raise NotImplementedError(
"TrtLlmGenExperts is not yet used by an Oracle. "
"This method should not be called."
)
p = current_platform
return p.is_cuda() and p.is_device_capability_family(100) and has_flashinfer()
@staticmethod
def _supports_no_act_and_mul() -> bool:
raise NotImplementedError(
"TrtLlmGenExperts is not yet used by an Oracle. "
"This method should not be called."
)
return False
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
raise NotImplementedError(
"TrtLlmGenExperts is not yet used by an Oracle. "
"This method should not be called."
)
SUPPORTED_W_A = [
(kMxfp4Static, None),
(kMxfp4Static, kMxfp8Dynamic),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
raise NotImplementedError(
"TrtLlmGenExperts is not yet used by an Oracle. "
"This method should not be called."
)
return activation == MoEActivation.SWIGLUOAI
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
raise NotImplementedError(
"TrtLlmGenExperts is not yet used by an Oracle. "
"This method should not be called."
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def supports_chunking(self) -> bool:
return False
def supports_expert_map(self) -> bool:
return False
@property
def expects_unquantized_inputs(self) -> bool:
# Expert handles MXFP8 quantization internally if needed
return True
class TrtLlmMxfp4ExpertsMonolithic(
TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsMonolithic
):
"""
Monolithic version of the MXFP4 TRTLLM kernel (router + experts).
Wraps flashinfer.trtllm_fp4_block_scale_moe().
"""
@staticmethod
def _supports_parallel_config(
moe_parallel_config: FusedMoEParallelConfig,
) -> bool:
return (
not moe_parallel_config.use_all2all_kernels
and not moe_parallel_config.enable_eplb
and moe_parallel_config.dp_size <= 1
)
@staticmethod
def _supports_routing_method(
routing_method: RoutingMethodType,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
return routing_method in [
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
]
@staticmethod
def _supports_router_logits_dtype(
router_logits_dtype: torch.dtype | None,
routing_method: RoutingMethodType,
) -> bool:
# Kernel converts to bfloat16 internally
return True
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
apply_router_weight_on_input: bool,
# grouped topk + fused topk bias parameters
num_expert_group: int | None = None,
e_score_correction_bias: torch.Tensor | None = None,
routed_scaling_factor: float | None = None,
topk_group: int | None = None,
) -> torch.Tensor:
from flashinfer import trtllm_fp4_block_scale_moe
# Handle input quantization
if self.use_mxfp8_input:
from flashinfer import mxfp8_quantize
x_quant, x_scale = mxfp8_quantize(
hidden_states,
is_sf_swizzled_layout=False,
alignment=256,
)
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(
*hidden_states.shape[:-1], -1
)
else:
assert hidden_states.dtype == torch.bfloat16
x_quant = hidden_states
x_scale = None
output = torch.empty_like(hidden_states)
return trtllm_fp4_block_scale_moe(
routing_logits=router_logits.to(torch.bfloat16),
routing_bias=None,
hidden_states=x_quant,
hidden_states_scale=x_scale,
gemm1_weights=w1,
gemm1_weights_scale=self.w1_scale,
gemm1_bias=self.w1_bias,
gemm1_alpha=self.gemm1_alpha,
gemm1_beta=self.gemm1_beta,
gemm1_clamp_limit=self.gemm1_clamp_limit,
gemm2_weights=w2,
gemm2_weights_scale=self.w2_scale,
gemm2_bias=self.w2_bias,
output1_scale_scalar=None,
output1_scale_gate_scalar=None,
output2_scale_scalar=None,
num_experts=global_num_experts,
top_k=self.topk,
n_group=None,
topk_group=None,
intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.ep_rank * self.local_num_experts,
local_num_experts=self.local_num_experts,
routed_scaling_factor=None,
routing_method_type=self.routing_method_type,
do_finalize=True,
tune_max_num_tokens=max(self.max_capture_size, 1),
output=output,
)[0]
class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModular):
"""
Modular version of the MXFP4 TRTLLM kernel (just the experts).
Wraps flashinfer.trtllm_fp4_block_scale_routed_moe().
Moved from trtllm_moe.py.
"""
@property
def expects_unquantized_inputs(self) -> bool:
return True
@staticmethod
def _supports_parallel_config(
moe_parallel_config: FusedMoEParallelConfig,
) -> bool:
return True
def supports_expert_map(self) -> bool:
return True
......@@ -129,10 +285,22 @@ class TrtLlmGenExperts(mk.FusedMoEExpertsModular):
intermediate_size = w2.size(1)
local_expert_offset = self.moe_config.ep_rank * local_num_experts
x_quant = hidden_states
x_scale = a1q_scale
if x_scale is not None:
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x_quant.shape[:-1], -1)
# Handle input quantization
if self.use_mxfp8_input:
from flashinfer import mxfp8_quantize
x_quant, x_scale = mxfp8_quantize(
hidden_states,
is_sf_swizzled_layout=False,
alignment=256,
)
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(
*hidden_states.shape[:-1], -1
)
else:
assert hidden_states.dtype == torch.bfloat16
x_quant = hidden_states
x_scale = None
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
torch.bfloat16
......@@ -165,7 +333,7 @@ class TrtLlmGenExperts(mk.FusedMoEExpertsModular):
"local_expert_offset": local_expert_offset,
"local_num_experts": local_num_experts,
"routed_scaling_factor": None,
"routing_method_type": 1,
"routing_method_type": self.routing_method_type,
"do_finalize": True,
"output": output,
"tune_max_num_tokens": max(self.max_capture_size, 1),
......
......@@ -40,6 +40,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8Static128BlockSym,
kFp8StaticChannelSym,
kFp8StaticTensorSym,
kMxfp4Static,
kNvfp4Static,
)
from vllm.platforms import current_platform
......@@ -574,12 +575,13 @@ class MarlinExpertsBase(mk.FusedMoEExpertsModular):
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
# TODO(rob): add int4, mxfp4, int8 as integrations
# TODO(rob): add int4, int8 as integrations
# are migrated to use the oracle one-by-one.
SUPPORTED_W = [
kFp8Static128BlockSym,
kFp8StaticChannelSym,
kFp8StaticTensorSym,
kMxfp4Static,
kNvfp4Static,
]
return weight_key in SUPPORTED_W
......
......@@ -11,8 +11,10 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
......@@ -20,6 +22,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kMxfp4Static,
)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
......@@ -537,43 +540,43 @@ def make_routing_data(
class BaseOAITritonExperts(mk.FusedMoEExpertsModular):
@property
def expects_unquantized_inputs(self) -> bool:
return True
@staticmethod
def _supports_current_device() -> bool:
raise NotImplementedError(
"OAITritonExperts is not yet used by an Oracle. "
"This method should not be called."
)
p = current_platform
if not p.is_cuda_alike():
return False
cap = p.get_device_capability()
if cap is None:
return False
# (9,0) <= cap < (11,0) covers CUDA SM90 (Hopper), SM100+ (Blackwell)
# and ROCm gfx942/gfx950 (which map to 9.4/9.5).
return (9, 0) <= (cap.major, cap.minor) < (11, 0)
@staticmethod
def _supports_no_act_and_mul() -> bool:
raise NotImplementedError(
"OAITritonExperts is not yet used by an Oracle. "
"This method should not be called."
)
return False
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
raise NotImplementedError(
"OAITritonExperts is not yet used by an Oracle. "
"This method should not be called."
)
SUPPORTED_W_A = [
(kMxfp4Static, None),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
raise NotImplementedError(
"OAITritonExperts is not yet used by an Oracle. "
"This method should not be called."
)
raise NotImplementedError
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
raise NotImplementedError(
"OAITritonExperts is not yet used by an Oracle. "
"This method should not be called."
)
return True
def supports_expert_map(self) -> bool:
return True
......@@ -630,6 +633,10 @@ class BaseOAITritonExperts(mk.FusedMoEExpertsModular):
class OAITritonExperts(BaseOAITritonExperts):
"""OAI Triton-based fused MoE expert implementation."""
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
return activation == MoEActivation.SWIGLUOAI
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
......@@ -714,6 +721,15 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
One use case for it is to inject LoRA modules on the activation and moe_sum.
"""
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
return activation in [
MoEActivation.SILU,
MoEActivation.GELU,
MoEActivation.SWIGLUOAI,
MoEActivation.SWIGLUSTEP,
]
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
......@@ -839,3 +855,118 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
)
self.moe_sum(intermediate_cache3.view(-1, topk, K), output)
class OAITritonMxfp4ExpertsMonolithic(mk.FusedMoEExpertsMonolithic):
"""Monolithic Triton MXFP4 expert. Wraps triton_kernel_moe_forward()."""
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
):
super().__init__(moe_config, quant_config)
self.topk = moe_config.experts_per_token
self.renormalize = moe_config.routing_method in (
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
)
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
@staticmethod
def _supports_current_device() -> bool:
p = current_platform
if not p.is_cuda_alike():
return False
cap = p.get_device_capability()
if cap is None:
return False
# (9,0) <= cap < (11,0) covers CUDA SM90 (Hopper), SM100+ (Blackwell)
# and ROCm gfx942/gfx950 (which map to 9.4/9.5).
return (9, 0) <= (cap.major, cap.minor) < (11, 0)
@staticmethod
def _supports_no_act_and_mul() -> bool:
return False
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
SUPPORTED_W_A = [
(kMxfp4Static, None),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
return activation == MoEActivation.SWIGLUOAI
@staticmethod
def _supports_parallel_config(
moe_parallel_config: FusedMoEParallelConfig,
) -> bool:
return (
not moe_parallel_config.use_all2all_kernels
and not moe_parallel_config.enable_eplb
and moe_parallel_config.dp_size <= 1
)
@staticmethod
def _supports_routing_method(
routing_method: RoutingMethodType,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
return routing_method in [
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
]
@staticmethod
def _supports_router_logits_dtype(
router_logits_dtype: torch.dtype | None,
routing_method: RoutingMethodType,
) -> bool:
return True
def supports_expert_map(self) -> bool:
return True
@property
def expects_unquantized_inputs(self) -> bool:
return True
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
apply_router_weight_on_input: bool,
# grouped topk + fused topk bias parameters
num_expert_group: int | None = None,
e_score_correction_bias: torch.Tensor | None = None,
routed_scaling_factor: float | None = None,
topk_group: int | None = None,
) -> torch.Tensor:
return triton_kernel_moe_forward(
hidden_states=hidden_states,
w1=w1,
w2=w2,
gating_output=router_logits,
topk=self.topk,
renormalize=self.renormalize,
global_num_experts=global_num_experts,
expert_map=expert_map,
quant_config=self.quant_config,
apply_router_weight_on_input=apply_router_weight_on_input,
)
......@@ -52,7 +52,6 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
)
from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up
logger = init_logger(__name__)
......@@ -218,7 +217,6 @@ def maybe_roundup_hidden_size(
moe_parallel_config: FusedMoEParallelConfig,
is_lora_enabled: bool,
model_type: str | None,
is_mxfp4_quant: bool,
) -> int:
"""
Given layer hidden size and MoE configurations, round up hidden_size
......@@ -232,7 +230,6 @@ def maybe_roundup_hidden_size(
is used in the case of mxfp4 quantization in selecting the
MxFP4Backend.
model_type: for checking if gpt-oss
is_mxfp4_quant: whether the layer is quantized with mxfp4
Return:
Rounded up hidden_size if rounding up is required based on the configs.
......@@ -246,28 +243,6 @@ def maybe_roundup_hidden_size(
hidden_size, act_dtype, moe_parallel_config
)
# we are padding globally so EP buffer allocation works
if model_type == "gpt_oss" and is_mxfp4_quant:
from vllm.model_executor.layers.quantization.mxfp4 import (
Mxfp4Backend,
get_mxfp4_backend,
)
current_mxfp4_backend = get_mxfp4_backend(is_lora_enabled)
if (
current_mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
):
hidden_size = round_up(hidden_size, 128)
elif (
current_platform.is_rocm()
or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
or current_mxfp4_backend == Mxfp4Backend.MARLIN
):
hidden_size = round_up(hidden_size, 256)
return hidden_size
......@@ -540,9 +515,6 @@ class FusedMoE(CustomOp):
moe_parallel_config=self.moe_parallel_config,
is_lora_enabled=vllm_config.lora_config is not None,
model_type=self.model_type,
is_mxfp4_quant=(
quant_config is not None and quant_config.is_mxfp4_quant(prefix, self)
),
)
self.hidden_size = hidden_size
......
This diff is collapsed.
......@@ -14,7 +14,6 @@ from vllm.model_executor.layers.fused_moe.all2all_utils import (
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
mxfp4_w4a16_moe_quant_config,
nvfp4_moe_quant_config,
nvfp4_w4a16_moe_quant_config,
)
......@@ -347,16 +346,6 @@ def convert_to_nvfp4_moe_kernel_format(
)
def make_mxfp4_moe_quant_config(
w13_scale: torch.Tensor,
w2_scale: torch.Tensor,
) -> FusedMoEQuantConfig:
return mxfp4_w4a16_moe_quant_config(
w1_scale=w13_scale,
w2_scale=w2_scale,
)
def make_nvfp4_moe_quant_config(
backend: NvFp4MoeBackend,
w13_scale: torch.Tensor,
......
......@@ -24,6 +24,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8Static128BlockSym,
kFp8StaticChannelSym,
kFp8StaticTensorSym,
kMxfp4Static,
)
......@@ -201,6 +202,8 @@ def rocm_aiter_fused_experts(
activation_method = ActivationMethod.SILU
elif activation == MoEActivation.GELU:
activation_method = ActivationMethod.GELU
elif activation == MoEActivation.SWIGLUOAI:
activation_method = rocm_aiter_ops.get_aiter_activation_type("swiglu")
else:
raise ValueError(f"Unsupported activation: {activation}")
......@@ -247,8 +250,8 @@ def rocm_aiter_fused_experts(
else:
quant_method = QuantMethod.NO.value
# quark moe for mxfp4 w_dtype mxfp4 a_dtype
if quant_config.use_mxfp4_w4a4:
# mxfp4: both w4a4 (quark) and w4a16 (oracle CK) use BLOCK_1X32
if quant_config.use_mxfp4_w4a4 or quant_config.use_mxfp4_w4a16:
quant_method = QuantMethod.BLOCK_1X32.value
# w8a8 block-scaled
if quant_config.block_shape is not None and quant_config.use_fp8_w8a8:
......@@ -289,6 +292,8 @@ def rocm_aiter_fused_experts(
doweight_stage1=apply_router_weight_on_input,
num_local_tokens=num_local_tokens,
output_dtype=output_dtype,
bias1=quant_config.w1_bias if quant_config.use_mxfp4_w4a16 else None,
bias2=quant_config.w2_bias if quant_config.use_mxfp4_w4a16 else None,
)
......@@ -319,21 +324,23 @@ class AiterExperts(mk.FusedMoEExpertsModular):
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
# TODO(rob): AITER also supports MXFP4, which is not
# yet supported via an Oracle. Once it is, we will add
# MXFP4 to this list.
SUPPORTED_W_A = [
(None, None),
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
(kFp8StaticTensorSym, kFp8StaticTensorSym),
(kFp8StaticTensorSym, kFp8DynamicTensorSym),
(kFp8StaticChannelSym, kFp8DynamicTokenSym),
(kMxfp4Static, None),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
return activation in [MoEActivation.SILU, MoEActivation.GELU]
return activation in [
MoEActivation.SILU,
MoEActivation.GELU,
MoEActivation.SWIGLUOAI,
]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
......
......@@ -45,11 +45,14 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
make_fp8_moe_quant_config,
select_fp8_moe_backend,
)
from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import (
Mxfp4MoeBackend,
make_mxfp4_moe_kernel,
make_mxfp4_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
NvFp4MoeBackend,
convert_to_nvfp4_moe_kernel_format,
is_global_sf_supported_for_nvfp4_backend,
make_mxfp4_moe_quant_config,
make_nvfp4_moe_kernel,
make_nvfp4_moe_quant_config,
select_nvfp4_moe_backend,
......@@ -235,7 +238,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
def __init__(self, moe):
super().__init__(moe)
self.group_size = 32
self.mxfp4_backend = NvFp4MoeBackend.MARLIN
self.mxfp4_backend = Mxfp4MoeBackend.MARLIN
self.experts_cls = MarlinExperts
def create_weights(
......@@ -310,7 +313,9 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
return make_mxfp4_moe_quant_config(
w13_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale
mxfp4_backend=self.mxfp4_backend,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
)
def process_weights_after_loading(self, layer: FusedMoE) -> None:
......@@ -334,10 +339,11 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config is not None:
self.moe_kernel = make_nvfp4_moe_kernel(
self.moe_kernel = make_mxfp4_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
experts_cls=self.experts_cls,
mxfp4_backend=self.mxfp4_backend,
shared_experts=layer.shared_experts,
routing_tables=layer._maybe_init_expert_routing_tables(),
)
......
......@@ -25,9 +25,9 @@ from vllm.model_executor.layers.fused_moe.config import (
ocp_mx_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.quantization.mxfp4 import (
Mxfp4Backend,
get_mxfp4_backend,
from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import (
Mxfp4MoeBackend,
select_mxfp4_moe_backend,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_fp8_moe_layer_for_marlin,
......@@ -699,9 +699,9 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
f"Please check that the combination is supported in OCP_MX_Scheme."
)
self.mxfp4_backend: Mxfp4Backend | None = None
self.mxfp4_backend: Mxfp4MoeBackend | None = None
if self.ocp_mx_scheme == "w_mxfp4":
self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled)
self.mxfp4_backend, _ = select_mxfp4_moe_backend(moe)
if self.input_quant is not None:
self.static_input_scales = not self.input_quant.get("is_dynamic")
......
......@@ -389,9 +389,9 @@ def prepare_moe_fp4_layer_for_marlin(
group_size = 16 if is_nvfp4 else 32
e = layer.num_experts
k = layer.hidden_size
n = layer.intermediate_size_per_partition
e = layer.moe_config.num_experts
k = layer.moe_config.hidden_dim
n = layer.moe_config.intermediate_size_per_partition
# WORKSPACE
device = layer.w13_weight.device
......@@ -500,6 +500,120 @@ def prepare_moe_fp4_layer_for_marlin(
setattr(layer, name, bias)
def prepare_moe_mxfp4_layer_for_marlin(
layer: torch.nn.Module,
w13: torch.Tensor,
w2: torch.Tensor,
w13_scale: torch.Tensor,
w2_scale: torch.Tensor,
w13_bias: torch.Tensor | None,
w2_bias: torch.Tensor | None,
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor | None,
torch.Tensor | None,
]:
"""Pure-function version of prepare_moe_fp4_layer_for_marlin for MXFP4.
Takes weight tensors as inputs and returns transformed tensors.
Does NOT modify the layer in-place.
"""
input_dtype = get_marlin_input_dtype()
if (
input_dtype is not None
and input_dtype.itemsize == 1
and input_dtype != torch.float8_e4m3fn
):
raise RuntimeError("MXFP4 weight + INT8 activation is not supported.")
group_size = 32 # MXFP4 block size
# Derive dimensions from actual weight shapes to handle rounded/padded
# sizes correctly (e.g., Mxfp4MoEMethod rounds up hidden_dim).
# w13 shape: (E, 2*N, K//2)
e = w13.shape[0]
n = w13.shape[1] // 2 # intermediate_size_per_partition
k = w13.shape[2] * 2 # hidden_size
device = w13.device
param_dtype = layer.params_dtype
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
perm = torch.empty(0, dtype=torch.int, device=device)
# WEIGHT: Repack weights to marlin format
def repack_weight(weight: torch.Tensor, name: str) -> torch.Tensor:
tensor_list = []
if "w13" in name:
size_n, size_k = n * 2, k
else:
size_n, size_k = k, n
assert weight.shape == (e, size_n, size_k // 2)
for i in range(e):
qweight = weight[i].view(torch.int32).T.contiguous()
marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=qweight,
perm=perm,
size_k=size_k,
size_n=size_n,
num_bits=4,
is_a_8bit=is_a_8bit,
)
tensor_list.append(marlin_qweight)
return torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
w13 = repack_weight(w13, "w13")
w2 = repack_weight(w2, "w2")
# WEIGHT SCALES: Permute scales
def permute_scales(scales: torch.Tensor, name: str) -> torch.Tensor:
scales = scales.view(torch.float8_e8m0fnu)
scales = scales.to(param_dtype)
tensor_list = []
if "w13" in name:
size_n, size_k = n * 2, k
else:
size_n, size_k = k, n
for i in range(e):
scale = scales[i].T
marlin_scales = marlin_permute_scales(
s=scale,
size_k=size_k,
size_n=size_n,
group_size=group_size,
is_a_8bit=is_a_8bit,
)
marlin_scales = mxfp4_marlin_process_scales(
marlin_scales, input_dtype=input_dtype
)
tensor_list.append(marlin_scales)
return torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
w13_scale = permute_scales(w13_scale, "w13")
w2_scale = permute_scales(w2_scale, "w2")
# BIAS: Permute bias
def permute_bias(bias: torch.Tensor | None) -> torch.Tensor | None:
if bias is None:
return None
bias = bias.to(param_dtype)
tensor_list = []
for i in range(e):
tensor_list.append(marlin_permute_bias(bias[i]))
return torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
w13_bias = permute_bias(w13_bias)
w2_bias = permute_bias(w2_bias)
return w13, w2, w13_scale, w2_scale, w13_bias, w2_bias
def rand_marlin_weight_nvfp4_like(weight, group_size, input_dtype=None):
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import Any
import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.utils.import_utils import has_triton_kernels
......@@ -22,7 +20,7 @@ logger = init_logger(__name__)
CK_MXFP4_MOE_DIM_ALIGNMENT = 256
def _swizzle_mxfp4(quant_tensor, scale, num_warps):
def _swizzle_mxfp4(quant_tensor, scale, num_warps=8):
"""weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel"""
assert has_triton_kernels()
import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
......@@ -87,35 +85,6 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps):
return quant_tensor, InFlexData(), scale
def _can_support_mxfp4(
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
scoring_func: str = "softmax",
activation: MoEActivation = MoEActivation.SWIGLUOAI,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
):
return not (
use_grouped_topk
or topk_group
or num_expert_group
or custom_routing_function
or e_score_correction_bias
or apply_router_weight_on_input
or scoring_func != "softmax"
or activation != MoEActivation.SWIGLUOAI
or expert_load_view
or logical_to_physical_map
or logical_replica_count
)
def get_padding_alignment():
return (
256
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment