Unverified Commit 726efe17 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[MoE Refactor] Move the shared/fused expert output sum into MoERunnerBase (#35949)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
Co-authored-by: default avatarRobert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
parent 59556265
......@@ -194,7 +194,6 @@ class Ernie4_5_MoeMoE(nn.Module):
top_k=config.moe_k,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=True,
quant_config=quant_config,
prefix=f"{prefix}.experts",
......@@ -215,16 +214,6 @@ class Ernie4_5_MoeMoE(nn.Module):
hidden_states=hidden_states, router_logits=router_logits
)
if self.has_shared_experts:
final_hidden_states = final_hidden_states[0] + final_hidden_states[1]
else:
final_hidden_states = final_hidden_states[1]
if self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states
)
return final_hidden_states.view(orig_shape)
......
......@@ -263,7 +263,6 @@ class Ernie4_5_VLMoeMoE(nn.Module):
top_k=config.moe_k,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size[0],
reduce_results=False,
renormalize=True,
quant_config=quant_config,
e_score_correction_bias=self.e_score_correction_bias[0],
......@@ -301,7 +300,6 @@ class Ernie4_5_VLMoeMoE(nn.Module):
top_k=config.moe_k,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size[1],
reduce_results=False,
renormalize=True,
quant_config=quant_config,
e_score_correction_bias=self.e_score_correction_bias[1],
......@@ -342,9 +340,6 @@ class Ernie4_5_VLMoeMoE(nn.Module):
visual_token_mask = visual_token_mask.repeat(1, self.hidden_size).bool()
text_token_mask = ~visual_token_mask
final_experts_hidden_states = torch.zeros_like(hidden_states)
final_shared_output = (
torch.zeros_like(hidden_states) if self.has_shared_experts else None
)
text_hidden_states = hidden_states[text_token_mask].reshape(
-1, self.hidden_size
......@@ -356,26 +351,20 @@ class Ernie4_5_VLMoeMoE(nn.Module):
text_router_logits, _ = self.text_experts_gate(
text_hidden_states.to(dtype=torch.float32)
)
text_shared_output, text_experts_output = self.text_experts(
text_output = self.text_experts(
hidden_states=text_hidden_states, router_logits=text_router_logits
)
final_experts_hidden_states[text_token_mask] = text_experts_output.flatten()
if self.has_shared_experts:
final_shared_output[text_token_mask] = text_shared_output.flatten()
final_experts_hidden_states[text_token_mask] = text_output.flatten()
vision_router_logits, _ = self.vision_experts_gate(
vision_hidden_states.to(dtype=torch.float32)
)
vision_shared_output, vision_experts_output = self.vision_experts(
vision_output = self.vision_experts(
hidden_states=vision_hidden_states, router_logits=vision_router_logits
)
final_experts_hidden_states[visual_token_mask] = (
vision_experts_output.flatten()
)
if self.has_shared_experts:
final_shared_output[visual_token_mask] = vision_shared_output.flatten()
final_experts_hidden_states[visual_token_mask] = vision_output.flatten()
final_hidden_states = (final_shared_output, final_experts_hidden_states)
final_hidden_states = final_experts_hidden_states
else:
# only text modal input
text_router_logits, _ = self.text_experts_gate(
......@@ -386,20 +375,6 @@ class Ernie4_5_VLMoeMoE(nn.Module):
hidden_states=hidden_states, router_logits=text_router_logits
)
if self.has_shared_experts:
# for shared_experts model
final_hidden_states = final_hidden_states[0] + final_hidden_states[1]
else:
# for not shared_experts model
final_hidden_states = final_hidden_states[1]
if self.tp_size > 1:
final_hidden_states = (
self.text_experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states
)
)
return final_hidden_states.view(orig_shape)
......
......@@ -31,6 +31,7 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
......@@ -116,12 +117,26 @@ class ExaoneMoe(nn.Module):
self.physical_expert_start + self.n_local_physical_experts
)
self.experts = FusedMoE(
if getattr(config, "num_shared_experts", 0) > 0:
intermediate_size = config.moe_intermediate_size * config.num_shared_experts
self.shared_experts = ExaoneMoeGatedMLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
prefix=f"{prefix}.shared_experts",
)
else:
self.shared_experts = None
self.experts = SharedFusedMoE(
shared_experts=self.shared_experts,
gate=self.gate,
num_experts=self.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
......@@ -135,41 +150,16 @@ class ExaoneMoe(nn.Module):
num_redundant_experts=self.n_redundant_experts,
)
if getattr(config, "num_shared_experts", 0) > 0:
intermediate_size = config.moe_intermediate_size * config.num_shared_experts
self.shared_experts = ExaoneMoeGatedMLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=self.experts.must_reduce_shared_expert_outputs(),
prefix=f"{prefix}.shared_experts",
)
else:
self.shared_experts = None
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_dim = hidden_states.shape[-1]
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
hidden_states=hidden_states, router_logits=hidden_states
)
if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
final_hidden_states
)
return final_hidden_states.view(orig_shape)
......
......@@ -76,7 +76,6 @@ class FlexOlmoMoE(nn.Module):
top_k=hf_config.num_experts_per_tok,
hidden_size=hf_config.hidden_size,
intermediate_size=hf_config.intermediate_size,
reduce_results=True,
renormalize=False,
quant_config=None,
tp_size=tp_size,
......
......@@ -349,7 +349,6 @@ class Gemma4MoE(nn.Module):
"moe_intermediate_size",
getattr(config, "expert_intermediate_size", None),
),
reduce_results=True,
renormalize=True,
quant_config=quant_config,
prefix=f"{prefix}.experts",
......
......@@ -184,7 +184,6 @@ class Glm4MoE(nn.Module):
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
......@@ -192,8 +191,8 @@ class Glm4MoE(nn.Module):
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func="sigmoid",
# we do scaling outside, set factor to 1.0 to avoid double mul
routed_scaling_factor=1.0,
routed_scaling_factor=self.routed_scaling_factor,
apply_routed_scale_to_output=True,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
......@@ -207,23 +206,9 @@ class Glm4MoE(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states.to(dtype=torch.float32))
fused_moe_out = self.experts(
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
if self.shared_experts is not None:
shared_output, final_hidden_states = fused_moe_out
assert shared_output is not None
final_hidden_states = (
final_hidden_states * self.routed_scaling_factor + shared_output
)
else:
final_hidden_states = fused_moe_out * self.routed_scaling_factor
if self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states
)
return final_hidden_states.view(num_tokens, hidden_dim)
......
......@@ -189,7 +189,6 @@ class MLPBlock(torch.nn.Module):
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
reduce_results=True,
renormalize=True,
quant_config=quant_config,
prefix=f"{prefix}.experts",
......
......@@ -104,7 +104,6 @@ class GraniteMoeMoE(nn.Module):
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=params_dtype,
reduce_results=True,
renormalize=True,
quant_config=quant_config,
tp_size=tp_size,
......
......@@ -209,7 +209,6 @@ class Grok1MoE(nn.Module):
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=params_dtype,
reduce_results=True,
renormalize=renormalize,
quant_config=quant_config,
tp_size=tp_size,
......
......@@ -39,7 +39,6 @@ from vllm.distributed import (
get_ep_group,
get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention
......@@ -445,7 +444,6 @@ class HunYuanSparseMoeBlock(nn.Module):
top_k=top_k,
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
reduce_results=False,
renormalize=top_k > 1,
quant_config=quant_config,
prefix=f"{prefix}.experts",
......@@ -464,11 +462,6 @@ class HunYuanSparseMoeBlock(nn.Module):
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
if self.shared_mlp is not None:
final_hidden_states = final_hidden_states[0] + final_hidden_states[1]
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(orig_shape)
......
......@@ -176,7 +176,6 @@ class InternS1ProMoeSparseMoeBlock(nn.Module):
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=True,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
......
......@@ -90,7 +90,6 @@ class JambaMoE(nn.Module):
self.intermediate_size,
tp_size=tp_size,
params_dtype=params_dtype,
reduce_results=True,
renormalize=False,
use_grouped_topk=False,
quant_config=quant_config,
......
......@@ -11,11 +11,10 @@ from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig
from vllm.distributed import (
get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.kda import KimiDeltaAttention
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
......@@ -132,12 +131,25 @@ class KimiMoE(nn.Module):
self.gate.e_score_correction_bias = nn.Parameter(torch.empty(num_experts))
self.experts = FusedMoE(
if self.num_shared_experts is not None:
intermediate_size = moe_intermediate_size * self.num_shared_experts
self.shared_experts = KimiMLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
prefix=f"{prefix}.shared_experts",
)
else:
self.shared_experts = None
self.experts = SharedFusedMoE(
shared_experts=self.shared_experts,
num_experts=num_experts,
top_k=config.num_experts_per_token,
hidden_size=hidden_size,
intermediate_size=moe_intermediate_size,
reduce_results=False,
renormalize=moe_renormalize,
quant_config=quant_config,
use_grouped_topk=config.use_grouped_topk,
......@@ -146,34 +158,16 @@ class KimiMoE(nn.Module):
prefix=f"{prefix}.experts",
scoring_func=config.moe_router_activation_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
)
if self.num_shared_experts is not None:
intermediate_size = moe_intermediate_size * self.num_shared_experts
self.shared_experts = KimiMLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
prefix=f"{prefix}.shared_experts",
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_size)
if self.num_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = (
self.experts(hidden_states=hidden_states, router_logits=router_logits)
* self.routed_scaling_factor
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_size)
......@@ -482,7 +476,7 @@ class KimiLinearModel(nn.Module):
if self.config.is_moe:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2",
......
......@@ -150,7 +150,6 @@ class Lfm2MoeSparseMoeBlock(nn.Module):
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True, # needed for softmax score func
......@@ -161,6 +160,7 @@ class Lfm2MoeSparseMoeBlock(nn.Module):
num_redundant_experts=self.n_redundant_experts,
scoring_func="sigmoid",
e_score_correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
......@@ -170,16 +170,10 @@ class Lfm2MoeSparseMoeBlock(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = (
self.experts(hidden_states=hidden_states, router_logits=router_logits)
* self.routed_scaling_factor
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
if self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
final_hidden_states
)
return final_hidden_states.view(orig_shape)
......
......@@ -135,7 +135,6 @@ class Llama4MoE(nn.Module):
custom_routing_function=Llama4MoE.custom_routing_function,
intermediate_size=intermediate_size_moe,
apply_router_weight_on_input=True,
reduce_results=False,
renormalize=False,
quant_config=quant_config,
prefix=f"{prefix}.experts",
......@@ -151,19 +150,14 @@ class Llama4MoE(nn.Module):
router_logits, _ = self.router(hidden_states)
shared_out, routed_out = self.experts(
experts_out = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
)
experts_out = routed_out + shared_out
if self.is_sequence_parallel:
experts_out = tensor_model_parallel_all_gather(experts_out, 0)
experts_out = experts_out[:num_tokens]
elif self.tp_size > 1:
experts_out = self.experts.maybe_all_reduce_tensor_model_parallel(
experts_out
)
return experts_out
......
......@@ -300,7 +300,6 @@ class LongcatMoe(nn.Module):
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
reduce_results=True,
params_dtype=params_dtype,
renormalize=False,
quant_config=quant_config,
......
......@@ -162,7 +162,6 @@ class MiMoV2MoE(nn.Module):
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=True,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
......
......@@ -36,7 +36,6 @@ from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import (
get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE
......@@ -104,7 +103,6 @@ class MiniMaxM2MoE(nn.Module):
e_score_correction_bias=self.e_score_correction_bias,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
reduce_results=False,
renormalize=True,
quant_config=quant_config,
prefix=f"{prefix}.experts",
......@@ -134,9 +132,6 @@ class MiniMaxM2MoE(nn.Module):
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
final_hidden_states = final_hidden_states
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_dim)
......
......@@ -162,7 +162,6 @@ class MiniMaxText01MoE(nn.Module):
hidden_size=self.hidden_size,
intermediate_size=self.intermediate_size * self.tp_size,
params_dtype=self.params_dtype,
reduce_results=True,
renormalize=True,
quant_config=self.quant_config,
tp_size=self.tp_size,
......
......@@ -132,7 +132,6 @@ class MixtralMoE(nn.Module):
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=params_dtype,
reduce_results=True,
renormalize=True,
quant_config=quant_config,
tp_size=tp_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment