"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "6a512a00dfa306762c2878bffc3a5664a758d105"
Unverified Commit ff1f83b0 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Refactor] Replace `activation: str` with `MoEActivation` enum (#33843)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
Signed-off-by: default avatarMichael Goin <mgoin64@gmail.com>
parent 83b47f67
...@@ -12,6 +12,10 @@ from torch.nn.parameter import Parameter, UninitializedParameter ...@@ -12,6 +12,10 @@ from torch.nn.parameter import Parameter, UninitializedParameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import (
MoEActivation,
apply_moe_activation,
)
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
...@@ -246,16 +250,13 @@ def _fused_moe_gguf( ...@@ -246,16 +250,13 @@ def _fused_moe_gguf(
qweight_type2: int, qweight_type2: int,
activation: str, activation: str,
) -> torch.Tensor: ) -> torch.Tensor:
activation_enum = MoEActivation.from_str(activation)
def act(x: torch.Tensor): def act(x: torch.Tensor):
d = x.shape[-1] // 2 d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,) output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device) out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
if activation == "silu": apply_moe_activation(activation_enum, out, x)
torch.ops._C.silu_and_mul(out, x)
elif activation == "gelu":
torch.ops._C.gelu_and_mul(out, x)
else:
raise ValueError(f"Unsupported activation: {activation}")
return out return out
# lazy import to avoid triggering triton import in CPU backend # lazy import to avoid triggering triton import in CPU backend
...@@ -637,7 +638,6 @@ class GGUFMoEMethod(FusedMoEMethodBase): ...@@ -637,7 +638,6 @@ class GGUFMoEMethod(FusedMoEMethodBase):
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert layer.activation == "silu", "Only SiLU activation is supported."
if layer.apply_router_weight_on_input: if layer.apply_router_weight_on_input:
raise NotImplementedError( raise NotImplementedError(
"Apply router weight on input is not supported for" "Apply router weight on input is not supported for"
...@@ -652,7 +652,7 @@ class GGUFMoEMethod(FusedMoEMethodBase): ...@@ -652,7 +652,7 @@ class GGUFMoEMethod(FusedMoEMethodBase):
topk_ids, topk_ids,
layer.w13_qweight_type.weight_type, layer.w13_qweight_type.weight_type,
layer.w2_qweight_type.weight_type, layer.w2_qweight_type.weight_type,
layer.activation, layer.activation.value,
) )
......
...@@ -10,6 +10,7 @@ from torch.nn.parameter import Parameter ...@@ -10,6 +10,7 @@ from torch.nn.parameter import Parameter
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
...@@ -936,7 +937,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -936,7 +937,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
) )
# TODO(rob): this validation should happen at kernel selection # TODO(rob): this validation should happen at kernel selection
# time in the oracle rather than here. # time in the oracle rather than here.
assert layer.activation == "silu", ( assert layer.activation == MoEActivation.SILU, (
f"Expected 'silu' activation but got {layer.activation}" f"Expected 'silu' activation but got {layer.activation}"
) )
assert not layer.renormalize assert not layer.renormalize
...@@ -965,7 +966,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -965,7 +966,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
# TODO(rob): this validation should happen at kernel selection # TODO(rob): this validation should happen at kernel selection
# time in the oracle rather than here. # time in the oracle rather than here.
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
assert layer.activation in ("silu", "relu2_no_mul"), ( assert layer.activation in (
MoEActivation.SILU,
MoEActivation.RELU2_NO_MUL,
), (
"Expected activation to be in ('silu', 'relu2_no_mul')," "Expected activation to be in ('silu', 'relu2_no_mul'),"
f"but got {layer.activation}" f"but got {layer.activation}"
) )
......
...@@ -6,6 +6,7 @@ from typing import Any ...@@ -6,6 +6,7 @@ from typing import Any
import torch import torch
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, FusedMoEQuantConfig,
int4_w4a16_moe_quant_config, int4_w4a16_moe_quant_config,
...@@ -371,7 +372,9 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -371,7 +372,9 @@ class MoeWNA16Method(FusedMoEMethodBase):
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
assert layer.activation == "silu", "Only SiLU activation is supported." assert layer.activation == MoEActivation.SILU, (
f"Only SiLU activation is supported, not {layer.activation}."
)
return fused_experts( return fused_experts(
x, x,
......
...@@ -13,6 +13,7 @@ from vllm.model_executor.layers.fused_moe import ( ...@@ -13,6 +13,7 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoE,
FusedMoEConfig, FusedMoEConfig,
FusedMoEMethodBase, FusedMoEMethodBase,
MoEActivation,
) )
from vllm.model_executor.layers.fused_moe import modular_kernel as mk from vllm.model_executor.layers.fused_moe import modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
...@@ -1141,8 +1142,9 @@ class XpuMxfp4MoEMethod(Mxfp4MoEMethod): ...@@ -1141,8 +1142,9 @@ class XpuMxfp4MoEMethod(Mxfp4MoEMethod):
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
assert layer.activation == "swigluoai", ( assert layer.activation == MoEActivation.SWIGLUOAI, (
"Only swiglu_oai activation is supported for XPU MXFP4 MoE" "Only swiglu_oai activation is supported for "
f"XPU MXFP4 MoE, not {layer.activation}."
) )
from vllm_xpu_kernels.fused_moe_interface import xpu_fused_moe from vllm_xpu_kernels.fused_moe_interface import xpu_fused_moe
......
...@@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe import ( ...@@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEMethodBase, FusedMoEMethodBase,
FusedMoeWeightScaleSupported, FusedMoeWeightScaleSupported,
MoEActivation,
) )
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, FusedMoEQuantConfig,
...@@ -438,7 +439,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -438,7 +439,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
expert_map=layer.expert_map, expert_map=layer.expert_map,
) )
elif self.use_marlin: elif self.use_marlin:
assert layer.activation == "silu", ( assert layer.activation == MoEActivation.SILU, (
f"{layer.activation} not supported for Marlin MoE." f"{layer.activation} not supported for Marlin MoE."
) )
return fused_marlin_moe( return fused_marlin_moe(
......
...@@ -9,6 +9,7 @@ import torch ...@@ -9,6 +9,7 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEParallelConfig, FusedMoEParallelConfig,
...@@ -64,9 +65,9 @@ def _supports_quant_scheme( ...@@ -64,9 +65,9 @@ def _supports_quant_scheme(
return (weight_key, activation_key) in SUPPORTED_W_A return (weight_key, activation_key) in SUPPORTED_W_A
def _supports_activation(activation: str) -> bool: def _supports_activation(activation: MoEActivation) -> bool:
"""Supports silu activation only.""" """Supports silu activation only."""
return activation in ["silu"] return activation in [MoEActivation.SILU]
def _supports_routing_method( def _supports_routing_method(
...@@ -267,7 +268,7 @@ def flashinfer_trtllm_fp4_moe( ...@@ -267,7 +268,7 @@ def flashinfer_trtllm_fp4_moe(
x: torch.Tensor | tuple[torch.Tensor, torch.Tensor], x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
activation: str, activation: MoEActivation,
global_num_experts: int, global_num_experts: int,
num_expert_group: int | None, num_expert_group: int | None,
topk_group: int | None, topk_group: int | None,
...@@ -297,7 +298,7 @@ def flashinfer_trtllm_fp4_moe( ...@@ -297,7 +298,7 @@ def flashinfer_trtllm_fp4_moe(
from vllm.model_executor.models.llama4 import Llama4MoE from vllm.model_executor.models.llama4 import Llama4MoE
# https://github.com/flashinfer-ai/flashinfer/blob/f0277fd1bff90e309e5c19cab36c5dae056d685d/flashinfer/fused_moe/core.py#L2404 # https://github.com/flashinfer-ai/flashinfer/blob/f0277fd1bff90e309e5c19cab36c5dae056d685d/flashinfer/fused_moe/core.py#L2404
assert activation == "silu", ( assert activation == MoEActivation.SILU, (
"Only SiLU activation is supported for FlashInfer TRTLLM FP4 MoE. " "Only SiLU activation is supported for FlashInfer TRTLLM FP4 MoE. "
f"{activation} found instead." f"{activation} found instead."
) )
...@@ -365,7 +366,7 @@ def flashinfer_trtllm_fp4_routed_moe( ...@@ -365,7 +366,7 @@ def flashinfer_trtllm_fp4_routed_moe(
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
top_k: int, top_k: int,
activation: str, activation: MoEActivation,
global_num_experts: int, global_num_experts: int,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
...@@ -387,7 +388,7 @@ def flashinfer_trtllm_fp4_routed_moe( ...@@ -387,7 +388,7 @@ def flashinfer_trtllm_fp4_routed_moe(
import flashinfer import flashinfer
# https://github.com/flashinfer-ai/flashinfer/blob/f0277fd1bff90e309e5c19cab36c5dae056d685d/flashinfer/fused_moe/core.py#L2535 # https://github.com/flashinfer-ai/flashinfer/blob/f0277fd1bff90e309e5c19cab36c5dae056d685d/flashinfer/fused_moe/core.py#L2535
assert activation == "silu", ( assert activation == MoEActivation.SILU, (
"Only SiLU activation is supported for FlashInfer TRTLLM FP4 Routed MoE. " "Only SiLU activation is supported for FlashInfer TRTLLM FP4 Routed MoE. "
f"{activation} found instead." f"{activation} found instead."
) )
......
...@@ -6,6 +6,7 @@ from typing import Any ...@@ -6,6 +6,7 @@ from typing import Any
import torch import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import triton from vllm.triton_utils import triton
from vllm.utils.import_utils import has_triton_kernels from vllm.utils.import_utils import has_triton_kernels
...@@ -88,7 +89,7 @@ def _can_support_mxfp4( ...@@ -88,7 +89,7 @@ def _can_support_mxfp4(
e_score_correction_bias: torch.Tensor | None = None, e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
scoring_func: str = "softmax", scoring_func: str = "softmax",
activation: str = "swigluoai", activation: MoEActivation = MoEActivation.SWIGLUOAI,
expert_load_view: torch.Tensor | None = None, expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None, logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None,
...@@ -101,7 +102,7 @@ def _can_support_mxfp4( ...@@ -101,7 +102,7 @@ def _can_support_mxfp4(
or e_score_correction_bias or e_score_correction_bias
or apply_router_weight_on_input or apply_router_weight_on_input
or scoring_func != "softmax" or scoring_func != "softmax"
or activation != "swigluoai" or activation != MoEActivation.SWIGLUOAI
or expert_load_view or expert_load_view
or logical_to_physical_map or logical_to_physical_map
or logical_replica_count or logical_replica_count
......
...@@ -33,8 +33,11 @@ from vllm.distributed.communication_op import tensor_model_parallel_all_gather ...@@ -33,8 +33,11 @@ from vllm.distributed.communication_op import tensor_model_parallel_all_gather
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
from vllm.model_executor.layers.activation import ReLUSquaredActivation from vllm.model_executor.layers.activation import ReLUSquaredActivation
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE, SharedFusedMoE from vllm.model_executor.layers.fused_moe import (
from vllm.model_executor.layers.fused_moe.utils import activation_without_mul FusedMoE,
SharedFusedMoE,
activation_without_mul,
)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment