Unverified Commit ff1f83b0 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Refactor] Replace `activation: str` with `MoEActivation` enum (#33843)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
Signed-off-by: default avatarMichael Goin <mgoin64@gmail.com>
parent 83b47f67
...@@ -7,6 +7,10 @@ import torch ...@@ -7,6 +7,10 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import (
MoEActivation,
apply_moe_activation,
)
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEParallelConfig, FusedMoEParallelConfig,
...@@ -25,7 +29,6 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( ...@@ -25,7 +29,6 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
) )
from vllm.model_executor.layers.fused_moe.utils import ( from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, _resize_cache,
apply_moe_activation,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, QuantKey,
...@@ -51,7 +54,7 @@ def run_cutlass_moe_fp8( ...@@ -51,7 +54,7 @@ def run_cutlass_moe_fp8(
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str, activation: MoEActivation,
global_num_experts: int, global_num_experts: int,
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
w1_scale: torch.Tensor | None, w1_scale: torch.Tensor | None,
...@@ -73,7 +76,7 @@ def run_cutlass_moe_fp8( ...@@ -73,7 +76,7 @@ def run_cutlass_moe_fp8(
): ):
a1q = hidden_states a1q = hidden_states
assert not activation.endswith("_no_mul"), "Only gated activation is supported" assert activation.is_gated, "Only gated activation is supported"
assert w1_scale is not None assert w1_scale is not None
assert w2_scale is not None assert w2_scale is not None
assert w1.dtype == torch.float8_e4m3fn assert w1.dtype == torch.float8_e4m3fn
...@@ -310,8 +313,12 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -310,8 +313,12 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
return (weight_key, activation_key) in SUPPORTED_W_A return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod @staticmethod
def _supports_activation(activation: str) -> bool: def _supports_activation(activation: MoEActivation) -> bool:
return activation in ["silu", "gelu", "swigluoai"] return activation in [
MoEActivation.SILU,
MoEActivation.GELU,
MoEActivation.SWIGLUOAI,
]
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl. # Let PrepareAndFinalize::finalize() decide the impl.
...@@ -325,7 +332,7 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -325,7 +332,7 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str, activation: MoEActivation,
global_num_experts: int, global_num_experts: int,
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None, a1q_scale: torch.Tensor | None,
...@@ -415,7 +422,7 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base): ...@@ -415,7 +422,7 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str, activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
activation_out_dim = self.adjust_N_for_activation(N, activation) activation_out_dim = self.adjust_N_for_activation(N, activation)
workspace1 = (M * topk, max(N, K)) workspace1 = (M * topk, max(N, K))
...@@ -456,7 +463,7 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base): ...@@ -456,7 +463,7 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str, activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
num_dp = self.num_dispatchers num_dp = self.num_dispatchers
assert num_dp is not None assert num_dp is not None
...@@ -489,7 +496,7 @@ def run_cutlass_moe_fp4( ...@@ -489,7 +496,7 @@ def run_cutlass_moe_fp4(
w2_alphas: torch.Tensor, w2_alphas: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str, activation: MoEActivation,
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
m: int, m: int,
...@@ -612,7 +619,7 @@ def run_cutlass_moe_fp4( ...@@ -612,7 +619,7 @@ def run_cutlass_moe_fp4(
blockscale_offsets[:-1], blockscale_offsets[:-1],
) )
del rep_a_fp4, rep_a_blockscale del rep_a_fp4, rep_a_blockscale
if activation == "silu": if activation == MoEActivation.SILU:
# Fused SiLU+Mul+NVFP4 quantization # Fused SiLU+Mul+NVFP4 quantization
# Note: c2 workspace is no longer needed since SiLU is fused with quantization. # Note: c2 workspace is no longer needed since SiLU is fused with quantization.
# c3 reuses workspace13 after c1 is consumed. # c3 reuses workspace13 after c1 is consumed.
...@@ -682,8 +689,12 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -682,8 +689,12 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
return (weight_key, activation_key) == (kNvfp4Static, kNvfp4Dynamic) return (weight_key, activation_key) == (kNvfp4Static, kNvfp4Dynamic)
@staticmethod @staticmethod
def _supports_activation(activation: str) -> bool: def _supports_activation(activation: MoEActivation) -> bool:
return activation in ["silu", "gelu", "swigluoai"] return activation in [
MoEActivation.SILU,
MoEActivation.GELU,
MoEActivation.SWIGLUOAI,
]
@staticmethod @staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
...@@ -716,7 +727,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -716,7 +727,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str, activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
workspace1 = (M * topk, max(2 * N, K)) workspace1 = (M * topk, max(2 * N, K))
workspace2 = (M * topk, N) workspace2 = (M * topk, N)
...@@ -731,7 +742,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -731,7 +742,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str, activation: MoEActivation,
global_num_experts: int, global_num_experts: int,
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None, # unused a1q_scale: torch.Tensor | None, # unused
...@@ -776,7 +787,7 @@ def run_cutlass_moe_w4a8_fp8( ...@@ -776,7 +787,7 @@ def run_cutlass_moe_w4a8_fp8(
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str, activation: MoEActivation,
global_num_experts: int, global_num_experts: int,
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
w1_scale: torch.Tensor | None, w1_scale: torch.Tensor | None,
...@@ -970,7 +981,7 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -970,7 +981,7 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
) )
@staticmethod @staticmethod
def _supports_activation(activation: str) -> bool: def _supports_activation(activation: MoEActivation) -> bool:
raise NotImplementedError( raise NotImplementedError(
"CutlassExpertsW4A8Fp8 is not yet used by an Oracle. " "CutlassExpertsW4A8Fp8 is not yet used by an Oracle. "
"This method should not be called." "This method should not be called."
...@@ -1005,7 +1016,7 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1005,7 +1016,7 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str, activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
activation_out_dim = self.adjust_N_for_activation(N, activation) activation_out_dim = self.adjust_N_for_activation(N, activation)
workspace1 = (M * topk, max(N, K)) workspace1 = (M * topk, max(N, K))
...@@ -1021,7 +1032,7 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1021,7 +1032,7 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str, activation: MoEActivation,
global_num_experts: int, global_num_experts: int,
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None, a1q_scale: torch.Tensor | None,
...@@ -1094,7 +1105,7 @@ def cutlass_moe_w4a8_fp8( ...@@ -1094,7 +1105,7 @@ def cutlass_moe_w4a8_fp8(
s_strides2: torch.Tensor, s_strides2: torch.Tensor,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
activation: str = "silu", activation: MoEActivation = MoEActivation.SILU,
expert_map: torch.Tensor | None = None, expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
...@@ -1137,7 +1148,7 @@ def cutlass_moe_w4a8_fp8( ...@@ -1137,7 +1148,7 @@ def cutlass_moe_w4a8_fp8(
dtype: torch.int64 dtype: torch.int64
- per_act_token (Optional[bool]): Whether the scale is per-token or - per_act_token (Optional[bool]): Whether the scale is per-token or
per-tensor. per-tensor.
- activation (str): The activation function to use. - activation (MoEActivation): The activation function to use.
- expert_map (Optional[torch.Tensor]): In the case of Expert parallel, - expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
every Rank is responsible for a subset of experts. expert_map is a every Rank is responsible for a subset of experts. expert_map is a
mapping from global expert-id to local expert-id. When expert_map[i] mapping from global expert-id to local expert-id. When expert_map[i]
......
...@@ -5,6 +5,7 @@ import torch ...@@ -5,6 +5,7 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEParallelConfig, FusedMoEParallelConfig,
...@@ -145,8 +146,8 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -145,8 +146,8 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
return (weight_key, activation_key) in SUPPORTED_W_A return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod @staticmethod
def _supports_activation(activation: str) -> bool: def _supports_activation(activation: MoEActivation) -> bool:
return activation in ["silu", "swiglustep"] return activation in [MoEActivation.SILU, MoEActivation.SWIGLUSTEP]
@staticmethod @staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
...@@ -171,7 +172,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -171,7 +172,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str, activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
assert self.block_shape is not None assert self.block_shape is not None
block_m = self.block_shape[0] block_m = self.block_shape[0]
...@@ -187,7 +188,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -187,7 +188,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
return (workspace1, workspace2, output) return (workspace1, workspace2, output)
def _act_mul_quant( def _act_mul_quant(
self, input: torch.Tensor, output: torch.Tensor, activation: str self, input: torch.Tensor, output: torch.Tensor, activation: MoEActivation
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
assert self.block_shape is not None assert self.block_shape is not None
block_k = self.block_shape[1] block_k = self.block_shape[1]
...@@ -210,7 +211,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -210,7 +211,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
return a2q, a2q_scale return a2q, a2q_scale
# 2. Hopper / non‑E8M0: prefer the fused SiLU+mul+quant kernel # 2. Hopper / non‑E8M0: prefer the fused SiLU+mul+quant kernel
if activation == "silu": if activation == MoEActivation.SILU:
use_ue8m0 = scale_fmt == DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0 use_ue8m0 = scale_fmt == DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0
return silu_mul_per_token_group_quant_fp8_colmajor( return silu_mul_per_token_group_quant_fp8_colmajor(
input=input, input=input,
...@@ -235,7 +236,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -235,7 +236,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str, activation: MoEActivation,
global_num_experts: int, global_num_experts: int,
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None, a1q_scale: torch.Tensor | None,
......
...@@ -6,6 +6,7 @@ from abc import ABC, abstractmethod ...@@ -6,6 +6,7 @@ from abc import ABC, abstractmethod
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
...@@ -76,7 +77,7 @@ class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC): ...@@ -76,7 +77,7 @@ class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
) and fallback_cls._supports_quant_scheme(weight_key, activation_key) ) and fallback_cls._supports_quant_scheme(weight_key, activation_key)
@classmethod @classmethod
def _supports_activation(cls, activation: str) -> bool: def _supports_activation(cls, activation: MoEActivation) -> bool:
experts_cls, fallback_cls = cls.get_clses() experts_cls, fallback_cls = cls.get_clses()
return experts_cls._supports_activation( return experts_cls._supports_activation(
activation activation
...@@ -138,7 +139,7 @@ class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC): ...@@ -138,7 +139,7 @@ class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str, activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
raise NotImplementedError raise NotImplementedError
...@@ -159,7 +160,7 @@ class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC): ...@@ -159,7 +160,7 @@ class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str, activation: MoEActivation,
global_num_experts: int, global_num_experts: int,
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None, a1q_scale: torch.Tensor | None,
......
...@@ -6,6 +6,7 @@ import torch ...@@ -6,6 +6,7 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs from vllm import envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEParallelConfig, FusedMoEParallelConfig,
...@@ -72,8 +73,8 @@ class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -72,8 +73,8 @@ class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute):
return (weight_key, activation_key) in SUPPORTED_W_A return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod @staticmethod
def _supports_activation(activation: str) -> bool: def _supports_activation(activation: MoEActivation) -> bool:
return activation in ["silu"] return activation == MoEActivation.SILU
@staticmethod @staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
...@@ -101,7 +102,7 @@ class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -101,7 +102,7 @@ class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str, activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# We use global_num_experts due to how moe_align_block_size handles # We use global_num_experts due to how moe_align_block_size handles
# expert_maps. # expert_maps.
...@@ -135,7 +136,7 @@ class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -135,7 +136,7 @@ class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute):
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str, activation: MoEActivation,
global_num_experts: int, global_num_experts: int,
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None, a1q_scale: torch.Tensor | None,
......
...@@ -5,6 +5,7 @@ import torch ...@@ -5,6 +5,7 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEParallelConfig, FusedMoEParallelConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
...@@ -130,8 +131,8 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -130,8 +131,8 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
) )
@staticmethod @staticmethod
def _supports_activation(activation: str) -> bool: def _supports_activation(activation: MoEActivation) -> bool:
return activation in ["silu", "relu2_no_mul"] return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
@staticmethod @staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
...@@ -164,7 +165,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -164,7 +165,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str, activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# We use global_num_experts due to how moe_align_block_size handles # We use global_num_experts due to how moe_align_block_size handles
# expert_maps. # expert_maps.
...@@ -201,7 +202,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -201,7 +202,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str, activation: MoEActivation,
global_num_experts: int, global_num_experts: int,
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None, a1q_scale: torch.Tensor | None,
...@@ -214,8 +215,8 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -214,8 +215,8 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
from flashinfer.fused_moe.core import ActivationType from flashinfer.fused_moe.core import ActivationType
activation_str_to_value_map = { activation_str_to_value_map = {
"silu": ActivationType.Swiglu, # This is the default MoEActivation.SILU: ActivationType.Swiglu, # This is the default
"relu2_no_mul": ActivationType.Relu2, MoEActivation.RELU2_NO_MUL: ActivationType.Relu2,
} }
assert activation in activation_str_to_value_map, ( assert activation in activation_str_to_value_map, (
f"{activation=} missing from {activation_str_to_value_map.keys()=}" f"{activation=} missing from {activation_str_to_value_map.keys()=}"
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEParallelConfig, FusedMoEParallelConfig,
...@@ -50,9 +51,9 @@ def _supports_quant_scheme( ...@@ -50,9 +51,9 @@ def _supports_quant_scheme(
return (weight_key, activation_key) in SUPPORTED_W_A return (weight_key, activation_key) in SUPPORTED_W_A
def _supports_activation(activation: str) -> bool: def _supports_activation(activation: MoEActivation) -> bool:
"""Supports silu activation only.""" """Supports silu activation only."""
return activation in ["silu"] return activation == MoEActivation.SILU
def _supports_routing_method( def _supports_routing_method(
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEParallelConfig, FusedMoEParallelConfig,
...@@ -698,7 +699,7 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -698,7 +699,7 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
) )
@staticmethod @staticmethod
def _supports_activation(activation: str) -> bool: def _supports_activation(activation: MoEActivation) -> bool:
raise NotImplementedError( raise NotImplementedError(
"NaiveBatchedExperts is not yet used by an Oracle. " "NaiveBatchedExperts is not yet used by an Oracle. "
"This method should not be called." "This method should not be called."
...@@ -730,7 +731,7 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -730,7 +731,7 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str, activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
assert self.num_dispatchers is not None assert self.num_dispatchers is not None
assert self.max_num_tokens is not None assert self.max_num_tokens is not None
...@@ -757,7 +758,7 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -757,7 +758,7 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str, activation: MoEActivation,
global_num_experts: int, global_num_experts: int,
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None, a1q_scale: torch.Tensor | None,
...@@ -942,14 +943,14 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -942,14 +943,14 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
) )
@staticmethod @staticmethod
def _supports_activation(activation: str) -> bool: def _supports_activation(activation: MoEActivation) -> bool:
return activation in [ return activation in [
"silu", MoEActivation.SILU,
"gelu", MoEActivation.GELU,
"swigluoai", MoEActivation.SWIGLUOAI,
"silu_no_mul", MoEActivation.SILU_NO_MUL,
"gelu_no_mul", MoEActivation.GELU_NO_MUL,
"relu2_no_mul", MoEActivation.RELU2_NO_MUL,
] ]
@staticmethod @staticmethod
...@@ -975,7 +976,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -975,7 +976,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str, activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
assert self.num_dispatchers is not None assert self.num_dispatchers is not None
assert self.max_num_tokens is not None assert self.max_num_tokens is not None
...@@ -996,7 +997,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -996,7 +997,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str, activation: MoEActivation,
global_num_experts: int, global_num_experts: int,
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None, a1q_scale: torch.Tensor | None,
......
...@@ -8,6 +8,10 @@ import torch ...@@ -8,6 +8,10 @@ import torch
import vllm._custom_ops as ops import vllm._custom_ops as ops
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.activation import (
MoEActivation,
apply_moe_activation,
)
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEParallelConfig, FusedMoEParallelConfig,
...@@ -23,7 +27,6 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( ...@@ -23,7 +27,6 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
) )
from vllm.model_executor.layers.fused_moe.utils import ( from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, _resize_cache,
apply_moe_activation,
disable_inplace, disable_inplace,
) )
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
...@@ -59,9 +62,9 @@ def _fused_marlin_moe( ...@@ -59,9 +62,9 @@ def _fused_marlin_moe(
sorted_token_ids: torch.Tensor, sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor, expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor, num_tokens_post_padded: torch.Tensor,
activation: str = "silu", activation: MoEActivation = MoEActivation.SILU,
activation_func: Callable[ activation_func: Callable[
[str, torch.Tensor, torch.Tensor], None [MoEActivation, torch.Tensor, torch.Tensor], None
] = apply_moe_activation, ] = apply_moe_activation,
input_global_scale1: torch.Tensor | None = None, input_global_scale1: torch.Tensor | None = None,
input_global_scale2: torch.Tensor | None = None, input_global_scale2: torch.Tensor | None = None,
...@@ -83,7 +86,7 @@ def _fused_marlin_moe( ...@@ -83,7 +86,7 @@ def _fused_marlin_moe(
assert hidden_states.ndim == 2 assert hidden_states.ndim == 2
M, K = hidden_states.size() M, K = hidden_states.size()
N = marlin_moe_intermediate_size(w1, w2) N = marlin_moe_intermediate_size(w1, w2)
w13_num_shards = 1 if "no_mul" in activation else 2 w13_num_shards = 2 if activation.is_gated else 1
if workspace is None: if workspace is None:
workspace = marlin_make_workspace_new(hidden_states.device, 4) workspace = marlin_make_workspace_new(hidden_states.device, 4)
...@@ -215,9 +218,9 @@ def fused_marlin_moe( ...@@ -215,9 +218,9 @@ def fused_marlin_moe(
quant_type_id: int, quant_type_id: int,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
activation: str = "silu", activation: MoEActivation = MoEActivation.SILU,
activation_func: Callable[ activation_func: Callable[
[str, torch.Tensor, torch.Tensor], None [MoEActivation, torch.Tensor, torch.Tensor], None
] = apply_moe_activation, ] = apply_moe_activation,
moe_sum: Callable[[torch.Tensor, torch.Tensor], None] | None = None, moe_sum: Callable[[torch.Tensor, torch.Tensor], None] | None = None,
expert_map: torch.Tensor | None = None, expert_map: torch.Tensor | None = None,
...@@ -377,7 +380,7 @@ def batched_fused_marlin_moe( ...@@ -377,7 +380,7 @@ def batched_fused_marlin_moe(
quant_type_id: int, quant_type_id: int,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
activation: str | None = "silu", activation: MoEActivation = MoEActivation.SILU,
expert_map: torch.Tensor | None = None, expert_map: torch.Tensor | None = None,
global_scale1: torch.Tensor | None = None, global_scale1: torch.Tensor | None = None,
global_scale2: torch.Tensor | None = None, global_scale2: torch.Tensor | None = None,
...@@ -579,14 +582,14 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -579,14 +582,14 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
return weight_key in SUPPORTED_W return weight_key in SUPPORTED_W
@staticmethod @staticmethod
def _supports_activation(activation: str) -> bool: def _supports_activation(activation: MoEActivation) -> bool:
return activation in [ return activation in [
"silu", MoEActivation.SILU,
"gelu", MoEActivation.GELU,
"swigluoai", MoEActivation.SWIGLUOAI,
"silu_no_mul", MoEActivation.SILU_NO_MUL,
"gelu_no_mul", MoEActivation.GELU_NO_MUL,
"relu2_no_mul", MoEActivation.RELU2_NO_MUL,
] ]
@staticmethod @staticmethod
...@@ -661,7 +664,7 @@ class MarlinExperts(MarlinExpertsBase): ...@@ -661,7 +664,7 @@ class MarlinExperts(MarlinExpertsBase):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str, activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# Modular Kernel provisions output buffer from workspace1. However in # Modular Kernel provisions output buffer from workspace1. However in
# the fused_marlin_moe() function, the final torch.sum(), is defined # the fused_marlin_moe() function, the final torch.sum(), is defined
...@@ -692,7 +695,7 @@ class MarlinExperts(MarlinExpertsBase): ...@@ -692,7 +695,7 @@ class MarlinExperts(MarlinExpertsBase):
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str, activation: MoEActivation,
global_num_experts: int, global_num_experts: int,
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None, a1q_scale: torch.Tensor | None,
...@@ -788,7 +791,7 @@ class BatchedMarlinExperts(MarlinExpertsBase): ...@@ -788,7 +791,7 @@ class BatchedMarlinExperts(MarlinExpertsBase):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str, activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
assert self.num_dispatchers is not None assert self.num_dispatchers is not None
assert self.max_num_tokens is not None assert self.max_num_tokens is not None
...@@ -808,7 +811,7 @@ class BatchedMarlinExperts(MarlinExpertsBase): ...@@ -808,7 +811,7 @@ class BatchedMarlinExperts(MarlinExpertsBase):
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str, activation: MoEActivation,
global_num_experts: int, global_num_experts: int,
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None, a1q_scale: torch.Tensor | None,
......
...@@ -17,6 +17,10 @@ from vllm.logger import init_logger ...@@ -17,6 +17,10 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, vllm_is_batch_invariant,
) )
from vllm.model_executor.layers.fused_moe.activation import (
MoEActivation,
apply_moe_activation,
)
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, FUSED_MOE_UNQUANTIZED_CONFIG,
FusedMoEConfig, FusedMoEConfig,
...@@ -32,7 +36,6 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( ...@@ -32,7 +36,6 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
) )
from vllm.model_executor.layers.fused_moe.utils import ( from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, _resize_cache,
apply_moe_activation,
disable_inplace, disable_inplace,
moe_kernel_quantize_input, moe_kernel_quantize_input,
) )
...@@ -1468,6 +1471,7 @@ def outplace_fused_experts_fake( ...@@ -1468,6 +1471,7 @@ def outplace_fused_experts_fake(
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
...@@ -1521,7 +1525,7 @@ def fused_experts( ...@@ -1521,7 +1525,7 @@ def fused_experts(
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
inplace: bool = False, inplace: bool = False,
activation: str = "silu", activation: MoEActivation = MoEActivation.SILU,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: torch.Tensor | None = None, expert_map: torch.Tensor | None = None,
...@@ -1539,7 +1543,7 @@ def fused_experts( ...@@ -1539,7 +1543,7 @@ def fused_experts(
w2=w2, w2=w2,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
activation=activation, activation=activation.value,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=quant_config.use_fp8_w8a8, use_fp8_w8a8=quant_config.use_fp8_w8a8,
use_int8_w8a8=quant_config.use_int8_w8a8, use_int8_w8a8=quant_config.use_int8_w8a8,
...@@ -1618,6 +1622,9 @@ def fused_experts_impl( ...@@ -1618,6 +1622,9 @@ def fused_experts_impl(
w1_bias: torch.Tensor | None = None, w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
# Convert string activation to enum for internal use
activation_enum = MoEActivation.from_str(activation)
# Check constraints. # Check constraints.
if use_int4_w4a16: if use_int4_w4a16:
assert hidden_states.size(1) // 2 == w1.size(2), "Hidden size mismatch" assert hidden_states.size(1) // 2 == w1.size(2), "Hidden size mismatch"
...@@ -1692,7 +1699,7 @@ def fused_experts_impl( ...@@ -1692,7 +1699,7 @@ def fused_experts_impl(
# This needs separate memory since it's used concurrently with cache1 # This needs separate memory since it's used concurrently with cache1
activation_out_dim = mk.FusedMoEPermuteExpertsUnpermute.adjust_N_for_activation( activation_out_dim = mk.FusedMoEPermuteExpertsUnpermute.adjust_N_for_activation(
N, activation N, activation_enum
) )
intermediate_cache2 = torch.empty( intermediate_cache2 = torch.empty(
(M * top_k_num, activation_out_dim), (M * top_k_num, activation_out_dim),
...@@ -1832,7 +1839,7 @@ def fused_experts_impl( ...@@ -1832,7 +1839,7 @@ def fused_experts_impl(
) )
apply_moe_activation( apply_moe_activation(
activation, intermediate_cache2, intermediate_cache1.view(-1, N) activation_enum, intermediate_cache2, intermediate_cache1.view(-1, N)
) )
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
...@@ -1932,8 +1939,13 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1932,8 +1939,13 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
return (weight_key, activation_key) in SUPPORTED_W_A return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod @staticmethod
def _supports_activation(activation: str) -> bool: def _supports_activation(activation: MoEActivation) -> bool:
return activation in ["silu", "gelu", "swigluoai", "swiglustep"] return activation in [
MoEActivation.SILU,
MoEActivation.GELU,
MoEActivation.SWIGLUOAI,
MoEActivation.SWIGLUSTEP,
]
@staticmethod @staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
...@@ -1957,7 +1969,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1957,7 +1969,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str, activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
activation_out_dim = self.adjust_N_for_activation(N, activation) activation_out_dim = self.adjust_N_for_activation(N, activation)
workspace1 = (M, topk, max(activation_out_dim, K)) workspace1 = (M, topk, max(activation_out_dim, K))
...@@ -1973,7 +1985,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1973,7 +1985,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str, activation: MoEActivation,
global_num_experts: int, global_num_experts: int,
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None, a1q_scale: torch.Tensor | None,
...@@ -2138,7 +2150,7 @@ class TritonWNA16Experts(TritonExperts): ...@@ -2138,7 +2150,7 @@ class TritonWNA16Experts(TritonExperts):
) )
@staticmethod @staticmethod
def _supports_activation(activation: str) -> bool: def _supports_activation(activation: MoEActivation) -> bool:
raise NotImplementedError( raise NotImplementedError(
"TritonWNA16Experts is not yet used by an Oracle. " "TritonWNA16Experts is not yet used by an Oracle. "
"This method should not be called." "This method should not be called."
...@@ -2159,7 +2171,7 @@ class TritonWNA16Experts(TritonExperts): ...@@ -2159,7 +2171,7 @@ class TritonWNA16Experts(TritonExperts):
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str, activation: MoEActivation,
global_num_experts: int, global_num_experts: int,
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None, a1q_scale: torch.Tensor | None,
......
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, FUSED_MOE_UNQUANTIZED_CONFIG,
FusedMoEParallelConfig, FusedMoEParallelConfig,
...@@ -172,7 +173,7 @@ def triton_kernel_moe_forward( ...@@ -172,7 +173,7 @@ def triton_kernel_moe_forward(
gating_output: torch.Tensor, gating_output: torch.Tensor,
topk: int, topk: int,
renormalize: bool, renormalize: bool,
activation: str = "silu", activation: MoEActivation = MoEActivation.SWIGLUOAI,
quant_config: FusedMoEQuantConfig | None = None, quant_config: FusedMoEQuantConfig | None = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
...@@ -211,7 +212,7 @@ def triton_kernel_fused_experts( ...@@ -211,7 +212,7 @@ def triton_kernel_fused_experts(
gather_indx, # GatherIndx gather_indx, # GatherIndx
scatter_indx, # ScatterIndx scatter_indx, # ScatterIndx
topk: int, topk: int,
activation: str = "silu", activation: MoEActivation = MoEActivation.SWIGLUOAI,
quant_config: FusedMoEQuantConfig | None = None, quant_config: FusedMoEQuantConfig | None = None,
swiglu_alpha: float = 1.702, swiglu_alpha: float = 1.702,
swiglu_limit: float = 7.0, swiglu_limit: float = 7.0,
...@@ -222,6 +223,9 @@ def triton_kernel_fused_experts( ...@@ -222,6 +223,9 @@ def triton_kernel_fused_experts(
a1q_scale: torch.Tensor | None = None, a1q_scale: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Triton implementation of fused expert computation using OAI kernels.""" """Triton implementation of fused expert computation using OAI kernels."""
assert activation == MoEActivation.SWIGLUOAI, (
"Only SWIGLUOAI activation is supported"
)
if quant_config is None: if quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
...@@ -379,7 +383,7 @@ class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -379,7 +383,7 @@ class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
) )
@staticmethod @staticmethod
def _supports_activation(activation: str) -> bool: def _supports_activation(activation: MoEActivation) -> bool:
raise NotImplementedError( raise NotImplementedError(
"OAITritonExperts is not yet used by an Oracle. " "OAITritonExperts is not yet used by an Oracle. "
"This method should not be called." "This method should not be called."
...@@ -463,7 +467,7 @@ class OAITritonExperts(BaseOAITritonExperts): ...@@ -463,7 +467,7 @@ class OAITritonExperts(BaseOAITritonExperts):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str, activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# workspace are allocated inside the kernel # workspace are allocated inside the kernel
activation_out_dim = self.adjust_N_for_activation(N, activation) activation_out_dim = self.adjust_N_for_activation(N, activation)
...@@ -480,7 +484,7 @@ class OAITritonExperts(BaseOAITritonExperts): ...@@ -480,7 +484,7 @@ class OAITritonExperts(BaseOAITritonExperts):
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str, activation: MoEActivation,
global_num_experts: int, global_num_experts: int,
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None, a1q_scale: torch.Tensor | None,
...@@ -547,7 +551,7 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts): ...@@ -547,7 +551,7 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str, activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# workspace are allocated inside the kernel # workspace are allocated inside the kernel
activation_out_dim = self.adjust_N_for_activation(N, activation) activation_out_dim = self.adjust_N_for_activation(N, activation)
...@@ -567,7 +571,7 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts): ...@@ -567,7 +571,7 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str, activation: MoEActivation,
global_num_experts: int, global_num_experts: int,
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None, a1q_scale: torch.Tensor | None,
......
...@@ -20,6 +20,7 @@ from vllm.distributed import ( ...@@ -20,6 +20,7 @@ from vllm.distributed import (
from vllm.distributed.eplb.eplb_state import EplbLayerState, EplbState from vllm.distributed.eplb.eplb_state import EplbLayerState, EplbState
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEParallelConfig, FusedMoEParallelConfig,
...@@ -500,7 +501,7 @@ class FusedMoE(CustomOp): ...@@ -500,7 +501,7 @@ class FusedMoE(CustomOp):
# TODO(bnell): end attributes # TODO(bnell): end attributes
self.apply_router_weight_on_input = apply_router_weight_on_input self.apply_router_weight_on_input = apply_router_weight_on_input
self.activation = activation self.activation = MoEActivation.from_str(activation)
self.router = create_fused_moe_router( self.router = create_fused_moe_router(
top_k=top_k, top_k=top_k,
...@@ -554,7 +555,7 @@ class FusedMoE(CustomOp): ...@@ -554,7 +555,7 @@ class FusedMoE(CustomOp):
has_bias=has_bias, has_bias=has_bias,
is_act_and_mul=is_act_and_mul, is_act_and_mul=is_act_and_mul,
is_lora_enabled=vllm_config.lora_config is not None, is_lora_enabled=vllm_config.lora_config is not None,
activation=activation, activation=self.activation,
device=vllm_config.device_config.device, device=vllm_config.device_config.device,
routing_method=self.routing_method_type, routing_method=self.routing_method_type,
# TODO: in_dtype == out_dtype? # TODO: in_dtype == out_dtype?
......
...@@ -12,6 +12,10 @@ import torch ...@@ -12,6 +12,10 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.forward_context import get_forward_context, is_forward_context_available from vllm.forward_context import get_forward_context, is_forward_context_available
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import (
MoEActivation,
apply_moe_activation,
)
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEParallelConfig, FusedMoEParallelConfig,
...@@ -19,7 +23,6 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -19,7 +23,6 @@ from vllm.model_executor.layers.fused_moe.config import (
) )
from vllm.model_executor.layers.fused_moe.utils import ( from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, _resize_cache,
apply_moe_activation,
count_expert_num_tokens, count_expert_num_tokens,
disable_inplace, disable_inplace,
) )
...@@ -536,7 +539,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -536,7 +539,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def _supports_activation(activation: str) -> bool: def _supports_activation(activation: MoEActivation) -> bool:
""" """
Whether the kernel supports a particular act function. Whether the kernel supports a particular act function.
""" """
...@@ -658,7 +661,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -658,7 +661,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: ExpertTokensMetadata | None, expert_tokens_meta: ExpertTokensMetadata | None,
activation: str, activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
""" """
Compute the shapes for the temporary and final outputs of the two gemms Compute the shapes for the temporary and final outputs of the two gemms
...@@ -690,7 +693,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -690,7 +693,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
raise NotImplementedError raise NotImplementedError
@staticmethod @staticmethod
def adjust_N_for_activation(N: int, activation: str) -> int: def adjust_N_for_activation(N: int, activation: MoEActivation) -> int:
""" """
Calculate the output dimension for the activation function. Calculate the output dimension for the activation function.
...@@ -702,16 +705,15 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -702,16 +705,15 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
Args: Args:
N: The intermediate size (width of w1/w3 weights). N: The intermediate size (width of w1/w3 weights).
activation: The activation function name. activation: The activation function enum.
Returns: Returns:
The output dimension after activation. The output dimension after activation.
""" """
is_no_mul = activation.endswith("_no_mul") return N if not activation.is_gated else N // 2
return N if is_no_mul else N // 2
def activation( def activation(
self, activation: str, output: torch.Tensor, input: torch.Tensor self, activation: MoEActivation, output: torch.Tensor, input: torch.Tensor
) -> None: ) -> None:
apply_moe_activation(activation, output, input) apply_moe_activation(activation, output, input)
...@@ -732,7 +734,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -732,7 +734,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str, activation: MoEActivation,
global_num_experts: int, global_num_experts: int,
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None, a1q_scale: torch.Tensor | None,
...@@ -892,7 +894,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -892,7 +894,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: ExpertTokensMetadata | None, expert_tokens_meta: ExpertTokensMetadata | None,
activation: str, activation: MoEActivation,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Allocate temporary and output buffers for the fused experts op. Allocate temporary and output buffers for the fused experts op.
...@@ -1135,7 +1137,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1135,7 +1137,7 @@ class FusedMoEModularKernel(torch.nn.Module):
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str, activation: MoEActivation,
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
...@@ -1309,7 +1311,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1309,7 +1311,7 @@ class FusedMoEModularKernel(torch.nn.Module):
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str = "silu", activation: MoEActivation = MoEActivation.SILU,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: torch.Tensor | None = None, expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
...@@ -1326,7 +1328,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1326,7 +1328,7 @@ class FusedMoEModularKernel(torch.nn.Module):
- topk_weights (torch.Tensor): The topk weights applied at the end of - topk_weights (torch.Tensor): The topk weights applied at the end of
the layer. the layer.
- topk_ids (torch.Tensor): A map of row to expert id. - topk_ids (torch.Tensor): A map of row to expert id.
- activation (str): The activation function to apply after the first - activation (MoEActivation): The activation function to apply after the first
MoE layer. MoE layer.
- global_num_experts (int): The total number of experts in the global - global_num_experts (int): The total number of experts in the global
expert space. expert space.
......
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, FUSED_MOE_UNQUANTIZED_CONFIG,
FusedMoEParallelConfig, FusedMoEParallelConfig,
...@@ -184,7 +185,7 @@ def rocm_aiter_fused_experts( ...@@ -184,7 +185,7 @@ def rocm_aiter_fused_experts(
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str = "silu", activation: MoEActivation = MoEActivation.SILU,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
expert_map: torch.Tensor | None = None, expert_map: torch.Tensor | None = None,
quant_config: FusedMoEQuantConfig | None = None, quant_config: FusedMoEQuantConfig | None = None,
...@@ -196,9 +197,13 @@ def rocm_aiter_fused_experts( ...@@ -196,9 +197,13 @@ def rocm_aiter_fused_experts(
if quant_config is None: if quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
activation_method = ( if activation == MoEActivation.SILU:
ActivationMethod.SILU if activation == "silu" else ActivationMethod.GELU activation_method = ActivationMethod.SILU
) elif activation == MoEActivation.GELU:
activation_method = ActivationMethod.GELU
else:
raise ValueError(f"Unsupported activation: {activation}")
# All AITER Fused MoE kernels are expecting the following datatypes # All AITER Fused MoE kernels are expecting the following datatypes
topk_weights = topk_weights.to(torch.float32) topk_weights = topk_weights.to(torch.float32)
topk_ids = topk_ids.to(torch.int32) topk_ids = topk_ids.to(torch.int32)
...@@ -322,8 +327,8 @@ class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -322,8 +327,8 @@ class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
return (weight_key, activation_key) in SUPPORTED_W_A return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod @staticmethod
def _supports_activation(activation: str) -> bool: def _supports_activation(activation: MoEActivation) -> bool:
return activation in ["silu", "gelu"] return activation in [MoEActivation.SILU, MoEActivation.GELU]
@staticmethod @staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
...@@ -347,7 +352,7 @@ class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -347,7 +352,7 @@ class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str, activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# Workspaces are managed internally by AITER. # Workspaces are managed internally by AITER.
workspace1 = (0,) workspace1 = (0,)
...@@ -363,7 +368,7 @@ class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -363,7 +368,7 @@ class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str, activation: MoEActivation,
global_num_experts: int, global_num_experts: int,
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None, a1q_scale: torch.Tensor | None,
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
...@@ -45,7 +46,7 @@ class TritonOrCutlassExperts(FallbackExperts): ...@@ -45,7 +46,7 @@ class TritonOrCutlassExperts(FallbackExperts):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str, activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# Small batch fallback for sm100. # Small batch fallback for sm100.
if self.is_sm100 and M <= 8: if self.is_sm100 and M <= 8:
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
...@@ -45,7 +46,7 @@ class TritonOrDeepGemmExperts(FallbackExperts): ...@@ -45,7 +46,7 @@ class TritonOrDeepGemmExperts(FallbackExperts):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str, activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# Note: the deep gemm workspaces are strictly larger than the triton # Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm # workspaces so we can be pessimistic here and allocate for DeepGemm
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEParallelConfig, FusedMoEParallelConfig,
...@@ -64,7 +65,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -64,7 +65,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
) )
@staticmethod @staticmethod
def _supports_activation(activation: str) -> bool: def _supports_activation(activation: MoEActivation) -> bool:
raise NotImplementedError( raise NotImplementedError(
"TrtLlmGenExperts is not yet used by an Oracle. " "TrtLlmGenExperts is not yet used by an Oracle. "
"This method should not be called." "This method should not be called."
...@@ -95,7 +96,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -95,7 +96,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str, activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# The workspaces for this implementation are managed by flashinfer. # The workspaces for this implementation are managed by flashinfer.
workspace1 = (0,) workspace1 = (0,)
...@@ -111,7 +112,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -111,7 +112,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str, activation: MoEActivation,
global_num_experts: int, global_num_experts: int,
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None, a1q_scale: torch.Tensor | None,
......
...@@ -4,7 +4,6 @@ import functools ...@@ -4,7 +4,6 @@ import functools
from math import prod from math import prod
import torch import torch
import torch.nn.functional as F
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
...@@ -341,65 +340,6 @@ def _validate_scale_shape( ...@@ -341,65 +340,6 @@ def _validate_scale_shape(
assert a_scale.shape == expected, f"{a_scale.shape} == {expected}" assert a_scale.shape == expected, f"{a_scale.shape} == {expected}"
def activation_without_mul(activation: str) -> str:
return activation + "_no_mul"
RELU2_NO_MUL: str = activation_without_mul("relu2")
SILU_NO_MUL: str = activation_without_mul("silu")
GELU_NO_MUL: str = activation_without_mul("gelu")
def apply_moe_activation(
activation: str,
output: torch.Tensor,
input: torch.Tensor,
) -> torch.Tensor:
"""
Apply MoE activation function.
For *_and_mul activations (silu, gelu, swigluoai):
- Expects output.size(-1) * 2 == input.size(-1)
For *_no_mul activations (silu_no_mul, gelu_no_mul, relu2_no_mul):
- Expects output.size(-1) == input.size(-1)
"""
is_no_mul = activation.endswith("_no_mul")
if is_no_mul:
assert output.size(-1) == input.size(-1), (
f"{activation} expects equal sizes: {output.size(-1)} vs {input.size(-1)}"
)
else:
assert output.size(-1) * 2 == input.size(-1), (
f"{activation} expects 2x ratio: {output.size(-1) * 2} vs {input.size(-1)}"
)
# Activations with gated multiplication (gate × activation(up))
if activation == "silu":
torch.ops._C.silu_and_mul(output, input)
elif activation == "gelu":
torch.ops._C.gelu_and_mul(output, input)
elif activation == "swigluoai":
torch.ops._C.swigluoai_and_mul(output, input)
elif activation == "swiglustep":
from vllm.model_executor.layers.activation import swiglustep_and_mul_triton
swiglustep_and_mul_triton(output, input)
# Activations without gated multiplication
elif activation == SILU_NO_MUL:
output.copy_(F.silu(input))
elif activation == GELU_NO_MUL:
output.copy_(F.gelu(input))
elif activation == RELU2_NO_MUL:
F.relu(input, inplace=True)
torch.square(input, out=output)
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
return output
# Torch custom ops can't deal with outputs aliasing inputs so we need to # Torch custom ops can't deal with outputs aliasing inputs so we need to
# disable inplace for torch >= 2.9. # disable inplace for torch >= 2.9.
# See https://github.com/vllm-project/vllm/issues/26378 # See https://github.com/vllm-project/vllm/issues/26378
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEParallelConfig, FusedMoEParallelConfig,
...@@ -55,8 +56,12 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -55,8 +56,12 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
return False return False
@staticmethod @staticmethod
def _supports_activation(activation: str) -> bool: def _supports_activation(activation: MoEActivation) -> bool:
return activation in ["silu", "gelu", "swigluoai"] return activation in [
MoEActivation.SILU,
MoEActivation.GELU,
MoEActivation.SWIGLUOAI,
]
@staticmethod @staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
...@@ -92,7 +97,7 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -92,7 +97,7 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str, activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
workspace1 = (0,) workspace1 = (0,)
workspace2 = (0,) workspace2 = (0,)
...@@ -107,7 +112,7 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -107,7 +112,7 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str, activation: MoEActivation,
global_num_experts: int, global_num_experts: int,
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None, a1q_scale: torch.Tensor | None,
...@@ -129,7 +134,7 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -129,7 +134,7 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
n_experts_per_token=topk, n_experts_per_token=topk,
activation=activation, activation=activation.value,
num_experts=self.moe_config.num_local_experts, num_experts=self.moe_config.num_local_experts,
ep_rank=self.moe_config.ep_rank, ep_rank=self.moe_config.ep_rank,
ep_size=self.moe_config.ep_size, ep_size=self.moe_config.ep_size,
......
...@@ -24,6 +24,7 @@ from vllm.model_executor.layers.fused_moe import ( ...@@ -24,6 +24,7 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoeWeightScaleSupported, FusedMoeWeightScaleSupported,
UnquantizedFusedMoEMethod, UnquantizedFusedMoEMethod,
) )
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
...@@ -622,7 +623,9 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -622,7 +623,9 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.is_monolithic assert self.is_monolithic
assert layer.activation == "silu", "Only SiLU activation is supported." assert layer.activation == MoEActivation.SILU, (
f"Only SiLU activation is supported, not {layer.activation}."
)
assert ( assert (
self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
and not layer.enable_eplb and not layer.enable_eplb
...@@ -649,7 +652,9 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -649,7 +652,9 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic assert not self.is_monolithic
assert layer.activation == "silu", "Only SiLU activation is supported." assert layer.activation == MoEActivation.SILU, (
f"Only SiLU activation is supported, not {layer.activation}."
)
# EPLB path # EPLB path
if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
...@@ -1025,7 +1030,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1025,7 +1030,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.is_monolithic assert self.is_monolithic
assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
assert layer.activation == "silu" assert layer.activation == MoEActivation.SILU, (
f"Only SiLU activation is supported, not {layer.activation}."
)
if self.block_quant: if self.block_quant:
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
...@@ -2271,19 +2278,21 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -2271,19 +2278,21 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
assert not layer.enable_eplb, "EPLB not supported for W4A8-int MoE yet." assert not layer.enable_eplb, "EPLB not supported for W4A8-int MoE yet."
assert layer.activation in ("silu", "swigluoai", "swiglu"), ( assert layer.activation in (
"Only SiLU/SwiGLUGU/SwiGLUUG are supported." MoEActivation.SILU,
) MoEActivation.SWIGLUOAI,
MoEActivation.SWIGLUSTEP,
), "Only SiLU/SwiGLUGU/SwiGLUUG are supported."
assert layer.expert_map is None, """expert_map/EP not implemented assert layer.expert_map is None, """expert_map/EP not implemented
for CPU dyn-4bit MoE.""" for CPU dyn-4bit MoE."""
def _act_kind(s: str) -> int: def _act_kind(s: MoEActivation) -> int:
# 0 = SwiGLU_Gu (SiLU(g)*u), 1 = SwiGLU_Ug (SiLU(u)*g), 2 = SiLU # 0 = SwiGLU_Gu (SiLU(g)*u), 1 = SwiGLU_Ug (SiLU(u)*g), 2 = SiLU
if s == "swiglu": if s == MoEActivation.SWIGLUSTEP:
return 0 return 0
if s == "swigluoai": if s == MoEActivation.SWIGLUOAI:
return 1 return 1
if s == "silu": if s == MoEActivation.SILU:
return 2 return 2
raise ValueError(f"Unknown activation '{s}'") raise ValueError(f"Unknown activation '{s}'")
......
...@@ -23,6 +23,7 @@ from vllm.model_executor.layers.fused_moe import ( ...@@ -23,6 +23,7 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoEPermuteExpertsUnpermute, FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize, FusedMoEPrepareAndFinalize,
FusedMoeWeightScaleSupported, FusedMoeWeightScaleSupported,
MoEActivation,
) )
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, FusedMoEQuantConfig,
...@@ -965,7 +966,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -965,7 +966,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# TODO(rob): convert this to MK. # TODO(rob): convert this to MK.
if layer.enable_eplb: if layer.enable_eplb:
raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.") raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
assert layer.activation == "silu", ( assert layer.activation == MoEActivation.SILU, (
f"Expected 'silu' activation but got {layer.activation}" f"Expected 'silu' activation but got {layer.activation}"
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment