Unverified Commit 726efe17 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[MoE Refactor] Move the shared/fused expert output sum into MoERunnerBase (#35949)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
Co-authored-by: default avatarRobert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
parent 59556265
...@@ -216,7 +216,6 @@ class NemotronHMoE(nn.Module): ...@@ -216,7 +216,6 @@ class NemotronHMoE(nn.Module):
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
hidden_size=self.moe_hidden_size, hidden_size=self.moe_hidden_size,
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob, renormalize=config.norm_topk_prob,
quant_config=quant_config, quant_config=quant_config,
use_grouped_topk=True, use_grouped_topk=True,
...@@ -231,6 +230,9 @@ class NemotronHMoE(nn.Module): ...@@ -231,6 +230,9 @@ class NemotronHMoE(nn.Module):
num_redundant_experts=self.n_redundant_experts, num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel, is_sequence_parallel=self.is_sequence_parallel,
routed_input_transform=self.fc1_latent_proj, routed_input_transform=self.fc1_latent_proj,
routed_output_transform=self.fc2_latent_proj,
routed_scaling_factor=self.routed_scaling_factor,
apply_routed_scale_to_output=True,
router_logits_dtype=self.gate.out_dtype, router_logits_dtype=self.gate.out_dtype,
) )
...@@ -244,38 +246,15 @@ class NemotronHMoE(nn.Module): ...@@ -244,38 +246,15 @@ class NemotronHMoE(nn.Module):
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
# SharedFusedMoE handles: final_hidden_states = self.experts(
# - shared experts (with original hidden_states)
# - routed_input_transform (fc1_latent_proj) for latent MoE
# - multistream parallelism between shared and routed experts
shared_output, final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits hidden_states=hidden_states, router_logits=router_logits
) )
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
if hidden_states.dtype != torch.float16:
final_hidden_states *= self.routed_scaling_factor
elif self.shared_experts is not None:
shared_output *= 1.0 / self.routed_scaling_factor
# TODO: See SharedFusedMoE.apply_routed_input_transform
# for bandwidth optimization
if self.use_latent_moe:
final_hidden_states, _ = self.fc2_latent_proj(final_hidden_states)
if self.shared_experts is not None:
final_hidden_states += shared_output
if self.is_sequence_parallel: if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather( final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0 final_hidden_states, 0
) )
final_hidden_states = final_hidden_states[:num_tokens] final_hidden_states = final_hidden_states[:num_tokens]
elif self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states
)
return final_hidden_states.view(num_tokens, hidden_dim) return final_hidden_states.view(num_tokens, hidden_dim)
......
...@@ -98,7 +98,6 @@ class OlmoeMoE(nn.Module): ...@@ -98,7 +98,6 @@ class OlmoeMoE(nn.Module):
top_k=top_k, top_k=top_k,
hidden_size=hidden_size, hidden_size=hidden_size,
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
reduce_results=True,
renormalize=False, renormalize=False,
quant_config=quant_config, quant_config=quant_config,
tp_size=tp_size, tp_size=tp_size,
......
...@@ -206,7 +206,6 @@ class OpenPanguMoE(nn.Module): ...@@ -206,7 +206,6 @@ class OpenPanguMoE(nn.Module):
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob, renormalize=config.norm_topk_prob,
quant_config=quant_config, quant_config=quant_config,
use_grouped_topk=True, use_grouped_topk=True,
...@@ -214,8 +213,8 @@ class OpenPanguMoE(nn.Module): ...@@ -214,8 +213,8 @@ class OpenPanguMoE(nn.Module):
topk_group=1, topk_group=1,
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
scoring_func="sigmoid", scoring_func="sigmoid",
# we do scaling outside, set factor to 1.0 to avoid double mul routed_scaling_factor=self.routed_scaling_factor,
routed_scaling_factor=1.0, apply_routed_scale_to_output=True,
e_score_correction_bias=self.gate.e_score_correction_bias, e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb, enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts, num_redundant_experts=self.n_redundant_experts,
...@@ -234,33 +233,15 @@ class OpenPanguMoE(nn.Module): ...@@ -234,33 +233,15 @@ class OpenPanguMoE(nn.Module):
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
fused_moe_out = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits hidden_states=hidden_states, router_logits=router_logits
) )
shared_output, final_hidden_states = fused_moe_out
if self.shared_experts is None:
assert shared_output is None
if hidden_states.dtype != torch.float16:
final_hidden_states *= self.routed_scaling_factor
elif self.shared_experts is not None:
assert shared_output is not None
shared_output *= 1.0 / self.routed_scaling_factor
if self.shared_experts is not None:
assert shared_output is not None
final_hidden_states += shared_output
if self.is_sequence_parallel: if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather( final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0 final_hidden_states, 0
) )
final_hidden_states = final_hidden_states[:num_tokens] final_hidden_states = final_hidden_states[:num_tokens]
elif self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states
)
return final_hidden_states.view(num_tokens, hidden_dim) return final_hidden_states.view(num_tokens, hidden_dim)
......
...@@ -359,7 +359,6 @@ class Param2MoEMoEBlock(nn.Module): ...@@ -359,7 +359,6 @@ class Param2MoEMoEBlock(nn.Module):
top_k=self.top_k, top_k=self.top_k,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=self.norm_expert_prob, renormalize=self.norm_expert_prob,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
...@@ -388,24 +387,11 @@ class Param2MoEMoEBlock(nn.Module): ...@@ -388,24 +387,11 @@ class Param2MoEMoEBlock(nn.Module):
self.gate.weight.float(), self.gate.weight.float(),
).to(hidden_states.dtype) ).to(hidden_states.dtype)
final_hidden = self.experts( expert_output = self.experts(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
) )
if self.shared_experts is not None:
shared_output, expert_output = final_hidden
else:
shared_output, expert_output = None, final_hidden
if shared_output is not None:
expert_output = expert_output + shared_output
if self.tp_size > 1:
expert_output = self.experts.maybe_all_reduce_tensor_model_parallel(
expert_output
)
return expert_output.view(num_tokens, hidden_dim) return expert_output.view(num_tokens, hidden_dim)
......
...@@ -281,7 +281,6 @@ class PhiMoE(nn.Module): ...@@ -281,7 +281,6 @@ class PhiMoE(nn.Module):
hidden_size=hidden_size, hidden_size=hidden_size,
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
params_dtype=params_dtype, params_dtype=params_dtype,
reduce_results=True,
renormalize=False, renormalize=False,
quant_config=quant_config, quant_config=quant_config,
tp_size=tp_size, tp_size=tp_size,
......
...@@ -170,7 +170,6 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ...@@ -170,7 +170,6 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob, renormalize=config.norm_topk_prob,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
...@@ -187,12 +186,6 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ...@@ -187,12 +186,6 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
final_hidden_states = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits hidden_states=hidden_states, router_logits=router_logits
) )
if self.shared_expert is not None:
final_hidden_states = final_hidden_states[0] + final_hidden_states[1]
if self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
final_hidden_states
)
return final_hidden_states.view(orig_shape) return final_hidden_states.view(orig_shape)
......
...@@ -212,7 +212,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -212,7 +212,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob, renormalize=config.norm_topk_prob,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
...@@ -234,22 +233,15 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -234,22 +233,15 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
shared_out, fused_out = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits hidden_states=hidden_states, router_logits=router_logits
) )
final_hidden_states = (
shared_out + fused_out if shared_out is not None else fused_out
)
if self.is_sequence_parallel: if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather( final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0 final_hidden_states, 0
) )
final_hidden_states = final_hidden_states[:num_tokens] final_hidden_states = final_hidden_states[:num_tokens]
elif self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
final_hidden_states
)
# return to 1d if input is 1d # return to 1d if input is 1d
return final_hidden_states.squeeze(0) if is_input_1d else final_hidden_states return final_hidden_states.squeeze(0) if is_input_1d else final_hidden_states
......
...@@ -153,7 +153,6 @@ class Qwen3NextSparseMoeBlock(nn.Module): ...@@ -153,7 +153,6 @@ class Qwen3NextSparseMoeBlock(nn.Module):
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=getattr(config, "norm_topk_prob", True), renormalize=getattr(config, "norm_topk_prob", True),
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
...@@ -183,18 +182,11 @@ class Qwen3NextSparseMoeBlock(nn.Module): ...@@ -183,18 +182,11 @@ class Qwen3NextSparseMoeBlock(nn.Module):
hidden_states=hidden_states, router_logits=router_logits hidden_states=hidden_states, router_logits=router_logits
) )
if self.shared_expert is not None:
final_hidden_states = final_hidden_states[0] + final_hidden_states[1]
if self.is_sequence_parallel: if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather( final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0 final_hidden_states, 0
) )
final_hidden_states = final_hidden_states[:num_tokens] final_hidden_states = final_hidden_states[:num_tokens]
elif self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
final_hidden_states
)
return final_hidden_states.view(orig_shape) return final_hidden_states.view(orig_shape)
......
...@@ -341,7 +341,6 @@ class SarvamMLAMoE(nn.Module): ...@@ -341,7 +341,6 @@ class SarvamMLAMoE(nn.Module):
top_k=self.top_k, top_k=self.top_k,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=self.norm_expert_prob, renormalize=self.norm_expert_prob,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
...@@ -370,20 +369,7 @@ class SarvamMLAMoE(nn.Module): ...@@ -370,20 +369,7 @@ class SarvamMLAMoE(nn.Module):
router_logits=router_logits, router_logits=router_logits,
) )
if self.shared_experts is not None: return final_hidden.view(num_tokens, hidden_dim)
shared_output, expert_output = final_hidden
else:
shared_output, expert_output = None, final_hidden
if shared_output is not None:
expert_output = expert_output + shared_output
if self.tp_size > 1:
expert_output = self.experts.maybe_all_reduce_tensor_model_parallel(
expert_output
)
return expert_output.view(num_tokens, hidden_dim)
class SarvamMLABlock(nn.Module): class SarvamMLABlock(nn.Module):
......
...@@ -14,7 +14,6 @@ from vllm.config import CacheConfig, ModelConfig, VllmConfig ...@@ -14,7 +14,6 @@ from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import ( from vllm.distributed import (
get_pp_group, get_pp_group,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
...@@ -71,7 +70,6 @@ class FusedMoEBlock(nn.Module): ...@@ -71,7 +70,6 @@ class FusedMoEBlock(nn.Module):
top_k=config.moe_top_k, top_k=config.moe_top_k,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_expert_weight, renormalize=config.norm_expert_weight,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
...@@ -94,8 +92,6 @@ class FusedMoEBlock(nn.Module): ...@@ -94,8 +92,6 @@ class FusedMoEBlock(nn.Module):
final_hidden_states = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits hidden_states=hidden_states, router_logits=router_logits
) )
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(orig_shape) return final_hidden_states.view(orig_shape)
......
...@@ -379,7 +379,6 @@ class FusedMoEBlock(nn.Module): ...@@ -379,7 +379,6 @@ class FusedMoEBlock(nn.Module):
top_k=config.moe_top_k, top_k=config.moe_top_k,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_expert_weight, renormalize=config.norm_expert_weight,
quant_config=quant_config, quant_config=quant_config,
activation=activation, activation=activation,
...@@ -397,30 +396,16 @@ class FusedMoEBlock(nn.Module): ...@@ -397,30 +396,16 @@ class FusedMoEBlock(nn.Module):
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
if self.experts.is_internal_router: if self.experts.is_internal_router:
# In this case, the gate/router runs inside the FusedMoE class final_hidden_states = self.experts(
fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=hidden_states hidden_states=hidden_states, router_logits=hidden_states
) )
else: else:
# router_logits: (num_tokens, n_experts) # TODO(bnell): this gate could be moved into the FusedMoE?
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
fused_moe_out = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits hidden_states=hidden_states, router_logits=router_logits
) )
shared_output, final_hidden_states = fused_moe_out
if self.share_expert is None:
assert shared_output is None
if self.share_expert is not None:
assert shared_output is not None
final_hidden_states += shared_output
if self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states
)
return final_hidden_states.view(num_tokens, hidden_dim) return final_hidden_states.view(num_tokens, hidden_dim)
......
...@@ -204,8 +204,6 @@ class MoEMixin(MixtureOfExperts): ...@@ -204,8 +204,6 @@ class MoEMixin(MixtureOfExperts):
) )
assert intermediate_size is not None assert intermediate_size is not None
# If there are shared experts, the results are
# reduced after mlp.forward() not inside FusedMoE
num_shared_experts = getattr_iter( num_shared_experts = getattr_iter(
text_config, text_config,
[ [
...@@ -214,17 +212,6 @@ class MoEMixin(MixtureOfExperts): ...@@ -214,17 +212,6 @@ class MoEMixin(MixtureOfExperts):
], ],
0, 0,
) )
reduce_results = num_shared_experts == 0
def add_all_reduce(mlp: nn.Module):
"""Adds an all-reduce to the output of `mlp.forward()`."""
class MLPWithAllReduce(mlp.__class__):
def forward(self, *args, **kwargs):
output = super().forward(*args, **kwargs)
return self.experts.maybe_all_reduce_tensor_model_parallel(output)
mlp.__class__ = MLPWithAllReduce
# Unused kwargs since we use custom_routing_function: # Unused kwargs since we use custom_routing_function:
# - `scoring_func` and `e_score_correction_bias` only used for grouped # - `scoring_func` and `e_score_correction_bias` only used for grouped
...@@ -289,14 +276,11 @@ class MoEMixin(MixtureOfExperts): ...@@ -289,14 +276,11 @@ class MoEMixin(MixtureOfExperts):
if "bias" in experts_param_name: if "bias" in experts_param_name:
has_bias = True has_bias = True
break break
# Double check there are no shared experts
nonlocal reduce_results
if reduce_results:
for mlp_param_name, _ in mlp.named_parameters():
if "shared_expert" in mlp_param_name:
reduce_results = False
# If the config does not specify num_shared_experts, but # If the config does not specify num_shared_experts, but
# the model has shared experts, we assume there is one. # the model has shared experts, we assume there is one.
if self.num_shared_experts == 0:
for mlp_param_name, _ in mlp.named_parameters():
if "shared_expert" in mlp_param_name:
self.num_shared_experts = 1 self.num_shared_experts = 1
break break
# Replace experts module with FusedMoE # Replace experts module with FusedMoE
...@@ -305,7 +289,6 @@ class MoEMixin(MixtureOfExperts): ...@@ -305,7 +289,6 @@ class MoEMixin(MixtureOfExperts):
top_k=top_k, top_k=top_k,
hidden_size=hidden_size, hidden_size=hidden_size,
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
reduce_results=reduce_results,
renormalize=renormalize, renormalize=renormalize,
# Hard coded because topk happens in Transformers # Hard coded because topk happens in Transformers
use_grouped_topk=False, use_grouped_topk=False,
...@@ -326,13 +309,6 @@ class MoEMixin(MixtureOfExperts): ...@@ -326,13 +309,6 @@ class MoEMixin(MixtureOfExperts):
self.moe_layers.append(fused_experts) self.moe_layers.append(fused_experts)
self.expert_weights.append(fused_experts.get_expert_weights()) self.expert_weights.append(fused_experts.get_expert_weights())
self.num_moe_layers += 1 self.num_moe_layers += 1
# If results are not all-reduced in FusedMoE, ensure they
# are all-reduced at the end of mlp.forward() if tensor
# parallel or expert parallel is enabled
if not reduce_results and (
fused_experts.tp_size > 1 or fused_experts.ep_size > 1
):
add_all_reduce(mlp)
else: else:
_recursive_replace(child_module, prefix=qual_name) _recursive_replace(child_module, prefix=qual_name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment