Unverified Commit 5e584ce9 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[MoE Refactor] Remove SharedFusedMoE class (#35782)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent 1842447c
...@@ -32,7 +32,7 @@ from transformers import PretrainedConfig ...@@ -32,7 +32,7 @@ from transformers import PretrainedConfig
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.fused_moe import FusedMoE, SharedFusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
...@@ -260,7 +260,7 @@ class Glm4MoeLiteMTP(nn.Module, SupportsPP, Glm4MixtureOfExperts): ...@@ -260,7 +260,7 @@ class Glm4MoeLiteMTP(nn.Module, SupportsPP, Glm4MixtureOfExperts):
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
] ]
expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( expert_params_mapping = FusedMoE.make_expert_params_mapping(
self, self,
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
......
...@@ -42,7 +42,7 @@ from vllm.distributed import ( ...@@ -42,7 +42,7 @@ from vllm.distributed import (
) )
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
...@@ -438,7 +438,7 @@ class HunYuanSparseMoeBlock(nn.Module): ...@@ -438,7 +438,7 @@ class HunYuanSparseMoeBlock(nn.Module):
else: else:
self.shared_mlp = None self.shared_mlp = None
self.experts = SharedFusedMoE( self.experts = FusedMoE(
shared_experts=self.shared_mlp, shared_experts=self.shared_mlp,
num_experts=self.n_routed_experts, num_experts=self.n_routed_experts,
top_k=top_k, top_k=top_k,
...@@ -712,7 +712,7 @@ class HunYuanModel(nn.Module, EagleModelMixin): ...@@ -712,7 +712,7 @@ class HunYuanModel(nn.Module, EagleModelMixin):
if _is_moe(self.config): if _is_moe(self.config):
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
return SharedFusedMoE.make_expert_params_mapping( return FusedMoE.make_expert_params_mapping(
self, self,
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
......
...@@ -14,7 +14,7 @@ from vllm.distributed import ( ...@@ -14,7 +14,7 @@ from vllm.distributed import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.kda import KimiDeltaAttention from vllm.model_executor.layers.kda import KimiDeltaAttention
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
...@@ -144,7 +144,7 @@ class KimiMoE(nn.Module): ...@@ -144,7 +144,7 @@ class KimiMoE(nn.Module):
else: else:
self.shared_experts = None self.shared_experts = None
self.experts = SharedFusedMoE( self.experts = FusedMoE(
shared_experts=self.shared_experts, shared_experts=self.shared_experts,
num_experts=num_experts, num_experts=num_experts,
top_k=config.num_experts_per_token, top_k=config.num_experts_per_token,
...@@ -476,7 +476,7 @@ class KimiLinearModel(nn.Module): ...@@ -476,7 +476,7 @@ class KimiLinearModel(nn.Module):
if self.config.is_moe: if self.config.is_moe:
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( expert_params_mapping = FusedMoE.make_expert_params_mapping(
self, self,
ckpt_gate_proj_name="w1", ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2", ckpt_down_proj_name="w2",
......
...@@ -36,7 +36,7 @@ from vllm.model_executor.layers.attention import ( ...@@ -36,7 +36,7 @@ from vllm.model_executor.layers.attention import (
Attention, Attention,
ChunkedLocalAttention, ChunkedLocalAttention,
) )
from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
QKVParallelLinear, QKVParallelLinear,
...@@ -127,7 +127,7 @@ class Llama4MoE(nn.Module): ...@@ -127,7 +127,7 @@ class Llama4MoE(nn.Module):
self.n_physical_experts = self.n_local_experts + self.n_redundant_experts self.n_physical_experts = self.n_local_experts + self.n_redundant_experts
self.n_local_physical_experts = self.n_physical_experts // self.ep_size self.n_local_physical_experts = self.n_physical_experts // self.ep_size
self.experts = SharedFusedMoE( self.experts = FusedMoE(
shared_experts=self.shared_expert, shared_experts=self.shared_expert,
num_experts=config.num_local_experts, num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
...@@ -414,7 +414,7 @@ class Llama4Model(LlamaModel): ...@@ -414,7 +414,7 @@ class Llama4Model(LlamaModel):
params_dict: The dictionary of module parameters. params_dict: The dictionary of module parameters.
loaded_params: The set of already loaded parameters. loaded_params: The set of already loaded parameters.
expert_params_mapping: The mapping of expert parameters. Must be expert_params_mapping: The mapping of expert parameters. Must be
generated by SharedFusedMoE.make_expert_params_mapping(). generated by FusedMoE.make_expert_params_mapping().
fused: Whether the expert weights are fused into a single weight fused: Whether the expert weights are fused into a single weight
tensor or are separate weight tensors for each expert. tensor or are separate weight tensors for each expert.
When fused is True, loaded_weight should have shape of: When fused is True, loaded_weight should have shape of:
...@@ -554,7 +554,7 @@ class Llama4Model(LlamaModel): ...@@ -554,7 +554,7 @@ class Llama4Model(LlamaModel):
fused_experts_params = False fused_experts_params = False
# Expert parameter mapping for the case where the expert weights are # Expert parameter mapping for the case where the expert weights are
# not fused into a single weight tensor. # not fused into a single weight tensor.
expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( expert_params_mapping = FusedMoE.make_expert_params_mapping(
self, self,
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
...@@ -564,7 +564,7 @@ class Llama4Model(LlamaModel): ...@@ -564,7 +564,7 @@ class Llama4Model(LlamaModel):
) )
# Expert parameter mapping for the case where the expert weights are # Expert parameter mapping for the case where the expert weights are
# fused into a single weight tensor. # fused into a single weight tensor.
expert_params_mapping_fused = SharedFusedMoE.make_expert_params_mapping( expert_params_mapping_fused = FusedMoE.make_expert_params_mapping(
self, self,
ckpt_gate_proj_name="gate_up_proj", ckpt_gate_proj_name="gate_up_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
......
...@@ -34,8 +34,8 @@ from vllm.distributed.parallel_state import get_pp_group ...@@ -34,8 +34,8 @@ from vllm.distributed.parallel_state import get_pp_group
from vllm.model_executor.layers.activation import ReLUSquaredActivation from vllm.model_executor.layers.activation import ReLUSquaredActivation
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import ( from vllm.model_executor.layers.fused_moe import (
FusedMoE,
GateLinear, GateLinear,
SharedFusedMoE,
activation_without_mul, activation_without_mul,
) )
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -210,7 +210,7 @@ class NemotronHMoE(nn.Module): ...@@ -210,7 +210,7 @@ class NemotronHMoE(nn.Module):
self.fc1_latent_proj = None self.fc1_latent_proj = None
self.fc2_latent_proj = None self.fc2_latent_proj = None
self.experts = SharedFusedMoE( self.experts = FusedMoE(
shared_experts=self.shared_experts, shared_experts=self.shared_experts,
num_experts=config.n_routed_experts, num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
...@@ -652,7 +652,7 @@ class NemotronHModel(nn.Module): ...@@ -652,7 +652,7 @@ class NemotronHModel(nn.Module):
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
if self.has_moe: if self.has_moe:
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( expert_params_mapping = FusedMoE.make_expert_params_mapping(
# - FusedMoe.w1 (aka gate_proj) should be up_proj since that's # - FusedMoe.w1 (aka gate_proj) should be up_proj since that's
# what the activation is applied to # what the activation is applied to
# - FusedMoe.w3 (aka up_proj) should be ignored since we're # - FusedMoe.w3 (aka up_proj) should be ignored since we're
......
...@@ -44,7 +44,7 @@ from vllm.model_executor.layers.attention import ( ...@@ -44,7 +44,7 @@ from vllm.model_executor.layers.attention import (
Attention, Attention,
StaticSinkAttention, StaticSinkAttention,
) )
from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
...@@ -200,7 +200,7 @@ class OpenPanguMoE(nn.Module): ...@@ -200,7 +200,7 @@ class OpenPanguMoE(nn.Module):
else: else:
self.shared_experts = None self.shared_experts = None
self.experts = SharedFusedMoE( self.experts = FusedMoE(
shared_experts=self.shared_experts, shared_experts=self.shared_experts,
num_experts=config.n_routed_experts, num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
...@@ -1149,7 +1149,7 @@ class OpenPanguModel(nn.Module): ...@@ -1149,7 +1149,7 @@ class OpenPanguModel(nn.Module):
] ]
has_experts = hasattr(self.config, "n_routed_experts") has_experts = hasattr(self.config, "n_routed_experts")
if has_experts: if has_experts:
expert_merge_mapping = SharedFusedMoE.make_expert_params_mapping( expert_merge_mapping = FusedMoE.make_expert_params_mapping(
self, self,
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
......
...@@ -32,7 +32,7 @@ from vllm.distributed import ( ...@@ -32,7 +32,7 @@ from vllm.distributed import (
) )
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -353,7 +353,7 @@ class Param2MoEMoEBlock(nn.Module): ...@@ -353,7 +353,7 @@ class Param2MoEMoEBlock(nn.Module):
else: else:
self.shared_experts = None # type: ignore[assignment] self.shared_experts = None # type: ignore[assignment]
self.experts = SharedFusedMoE( self.experts = FusedMoE(
shared_experts=self.shared_experts, shared_experts=self.shared_experts,
num_experts=self.num_experts, num_experts=self.num_experts,
top_k=self.top_k, top_k=self.top_k,
...@@ -370,7 +370,7 @@ class Param2MoEMoEBlock(nn.Module): ...@@ -370,7 +370,7 @@ class Param2MoEMoEBlock(nn.Module):
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
) )
def maybe_get_fused_moe(self) -> SharedFusedMoE: def maybe_get_fused_moe(self) -> FusedMoE:
return self.experts return self.experts
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
...@@ -690,7 +690,7 @@ class Param2MoEModel(nn.Module): ...@@ -690,7 +690,7 @@ class Param2MoEModel(nn.Module):
return loaded_params return loaded_params
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return SharedFusedMoE.make_expert_params_mapping( return FusedMoE.make_expert_params_mapping(
self, self,
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
......
...@@ -40,7 +40,7 @@ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size ...@@ -40,7 +40,7 @@ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -164,7 +164,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ...@@ -164,7 +164,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
else: else:
self.shared_expert = None self.shared_expert = None
self.experts = SharedFusedMoE( self.experts = FusedMoE(
shared_experts=self.shared_expert, shared_experts=self.shared_expert,
num_experts=config.num_experts, num_experts=config.num_experts,
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
...@@ -418,7 +418,7 @@ class Qwen2MoeModel(nn.Module): ...@@ -418,7 +418,7 @@ class Qwen2MoeModel(nn.Module):
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
return SharedFusedMoE.make_expert_params_mapping( return FusedMoE.make_expert_params_mapping(
self, self,
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
......
...@@ -43,7 +43,7 @@ from vllm.distributed import ( ...@@ -43,7 +43,7 @@ from vllm.distributed import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -205,7 +205,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -205,7 +205,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
self.shared_expert_gate = None self.shared_expert_gate = None
self.shared_expert = None self.shared_expert = None
self.experts = SharedFusedMoE( self.experts = FusedMoE(
shared_experts=self.shared_expert, shared_experts=self.shared_expert,
gate=self.gate, gate=self.gate,
num_experts=self.n_routed_experts, num_experts=self.n_routed_experts,
...@@ -508,7 +508,7 @@ class Qwen3MoeModel(nn.Module, EagleModelMixin): ...@@ -508,7 +508,7 @@ class Qwen3MoeModel(nn.Module, EagleModelMixin):
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
return SharedFusedMoE.make_expert_params_mapping( return FusedMoE.make_expert_params_mapping(
self, self,
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
......
...@@ -23,7 +23,7 @@ from vllm.distributed import ( ...@@ -23,7 +23,7 @@ from vllm.distributed import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import ( from vllm.model_executor.layers.layernorm import (
GemmaRMSNorm as Qwen3NextRMSNorm, GemmaRMSNorm as Qwen3NextRMSNorm,
) )
...@@ -146,7 +146,7 @@ class Qwen3NextSparseMoeBlock(nn.Module): ...@@ -146,7 +146,7 @@ class Qwen3NextSparseMoeBlock(nn.Module):
else: else:
self.shared_expert = None self.shared_expert = None
self.experts = SharedFusedMoE( self.experts = FusedMoE(
shared_experts=self.shared_expert, shared_experts=self.shared_expert,
gate=self.gate, gate=self.gate,
num_experts=self.n_routed_experts, num_experts=self.n_routed_experts,
...@@ -533,7 +533,7 @@ class Qwen3NextModel(nn.Module, EagleModelMixin): ...@@ -533,7 +533,7 @@ class Qwen3NextModel(nn.Module, EagleModelMixin):
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
return SharedFusedMoE.make_expert_params_mapping( return FusedMoE.make_expert_params_mapping(
self, self,
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
......
...@@ -35,7 +35,7 @@ from vllm.distributed import ( ...@@ -35,7 +35,7 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
...@@ -335,7 +335,7 @@ class SarvamMLAMoE(nn.Module): ...@@ -335,7 +335,7 @@ class SarvamMLAMoE(nn.Module):
else: else:
self.shared_experts = None self.shared_experts = None
self.experts = SharedFusedMoE( self.experts = FusedMoE(
shared_experts=self.shared_experts, shared_experts=self.shared_experts,
num_experts=self.num_experts, num_experts=self.num_experts,
top_k=self.top_k, top_k=self.top_k,
...@@ -352,7 +352,7 @@ class SarvamMLAMoE(nn.Module): ...@@ -352,7 +352,7 @@ class SarvamMLAMoE(nn.Module):
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
) )
def maybe_get_fused_moe(self) -> SharedFusedMoE: def maybe_get_fused_moe(self) -> FusedMoE:
return self.experts return self.experts
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
...@@ -529,7 +529,7 @@ class SarvamMLAModel(nn.Module): ...@@ -529,7 +529,7 @@ class SarvamMLAModel(nn.Module):
return hidden_states return hidden_states
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return SharedFusedMoE.make_expert_params_mapping( return FusedMoE.make_expert_params_mapping(
self, self,
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
......
...@@ -24,7 +24,6 @@ from vllm.logger import init_logger ...@@ -24,7 +24,6 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul, SwigluStepAndMul from vllm.model_executor.layers.activation import SiluAndMul, SwigluStepAndMul
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import GemmaRMSNorm from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
...@@ -372,7 +371,7 @@ class FusedMoEBlock(nn.Module): ...@@ -372,7 +371,7 @@ class FusedMoEBlock(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.share_expert", prefix=f"{prefix}.share_expert",
) )
self.experts = SharedFusedMoE( self.experts = FusedMoE(
shared_experts=self.share_expert, shared_experts=self.share_expert,
gate=self.gate, gate=self.gate,
num_experts=config.moe_num_experts, num_experts=config.moe_num_experts,
......
...@@ -34,3 +34,16 @@ def length_from_prompt_token_ids_or_embeds( ...@@ -34,3 +34,16 @@ def length_from_prompt_token_ids_or_embeds(
f" prompt_embeds={prompt_embeds_len}" f" prompt_embeds={prompt_embeds_len}"
) )
return prompt_token_len return prompt_token_len
def is_moe_layer(module: torch.nn.Module) -> bool:
# TODO(bnell): Should use isinstance but can't due to circular dependencies.
def _check_bases(cls):
if cls.__name__ == "FusedMoE":
return True
for b in cls.__bases__:
if _check_bases(b):
return True
return _check_bases(module.__class__)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment