Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
517b769b
Unverified
Commit
517b769b
authored
Mar 31, 2026
by
czhu-cohere
Committed by
GitHub
Mar 31, 2026
Browse files
[Perf] Fix DBO overlap: capture DeepEP event before yield (#38451)
Signed-off-by:
root
<
conway.zhu@cohere.com
>
parent
d9b90a07
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
5 deletions
+7
-5
vllm/model_executor/layers/fused_moe/prepare_finalize/deepep_ht.py
...l_executor/layers/fused_moe/prepare_finalize/deepep_ht.py
+7
-5
No files found.
vllm/model_executor/layers/fused_moe/prepare_finalize/deepep_ht.py
View file @
517b769b
...
@@ -107,15 +107,17 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
...
@@ -107,15 +107,17 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
)
->
Callable
:
)
->
Callable
:
has_scales
=
token_scales
is
not
None
has_scales
=
token_scales
is
not
None
# Capture a DeepEP event on the compute stream before yielding.
# This must happen before the yield so the event only covers this
# ubatch's compute work. If captured after, the compute stream tail
# may include the other ubatch's work, preventing overlap.
previous_event
=
dbo_get_previous_event
(
self
.
buffer
.
capture
)
# We yield before launching the dispatch kernel since the dispatch
# We yield before launching the dispatch kernel since the dispatch
# kernel will block the CPU so we want to queue up all the compute
# kernel will block the CPU so we want to queue up all the compute
# for the other ubatch before the dispatch kernel starts.
# for the other ubatch before the dispatch kernel starts.
dbo_yield_and_switch_from_compute_to_comm
()
dbo_yield_and_switch_from_compute_to_comm
()
# capture a DeepEP event and pass it as previous_event so
# DeepEP honors the dependency internally.
previous_event
=
dbo_get_previous_event
(
self
.
buffer
.
capture
)
(
(
num_tokens_per_rank
,
num_tokens_per_rank
,
num_tokens_per_rdma_rank
,
num_tokens_per_rdma_rank
,
...
@@ -357,11 +359,11 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
...
@@ -357,11 +359,11 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
topk_ids
=
topk_ids
,
topk_ids
=
topk_ids
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
)
previous_event
=
dbo_get_previous_event
(
self
.
buffer
.
capture
)
dbo_yield_and_switch_from_compute_to_comm
()
dbo_yield_and_switch_from_compute_to_comm
()
assert
fused_expert_output
.
dtype
==
torch
.
bfloat16
,
(
assert
fused_expert_output
.
dtype
==
torch
.
bfloat16
,
(
f
"Expected fused_expert_output bfloat16, got
{
fused_expert_output
.
dtype
}
"
f
"Expected fused_expert_output bfloat16, got
{
fused_expert_output
.
dtype
}
"
)
)
previous_event
=
dbo_get_previous_event
(
self
.
buffer
.
capture
)
combined_x
,
_
,
event
=
self
.
buffer
.
combine
(
combined_x
,
_
,
event
=
self
.
buffer
.
combine
(
# HT combine only supports BF16
# HT combine only supports BF16
x
=
fused_expert_output
,
x
=
fused_expert_output
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment