Commit b281794e authored by laibao's avatar laibao Committed by zhangzbb
Browse files

[BUGFIX] 修复 fused MoE modular kernel 路径中 shared_output 和 routed_scaling_factor 透传不完整的问题

parent be03cbe8
......@@ -98,6 +98,8 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_nn_moe: bool | None = False,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.fused_experts(
hidden_states=x,
......@@ -110,4 +112,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=None if self.disable_expert_map else layer.expert_map,
)
\ No newline at end of file
use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
......@@ -735,6 +735,8 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
workspace2: torch.Tensor,
expert_tokens_meta: ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> None:
"""
This function computes the intermediate result of a Mixture of Experts
......@@ -1155,6 +1157,8 @@ class FusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input: bool,
expert_tokens_meta: ExpertTokensMetadata | None,
use_nn_moe: bool | None = False,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor:
_, M_full, N, K, top_k = self.fused_experts.moe_problem_size(
a1q, w1, w2, topk_ids
......@@ -1216,7 +1220,13 @@ class FusedMoEModularKernel(torch.nn.Module):
c_fused_out = self._slice_output_tensor(
fused_out, chunk_idx, num_chunks, CHUNK_SIZE, M_full
)
c_shared_output = (
None
if shared_output is None
else self._slice_output_tensor(
shared_output, chunk_idx, num_chunks, CHUNK_SIZE, M_full
)
)
self.fused_experts.apply(
output=c_fused_out,
hidden_states=a1q[s:e],
......@@ -1234,6 +1244,8 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_tokens_meta=c_expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input,
use_nn_moe=use_nn_moe,
shared_output=c_shared_output,
routed_scaling_factor=routed_scaling_factor,
)
return fused_out
......@@ -1246,13 +1258,12 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
shared_output: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
The _finalize method is a wrapper around self.prepare_finalize.finalize
that handles DBO, async and shared expert overlap.
"""
shared_output: torch.Tensor | None = None
if not self.prepare_finalize.supports_async():
assert not dbo_enabled()
......@@ -1264,11 +1275,11 @@ class FusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input,
self.fused_experts.finalize_weight_and_reduce_impl(),
)
if self.shared_experts is not None:
if shared_output is None and self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
else:
self.alt_event.record()
if self.shared_experts is not None:
if shared_output is None and self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
current_stream = torch.cuda.current_stream()
......@@ -1327,6 +1338,8 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
use_nn_moe: bool | None = False,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
This function computes a Mixture of Experts (MoE) layer using two sets
......@@ -1389,6 +1402,8 @@ class FusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input=apply_router_weight_on_input,
expert_tokens_meta=expert_tokens_meta,
use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
return self._finalize(
......@@ -1398,4 +1413,5 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_weights,
topk_ids,
apply_router_weight_on_input,
shared_output=shared_output,
)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment