Unverified Commit 517b769b authored by czhu-cohere's avatar czhu-cohere Committed by GitHub
Browse files

[Perf] Fix DBO overlap: capture DeepEP event before yield (#38451)


Signed-off-by: default avatarroot <conway.zhu@cohere.com>
parent d9b90a07
......@@ -107,15 +107,17 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
) -> Callable:
has_scales = token_scales is not None
# Capture a DeepEP event on the compute stream before yielding.
# This must happen before the yield so the event only covers this
# ubatch's compute work. If captured after, the compute stream tail
# may include the other ubatch's work, preventing overlap.
previous_event = dbo_get_previous_event(self.buffer.capture)
# We yield before launching the dispatch kernel since the dispatch
# kernel will block the CPU so we want to queue up all the compute
# for the other ubatch before the dispatch kernel starts.
dbo_yield_and_switch_from_compute_to_comm()
# capture a DeepEP event and pass it as previous_event so
# DeepEP honors the dependency internally.
previous_event = dbo_get_previous_event(self.buffer.capture)
(
num_tokens_per_rank,
num_tokens_per_rdma_rank,
......@@ -357,11 +359,11 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
topk_ids=topk_ids,
apply_router_weight_on_input=apply_router_weight_on_input,
)
previous_event = dbo_get_previous_event(self.buffer.capture)
dbo_yield_and_switch_from_compute_to_comm()
assert fused_expert_output.dtype == torch.bfloat16, (
f"Expected fused_expert_output bfloat16, got {fused_expert_output.dtype}"
)
previous_event = dbo_get_previous_event(self.buffer.capture)
combined_x, _, event = self.buffer.combine(
# HT combine only supports BF16
x=fused_expert_output,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment