Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
dc2979c5
Unverified
Commit
dc2979c5
authored
Sep 18, 2025
by
bnellnm
Committed by
GitHub
Sep 18, 2025
Browse files
[Kernels] Overlap shared experts with combine instead of dispatch (#24254)
Signed-off-by:
Bill Nell
<
bnell@redhat.com
>
parent
027d37df
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
203 additions
and
36 deletions
+203
-36
vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
...l_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
+45
-5
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
...l_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
+47
-8
vllm/model_executor/layers/fused_moe/modular_kernel.py
vllm/model_executor/layers/fused_moe/modular_kernel.py
+76
-19
vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
.../model_executor/layers/fused_moe/pplx_prepare_finalize.py
+35
-4
No files found.
vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
View file @
dc2979c5
...
...
@@ -240,7 +240,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
quant_config
)
return
receiver
()
def
finalize
(
def
_
finalize
(
self
,
output
:
torch
.
Tensor
,
fused_expert_output
:
torch
.
Tensor
,
...
...
@@ -248,7 +248,8 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
weight_and_reduce_impl
:
mk
.
TopKWeightAndReduce
,
)
->
None
:
do_async
:
bool
,
)
->
Optional
[
Callable
]:
assert
self
.
handle
is
not
None
...
...
@@ -271,7 +272,46 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights
=
None
,
config
=
self
.
_get_combine_config
(),
previous_event
=
None
,
async_finish
=
False
,
async_finish
=
do_async
,
allocate_on_comm_stream
=
False
)
# Respect inplace outputs.
output
.
copy_
(
combined_x
,
non_blocking
=
True
)
if
do_async
:
def
_receiver
():
event
.
current_stream_wait
()
# Respect inplace outputs.
output
.
copy_
(
combined_x
,
non_blocking
=
True
)
return
lambda
:
_receiver
()
else
:
# Respect inplace outputs.
output
.
copy_
(
combined_x
,
non_blocking
=
True
)
return
None
def
finalize_async
(
self
,
output
:
torch
.
Tensor
,
fused_expert_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
weight_and_reduce_impl
:
mk
.
TopKWeightAndReduce
,
)
->
Callable
:
receiver
=
self
.
_finalize
(
output
,
fused_expert_output
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
,
weight_and_reduce_impl
,
True
)
assert
receiver
is
not
None
return
receiver
def
finalize
(
self
,
output
:
torch
.
Tensor
,
fused_expert_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
weight_and_reduce_impl
:
mk
.
TopKWeightAndReduce
,
)
->
None
:
self
.
_finalize
(
output
,
fused_expert_output
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
,
weight_and_reduce_impl
,
False
)
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
View file @
dc2979c5
...
...
@@ -12,8 +12,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
from
vllm.model_executor.layers.fused_moe.utils
import
(
moe_kernel_quantize_input
,
normalize_batched_scales_shape
)
from
vllm.v1.worker.ubatching
import
(
dbo_current_ubatch_id
,
dbo_enabled
,
dbo_maybe_run_recv_hook
,
dbo_register_recv_hook
,
dbo_yield
)
dbo_maybe_run_recv_hook
)
# DeepEP kernels quantize dispatch inputs in 128 element chunks.
DEEPEP_QUANT_BLOCK_SIZE
=
128
...
...
@@ -198,7 +197,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
hook
()
return
receiver
()
def
finalize
(
def
_
finalize
(
self
,
output
:
torch
.
Tensor
,
fused_expert_output
:
torch
.
Tensor
,
...
...
@@ -206,13 +205,14 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
weight_and_reduce_impl
:
mk
.
TopKWeightAndReduce
,
)
->
None
:
do_async
:
bool
,
)
->
Optional
[
Callable
]:
assert
isinstance
(
weight_and_reduce_impl
,
TopKWeightAndReduceDelegate
),
(
"Weight application and reduction happens in the combine kernel."
)
a2a_idx
=
dbo_current_ubatch_id
()
do_recv_hook
=
dbo_enabled
()
do_recv_hook
=
dbo_enabled
()
or
do_async
handle
=
self
.
handles
[
a2a_idx
]
assert
handle
is
not
None
...
...
@@ -232,6 +232,45 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
zero_copy
=
False
,
return_recv_hook
=
do_recv_hook
,
out
=
output
)
if
recv_hook
is
not
None
:
dbo_register_recv_hook
(
recv_hook
)
dbo_yield
()
return
recv_hook
def
finalize_async
(
self
,
output
:
torch
.
Tensor
,
fused_expert_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
weight_and_reduce_impl
:
mk
.
TopKWeightAndReduce
,
)
->
Callable
:
recv_hook
=
self
.
_finalize
(
output
,
fused_expert_output
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
,
weight_and_reduce_impl
,
do_async
=
True
,
)
assert
recv_hook
is
not
None
return
recv_hook
def
finalize
(
self
,
output
:
torch
.
Tensor
,
fused_expert_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
weight_and_reduce_impl
:
mk
.
TopKWeightAndReduce
,
)
->
None
:
self
.
_finalize
(
output
,
fused_expert_output
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
,
weight_and_reduce_impl
,
do_async
=
False
,
)
vllm/model_executor/layers/fused_moe/modular_kernel.py
View file @
dc2979c5
...
...
@@ -209,7 +209,8 @@ class FusedMoEPrepareAndFinalize(ABC):
def
supports_async
(
self
)
->
bool
:
"""
Indicates whether or not this class implements prepare_async.
Indicates whether or not this class implements prepare_async and
finalize_async.
"""
return
False
...
...
@@ -275,6 +276,42 @@ class FusedMoEPrepareAndFinalize(ABC):
"""
raise
NotImplementedError
def
finalize_async
(
self
,
output
:
torch
.
Tensor
,
fused_expert_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
weight_and_reduce_impl
:
TopKWeightAndReduce
,
)
->
Callable
:
"""
Perform any combine plus apply weights and perform a reduction on the
fused experts output but do not wait for results from other workers.
- output: The output tensor, written in place. Must be (M, K) shape.
- fused_expert_output: The unweighted, unreduced output of the fused
experts, it will have (M, topk, K) shape.
- topk_weights: The weights to be applied to the fused_experts_output.
- topk_ids: The topk_ids.
- apply_router_weight_on_input: When False, apply the weights to
fused_expert_output.
- weight_and_reduce_impl: An optional TopKWeightAndReduce
implementation.
Returns a callback that when invoked waits for results from other
workers and has the same return signature as `finalize`, e.g.
receiver = obj.finalize_async(output, ...)
... output not valid yet ...
receiver()
... output valid here ...
is equivalent to:
obj.finalize(output, ...)
"""
raise
NotImplementedError
@
property
@
abstractmethod
def
activation_format
(
self
)
->
FusedMoEActivationFormat
:
...
...
@@ -814,23 +851,20 @@ class FusedMoEModularKernel(torch.nn.Module):
"""
a1
=
hidden_states
output
=
a1
if
inplace
else
torch
.
zeros_like
(
a1
)
if
inplace
and
self
.
shared_experts
is
None
:
output
=
a1
else
:
output
=
torch
.
zeros_like
(
a1
)
local_num_experts
=
w1
.
size
(
0
)
if
global_num_experts
==
-
1
:
global_num_experts
=
local_num_experts
shared_output
:
torch
.
Tensor
if
not
self
.
prepare_finalize
.
supports_async
():
# We shouldn't be running an a2a kernel that doesn't
# support async prepare/finalize
assert
not
dbo_enabled
()
# Run shared experts serially with dispatch.
if
self
.
shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
a1
)
(
a1q
,
a1q_scale
,
expert_tokens_meta
,
_expert_topk_ids
,
_expert_topk_weights
)
=
self
.
prepare_finalize
.
prepare
(
a1
,
...
...
@@ -854,9 +888,6 @@ class FusedMoEModularKernel(torch.nn.Module):
self
.
fused_experts
.
quant_config
,
)
if
self
.
shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
a1
)
# If DBO is being used, register the hook with the ubatch context
# and call it in dbo_maybe_run_recv_hook instead of passing it to
# the receiver.
...
...
@@ -900,16 +931,42 @@ class FusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
self
.
prepare_finalize
.
finalize
(
output
,
fused_out
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
,
self
.
fused_experts
.
finalize_weight_and_reduce_impl
(),
)
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
if
not
self
.
prepare_finalize
.
supports_async
():
assert
not
dbo_enabled
()
self
.
prepare_finalize
.
finalize
(
output
,
fused_out
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
,
self
.
fused_experts
.
finalize_weight_and_reduce_impl
(),
)
if
self
.
shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
a1
)
else
:
recv_hook
=
self
.
prepare_finalize
.
finalize_async
(
output
,
fused_out
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
,
self
.
fused_experts
.
finalize_weight_and_reduce_impl
(),
)
if
self
.
shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
a1
)
assert
recv_hook
is
not
None
dbo_register_recv_hook
(
recv_hook
)
dbo_yield
()
if
not
dbo_enabled
():
recv_hook
()
if
self
.
shared_experts
is
None
:
return
output
else
:
assert
shared_output
is
not
None
return
shared_output
,
output
vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
View file @
dc2979c5
...
...
@@ -272,7 +272,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
hook
()
return
receiver
()
def
finalize
(
def
finalize
_async
(
self
,
output
:
torch
.
Tensor
,
fused_expert_output
:
torch
.
Tensor
,
...
...
@@ -280,7 +280,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
weight_and_reduce_impl
:
mk
.
TopKWeightAndReduce
,
)
->
Non
e
:
)
->
Callabl
e
:
assert
isinstance
(
weight_and_reduce_impl
,
TopKWeightAndReduceDelegate
),
(
"Weight application and reduction happens in the combine kernel."
)
...
...
@@ -303,8 +303,39 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
if
apply_router_weight_on_input
:
topk_weights
=
torch
.
ones_like
(
topk_weights
)
topk_ids_u32
=
topk_ids
.
view
(
dtype
=
torch
.
uint32
)
self
.
a2a
.
combine
(
out_tokens
=
output
,
indices
=
topk_ids
.
view
(
dtype
=
torch
.
uint
32
)
,
indices
=
topk_ids
_u
32
,
weights
=
topk_weights
,
expert_y
=
fused_expert_output
,
bound_m
=
bound_m
)
bound_m
=
bound_m
,
do_send
=
True
,
do_recv
=
False
)
return
lambda
:
self
.
a2a
.
combine
(
out_tokens
=
output
,
indices
=
topk_ids_u32
,
weights
=
topk_weights
,
expert_y
=
fused_expert_output
,
bound_m
=
bound_m
,
do_send
=
False
,
do_recv
=
True
)
def
finalize
(
self
,
output
:
torch
.
Tensor
,
fused_expert_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
weight_and_reduce_impl
:
mk
.
TopKWeightAndReduce
,
)
->
None
:
receiver
=
self
.
finalize_async
(
output
,
fused_expert_output
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
,
weight_and_reduce_impl
,
)
receiver
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment