Unverified Commit ff1f83b0 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Refactor] Replace `activation: str` with `MoEActivation` enum (#33843)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
Signed-off-by: default avatarMichael Goin <mgoin64@gmail.com>
parent 83b47f67
......@@ -7,6 +7,10 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import (
MoEActivation,
apply_moe_activation,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
......@@ -25,7 +29,6 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
)
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache,
apply_moe_activation,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
......@@ -51,7 +54,7 @@ def run_cutlass_moe_fp8(
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
w1_scale: torch.Tensor | None,
......@@ -73,7 +76,7 @@ def run_cutlass_moe_fp8(
):
a1q = hidden_states
assert not activation.endswith("_no_mul"), "Only gated activation is supported"
assert activation.is_gated, "Only gated activation is supported"
assert w1_scale is not None
assert w2_scale is not None
assert w1.dtype == torch.float8_e4m3fn
......@@ -310,8 +313,12 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: str) -> bool:
return activation in ["silu", "gelu", "swigluoai"]
def _supports_activation(activation: MoEActivation) -> bool:
return activation in [
MoEActivation.SILU,
MoEActivation.GELU,
MoEActivation.SWIGLUOAI,
]
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
......@@ -325,7 +332,7 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
......@@ -415,7 +422,7 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
activation_out_dim = self.adjust_N_for_activation(N, activation)
workspace1 = (M * topk, max(N, K))
......@@ -456,7 +463,7 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
num_dp = self.num_dispatchers
assert num_dp is not None
......@@ -489,7 +496,7 @@ def run_cutlass_moe_fp4(
w2_alphas: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
activation: MoEActivation,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
m: int,
......@@ -612,7 +619,7 @@ def run_cutlass_moe_fp4(
blockscale_offsets[:-1],
)
del rep_a_fp4, rep_a_blockscale
if activation == "silu":
if activation == MoEActivation.SILU:
# Fused SiLU+Mul+NVFP4 quantization
# Note: c2 workspace is no longer needed since SiLU is fused with quantization.
# c3 reuses workspace13 after c1 is consumed.
......@@ -682,8 +689,12 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
return (weight_key, activation_key) == (kNvfp4Static, kNvfp4Dynamic)
@staticmethod
def _supports_activation(activation: str) -> bool:
return activation in ["silu", "gelu", "swigluoai"]
def _supports_activation(activation: MoEActivation) -> bool:
return activation in [
MoEActivation.SILU,
MoEActivation.GELU,
MoEActivation.SWIGLUOAI,
]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
......@@ -716,7 +727,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
workspace1 = (M * topk, max(2 * N, K))
workspace2 = (M * topk, N)
......@@ -731,7 +742,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None, # unused
......@@ -776,7 +787,7 @@ def run_cutlass_moe_w4a8_fp8(
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
w1_scale: torch.Tensor | None,
......@@ -970,7 +981,7 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
)
@staticmethod
def _supports_activation(activation: str) -> bool:
def _supports_activation(activation: MoEActivation) -> bool:
raise NotImplementedError(
"CutlassExpertsW4A8Fp8 is not yet used by an Oracle. "
"This method should not be called."
......@@ -1005,7 +1016,7 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
activation_out_dim = self.adjust_N_for_activation(N, activation)
workspace1 = (M * topk, max(N, K))
......@@ -1021,7 +1032,7 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
......@@ -1094,7 +1105,7 @@ def cutlass_moe_w4a8_fp8(
s_strides2: torch.Tensor,
quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
activation: str = "silu",
activation: MoEActivation = MoEActivation.SILU,
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
......@@ -1137,7 +1148,7 @@ def cutlass_moe_w4a8_fp8(
dtype: torch.int64
- per_act_token (Optional[bool]): Whether the scale is per-token or
per-tensor.
- activation (str): The activation function to use.
- activation (MoEActivation): The activation function to use.
- expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
every Rank is responsible for a subset of experts. expert_map is a
mapping from global expert-id to local expert-id. When expert_map[i]
......
......@@ -5,6 +5,7 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
......@@ -145,8 +146,8 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: str) -> bool:
return activation in ["silu", "swiglustep"]
def _supports_activation(activation: MoEActivation) -> bool:
return activation in [MoEActivation.SILU, MoEActivation.SWIGLUSTEP]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
......@@ -171,7 +172,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
assert self.block_shape is not None
block_m = self.block_shape[0]
......@@ -187,7 +188,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
return (workspace1, workspace2, output)
def _act_mul_quant(
self, input: torch.Tensor, output: torch.Tensor, activation: str
self, input: torch.Tensor, output: torch.Tensor, activation: MoEActivation
) -> tuple[torch.Tensor, torch.Tensor]:
assert self.block_shape is not None
block_k = self.block_shape[1]
......@@ -210,7 +211,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
return a2q, a2q_scale
# 2. Hopper / non‑E8M0: prefer the fused SiLU+mul+quant kernel
if activation == "silu":
if activation == MoEActivation.SILU:
use_ue8m0 = scale_fmt == DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0
return silu_mul_per_token_group_quant_fp8_colmajor(
input=input,
......@@ -235,7 +236,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
......
......@@ -6,6 +6,7 @@ from abc import ABC, abstractmethod
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
......@@ -76,7 +77,7 @@ class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
) and fallback_cls._supports_quant_scheme(weight_key, activation_key)
@classmethod
def _supports_activation(cls, activation: str) -> bool:
def _supports_activation(cls, activation: MoEActivation) -> bool:
experts_cls, fallback_cls = cls.get_clses()
return experts_cls._supports_activation(
activation
......@@ -138,7 +139,7 @@ class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
raise NotImplementedError
......@@ -159,7 +160,7 @@ class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
......
......@@ -6,6 +6,7 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
......@@ -72,8 +73,8 @@ class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute):
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: str) -> bool:
return activation in ["silu"]
def _supports_activation(activation: MoEActivation) -> bool:
return activation == MoEActivation.SILU
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
......@@ -101,7 +102,7 @@ class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# We use global_num_experts due to how moe_align_block_size handles
# expert_maps.
......@@ -135,7 +136,7 @@ class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute):
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
......
......@@ -5,6 +5,7 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEParallelConfig,
FusedMoEQuantConfig,
......@@ -130,8 +131,8 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
)
@staticmethod
def _supports_activation(activation: str) -> bool:
return activation in ["silu", "relu2_no_mul"]
def _supports_activation(activation: MoEActivation) -> bool:
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
......@@ -164,7 +165,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# We use global_num_experts due to how moe_align_block_size handles
# expert_maps.
......@@ -201,7 +202,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
......@@ -214,8 +215,8 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
from flashinfer.fused_moe.core import ActivationType
activation_str_to_value_map = {
"silu": ActivationType.Swiglu, # This is the default
"relu2_no_mul": ActivationType.Relu2,
MoEActivation.SILU: ActivationType.Swiglu, # This is the default
MoEActivation.RELU2_NO_MUL: ActivationType.Relu2,
}
assert activation in activation_str_to_value_map, (
f"{activation=} missing from {activation_str_to_value_map.keys()=}"
......
......@@ -4,6 +4,7 @@
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
......@@ -50,9 +51,9 @@ def _supports_quant_scheme(
return (weight_key, activation_key) in SUPPORTED_W_A
def _supports_activation(activation: str) -> bool:
def _supports_activation(activation: MoEActivation) -> bool:
"""Supports silu activation only."""
return activation in ["silu"]
return activation == MoEActivation.SILU
def _supports_routing_method(
......
......@@ -5,6 +5,7 @@
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
......@@ -698,7 +699,7 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
)
@staticmethod
def _supports_activation(activation: str) -> bool:
def _supports_activation(activation: MoEActivation) -> bool:
raise NotImplementedError(
"NaiveBatchedExperts is not yet used by an Oracle. "
"This method should not be called."
......@@ -730,7 +731,7 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
assert self.num_dispatchers is not None
assert self.max_num_tokens is not None
......@@ -757,7 +758,7 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
......@@ -942,14 +943,14 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
)
@staticmethod
def _supports_activation(activation: str) -> bool:
def _supports_activation(activation: MoEActivation) -> bool:
return activation in [
"silu",
"gelu",
"swigluoai",
"silu_no_mul",
"gelu_no_mul",
"relu2_no_mul",
MoEActivation.SILU,
MoEActivation.GELU,
MoEActivation.SWIGLUOAI,
MoEActivation.SILU_NO_MUL,
MoEActivation.GELU_NO_MUL,
MoEActivation.RELU2_NO_MUL,
]
@staticmethod
......@@ -975,7 +976,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
assert self.num_dispatchers is not None
assert self.max_num_tokens is not None
......@@ -996,7 +997,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
......
......@@ -8,6 +8,10 @@ import torch
import vllm._custom_ops as ops
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.activation import (
MoEActivation,
apply_moe_activation,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
......@@ -23,7 +27,6 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
)
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache,
apply_moe_activation,
disable_inplace,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
......@@ -59,9 +62,9 @@ def _fused_marlin_moe(
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
activation: str = "silu",
activation: MoEActivation = MoEActivation.SILU,
activation_func: Callable[
[str, torch.Tensor, torch.Tensor], None
[MoEActivation, torch.Tensor, torch.Tensor], None
] = apply_moe_activation,
input_global_scale1: torch.Tensor | None = None,
input_global_scale2: torch.Tensor | None = None,
......@@ -83,7 +86,7 @@ def _fused_marlin_moe(
assert hidden_states.ndim == 2
M, K = hidden_states.size()
N = marlin_moe_intermediate_size(w1, w2)
w13_num_shards = 1 if "no_mul" in activation else 2
w13_num_shards = 2 if activation.is_gated else 1
if workspace is None:
workspace = marlin_make_workspace_new(hidden_states.device, 4)
......@@ -215,9 +218,9 @@ def fused_marlin_moe(
quant_type_id: int,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
activation: str = "silu",
activation: MoEActivation = MoEActivation.SILU,
activation_func: Callable[
[str, torch.Tensor, torch.Tensor], None
[MoEActivation, torch.Tensor, torch.Tensor], None
] = apply_moe_activation,
moe_sum: Callable[[torch.Tensor, torch.Tensor], None] | None = None,
expert_map: torch.Tensor | None = None,
......@@ -377,7 +380,7 @@ def batched_fused_marlin_moe(
quant_type_id: int,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
activation: str | None = "silu",
activation: MoEActivation = MoEActivation.SILU,
expert_map: torch.Tensor | None = None,
global_scale1: torch.Tensor | None = None,
global_scale2: torch.Tensor | None = None,
......@@ -579,14 +582,14 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
return weight_key in SUPPORTED_W
@staticmethod
def _supports_activation(activation: str) -> bool:
def _supports_activation(activation: MoEActivation) -> bool:
return activation in [
"silu",
"gelu",
"swigluoai",
"silu_no_mul",
"gelu_no_mul",
"relu2_no_mul",
MoEActivation.SILU,
MoEActivation.GELU,
MoEActivation.SWIGLUOAI,
MoEActivation.SILU_NO_MUL,
MoEActivation.GELU_NO_MUL,
MoEActivation.RELU2_NO_MUL,
]
@staticmethod
......@@ -661,7 +664,7 @@ class MarlinExperts(MarlinExpertsBase):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# Modular Kernel provisions output buffer from workspace1. However in
# the fused_marlin_moe() function, the final torch.sum(), is defined
......@@ -692,7 +695,7 @@ class MarlinExperts(MarlinExpertsBase):
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
......@@ -788,7 +791,7 @@ class BatchedMarlinExperts(MarlinExpertsBase):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
assert self.num_dispatchers is not None
assert self.max_num_tokens is not None
......@@ -808,7 +811,7 @@ class BatchedMarlinExperts(MarlinExpertsBase):
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
......
......@@ -17,6 +17,10 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.model_executor.layers.fused_moe.activation import (
MoEActivation,
apply_moe_activation,
)
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
FusedMoEConfig,
......@@ -32,7 +36,6 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
)
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache,
apply_moe_activation,
disable_inplace,
moe_kernel_quantize_input,
)
......@@ -1468,6 +1471,7 @@ def outplace_fused_experts_fake(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
......@@ -1521,7 +1525,7 @@ def fused_experts(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
activation: MoEActivation = MoEActivation.SILU,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
......@@ -1539,7 +1543,7 @@ def fused_experts(
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
activation=activation.value,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=quant_config.use_fp8_w8a8,
use_int8_w8a8=quant_config.use_int8_w8a8,
......@@ -1618,6 +1622,9 @@ def fused_experts_impl(
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
) -> torch.Tensor:
# Convert string activation to enum for internal use
activation_enum = MoEActivation.from_str(activation)
# Check constraints.
if use_int4_w4a16:
assert hidden_states.size(1) // 2 == w1.size(2), "Hidden size mismatch"
......@@ -1692,7 +1699,7 @@ def fused_experts_impl(
# This needs separate memory since it's used concurrently with cache1
activation_out_dim = mk.FusedMoEPermuteExpertsUnpermute.adjust_N_for_activation(
N, activation
N, activation_enum
)
intermediate_cache2 = torch.empty(
(M * top_k_num, activation_out_dim),
......@@ -1832,7 +1839,7 @@ def fused_experts_impl(
)
apply_moe_activation(
activation, intermediate_cache2, intermediate_cache1.view(-1, N)
activation_enum, intermediate_cache2, intermediate_cache1.view(-1, N)
)
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
......@@ -1932,8 +1939,13 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: str) -> bool:
return activation in ["silu", "gelu", "swigluoai", "swiglustep"]
def _supports_activation(activation: MoEActivation) -> bool:
return activation in [
MoEActivation.SILU,
MoEActivation.GELU,
MoEActivation.SWIGLUOAI,
MoEActivation.SWIGLUSTEP,
]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
......@@ -1957,7 +1969,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
activation_out_dim = self.adjust_N_for_activation(N, activation)
workspace1 = (M, topk, max(activation_out_dim, K))
......@@ -1973,7 +1985,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
......@@ -2138,7 +2150,7 @@ class TritonWNA16Experts(TritonExperts):
)
@staticmethod
def _supports_activation(activation: str) -> bool:
def _supports_activation(activation: MoEActivation) -> bool:
raise NotImplementedError(
"TritonWNA16Experts is not yet used by an Oracle. "
"This method should not be called."
......@@ -2159,7 +2171,7 @@ class TritonWNA16Experts(TritonExperts):
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
......
......@@ -7,6 +7,7 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
FusedMoEParallelConfig,
......@@ -172,7 +173,7 @@ def triton_kernel_moe_forward(
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
activation: str = "silu",
activation: MoEActivation = MoEActivation.SWIGLUOAI,
quant_config: FusedMoEQuantConfig | None = None,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
......@@ -211,7 +212,7 @@ def triton_kernel_fused_experts(
gather_indx, # GatherIndx
scatter_indx, # ScatterIndx
topk: int,
activation: str = "silu",
activation: MoEActivation = MoEActivation.SWIGLUOAI,
quant_config: FusedMoEQuantConfig | None = None,
swiglu_alpha: float = 1.702,
swiglu_limit: float = 7.0,
......@@ -222,6 +223,9 @@ def triton_kernel_fused_experts(
a1q_scale: torch.Tensor | None = None,
) -> torch.Tensor:
"""Triton implementation of fused expert computation using OAI kernels."""
assert activation == MoEActivation.SWIGLUOAI, (
"Only SWIGLUOAI activation is supported"
)
if quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
......@@ -379,7 +383,7 @@ class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
)
@staticmethod
def _supports_activation(activation: str) -> bool:
def _supports_activation(activation: MoEActivation) -> bool:
raise NotImplementedError(
"OAITritonExperts is not yet used by an Oracle. "
"This method should not be called."
......@@ -463,7 +467,7 @@ class OAITritonExperts(BaseOAITritonExperts):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# workspace are allocated inside the kernel
activation_out_dim = self.adjust_N_for_activation(N, activation)
......@@ -480,7 +484,7 @@ class OAITritonExperts(BaseOAITritonExperts):
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
......@@ -547,7 +551,7 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# workspace are allocated inside the kernel
activation_out_dim = self.adjust_N_for_activation(N, activation)
......@@ -567,7 +571,7 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
......
......@@ -20,6 +20,7 @@ from vllm.distributed import (
from vllm.distributed.eplb.eplb_state import EplbLayerState, EplbState
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
......@@ -500,7 +501,7 @@ class FusedMoE(CustomOp):
# TODO(bnell): end attributes
self.apply_router_weight_on_input = apply_router_weight_on_input
self.activation = activation
self.activation = MoEActivation.from_str(activation)
self.router = create_fused_moe_router(
top_k=top_k,
......@@ -554,7 +555,7 @@ class FusedMoE(CustomOp):
has_bias=has_bias,
is_act_and_mul=is_act_and_mul,
is_lora_enabled=vllm_config.lora_config is not None,
activation=activation,
activation=self.activation,
device=vllm_config.device_config.device,
routing_method=self.routing_method_type,
# TODO: in_dtype == out_dtype?
......
......@@ -12,6 +12,10 @@ import torch
import vllm.envs as envs
from vllm.forward_context import get_forward_context, is_forward_context_available
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import (
MoEActivation,
apply_moe_activation,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
......@@ -19,7 +23,6 @@ from vllm.model_executor.layers.fused_moe.config import (
)
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache,
apply_moe_activation,
count_expert_num_tokens,
disable_inplace,
)
......@@ -536,7 +539,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
@staticmethod
@abstractmethod
def _supports_activation(activation: str) -> bool:
def _supports_activation(activation: MoEActivation) -> bool:
"""
Whether the kernel supports a particular act function.
"""
......@@ -658,7 +661,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: ExpertTokensMetadata | None,
activation: str,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
"""
Compute the shapes for the temporary and final outputs of the two gemms
......@@ -690,7 +693,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
raise NotImplementedError
@staticmethod
def adjust_N_for_activation(N: int, activation: str) -> int:
def adjust_N_for_activation(N: int, activation: MoEActivation) -> int:
"""
Calculate the output dimension for the activation function.
......@@ -702,16 +705,15 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
Args:
N: The intermediate size (width of w1/w3 weights).
activation: The activation function name.
activation: The activation function enum.
Returns:
The output dimension after activation.
"""
is_no_mul = activation.endswith("_no_mul")
return N if is_no_mul else N // 2
return N if not activation.is_gated else N // 2
def activation(
self, activation: str, output: torch.Tensor, input: torch.Tensor
self, activation: MoEActivation, output: torch.Tensor, input: torch.Tensor
) -> None:
apply_moe_activation(activation, output, input)
......@@ -732,7 +734,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
......@@ -892,7 +894,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: ExpertTokensMetadata | None,
activation: str,
activation: MoEActivation,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Allocate temporary and output buffers for the fused experts op.
......@@ -1135,7 +1137,7 @@ class FusedMoEModularKernel(torch.nn.Module):
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
activation: MoEActivation,
global_num_experts: int,
local_num_experts: int,
expert_map: torch.Tensor | None,
......@@ -1309,7 +1311,7 @@ class FusedMoEModularKernel(torch.nn.Module):
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
activation: MoEActivation = MoEActivation.SILU,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
......@@ -1326,7 +1328,7 @@ class FusedMoEModularKernel(torch.nn.Module):
- topk_weights (torch.Tensor): The topk weights applied at the end of
the layer.
- topk_ids (torch.Tensor): A map of row to expert id.
- activation (str): The activation function to apply after the first
- activation (MoEActivation): The activation function to apply after the first
MoE layer.
- global_num_experts (int): The total number of experts in the global
expert space.
......
......@@ -7,6 +7,7 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
FusedMoEParallelConfig,
......@@ -184,7 +185,7 @@ def rocm_aiter_fused_experts(
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
activation: MoEActivation = MoEActivation.SILU,
apply_router_weight_on_input: bool = False,
expert_map: torch.Tensor | None = None,
quant_config: FusedMoEQuantConfig | None = None,
......@@ -196,9 +197,13 @@ def rocm_aiter_fused_experts(
if quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
activation_method = (
ActivationMethod.SILU if activation == "silu" else ActivationMethod.GELU
)
if activation == MoEActivation.SILU:
activation_method = ActivationMethod.SILU
elif activation == MoEActivation.GELU:
activation_method = ActivationMethod.GELU
else:
raise ValueError(f"Unsupported activation: {activation}")
# All AITER Fused MoE kernels are expecting the following datatypes
topk_weights = topk_weights.to(torch.float32)
topk_ids = topk_ids.to(torch.int32)
......@@ -322,8 +327,8 @@ class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: str) -> bool:
return activation in ["silu", "gelu"]
def _supports_activation(activation: MoEActivation) -> bool:
return activation in [MoEActivation.SILU, MoEActivation.GELU]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
......@@ -347,7 +352,7 @@ class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# Workspaces are managed internally by AITER.
workspace1 = (0,)
......@@ -363,7 +368,7 @@ class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
......
......@@ -5,6 +5,7 @@
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
......@@ -45,7 +46,7 @@ class TritonOrCutlassExperts(FallbackExperts):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# Small batch fallback for sm100.
if self.is_sm100 and M <= 8:
......
......@@ -4,6 +4,7 @@
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
......@@ -45,7 +46,7 @@ class TritonOrDeepGemmExperts(FallbackExperts):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm
......
......@@ -4,6 +4,7 @@
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
......@@ -64,7 +65,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
)
@staticmethod
def _supports_activation(activation: str) -> bool:
def _supports_activation(activation: MoEActivation) -> bool:
raise NotImplementedError(
"TrtLlmGenExperts is not yet used by an Oracle. "
"This method should not be called."
......@@ -95,7 +96,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# The workspaces for this implementation are managed by flashinfer.
workspace1 = (0,)
......@@ -111,7 +112,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
......
......@@ -4,7 +4,6 @@ import functools
from math import prod
import torch
import torch.nn.functional as F
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
......@@ -341,65 +340,6 @@ def _validate_scale_shape(
assert a_scale.shape == expected, f"{a_scale.shape} == {expected}"
def activation_without_mul(activation: str) -> str:
return activation + "_no_mul"
RELU2_NO_MUL: str = activation_without_mul("relu2")
SILU_NO_MUL: str = activation_without_mul("silu")
GELU_NO_MUL: str = activation_without_mul("gelu")
def apply_moe_activation(
activation: str,
output: torch.Tensor,
input: torch.Tensor,
) -> torch.Tensor:
"""
Apply MoE activation function.
For *_and_mul activations (silu, gelu, swigluoai):
- Expects output.size(-1) * 2 == input.size(-1)
For *_no_mul activations (silu_no_mul, gelu_no_mul, relu2_no_mul):
- Expects output.size(-1) == input.size(-1)
"""
is_no_mul = activation.endswith("_no_mul")
if is_no_mul:
assert output.size(-1) == input.size(-1), (
f"{activation} expects equal sizes: {output.size(-1)} vs {input.size(-1)}"
)
else:
assert output.size(-1) * 2 == input.size(-1), (
f"{activation} expects 2x ratio: {output.size(-1) * 2} vs {input.size(-1)}"
)
# Activations with gated multiplication (gate × activation(up))
if activation == "silu":
torch.ops._C.silu_and_mul(output, input)
elif activation == "gelu":
torch.ops._C.gelu_and_mul(output, input)
elif activation == "swigluoai":
torch.ops._C.swigluoai_and_mul(output, input)
elif activation == "swiglustep":
from vllm.model_executor.layers.activation import swiglustep_and_mul_triton
swiglustep_and_mul_triton(output, input)
# Activations without gated multiplication
elif activation == SILU_NO_MUL:
output.copy_(F.silu(input))
elif activation == GELU_NO_MUL:
output.copy_(F.gelu(input))
elif activation == RELU2_NO_MUL:
F.relu(input, inplace=True)
torch.square(input, out=output)
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
return output
# Torch custom ops can't deal with outputs aliasing inputs so we need to
# disable inplace for torch >= 2.9.
# See https://github.com/vllm-project/vllm/issues/26378
......
......@@ -3,6 +3,7 @@
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
......@@ -55,8 +56,12 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
return False
@staticmethod
def _supports_activation(activation: str) -> bool:
return activation in ["silu", "gelu", "swigluoai"]
def _supports_activation(activation: MoEActivation) -> bool:
return activation in [
MoEActivation.SILU,
MoEActivation.GELU,
MoEActivation.SWIGLUOAI,
]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
......@@ -92,7 +97,7 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
workspace1 = (0,)
workspace2 = (0,)
......@@ -107,7 +112,7 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
......@@ -129,7 +134,7 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
topk_weights=topk_weights,
topk_ids=topk_ids,
n_experts_per_token=topk,
activation=activation,
activation=activation.value,
num_experts=self.moe_config.num_local_experts,
ep_rank=self.moe_config.ep_rank,
ep_size=self.moe_config.ep_size,
......
......@@ -24,6 +24,7 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoeWeightScaleSupported,
UnquantizedFusedMoEMethod,
)
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
......@@ -622,7 +623,9 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.is_monolithic
assert layer.activation == "silu", "Only SiLU activation is supported."
assert layer.activation == MoEActivation.SILU, (
f"Only SiLU activation is supported, not {layer.activation}."
)
assert (
self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
and not layer.enable_eplb
......@@ -649,7 +652,9 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
assert layer.activation == "silu", "Only SiLU activation is supported."
assert layer.activation == MoEActivation.SILU, (
f"Only SiLU activation is supported, not {layer.activation}."
)
# EPLB path
if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
......@@ -1025,7 +1030,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.is_monolithic
assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
assert layer.activation == "silu"
assert layer.activation == MoEActivation.SILU, (
f"Only SiLU activation is supported, not {layer.activation}."
)
if self.block_quant:
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
......@@ -2271,19 +2278,21 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
router_logits: torch.Tensor,
) -> torch.Tensor:
assert not layer.enable_eplb, "EPLB not supported for W4A8-int MoE yet."
assert layer.activation in ("silu", "swigluoai", "swiglu"), (
"Only SiLU/SwiGLUGU/SwiGLUUG are supported."
)
assert layer.activation in (
MoEActivation.SILU,
MoEActivation.SWIGLUOAI,
MoEActivation.SWIGLUSTEP,
), "Only SiLU/SwiGLUGU/SwiGLUUG are supported."
assert layer.expert_map is None, """expert_map/EP not implemented
for CPU dyn-4bit MoE."""
def _act_kind(s: str) -> int:
def _act_kind(s: MoEActivation) -> int:
# 0 = SwiGLU_Gu (SiLU(g)*u), 1 = SwiGLU_Ug (SiLU(u)*g), 2 = SiLU
if s == "swiglu":
if s == MoEActivation.SWIGLUSTEP:
return 0
if s == "swigluoai":
if s == MoEActivation.SWIGLUOAI:
return 1
if s == "silu":
if s == MoEActivation.SILU:
return 2
raise ValueError(f"Unknown activation '{s}'")
......
......@@ -23,6 +23,7 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize,
FusedMoeWeightScaleSupported,
MoEActivation,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
......@@ -965,7 +966,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# TODO(rob): convert this to MK.
if layer.enable_eplb:
raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
assert layer.activation == "silu", (
assert layer.activation == MoEActivation.SILU, (
f"Expected 'silu' activation but got {layer.activation}"
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment