Unverified Commit 5e584ce9 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[MoE Refactor] Remove SharedFusedMoE class (#35782)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent 1842447c
......@@ -32,7 +32,7 @@ from transformers import PretrainedConfig
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig
from vllm.model_executor.layers.fused_moe import FusedMoE, SharedFusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
......@@ -260,7 +260,7 @@ class Glm4MoeLiteMTP(nn.Module, SupportsPP, Glm4MixtureOfExperts):
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
]
expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
expert_params_mapping = FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
......
......@@ -42,7 +42,7 @@ from vllm.distributed import (
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
......@@ -438,7 +438,7 @@ class HunYuanSparseMoeBlock(nn.Module):
else:
self.shared_mlp = None
self.experts = SharedFusedMoE(
self.experts = FusedMoE(
shared_experts=self.shared_mlp,
num_experts=self.n_routed_experts,
top_k=top_k,
......@@ -712,7 +712,7 @@ class HunYuanModel(nn.Module, EagleModelMixin):
if _is_moe(self.config):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return SharedFusedMoE.make_expert_params_mapping(
return FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
......
......@@ -14,7 +14,7 @@ from vllm.distributed import (
)
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.kda import KimiDeltaAttention
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
......@@ -144,7 +144,7 @@ class KimiMoE(nn.Module):
else:
self.shared_experts = None
self.experts = SharedFusedMoE(
self.experts = FusedMoE(
shared_experts=self.shared_experts,
num_experts=num_experts,
top_k=config.num_experts_per_token,
......@@ -476,7 +476,7 @@ class KimiLinearModel(nn.Module):
if self.config.is_moe:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
expert_params_mapping = FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2",
......
......@@ -36,7 +36,7 @@ from vllm.model_executor.layers.attention import (
Attention,
ChunkedLocalAttention,
)
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
QKVParallelLinear,
......@@ -127,7 +127,7 @@ class Llama4MoE(nn.Module):
self.n_physical_experts = self.n_local_experts + self.n_redundant_experts
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
self.experts = SharedFusedMoE(
self.experts = FusedMoE(
shared_experts=self.shared_expert,
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
......@@ -414,7 +414,7 @@ class Llama4Model(LlamaModel):
params_dict: The dictionary of module parameters.
loaded_params: The set of already loaded parameters.
expert_params_mapping: The mapping of expert parameters. Must be
generated by SharedFusedMoE.make_expert_params_mapping().
generated by FusedMoE.make_expert_params_mapping().
fused: Whether the expert weights are fused into a single weight
tensor or are separate weight tensors for each expert.
When fused is True, loaded_weight should have shape of:
......@@ -554,7 +554,7 @@ class Llama4Model(LlamaModel):
fused_experts_params = False
# Expert parameter mapping for the case where the expert weights are
# not fused into a single weight tensor.
expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
expert_params_mapping = FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
......@@ -564,7 +564,7 @@ class Llama4Model(LlamaModel):
)
# Expert parameter mapping for the case where the expert weights are
# fused into a single weight tensor.
expert_params_mapping_fused = SharedFusedMoE.make_expert_params_mapping(
expert_params_mapping_fused = FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_up_proj",
ckpt_down_proj_name="down_proj",
......
......@@ -34,8 +34,8 @@ from vllm.distributed.parallel_state import get_pp_group
from vllm.model_executor.layers.activation import ReLUSquaredActivation
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
GateLinear,
SharedFusedMoE,
activation_without_mul,
)
from vllm.model_executor.layers.layernorm import RMSNorm
......@@ -210,7 +210,7 @@ class NemotronHMoE(nn.Module):
self.fc1_latent_proj = None
self.fc2_latent_proj = None
self.experts = SharedFusedMoE(
self.experts = FusedMoE(
shared_experts=self.shared_experts,
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
......@@ -652,7 +652,7 @@ class NemotronHModel(nn.Module):
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
if self.has_moe:
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
expert_params_mapping = FusedMoE.make_expert_params_mapping(
# - FusedMoe.w1 (aka gate_proj) should be up_proj since that's
# what the activation is applied to
# - FusedMoe.w3 (aka up_proj) should be ignored since we're
......
......@@ -44,7 +44,7 @@ from vllm.model_executor.layers.attention import (
Attention,
StaticSinkAttention,
)
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
......@@ -200,7 +200,7 @@ class OpenPanguMoE(nn.Module):
else:
self.shared_experts = None
self.experts = SharedFusedMoE(
self.experts = FusedMoE(
shared_experts=self.shared_experts,
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
......@@ -1149,7 +1149,7 @@ class OpenPanguModel(nn.Module):
]
has_experts = hasattr(self.config, "n_routed_experts")
if has_experts:
expert_merge_mapping = SharedFusedMoE.make_expert_params_mapping(
expert_merge_mapping = FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
......
......@@ -32,7 +32,7 @@ from vllm.distributed import (
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
......@@ -353,7 +353,7 @@ class Param2MoEMoEBlock(nn.Module):
else:
self.shared_experts = None # type: ignore[assignment]
self.experts = SharedFusedMoE(
self.experts = FusedMoE(
shared_experts=self.shared_experts,
num_experts=self.num_experts,
top_k=self.top_k,
......@@ -370,7 +370,7 @@ class Param2MoEMoEBlock(nn.Module):
routed_scaling_factor=self.routed_scaling_factor,
)
def maybe_get_fused_moe(self) -> SharedFusedMoE:
def maybe_get_fused_moe(self) -> FusedMoE:
return self.experts
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
......@@ -690,7 +690,7 @@ class Param2MoEModel(nn.Module):
return loaded_params
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return SharedFusedMoE.make_expert_params_mapping(
return FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
......
......@@ -40,7 +40,7 @@ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
......@@ -164,7 +164,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
else:
self.shared_expert = None
self.experts = SharedFusedMoE(
self.experts = FusedMoE(
shared_experts=self.shared_expert,
num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
......@@ -418,7 +418,7 @@ class Qwen2MoeModel(nn.Module):
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return SharedFusedMoE.make_expert_params_mapping(
return FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
......
......@@ -43,7 +43,7 @@ from vllm.distributed import (
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
......@@ -205,7 +205,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
self.shared_expert_gate = None
self.shared_expert = None
self.experts = SharedFusedMoE(
self.experts = FusedMoE(
shared_experts=self.shared_expert,
gate=self.gate,
num_experts=self.n_routed_experts,
......@@ -508,7 +508,7 @@ class Qwen3MoeModel(nn.Module, EagleModelMixin):
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return SharedFusedMoE.make_expert_params_mapping(
return FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
......
......@@ -23,7 +23,7 @@ from vllm.distributed import (
)
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import (
GemmaRMSNorm as Qwen3NextRMSNorm,
)
......@@ -146,7 +146,7 @@ class Qwen3NextSparseMoeBlock(nn.Module):
else:
self.shared_expert = None
self.experts = SharedFusedMoE(
self.experts = FusedMoE(
shared_experts=self.shared_expert,
gate=self.gate,
num_experts=self.n_routed_experts,
......@@ -533,7 +533,7 @@ class Qwen3NextModel(nn.Module, EagleModelMixin):
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return SharedFusedMoE.make_expert_params_mapping(
return FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
......
......@@ -35,7 +35,7 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
......@@ -335,7 +335,7 @@ class SarvamMLAMoE(nn.Module):
else:
self.shared_experts = None
self.experts = SharedFusedMoE(
self.experts = FusedMoE(
shared_experts=self.shared_experts,
num_experts=self.num_experts,
top_k=self.top_k,
......@@ -352,7 +352,7 @@ class SarvamMLAMoE(nn.Module):
routed_scaling_factor=self.routed_scaling_factor,
)
def maybe_get_fused_moe(self) -> SharedFusedMoE:
def maybe_get_fused_moe(self) -> FusedMoE:
return self.experts
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
......@@ -529,7 +529,7 @@ class SarvamMLAModel(nn.Module):
return hidden_states
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return SharedFusedMoE.make_expert_params_mapping(
return FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
......
......@@ -24,7 +24,6 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul, SwigluStepAndMul
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
......@@ -372,7 +371,7 @@ class FusedMoEBlock(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.share_expert",
)
self.experts = SharedFusedMoE(
self.experts = FusedMoE(
shared_experts=self.share_expert,
gate=self.gate,
num_experts=config.moe_num_experts,
......
......@@ -34,3 +34,16 @@ def length_from_prompt_token_ids_or_embeds(
f" prompt_embeds={prompt_embeds_len}"
)
return prompt_token_len
def is_moe_layer(module: torch.nn.Module) -> bool:
# TODO(bnell): Should use isinstance but can't due to circular dependencies.
def _check_bases(cls):
if cls.__name__ == "FusedMoE":
return True
for b in cls.__bases__:
if _check_bases(b):
return True
return _check_bases(module.__class__)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment