Unverified Commit ce399e15 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Make single-batch overlap compatible with NextN (#11804)

parent ea6275df
...@@ -170,6 +170,7 @@ class DeepEPMoE(FusedMoE): ...@@ -170,6 +170,7 @@ class DeepEPMoE(FusedMoE):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
forward_shared_experts=None, forward_shared_experts=None,
alt_stream=None, alt_stream=None,
disable_sbo=False,
): ):
# We have to call SBO inside MoE to be compatible with hooks used in offloading # We have to call SBO inside MoE to be compatible with hooks used in offloading
return single_batch_overlap.execute_sbo( return single_batch_overlap.execute_sbo(
...@@ -181,6 +182,7 @@ class DeepEPMoE(FusedMoE): ...@@ -181,6 +182,7 @@ class DeepEPMoE(FusedMoE):
experts=self, experts=self,
forward_shared_experts=forward_shared_experts, forward_shared_experts=forward_shared_experts,
alt_stream=alt_stream, alt_stream=alt_stream,
disable_sbo=disable_sbo,
) )
def dispatch( def dispatch(
......
...@@ -902,6 +902,8 @@ class DeepseekV2MoE(nn.Module): ...@@ -902,6 +902,8 @@ class DeepseekV2MoE(nn.Module):
dict( dict(
forward_shared_experts=_forward_shared_experts_and_put_results, forward_shared_experts=_forward_shared_experts_and_put_results,
alt_stream=self.alt_stream, alt_stream=self.alt_stream,
# SBO is not yet implemented for NextN
disable_sbo=self.is_nextn,
) )
if self._fuse_shared_experts_inside_sbo if self._fuse_shared_experts_inside_sbo
else {} else {}
......
...@@ -60,13 +60,14 @@ def execute_sbo( ...@@ -60,13 +60,14 @@ def execute_sbo(
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
alt_stream: Optional = None, alt_stream: Optional = None,
disable_sbo: bool = False,
): ):
dispatch_output = experts.dispatch( dispatch_output = experts.dispatch(
hidden_states, topk_idx, topk_weights, forward_batch hidden_states, topk_idx, topk_weights, forward_batch
) )
combine_overlap_args, down_gemm_overlap_args, meta_overlap_args = ( combine_overlap_args, down_gemm_overlap_args, meta_overlap_args = (
_compute_overlap_args(dispatch_output, alt_stream) _compute_overlap_args(dispatch_output, alt_stream, disable_sbo=disable_sbo)
) )
hidden_states = experts.moe_impl( hidden_states = experts.moe_impl(
...@@ -75,7 +76,7 @@ def execute_sbo( ...@@ -75,7 +76,7 @@ def execute_sbo(
if (e := meta_overlap_args.get("record_event_after_down")) is not None: if (e := meta_overlap_args.get("record_event_after_down")) is not None:
e.record() e.record()
if SboFlags.enable_combine_shared_two_stream_overlap(): if (not disable_sbo) and SboFlags.enable_combine_shared_two_stream_overlap():
# TODO reduce sm for non-deepgemm # TODO reduce sm for non-deepgemm
with deep_gemm_wrapper.configure_deep_gemm_num_sms( with deep_gemm_wrapper.configure_deep_gemm_num_sms(
meta_overlap_args["compute_num_sms"] meta_overlap_args["compute_num_sms"]
...@@ -93,8 +94,8 @@ def execute_sbo( ...@@ -93,8 +94,8 @@ def execute_sbo(
return hidden_states return hidden_states
def _compute_overlap_args(dispatch_output, alt_stream): def _compute_overlap_args(dispatch_output, alt_stream, disable_sbo):
if not ( if disable_sbo or not (
SboFlags.enable_combine_down_gemm_two_stream_overlap() SboFlags.enable_combine_down_gemm_two_stream_overlap()
or SboFlags.enable_combine_shared_two_stream_overlap() or SboFlags.enable_combine_shared_two_stream_overlap()
): ):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment