Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
77e929a1
Unverified
Commit
77e929a1
authored
Apr 04, 2025
by
fzyzcjy
Committed by
GitHub
Apr 04, 2025
Browse files
Support async DeepEP by splitting into two stages (#4995)
parent
febe21ce
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
86 additions
and
32 deletions
+86
-32
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
+86
-32
No files found.
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
View file @
77e929a1
...
...
@@ -113,7 +113,7 @@ class _DeepEPDispatcherImplBase:
self
.
handle
=
None
def
dispatch
(
def
dispatch
_a
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
...
...
@@ -123,12 +123,18 @@ class _DeepEPDispatcherImplBase:
):
raise
NotImplementedError
def
combine
(
def
dispatch_b
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
def
combine_a
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
):
raise
NotImplementedError
def
combine_b
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
...
...
@@ -142,7 +148,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
self
.
async_finish
=
async_finish
self
.
src2dst
=
None
def
dispatch
(
def
dispatch
_a
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
...
...
@@ -151,12 +157,20 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
num_max_dispatch_tokens_per_rank
:
int
,
):
topk_idx
=
topk_idx
.
to
(
torch
.
int64
)
previous_event
=
Buffer
.
capture
()
if
self
.
async_finish
else
None
return
hidden_states
,
topk_idx
,
topk_weights
,
num_experts
,
previous_event
def
dispatch_b
(
self
,
hidden_states
,
topk_idx
,
topk_weights
,
num_experts
,
previous_event
):
(
hidden_states
,
topk_idx
,
topk_weights
,
event
,
)
=
self
.
_dispatch_normal
(
hidden_states
,
topk_idx
,
topk_weights
,
num_experts
)
)
=
self
.
_dispatch_core
(
hidden_states
,
topk_idx
,
topk_weights
,
num_experts
,
previous_event
)
event
.
current_stream_wait
()
if
self
.
async_finish
else
()
if
hidden_states
.
shape
[
0
]
>
0
:
reorder_topk_ids
,
seg_indptr
,
hidden_states
=
self
.
_deepep_permute
(
...
...
@@ -187,15 +201,14 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
expected_m
,
)
def
_dispatch_
n
or
mal
(
def
_dispatch_
c
or
e
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
num_experts
:
int
,
previous_event
,
):
previous_event
=
Buffer
.
capture
()
if
self
.
async_finish
else
None
(
num_tokens_per_rank
,
num_tokens_per_rdma_rank
,
...
...
@@ -279,12 +292,12 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
)
return
reorder_topk_ids
,
seg_indptr
,
gateup_input
def
combine
(
def
combine
_a
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
):
if
hidden_states
.
shape
[
0
]
>
0
:
num_tokens
=
self
.
src2dst
.
shape
[
0
]
//
self
.
router_topk
output
=
torch
.
empty
(
...
...
@@ -308,16 +321,15 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
hidden_states
,
event
=
self
.
_combine_normal
(
output
,
)
event
.
current_stream_wait
()
if
self
.
async_finish
else
()
previous_event
=
Buffer
.
capture
()
if
self
.
async_finish
else
None
return
output
,
previous_event
def
combine_b
(
self
,
output
,
previous_event
):
hidden_states
,
event
=
self
.
_combine_core
(
output
,
previous_event
)
event
.
current_stream_wait
()
if
self
.
async_finish
else
()
return
hidden_states
def
_combine_normal
(
self
,
x
:
torch
.
Tensor
):
previous_event
=
Buffer
.
capture
()
if
self
.
async_finish
else
None
def
_combine_core
(
self
,
x
:
torch
.
Tensor
,
previous_event
):
combined_x
,
_
,
event
=
self
.
buffer_normal
.
combine
(
x
,
self
.
handle
,
...
...
@@ -346,7 +358,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
)
self
.
return_recv_hook
=
return_recv_hook
def
dispatch
(
def
dispatch
_a
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
...
...
@@ -361,13 +373,33 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
*
topk_idx
.
shape
[
1
]
+
num_experts
)
//
num_experts
hidden_states
,
masked_m
,
event
,
hook
=
self
.
_dispatch_
low_latency
(
hidden_states
,
masked_m
,
event
,
hook
=
self
.
_dispatch_
core
(
hidden_states
,
topk_idx
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
use_fp8
=
True
,
)
return
(
hidden_states
,
topk_idx
,
topk_weights
,
masked_m
,
expected_m
,
event
,
hook
,
)
def
dispatch_b
(
self
,
hidden_states
,
topk_idx
,
topk_weights
,
masked_m
,
expected_m
,
event
,
hook
,
):
hook
()
if
self
.
return_recv_hook
else
event
.
current_stream_wait
()
# TODO
...
...
@@ -389,7 +421,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
expected_m
,
)
def
_dispatch_
low_latency
(
def
_dispatch_
core
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
...
...
@@ -443,22 +475,24 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
)
return
packed_recv_hidden
,
packed_recv_count
,
event
,
hook
def
combine
(
def
combine
_a
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
hidden_states
,
event
,
hook
=
self
.
_combine_
low_latency
(
):
hidden_states
,
event
,
hook
=
self
.
_combine_
core
(
hidden_states
,
topk_idx
,
topk_weights
,
)
hook
()
if
self
.
return_recv_hook
else
event
.
current_stream_wait
()
return
hidden_states
,
event
,
hook
def
combine_b
(
self
,
hidden_states
,
event
,
hook
):
hook
()
if
self
.
return_recv_hook
else
event
.
current_stream_wait
()
return
hidden_states
def
_combine_
low_latency
(
def
_combine_
core
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
...
...
@@ -514,7 +548,11 @@ class DeepEPDispatcher:
**
common_kwargs
,
)
def
dispatch
(
def
dispatch
(
self
,
*
args
,
**
kwargs
)
->
Tuple
:
self
.
dispatch_a
(
*
args
,
**
kwargs
)
return
self
.
dispatch_b
()
def
dispatch_a
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
...
...
@@ -522,29 +560,45 @@ class DeepEPDispatcher:
num_experts
:
int
,
num_max_dispatch_tokens_per_rank
:
int
=
128
,
forward_mode
:
ForwardMode
=
None
,
)
->
Tuple
:
return
self
.
_get_
dispatcher
(
forward_mode
).
dispatch
(
):
inner_state
=
self
.
_get_
impl
(
forward_mode
).
dispatch
_a
(
hidden_states
=
hidden_states
,
topk_idx
=
topk_idx
,
topk_weights
=
topk_weights
,
num_experts
=
num_experts
,
num_max_dispatch_tokens_per_rank
=
num_max_dispatch_tokens_per_rank
,
)
self
.
_dispatch_intermediate_state
=
forward_mode
,
inner_state
def
combine
(
def
dispatch_b
(
self
):
forward_mode
,
inner_state
=
self
.
_dispatch_intermediate_state
del
self
.
_dispatch_intermediate_state
return
self
.
_get_impl
(
forward_mode
).
dispatch_b
(
*
inner_state
)
def
combine
(
self
,
*
args
,
**
kwargs
)
->
Tuple
:
self
.
combine_a
(
*
args
,
**
kwargs
)
return
self
.
combine_b
()
def
combine_a
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
forward_mode
:
ForwardMode
,
)
->
torch
.
Tensor
:
return
self
.
_get_
dispatcher
(
forward_mode
).
combine
(
):
inner_state
=
self
.
_get_
impl
(
forward_mode
).
combine
_a
(
hidden_states
=
hidden_states
,
topk_idx
=
topk_idx
,
topk_weights
=
topk_weights
,
)
self
.
_combine_intermediate_state
=
forward_mode
,
inner_state
def
combine_b
(
self
):
forward_mode
,
inner_state
=
self
.
_combine_intermediate_state
del
self
.
_combine_intermediate_state
return
self
.
_get_impl
(
forward_mode
).
combine_b
(
*
inner_state
)
def
_get_
dispatcher
(
self
,
forward_mode
:
ForwardMode
)
->
_DeepEPDispatcherImplBase
:
def
_get_
impl
(
self
,
forward_mode
:
ForwardMode
)
->
"
_DeepEPDispatcherImplBase
"
:
resolved_deepep_mode
=
self
.
deepep_mode
.
resolve
(
forward_mode
)
if
resolved_deepep_mode
==
DeepEPMode
.
normal
:
return
self
.
_normal_dispatcher
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment