Unverified Commit b1e13e7c authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[hotfix] Incorrect CombineOverlapArgs in SBO (#12230)

parent cc7b04a2
...@@ -235,7 +235,6 @@ class DeepEPMoE(FusedMoE): ...@@ -235,7 +235,6 @@ class DeepEPMoE(FusedMoE):
hidden_states=output, hidden_states=output,
topk_ids=dispatch_output.topk_ids, topk_ids=dispatch_output.topk_ids,
topk_weights=dispatch_output.topk_weights, topk_weights=dispatch_output.topk_weights,
overlap_args=down_gemm_overlap_args,
) )
def combine( def combine(
......
...@@ -97,7 +97,6 @@ class DeepEPNormalCombineInput(NamedTuple): ...@@ -97,7 +97,6 @@ class DeepEPNormalCombineInput(NamedTuple):
hidden_states: torch.Tensor hidden_states: torch.Tensor
topk_ids: torch.Tensor topk_ids: torch.Tensor
topk_weights: torch.Tensor topk_weights: torch.Tensor
overlap_args: Optional[CombineOverlapArgs] = None
@property @property
def format(self) -> CombineInputFormat: def format(self) -> CombineInputFormat:
...@@ -110,7 +109,6 @@ class DeepEPLLCombineInput(NamedTuple): ...@@ -110,7 +109,6 @@ class DeepEPLLCombineInput(NamedTuple):
hidden_states: torch.Tensor hidden_states: torch.Tensor
topk_ids: torch.Tensor topk_ids: torch.Tensor
topk_weights: torch.Tensor topk_weights: torch.Tensor
overlap_args: Optional[CombineOverlapArgs] = None
@property @property
def format(self) -> CombineInputFormat: def format(self) -> CombineInputFormat:
...@@ -333,7 +331,7 @@ class _DeepEPDispatcherImplBase: ...@@ -333,7 +331,7 @@ class _DeepEPDispatcherImplBase:
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
overlap_args: Optional["CombineOverlapArgs"], overlap_args: Optional[CombineOverlapArgs] = None,
): ):
raise NotImplementedError raise NotImplementedError
...@@ -463,7 +461,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -463,7 +461,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
overlap_args: Optional["CombineOverlapArgs"], overlap_args: Optional[CombineOverlapArgs] = None,
): ):
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu: if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
...@@ -619,7 +617,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -619,7 +617,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
overlap_args: Optional["CombineOverlapArgs"], overlap_args: Optional[CombineOverlapArgs] = None,
): ):
hidden_states, event, hook = self._combine_core( hidden_states, event, hook = self._combine_core(
hidden_states, hidden_states,
...@@ -645,7 +643,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -645,7 +643,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
overlap_args: Optional["CombineOverlapArgs"], overlap_args: Optional[CombineOverlapArgs] = None,
): ):
buffer = self._get_buffer() buffer = self._get_buffer()
...@@ -762,16 +760,21 @@ class DeepEPDispatcher(BaseDispatcher): ...@@ -762,16 +760,21 @@ class DeepEPDispatcher(BaseDispatcher):
del self._dispatch_intermediate_state del self._dispatch_intermediate_state
return self._get_impl().dispatch_b(*inner_state) return self._get_impl().dispatch_b(*inner_state)
def combine(self, combine_input: CombineInput) -> Tuple: def combine(
self.combine_a(combine_input) self,
combine_input: CombineInput,
overlap_args: Optional[CombineOverlapArgs] = None,
) -> Tuple:
self.combine_a(combine_input, overlap_args)
ret = self.combine_b() ret = self.combine_b()
return ret return ret
def combine_a( def combine_a(
self, self,
combine_input: CombineInput, combine_input: CombineInput,
overlap_args: Optional[CombineOverlapArgs] = None,
): ):
hidden_states, topk_ids, topk_weights, overlap_args = combine_input hidden_states, topk_ids, topk_weights = combine_input
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A) self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
inner_state = self._get_impl().combine_a( inner_state = self._get_impl().combine_a(
hidden_states=hidden_states, hidden_states=hidden_states,
......
...@@ -98,7 +98,10 @@ def execute_sbo( ...@@ -98,7 +98,10 @@ def execute_sbo(
): ):
forward_shared_experts() forward_shared_experts()
hidden_states = experts.dispatcher.combine(combine_input=combine_input) hidden_states = experts.dispatcher.combine(
combine_input=combine_input,
overlap_args=combine_overlap_args,
)
return hidden_states return hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment