Unverified Commit ff1f83b0 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Refactor] Replace `activation: str` with `MoEActivation` enum (#33843)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
Signed-off-by: default avatarMichael Goin <mgoin64@gmail.com>
parent 83b47f67
......@@ -11,6 +11,7 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from tests.kernels.moe.utils import make_dummy_moe_config
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
......@@ -161,7 +162,7 @@ def bench_run(
w2_fp8q_cutlass,
topk_weights,
topk_ids,
activation="silu",
activation=MoEActivation.SILU,
global_num_experts=num_experts,
)
torch.cuda.synchronize()
......
......@@ -16,6 +16,7 @@ import torch
from ray.experimental.tqdm_ray import tqdm
from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
......@@ -211,7 +212,8 @@ def benchmark_config(
hidden_dim=hidden_size,
intermediate_size_per_partition=shard_intermediate_size,
num_local_experts=num_experts,
activation="silu",
num_logical_experts=num_experts,
activation=MoEActivation.SILU,
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
in_dtype=init_dtype,
routing_method=RoutingMethodType.TopK,
......
......@@ -22,6 +22,7 @@ from vllm.distributed import (
)
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
......@@ -599,7 +600,7 @@ def make_modular_kernel(
moe_parallel_config=moe_parallel_config,
in_dtype=config.dtype,
max_num_tokens=next_power_of_2(config.M),
activation="silu",
activation=MoEActivation.SILU,
device=vllm_config.device_config.device,
routing_method=RoutingMethodType.DeepSeekV3,
)
......
......@@ -6,6 +6,7 @@ import torch
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
from vllm._custom_ops import cpu_fused_moe, cpu_prepack_moe_weight
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.cpu_fused_moe import _CPU_MOE_ACT_FN
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
......@@ -19,7 +20,7 @@ EXPERT_NUM = [
HIDDEN_DIM = [128, 2880]
INTERMEDIATE_DIM = [128, 2880]
BATCH_SIZE = [1, 64, 256]
ACT = ["silu", "swigluoai"]
ACT = [MoEActivation.SILU, MoEActivation.SWIGLUOAI]
USE_BIAS = [True, False]
ISA = ["amx", "vec"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
DTYPE = [torch.bfloat16]
......@@ -33,7 +34,7 @@ def ref_fused_moe(
w2_bias: torch.Tensor | None,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
activation: MoEActivation,
) -> torch.Tensor:
len_experts = w13.size(0)
......@@ -103,7 +104,7 @@ def test_cpu_fused_moe(
intermediate_size: int,
use_bias: bool,
dtype: torch.dtype,
act: str,
act: MoEActivation,
isa: str,
):
set_random_seed(0)
......@@ -153,7 +154,7 @@ def test_cpu_fused_moe(
w2_bias,
topk_weight,
topk_ids,
act,
act.value,
isa,
)
......
......@@ -12,6 +12,7 @@ from tests.kernels.moe.utils import make_dummy_moe_config
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
FusedMoEQuantConfig,
......@@ -531,7 +532,7 @@ def test_run_cutlass_moe_fp8(
c_strides1 = torch.full((e,), 2 * n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((e,), k, device="cuda", dtype=torch.int64)
activation = "silu"
activation = MoEActivation.SILU
a1q, a1q_scale = moe_kernel_quantize_input(
mt.a, mt.a_scale, torch.float8_e4m3fn, per_act_token
)
......
......@@ -16,6 +16,7 @@ from typing_extensions import ParamSpec
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
fp8_w8a8_moe_quant_config,
......@@ -324,7 +325,7 @@ def deepep_deepgemm_moe_impl(
w2=w2,
topk_weights=test_tensors.topk_weights,
topk_ids=test_tensors.topk,
activation="silu",
activation=MoEActivation.SILU,
global_num_experts=num_experts,
expert_map=build_expert_map(),
apply_router_weight_on_input=False,
......
......@@ -15,6 +15,7 @@ from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import TritonExperts
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
)
......@@ -260,7 +261,7 @@ def deep_ep_moe_impl(
w2=w2,
topk_weights=topk_weights_chunk,
topk_ids=topk_chunk,
activation="silu",
activation=MoEActivation.SILU,
global_num_experts=num_experts,
expert_map=build_expert_map(),
apply_router_weight_on_input=False,
......
......@@ -7,6 +7,7 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
......@@ -93,9 +94,14 @@ class TestData:
@staticmethod
def make_moe_tensors_8bit(
m: int, k: int, n: int, e: int, is_trtllm: bool, activation: str = "silu"
m: int,
k: int,
n: int,
e: int,
is_trtllm: bool,
activation: MoEActivation = MoEActivation.SILU,
) -> "TestData":
is_gated = activation != "relu2_no_mul"
is_gated = activation.is_gated
hidden_states = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
w13 = torch.randn(
......@@ -194,7 +200,7 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
activation="silu",
activation=MoEActivation.SILU,
global_num_experts=e,
expert_map=None,
apply_router_weight_on_input=True,
......@@ -219,21 +225,19 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("activation", ["silu", "relu2_no_mul"])
@pytest.mark.parametrize("activation", [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL])
def test_flashinfer_cutlass_moe_fp8_no_graph(
m: int,
n: int,
k: int,
e: int,
topk: int,
activation: str,
activation: MoEActivation,
monkeypatch,
workspace_init,
):
set_random_seed(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
assert activation in ["silu", "relu2_no_mul"]
is_act_and_mul = activation == "silu_and_mul"
with set_current_vllm_config(vllm_config):
td = TestData.make_moe_tensors_8bit(
m, k, n, e, is_trtllm=False, activation=activation
......@@ -292,7 +296,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
device="cuda",
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
in_dtype=torch.bfloat16,
is_act_and_mul=is_act_and_mul,
is_act_and_mul=activation.is_gated,
routing_method=RoutingMethodType.TopK,
)
......
......@@ -13,6 +13,7 @@ from tests.kernels.utils import torch_moe
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
......@@ -54,7 +55,7 @@ MNK_FACTORS = [
@pytest.mark.parametrize("e", [40, 64, 256])
@pytest.mark.parametrize("topk", [1, 6, 8])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("activation", ["silu_and_mul", "relu2"])
@pytest.mark.parametrize("activation", [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL])
@torch.inference_mode()
def test_flashinfer_fp4_moe_no_graph(
m: int,
......@@ -63,7 +64,7 @@ def test_flashinfer_fp4_moe_no_graph(
e: int,
topk: int,
dtype: torch.dtype,
activation: str,
activation: MoEActivation,
workspace_init,
):
set_random_seed(7)
......@@ -73,7 +74,7 @@ def test_flashinfer_fp4_moe_no_graph(
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
quant_blocksize = 16
is_gated_act = activation == "silu_and_mul"
is_gated_act = activation.is_gated
w1_q, w2_q, quant_config = make_test_quant_config(
e,
......@@ -112,15 +113,13 @@ def test_flashinfer_fp4_moe_no_graph(
inplace=False,
)
fi_activation = {"silu_and_mul": "silu", "relu2": "relu2_no_mul"}[activation]
flashinfer_output = flashinfer_experts(
hidden_states=a,
w1=w1_q,
w2=w2_q,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=fi_activation,
activation=activation,
)
# Reference check:
......
......@@ -7,6 +7,7 @@ Test modular OAI Triton MoE
import pytest
import torch
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.utils.import_utils import has_triton_kernels
if not has_triton_kernels():
......@@ -192,7 +193,7 @@ def oai_triton_moe_impl(
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation="swigluoai",
activation=MoEActivation.SWIGLUOAI,
global_num_experts=num_experts,
expert_map=None,
apply_router_weight_on_input=False,
......
......@@ -29,6 +29,7 @@ from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.parallel_state import init_distributed_environment
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.model_executor.layers.fused_moe import (
MoEActivation,
fused_topk,
)
from vllm.model_executor.layers.fused_moe.config import (
......@@ -1155,7 +1156,10 @@ def test_fused_marlin_moe_with_bias(m):
@pytest.mark.parametrize("m", [1, 64, 256])
@pytest.mark.parametrize("n,k", [(1024, 1024), (2048, 2048)])
@pytest.mark.parametrize("e,topk", [(8, 2), (64, 4)])
def test_fused_marlin_moe_non_gated(m: int, n: int, k: int, e: int, topk: int):
@pytest.mark.parametrize("activation", [MoEActivation.RELU2_NO_MUL])
def test_fused_marlin_moe_non_gated(
m: int, n: int, k: int, e: int, topk: int, activation: MoEActivation
):
"""Test Marlin MoE with non-gated activation (relu2_no_mul).
Non-gated activations like relu2 don't have the gate-up projection pattern,
......@@ -1198,7 +1202,7 @@ def test_fused_marlin_moe_non_gated(m: int, n: int, k: int, e: int, topk: int):
w2_data.w_ref,
score,
topk,
activation="relu2",
activation=activation,
)
marlin_output = fused_marlin_moe(
......@@ -1223,7 +1227,7 @@ def test_fused_marlin_moe_non_gated(m: int, n: int, k: int, e: int, topk: int):
w2_zeros=w2_data.zeros,
quant_type_id=quant_type.id,
is_k_full=is_k_full,
activation="relu2_no_mul",
activation=activation,
)
torch.testing.assert_close(marlin_output, torch_output, atol=1e-1, rtol=0)
......@@ -1330,9 +1334,18 @@ def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype):
@pytest.mark.parametrize("topk", [2])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("with_bias", [False, True])
@pytest.mark.parametrize("activation", ["silu"])
@pytest.mark.parametrize("activation", [MoEActivation.SILU])
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only test")
def test_cpu_fused_moe_basic(m, n, k, e, topk, dtype, with_bias, activation):
def test_cpu_fused_moe_basic(
m: int,
n: int,
k: int,
e: int,
topk: int,
dtype: torch.dtype,
with_bias: bool,
activation: MoEActivation,
):
from vllm.model_executor.layers.fused_moe.cpu_fused_moe import CPUFusedMOE
device = "cpu"
......@@ -1608,6 +1621,7 @@ def test_unquantized_bf16_flashinfer_trtllm_backend(
hidden_dim=k,
intermediate_size_per_partition=n,
num_local_experts=e,
num_logical_experts=e,
activation="silu",
device="cuda",
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
......
......@@ -9,6 +9,7 @@ from tests.kernels.utils import torch_experts
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
......@@ -149,7 +150,7 @@ def pplx_cutlass_moe(
num_local_experts=num_local_experts,
num_logical_experts=num_experts,
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
activation="silu",
activation=MoEActivation.SILU,
in_dtype=torch.bfloat16,
device="cuda",
routing_method=RoutingMethodType.Llama4,
......
......@@ -11,15 +11,11 @@ import pytest
import torch
from tests.kernels.moe.utils import make_dummy_moe_config
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
)
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
from vllm.model_executor.layers.fused_moe.utils import (
GELU_NO_MUL,
RELU2_NO_MUL,
SILU_NO_MUL,
)
from vllm.platforms import current_platform
# Test parameters
......@@ -28,7 +24,11 @@ N_SIZES = [128, 256]
K_SIZES = [64, 128]
TOPK_VALUES = [1, 2]
NUM_EXPERTS = 8
NO_MUL_ACTIVATIONS = [SILU_NO_MUL, GELU_NO_MUL, RELU2_NO_MUL]
NO_MUL_ACTIVATIONS = [
MoEActivation.SILU_NO_MUL,
MoEActivation.GELU_NO_MUL,
MoEActivation.RELU2_NO_MUL,
]
def make_test_tensors(
......@@ -73,7 +73,7 @@ def test_triton_experts_no_mul_activation(
n: int,
k: int,
topk: int,
activation: str,
activation: MoEActivation,
):
hidden_states, w1, w2, topk_weights, topk_ids = make_test_tensors(
m, n, k, NUM_EXPERTS, topk
......@@ -161,11 +161,11 @@ def test_workspace_shapes_no_mul_vs_gated():
)
ws1_no_mul, _, out_no_mul = experts.workspace_shapes(
M, N, K, topk, 8, 8, None, SILU_NO_MUL
M, N, K, topk, 8, 8, None, MoEActivation.SILU_NO_MUL
)
ws1_gated, _, out_gated = experts.workspace_shapes(
M, N, K, topk, 8, 8, None, "silu"
M, N, K, topk, 8, 8, None, MoEActivation.SILU
)
# For no_mul: activation_out_dim = N
......@@ -202,10 +202,10 @@ def test_adjust_n_for_activation():
N = 256
# Gated activations should return N // 2
assert experts.adjust_N_for_activation(N, "silu") == N // 2
assert experts.adjust_N_for_activation(N, "gelu") == N // 2
assert experts.adjust_N_for_activation(N, MoEActivation.SILU) == N // 2
assert experts.adjust_N_for_activation(N, MoEActivation.GELU) == N // 2
# Non-gated activations should return N
assert experts.adjust_N_for_activation(N, SILU_NO_MUL) == N
assert experts.adjust_N_for_activation(N, GELU_NO_MUL) == N
assert experts.adjust_N_for_activation(N, RELU2_NO_MUL) == N
assert experts.adjust_N_for_activation(N, MoEActivation.SILU_NO_MUL) == N
assert experts.adjust_N_for_activation(N, MoEActivation.GELU_NO_MUL) == N
assert experts.adjust_N_for_activation(N, MoEActivation.RELU2_NO_MUL) == N
......@@ -12,6 +12,7 @@ from vllm.model_executor.layers.fused_moe import (
fused_experts,
fused_topk,
)
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
......@@ -54,7 +55,7 @@ def make_dummy_moe_config(
num_local_experts=num_experts,
num_logical_experts=num_experts,
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
activation="silu",
activation=MoEActivation.SILU,
in_dtype=in_dtype,
device="cuda",
routing_method=RoutingMethodType.TopK,
......
......@@ -15,6 +15,7 @@ from torch._prims_common import TensorLikeType
from tests.kernels.quant_utils import native_w8a8_block_matmul
from vllm.model_executor.custom_op import op_registry
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.utils.torch_utils import make_tensor_with_pad
from vllm.v1.attention.backend import AttentionType
......@@ -840,7 +841,7 @@ def torch_experts(
per_act_token_quant=False,
block_shape: list[int] | None = None,
apply_router_weights_on_input: bool = False,
activation: str = "silu_and_mul",
activation: MoEActivation = MoEActivation.SILU,
) -> torch.Tensor:
assert (
global_num_experts == -1
......@@ -883,7 +884,7 @@ def torch_experts(
f32 = torch.float32
act = op_registry[activation]
act = op_registry[activation.custom_op_name]
for i in range(num_experts):
mask = topk_ids == i
......@@ -973,7 +974,7 @@ def torch_moe(
b_bias2: torch.Tensor | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
activation: str = "silu_and_mul",
activation: MoEActivation = MoEActivation.SILU,
) -> torch.Tensor:
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
......
......@@ -4,6 +4,11 @@
from contextlib import contextmanager
from typing import Any
from vllm.model_executor.layers.fused_moe.activation import (
MoEActivation,
activation_without_mul,
apply_moe_activation,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
RoutingMethodType,
......@@ -27,7 +32,6 @@ from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
UnquantizedFusedMoEMethod,
)
from vllm.model_executor.layers.fused_moe.utils import activation_without_mul
from vllm.model_executor.layers.fused_moe.zero_expert_fused_moe import (
ZeroExpertFusedMoE,
)
......@@ -54,6 +58,7 @@ __all__ = [
"FusedMoERouter",
"FusedMoEConfig",
"FusedMoEMethodBase",
"MoEActivation",
"UnquantizedFusedMoEMethod",
"FusedMoeWeightScaleSupported",
"FusedMoEPermuteExpertsUnpermute",
......@@ -63,6 +68,7 @@ __all__ = [
"SharedFusedMoE",
"ZeroExpertFusedMoE",
"activation_without_mul",
"apply_moe_activation",
"override_config",
"get_config",
]
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""MoE activation function enum and utilities."""
from enum import Enum
import torch
import torch.nn.functional as F
class MoEActivation(Enum):
"""Activation functions for MoE layers."""
# Gated activations (gate * activation(up)) expect input of shape [..., 2*d]
# and produce output of shape [..., d]
SILU = "silu"
GELU = "gelu"
RELU2 = "relu2"
SWIGLUOAI = "swigluoai"
SWIGLUSTEP = "swiglustep"
# Non-gated activations (no mul with gate) expect input of shape [..., d]
# and produce output of shape [..., d].
# NOTE: Non-gated activations require the "_no_mul" suffix to be present.
SILU_NO_MUL = "silu_no_mul"
GELU_NO_MUL = "gelu_no_mul"
RELU2_NO_MUL = "relu2_no_mul"
@property
def is_gated(self) -> bool:
"""Returns True if activation expects gate*activation(up) pattern.
Gated activations expect input tensor with 2x the output size,
where the first half is the gate and second half is the up projection.
"""
return not self.value.endswith("_no_mul")
@property
def custom_op_name(self) -> str:
"""Maps to the CustomOp name of activations
in vllm/model_executor/layers/activation.py."""
return _CUSTOM_OP_NAMES[self]
def without_mul(self) -> "MoEActivation":
"""Get the non-gated variant of this activation.
For activations that have a _no_mul variant, returns that variant.
For activations without a _no_mul variant (or already _no_mul),
returns self.
"""
return _WITHOUT_MUL.get(self, self)
@classmethod
def from_str(cls, s: str) -> "MoEActivation":
"""Parse from string for backward compatibility."""
for member in cls:
if member.value == s:
return member
valid = [m.value for m in cls]
raise ValueError(f"Unknown MoE activation: {s!r}. Valid activations: {valid}")
# Module-level lookup tables used by MoEActivation functions.
_CUSTOM_OP_NAMES: dict[MoEActivation, str] = {
MoEActivation.SILU: "silu_and_mul",
MoEActivation.GELU: "gelu_and_mul",
MoEActivation.SWIGLUOAI: "swigluoai_and_mul",
MoEActivation.SWIGLUSTEP: "swiglustep_and_mul",
MoEActivation.RELU2: "relu2",
MoEActivation.SILU_NO_MUL: "silu_and_mul",
MoEActivation.GELU_NO_MUL: "gelu_and_mul",
MoEActivation.RELU2_NO_MUL: "relu2",
}
_WITHOUT_MUL: dict[MoEActivation, MoEActivation] = {
MoEActivation.SILU: MoEActivation.SILU_NO_MUL,
MoEActivation.GELU: MoEActivation.GELU_NO_MUL,
MoEActivation.RELU2: MoEActivation.RELU2_NO_MUL,
}
def activation_without_mul(activation: str) -> str:
"""Get the non-gated variant of an activation function.
Args:
activation: The activation function name (e.g., "silu", "gelu")
Returns:
The non-gated activation name (e.g., "silu_no_mul", "gelu_no_mul")
"""
return MoEActivation.from_str(activation).without_mul().value
def apply_moe_activation(
activation: MoEActivation,
output: torch.Tensor,
input: torch.Tensor,
) -> torch.Tensor:
"""Apply MoE activation function."""
assert input.dim() == 2, "Input must be 2D"
assert output.dim() == 2, "Output must be 2D"
if activation.is_gated:
assert output.size(-1) * 2 == input.size(-1), (
f"{activation.value} expects 2x ratio: "
f"{output.size(-1) * 2} vs {input.size(-1)}"
)
else:
assert output.size(-1) == input.size(-1), (
f"{activation.value} expects equal sizes: "
f"{output.size(-1)} vs {input.size(-1)}"
)
# Activations with gated multiplication (gate × activation(up))
if activation == MoEActivation.SILU:
torch.ops._C.silu_and_mul(output, input)
elif activation == MoEActivation.GELU:
torch.ops._C.gelu_and_mul(output, input)
elif activation == MoEActivation.SWIGLUOAI:
torch.ops._C.swigluoai_and_mul(output, input)
elif activation == MoEActivation.SWIGLUSTEP:
from vllm.model_executor.layers.activation import swiglustep_and_mul_triton
swiglustep_and_mul_triton(output, input)
# Activations without gated multiplication
elif activation == MoEActivation.SILU_NO_MUL:
output.copy_(F.silu(input))
elif activation == MoEActivation.GELU_NO_MUL:
output.copy_(F.gelu(input))
elif activation == MoEActivation.RELU2_NO_MUL:
F.relu(input, inplace=True)
torch.square(input, out=output)
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
return output
......@@ -7,6 +7,7 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.forward_context import get_forward_context, is_forward_context_available
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
......@@ -303,8 +304,8 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: str) -> bool:
return activation in ["silu"]
def _supports_activation(activation: MoEActivation) -> bool:
return activation == MoEActivation.SILU
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
......@@ -338,7 +339,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# FIXME (varun): We should be able to dispatch only from the leader
# DP ranks in the case of TP > 1. At the moment, all the Ranks
......@@ -389,7 +390,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
......
......@@ -14,6 +14,7 @@ from vllm.distributed import (
get_tensor_model_parallel_rank,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
OCP_MX_DTYPES,
OCP_MX_Scheme,
......@@ -1132,7 +1133,7 @@ class FusedMoEConfig:
intermediate_size_per_partition: int
num_local_experts: int
num_logical_experts: int
activation: str
activation: MoEActivation
device: torch.device | str
routing_method: RoutingMethodType
moe_parallel_config: FusedMoEParallelConfig
......
......@@ -9,6 +9,7 @@ from torch.nn import functional as F
from vllm import _custom_ops as ops
from vllm._custom_ops import cpu_fused_moe, cpu_prepack_moe_weight
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.quantization.utils.layer_utils import replace_parameter
from vllm.utils.torch_utils import direct_register_custom_op
......@@ -36,9 +37,9 @@ def _swigluoai_forward_native(
# Map activation names to their native forward functions.
# Uses static methods or standalone functions to avoid instantiating CustomOp
# classes, which would call get_current_vllm_config() before config is set.
_CPU_MOE_ACT_FN: dict[str, Callable[[torch.Tensor], torch.Tensor]] = {
"silu": SiluAndMul.forward_native,
"swigluoai": _swigluoai_forward_native,
_CPU_MOE_ACT_FN: dict[MoEActivation, Callable[[torch.Tensor], torch.Tensor]] = {
MoEActivation.SILU: SiluAndMul.forward_native,
MoEActivation.SWIGLUOAI: _swigluoai_forward_native,
}
......@@ -168,9 +169,9 @@ class SGLFusedMOE:
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
activation: MoEActivation = MoEActivation.SILU,
) -> torch.Tensor:
assert activation == "silu", f"{activation} is not supported."
assert activation == MoEActivation.SILU, f"{activation} is not supported."
assert not apply_router_weight_on_input
topk_weights, topk_ids = select_experts(
hidden_states=x,
......@@ -235,7 +236,7 @@ class CPUFusedMOE:
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
activation: MoEActivation = MoEActivation.SILU,
) -> torch.Tensor:
assert activation in _CPU_MOE_ACT_FN, f"{activation} is not supported."
......@@ -353,7 +354,7 @@ class CPUFusedMOE:
input: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
activation: MoEActivation,
global_num_experts: int = -1,
skip_weighted: bool = False,
) -> torch.Tensor:
......@@ -371,7 +372,7 @@ class CPUFusedMOE:
getattr(layer, "w2_bias", None),
topk_weights,
topk_ids,
activation,
activation.value,
self.isa,
skip_weighted,
)
......@@ -383,7 +384,7 @@ class CPUFusedMOE:
input: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
activation: MoEActivation,
global_num_experts: int = -1,
skip_weighted: bool = False,
) -> torch.Tensor:
......@@ -419,6 +420,7 @@ def cpu_fused_moe_torch(
global_num_experts: int = -1,
skip_weighted: bool = False,
) -> None:
act = MoEActivation.from_str(activation)
layer = _CPU_MOE_LAYER_CACHE[layer_id]()
# Ref code from https://github.com/sgl-project/sglang/blob/716e682721397df103f347d22da8bd46c6016dab/python/sglang/srt/layers/moe/fused_moe_native.py#L53
......@@ -442,7 +444,7 @@ def cpu_fused_moe_torch(
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
gate_up = layer.gate_up_linear[i](tokens_for_this_expert) # type: ignore
gate_up = _CPU_MOE_ACT_FN[activation](gate_up)
gate_up = _CPU_MOE_ACT_FN[act](gate_up)
expert_out = layer.down_linear[i](gate_up) # type: ignore
outputs.append(expert_out)
start_idx = end_idx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment