Unverified Commit c277fbdf authored by TomerBN-Nvidia's avatar TomerBN-Nvidia Committed by GitHub
Browse files

[Feat] Support non-gated MoE with Marlin, NVFP4 CUTLASS, FP8, INT8, compressed-tensors (#32257)


Signed-off-by: default avatarTomer Natan <tbarnatan@computelab-frontend-8.nvidia.com>
Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
Co-authored-by: default avatarTomer Natan <tbarnatan@computelab-frontend-8.nvidia.com>
Co-authored-by: default avatarmgoin <mgoin64@gmail.com>
Co-authored-by: default avatarTomer Natan <tbarnatan@ipp1-1429.ipp1a1.colossus.nvidia.com>
parent aca5c514
...@@ -526,7 +526,7 @@ def test_run_cutlass_moe_fp8( ...@@ -526,7 +526,7 @@ def test_run_cutlass_moe_fp8(
c_strides1 = torch.full((e,), 2 * n, device="cuda", dtype=torch.int64) c_strides1 = torch.full((e,), 2 * n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((e,), k, device="cuda", dtype=torch.int64) c_strides2 = torch.full((e,), k, device="cuda", dtype=torch.int64)
activation = lambda o, i: torch.ops._C.silu_and_mul(o, i) activation = "silu"
a1q, a1q_scale = moe_kernel_quantize_input( a1q, a1q_scale = moe_kernel_quantize_input(
mt.a, mt.a_scale, torch.float8_e4m3fn, per_act_token mt.a, mt.a_scale, torch.float8_e4m3fn, per_act_token
) )
......
...@@ -1079,6 +1079,86 @@ def test_fused_marlin_moe_with_bias(m): ...@@ -1079,6 +1079,86 @@ def test_fused_marlin_moe_with_bias(m):
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0) torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
@pytest.mark.flaky(reruns=2)
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
@pytest.mark.parametrize("m", [1, 64, 256])
@pytest.mark.parametrize("n,k", [(1024, 1024), (2048, 2048)])
@pytest.mark.parametrize("e,topk", [(8, 2), (64, 4)])
def test_fused_marlin_moe_non_gated(m: int, n: int, k: int, e: int, topk: int):
"""Test Marlin MoE with non-gated activation (relu2_no_mul).
Non-gated activations like relu2 don't have the gate-up projection pattern,
so w1 has shape (e, n, k) instead of (e, 2*n, k).
"""
torch.cuda.manual_seed(42)
group_size = 16 # NVFP4 group size
is_k_full = True
quant_type = scalar_types.float4_e2m1f
dtype = torch.bfloat16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
# Non-gated: w1 shape is (e, n, k) not (e, 2*n, k)
w1 = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
w1_data = MarlinMoEWeightData.make(
w=w1,
quant_type=quant_type,
group_size=group_size,
act_order=False,
)
w2_data = MarlinMoEWeightData.make(
w=w2,
quant_type=quant_type,
group_size=group_size,
act_order=False,
)
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
with set_current_vllm_config(vllm_config):
torch_output = torch_moe(
a,
w1_data.w_ref,
w2_data.w_ref,
score,
topk,
activation="relu2",
)
marlin_output = fused_marlin_moe(
a,
w1_data.qweight,
w2_data.qweight,
None, # bias1
None, # bias2
w1_data.scales,
w2_data.scales,
score,
topk_weights,
topk_ids,
global_num_experts=e,
expert_map=None,
global_scale1=w1_data.global_scale,
global_scale2=w2_data.global_scale,
g_idx1=w1_data.g_idx,
g_idx2=w2_data.g_idx,
sort_indices1=w1_data.sort_indices,
sort_indices2=w2_data.sort_indices,
w1_zeros=w1_data.zeros,
w2_zeros=w2_data.zeros,
quant_type_id=quant_type.id,
is_k_full=is_k_full,
activation="relu2_no_mul",
)
torch.testing.assert_close(marlin_output, torch_output, atol=1e-1, rtol=0)
@pytest.mark.parametrize("ep_size", [1, 2]) @pytest.mark.parametrize("ep_size", [1, 2])
def test_moe_align_block_size_opcheck(ep_size): def test_moe_align_block_size_opcheck(ep_size):
num_experts = 4 num_experts = 4
......
...@@ -1451,6 +1451,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1451,6 +1451,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# - "flashinfer-cudnn": use flashinfer cudnn GEMM backend # - "flashinfer-cudnn": use flashinfer cudnn GEMM backend
# - "flashinfer-trtllm": use flashinfer trtllm GEMM backend # - "flashinfer-trtllm": use flashinfer trtllm GEMM backend
# - "flashinfer-cutlass": use flashinfer cutlass GEMM backend # - "flashinfer-cutlass": use flashinfer cutlass GEMM backend
# - "marlin": use marlin GEMM backend (for GPUs without native FP4 support)
# - <none>: automatically pick an available backend # - <none>: automatically pick an available backend
"VLLM_NVFP4_GEMM_BACKEND": env_with_choices( "VLLM_NVFP4_GEMM_BACKEND": env_with_choices(
"VLLM_NVFP4_GEMM_BACKEND", "VLLM_NVFP4_GEMM_BACKEND",
...@@ -1460,6 +1461,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1460,6 +1461,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
"flashinfer-trtllm", "flashinfer-trtllm",
"flashinfer-cutlass", "flashinfer-cutlass",
"cutlass", "cutlass",
"marlin",
], ],
), ),
# Controls garbage collection during CUDA graph capture. # Controls garbage collection during CUDA graph capture.
......
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""CUTLASS based Fused MoE kernels.""" """CUTLASS based Fused MoE kernels."""
from collections.abc import Callable
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
...@@ -21,7 +19,10 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( ...@@ -21,7 +19,10 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate, TopKWeightAndReduceDelegate,
TopKWeightAndReduceNoOP, TopKWeightAndReduceNoOP,
) )
from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache,
apply_moe_activation,
)
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -33,7 +34,7 @@ def run_cutlass_moe_fp8( ...@@ -33,7 +34,7 @@ def run_cutlass_moe_fp8(
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation_callable: Callable, activation: str,
global_num_experts: int, global_num_experts: int,
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
w1_scale: torch.Tensor | None, w1_scale: torch.Tensor | None,
...@@ -55,6 +56,7 @@ def run_cutlass_moe_fp8( ...@@ -55,6 +56,7 @@ def run_cutlass_moe_fp8(
): ):
a1q = hidden_states a1q = hidden_states
assert not activation.endswith("_no_mul"), "Only gated activation is supported"
assert w1_scale is not None assert w1_scale is not None
assert w2_scale is not None assert w2_scale is not None
assert w1.dtype == torch.float8_e4m3fn assert w1.dtype == torch.float8_e4m3fn
...@@ -198,7 +200,7 @@ def run_cutlass_moe_fp8( ...@@ -198,7 +200,7 @@ def run_cutlass_moe_fp8(
per_out_ch, per_out_ch,
) )
activation_callable(act_out, mm1_out) apply_moe_activation(activation, act_out, mm1_out)
a2q, a2q_scale = ops.scaled_fp8_quant( a2q, a2q_scale = ops.scaled_fp8_quant(
act_out, a2_scale, use_per_token_if_dynamic=per_act_token, output=quant_out act_out, a2_scale, use_per_token_if_dynamic=per_act_token, output=quant_out
...@@ -288,8 +290,6 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -288,8 +290,6 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
if expert_tokens_meta is not None: if expert_tokens_meta is not None:
expert_num_tokens = expert_tokens_meta.expert_num_tokens expert_num_tokens = expert_tokens_meta.expert_num_tokens
activation_callable = lambda o, i: self.activation(activation, o, i)
use_batched_format = ( use_batched_format = (
self.activation_formats[0] == mk.FusedMoEActivationFormat.BatchedExperts self.activation_formats[0] == mk.FusedMoEActivationFormat.BatchedExperts
) )
...@@ -301,7 +301,7 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -301,7 +301,7 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
w1, w1,
w2, w2,
topk_ids, topk_ids,
activation_callable, activation,
global_num_experts, global_num_experts,
expert_map, expert_map,
self.w1_scale, self.w1_scale,
...@@ -436,6 +436,7 @@ def run_cutlass_moe_fp4( ...@@ -436,6 +436,7 @@ def run_cutlass_moe_fp4(
w2_alphas: torch.Tensor, w2_alphas: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str,
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
m: int, m: int,
...@@ -544,8 +545,7 @@ def run_cutlass_moe_fp4( ...@@ -544,8 +545,7 @@ def run_cutlass_moe_fp4(
num_topk, num_topk,
) )
c1 = _resize_cache(workspace13, (m * topk, n * 2)) c1 = _resize_cache(workspace13, (m * topk, n * 2))
# Note: c2 workspace is no longer needed since SiLU is fused with quantization. c2 = _resize_cache(workspace2, (m * topk, n))
# c3 reuses workspace13 after c1 is consumed.
c3 = _resize_cache(workspace13, (m * topk, k)) c3 = _resize_cache(workspace13, (m * topk, k))
ops.cutlass_fp4_moe_mm( ops.cutlass_fp4_moe_mm(
c1, c1,
...@@ -559,10 +559,18 @@ def run_cutlass_moe_fp4( ...@@ -559,10 +559,18 @@ def run_cutlass_moe_fp4(
blockscale_offsets[:-1], blockscale_offsets[:-1],
) )
del rep_a_fp4, rep_a_blockscale del rep_a_fp4, rep_a_blockscale
if activation == "silu":
# Fused SiLU+Mul+NVFP4 quantization # Fused SiLU+Mul+NVFP4 quantization
# Note: c2 workspace is no longer needed since SiLU is fused with quantization.
# c3 reuses workspace13 after c1 is consumed.
int_fp4, int_blockscale = ops.silu_and_mul_scaled_fp4_experts_quant( int_fp4, int_blockscale = ops.silu_and_mul_scaled_fp4_experts_quant(
c1, a2_gscale, expert_offsets, blockscale_offsets, num_topk c1, a2_gscale, expert_offsets, blockscale_offsets, num_topk
) )
else:
apply_moe_activation(activation, c2, c1)
int_fp4, int_blockscale = ops.scaled_fp4_experts_quant(
c2, a2_gscale, expert_offsets, blockscale_offsets, num_topk
)
ops.cutlass_fp4_moe_mm( ops.cutlass_fp4_moe_mm(
c3, c3,
...@@ -693,6 +701,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -693,6 +701,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
w2_alphas=self.g2_alphas, w2_alphas=self.g2_alphas,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
activation=activation,
workspace13=workspace13, workspace13=workspace13,
workspace2=workspace2, workspace2=workspace2,
m=m, m=m,
...@@ -711,7 +720,7 @@ def run_cutlass_moe_w4a8_fp8( ...@@ -711,7 +720,7 @@ def run_cutlass_moe_w4a8_fp8(
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation_callable: Callable, activation: str,
global_num_experts: int, global_num_experts: int,
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
w1_scale: torch.Tensor | None, w1_scale: torch.Tensor | None,
...@@ -815,7 +824,7 @@ def run_cutlass_moe_w4a8_fp8( ...@@ -815,7 +824,7 @@ def run_cutlass_moe_w4a8_fp8(
s_strides1, s_strides1,
) )
activation_callable(act_out, mm1_out) apply_moe_activation(activation, act_out, mm1_out)
a2q, a2q_scale = ops.scaled_fp8_quant( a2q, a2q_scale = ops.scaled_fp8_quant(
act_out, a2_scale, use_per_token_if_dynamic=per_act_token, output=quant_out act_out, a2_scale, use_per_token_if_dynamic=per_act_token, output=quant_out
...@@ -936,7 +945,6 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -936,7 +945,6 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
assert self.w2_zp is None, "w2_zp is not supported in CUTLASS MoE" assert self.w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
expert_num_tokens = None expert_num_tokens = None
activation_callable = lambda o, i: self.activation(activation, o, i)
use_batched_format = ( use_batched_format = (
self.activation_formats[0] == mk.FusedMoEActivationFormat.BatchedExperts self.activation_formats[0] == mk.FusedMoEActivationFormat.BatchedExperts
...@@ -951,7 +959,7 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -951,7 +959,7 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
w1, w1,
w2, w2,
topk_ids, topk_ids,
activation_callable, activation,
global_num_experts, global_num_experts,
expert_map, expert_map,
self.w1_scale, self.w1_scale,
......
...@@ -17,7 +17,11 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( ...@@ -17,7 +17,11 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate, TopKWeightAndReduceDelegate,
TopKWeightAndReduceNoOP, TopKWeightAndReduceNoOP,
) )
from vllm.model_executor.layers.fused_moe.utils import _resize_cache, disable_inplace from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache,
apply_moe_activation,
disable_inplace,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_make_workspace_new, marlin_make_workspace_new,
marlin_moe_intermediate_size, marlin_moe_intermediate_size,
...@@ -27,21 +31,6 @@ from vllm.platforms import current_platform ...@@ -27,21 +31,6 @@ from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types from vllm.scalar_type import ScalarType, scalar_types
def default_activation_func(
activation: str, output: torch.Tensor, input: torch.Tensor
) -> None:
if activation == "silu":
torch.ops._C.silu_and_mul(output, input)
elif activation == "swigluoai":
# alpha = 1.702, limit = 7.0
torch.ops._C.swigluoai_and_mul(output, input)
else:
raise ValueError(
f"Unsupported activation: {activation}. "
"Only silu and swigluoai activations are supported."
)
def _fused_marlin_moe( def _fused_marlin_moe(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
...@@ -62,7 +51,7 @@ def _fused_marlin_moe( ...@@ -62,7 +51,7 @@ def _fused_marlin_moe(
activation: str = "silu", activation: str = "silu",
activation_func: Callable[ activation_func: Callable[
[str, torch.Tensor, torch.Tensor], None [str, torch.Tensor, torch.Tensor], None
] = default_activation_func, ] = apply_moe_activation,
input_global_scale1: torch.Tensor | None = None, input_global_scale1: torch.Tensor | None = None,
input_global_scale2: torch.Tensor | None = None, input_global_scale2: torch.Tensor | None = None,
global_scale1: torch.Tensor | None = None, global_scale1: torch.Tensor | None = None,
...@@ -83,13 +72,13 @@ def _fused_marlin_moe( ...@@ -83,13 +72,13 @@ def _fused_marlin_moe(
assert hidden_states.ndim == 2 assert hidden_states.ndim == 2
M, K = hidden_states.size() M, K = hidden_states.size()
N = marlin_moe_intermediate_size(w1, w2) N = marlin_moe_intermediate_size(w1, w2)
w13_num_shards = 1 if "no_mul" in activation else 2
if workspace is None: if workspace is None:
workspace = marlin_make_workspace_new(hidden_states.device, 4) workspace = marlin_make_workspace_new(hidden_states.device, 4)
if intermediate_cache13 is None: if intermediate_cache13 is None:
intermediate_cache13 = torch.empty( intermediate_cache13 = torch.empty(
(M * num_topk * max(2 * N, K),), (M * num_topk * max(w13_num_shards * N, K),),
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
) )
...@@ -101,7 +90,9 @@ def _fused_marlin_moe( ...@@ -101,7 +90,9 @@ def _fused_marlin_moe(
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
) )
intermediate_cache1 = _resize_cache(intermediate_cache13, (M * num_topk, 2 * N)) intermediate_cache1 = _resize_cache(
intermediate_cache13, (M * num_topk, w13_num_shards * N)
)
intermediate_cache3 = _resize_cache(intermediate_cache13, (M * num_topk, K)) intermediate_cache3 = _resize_cache(intermediate_cache13, (M * num_topk, K))
...@@ -137,16 +128,17 @@ def _fused_marlin_moe( ...@@ -137,16 +128,17 @@ def _fused_marlin_moe(
mul_topk_weights=apply_router_weight_on_input, mul_topk_weights=apply_router_weight_on_input,
b_q_type=quant_type, b_q_type=quant_type,
size_m=M, size_m=M,
size_n=2 * N, size_n=w13_num_shards * N,
size_k=K, size_k=K,
is_k_full=is_k_full, is_k_full=is_k_full,
use_atomic_add=False, use_atomic_add=False,
use_fp32_reduce=True, use_fp32_reduce=True,
is_zp_float=False, is_zp_float=False,
) )
activation_func( activation_func(
activation, intermediate_cache2, intermediate_cache1.view(-1, 2 * N) activation,
intermediate_cache2,
intermediate_cache1.view(-1, w13_num_shards * N),
) )
if output is None: if output is None:
...@@ -216,7 +208,7 @@ def fused_marlin_moe( ...@@ -216,7 +208,7 @@ def fused_marlin_moe(
activation: str = "silu", activation: str = "silu",
activation_func: Callable[ activation_func: Callable[
[str, torch.Tensor, torch.Tensor], None [str, torch.Tensor, torch.Tensor], None
] = default_activation_func, ] = apply_moe_activation,
moe_sum: Callable[[torch.Tensor, torch.Tensor], None] | None = None, moe_sum: Callable[[torch.Tensor, torch.Tensor], None] | None = None,
expert_map: torch.Tensor | None = None, expert_map: torch.Tensor | None = None,
input_global_scale1: torch.Tensor | None = None, input_global_scale1: torch.Tensor | None = None,
......
...@@ -619,26 +619,7 @@ class FusedMoE(CustomOp): ...@@ -619,26 +619,7 @@ class FusedMoE(CustomOp):
# for heuristic purposes, so it must be initialized first. # for heuristic purposes, so it must be initialized first.
self.quant_method: FusedMoEMethodBase = _get_quant_method() self.quant_method: FusedMoEMethodBase = _get_quant_method()
if not self.moe_config.is_act_and_mul: if not self.moe_config.is_act_and_mul and not current_platform.is_cuda():
# Avoid circular import
from vllm.model_executor.layers.quantization.modelopt import (
ModelOptFp8MoEMethod,
ModelOptNvFp4FusedMoE,
)
if not isinstance(
self.quant_method,
(
UnquantizedFusedMoEMethod,
ModelOptFp8MoEMethod,
ModelOptNvFp4FusedMoE,
),
):
raise NotImplementedError(
"is_act_and_mul=False is supported only for unquantized "
", ModelOpt FP8, and ModelOpt NvFp4 checkpoints"
)
if not current_platform.is_cuda():
raise NotImplementedError( raise NotImplementedError(
"is_act_and_mul=False is supported only for CUDA for now" "is_act_and_mul=False is supported only for CUDA for now"
) )
......
...@@ -52,7 +52,7 @@ def select_fp8_moe_backend( ...@@ -52,7 +52,7 @@ def select_fp8_moe_backend(
block_quant: bool, block_quant: bool,
tp_size: int, tp_size: int,
with_lora_support: bool, with_lora_support: bool,
is_act_and_mul: bool = True, is_act_and_mul: bool,
allow_vllm_cutlass: bool = False, allow_vllm_cutlass: bool = False,
) -> Fp8MoeBackend: ) -> Fp8MoeBackend:
""" """
...@@ -128,7 +128,7 @@ def select_fp8_moe_backend( ...@@ -128,7 +128,7 @@ def select_fp8_moe_backend(
scope="local", scope="local",
) )
if use_deep_gemm and moe_use_deep_gemm and block_quant: if use_deep_gemm and moe_use_deep_gemm and block_quant and is_act_and_mul:
if not has_deep_gemm(): if not has_deep_gemm():
logger.warning_once( logger.warning_once(
"DeepGEMM backend requested but not available.", scope="local" "DeepGEMM backend requested but not available.", scope="local"
...@@ -141,7 +141,12 @@ def select_fp8_moe_backend( ...@@ -141,7 +141,12 @@ def select_fp8_moe_backend(
logger.info_once(_make_log_backend("ROCm AITER"), scope="local") logger.info_once(_make_log_backend("ROCm AITER"), scope="local")
return Fp8MoeBackend.AITER return Fp8MoeBackend.AITER
if allow_vllm_cutlass and not block_quant and cutlass_group_gemm_supported(): if (
allow_vllm_cutlass
and not block_quant
and cutlass_group_gemm_supported()
and is_act_and_mul
):
logger.info_once(_make_log_backend("vLLM CUTLASS"), scope="local") logger.info_once(_make_log_backend("vLLM CUTLASS"), scope="local")
return Fp8MoeBackend.VLLM_CUTLASS return Fp8MoeBackend.VLLM_CUTLASS
......
...@@ -178,6 +178,7 @@ def convert_to_nvfp4_moe_kernel_format( ...@@ -178,6 +178,7 @@ def convert_to_nvfp4_moe_kernel_format(
w2=w2, w2=w2,
w2_scale=w2_scale, w2_scale=w2_scale,
w2_scale_2=w2_scale_2, w2_scale_2=w2_scale_2,
is_act_and_mul=is_act_and_mul,
) )
else: else:
raise ValueError(f"Unknown NvFp4 backend for MoE: {nvfp4_backend}") raise ValueError(f"Unknown NvFp4 backend for MoE: {nvfp4_backend}")
......
...@@ -367,7 +367,8 @@ def apply_moe_activation( ...@@ -367,7 +367,8 @@ def apply_moe_activation(
elif activation == GELU_NO_MUL: elif activation == GELU_NO_MUL:
output.copy_(F.gelu(input)) output.copy_(F.gelu(input))
elif activation == RELU2_NO_MUL: elif activation == RELU2_NO_MUL:
torch.square(F.relu(input), out=output) F.relu(input, inplace=True)
torch.square(input, out=output)
else: else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}") raise ValueError(f"Unsupported FusedMoe activation: {activation}")
......
...@@ -764,8 +764,6 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -764,8 +764,6 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert layer.activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids = router.select_experts( topk_weights, topk_ids = router.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
......
...@@ -370,12 +370,6 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -370,12 +370,6 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
layer_name: str | None = None, layer_name: str | None = None,
use_marlin: bool = False, use_marlin: bool = False,
): ):
if not moe.is_act_and_mul:
raise ValueError(
"CompressedTensorsW4A4Nvfp4MoEMethod does not yet "
"support non gated MoE models."
)
super().__init__(moe) super().__init__(moe)
self.group_size = 16 self.group_size = 16
if use_marlin: if use_marlin:
...@@ -388,6 +382,16 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -388,6 +382,16 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
) )
else: else:
self.nvfp4_backend = select_nvfp4_moe_backend() self.nvfp4_backend = select_nvfp4_moe_backend()
# TODO: move this type of check into the oracle.
if not self.moe.is_act_and_mul and self.nvfp4_backend not in [
NvFp4MoeBackend.FLASHINFER_CUTLASS,
NvFp4MoeBackend.MARLIN,
]:
raise NotImplementedError(
"Non-gated activations are only supported by FlashInfer "
f"CUTLASS and Marlin NvFP4 MoE backends, not {self.nvfp4_backend}."
)
self.use_global_sf = is_global_sf_supported_for_nvfp4_backend( self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
self.nvfp4_backend self.nvfp4_backend
) )
...@@ -404,11 +408,12 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -404,11 +408,12 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
): ):
layer.num_experts = num_experts layer.num_experts = num_experts
layer.params_dtype = params_dtype layer.params_dtype = params_dtype
w13_num_shards = 2 if self.moe.is_act_and_mul else 1
w13_weight = torch.nn.Parameter( w13_weight = torch.nn.Parameter(
torch.empty( torch.empty(
num_experts, num_experts,
2 * intermediate_size_per_partition, w13_num_shards * intermediate_size_per_partition,
# 2 fp4 items are packed in the input dimension # 2 fp4 items are packed in the input dimension
hidden_size // 2, hidden_size // 2,
requires_grad=False, requires_grad=False,
...@@ -436,7 +441,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -436,7 +441,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
w13_weight_scale = torch.nn.Parameter( w13_weight_scale = torch.nn.Parameter(
torch.empty( torch.empty(
num_experts, num_experts,
2 * intermediate_size_per_partition, w13_num_shards * intermediate_size_per_partition,
# 2 fp4 items are packed in the input dimension # 2 fp4 items are packed in the input dimension
hidden_size // self.group_size, hidden_size // self.group_size,
dtype=torch.float8_e4m3fn, dtype=torch.float8_e4m3fn,
...@@ -467,7 +472,8 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -467,7 +472,8 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
# Weight Global Scales # Weight Global Scales
w13_weight_scale_2 = torch.nn.Parameter( w13_weight_scale_2 = torch.nn.Parameter(
torch.empty(num_experts, 2, dtype=torch.float32), requires_grad=False torch.empty(num_experts, w13_num_shards, dtype=torch.float32),
requires_grad=False,
) )
layer.register_parameter("w13_weight_global_scale", w13_weight_scale_2) layer.register_parameter("w13_weight_global_scale", w13_weight_scale_2)
extra_weight_attrs.update( extra_weight_attrs.update(
...@@ -486,7 +492,8 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -486,7 +492,8 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
# Input Global Scales # Input Global Scales
w13_input_scale = torch.nn.Parameter( w13_input_scale = torch.nn.Parameter(
torch.empty(num_experts, 2, dtype=torch.float32), requires_grad=False torch.empty(num_experts, w13_num_shards, dtype=torch.float32),
requires_grad=False,
) )
layer.register_parameter("w13_input_global_scale", w13_input_scale) layer.register_parameter("w13_input_global_scale", w13_input_scale)
extra_weight_attrs.update( extra_weight_attrs.update(
...@@ -640,6 +647,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -640,6 +647,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
x=x, x=x,
router_logits=router_logits, router_logits=router_logits,
top_k=layer.top_k, top_k=layer.top_k,
activation=layer.activation,
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
num_expert_group=layer.num_expert_group, num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group, topk_group=layer.topk_group,
...@@ -666,6 +674,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -666,6 +674,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
topk_ids=topk_ids, topk_ids=topk_ids,
topk_weights=topk_weights, topk_weights=topk_weights,
top_k=layer.top_k, top_k=layer.top_k,
activation=layer.activation,
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
) )
else: else:
...@@ -722,6 +731,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -722,6 +731,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
block_quant=self.block_quant, block_quant=self.block_quant,
tp_size=moe.tp_size, tp_size=moe.tp_size,
with_lora_support=moe.is_lora_enabled, with_lora_support=moe.is_lora_enabled,
is_act_and_mul=moe.is_act_and_mul,
# TODO(rob): enable selecting this externally. # TODO(rob): enable selecting this externally.
allow_vllm_cutlass=True, allow_vllm_cutlass=True,
) )
...@@ -760,6 +770,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -760,6 +770,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.weight_block_size = None layer.weight_block_size = None
params_dtype = torch.float8_e4m3fn params_dtype = torch.float8_e4m3fn
w13_num_shards = 2 if self.moe.is_act_and_mul else 1
if self.block_quant: if self.block_quant:
assert self.weight_block_size is not None assert self.weight_block_size is not None
...@@ -791,7 +802,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -791,7 +802,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
w13_weight = torch.nn.Parameter( w13_weight = torch.nn.Parameter(
torch.empty( torch.empty(
num_experts, num_experts,
2 * intermediate_size_per_partition, w13_num_shards * intermediate_size_per_partition,
hidden_size, hidden_size,
dtype=params_dtype, dtype=params_dtype,
), ),
...@@ -814,10 +825,12 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -814,10 +825,12 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
# WEIGHT_SCALES # WEIGHT_SCALES
if self.weight_quant.strategy == QuantizationStrategy.TENSOR: if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
# Allocate 2 scales for w1 and w3 respectively. # For gated MoE, allocate 2 scales for w1 and w3 respectively.
# They are combined to a single scale after weight loading. # They will be combined to a single scale after weight loading.
# For non-gated MoE, allocate 1 scale for w13.
w13_weight_scale = torch.nn.Parameter( w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False torch.ones(num_experts, w13_num_shards, dtype=torch.float32),
requires_grad=False,
) )
layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter( w2_weight_scale = torch.nn.Parameter(
...@@ -835,7 +848,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -835,7 +848,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
w13_weight_scale = torch.nn.Parameter( w13_weight_scale = torch.nn.Parameter(
torch.ones( torch.ones(
num_experts, num_experts,
2 * intermediate_size_per_partition, w13_num_shards * intermediate_size_per_partition,
1, 1,
dtype=torch.float32, dtype=torch.float32,
), ),
...@@ -858,7 +871,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -858,7 +871,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
w13_weight_scale = torch.nn.Parameter( w13_weight_scale = torch.nn.Parameter(
torch.ones( torch.ones(
num_experts, num_experts,
2 * ((intermediate_size_per_partition + block_n - 1) // block_n), w13_num_shards
* ((intermediate_size_per_partition + block_n - 1) // block_n),
(hidden_size + block_k - 1) // block_k, (hidden_size + block_k - 1) // block_k,
dtype=torch.float32, dtype=torch.float32,
), ),
...@@ -930,11 +944,12 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -930,11 +944,12 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
# Per-tensor kernels use a single scale, for W13, but on disk there # Per-tensor kernels use a single scale, for W13, but on disk there
# is a separate scale for W1 and W3. Requantize with the max scale. # is a separate scale for W1 and W3. Requantize with the max scale.
if self.weight_quant.strategy == QuantizationStrategy.TENSOR: if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
process_fp8_weight_tensor_strategy_moe( w13, w13_scale = process_fp8_weight_tensor_strategy_moe(
w13, w13,
w13_scale, w13_scale,
shard_size=layer.intermediate_size_per_partition, shard_size=layer.intermediate_size_per_partition,
num_experts=layer.num_local_experts, num_experts=layer.num_local_experts,
is_act_and_mul=self.moe.is_act_and_mul,
) )
w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format( w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
...@@ -1166,12 +1181,13 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1166,12 +1181,13 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
**extra_weight_attrs, **extra_weight_attrs,
): ):
params_dtype = torch.int8 params_dtype = torch.int8
w13_num_shards = 2 if self.moe.is_act_and_mul else 1
# WEIGHTS # WEIGHTS
w13_weight = torch.nn.Parameter( w13_weight = torch.nn.Parameter(
torch.empty( torch.empty(
num_experts, num_experts,
2 * intermediate_size_per_partition, w13_num_shards * intermediate_size_per_partition,
hidden_size, hidden_size,
dtype=params_dtype, dtype=params_dtype,
), ),
...@@ -1196,7 +1212,10 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1196,7 +1212,10 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
assert self.weight_quant.strategy == QuantizationStrategy.CHANNEL assert self.weight_quant.strategy == QuantizationStrategy.CHANNEL
w13_weight_scale = torch.nn.Parameter( w13_weight_scale = torch.nn.Parameter(
torch.ones( torch.ones(
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32 num_experts,
w13_num_shards * intermediate_size_per_partition,
1,
dtype=torch.float32,
), ),
requires_grad=False, requires_grad=False,
) )
...@@ -1296,6 +1315,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): ...@@ -1296,6 +1315,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
**extra_weight_attrs, **extra_weight_attrs,
): ):
intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full") intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full")
w13_num_shards = 2 if self.moe.is_act_and_mul else 1
# Will transpose the loaded weight along the # Will transpose the loaded weight along the
# intermediate and hidden dim sizes. Will # intermediate and hidden dim sizes. Will
...@@ -1307,7 +1327,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): ...@@ -1307,7 +1327,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
torch.empty( torch.empty(
num_experts, num_experts,
hidden_size // self.packed_factor, hidden_size // self.packed_factor,
2 * intermediate_size_per_partition, w13_num_shards * intermediate_size_per_partition,
dtype=torch.int32, dtype=torch.int32,
), ),
requires_grad=False, requires_grad=False,
...@@ -1352,7 +1372,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): ...@@ -1352,7 +1372,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
torch.ones( torch.ones(
num_experts, num_experts,
num_groups_w13, num_groups_w13,
2 * intermediate_size_per_partition, w13_num_shards * intermediate_size_per_partition,
dtype=params_dtype, dtype=params_dtype,
), ),
requires_grad=False, requires_grad=False,
...@@ -1600,10 +1620,6 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): ...@@ -1600,10 +1620,6 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert layer.activation == "silu", (
f"{layer.activation} not supported for Marlin MoE."
)
topk_weights, topk_ids = router.select_experts( topk_weights, topk_ids = router.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
...@@ -1625,6 +1641,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): ...@@ -1625,6 +1641,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
quant_type_id=self.quant_type.id, quant_type_id=self.quant_type.id,
apply_router_weight_on_input=layer.apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
activation=layer.activation,
expert_map=layer.expert_map, expert_map=layer.expert_map,
g_idx1=layer.w13_weight_g_idx, g_idx1=layer.w13_weight_g_idx,
g_idx2=layer.w2_weight_g_idx, g_idx2=layer.w2_weight_g_idx,
...@@ -1675,11 +1692,12 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -1675,11 +1692,12 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
extra_weight_attrs.update( extra_weight_attrs.update(
{"is_transposed": True, "quant_method": self.strategy} {"is_transposed": True, "quant_method": self.strategy}
) )
w13_num_shards = 2 if self.moe.is_act_and_mul else 1
w13_weight = torch.nn.Parameter( w13_weight = torch.nn.Parameter(
torch.empty( torch.empty(
num_experts, num_experts,
hidden_size // self.packed_factor, hidden_size // self.packed_factor,
2 * intermediate_size_per_partition, w13_num_shards * intermediate_size_per_partition,
dtype=torch.int32, dtype=torch.int32,
), ),
requires_grad=False, requires_grad=False,
...@@ -1712,7 +1730,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -1712,7 +1730,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
torch.ones( torch.ones(
num_experts, num_experts,
num_groups_w13, num_groups_w13,
2 * intermediate_size_per_partition, w13_num_shards * intermediate_size_per_partition,
dtype=params_dtype, dtype=params_dtype,
), ),
requires_grad=False, requires_grad=False,
......
...@@ -637,6 +637,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -637,6 +637,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
block_quant=self.block_quant, block_quant=self.block_quant,
tp_size=layer.moe_parallel_config.tp_size, tp_size=layer.moe_parallel_config.tp_size,
with_lora_support=self.moe.is_lora_enabled, with_lora_support=self.moe.is_lora_enabled,
is_act_and_mul=self.moe.is_act_and_mul,
) )
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
......
...@@ -900,8 +900,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -900,8 +900,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert layer.activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids = router.select_experts( topk_weights, topk_ids = router.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
......
...@@ -733,6 +733,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -733,6 +733,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
block_quant=False, block_quant=False,
tp_size=moe_config.moe_parallel_config.tp_size, tp_size=moe_config.moe_parallel_config.tp_size,
with_lora_support=self.moe.is_lora_enabled, with_lora_support=self.moe.is_lora_enabled,
is_act_and_mul=self.moe.is_act_and_mul,
) )
self.kernel: mk.FusedMoEModularKernel | None = None self.kernel: mk.FusedMoEModularKernel | None = None
...@@ -789,15 +790,12 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -789,15 +790,12 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
) )
weight_loader = extra_weight_attrs.get("weight_loader") weight_loader = extra_weight_attrs.get("weight_loader")
if self.moe.is_act_and_mul: w13_num_shards = 2 if self.moe.is_act_and_mul else 1
w13_up_dim = 2 * intermediate_size_per_partition
else:
w13_up_dim = intermediate_size_per_partition
w13_weight = ModelWeightParameter( w13_weight = ModelWeightParameter(
data=torch.empty( data=torch.empty(
num_experts, num_experts,
w13_up_dim, w13_num_shards * intermediate_size_per_partition,
hidden_size, hidden_size,
dtype=weight_dtype, dtype=weight_dtype,
), ),
...@@ -826,7 +824,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -826,7 +824,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
# For non-gated MoE, allocate 1 scale for w13. # For non-gated MoE, allocate 1 scale for w13.
w13_weight_scale = PerTensorScaleParameter( w13_weight_scale = PerTensorScaleParameter(
data=torch.full( data=torch.full(
(num_experts, 2 if self.moe.is_act_and_mul else 1), (num_experts, w13_num_shards),
1.0, 1.0,
dtype=torch.float32, dtype=torch.float32,
), ),
...@@ -1132,6 +1130,9 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): ...@@ -1132,6 +1130,9 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
elif envs.VLLM_NVFP4_GEMM_BACKEND == "cutlass": elif envs.VLLM_NVFP4_GEMM_BACKEND == "cutlass":
self.backend = "cutlass" self.backend = "cutlass"
assert cutlass_fp4_supported(), f"Cutlass is required for {self.backend}" assert cutlass_fp4_supported(), f"Cutlass is required for {self.backend}"
elif envs.VLLM_NVFP4_GEMM_BACKEND == "marlin":
self.backend = "marlin"
assert is_fp4_marlin_supported(), f"Marlin is required for {self.backend}"
if self.backend == "none": if self.backend == "none":
raise ValueError( raise ValueError(
...@@ -1337,13 +1338,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1337,13 +1338,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
self.quant_config = quant_config self.quant_config = quant_config
self.nvfp4_backend = select_nvfp4_moe_backend() self.nvfp4_backend = select_nvfp4_moe_backend()
# TODO: move this type of check into the oracle. # TODO: move this type of check into the oracle.
if ( if not self.moe.is_act_and_mul and self.nvfp4_backend not in [
not self.moe.is_act_and_mul NvFp4MoeBackend.FLASHINFER_CUTLASS,
and not self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS NvFp4MoeBackend.MARLIN,
): ]:
raise NotImplementedError( raise NotImplementedError(
"Non-gated activations are only supported by FlashInfer " "Non-gated activations are only supported by FlashInfer "
"CUTLASS NvFP4 MoE backend." f"CUTLASS and Marlin NvFP4 MoE backends, not {self.nvfp4_backend}."
) )
self.use_global_sf = is_global_sf_supported_for_nvfp4_backend( self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
...@@ -1409,11 +1410,12 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1409,11 +1410,12 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
weight_scale_dtype = torch.float8_e4m3fn weight_scale_dtype = torch.float8_e4m3fn
weight_loader = extra_weight_attrs.get("weight_loader") weight_loader = extra_weight_attrs.get("weight_loader")
global_num_experts = extra_weight_attrs.get("global_num_experts") global_num_experts = extra_weight_attrs.get("global_num_experts")
w13_num_shards = 2 if self.moe.is_act_and_mul else 1
# GEMM 1 # GEMM 1
w13_weight = ModelWeightParameter( w13_weight = ModelWeightParameter(
data=torch.empty( data=torch.empty(
num_experts, num_experts,
(2 if self.moe.is_act_and_mul else 1) * intermediate_size_per_partition, w13_num_shards * intermediate_size_per_partition,
# 2 fp4 items are packed in the input dimension # 2 fp4 items are packed in the input dimension
hidden_size // 2, hidden_size // 2,
dtype=weight_dtype, dtype=weight_dtype,
...@@ -1442,7 +1444,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1442,7 +1444,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
w13_weight_scale = ModelWeightParameter( w13_weight_scale = ModelWeightParameter(
data=torch.empty( data=torch.empty(
num_experts, num_experts,
(2 if self.moe.is_act_and_mul else 1) * intermediate_size_per_partition, w13_num_shards * intermediate_size_per_partition,
# 2 fp4 items are packed in the input dimension # 2 fp4 items are packed in the input dimension
hidden_size // self.quant_config.group_size, hidden_size // self.quant_config.group_size,
dtype=weight_scale_dtype, dtype=weight_scale_dtype,
...@@ -1472,9 +1474,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1472,9 +1474,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
) )
w13_weight_scale_2 = PerTensorScaleParameter( w13_weight_scale_2 = PerTensorScaleParameter(
data=torch.empty( data=torch.empty(num_experts, w13_num_shards, dtype=torch.float32),
num_experts, 2 if self.moe.is_act_and_mul else 1, dtype=torch.float32
),
weight_loader=weight_loader, weight_loader=weight_loader,
) )
layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2) layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)
...@@ -1495,7 +1495,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1495,7 +1495,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
w13_input_scale = PerTensorScaleParameter( w13_input_scale = PerTensorScaleParameter(
data=torch.empty( data=torch.empty(
global_sf_num_experts, global_sf_num_experts,
2 if self.moe.is_act_and_mul else 1, w13_num_shards,
dtype=torch.float32, dtype=torch.float32,
), ),
weight_loader=weight_loader, weight_loader=weight_loader,
...@@ -1616,6 +1616,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1616,6 +1616,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
x=x, x=x,
router_logits=router_logits, router_logits=router_logits,
top_k=layer.top_k, top_k=layer.top_k,
activation=layer.activation,
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
num_expert_group=layer.num_expert_group, num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group, topk_group=layer.topk_group,
...@@ -1642,6 +1643,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1642,6 +1643,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
topk_ids=topk_ids, topk_ids=topk_ids,
topk_weights=topk_weights, topk_weights=topk_weights,
top_k=layer.top_k, top_k=layer.top_k,
activation=layer.activation,
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
) )
else: else:
......
...@@ -255,6 +255,7 @@ def flashinfer_trtllm_fp4_moe( ...@@ -255,6 +255,7 @@ def flashinfer_trtllm_fp4_moe(
x: torch.Tensor | tuple[torch.Tensor, torch.Tensor], x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
activation: str,
global_num_experts: int, global_num_experts: int,
num_expert_group: int | None, num_expert_group: int | None,
topk_group: int | None, topk_group: int | None,
...@@ -269,6 +270,7 @@ def flashinfer_trtllm_fp4_moe( ...@@ -269,6 +270,7 @@ def flashinfer_trtllm_fp4_moe(
x: Input tensor x: Input tensor
router_logits: Router logits for expert selection router_logits: Router logits for expert selection
top_k: Number of experts to select per token top_k: Number of experts to select per token
activation: Activation function to use
global_num_experts: Total number of experts across all ranks global_num_experts: Total number of experts across all ranks
num_expert_group: Number of expert groups (for grouped routing) num_expert_group: Number of expert groups (for grouped routing)
topk_group: Top-k within each group topk_group: Top-k within each group
...@@ -282,6 +284,12 @@ def flashinfer_trtllm_fp4_moe( ...@@ -282,6 +284,12 @@ def flashinfer_trtllm_fp4_moe(
from vllm.model_executor.models.llama4 import Llama4MoE from vllm.model_executor.models.llama4 import Llama4MoE
# https://github.com/flashinfer-ai/flashinfer/blob/f0277fd1bff90e309e5c19cab36c5dae056d685d/flashinfer/fused_moe/core.py#L2404
assert activation == "silu", (
"Only SiLU activation is supported for FlashInfer TRTLLM FP4 MoE. "
f"{activation} found instead."
)
# Quantize input to FP4 # Quantize input to FP4
if isinstance(x, tuple): if isinstance(x, tuple):
hidden_states_fp4, hidden_states_scale_linear_fp4 = x hidden_states_fp4, hidden_states_scale_linear_fp4 = x
...@@ -352,6 +360,7 @@ def flashinfer_trtllm_fp4_routed_moe( ...@@ -352,6 +360,7 @@ def flashinfer_trtllm_fp4_routed_moe(
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
top_k: int, top_k: int,
activation: str,
global_num_experts: int, global_num_experts: int,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
...@@ -364,6 +373,7 @@ def flashinfer_trtllm_fp4_routed_moe( ...@@ -364,6 +373,7 @@ def flashinfer_trtllm_fp4_routed_moe(
x: Input tensor x: Input tensor
topk_ids: Ids of selected experts topk_ids: Ids of selected experts
top_k: Number of experts to select per token top_k: Number of experts to select per token
activation: Activation function to use
global_num_experts: Total number of experts across all ranks global_num_experts: Total number of experts across all ranks
Returns: Returns:
...@@ -371,6 +381,12 @@ def flashinfer_trtllm_fp4_routed_moe( ...@@ -371,6 +381,12 @@ def flashinfer_trtllm_fp4_routed_moe(
""" """
import flashinfer import flashinfer
# https://github.com/flashinfer-ai/flashinfer/blob/f0277fd1bff90e309e5c19cab36c5dae056d685d/flashinfer/fused_moe/core.py#L2535
assert activation == "silu", (
"Only SiLU activation is supported for FlashInfer TRTLLM FP4 Routed MoE. "
f"{activation} found instead."
)
# Pack top k ids and expert weights into a single int32 tensor, as # Pack top k ids and expert weights into a single int32 tensor, as
# required by TRT-LLM # required by TRT-LLM
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to( packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
......
...@@ -233,8 +233,6 @@ def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool: ...@@ -233,8 +233,6 @@ def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
intermediate_size_per_partition = layer.intermediate_size_per_partition intermediate_size_per_partition = layer.intermediate_size_per_partition
# apply_router_weight_on_input is not supported for moe marlin # apply_router_weight_on_input is not supported for moe marlin
supports_router_weight = not layer.apply_router_weight_on_input supports_router_weight = not layer.apply_router_weight_on_input
# moe marlin requires the activation to be silu
supports_activation = layer.activation == "silu"
# gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size) # gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size)
# down: (n, k) = (hidden_size, intermediate_size_per_partition) # down: (n, k) = (hidden_size, intermediate_size_per_partition)
...@@ -244,12 +242,7 @@ def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool: ...@@ -244,12 +242,7 @@ def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
and intermediate_size_per_partition % max(64, group_size) == 0 and intermediate_size_per_partition % max(64, group_size) == 0
) )
supports_group_size = group_size in [-1, 32, 64, 128] supports_group_size = group_size in [-1, 32, 64, 128]
return ( return supports_shape and supports_group_size and supports_router_weight
supports_shape
and supports_group_size
and supports_router_weight
and supports_activation
)
def marlin_moe_intermediate_size(w1_packed: torch.Tensor, w2_packed: torch.Tensor): def marlin_moe_intermediate_size(w1_packed: torch.Tensor, w2_packed: torch.Tensor):
......
...@@ -235,6 +235,7 @@ def prepare_nvfp4_moe_layer_for_marlin( ...@@ -235,6 +235,7 @@ def prepare_nvfp4_moe_layer_for_marlin(
w2: torch.Tensor, w2: torch.Tensor,
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
w2_scale_2: torch.Tensor, w2_scale_2: torch.Tensor,
is_act_and_mul: bool,
) -> tuple[ ) -> tuple[
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
]: ]:
...@@ -266,8 +267,9 @@ def prepare_nvfp4_moe_layer_for_marlin( ...@@ -266,8 +267,9 @@ def prepare_nvfp4_moe_layer_for_marlin(
# Repack weights to marlin format # Repack weights to marlin format
def repack_weight(weight: torch.Tensor, name: str) -> torch.Tensor: def repack_weight(weight: torch.Tensor, name: str) -> torch.Tensor:
tensor_list = [] tensor_list = []
num_shards = 2 if is_act_and_mul else 1
if "w13" in name: if "w13" in name:
size_n, size_k = N * 2, K size_n, size_k = N * num_shards, K
else: else:
size_n, size_k = K, N size_n, size_k = K, N
...@@ -300,8 +302,9 @@ def prepare_nvfp4_moe_layer_for_marlin( ...@@ -300,8 +302,9 @@ def prepare_nvfp4_moe_layer_for_marlin(
g_scales = g_scales.to(param_dtype) g_scales = g_scales.to(param_dtype)
tensor_list = [] tensor_list = []
num_shards = 2 if is_act_and_mul else 1
if "w13" in name: if "w13" in name:
size_n, size_k = N * 2, K size_n, size_k = N * num_shards, K
else: else:
size_n, size_k = K, N size_n, size_k = K, N
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment