Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
bc3f6db2
Unverified
Commit
bc3f6db2
authored
Apr 09, 2025
by
Jinyan Chen
Committed by
GitHub
Apr 08, 2025
Browse files
[Fix] DeepEP Compatibility with Low Latency (#5068)
Co-authored-by:
ch-wan
<
cwan39@gatech.edu
>
parent
aac531c5
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
148 additions
and
120 deletions
+148
-120
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
+145
-118
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+1
-1
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+1
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+1
-0
No files found.
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
View file @
bc3f6db2
...
...
@@ -7,6 +7,7 @@ try:
except
ImportError
:
use_deepep
=
False
from
enum
import
IntEnum
,
auto
from
typing
import
Optional
,
Tuple
import
torch
...
...
@@ -19,70 +20,95 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
_buffer_normal
=
None
_buffer_low_latency
=
None
class
DeepEPDispatchMode
(
IntEnum
):
NORMAL
=
auto
()
LOW_LATENCY
=
auto
()
def
_get_buffer_normal
(
group
:
dist
.
ProcessGroup
,
hidden_bytes
:
int
):
"""
Copy from DeepEP example usage in model inference prefilling.
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-model-training-or-inference-prefilling
"""
global
_buffer_normal
class
DeepEPBuffer
:
num_nvl_bytes
,
num_rdma_bytes
=
0
,
0
for
config
in
(
Buffer
.
get_dispatch_config
(
group
.
size
()),
Buffer
.
get_combine_config
(
group
.
size
()),
):
num_nvl_bytes
=
max
(
config
.
get_nvl_buffer_size_hint
(
hidden_bytes
,
group
.
size
()),
num_nvl_bytes
)
num_rdma_bytes
=
max
(
config
.
get_rdma_buffer_size_hint
(
hidden_bytes
,
group
.
size
()),
num_rdma_bytes
)
_buffer
:
Optional
[
Buffer
]
=
None
_dispatch_mode
:
Optional
[
DeepEPDispatchMode
]
=
None
_hidden_size
:
Optional
[
int
]
=
None
_num_max_dispatch_tokens_per_rank
:
Optional
[
int
]
=
None
_num_experts
:
Optional
[
int
]
=
None
if
(
_buffer_normal
is
None
or
_buffer_normal
.
group
!=
group
or
_buffer_normal
.
num_nvl_bytes
<
num_nvl_bytes
or
_buffer_normal
.
num_rdma_bytes
<
num_rdma_bytes
):
_buffer_normal
=
Buffer
(
group
,
num_nvl_bytes
,
num_rdma_bytes
)
return
_buffer_normal
def
_get_buffer_low_latency
(
group
:
dist
.
ProcessGroup
,
num_max_dispatch_tokens_per_rank
:
int
,
hidden
:
int
,
num_experts
:
int
,
):
"""
Copy from DeepEP example usage in model inference decoding.
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
"""
global
_buffer_low_latency
num_rdma_bytes
=
Buffer
.
get_low_latency_rdma_size_hint
(
num_max_dispatch_tokens_per_rank
,
hidden
,
group
.
size
(),
num_experts
)
if
(
_buffer_low_latency
is
None
or
_buffer_low_latency
.
group
!=
group
or
not
_buffer_low_latency
.
low_latency_mode
or
_buffer_low_latency
.
num_rdma_bytes
<
num_rdma_bytes
@
classmethod
def
get_deepep_buffer
(
cls
,
group
:
dist
.
ProcessGroup
,
hidden_size
:
int
,
param_bytes
:
int
,
deepep_mode
:
DeepEPMode
,
num_max_dispatch_tokens_per_rank
:
int
=
None
,
num_experts
:
int
=
None
,
):
assert
num_experts
%
group
.
size
()
==
0
_buffer_low_latency
=
Buffer
(
if
cls
.
_buffer
is
not
None
:
return
cls
.
_buffer
cls
.
_hidden_size
=
hidden_size
cls
.
_num_max_dispatch_tokens_per_rank
=
num_max_dispatch_tokens_per_rank
cls
.
_num_experts
=
num_experts
num_nvl_bytes
,
num_rdma_bytes
=
0
,
0
if
deepep_mode
.
enable_normal
():
hidden_bytes
=
hidden_size
*
param_bytes
for
config
in
(
Buffer
.
get_dispatch_config
(
group
.
size
()),
Buffer
.
get_combine_config
(
group
.
size
()),
):
num_nvl_bytes
=
max
(
config
.
get_nvl_buffer_size_hint
(
hidden_bytes
,
group
.
size
()),
num_nvl_bytes
,
)
num_rdma_bytes
=
max
(
config
.
get_rdma_buffer_size_hint
(
hidden_bytes
,
group
.
size
()),
num_rdma_bytes
,
)
if
deepep_mode
.
enable_low_latency
():
assert
num_max_dispatch_tokens_per_rank
is
not
None
assert
num_experts
is
not
None
and
num_experts
%
group
.
size
()
==
0
num_rdma_bytes
=
max
(
Buffer
.
get_low_latency_rdma_size_hint
(
num_max_dispatch_tokens_per_rank
,
hidden_size
,
group
.
size
(),
num_experts
,
),
num_rdma_bytes
,
)
cls
.
_buffer
=
Buffer
(
group
,
num_rdma_bytes
=
num_rdma_bytes
,
low_latency_mode
=
True
,
num_qps_per_rank
=
num_experts
//
group
.
size
(),
num_nvl_bytes
,
num_rdma_bytes
,
low_latency_mode
=
deepep_mode
.
enable_low_latency
(),
num_qps_per_rank
=
(
num_experts
//
group
.
size
()
if
deepep_mode
.
enable_low_latency
()
else
1
),
)
return
_buffer_low_latency
return
cls
.
_buffer
@
classmethod
def
clean_buffer
(
cls
):
if
not
cls
.
_buffer
.
low_latency_mode
:
return
cls
.
_buffer
.
clean_low_latency_buffer
(
cls
.
_num_max_dispatch_tokens_per_rank
,
cls
.
_hidden_size
,
cls
.
_num_experts
,
)
@
classmethod
def
set_dispatch_mode_as_normal
(
cls
):
cls
.
_dispatch_mode
=
DeepEPDispatchMode
.
NORMAL
@
classmethod
def
set_dispatch_mode_as_low_latency
(
cls
):
if
cls
.
_dispatch_mode
==
DeepEPDispatchMode
.
NORMAL
:
cls
.
clean_buffer
()
cls
.
_dispatch_mode
=
DeepEPDispatchMode
.
LOW_LATENCY
class
_DeepEPDispatcherImplBase
:
...
...
@@ -95,6 +121,7 @@ class _DeepEPDispatcherImplBase:
num_local_experts
:
int
,
hidden_size
:
int
,
params_dtype
:
torch
.
dtype
,
deepep_mode
:
DeepEPMode
,
):
if
not
use_deepep
:
raise
ImportError
(
...
...
@@ -109,7 +136,10 @@ class _DeepEPDispatcherImplBase:
self
.
num_local_experts
=
num_local_experts
self
.
hidden_size
=
hidden_size
self
.
params_dtype
=
params_dtype
self
.
deepep_mode
=
deepep_mode
self
.
params_bytes
=
2
self
.
num_max_dispatch_tokens_per_rank
=
128
self
.
handle
=
None
...
...
@@ -118,8 +148,6 @@ class _DeepEPDispatcherImplBase:
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
num_experts
:
int
,
num_max_dispatch_tokens_per_rank
:
int
,
):
raise
NotImplementedError
...
...
@@ -137,14 +165,14 @@ class _DeepEPDispatcherImplBase:
def
combine_b
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
def
_get_buffer
(
self
)
->
Buffer
:
raise
NotImplementedError
class
_DeepEPDispatcherImplNormal
(
_DeepEPDispatcherImplBase
):
def
__init__
(
self
,
async_finish
:
bool
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
buffer_normal
=
_get_buffer_normal
(
self
.
group
,
self
.
hidden_size
*
self
.
params_bytes
)
self
.
async_finish
=
async_finish
self
.
src2dst
=
None
...
...
@@ -153,24 +181,18 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
num_experts
:
int
,
num_max_dispatch_tokens_per_rank
:
int
,
):
topk_idx
=
topk_idx
.
to
(
torch
.
int64
)
previous_event
=
Buffer
.
capture
()
if
self
.
async_finish
else
None
return
hidden_states
,
topk_idx
,
topk_weights
,
num_experts
,
previous_event
return
hidden_states
,
topk_idx
,
topk_weights
,
previous_event
def
dispatch_b
(
self
,
hidden_states
,
topk_idx
,
topk_weights
,
num_experts
,
previous_event
):
def
dispatch_b
(
self
,
hidden_states
,
topk_idx
,
topk_weights
,
previous_event
):
(
hidden_states
,
topk_idx
,
topk_weights
,
event
,
)
=
self
.
_dispatch_core
(
hidden_states
,
topk_idx
,
topk_weights
,
num_experts
,
previous_event
)
)
=
self
.
_dispatch_core
(
hidden_states
,
topk_idx
,
topk_weights
,
previous_event
)
event
.
current_stream_wait
()
if
self
.
async_finish
else
()
if
hidden_states
.
shape
[
0
]
>
0
:
reorder_topk_ids
,
seg_indptr
,
hidden_states
=
self
.
_deepep_permute
(
...
...
@@ -181,7 +203,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
(
0
,),
device
=
hidden_states
.
device
,
dtype
=
torch
.
int64
)
seg_indptr
=
torch
.
zeros
(
(
num_experts
+
1
,),
device
=
hidden_states
.
device
,
dtype
=
torch
.
int64
(
self
.
num_experts
+
1
,),
device
=
hidden_states
.
device
,
dtype
=
torch
.
int64
)
masked_m
=
expected_m
=
None
...
...
@@ -201,18 +223,18 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
num_experts
:
int
,
previous_event
,
):
buffer
=
self
.
_get_buffer
()
(
num_tokens_per_rank
,
num_tokens_per_rdma_rank
,
num_tokens_per_expert
,
is_token_in_rank
,
previous_event
,
)
=
self
.
buffer
_normal
.
get_dispatch_layout
(
)
=
buffer
.
get_dispatch_layout
(
topk_idx
,
num_experts
,
self
.
num_experts
,
previous_event
=
previous_event
,
async_finish
=
self
.
async_finish
,
allocate_on_comm_stream
=
previous_event
is
not
None
,
...
...
@@ -221,6 +243,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
# FIXME: `handle` should be transmitted with tokens from dispatch to combine.
# However, doing this would incur an unknown synchronization error, but keeping
# `handle` as a member variable works.
(
recv_x
,
recv_topk_idx
,
...
...
@@ -228,7 +251,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
_
,
# num_recv_tokens_per_expert_list
self
.
handle
,
event
,
)
=
self
.
buffer
_normal
.
dispatch
(
)
=
buffer
.
dispatch
(
x
,
topk_idx
=
topk_idx
,
topk_weights
=
topk_weights
,
...
...
@@ -327,7 +350,8 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
return
hidden_states
def
_combine_core
(
self
,
x
:
torch
.
Tensor
,
previous_event
):
combined_x
,
_
,
event
=
self
.
buffer_normal
.
combine
(
buffer
=
self
.
_get_buffer
()
combined_x
,
_
,
event
=
buffer
.
combine
(
x
,
self
.
handle
,
async_finish
=
self
.
async_finish
,
...
...
@@ -336,6 +360,17 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
)
return
combined_x
,
event
def
_get_buffer
(
self
):
DeepEPBuffer
.
set_dispatch_mode_as_normal
()
return
DeepEPBuffer
.
get_deepep_buffer
(
self
.
group
,
self
.
hidden_size
,
self
.
params_bytes
,
self
.
deepep_mode
,
self
.
num_max_dispatch_tokens_per_rank
,
self
.
num_experts
,
)
class
_DeepEPDispatcherImplLowLatency
(
_DeepEPDispatcherImplBase
):
def
__init__
(
self
,
return_recv_hook
:
bool
,
**
kwargs
):
...
...
@@ -345,14 +380,6 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
"""
# TODO(ch-wan): allow users to set this value
self
.
num_max_dispatch_tokens_per_rank
=
128
self
.
buffer_low_latency
=
_get_buffer_low_latency
(
self
.
group
,
self
.
num_max_dispatch_tokens_per_rank
,
self
.
hidden_size
,
self
.
num_experts
,
)
self
.
return_recv_hook
=
return_recv_hook
def
dispatch_a
(
...
...
@@ -360,21 +387,16 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
num_experts
:
int
,
num_max_dispatch_tokens_per_rank
:
int
,
):
buffer
=
self
.
_get_buffer
()
topk_idx
=
topk_idx
.
to
(
torch
.
int64
)
expected_m
=
(
hidden_states
.
shape
[
0
]
*
self
.
buffer_low_latency
.
group_size
*
topk_idx
.
shape
[
1
]
+
num_experts
)
//
num_experts
hidden_states
.
shape
[
0
]
*
buffer
.
group_size
*
topk_idx
.
shape
[
1
]
+
self
.
num_experts
)
//
self
.
num_experts
hidden_states
,
masked_m
,
event
,
hook
=
self
.
_dispatch_core
(
hidden_states
,
topk_idx
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
use_fp8
=
True
,
)
return
(
...
...
@@ -415,8 +437,6 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
self
,
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
num_max_dispatch_tokens_per_rank
:
int
,
num_experts
:
int
,
use_fp8
:
bool
=
False
,
):
"""
...
...
@@ -451,13 +471,13 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
"""
buffer
=
self
.
_get_buffer
()
packed_recv_hidden
,
packed_recv_count
,
self
.
handle
,
event
,
hook
=
(
self
.
buffer
_low_latency
.
low_latency_dispatch
(
buffer
.
low_latency_dispatch
(
hidden_states
,
topk_idx
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
self
.
num_max_dispatch_tokens_per_rank
,
self
.
num_experts
,
use_fp8
=
use_fp8
,
async_finish
=
not
self
.
return_recv_hook
,
return_recv_hook
=
self
.
return_recv_hook
,
...
...
@@ -488,19 +508,29 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
):
combined_hidden_states
,
event
,
hook
=
(
self
.
buffer_low_latency
.
low_latency_combine
(
hidden_states
,
topk_idx
,
topk_weights
,
self
.
handle
,
async_finish
=
not
self
.
return_recv_hook
,
return_recv_hook
=
self
.
return_recv_hook
,
)
buffer
=
self
.
_get_buffer
()
combined_hidden_states
,
event
,
hook
=
buffer
.
low_latency_combine
(
hidden_states
,
topk_idx
,
topk_weights
,
self
.
handle
,
async_finish
=
not
self
.
return_recv_hook
,
return_recv_hook
=
self
.
return_recv_hook
,
)
self
.
handle
=
None
return
combined_hidden_states
,
event
,
hook
def
_get_buffer
(
self
):
DeepEPBuffer
.
set_dispatch_mode_as_low_latency
()
return
DeepEPBuffer
.
get_deepep_buffer
(
self
.
group
,
self
.
hidden_size
,
self
.
params_bytes
,
self
.
deepep_mode
,
self
.
num_max_dispatch_tokens_per_rank
,
self
.
num_experts
,
)
class
DeepEPDispatcher
:
def
__init__
(
...
...
@@ -526,18 +556,19 @@ class DeepEPDispatcher:
num_local_experts
=
num_local_experts
,
hidden_size
=
hidden_size
,
params_dtype
=
params_dtype
,
deepep_mode
=
deepep_mode
,
)
if
self
.
deepep_mode
.
enable_normal
():
self
.
_normal_dispatcher
=
_DeepEPDispatcherImplNormal
(
async_finish
=
async_finish
,
**
common_kwargs
,
)
if
self
.
deepep_mode
.
enable_low_latency
():
self
.
_low_latency_dispatcher
=
_DeepEPDispatcherImplLowLatency
(
return_recv_hook
=
return_recv_hook
,
**
common_kwargs
,
)
if
self
.
deepep_mode
.
enable_normal
():
self
.
_normal_dispatcher
=
_DeepEPDispatcherImplNormal
(
async_finish
=
async_finish
,
**
common_kwargs
,
)
def
dispatch
(
self
,
*
args
,
**
kwargs
)
->
Tuple
:
self
.
dispatch_a
(
*
args
,
**
kwargs
)
...
...
@@ -548,16 +579,12 @@ class DeepEPDispatcher:
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
num_experts
:
int
,
num_max_dispatch_tokens_per_rank
:
int
=
128
,
forward_mode
:
ForwardMode
=
None
,
):
inner_state
=
self
.
_get_impl
(
forward_mode
).
dispatch_a
(
hidden_states
=
hidden_states
,
topk_idx
=
topk_idx
,
topk_weights
=
topk_weights
,
num_experts
=
num_experts
,
num_max_dispatch_tokens_per_rank
=
num_max_dispatch_tokens_per_rank
,
)
self
.
_dispatch_intermediate_state
=
forward_mode
,
inner_state
...
...
@@ -589,7 +616,7 @@ class DeepEPDispatcher:
del
self
.
_combine_intermediate_state
return
self
.
_get_impl
(
forward_mode
).
combine_b
(
*
inner_state
)
def
_get_impl
(
self
,
forward_mode
:
ForwardMode
)
->
"
_DeepEPDispatcherImplBase
"
:
def
_get_impl
(
self
,
forward_mode
:
ForwardMode
)
->
_DeepEPDispatcherImplBase
:
resolved_deepep_mode
=
self
.
deepep_mode
.
resolve
(
forward_mode
)
if
resolved_deepep_mode
==
DeepEPMode
.
normal
:
return
self
.
_normal_dispatcher
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
bc3f6db2
...
...
@@ -72,7 +72,7 @@ class ForwardMode(IntEnum):
DUMMY_FIRST
=
auto
()
def
is_prefill
(
self
):
return
self
==
ForwardMode
.
PREFILL
return
self
.
is_extend
()
def
is_extend
(
self
):
return
(
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
bc3f6db2
...
...
@@ -324,6 +324,7 @@ class DeepseekV2MoE(nn.Module):
correction_bias
=
self
.
correction_bias
,
)
if
self
.
ep_size
>
1
:
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
(
hidden_states
,
topk_idx
,
...
...
@@ -336,7 +337,6 @@ class DeepseekV2MoE(nn.Module):
hidden_states
,
topk_idx
,
topk_weights
,
self
.
num_experts
,
forward_mode
=
forward_mode
,
)
final_hidden_states
=
(
...
...
python/sglang/srt/server_args.py
View file @
bc3f6db2
...
...
@@ -1101,6 +1101,7 @@ class ServerArgs:
"--deepep-mode"
,
type
=
str
,
choices
=
[
"normal"
,
"low_latency"
,
"auto"
],
default
=
"auto"
,
help
=
"Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch."
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment