Commit b281794e authored by laibao's avatar laibao Committed by zhangzbb
Browse files

[BUGFIX] 修复 fused MoE modular kernel 路径中 shared_output 和 routed_scaling_factor 透传不完整的问题

parent be03cbe8
...@@ -98,6 +98,8 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): ...@@ -98,6 +98,8 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.fused_experts( return self.fused_experts(
hidden_states=x, hidden_states=x,
...@@ -110,4 +112,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): ...@@ -110,4 +112,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=layer.apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=None if self.disable_expert_map else layer.expert_map, expert_map=None if self.disable_expert_map else layer.expert_map,
) use_nn_moe=use_nn_moe,
\ No newline at end of file shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
...@@ -735,6 +735,8 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -735,6 +735,8 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_tokens_meta: ExpertTokensMetadata | None, expert_tokens_meta: ExpertTokensMetadata | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> None: ) -> None:
""" """
This function computes the intermediate result of a Mixture of Experts This function computes the intermediate result of a Mixture of Experts
...@@ -1155,6 +1157,8 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1155,6 +1157,8 @@ class FusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
expert_tokens_meta: ExpertTokensMetadata | None, expert_tokens_meta: ExpertTokensMetadata | None,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
_, M_full, N, K, top_k = self.fused_experts.moe_problem_size( _, M_full, N, K, top_k = self.fused_experts.moe_problem_size(
a1q, w1, w2, topk_ids a1q, w1, w2, topk_ids
...@@ -1216,7 +1220,13 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1216,7 +1220,13 @@ class FusedMoEModularKernel(torch.nn.Module):
c_fused_out = self._slice_output_tensor( c_fused_out = self._slice_output_tensor(
fused_out, chunk_idx, num_chunks, CHUNK_SIZE, M_full fused_out, chunk_idx, num_chunks, CHUNK_SIZE, M_full
) )
c_shared_output = (
None
if shared_output is None
else self._slice_output_tensor(
shared_output, chunk_idx, num_chunks, CHUNK_SIZE, M_full
)
)
self.fused_experts.apply( self.fused_experts.apply(
output=c_fused_out, output=c_fused_out,
hidden_states=a1q[s:e], hidden_states=a1q[s:e],
...@@ -1234,6 +1244,8 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1234,6 +1244,8 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_tokens_meta=c_expert_tokens_meta, expert_tokens_meta=c_expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
use_nn_moe=use_nn_moe, use_nn_moe=use_nn_moe,
shared_output=c_shared_output,
routed_scaling_factor=routed_scaling_factor,
) )
return fused_out return fused_out
...@@ -1246,13 +1258,12 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1246,13 +1258,12 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
shared_output: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
""" """
The _finalize method is a wrapper around self.prepare_finalize.finalize The _finalize method is a wrapper around self.prepare_finalize.finalize
that handles DBO, async and shared expert overlap. that handles DBO, async and shared expert overlap.
""" """
shared_output: torch.Tensor | None = None
if not self.prepare_finalize.supports_async(): if not self.prepare_finalize.supports_async():
assert not dbo_enabled() assert not dbo_enabled()
...@@ -1264,11 +1275,11 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1264,11 +1275,11 @@ class FusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input, apply_router_weight_on_input,
self.fused_experts.finalize_weight_and_reduce_impl(), self.fused_experts.finalize_weight_and_reduce_impl(),
) )
if self.shared_experts is not None: if shared_output is None and self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
else: else:
self.alt_event.record() self.alt_event.record()
if self.shared_experts is not None: if shared_output is None and self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
current_stream = torch.cuda.current_stream() current_stream = torch.cuda.current_stream()
...@@ -1327,6 +1338,8 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1327,6 +1338,8 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map: torch.Tensor | None = None, expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets This function computes a Mixture of Experts (MoE) layer using two sets
...@@ -1389,6 +1402,8 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1389,6 +1402,8 @@ class FusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
expert_tokens_meta=expert_tokens_meta, expert_tokens_meta=expert_tokens_meta,
use_nn_moe=use_nn_moe, use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
) )
return self._finalize( return self._finalize(
...@@ -1398,4 +1413,5 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1398,4 +1413,5 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_weights, topk_weights,
topk_ids, topk_ids,
apply_router_weight_on_input, apply_router_weight_on_input,
shared_output=shared_output,
) )
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment