Unverified Commit 995bbf38 authored by TomerBN-Nvidia's avatar TomerBN-Nvidia Committed by GitHub
Browse files

[Bugfix] Fix shared expert input for latent MoE in EP+DP (Nemotron-H) (#34087)


Signed-off-by: default avatarTomer Natan <tbarnatan@nvidia.com>
Co-authored-by: default avatarCursor <cursoragent@cursor.com>
parent d4f123cc
...@@ -139,7 +139,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -139,7 +139,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
# work with SP. This will be removed in follow up after we get # work with SP. This will be removed in follow up after we get
# rid of the FlashInfer specific P/F function. # rid of the FlashInfer specific P/F function.
# TODO: the per-tensor fp8 kernels don't work with MNNVL FI A2As. # TODO: the per-tensor fp8 kernels don't work with MNNVL FI A2As.
return not moe_parallel_config.is_sequence_parallel return True
@staticmethod @staticmethod
def activation_format() -> mk.FusedMoEActivationFormat: def activation_format() -> mk.FusedMoEActivationFormat:
......
...@@ -101,4 +101,5 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): ...@@ -101,4 +101,5 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=layer.apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=None if self.disable_expert_map else layer.expert_map, expert_map=None if self.disable_expert_map else layer.expert_map,
shared_experts_input=layer._get_shared_experts_input(x),
) )
...@@ -1228,13 +1228,28 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1228,13 +1228,28 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
shared_experts_input: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
""" """
The _finalize method is a wrapper around self.prepare_finalize.finalize The _finalize method is a wrapper around self.prepare_finalize.finalize
that handles DBO, async and shared expert overlap. that handles DBO, async and shared expert overlap.
Args:
shared_experts_input: Optional separate input for shared experts.
When latent MoE is used, hidden_states is the latent-projected
tensor (smaller dimension) used by routed experts, while
shared_experts_input is the original hidden_states (full
dimension) needed by the shared expert MLP.
""" """
shared_output: torch.Tensor | None = None shared_output: torch.Tensor | None = None
# For latent MoE: shared experts need the original hidden_states
# (full hidden_size), not the latent-projected version used by
# routed experts.
se_hidden_states = (
shared_experts_input if shared_experts_input is not None else hidden_states
)
if not self.prepare_finalize.supports_async(): if not self.prepare_finalize.supports_async():
assert not dbo_enabled() assert not dbo_enabled()
...@@ -1247,7 +1262,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1247,7 +1262,7 @@ class FusedMoEModularKernel(torch.nn.Module):
self.fused_experts.finalize_weight_and_reduce_impl(), self.fused_experts.finalize_weight_and_reduce_impl(),
) )
if self.shared_experts is not None: if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(se_hidden_states)
else: else:
finalize_ret = self.prepare_finalize.finalize_async( finalize_ret = self.prepare_finalize.finalize_async(
output, output,
...@@ -1258,7 +1273,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1258,7 +1273,7 @@ class FusedMoEModularKernel(torch.nn.Module):
self.fused_experts.finalize_weight_and_reduce_impl(), self.fused_experts.finalize_weight_and_reduce_impl(),
) )
if self.shared_experts is not None: if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(se_hidden_states)
# TODO(lucas): refactor this in the alternative schedules followup # TODO(lucas): refactor this in the alternative schedules followup
# currently unpack if we have hook + receiver pair or just # currently unpack if we have hook + receiver pair or just
...@@ -1298,6 +1313,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1298,6 +1313,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: torch.Tensor | None = None, expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
shared_experts_input: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets This function computes a Mixture of Experts (MoE) layer using two sets
...@@ -1320,6 +1336,9 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1320,6 +1336,9 @@ class FusedMoEModularKernel(torch.nn.Module):
- apply_router_weight_on_input (bool): When true, the topk weights are - apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is applied directly on the inputs. This is only applicable when topk is
1. 1.
- shared_experts_input (Optional[torch.Tensor]): Optional separate
input for shared experts. For latent MoE, this is the original
hidden_states before latent projection.
Returns: Returns:
- torch.Tensor: The output tensor after applying the MoE layer. - torch.Tensor: The output tensor after applying the MoE layer.
...@@ -1368,4 +1387,5 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1368,4 +1387,5 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_weights, topk_weights,
topk_ids, topk_ids,
apply_router_weight_on_input, apply_router_weight_on_input,
shared_experts_input=shared_experts_input,
) )
...@@ -361,6 +361,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -361,6 +361,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map, expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=layer._get_shared_experts_input(x),
) )
...@@ -672,6 +673,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -672,6 +673,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map, expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=layer._get_shared_experts_input(x),
) )
...@@ -1077,6 +1079,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1077,6 +1079,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
# https://github.com/vllm-project/vllm/commit/84166fee9770e6fba71a96978b3e7d149392fb28 # noqa: E501 # https://github.com/vllm-project/vllm/commit/84166fee9770e6fba71a96978b3e7d149392fb28 # noqa: E501
expert_map=layer.expert_map, expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=layer._get_shared_experts_input(x),
) )
@property @property
......
...@@ -1023,6 +1023,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1023,6 +1023,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map, expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=layer._get_shared_experts_input(x),
) )
......
...@@ -980,6 +980,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -980,6 +980,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map, expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=layer._get_shared_experts_input(x),
) )
...@@ -1550,6 +1551,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1550,6 +1551,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map, expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=layer._get_shared_experts_input(x),
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment