Unverified Commit 517b769b authored by czhu-cohere's avatar czhu-cohere Committed by GitHub
Browse files

[Perf] Fix DBO overlap: capture DeepEP event before yield (#38451)


Signed-off-by: default avatarroot <conway.zhu@cohere.com>
parent d9b90a07
...@@ -107,15 +107,17 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular): ...@@ -107,15 +107,17 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
) -> Callable: ) -> Callable:
has_scales = token_scales is not None has_scales = token_scales is not None
# Capture a DeepEP event on the compute stream before yielding.
# This must happen before the yield so the event only covers this
# ubatch's compute work. If captured after, the compute stream tail
# may include the other ubatch's work, preventing overlap.
previous_event = dbo_get_previous_event(self.buffer.capture)
# We yield before launching the dispatch kernel since the dispatch # We yield before launching the dispatch kernel since the dispatch
# kernel will block the CPU so we want to queue up all the compute # kernel will block the CPU so we want to queue up all the compute
# for the other ubatch before the dispatch kernel starts. # for the other ubatch before the dispatch kernel starts.
dbo_yield_and_switch_from_compute_to_comm() dbo_yield_and_switch_from_compute_to_comm()
# capture a DeepEP event and pass it as previous_event so
# DeepEP honors the dependency internally.
previous_event = dbo_get_previous_event(self.buffer.capture)
( (
num_tokens_per_rank, num_tokens_per_rank,
num_tokens_per_rdma_rank, num_tokens_per_rdma_rank,
...@@ -357,11 +359,11 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular): ...@@ -357,11 +359,11 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
topk_ids=topk_ids, topk_ids=topk_ids,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
) )
previous_event = dbo_get_previous_event(self.buffer.capture)
dbo_yield_and_switch_from_compute_to_comm() dbo_yield_and_switch_from_compute_to_comm()
assert fused_expert_output.dtype == torch.bfloat16, ( assert fused_expert_output.dtype == torch.bfloat16, (
f"Expected fused_expert_output bfloat16, got {fused_expert_output.dtype}" f"Expected fused_expert_output bfloat16, got {fused_expert_output.dtype}"
) )
previous_event = dbo_get_previous_event(self.buffer.capture)
combined_x, _, event = self.buffer.combine( combined_x, _, event = self.buffer.combine(
# HT combine only supports BF16 # HT combine only supports BF16
x=fused_expert_output, x=fused_expert_output,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment