Unverified Commit ce399e15 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Make single-batch overlap compatible with NextN (#11804)

parent ea6275df
......@@ -170,6 +170,7 @@ class DeepEPMoE(FusedMoE):
forward_batch: ForwardBatch,
forward_shared_experts=None,
alt_stream=None,
disable_sbo=False,
):
# We have to call SBO inside MoE to be compatible with hooks used in offloading
return single_batch_overlap.execute_sbo(
......@@ -181,6 +182,7 @@ class DeepEPMoE(FusedMoE):
experts=self,
forward_shared_experts=forward_shared_experts,
alt_stream=alt_stream,
disable_sbo=disable_sbo,
)
def dispatch(
......
......@@ -902,6 +902,8 @@ class DeepseekV2MoE(nn.Module):
dict(
forward_shared_experts=_forward_shared_experts_and_put_results,
alt_stream=self.alt_stream,
# SBO is not yet implemented for NextN
disable_sbo=self.is_nextn,
)
if self._fuse_shared_experts_inside_sbo
else {}
......
......@@ -60,13 +60,14 @@ def execute_sbo(
topk_weights: torch.Tensor,
forward_batch: ForwardBatch,
alt_stream: Optional = None,
disable_sbo: bool = False,
):
dispatch_output = experts.dispatch(
hidden_states, topk_idx, topk_weights, forward_batch
)
combine_overlap_args, down_gemm_overlap_args, meta_overlap_args = (
_compute_overlap_args(dispatch_output, alt_stream)
_compute_overlap_args(dispatch_output, alt_stream, disable_sbo=disable_sbo)
)
hidden_states = experts.moe_impl(
......@@ -75,7 +76,7 @@ def execute_sbo(
if (e := meta_overlap_args.get("record_event_after_down")) is not None:
e.record()
if SboFlags.enable_combine_shared_two_stream_overlap():
if (not disable_sbo) and SboFlags.enable_combine_shared_two_stream_overlap():
# TODO reduce sm for non-deepgemm
with deep_gemm_wrapper.configure_deep_gemm_num_sms(
meta_overlap_args["compute_num_sms"]
......@@ -93,8 +94,8 @@ def execute_sbo(
return hidden_states
def _compute_overlap_args(dispatch_output, alt_stream):
if not (
def _compute_overlap_args(dispatch_output, alt_stream, disable_sbo):
if disable_sbo or not (
SboFlags.enable_combine_down_gemm_two_stream_overlap()
or SboFlags.enable_combine_shared_two_stream_overlap()
):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment