Unverified Commit 726efe17 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[MoE Refactor] Move the shared/fused expert output sum into MoERunnerBase (#35949)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
Co-authored-by: default avatarRobert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
parent 59556265
......@@ -216,7 +216,6 @@ class NemotronHMoE(nn.Module):
top_k=config.num_experts_per_tok,
hidden_size=self.moe_hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
......@@ -231,6 +230,9 @@ class NemotronHMoE(nn.Module):
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
routed_input_transform=self.fc1_latent_proj,
routed_output_transform=self.fc2_latent_proj,
routed_scaling_factor=self.routed_scaling_factor,
apply_routed_scale_to_output=True,
router_logits_dtype=self.gate.out_dtype,
)
......@@ -244,38 +246,15 @@ class NemotronHMoE(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
# SharedFusedMoE handles:
# - shared experts (with original hidden_states)
# - routed_input_transform (fc1_latent_proj) for latent MoE
# - multistream parallelism between shared and routed experts
shared_output, final_hidden_states = self.experts(
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
if hidden_states.dtype != torch.float16:
final_hidden_states *= self.routed_scaling_factor
elif self.shared_experts is not None:
shared_output *= 1.0 / self.routed_scaling_factor
# TODO: See SharedFusedMoE.apply_routed_input_transform
# for bandwidth optimization
if self.use_latent_moe:
final_hidden_states, _ = self.fc2_latent_proj(final_hidden_states)
if self.shared_experts is not None:
final_hidden_states += shared_output
if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0
)
final_hidden_states = final_hidden_states[:num_tokens]
elif self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states
)
return final_hidden_states.view(num_tokens, hidden_dim)
......
......@@ -98,7 +98,6 @@ class OlmoeMoE(nn.Module):
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
reduce_results=True,
renormalize=False,
quant_config=quant_config,
tp_size=tp_size,
......
......@@ -206,7 +206,6 @@ class OpenPanguMoE(nn.Module):
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
......@@ -214,8 +213,8 @@ class OpenPanguMoE(nn.Module):
topk_group=1,
prefix=f"{prefix}.experts",
scoring_func="sigmoid",
# we do scaling outside, set factor to 1.0 to avoid double mul
routed_scaling_factor=1.0,
routed_scaling_factor=self.routed_scaling_factor,
apply_routed_scale_to_output=True,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
......@@ -234,33 +233,15 @@ class OpenPanguMoE(nn.Module):
router_logits, _ = self.gate(hidden_states)
fused_moe_out = self.experts(
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
shared_output, final_hidden_states = fused_moe_out
if self.shared_experts is None:
assert shared_output is None
if hidden_states.dtype != torch.float16:
final_hidden_states *= self.routed_scaling_factor
elif self.shared_experts is not None:
assert shared_output is not None
shared_output *= 1.0 / self.routed_scaling_factor
if self.shared_experts is not None:
assert shared_output is not None
final_hidden_states += shared_output
if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0
)
final_hidden_states = final_hidden_states[:num_tokens]
elif self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states
)
return final_hidden_states.view(num_tokens, hidden_dim)
......
......@@ -359,7 +359,6 @@ class Param2MoEMoEBlock(nn.Module):
top_k=self.top_k,
hidden_size=self.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=self.norm_expert_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
......@@ -388,24 +387,11 @@ class Param2MoEMoEBlock(nn.Module):
self.gate.weight.float(),
).to(hidden_states.dtype)
final_hidden = self.experts(
expert_output = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
)
if self.shared_experts is not None:
shared_output, expert_output = final_hidden
else:
shared_output, expert_output = None, final_hidden
if shared_output is not None:
expert_output = expert_output + shared_output
if self.tp_size > 1:
expert_output = self.experts.maybe_all_reduce_tensor_model_parallel(
expert_output
)
return expert_output.view(num_tokens, hidden_dim)
......
......@@ -281,7 +281,6 @@ class PhiMoE(nn.Module):
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=params_dtype,
reduce_results=True,
renormalize=False,
quant_config=quant_config,
tp_size=tp_size,
......
......@@ -170,7 +170,6 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
......@@ -187,12 +186,6 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
if self.shared_expert is not None:
final_hidden_states = final_hidden_states[0] + final_hidden_states[1]
if self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
final_hidden_states
)
return final_hidden_states.view(orig_shape)
......
......@@ -212,7 +212,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
......@@ -234,22 +233,15 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
shared_out, fused_out = self.experts(
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
final_hidden_states = (
shared_out + fused_out if shared_out is not None else fused_out
)
if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0
)
final_hidden_states = final_hidden_states[:num_tokens]
elif self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
final_hidden_states
)
# return to 1d if input is 1d
return final_hidden_states.squeeze(0) if is_input_1d else final_hidden_states
......
......@@ -153,7 +153,6 @@ class Qwen3NextSparseMoeBlock(nn.Module):
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=getattr(config, "norm_topk_prob", True),
quant_config=quant_config,
prefix=f"{prefix}.experts",
......@@ -183,18 +182,11 @@ class Qwen3NextSparseMoeBlock(nn.Module):
hidden_states=hidden_states, router_logits=router_logits
)
if self.shared_expert is not None:
final_hidden_states = final_hidden_states[0] + final_hidden_states[1]
if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0
)
final_hidden_states = final_hidden_states[:num_tokens]
elif self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
final_hidden_states
)
return final_hidden_states.view(orig_shape)
......
......@@ -341,7 +341,6 @@ class SarvamMLAMoE(nn.Module):
top_k=self.top_k,
hidden_size=self.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=self.norm_expert_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
......@@ -370,20 +369,7 @@ class SarvamMLAMoE(nn.Module):
router_logits=router_logits,
)
if self.shared_experts is not None:
shared_output, expert_output = final_hidden
else:
shared_output, expert_output = None, final_hidden
if shared_output is not None:
expert_output = expert_output + shared_output
if self.tp_size > 1:
expert_output = self.experts.maybe_all_reduce_tensor_model_parallel(
expert_output
)
return expert_output.view(num_tokens, hidden_dim)
return final_hidden.view(num_tokens, hidden_dim)
class SarvamMLABlock(nn.Module):
......
......@@ -14,7 +14,6 @@ from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import (
get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
......@@ -71,7 +70,6 @@ class FusedMoEBlock(nn.Module):
top_k=config.moe_top_k,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_expert_weight,
quant_config=quant_config,
prefix=f"{prefix}.experts",
......@@ -94,8 +92,6 @@ class FusedMoEBlock(nn.Module):
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(orig_shape)
......
......@@ -379,7 +379,6 @@ class FusedMoEBlock(nn.Module):
top_k=config.moe_top_k,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_expert_weight,
quant_config=quant_config,
activation=activation,
......@@ -397,30 +396,16 @@ class FusedMoEBlock(nn.Module):
hidden_states = hidden_states.view(-1, hidden_dim)
if self.experts.is_internal_router:
# In this case, the gate/router runs inside the FusedMoE class
fused_moe_out = self.experts(
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=hidden_states
)
else:
# router_logits: (num_tokens, n_experts)
# TODO(bnell): this gate could be moved into the FusedMoE?
router_logits, _ = self.gate(hidden_states)
fused_moe_out = self.experts(
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
shared_output, final_hidden_states = fused_moe_out
if self.share_expert is None:
assert shared_output is None
if self.share_expert is not None:
assert shared_output is not None
final_hidden_states += shared_output
if self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states
)
return final_hidden_states.view(num_tokens, hidden_dim)
......
......@@ -204,8 +204,6 @@ class MoEMixin(MixtureOfExperts):
)
assert intermediate_size is not None
# If there are shared experts, the results are
# reduced after mlp.forward() not inside FusedMoE
num_shared_experts = getattr_iter(
text_config,
[
......@@ -214,17 +212,6 @@ class MoEMixin(MixtureOfExperts):
],
0,
)
reduce_results = num_shared_experts == 0
def add_all_reduce(mlp: nn.Module):
"""Adds an all-reduce to the output of `mlp.forward()`."""
class MLPWithAllReduce(mlp.__class__):
def forward(self, *args, **kwargs):
output = super().forward(*args, **kwargs)
return self.experts.maybe_all_reduce_tensor_model_parallel(output)
mlp.__class__ = MLPWithAllReduce
# Unused kwargs since we use custom_routing_function:
# - `scoring_func` and `e_score_correction_bias` only used for grouped
......@@ -289,14 +276,11 @@ class MoEMixin(MixtureOfExperts):
if "bias" in experts_param_name:
has_bias = True
break
# Double check there are no shared experts
nonlocal reduce_results
if reduce_results:
# If the config does not specify num_shared_experts, but
# the model has shared experts, we assume there is one.
if self.num_shared_experts == 0:
for mlp_param_name, _ in mlp.named_parameters():
if "shared_expert" in mlp_param_name:
reduce_results = False
# If the config does not specify num_shared_experts, but
# the model has shared experts, we assume there is one.
self.num_shared_experts = 1
break
# Replace experts module with FusedMoE
......@@ -305,7 +289,6 @@ class MoEMixin(MixtureOfExperts):
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
reduce_results=reduce_results,
renormalize=renormalize,
# Hard coded because topk happens in Transformers
use_grouped_topk=False,
......@@ -326,13 +309,6 @@ class MoEMixin(MixtureOfExperts):
self.moe_layers.append(fused_experts)
self.expert_weights.append(fused_experts.get_expert_weights())
self.num_moe_layers += 1
# If results are not all-reduced in FusedMoE, ensure they
# are all-reduced at the end of mlp.forward() if tensor
# parallel or expert parallel is enabled
if not reduce_results and (
fused_experts.tp_size > 1 or fused_experts.ep_size > 1
):
add_all_reduce(mlp)
else:
_recursive_replace(child_module, prefix=qual_name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment