Commit 47e66c24 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Model] Apply shared experts overlap optimization to all models with shared experts (#26145)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent 3b736e1c
...@@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import ( ...@@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPermuteExpertsUnpermute, FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize, FusedMoEPrepareAndFinalize,
) )
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe.utils import activation_without_mul from vllm.model_executor.layers.fused_moe.utils import activation_without_mul
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON
...@@ -42,6 +43,7 @@ __all__ = [ ...@@ -42,6 +43,7 @@ __all__ = [
"FusedMoEPermuteExpertsUnpermute", "FusedMoEPermuteExpertsUnpermute",
"FusedMoEActivationFormat", "FusedMoEActivationFormat",
"FusedMoEPrepareAndFinalize", "FusedMoEPrepareAndFinalize",
"SharedFusedMoE",
"activation_without_mul", "activation_without_mul",
"override_config", "override_config",
"get_config", "get_config",
......
...@@ -18,13 +18,21 @@ class SharedFusedMoE(FusedMoE): ...@@ -18,13 +18,21 @@ class SharedFusedMoE(FusedMoE):
def __init__( def __init__(
self, self,
shared_experts: torch.nn.Module, shared_experts: Optional[torch.nn.Module],
use_overlapped: bool = True, use_overlapped: bool = True,
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
self._shared_experts = shared_experts self._shared_experts = shared_experts
self.use_overlapped = use_overlapped # Disable shared expert overlap if EP is disabled or we are not using
# flashinfer + DP since there is nothing to be gained in this case.
# Disabling the overlap optimization also prevents the shared experts
# from being hidden from torch.compile.
self.use_overlapped = (
use_overlapped
and not (self.use_ep or self.use_flashinfer_cutlass_kernels)
and self._shared_experts is not None
)
@property @property
def shared_experts(self) -> Optional[torch.nn.Module]: def shared_experts(self) -> Optional[torch.nn.Module]:
...@@ -36,16 +44,19 @@ class SharedFusedMoE(FusedMoE): ...@@ -36,16 +44,19 @@ class SharedFusedMoE(FusedMoE):
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
if not self.use_overlapped: if not self.use_overlapped:
if self._shared_experts is not None:
shared_out = self._shared_experts(hidden_states) shared_out = self._shared_experts(hidden_states)
# Reduce outputs if necessary, since the MLP should # Reduce shared expert outputs if necessary, since the MLP
# have been created with reduce_results=False. # should have been created with reduce_results=False.
if ( if (
self.reduce_results self.reduce_results
and self.tp_size > 1 and self.tp_size > 1
and self.must_reduce_shared_expert_outputs() and self.must_reduce_shared_expert_outputs()
): ):
shared_out = tensor_model_parallel_all_reduce(shared_out) shared_out = tensor_model_parallel_all_reduce(shared_out)
else:
shared_out = None
fused_out = super().forward( fused_out = super().forward(
hidden_states=hidden_states, hidden_states=hidden_states,
......
...@@ -741,6 +741,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -741,6 +741,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w13_input_scale = None layer.w13_input_scale = None
layer.w2_input_scale = None layer.w2_input_scale = None
self.rocm_aiter_moe_enabled = False
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
# Lazy import to avoid importing triton too early. # Lazy import to avoid importing triton too early.
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.model_executor.layers.shared_fused_moe.shared_fused_moe import SharedFusedMoE
__all__ = ["SharedFusedMoE"]
...@@ -13,7 +13,7 @@ from vllm.config import VllmConfig ...@@ -13,7 +13,7 @@ from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_rank from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
...@@ -206,7 +206,7 @@ class AriaProjector(nn.Module): ...@@ -206,7 +206,7 @@ class AriaProjector(nn.Module):
return out return out
class AriaFusedMoE(FusedMoE): class AriaFusedMoE(SharedFusedMoE):
def weight_loader( def weight_loader(
self, param: nn.Parameter, loaded_weight: torch.Tensor, shard_id: str self, param: nn.Parameter, loaded_weight: torch.Tensor, shard_id: str
) -> None: ) -> None:
...@@ -260,7 +260,16 @@ class AriaTextMoELayer(nn.Module): ...@@ -260,7 +260,16 @@ class AriaTextMoELayer(nn.Module):
torch.empty((self.config.moe_num_experts, self.config.hidden_size)) torch.empty((self.config.moe_num_experts, self.config.hidden_size))
) )
self.shared_experts = LlamaMLP(
config.hidden_size,
config.intermediate_size * config.moe_num_shared_experts,
"silu",
quant_config=quant_config,
bias=config.mlp_bias,
)
self.experts = AriaFusedMoE( self.experts = AriaFusedMoE(
shared_experts=self.shared_experts,
num_experts=config.moe_num_experts, num_experts=config.moe_num_experts,
top_k=config.moe_topk, top_k=config.moe_topk,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
...@@ -269,13 +278,6 @@ class AriaTextMoELayer(nn.Module): ...@@ -269,13 +278,6 @@ class AriaTextMoELayer(nn.Module):
reduce_results=True, reduce_results=True,
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
) )
self.shared_experts = LlamaMLP(
config.hidden_size,
config.intermediate_size * config.moe_num_shared_experts,
"silu",
quant_config=quant_config,
bias=config.mlp_bias,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
""" """
...@@ -291,12 +293,12 @@ class AriaTextMoELayer(nn.Module): ...@@ -291,12 +293,12 @@ class AriaTextMoELayer(nn.Module):
router_output = torch.nn.functional.linear(hidden_states, self.router_weight) router_output = torch.nn.functional.linear(hidden_states, self.router_weight)
hidden_states_copy = hidden_states.clone()
# NOTE: hidden_states will be modified inplace by `FusedMoE`
sparse_expert_output = self.experts(hidden_states, router_output) sparse_expert_output = self.experts(hidden_states, router_output)
shared_expert_output = self.shared_experts(hidden_states_copy)
return sparse_expert_output + shared_expert_output if self.shared_experts is not None:
return sparse_expert_output[0] + sparse_expert_output[1]
else:
return sparse_expert_output
class AriaTextDecoderLayer(LlamaDecoderLayer): class AriaTextDecoderLayer(LlamaDecoderLayer):
......
...@@ -43,7 +43,7 @@ from vllm.distributed import ( ...@@ -43,7 +43,7 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -276,22 +276,6 @@ class BailingMoE(nn.Module): ...@@ -276,22 +276,6 @@ class BailingMoE(nn.Module):
# default value for scoring_func # default value for scoring_func
self.score_function = "softmax" self.score_function = "softmax"
self.experts = FusedMoE(
num_experts=self.num_experts,
top_k=self.top_k,
hidden_size=self.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=self.norm_expert_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
scoring_func=self.score_function,
e_score_correction_bias=self.gate.expert_bias,
num_expert_group=self.n_group,
topk_group=self.topk_group,
use_grouped_topk=self.use_grouped_topk,
)
if self.num_shared_experts > 0: if self.num_shared_experts > 0:
if hasattr(config, "moe_shared_expert_intermediate_size"): if hasattr(config, "moe_shared_expert_intermediate_size"):
intermediate_size = config.moe_shared_expert_intermediate_size intermediate_size = config.moe_shared_expert_intermediate_size
...@@ -308,11 +292,27 @@ class BailingMoE(nn.Module): ...@@ -308,11 +292,27 @@ class BailingMoE(nn.Module):
else: else:
self.shared_experts = None self.shared_experts = None
self.experts = SharedFusedMoE(
shared_experts=self.shared_experts,
num_experts=self.num_experts,
top_k=self.top_k,
hidden_size=self.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=self.norm_expert_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
scoring_func=self.score_function,
e_score_correction_bias=self.gate.expert_bias,
num_expert_group=self.n_group,
topk_group=self.topk_group,
use_grouped_topk=self.use_grouped_topk,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape num_tokens, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_size) hidden_states = hidden_states.view(-1, hidden_size)
if self.shared_experts:
shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states.to(self.router_dtype)) router_logits = self.gate(hidden_states.to(self.router_dtype))
router_logits = router_logits.to(hidden_states.dtype) router_logits = router_logits.to(hidden_states.dtype)
...@@ -321,9 +321,14 @@ class BailingMoE(nn.Module): ...@@ -321,9 +321,14 @@ class BailingMoE(nn.Module):
hidden_states=hidden_states, router_logits=router_logits hidden_states=hidden_states, router_logits=router_logits
) )
if self.shared_experts is not None:
shared_output, final_hidden_states = final_hidden_states
else:
shared_output = None
final_hidden_states *= self.routed_scaling_factor final_hidden_states *= self.routed_scaling_factor
if self.shared_experts: if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1: if self.tp_size > 1:
...@@ -475,7 +480,7 @@ class BailingMoeModel(nn.Module): ...@@ -475,7 +480,7 @@ class BailingMoeModel(nn.Module):
return hidden_states return hidden_states
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return FusedMoE.make_expert_params_mapping( return SharedFusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
......
...@@ -49,7 +49,7 @@ from vllm.forward_context import get_forward_context ...@@ -49,7 +49,7 @@ from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
...@@ -64,7 +64,6 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( ...@@ -64,7 +64,6 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, per_token_group_quant_fp8,
) )
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
...@@ -205,26 +204,6 @@ class DeepseekV2MoE(nn.Module): ...@@ -205,26 +204,6 @@ class DeepseekV2MoE(nn.Module):
) )
if config.n_shared_experts is None: if config.n_shared_experts is None:
self.experts = FusedMoE(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
# we do scaling outside, set factor to 1.0 to avoid double mul
routed_scaling_factor=1.0,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
)
self.shared_experts = None self.shared_experts = None
else: else:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts intermediate_size = config.moe_intermediate_size * config.n_shared_experts
...@@ -1306,7 +1285,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR ...@@ -1306,7 +1285,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR
self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace
self.num_expert_groups = config.n_group self.num_expert_groups = config.n_group
self.moe_layers: list[FusedMoE] = [] self.moe_layers: list[SharedFusedMoE] = []
example_moe = None example_moe = None
for layer in self.model.layers: for layer in self.model.layers:
if isinstance(layer, PPMissingLayer): if isinstance(layer, PPMissingLayer):
...@@ -1394,7 +1373,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR ...@@ -1394,7 +1373,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping( expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
......
...@@ -42,7 +42,7 @@ from vllm.distributed import ( ...@@ -42,7 +42,7 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -145,7 +145,21 @@ class Dots1MoE(nn.Module): ...@@ -145,7 +145,21 @@ class Dots1MoE(nn.Module):
else: else:
self.gate.e_score_correction_bias = None self.gate.e_score_correction_bias = None
self.experts = FusedMoE( if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = Dots1MLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
prefix=f"{prefix}.shared_experts",
)
else:
self.shared_experts = None
self.experts = SharedFusedMoE(
shared_experts=self.shared_experts,
num_experts=config.n_routed_experts, num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
...@@ -163,29 +177,19 @@ class Dots1MoE(nn.Module): ...@@ -163,29 +177,19 @@ class Dots1MoE(nn.Module):
e_score_correction_bias=self.gate.e_score_correction_bias, e_score_correction_bias=self.gate.e_score_correction_bias,
) )
if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = Dots1MLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
prefix=f"{prefix}.shared_experts",
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
final_hidden_states = ( final_hidden_states = (
self.experts(hidden_states=hidden_states, router_logits=router_logits) self.experts(hidden_states=hidden_states, router_logits=router_logits)
* self.routed_scaling_factor * self.routed_scaling_factor
) )
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output if self.shared_experts is not None:
final_hidden_states = final_hidden_states[0] + final_hidden_states[1]
if self.tp_size > 1: if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_dim) return final_hidden_states.view(num_tokens, hidden_dim)
...@@ -426,7 +430,7 @@ class Dots1Model(nn.Module): ...@@ -426,7 +430,7 @@ class Dots1Model(nn.Module):
return hidden_states return hidden_states
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return FusedMoE.make_expert_params_mapping( return SharedFusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
......
...@@ -37,7 +37,7 @@ from vllm.config import CacheConfig, VllmConfig ...@@ -37,7 +37,7 @@ from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -145,18 +145,6 @@ class Ernie4_5_MoeMoE(nn.Module): ...@@ -145,18 +145,6 @@ class Ernie4_5_MoeMoE(nn.Module):
torch.empty(config.moe_num_experts, dtype=torch.float32) torch.empty(config.moe_num_experts, dtype=torch.float32)
) )
self.experts = FusedMoE(
num_experts=config.moe_num_experts,
top_k=config.moe_k,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=True,
quant_config=quant_config,
prefix=f"{prefix}.experts",
e_score_correction_bias=self.gate.e_score_correction_bias,
)
if self.has_shared_experts: if self.has_shared_experts:
intermediate_size = ( intermediate_size = (
config.moe_intermediate_size * config.moe_num_shared_experts config.moe_intermediate_size * config.moe_num_shared_experts
...@@ -167,16 +155,28 @@ class Ernie4_5_MoeMoE(nn.Module): ...@@ -167,16 +155,28 @@ class Ernie4_5_MoeMoE(nn.Module):
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.shared_experts", prefix=f"{prefix}.shared_experts",
reduce_results=self.experts.must_reduce_shared_expert_outputs(), reduce_results=False,
)
else:
self.shared_experts = None
self.experts = SharedFusedMoE(
shared_experts=self.shared_experts,
num_experts=config.moe_num_experts,
top_k=config.moe_k,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=True,
quant_config=quant_config,
prefix=f"{prefix}.experts",
e_score_correction_bias=self.gate.e_score_correction_bias,
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
orig_shape = hidden_states.shape orig_shape = hidden_states.shape
hidden_dim = hidden_states.shape[-1] hidden_dim = hidden_states.shape[-1]
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
shared_output = None
if self.has_shared_experts:
shared_output = self.shared_experts(hidden_states)
router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32)) router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32))
...@@ -184,8 +184,8 @@ class Ernie4_5_MoeMoE(nn.Module): ...@@ -184,8 +184,8 @@ class Ernie4_5_MoeMoE(nn.Module):
hidden_states=hidden_states, router_logits=router_logits hidden_states=hidden_states, router_logits=router_logits
) )
if self.has_shared_experts and shared_output is not None: if self.has_shared_experts:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = final_hidden_states[0] + final_hidden_states[1]
if self.tp_size > 1: if self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
...@@ -460,7 +460,7 @@ class Ernie4_5_MoeModel(nn.Module): ...@@ -460,7 +460,7 @@ class Ernie4_5_MoeModel(nn.Module):
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping( return SharedFusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
......
...@@ -37,7 +37,7 @@ from vllm.attention import Attention ...@@ -37,7 +37,7 @@ from vllm.attention import Attention
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
QKVParallelLinear, QKVParallelLinear,
...@@ -74,7 +74,15 @@ logger = init_logger(__name__) ...@@ -74,7 +74,15 @@ logger = init_logger(__name__)
class Ernie4_5_VLMoeMLP(Ernie4_5_MoeMLP): class Ernie4_5_VLMoeMLP(Ernie4_5_MoeMLP):
pass def __init__(self, shared_experts: Optional[torch.nn.Module] = None, **kwargs):
super().__init__(**kwargs)
self.shared_experts = shared_experts
def forward(self, x):
if self.shared_experts is not None:
return self.shared_experts(x) + super().forward(x)
else:
return super().forward(x)
class Ernie4_5_VLMoeAttention(nn.Module): class Ernie4_5_VLMoeAttention(nn.Module):
...@@ -223,6 +231,21 @@ class Ernie4_5_VLMoeMoE(nn.Module): ...@@ -223,6 +231,21 @@ class Ernie4_5_VLMoeMoE(nn.Module):
assert text_moe_layer_start_index <= text_moe_layer_end_index assert text_moe_layer_start_index <= text_moe_layer_end_index
if self.has_shared_experts:
intermediate_size = (
config.moe_intermediate_size[0] * config.moe_num_shared_experts
)
self.shared_experts = Ernie4_5_VLMoeMLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.shared_experts",
reduce_results=False,
)
else:
self.shared_experts = None
if ( if (
layer_idx >= text_moe_layer_start_index layer_idx >= text_moe_layer_start_index
and layer_idx <= text_moe_layer_end_index and layer_idx <= text_moe_layer_end_index
...@@ -236,7 +259,8 @@ class Ernie4_5_VLMoeMoE(nn.Module): ...@@ -236,7 +259,8 @@ class Ernie4_5_VLMoeMoE(nn.Module):
prefix=f"{prefix}.text_experts_gate", prefix=f"{prefix}.text_experts_gate",
) )
self.text_experts = FusedMoE( self.text_experts = SharedFusedMoE(
shared_experts=self.shared_experts,
num_experts=config.moe_num_experts[0], num_experts=config.moe_num_experts[0],
top_k=config.moe_k, top_k=config.moe_k,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
...@@ -249,6 +273,7 @@ class Ernie4_5_VLMoeMoE(nn.Module): ...@@ -249,6 +273,7 @@ class Ernie4_5_VLMoeMoE(nn.Module):
) )
else: else:
self.text_experts = Ernie4_5_VLMoeMLP( self.text_experts = Ernie4_5_VLMoeMLP(
shared_experts=self.shared_experts,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
...@@ -271,7 +296,8 @@ class Ernie4_5_VLMoeMoE(nn.Module): ...@@ -271,7 +296,8 @@ class Ernie4_5_VLMoeMoE(nn.Module):
prefix=f"{prefix}.vision_experts_gate", prefix=f"{prefix}.vision_experts_gate",
) )
self.vision_experts = FusedMoE( self.vision_experts = SharedFusedMoE(
shared_experts=self.shared_experts,
num_experts=config.moe_num_experts[1], num_experts=config.moe_num_experts[1],
top_k=config.moe_k, top_k=config.moe_k,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
...@@ -284,6 +310,7 @@ class Ernie4_5_VLMoeMoE(nn.Module): ...@@ -284,6 +310,7 @@ class Ernie4_5_VLMoeMoE(nn.Module):
) )
else: else:
self.vision_experts = Ernie4_5_VLMoeMLP( self.vision_experts = Ernie4_5_VLMoeMLP(
shared_experts=self.shared_experts,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
...@@ -292,19 +319,6 @@ class Ernie4_5_VLMoeMoE(nn.Module): ...@@ -292,19 +319,6 @@ class Ernie4_5_VLMoeMoE(nn.Module):
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
) )
if self.has_shared_experts:
intermediate_size = (
config.moe_intermediate_size[0] * config.moe_num_shared_experts
)
self.shared_experts = Ernie4_5_VLMoeMLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.shared_experts",
reduce_results=self.text_experts.must_reduce_shared_expert_outputs(),
)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -315,9 +329,6 @@ class Ernie4_5_VLMoeMoE(nn.Module): ...@@ -315,9 +329,6 @@ class Ernie4_5_VLMoeMoE(nn.Module):
hidden_dim = hidden_states.shape[-1] hidden_dim = hidden_states.shape[-1]
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
if self.has_shared_experts:
shared_output = self.shared_experts(hidden_states)
if visual_token_mask is not None and visual_token_mask.all(): if visual_token_mask is not None and visual_token_mask.all():
# only vision modal input # only vision modal input
router_logits, _ = self.vision_experts_gate( router_logits, _ = self.vision_experts_gate(
...@@ -362,8 +373,8 @@ class Ernie4_5_VLMoeMoE(nn.Module): ...@@ -362,8 +373,8 @@ class Ernie4_5_VLMoeMoE(nn.Module):
hidden_states=hidden_states, router_logits=text_router_logits hidden_states=hidden_states, router_logits=text_router_logits
) )
if self.has_shared_experts and shared_output is not None: if self.has_shared_experts:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = final_hidden_states[0] + final_hidden_states[1]
if self.tp_size > 1: if self.tp_size > 1:
final_hidden_states = ( final_hidden_states = (
...@@ -649,7 +660,7 @@ class Ernie4_5_VLMoeForCausalLM(nn.Module, SupportsPP): ...@@ -649,7 +660,7 @@ class Ernie4_5_VLMoeForCausalLM(nn.Module, SupportsPP):
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping( expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
......
...@@ -42,7 +42,7 @@ from vllm.distributed import ( ...@@ -42,7 +42,7 @@ from vllm.distributed import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -52,7 +52,6 @@ from vllm.model_executor.layers.linear import ( ...@@ -52,7 +52,6 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
...@@ -176,6 +175,9 @@ class Glm4MoE(nn.Module): ...@@ -176,6 +175,9 @@ class Glm4MoE(nn.Module):
reduce_results=False, reduce_results=False,
prefix=f"{prefix}.shared_experts", prefix=f"{prefix}.shared_experts",
) )
else:
self.shared_experts = None
self.experts = SharedFusedMoE( self.experts = SharedFusedMoE(
shared_experts=self.shared_experts, shared_experts=self.shared_experts,
num_experts=config.n_routed_experts, num_experts=config.n_routed_experts,
...@@ -196,26 +198,6 @@ class Glm4MoE(nn.Module): ...@@ -196,26 +198,6 @@ class Glm4MoE(nn.Module):
enable_eplb=self.enable_eplb, enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts, num_redundant_experts=self.n_redundant_experts,
) )
else:
self.experts = FusedMoE(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func="sigmoid",
# we do scaling outside, set factor to 1.0 to avoid double mul
routed_scaling_factor=1.0,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
...@@ -522,7 +504,7 @@ class Glm4MoeModel(nn.Module): ...@@ -522,7 +504,7 @@ class Glm4MoeModel(nn.Module):
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping( return SharedFusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
...@@ -677,7 +659,7 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ...@@ -677,7 +659,7 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace
self.num_expert_groups = config.n_group self.num_expert_groups = config.n_group
self.moe_layers: list[FusedMoE] = [] self.moe_layers: list[SharedFusedMoE] = []
example_moe = None example_moe = None
for layer in self.model.layers: for layer in self.model.layers:
if isinstance(layer, PPMissingLayer): if isinstance(layer, PPMissingLayer):
......
...@@ -44,7 +44,7 @@ from vllm.distributed import ( ...@@ -44,7 +44,7 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
...@@ -415,19 +415,6 @@ class HunYuanSparseMoeBlock(nn.Module): ...@@ -415,19 +415,6 @@ class HunYuanSparseMoeBlock(nn.Module):
self.physical_expert_start + self.n_local_physical_experts self.physical_expert_start + self.n_local_physical_experts
) )
self.experts = FusedMoE(
num_experts=self.n_routed_experts,
top_k=top_k,
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
reduce_results=False,
renormalize=top_k > 1,
quant_config=quant_config,
prefix=f"{prefix}.experts",
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
)
self.gate = ReplicatedLinear( self.gate = ReplicatedLinear(
config.hidden_size, config.hidden_size,
config.num_experts, config.num_experts,
...@@ -455,22 +442,34 @@ class HunYuanSparseMoeBlock(nn.Module): ...@@ -455,22 +442,34 @@ class HunYuanSparseMoeBlock(nn.Module):
else: else:
self.shared_mlp = None self.shared_mlp = None
self.experts = SharedFusedMoE(
shared_experts=self.shared_mlp,
num_experts=self.n_routed_experts,
top_k=top_k,
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
reduce_results=False,
renormalize=top_k > 1,
quant_config=quant_config,
prefix=f"{prefix}.experts",
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape. # NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape orig_shape = hidden_states.shape
hidden_dim = hidden_states.shape[-1] hidden_dim = hidden_states.shape[-1]
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
shared_output = None
if self.shared_mlp is not None:
shared_output = self.shared_mlp(hidden_states)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits hidden_states=hidden_states, router_logits=router_logits
) )
if shared_output is not None: if self.shared_mlp is not None:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = final_hidden_states[0] + final_hidden_states[1]
if self.tp_size > 1: if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
...@@ -724,7 +723,7 @@ class HunYuanModel(nn.Module): ...@@ -724,7 +723,7 @@ class HunYuanModel(nn.Module):
if _is_moe(self.config): if _is_moe(self.config):
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping( return SharedFusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
...@@ -1008,7 +1007,7 @@ class HunYuanMoEV1Base(HunyuanV1ModelBase, MixtureOfExperts): ...@@ -1008,7 +1007,7 @@ class HunYuanMoEV1Base(HunyuanV1ModelBase, MixtureOfExperts):
# Set MoE hyperparameters # Set MoE hyperparameters
self.expert_weights = [] self.expert_weights = []
self.num_expert_groups = 1 self.num_expert_groups = 1
self.moe_layers: list[FusedMoE] = [] self.moe_layers: list[SharedFusedMoE] = []
example_layer = None example_layer = None
for layer in self.model.layers: for layer in self.model.layers:
if isinstance(layer, PPMissingLayer): if isinstance(layer, PPMissingLayer):
......
...@@ -33,7 +33,7 @@ from vllm.distributed import ( ...@@ -33,7 +33,7 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
) )
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
QKVParallelLinear, QKVParallelLinear,
...@@ -42,7 +42,6 @@ from vllm.model_executor.layers.linear import ( ...@@ -42,7 +42,6 @@ from vllm.model_executor.layers.linear import (
) )
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
maybe_remap_kv_scale_name, maybe_remap_kv_scale_name,
...@@ -399,7 +398,7 @@ class Llama4Model(LlamaModel): ...@@ -399,7 +398,7 @@ class Llama4Model(LlamaModel):
params_dict: The dictionary of module parameters. params_dict: The dictionary of module parameters.
loaded_params: The set of already loaded parameters. loaded_params: The set of already loaded parameters.
expert_params_mapping: The mapping of expert parameters. Must be expert_params_mapping: The mapping of expert parameters. Must be
generated by FusedMoE.make_expert_params_mapping(). generated by SharedFusedMoE.make_expert_params_mapping().
fused: Whether the expert weights are fused into a single weight fused: Whether the expert weights are fused into a single weight
tensor or are separate weight tensors for each expert. tensor or are separate weight tensors for each expert.
When fused is True, loaded_weight should have shape of: When fused is True, loaded_weight should have shape of:
...@@ -522,7 +521,7 @@ class Llama4Model(LlamaModel): ...@@ -522,7 +521,7 @@ class Llama4Model(LlamaModel):
fused_experts_params = False fused_experts_params = False
# Expert parameter mapping for the case where the expert weights are # Expert parameter mapping for the case where the expert weights are
# not fused into a single weight tensor. # not fused into a single weight tensor.
expert_params_mapping = FusedMoE.make_expert_params_mapping( expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
...@@ -530,7 +529,7 @@ class Llama4Model(LlamaModel): ...@@ -530,7 +529,7 @@ class Llama4Model(LlamaModel):
) )
# Expert parameter mapping for the case where the expert weights are # Expert parameter mapping for the case where the expert weights are
# fused into a single weight tensor. # fused into a single weight tensor.
expert_params_mapping_fused = FusedMoE.make_expert_params_mapping( expert_params_mapping_fused = SharedFusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_up_proj", ckpt_gate_proj_name="gate_up_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="gate_up_proj", ckpt_up_proj_name="gate_up_proj",
......
...@@ -40,7 +40,7 @@ from vllm.config import CacheConfig, VllmConfig ...@@ -40,7 +40,7 @@ from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -79,6 +79,7 @@ class Qwen2MoeMLP(nn.Module): ...@@ -79,6 +79,7 @@ class Qwen2MoeMLP(nn.Module):
hidden_act: str, hidden_act: str,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True, reduce_results: bool = True,
expert_gate: Optional[torch.nn.Linear] = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -102,12 +103,17 @@ class Qwen2MoeMLP(nn.Module): ...@@ -102,12 +103,17 @@ class Qwen2MoeMLP(nn.Module):
f"Unsupported activation: {hidden_act}. Only silu is supported for now." f"Unsupported activation: {hidden_act}. Only silu is supported for now."
) )
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
self.expert_gate = expert_gate
def forward(self, x): def forward(self, x):
gate_up, _ = self.gate_up_proj(x) gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up) out = self.act_fn(gate_up)
x, _ = self.down_proj(x) out, _ = self.down_proj(out)
return x
if self.expert_gate is not None:
out = F.sigmoid(self.expert_gate(x)) * out
return out
class Qwen2MoeSparseMoeBlock(nn.Module): class Qwen2MoeSparseMoeBlock(nn.Module):
...@@ -126,17 +132,6 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ...@@ -126,17 +132,6 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
f"the number of experts {config.num_experts}." f"the number of experts {config.num_experts}."
) )
self.experts = FusedMoE(
num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
)
self.gate = ReplicatedLinear( self.gate = ReplicatedLinear(
config.hidden_size, config.hidden_size,
config.num_experts, config.num_experts,
...@@ -144,39 +139,47 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ...@@ -144,39 +139,47 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
quant_config=None, quant_config=None,
prefix=f"{prefix}.gate", prefix=f"{prefix}.gate",
) )
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
if config.shared_expert_intermediate_size > 0: if config.shared_expert_intermediate_size > 0:
self.shared_expert = Qwen2MoeMLP( self.shared_expert = Qwen2MoeMLP(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.shared_expert_intermediate_size, intermediate_size=config.shared_expert_intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
reduce_results=self.experts.must_reduce_shared_expert_outputs(), reduce_results=False,
expert_gate=self.shared_expert_gate,
prefix=f"{prefix}.shared_expert", prefix=f"{prefix}.shared_expert",
) )
else: else:
self.shared_expert = None self.shared_expert = None
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
self.experts = SharedFusedMoE(
shared_experts=self.shared_expert,
num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape. # NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape orig_shape = hidden_states.shape
hidden_dim = hidden_states.shape[-1] hidden_dim = hidden_states.shape[-1]
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
shared_output = None
if self.shared_expert is not None:
shared_output = self.shared_expert(hidden_states)
if self.shared_expert_gate is not None:
shared_output = (
F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_output
)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits hidden_states=hidden_states, router_logits=router_logits
) )
if shared_output is not None: if self.shared_expert is not None:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = final_hidden_states[0] + final_hidden_states[1]
if self.tp_size > 1: if self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
final_hidden_states final_hidden_states
...@@ -418,7 +421,7 @@ class Qwen2MoeModel(nn.Module): ...@@ -418,7 +421,7 @@ class Qwen2MoeModel(nn.Module):
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping( return SharedFusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
......
...@@ -7,7 +7,6 @@ from itertools import islice ...@@ -7,7 +7,6 @@ from itertools import islice
from typing import Optional from typing import Optional
import torch import torch
import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
...@@ -36,7 +35,7 @@ from vllm.model_executor.layers.fla.ops import ( ...@@ -36,7 +35,7 @@ from vllm.model_executor.layers.fla.ops import (
chunk_gated_delta_rule, chunk_gated_delta_rule,
fused_recurrent_gated_delta_rule, fused_recurrent_gated_delta_rule,
) )
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import GemmaRMSNorm as Qwen3NextRMSNorm from vllm.model_executor.layers.layernorm import GemmaRMSNorm as Qwen3NextRMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
...@@ -136,20 +135,6 @@ class Qwen3NextSparseMoeBlock(nn.Module): ...@@ -136,20 +135,6 @@ class Qwen3NextSparseMoeBlock(nn.Module):
self.physical_expert_start + self.n_local_physical_experts self.physical_expert_start + self.n_local_physical_experts
) )
self.experts = FusedMoE(
num_experts=self.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
)
self.gate = ReplicatedLinear( self.gate = ReplicatedLinear(
config.hidden_size, config.hidden_size,
config.num_experts, config.num_experts,
...@@ -158,18 +143,35 @@ class Qwen3NextSparseMoeBlock(nn.Module): ...@@ -158,18 +143,35 @@ class Qwen3NextSparseMoeBlock(nn.Module):
prefix=f"{prefix}.gate", prefix=f"{prefix}.gate",
) )
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
if config.shared_expert_intermediate_size > 0: if config.shared_expert_intermediate_size > 0:
self.shared_expert = Qwen3NextMLP( self.shared_expert = Qwen3NextMLP(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.shared_expert_intermediate_size, intermediate_size=config.shared_expert_intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
reduce_results=self.experts.must_reduce_shared_expert_outputs(), reduce_results=False,
expert_gate=self.shared_expert_gate,
prefix=f"{prefix}.shared_expert", prefix=f"{prefix}.shared_expert",
) )
else: else:
self.shared_expert = None self.shared_expert = None
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
self.experts = SharedFusedMoE(
shared_experts=self.shared_expert,
num_experts=self.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape. # NOTE: hidden_states can have either 1D or 2D shape.
...@@ -180,22 +182,14 @@ class Qwen3NextSparseMoeBlock(nn.Module): ...@@ -180,22 +182,14 @@ class Qwen3NextSparseMoeBlock(nn.Module):
if self.is_sequence_parallel: if self.is_sequence_parallel:
hidden_states = sequence_parallel_chunk(hidden_states) hidden_states = sequence_parallel_chunk(hidden_states)
shared_output = None
if self.shared_expert is not None:
shared_output = self.shared_expert(hidden_states)
if self.shared_expert_gate is not None:
shared_output = (
F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_output
)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits hidden_states=hidden_states, router_logits=router_logits
) )
if shared_output is not None: if self.shared_expert is not None:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = final_hidden_states[0] + final_hidden_states[1]
if self.is_sequence_parallel: if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather( final_hidden_states = tensor_model_parallel_all_gather(
...@@ -1008,7 +1002,7 @@ class Qwen3NextModel(nn.Module): ...@@ -1008,7 +1002,7 @@ class Qwen3NextModel(nn.Module):
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping( return SharedFusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
...@@ -1150,7 +1144,7 @@ class Qwen3NextForCausalLM( ...@@ -1150,7 +1144,7 @@ class Qwen3NextForCausalLM(
# Set MoE hyperparameters # Set MoE hyperparameters
self.expert_weights = [] self.expert_weights = []
self.moe_layers: list[FusedMoE] = [] self.moe_layers: list[SharedFusedMoE] = []
example_layer = None example_layer = None
for layer in self.model.layers: for layer in self.model.layers:
if isinstance(layer, PPMissingLayer): if isinstance(layer, PPMissingLayer):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment