Unverified Commit ff1f83b0 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Refactor] Replace `activation: str` with `MoEActivation` enum (#33843)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
Signed-off-by: default avatarMichael Goin <mgoin64@gmail.com>
parent 83b47f67
......@@ -12,6 +12,10 @@ from torch.nn.parameter import Parameter, UninitializedParameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import (
MoEActivation,
apply_moe_activation,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
......@@ -246,16 +250,13 @@ def _fused_moe_gguf(
qweight_type2: int,
activation: str,
) -> torch.Tensor:
activation_enum = MoEActivation.from_str(activation)
def act(x: torch.Tensor):
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
if activation == "silu":
torch.ops._C.silu_and_mul(out, x)
elif activation == "gelu":
torch.ops._C.gelu_and_mul(out, x)
else:
raise ValueError(f"Unsupported activation: {activation}")
apply_moe_activation(activation_enum, out, x)
return out
# lazy import to avoid triggering triton import in CPU backend
......@@ -637,7 +638,6 @@ class GGUFMoEMethod(FusedMoEMethodBase):
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert layer.activation == "silu", "Only SiLU activation is supported."
if layer.apply_router_weight_on_input:
raise NotImplementedError(
"Apply router weight on input is not supported for"
......@@ -652,7 +652,7 @@ class GGUFMoEMethod(FusedMoEMethodBase):
topk_ids,
layer.w13_qweight_type.weight_type,
layer.w2_qweight_type.weight_type,
layer.activation,
layer.activation.value,
)
......
......@@ -10,6 +10,7 @@ from torch.nn.parameter import Parameter
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
......@@ -936,7 +937,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
)
# TODO(rob): this validation should happen at kernel selection
# time in the oracle rather than here.
assert layer.activation == "silu", (
assert layer.activation == MoEActivation.SILU, (
f"Expected 'silu' activation but got {layer.activation}"
)
assert not layer.renormalize
......@@ -965,7 +966,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
# TODO(rob): this validation should happen at kernel selection
# time in the oracle rather than here.
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
assert layer.activation in ("silu", "relu2_no_mul"), (
assert layer.activation in (
MoEActivation.SILU,
MoEActivation.RELU2_NO_MUL,
), (
"Expected activation to be in ('silu', 'relu2_no_mul'),"
f"but got {layer.activation}"
)
......
......@@ -6,6 +6,7 @@ from typing import Any
import torch
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
int4_w4a16_moe_quant_config,
......@@ -371,7 +372,9 @@ class MoeWNA16Method(FusedMoEMethodBase):
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
assert layer.activation == "silu", "Only SiLU activation is supported."
assert layer.activation == MoEActivation.SILU, (
f"Only SiLU activation is supported, not {layer.activation}."
)
return fused_experts(
x,
......
......@@ -13,6 +13,7 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoEConfig,
FusedMoEMethodBase,
MoEActivation,
)
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import (
......@@ -1141,8 +1142,9 @@ class XpuMxfp4MoEMethod(Mxfp4MoEMethod):
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor:
assert layer.activation == "swigluoai", (
"Only swiglu_oai activation is supported for XPU MXFP4 MoE"
assert layer.activation == MoEActivation.SWIGLUOAI, (
"Only swiglu_oai activation is supported for "
f"XPU MXFP4 MoE, not {layer.activation}."
)
from vllm_xpu_kernels.fused_moe_interface import xpu_fused_moe
......
......@@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoEConfig,
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
MoEActivation,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
......@@ -438,7 +439,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
expert_map=layer.expert_map,
)
elif self.use_marlin:
assert layer.activation == "silu", (
assert layer.activation == MoEActivation.SILU, (
f"{layer.activation} not supported for Marlin MoE."
)
return fused_marlin_moe(
......
......@@ -9,6 +9,7 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
......@@ -64,9 +65,9 @@ def _supports_quant_scheme(
return (weight_key, activation_key) in SUPPORTED_W_A
def _supports_activation(activation: str) -> bool:
def _supports_activation(activation: MoEActivation) -> bool:
"""Supports silu activation only."""
return activation in ["silu"]
return activation in [MoEActivation.SILU]
def _supports_routing_method(
......@@ -267,7 +268,7 @@ def flashinfer_trtllm_fp4_moe(
x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
router_logits: torch.Tensor,
top_k: int,
activation: str,
activation: MoEActivation,
global_num_experts: int,
num_expert_group: int | None,
topk_group: int | None,
......@@ -297,7 +298,7 @@ def flashinfer_trtllm_fp4_moe(
from vllm.model_executor.models.llama4 import Llama4MoE
# https://github.com/flashinfer-ai/flashinfer/blob/f0277fd1bff90e309e5c19cab36c5dae056d685d/flashinfer/fused_moe/core.py#L2404
assert activation == "silu", (
assert activation == MoEActivation.SILU, (
"Only SiLU activation is supported for FlashInfer TRTLLM FP4 MoE. "
f"{activation} found instead."
)
......@@ -365,7 +366,7 @@ def flashinfer_trtllm_fp4_routed_moe(
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
top_k: int,
activation: str,
activation: MoEActivation,
global_num_experts: int,
) -> torch.Tensor:
"""
......@@ -387,7 +388,7 @@ def flashinfer_trtllm_fp4_routed_moe(
import flashinfer
# https://github.com/flashinfer-ai/flashinfer/blob/f0277fd1bff90e309e5c19cab36c5dae056d685d/flashinfer/fused_moe/core.py#L2535
assert activation == "silu", (
assert activation == MoEActivation.SILU, (
"Only SiLU activation is supported for FlashInfer TRTLLM FP4 Routed MoE. "
f"{activation} found instead."
)
......
......@@ -6,6 +6,7 @@ from typing import Any
import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.utils.import_utils import has_triton_kernels
......@@ -88,7 +89,7 @@ def _can_support_mxfp4(
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
scoring_func: str = "softmax",
activation: str = "swigluoai",
activation: MoEActivation = MoEActivation.SWIGLUOAI,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
......@@ -101,7 +102,7 @@ def _can_support_mxfp4(
or e_score_correction_bias
or apply_router_weight_on_input
or scoring_func != "softmax"
or activation != "swigluoai"
or activation != MoEActivation.SWIGLUOAI
or expert_load_view
or logical_to_physical_map
or logical_replica_count
......
......@@ -33,8 +33,11 @@ from vllm.distributed.communication_op import tensor_model_parallel_all_gather
from vllm.distributed.parallel_state import get_pp_group
from vllm.model_executor.layers.activation import ReLUSquaredActivation
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE, SharedFusedMoE
from vllm.model_executor.layers.fused_moe.utils import activation_without_mul
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
SharedFusedMoE,
activation_without_mul,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment