"verl/workers/reward/function.py" did not exist on "7121d0b089a891aef1067393368f1e7a18990fde"
Unverified Commit b1e13e7c authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[hotfix] Incorrect CombineOverlapArgs in SBO (#12230)

parent cc7b04a2
......@@ -235,7 +235,6 @@ class DeepEPMoE(FusedMoE):
hidden_states=output,
topk_ids=dispatch_output.topk_ids,
topk_weights=dispatch_output.topk_weights,
overlap_args=down_gemm_overlap_args,
)
def combine(
......
......@@ -97,7 +97,6 @@ class DeepEPNormalCombineInput(NamedTuple):
hidden_states: torch.Tensor
topk_ids: torch.Tensor
topk_weights: torch.Tensor
overlap_args: Optional[CombineOverlapArgs] = None
@property
def format(self) -> CombineInputFormat:
......@@ -110,7 +109,6 @@ class DeepEPLLCombineInput(NamedTuple):
hidden_states: torch.Tensor
topk_ids: torch.Tensor
topk_weights: torch.Tensor
overlap_args: Optional[CombineOverlapArgs] = None
@property
def format(self) -> CombineInputFormat:
......@@ -333,7 +331,7 @@ class _DeepEPDispatcherImplBase:
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
overlap_args: Optional["CombineOverlapArgs"],
overlap_args: Optional[CombineOverlapArgs] = None,
):
raise NotImplementedError
......@@ -463,7 +461,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
overlap_args: Optional["CombineOverlapArgs"],
overlap_args: Optional[CombineOverlapArgs] = None,
):
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
......@@ -619,7 +617,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
overlap_args: Optional["CombineOverlapArgs"],
overlap_args: Optional[CombineOverlapArgs] = None,
):
hidden_states, event, hook = self._combine_core(
hidden_states,
......@@ -645,7 +643,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
overlap_args: Optional["CombineOverlapArgs"],
overlap_args: Optional[CombineOverlapArgs] = None,
):
buffer = self._get_buffer()
......@@ -762,16 +760,21 @@ class DeepEPDispatcher(BaseDispatcher):
del self._dispatch_intermediate_state
return self._get_impl().dispatch_b(*inner_state)
def combine(self, combine_input: CombineInput) -> Tuple:
self.combine_a(combine_input)
def combine(
self,
combine_input: CombineInput,
overlap_args: Optional[CombineOverlapArgs] = None,
) -> Tuple:
self.combine_a(combine_input, overlap_args)
ret = self.combine_b()
return ret
def combine_a(
self,
combine_input: CombineInput,
overlap_args: Optional[CombineOverlapArgs] = None,
):
hidden_states, topk_ids, topk_weights, overlap_args = combine_input
hidden_states, topk_ids, topk_weights = combine_input
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
inner_state = self._get_impl().combine_a(
hidden_states=hidden_states,
......
......@@ -98,7 +98,10 @@ def execute_sbo(
):
forward_shared_experts()
hidden_states = experts.dispatcher.combine(combine_input=combine_input)
hidden_states = experts.dispatcher.combine(
combine_input=combine_input,
overlap_args=combine_overlap_args,
)
return hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment