Unverified Commit 97995f63 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[MoE Refactor] Create MK for TRTLLM Kernels (#32564)


Signed-off-by: default avatarRobert Shaw <robshaw@redhat.com>
Signed-off-by: default avatarRobert Shaw <rshaw@neuralmagic.com>
Signed-off-by: default avatarRobert Shaw <robertgshaw2@gmail.com>
Co-authored-by: default avatarRobert Shaw <robshaw@redhat.com>
Co-authored-by: default avatarRobert Shaw <rshaw@neuralmagic.com>
parent 881a6b01
...@@ -221,16 +221,16 @@ def test_marlin_vs_trtllm_mxint4_moe_kimik2(monkeypatch, m, n, k, e, topk, group ...@@ -221,16 +221,16 @@ def test_marlin_vs_trtllm_mxint4_moe_kimik2(monkeypatch, m, n, k, e, topk, group
) )
marlin_output = fused_marlin_moe( marlin_output = fused_marlin_moe(
a, hidden_states=a,
w1_marlin, w1=w1_marlin,
w2_marlin, w2=w2_marlin,
None, bias1=None,
None, bias2=None,
w1_scales_marlin, w1_scale=w1_scales_marlin,
w2_scales_marlin, w2_scale=w2_scales_marlin,
None, # gating_output not needed when topk_weights/ids provided topk_weights=topk_weights,
topk_weights, topk_ids=topk_ids,
topk_ids, quant_type_id=scalar_types.uint4b8.id,
global_num_experts=e, global_num_experts=e,
expert_map=None, expert_map=None,
global_scale1=None, global_scale1=None,
...@@ -244,7 +244,6 @@ def test_marlin_vs_trtllm_mxint4_moe_kimik2(monkeypatch, m, n, k, e, topk, group ...@@ -244,7 +244,6 @@ def test_marlin_vs_trtllm_mxint4_moe_kimik2(monkeypatch, m, n, k, e, topk, group
w1_zeros=None, w1_zeros=None,
w2_zeros=None, w2_zeros=None,
input_dtype=dtype, input_dtype=dtype,
quant_type_id=scalar_types.uint4b8.id,
is_k_full=True, is_k_full=True,
) )
......
...@@ -168,7 +168,6 @@ FUSED_MOE_CHUNK_SIZEs = [None, 16] ...@@ -168,7 +168,6 @@ FUSED_MOE_CHUNK_SIZEs = [None, 16]
def is_nyi_config(config: Config) -> bool: def is_nyi_config(config: Config) -> bool:
# We know these configs to be legitimate. but still fail. # We know these configs to be legitimate. but still fail.
info = expert_info(config.fused_experts_type) info = expert_info(config.fused_experts_type)
if info.needs_matching_quant: if info.needs_matching_quant:
# The triton kernels expect both per-act-token-quant and # The triton kernels expect both per-act-token-quant and
# per-out-ch-quant or neither. # per-out-ch-quant or neither.
...@@ -259,7 +258,7 @@ def test_modular_kernel_combinations_multigpu( ...@@ -259,7 +258,7 @@ def test_modular_kernel_combinations_multigpu(
dtype: torch.dtype, dtype: torch.dtype,
quant_config: TestMoEQuantConfig | None, quant_config: TestMoEQuantConfig | None,
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize, prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute, fused_experts_type: mk.FusedMoEExperts,
chunk_size: int | None, chunk_size: int | None,
world_size: int, world_size: int,
pytestconfig, pytestconfig,
...@@ -301,7 +300,7 @@ def test_modular_kernel_combinations_singlegpu( ...@@ -301,7 +300,7 @@ def test_modular_kernel_combinations_singlegpu(
dtype: torch.dtype, dtype: torch.dtype,
quant_config: TestMoEQuantConfig | None, quant_config: TestMoEQuantConfig | None,
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize, prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute, fused_experts_type: mk.FusedMoEExperts,
chunk_size: int | None, chunk_size: int | None,
world_size: int, world_size: int,
pytestconfig, pytestconfig,
......
...@@ -7,6 +7,7 @@ Test modular OAI Triton MoE ...@@ -7,6 +7,7 @@ Test modular OAI Triton MoE
import pytest import pytest
import torch import torch
from tests.utils import wait_for_gpu_memory_to_clear
from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.utils.import_utils import has_triton_kernels from vllm.utils.import_utils import has_triton_kernels
...@@ -24,15 +25,15 @@ from triton_kernels.tensor_details import layout ...@@ -24,15 +25,15 @@ from triton_kernels.tensor_details import layout
from triton_kernels.testing import assert_close from triton_kernels.testing import assert_close
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import mxfp4_w4a16_moe_quant_config from vllm.model_executor.layers.fused_moe.config import mxfp4_w4a16_moe_quant_config
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
OAITritonExperts, OAITritonExperts,
UnfusedOAITritonExperts, UnfusedOAITritonExperts,
) )
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import set_random_seed
...@@ -174,19 +175,25 @@ def oai_triton_moe_impl( ...@@ -174,19 +175,25 @@ def oai_triton_moe_impl(
w1_scale=w1_scale, w1_scale=w1_scale,
w2_scale=w2_scale, w2_scale=w2_scale,
) )
moe_config = make_dummy_moe_config()
if unfused: if unfused:
fused_experts = UnfusedOAITritonExperts(make_dummy_moe_config(), quant_config) fused_experts = UnfusedOAITritonExperts(moe_config, quant_config)
else: else:
fused_experts = OAITritonExperts(make_dummy_moe_config(), quant_config) fused_experts = OAITritonExperts(moe_config, quant_config)
mk = FusedMoEModularKernel( mk = FusedMoEKernel(
MoEPrepareAndFinalizeNoEP(), maybe_make_prepare_finalize(
moe=moe_config,
quant_config=quant_config,
allow_new_interface=True,
use_monolithic=False,
),
fused_experts, fused_experts,
inplace=False, inplace=False,
) )
return mk.forward( return mk.apply(
hidden_states=x, hidden_states=x,
w1=w1, w1=w1,
w2=w2, w2=w2,
...@@ -217,6 +224,7 @@ def test_oai_triton_moe( ...@@ -217,6 +224,7 @@ def test_oai_triton_moe(
unfused: bool, unfused: bool,
workspace_init, workspace_init,
): ):
wait_for_gpu_memory_to_clear(devices=[0], threshold_ratio=0.1)
set_random_seed(0) set_random_seed(0)
( (
w1, w1,
......
...@@ -346,14 +346,16 @@ def test_fused_moe( ...@@ -346,14 +346,16 @@ def test_fused_moe(
expert_map: torch.Tensor | None = None, expert_map: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
return m_fused_moe_fn( return m_fused_moe_fn.apply(
a, a,
w1, w1,
w2, w2,
topk_weights, topk_weights,
topk_ids, topk_ids,
activation=MoEActivation.SILU,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
apply_router_weight_on_input=False,
) )
fused_moe_fn = functools.partial(fused_moe, renormalize=False) fused_moe_fn = functools.partial(fused_moe, renormalize=False)
...@@ -500,14 +502,16 @@ def test_naive_block_assignment_moe( ...@@ -500,14 +502,16 @@ def test_naive_block_assignment_moe(
expert_map: torch.Tensor | None = None, expert_map: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
return m_fused_moe_fn( return m_fused_moe_fn.apply(
a, a,
w1, w1,
w2, w2,
topk_weights, topk_weights,
topk_ids, topk_ids,
activation=MoEActivation.SILU,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
apply_router_weight_on_input=False,
) )
fused_moe_fn = functools.partial(fused_moe, renormalize=False) fused_moe_fn = functools.partial(fused_moe, renormalize=False)
......
...@@ -15,12 +15,15 @@ from vllm import _custom_ops as ops ...@@ -15,12 +15,15 @@ from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import nvfp4_moe_quant_config from vllm.model_executor.layers.fused_moe.config import nvfp4_moe_quant_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import ( from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp4, CutlassExpertsFp4,
) )
from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP, make_moe_prepare_and_finalize_no_dp_ep,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import set_random_seed
...@@ -89,22 +92,32 @@ def test_cutlass_fp4_moe_no_graph( ...@@ -89,22 +92,32 @@ def test_cutlass_fp4_moe_no_graph(
w1_scale=w1_blockscale, w1_scale=w1_blockscale,
w2_scale=w2_blockscale, w2_scale=w2_blockscale,
) )
moe_config = make_dummy_moe_config()
kernel = mk.FusedMoEModularKernel( kernel = mk.FusedMoEKernel(
MoEPrepareAndFinalizeNoEP(), maybe_make_prepare_finalize(
moe=moe_config,
quant_config=quant_config,
allow_new_interface=True,
use_monolithic=False,
),
CutlassExpertsFp4( CutlassExpertsFp4(
moe_config=make_dummy_moe_config(), moe_config=moe_config,
quant_config=quant_config, quant_config=quant_config,
), ),
inplace=False, inplace=False,
) )
cutlass_output = kernel( cutlass_output = kernel.apply(
hidden_states=a, hidden_states=a,
w1=w1_q, w1=w1_q,
w2=w2_q, w2=w2_q,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
global_num_experts=e,
activation=mk.MoEActivation.SILU,
apply_router_weight_on_input=False,
expert_map=None,
) )
# Reference check: # Reference check:
...@@ -207,8 +220,8 @@ def test_cutlass_fp4_moe_swiglustep( ...@@ -207,8 +220,8 @@ def test_cutlass_fp4_moe_swiglustep(
w2_scale=w2_blockscale, w2_scale=w2_blockscale,
) )
kernel = mk.FusedMoEModularKernel( kernel = mk.FusedMoEKernel(
MoEPrepareAndFinalizeNoEP(), make_moe_prepare_and_finalize_no_dp_ep(use_monolithic=False),
CutlassExpertsFp4( CutlassExpertsFp4(
moe_config=make_dummy_moe_config(), moe_config=make_dummy_moe_config(),
quant_config=quant_config, quant_config=quant_config,
...@@ -216,13 +229,16 @@ def test_cutlass_fp4_moe_swiglustep( ...@@ -216,13 +229,16 @@ def test_cutlass_fp4_moe_swiglustep(
inplace=False, inplace=False,
) )
cutlass_output = kernel( cutlass_output = kernel.apply(
hidden_states=a, hidden_states=a,
w1=w1_q, w1=w1_q,
w2=w2_q, w2=w2_q,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
activation=MoEActivation.SWIGLUSTEP, activation=MoEActivation.SWIGLUSTEP,
global_num_experts=e,
expert_map=None,
apply_router_weight_on_input=False,
) )
# Reference: dequantize everything and run torch_moe with swiglustep # Reference: dequantize everything and run torch_moe with swiglustep
......
...@@ -8,6 +8,9 @@ from tests.kernels.quant_utils import per_block_cast_to_int8 ...@@ -8,6 +8,9 @@ from tests.kernels.quant_utils import per_block_cast_to_int8
from tests.kernels.quantization.nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX from tests.kernels.quantization.nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEParallelConfig, FusedMoEParallelConfig,
...@@ -23,10 +26,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( ...@@ -23,10 +26,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
TritonExperts, TritonExperts,
fused_experts, fused_experts,
) )
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.router.fused_topk_router import fused_topk from vllm.model_executor.layers.fused_moe.router.fused_topk_router import fused_topk
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.utils.deep_gemm import per_block_cast_to_fp8 from vllm.utils.deep_gemm import per_block_cast_to_fp8
...@@ -125,7 +125,9 @@ def batched_moe( ...@@ -125,7 +125,9 @@ def batched_moe(
a2_scale=a2_scale, a2_scale=a2_scale,
) )
fused_experts = FusedMoEModularKernel( moe_config = make_dummy_moe_config()
fused_experts = FusedMoEKernel(
BatchedPrepareAndFinalize( BatchedPrepareAndFinalize(
max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0 max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0
), ),
...@@ -133,12 +135,22 @@ def batched_moe( ...@@ -133,12 +135,22 @@ def batched_moe(
max_num_tokens=max_num_tokens, max_num_tokens=max_num_tokens,
num_dispatchers=1, num_dispatchers=1,
quant_config=quant_config, quant_config=quant_config,
moe_config=make_dummy_moe_config(), moe_config=moe_config,
), ),
inplace=False, inplace=False,
) )
return fused_experts(a, w1, w2, topk_weight, topk_ids) return fused_experts.apply(
a,
w1,
w2,
topk_weight,
topk_ids,
global_num_experts=w1.shape[0],
activation=moe_config.activation,
apply_router_weight_on_input=False,
expert_map=None,
)
def naive_batched_moe( def naive_batched_moe(
...@@ -166,8 +178,9 @@ def naive_batched_moe( ...@@ -166,8 +178,9 @@ def naive_batched_moe(
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
) )
moe_config = make_dummy_moe_config()
fused_experts = FusedMoEModularKernel( fused_experts = FusedMoEKernel(
BatchedPrepareAndFinalize( BatchedPrepareAndFinalize(
max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0 max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0
), ),
...@@ -175,12 +188,22 @@ def naive_batched_moe( ...@@ -175,12 +188,22 @@ def naive_batched_moe(
max_num_tokens=max_num_tokens, max_num_tokens=max_num_tokens,
num_dispatchers=1, num_dispatchers=1,
quant_config=quant_config, quant_config=quant_config,
moe_config=make_dummy_moe_config(), moe_config=moe_config,
), ),
inplace=False, inplace=False,
) )
return fused_experts(a, w1, w2, topk_weight, topk_ids) return fused_experts.apply(
a,
w1,
w2,
topk_weight,
topk_ids,
global_num_experts=w1.shape[0],
activation=moe_config.activation,
apply_router_weight_on_input=False,
expert_map=None,
)
def chunk_scales( def chunk_scales(
...@@ -581,9 +604,14 @@ def modular_triton_fused_moe( ...@@ -581,9 +604,14 @@ def modular_triton_fused_moe(
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
shared_experts: torch.nn.Module | None = None, shared_experts: torch.nn.Module | None = None,
) -> FusedMoEModularKernel: ) -> FusedMoEKernel:
return FusedMoEModularKernel( return FusedMoEKernel(
MoEPrepareAndFinalizeNoEP(), maybe_make_prepare_finalize(
moe=moe_config,
quant_config=quant_config,
allow_new_interface=True,
use_monolithic=False,
),
TritonExperts(moe_config, quant_config), TritonExperts(moe_config, quant_config),
shared_experts, shared_experts,
inplace=False, inplace=False,
......
...@@ -127,6 +127,14 @@ def test_deepseek_fp8_block_moe_deep_gemm(monkeypatch: pytest.MonkeyPatch): ...@@ -127,6 +127,14 @@ def test_deepseek_fp8_block_moe_deep_gemm(monkeypatch: pytest.MonkeyPatch):
) )
def test_deepseek_fp8_block_moe_vllm_triton(monkeypatch: pytest.MonkeyPatch):
can_initialize(
"deepseek-ai/DeepSeek-V3.1",
hf_overrides=HF_OVERRIDE_TEXT,
extra_args=["--moe-backend=triton"],
)
@pytest.mark.skip( @pytest.mark.skip(
reason=( reason=(
"Known issue: lack of kernel support. " "Known issue: lack of kernel support. "
...@@ -149,6 +157,14 @@ def test_deepseek_fp8_block_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatc ...@@ -149,6 +157,14 @@ def test_deepseek_fp8_block_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatc
) )
def test_deepseek_nvfp4_moe_flashinfer_vllm(monkeypatch: pytest.MonkeyPatch):
can_initialize(
"nvidia/DeepSeek-R1-0528-FP4-v2",
hf_overrides=HF_OVERRIDE_TEXT,
extra_args=["--moe-backend=cutlass"],
)
def test_deepseek_nvfp4_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch): def test_deepseek_nvfp4_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch):
can_initialize( can_initialize(
"nvidia/DeepSeek-R1-0528-FP4-v2", "nvidia/DeepSeek-R1-0528-FP4-v2",
...@@ -200,3 +216,67 @@ def test_qwen3_next_bf16_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): ...@@ -200,3 +216,67 @@ def test_qwen3_next_bf16_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
hf_overrides=HF_OVERRIDE_TEXT, hf_overrides=HF_OVERRIDE_TEXT,
extra_args=["--moe-backend=flashinfer_trtllm"], extra_args=["--moe-backend=flashinfer_trtllm"],
) )
## NemoTron ##
def test_nemotron_fp8_moe_flashinfer_throughput(monkeypatch: pytest.MonkeyPatch):
can_initialize(
"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8",
hf_overrides=HF_OVERRIDE_TEXT,
extra_args=["--moe-backend=flashinfer_cutlass"],
)
@pytest.mark.skip(
reason=(
"FP8 MoE backend FLASHINFER_TRTLLM does not support the "
"deployment configuration since kernel does not support "
"no act_and_mul MLP layer."
)
)
def test_nemotron_fp8_moe_flashinfer_latency(monkeypatch: pytest.MonkeyPatch):
can_initialize(
"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8",
hf_overrides=HF_OVERRIDE_TEXT,
extra_args=["--moe-backend=flashinfer_trtllm"],
)
@pytest.mark.skip(
reason=(
"FP8 MoE backend TRITON does not support the "
"deployment configuration since kernel does not support "
"no act_and_mul MLP layer."
)
)
def test_nemotron_fp8_moe_vllm_triton(monkeypatch: pytest.MonkeyPatch):
can_initialize(
"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8",
hf_overrides=HF_OVERRIDE_TEXT,
extra_args=["--moe-backend=triton"],
)
def test_nemotron_fp4_moe_flashinfer_throughput(monkeypatch: pytest.MonkeyPatch):
can_initialize(
"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4",
hf_overrides=HF_OVERRIDE_TEXT,
extra_args=["--moe-backend=flashinfer_cutlass"],
)
@pytest.mark.skip(
reason=(
"FP4 MoE backend FLASHINFER_TRTLLM does not support the "
"deployment configuration since kernel does not support "
"hidden_dim % 512 != 0."
)
)
def test_nemotron_fp4_moe_flashinfer_latency(monkeypatch: pytest.MonkeyPatch):
can_initialize(
"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4",
hf_overrides=HF_OVERRIDE_TEXT,
extra_args=["--moe-backend=flashinfer_trtllm"],
)
...@@ -32,10 +32,10 @@ from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( ...@@ -32,10 +32,10 @@ from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
UnfusedOAITritonExperts, UnfusedOAITritonExperts,
) )
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel, FusedMoEKernel,
) )
from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP, MoEPrepareAndFinalizeNoDPEPModular,
) )
from .utils import _get_lora_device, try_get_optimal_moe_lora_config from .utils import _get_lora_device, try_get_optimal_moe_lora_config
...@@ -136,7 +136,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): ...@@ -136,7 +136,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
if getattr(self.base_layer.quant_method, "supports_internal_mk", False): if getattr(self.base_layer.quant_method, "supports_internal_mk", False):
# Use the existing modular kernel from the quant method # Use the existing modular kernel from the quant method
m_fused_moe_fn = self.base_layer.quant_method.moe_mk m_fused_moe_fn = self.base_layer.quant_method.moe_kernel
# Don't let the kernel own shared experts so the runner can # Don't let the kernel own shared experts so the runner can
# overlap them with routed experts via a separate CUDA stream. # overlap them with routed experts via a separate CUDA stream.
m_fused_moe_fn.shared_experts = None m_fused_moe_fn.shared_experts = None
...@@ -144,8 +144,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): ...@@ -144,8 +144,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
# Create a new modular kernel via select_gemm_impl. # Create a new modular kernel via select_gemm_impl.
# Don't pass shared_experts to the kernel so the runner can # Don't pass shared_experts to the kernel so the runner can
# overlap them with routed experts via a separate CUDA stream. # overlap them with routed experts via a separate CUDA stream.
prepare_finalize = MoEPrepareAndFinalizeNoEP() prepare_finalize = MoEPrepareAndFinalizeNoDPEPModular()
m_fused_moe_fn = FusedMoEModularKernel( m_fused_moe_fn = FusedMoEKernel(
prepare_finalize, prepare_finalize,
self.base_layer.quant_method.select_gemm_impl( self.base_layer.quant_method.select_gemm_impl(
prepare_finalize, self.base_layer prepare_finalize, self.base_layer
...@@ -154,10 +154,11 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): ...@@ -154,10 +154,11 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
if quant_config.use_mxfp4_w4a16: if quant_config.use_mxfp4_w4a16:
assert isinstance( assert isinstance(
m_fused_moe_fn.fused_experts, (MarlinExperts, UnfusedOAITritonExperts) m_fused_moe_fn.impl.fused_experts,
(MarlinExperts, UnfusedOAITritonExperts),
) )
else: else:
assert isinstance(m_fused_moe_fn.fused_experts, TritonExperts) assert isinstance(m_fused_moe_fn.impl.fused_experts, TritonExperts)
def fwd_decorator(layer, func): def fwd_decorator(layer, func):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
...@@ -337,9 +338,9 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): ...@@ -337,9 +338,9 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
return wrapper return wrapper
fused_experts = m_fused_moe_fn.fused_experts fused_experts = m_fused_moe_fn.impl.fused_experts
m_fused_moe_fn.forward = fwd_decorator(self.base_layer, m_fused_moe_fn.forward) m_fused_moe_fn.apply = fwd_decorator(self.base_layer, m_fused_moe_fn.apply)
fused_experts.activation = act_decorator( fused_experts.activation = act_decorator(
self.base_layer, fused_experts.activation self.base_layer, fused_experts.activation
) )
......
...@@ -22,8 +22,8 @@ from vllm.model_executor.layers.fused_moe.layer import ( ...@@ -22,8 +22,8 @@ from vllm.model_executor.layers.fused_moe.layer import (
) )
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat, FusedMoEActivationFormat,
FusedMoEPermuteExpertsUnpermute, FusedMoEExpertsModular,
FusedMoEPrepareAndFinalize, FusedMoEPrepareAndFinalizeModular,
) )
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import ( from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
FusedMoERouter, FusedMoERouter,
...@@ -62,9 +62,9 @@ __all__ = [ ...@@ -62,9 +62,9 @@ __all__ = [
"MoEActivation", "MoEActivation",
"UnquantizedFusedMoEMethod", "UnquantizedFusedMoEMethod",
"FusedMoeWeightScaleSupported", "FusedMoeWeightScaleSupported",
"FusedMoEPermuteExpertsUnpermute", "FusedMoEExpertsModular",
"FusedMoEActivationFormat", "FusedMoEActivationFormat",
"FusedMoEPrepareAndFinalize", "FusedMoEPrepareAndFinalizeModular",
"GateLinear", "GateLinear",
"RoutingMethodType", "RoutingMethodType",
"SharedFusedMoE", "SharedFusedMoE",
......
...@@ -21,8 +21,8 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import ( ...@@ -21,8 +21,8 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPrepareAndFinalize, FusedMoEPrepareAndFinalize,
) )
from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNaiveEP, make_moe_prepare_and_finalize_naive_dp_ep,
MoEPrepareAndFinalizeNoEP, make_moe_prepare_and_finalize_no_dp_ep,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.import_utils import has_deep_ep, has_mori from vllm.utils.import_utils import has_deep_ep, has_mori
...@@ -77,6 +77,7 @@ def maybe_make_prepare_finalize( ...@@ -77,6 +77,7 @@ def maybe_make_prepare_finalize(
quant_config: FusedMoEQuantConfig | None, quant_config: FusedMoEQuantConfig | None,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
allow_new_interface: bool = False, allow_new_interface: bool = False,
use_monolithic: bool = False,
) -> FusedMoEPrepareAndFinalize | None: ) -> FusedMoEPrepareAndFinalize | None:
# NOTE(rob): we are migrating each quant_method to hold the MK # NOTE(rob): we are migrating each quant_method to hold the MK
# in all cases. The allow_new_interface=False flag allow us to fall # in all cases. The allow_new_interface=False flag allow us to fall
...@@ -102,14 +103,15 @@ def maybe_make_prepare_finalize( ...@@ -102,14 +103,15 @@ def maybe_make_prepare_finalize(
"Detected DP deployment with no --enable-expert-parallel. " "Detected DP deployment with no --enable-expert-parallel. "
"Falling back to AllGather+ReduceScatter dispatch/combine." "Falling back to AllGather+ReduceScatter dispatch/combine."
) )
return MoEPrepareAndFinalizeNaiveEP( return make_moe_prepare_and_finalize_naive_dp_ep(
is_sequence_parallel=moe.moe_parallel_config.is_sequence_parallel, is_sequence_parallel=moe.moe_parallel_config.is_sequence_parallel,
num_dispatchers=( num_dispatchers=(
get_ep_group().device_communicator.all2all_manager.world_size get_ep_group().device_communicator.all2all_manager.world_size
), ),
use_monolithic=use_monolithic,
) )
else: else:
return MoEPrepareAndFinalizeNoEP() return make_moe_prepare_and_finalize_no_dp_ep(use_monolithic)
all2all_manager = get_ep_group().device_communicator.all2all_manager all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None assert all2all_manager is not None
...@@ -201,8 +203,9 @@ def maybe_make_prepare_finalize( ...@@ -201,8 +203,9 @@ def maybe_make_prepare_finalize(
) )
elif moe.use_naive_all2all_kernels and allow_new_interface: elif moe.use_naive_all2all_kernels and allow_new_interface:
prepare_finalize = MoEPrepareAndFinalizeNaiveEP( prepare_finalize = make_moe_prepare_and_finalize_naive_dp_ep(
is_sequence_parallel=(moe.moe_parallel_config.is_sequence_parallel), use_monolithic=use_monolithic,
is_sequence_parallel=moe.moe_parallel_config.is_sequence_parallel,
num_dispatchers=all2all_manager.world_size, num_dispatchers=all2all_manager.world_size,
) )
......
...@@ -261,7 +261,7 @@ def persistent_masked_m_silu_mul_quant( ...@@ -261,7 +261,7 @@ def persistent_masked_m_silu_mul_quant(
return y_q, y_s return y_q, y_s
class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): class BatchedDeepGemmExperts(mk.FusedMoEExpertsModular):
def __init__( def __init__(
self, self,
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
......
...@@ -228,6 +228,7 @@ class FusedMoEQuantConfig: ...@@ -228,6 +228,7 @@ class FusedMoEQuantConfig:
_a2: FusedMoEQuantDesc _a2: FusedMoEQuantDesc
_w1: FusedMoEQuantDesc _w1: FusedMoEQuantDesc
_w2: FusedMoEQuantDesc _w2: FusedMoEQuantDesc
is_nvfp4_scale_swizzled: bool = True
def __post_init__(self): def __post_init__(self):
assert not self.per_act_token_quant or self.block_shape is None, ( assert not self.per_act_token_quant or self.block_shape is None, (
...@@ -475,6 +476,7 @@ class FusedMoEQuantConfig: ...@@ -475,6 +476,7 @@ class FusedMoEQuantConfig:
w1_zp: torch.Tensor | None = None, w1_zp: torch.Tensor | None = None,
w2_zp: torch.Tensor | None = None, w2_zp: torch.Tensor | None = None,
weight_dtype: torch.dtype | str | None = None, weight_dtype: torch.dtype | str | None = None,
is_nvfp4_scale_swizzled: bool = True,
) -> "FusedMoEQuantConfig": ) -> "FusedMoEQuantConfig":
""" """
General builder function for a FusedMoEQuantConfig. General builder function for a FusedMoEQuantConfig.
...@@ -504,6 +506,7 @@ class FusedMoEQuantConfig: ...@@ -504,6 +506,7 @@ class FusedMoEQuantConfig:
- w2_bias: Optional biases for w1 (GPT OSS Triton). - w2_bias: Optional biases for w1 (GPT OSS Triton).
- w1_zp: Optional w1 zero points for int4/int8 quantization. - w1_zp: Optional w1 zero points for int4/int8 quantization.
- w2_zp: Optional w2 zero points for int4/int8 quantization. - w2_zp: Optional w2 zero points for int4/int8 quantization.
- is_nvfp4_scale_swizzled: Whether to swizzle the nvfp4 scale swizzling.
""" """
assert not isinstance(quant_dtype, str) or quant_dtype in { assert not isinstance(quant_dtype, str) or quant_dtype in {
"nvfp4", "nvfp4",
...@@ -536,6 +539,7 @@ class FusedMoEQuantConfig: ...@@ -536,6 +539,7 @@ class FusedMoEQuantConfig:
_w2=FusedMoEQuantDesc( _w2=FusedMoEQuantDesc(
weight_dtype, w_shape, w2_scale, g2_alphas, w2_zp, w2_bias weight_dtype, w_shape, w2_scale, g2_alphas, w2_zp, w2_bias
), ),
is_nvfp4_scale_swizzled=is_nvfp4_scale_swizzled,
) )
assert quant_config.per_act_token_quant == per_act_token_quant assert quant_config.per_act_token_quant == per_act_token_quant
assert quant_config.per_out_ch_quant == per_out_ch_quant assert quant_config.per_out_ch_quant == per_out_ch_quant
...@@ -737,6 +741,7 @@ def nvfp4_moe_quant_config( ...@@ -737,6 +741,7 @@ def nvfp4_moe_quant_config(
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
w1_bias: torch.Tensor | None = None, w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None,
is_nvfp4_scale_swizzled: bool = True,
) -> FusedMoEQuantConfig: ) -> FusedMoEQuantConfig:
""" """
Construct a quant config for mxfp4 activations and nvp4 weights. Construct a quant config for mxfp4 activations and nvp4 weights.
...@@ -754,6 +759,7 @@ def nvfp4_moe_quant_config( ...@@ -754,6 +759,7 @@ def nvfp4_moe_quant_config(
per_act_token_quant=False, per_act_token_quant=False,
per_out_ch_quant=False, per_out_ch_quant=False,
block_shape=None, block_shape=None,
is_nvfp4_scale_swizzled=is_nvfp4_scale_swizzled,
) )
......
...@@ -21,7 +21,7 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( ...@@ -21,7 +21,7 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
moe_unpermute, moe_unpermute,
) )
from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP, MoEPrepareAndFinalizeNoDPEPModular,
) )
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate, TopKWeightAndReduceDelegate,
...@@ -262,7 +262,7 @@ def run_cutlass_moe_fp8( ...@@ -262,7 +262,7 @@ def run_cutlass_moe_fp8(
) )
class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): class CutlassExpertsFp8Base(mk.FusedMoEExpertsModular):
def __init__( def __init__(
self, self,
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
...@@ -661,7 +661,7 @@ def run_cutlass_moe_fp4( ...@@ -661,7 +661,7 @@ def run_cutlass_moe_fp4(
return return
class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): class CutlassExpertsFp4(mk.FusedMoEExpertsModular):
"""CUTLASS FP4 fused MoE expert implementation.""" """CUTLASS FP4 fused MoE expert implementation."""
@property @property
...@@ -928,7 +928,7 @@ def run_cutlass_moe_w4a8_fp8( ...@@ -928,7 +928,7 @@ def run_cutlass_moe_w4a8_fp8(
) )
class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute): class CutlassExpertsW4A8Fp8(mk.FusedMoEExpertsModular):
def __init__( def __init__(
self, self,
out_dtype: torch.dtype | None, out_dtype: torch.dtype | None,
...@@ -1170,8 +1170,8 @@ def cutlass_moe_w4a8_fp8( ...@@ -1170,8 +1170,8 @@ def cutlass_moe_w4a8_fp8(
num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(0) num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(0)
fn = mk.FusedMoEModularKernel( fn = mk.FusedMoEKernel(
MoEPrepareAndFinalizeNoEP(), MoEPrepareAndFinalizeNoDPEPModular(),
CutlassExpertsW4A8Fp8( CutlassExpertsW4A8Fp8(
out_dtype=a.dtype, out_dtype=a.dtype,
a_strides1=a_strides1, a_strides1=a_strides1,
...@@ -1186,10 +1186,9 @@ def cutlass_moe_w4a8_fp8( ...@@ -1186,10 +1186,9 @@ def cutlass_moe_w4a8_fp8(
quant_config=quant_config, quant_config=quant_config,
group_size=group_size, group_size=group_size,
), ),
inplace=False,
) )
return fn( return fn.apply(
a, a,
w1_q, w1_q,
w2_q, w2_q,
......
...@@ -113,7 +113,7 @@ def _valid_deep_gemm( ...@@ -113,7 +113,7 @@ def _valid_deep_gemm(
return True return True
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): class DeepGemmExperts(mk.FusedMoEExpertsModular):
"""DeepGemm-based fused MoE expert implementation.""" """DeepGemm-based fused MoE expert implementation."""
def __init__(self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig): def __init__(self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig):
......
...@@ -25,7 +25,7 @@ from vllm.v1.worker.ubatching import ( ...@@ -25,7 +25,7 @@ from vllm.v1.worker.ubatching import (
) )
class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
""" """
Prepare/Finalize using DeepEP High-Throughput kernels. Prepare/Finalize using DeepEP High-Throughput kernels.
""" """
...@@ -239,6 +239,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -239,6 +239,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
quant_dtype=quant_config.quant_dtype, quant_dtype=quant_config.quant_dtype,
per_act_token_quant=False, per_act_token_quant=False,
block_shape=quant_config.block_shape, block_shape=quant_config.block_shape,
is_fp4_scale_swizzled=quant_config.is_nvfp4_scale_swizzled,
) )
return ( return (
......
...@@ -49,7 +49,7 @@ def dequant_fp8( ...@@ -49,7 +49,7 @@ def dequant_fp8(
return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.size()) return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.size())
class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
""" """
Prepare/Finalize using DeepEP low-latency kernels. Prepare/Finalize using DeepEP low-latency kernels.
""" """
...@@ -119,7 +119,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -119,7 +119,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# time. This setting is handled by post_init_setup. # time. This setting is handled by post_init_setup.
self.use_ue8m0_dispatch = False self.use_ue8m0_dispatch = False
def post_init_setup(self, fused_experts: mk.FusedMoEPermuteExpertsUnpermute): def post_init_setup(self, fused_experts: mk.FusedMoEExperts):
if not fused_experts.supports_packed_ue8m0_act_scales(): if not fused_experts.supports_packed_ue8m0_act_scales():
# Early exit. # Early exit.
return return
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
activation_to_flashinfer_int,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8Dynamic128Sym,
kFp8Static128BlockSym,
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic):
"""
Fp8 TRTLLM-Gen MoE kernels. Supports monolithic interface.
"""
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
):
super().__init__(moe_config, quant_config)
if moe_config.moe_parallel_config.use_ep and quant_config.is_per_tensor:
raise NotImplementedError(
"EP parallelism is not supported with TRTLLM"
"per-tensor FP8 quantization."
)
self.routing_method_type = moe_config.routing_method
self.topk = moe_config.experts_per_token
self.intermediate_size_per_partition = (
moe_config.intermediate_size_per_partition
)
self.hidden_dim = moe_config.hidden_dim
self.local_num_experts = moe_config.num_local_experts
self.ep_rank = moe_config.moe_parallel_config.ep_rank
# Make additional scales for per-tensor interface.
if self.quant_config.is_per_tensor:
w1_scale = self.quant_config.w1_scale
assert w1_scale is not None
a1_scale = self.quant_config.a1_scale
assert a1_scale is not None
w2_scale = self.quant_config.w2_scale
assert w2_scale is not None
a2_scale = self.quant_config.a2_scale
assert a2_scale is not None
self._g1_alphas = (w1_scale * a1_scale).squeeze()
self._g2_alphas = (w2_scale * a2_scale).squeeze()
self._g1_scale_c = (
self._g1_alphas / self.quant_config.a2_scale
if moe_config.is_act_and_mul
else torch.ones_like(self._g1_alphas) / self.quant_config.a2_scale
)
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
@staticmethod
def _supports_current_device() -> bool:
"""Supports only Blackwell-family GPUs."""
p = current_platform
# Add check flashinfer trtllm is available
return p.is_cuda() and p.is_device_capability_family(100)
@staticmethod
def _supports_no_act_and_mul() -> bool:
"""Does not support non-gated MoE (i.e. Nanotron-3-Nano)."""
return True
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
"""Supports Fp8 per-tensor and Fp8 block."""
SUPPORTED_W_A = [
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
(kFp8StaticTensorSym, kFp8StaticTensorSym),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
"""Supports only SiLU and RELU^2 non-gated activation."""
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
@staticmethod
def _supports_routing_method(
routing_method: RoutingMethodType,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
"""Monolithic kernels need to express router support."""
# NOTE(dbari): TopK routing could also be enabled, but need to validate models
# NOTE(dbari): Default is not implemented and should not be enabled until it is
if (weight_key, activation_key) == (kFp8Static128BlockSym, kFp8Dynamic128Sym):
# NOTE(rob): potentially allow others here. This is a conservative list.
return routing_method in [
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
]
elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym):
# NOTE(dbari): as above, potentially allow others here.
return routing_method in [
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Llama4,
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
]
else:
raise ValueError("Unsupported quantization scheme.")
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
"""Monolithic kernel so only use with naive DP/EP and TP."""
return (
not moe_parallel_config.use_all2all_kernels
or moe_parallel_config.use_naive_all2all_kernels
) and not moe_parallel_config.enable_eplb
@staticmethod
def _supports_router_logits_dtype(
router_logits_dtype: torch.dtype | None,
routing_method: RoutingMethodType,
) -> bool:
"""
The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default.
Only DeepSeekV3 routing supports float32 router_logits (which is converted
internally in the kernel).
"""
if router_logits_dtype == torch.float32:
# Only DeepSeekV3 routing handles float32 logits
# https://github.com/flashinfer-ai/flashinfer/issues/2469
return routing_method == RoutingMethodType.DeepSeekV3
return True
def supports_chunking(self) -> bool:
return False
def supports_expert_map(self) -> bool:
return False
def _apply_per_block(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
apply_router_weight_on_input: bool,
# grouped topk + fused topk bias parameters
num_expert_group: int | None = None,
e_score_correction_bias: torch.Tensor | None = None,
routed_scaling_factor: float | None = None,
topk_group: int | None = None,
) -> torch.Tensor:
# Delay import for non-CUDA.
import flashinfer
assert not apply_router_weight_on_input
assert activation == MoEActivation.SILU
if e_score_correction_bias is not None:
e_score_correction_bias = e_score_correction_bias.to(hidden_states.dtype)
if self.routing_method_type == RoutingMethodType.DeepSeekV3:
router_logits = router_logits.to(torch.float32)
assert self.topk <= global_num_experts
assert self.topk <= 10
assert global_num_experts % 4 == 0
assert self.quant_config.block_shape == [128, 128]
# Routing kernel expects #experts <= #threads 512
assert global_num_experts <= 512
# Kernel requires transposed hidden state scales
# TODO: fuse into the quant kernel.
assert a1q_scale is not None
a1q_scale_t = a1q_scale.t().contiguous()
return flashinfer.fused_moe.trtllm_fp8_block_scale_moe(
routing_logits=router_logits,
routing_bias=e_score_correction_bias,
hidden_states=hidden_states,
hidden_states_scale=a1q_scale_t,
gemm1_weights=w1,
gemm1_weights_scale=self.quant_config.w1_scale,
gemm2_weights=w2,
gemm2_weights_scale=self.quant_config.w2_scale,
num_experts=global_num_experts,
top_k=self.topk,
n_group=(num_expert_group or 0),
topk_group=(topk_group or 0),
intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.ep_rank * self.local_num_experts,
local_num_experts=self.local_num_experts,
routed_scaling_factor=routed_scaling_factor,
routing_method_type=self.routing_method_type,
use_shuffled_weight=False,
)
def _apply_per_tensor(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
apply_router_weight_on_input: bool,
# grouped topk + fused topk bias parameters
num_expert_group: int | None = None,
e_score_correction_bias: torch.Tensor | None = None,
routed_scaling_factor: float | None = None,
topk_group: int | None = None,
) -> torch.Tensor:
# Delay import for non-CUDA.
import flashinfer
from flashinfer.fused_moe.core import ActivationType
# Confirm supported activation function.
assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
activation_type = ActivationType(activation_to_flashinfer_int(activation))
# Confirm Llama-4 routing is proper.
if self.routing_method_type == RoutingMethodType.Llama4:
assert apply_router_weight_on_input
else:
assert not apply_router_weight_on_input
# The DeepSeekV3 routing method requires float32 router logits.
if self.routing_method_type == RoutingMethodType.DeepSeekV3:
router_logits = router_logits.to(torch.float32)
out = flashinfer.fused_moe.trtllm_fp8_per_tensor_scale_moe(
routing_logits=router_logits,
routing_bias=e_score_correction_bias,
hidden_states=hidden_states,
gemm1_weights=w1,
output1_scales_scalar=self._g1_scale_c,
output1_scales_gate_scalar=self._g1_alphas,
gemm2_weights=w2,
output2_scales_scalar=self._g2_alphas,
num_experts=global_num_experts,
top_k=self.topk,
n_group=num_expert_group or 0,
topk_group=topk_group or 0,
intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.ep_rank * self.local_num_experts,
local_num_experts=self.local_num_experts,
routed_scaling_factor=routed_scaling_factor,
use_routing_scales_on_input=apply_router_weight_on_input,
routing_method_type=self.routing_method_type,
activation_type=activation_type,
)
return out
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
apply_router_weight_on_input: bool,
# grouped topk + fused topk bias parameters
num_expert_group: int | None = None,
e_score_correction_bias: torch.Tensor | None = None,
routed_scaling_factor: float | None = None,
topk_group: int | None = None,
) -> torch.Tensor:
if self.quant_config.block_shape is not None:
return self._apply_per_block(
hidden_states,
w1,
w2,
router_logits,
activation,
global_num_experts,
expert_map,
a1q_scale,
apply_router_weight_on_input,
num_expert_group=num_expert_group,
e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor,
topk_group=topk_group,
)
elif self.quant_config.is_per_tensor:
return self._apply_per_tensor(
hidden_states,
w1,
w2,
router_logits,
activation,
global_num_experts,
expert_map,
a1q_scale,
apply_router_weight_on_input,
num_expert_group=num_expert_group,
e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
else:
raise NotImplementedError(
"Only per-block and per-tensor quantization are supported in "
f"{self.__class__.__name__}."
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import flashinfer
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
activation_to_flashinfer_int,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kNvfp4Dynamic,
kNvfp4Static,
)
from vllm.platforms import current_platform
class TrtLlmNvFp4ExpertsBase:
"""
NvFp4 TRTLLM-Gen MoE kernels. Supports modular and monolithic interface.
"""
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
):
self.moe_config = moe_config
self.quant_config = quant_config
self.routing_method_type = self.moe_config.routing_method
self.topk = moe_config.experts_per_token
self.intermediate_size_per_partition = (
moe_config.intermediate_size_per_partition
)
self.hidden_dim = moe_config.hidden_dim
self.local_num_experts = moe_config.num_local_experts
self.ep_rank = moe_config.moe_parallel_config.ep_rank
assert self.quant_config.g1_alphas is not None
assert self.quant_config.a2_gscale is not None
if moe_config.is_act_and_mul:
# g1_alpha_s = a13_scale * w13_scale_2
# a2_gscale = (1 / a2_scale)
# g1_scale_c = a13_scale * w13_scale_2 / a2_scale
self.g1_scale_c = self.quant_config.g1_alphas * self.quant_config.a2_gscale
else:
self.g1_scale_c = (
torch.ones_like(self.quant_config.a1_gscale)
* self.quant_config.a2_gscale
)
@staticmethod
def _supports_current_device() -> bool:
"""Supports only Blackwell-family GPUs."""
p = current_platform
return p.is_cuda() and p.is_device_capability_family(100)
@staticmethod
def _supports_no_act_and_mul() -> bool:
"""Supports non-gated MoE (i.e. Nemotron-Nano)."""
return True
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
"""Supports Nvfp4 quantization."""
SUPPORTED_W_A = [
(kNvfp4Static, kNvfp4Dynamic),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
"""Supports only SiLU and RELU^2 non-gated activation."""
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
@staticmethod
def _supports_shape(hidden_dim: int) -> bool:
"""Requires hidden dim to be multiple of 512."""
return hidden_dim % 512 == 0
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def supports_chunking(self) -> bool:
return False
def supports_expert_map(self) -> bool:
return False
class TrtLlmNvFp4ExpertsModular(TrtLlmNvFp4ExpertsBase, mk.FusedMoEExpertsModular):
"""
Modular version of the implementation (just the experts).
"""
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
"""The modular implementation supports all parallel configs."""
return True
def workspace_shapes(
self,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# The workspaces for this implementation are managed by flashinfer.
workspace1 = (0,)
workspace2 = (0,)
# Hidden states are Nvfp4, packed into int8 dtype, so we
# need to multiply K by 2 to get the output shape right.
assert self.hidden_dim == K * 2
output = (M, self.hidden_dim)
return (workspace1, workspace2, output)
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
assert a1q_scale is not None
assert self.quant_config.w1_scale is not None
assert self.quant_config.w2_scale is not None
# Pack topk ids and weights into format expected by the kernel.
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
torch.bfloat16
).view(torch.int16)
# trtllm_fp4_block_scale_routed_moe does not support autotuning
# so skip this kernel during dummy run for autotuning.
import vllm.utils.flashinfer as fi_utils
if fi_utils._is_fi_autotuning:
return hidden_states
# Invoke kernel.
flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe(
topk_ids=packed_tensor,
routing_bias=None,
hidden_states=hidden_states,
hidden_states_scale=a1q_scale.view(torch.float8_e4m3fn).reshape(
*hidden_states.shape[:-1], -1
),
gemm1_weights=w1,
gemm1_weights_scale=self.quant_config.w1_scale.view(torch.float8_e4m3fn),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=w2,
gemm2_weights_scale=self.quant_config.w2_scale.view(torch.float8_e4m3fn),
gemm2_bias=None,
output1_scale_scalar=self.g1_scale_c,
output1_scale_gate_scalar=self.quant_config.g1_alphas,
output2_scale_scalar=self.quant_config.g2_alphas,
num_experts=global_num_experts,
top_k=self.topk,
n_group=0,
topk_group=0,
intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.ep_rank * self.local_num_experts,
local_num_experts=self.local_num_experts,
routed_scaling_factor=None,
routing_method_type=1,
do_finalize=True,
activation_type=activation_to_flashinfer_int(activation),
output=output,
)
class TrtLlmNvFp4ExpertsMonolithic(
TrtLlmNvFp4ExpertsBase, mk.FusedMoEExpertsMonolithic
):
"""
Monolithic version of the kernel (router + experts).
"""
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
"""The modular implementation should be used for the Dp/Ep or EPLB case."""
return (
not moe_parallel_config.use_all2all_kernels
and not moe_parallel_config.enable_eplb
)
@staticmethod
def _supports_routing_method(
routing_method_type: RoutingMethodType,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
# NOTE(rob): this is a conservative list.
return routing_method_type in [
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
RoutingMethodType.Llama4,
]
@staticmethod
def _supports_router_logits_dtype(
router_logits_dtype: torch.dtype | None,
routing_method: RoutingMethodType,
) -> bool:
"""
The FlashInfer TRTLLM NvFp4 kernel expects bfloat16 router_logits by default.
Only DeepSeekV3 routing supports float32 router_logits (which is converted
internally in the kernel).
"""
if router_logits_dtype == torch.float32:
# Only DeepSeekV3 routing handles float32 logits
# https://github.com/flashinfer-ai/flashinfer/issues/2469
return routing_method == RoutingMethodType.DeepSeekV3
return True
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
apply_router_weight_on_input: bool,
# grouped topk + fused topk bias parameters
num_expert_group: int | None = None,
e_score_correction_bias: torch.Tensor | None = None,
routed_scaling_factor: float | None = None,
topk_group: int | None = None,
) -> torch.Tensor:
assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
assert a1q_scale is not None
assert self.quant_config.w1_scale is not None
assert self.quant_config.w2_scale is not None
assert (
apply_router_weight_on_input
and self.routing_method_type == RoutingMethodType.Llama4
) or (
not apply_router_weight_on_input
and self.routing_method_type != RoutingMethodType.Llama4
)
# Prepare routing bias into kernel format.
routing_bias = e_score_correction_bias
if routing_bias is not None:
routing_bias = routing_bias.to(torch.bfloat16)
router_logits = (
router_logits.to(torch.float32)
if self.routing_method_type == RoutingMethodType.DeepSeekV3
else router_logits
)
# Invoke kernel.
return flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
routing_logits=router_logits,
routing_bias=routing_bias,
hidden_states=hidden_states,
hidden_states_scale=a1q_scale.view(torch.float8_e4m3fn).reshape(
*hidden_states.shape[:-1], -1
),
gemm1_weights=w1,
gemm1_weights_scale=self.quant_config.w1_scale.view(torch.float8_e4m3fn),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=w2,
gemm2_weights_scale=self.quant_config.w2_scale.view(torch.float8_e4m3fn),
gemm2_bias=None,
output1_scale_scalar=self.g1_scale_c,
output1_scale_gate_scalar=self.quant_config.g1_alphas,
output2_scale_scalar=self.quant_config.g2_alphas,
num_experts=global_num_experts,
top_k=self.topk,
n_group=(num_expert_group or 0),
topk_group=(topk_group or 0),
intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.ep_rank * self.local_num_experts,
local_num_experts=self.local_num_experts,
routed_scaling_factor=routed_scaling_factor,
routing_method_type=self.routing_method_type,
do_finalize=True,
)[0]
...@@ -11,13 +11,13 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig ...@@ -11,13 +11,13 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC): class FallbackExperts(mk.FusedMoEExpertsModular, ABC):
"""Base class for runtime dispatching of expert implementations.""" """Base class for runtime dispatching of expert implementations."""
def __init__( def __init__(
self, self,
experts: mk.FusedMoEPermuteExpertsUnpermute, experts: mk.FusedMoEExpertsModular,
fallback_experts: mk.FusedMoEPermuteExpertsUnpermute, fallback_experts: mk.FusedMoEExpertsModular,
): ):
super().__init__( super().__init__(
moe_config=experts.moe_config, quant_config=experts.quant_config moe_config=experts.moe_config, quant_config=experts.quant_config
...@@ -27,8 +27,8 @@ class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC): ...@@ -27,8 +27,8 @@ class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
@staticmethod @staticmethod
def get_clses() -> tuple[ def get_clses() -> tuple[
type[mk.FusedMoEPermuteExpertsUnpermute], type[mk.FusedMoEExpertsModular],
type[mk.FusedMoEPermuteExpertsUnpermute], type[mk.FusedMoEExpertsModular],
]: ]:
""" """
Get the cls for the experts and fallback experts. Get the cls for the experts and fallback experts.
...@@ -149,7 +149,7 @@ class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC): ...@@ -149,7 +149,7 @@ class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
) -> mk.FusedMoEPermuteExpertsUnpermute: ) -> mk.FusedMoEExpertsModular:
raise NotImplementedError raise NotImplementedError
def apply( def apply(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment