Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
febe21ce
"docs/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "b9e1c30d0e314b563c1b7ed59c520ebf743dec9f"
Unverified
Commit
febe21ce
authored
Apr 04, 2025
by
fzyzcjy
Committed by
GitHub
Apr 04, 2025
Browse files
Small refactor DeepEPDispatcher into subclasses (#4994)
parent
a995a773
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
309 additions
and
191 deletions
+309
-191
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
+291
-162
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+18
-29
No files found.
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
View file @
febe21ce
...
...
@@ -23,7 +23,7 @@ _buffer_normal = None
_buffer_low_latency
=
None
def
get_buffer_normal
(
group
:
dist
.
ProcessGroup
,
hidden_bytes
:
int
):
def
_
get_buffer_normal
(
group
:
dist
.
ProcessGroup
,
hidden_bytes
:
int
):
"""
Copy from DeepEP example usage in model inference prefilling.
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-model-training-or-inference-prefilling
...
...
@@ -53,7 +53,7 @@ def get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int):
return
_buffer_normal
def
get_buffer_low_latency
(
def
_
get_buffer_low_latency
(
group
:
dist
.
ProcessGroup
,
num_max_dispatch_tokens_per_rank
:
int
,
hidden
:
int
,
...
...
@@ -85,24 +85,16 @@ def get_buffer_low_latency(
return
_buffer_low_latency
class
DeepEPDispatcher
:
"""
Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py
"""
class
_DeepEPDispatcherImplBase
:
def
__init__
(
self
,
group
:
torch
.
distributed
.
ProcessGroup
,
router_topk
:
int
,
permute_fusion
:
bool
=
False
,
num_experts
:
int
=
None
,
num_local_experts
:
int
=
None
,
hidden_size
:
int
=
None
,
params_dtype
:
torch
.
dtype
=
None
,
deepep_mode
:
DeepEPMode
=
DeepEPMode
.
auto
,
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
,
permute_fusion
:
bool
,
num_experts
:
int
,
num_local_experts
:
int
,
hidden_size
:
int
,
params_dtype
:
torch
.
dtype
,
):
if
not
use_deepep
:
raise
ImportError
(
...
...
@@ -119,63 +111,36 @@ class DeepEPDispatcher:
self
.
params_dtype
=
params_dtype
self
.
params_bytes
=
2
self
.
deepep_mode
=
deepep_mode
self
.
handle
=
None
if
self
.
deepep_mode
.
enable_normal
():
self
.
buffer_normal
=
get_buffer_normal
(
self
.
group
,
self
.
hidden_size
*
self
.
params_bytes
)
self
.
async_finish
=
async_finish
self
.
src2dst
=
None
if
self
.
deepep_mode
.
enable_low_latency
():
"""
num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
"""
# TODO(ch-wan): allow users to set this value
self
.
num_max_dispatch_tokens_per_rank
=
128
self
.
buffer_low_latency
=
get_buffer_low_latency
(
self
.
group
,
self
.
num_max_dispatch_tokens_per_rank
,
self
.
hidden_size
,
self
.
num_experts
,
)
self
.
return_recv_hook
=
return_recv_hook
def
deepep_permute
(
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
fp8_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
use_fp8_w8a8
:
bool
=
False
,
use_block_qu
an
t
:
bool
=
False
,
topk_weights
:
torch
.
Tensor
,
num_experts
:
int
,
num_max_dispatch_tokens_per_r
an
k
:
int
,
):
reorder_topk_ids
,
self
.
src2dst
,
seg_indptr
=
deepep_run_moe_deep_preprocess
(
topk_idx
,
self
.
num_experts
)
num_total_tokens
=
reorder_topk_ids
.
numel
()
gateup_input
=
torch
.
empty
(
(
int
(
num_total_tokens
),
hidden_states
.
shape
[
1
]),
device
=
hidden_states
.
device
,
dtype
=
(
fp8_dtype
if
(
use_fp8_w8a8
and
not
use_block_quant
)
else
hidden_states
.
dtype
),
)
# PreReorder
deepep_permute_triton_kernel
[(
hidden_states
.
shape
[
0
],)](
hidden_states
,
gateup_input
,
self
.
src2dst
,
topk_idx
,
None
,
self
.
router_topk
,
hidden_states
.
shape
[
1
],
BLOCK_SIZE
=
512
,
raise
NotImplementedError
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
class
_DeepEPDispatcherImplNormal
(
_DeepEPDispatcherImplBase
):
def
__init__
(
self
,
async_finish
:
bool
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
buffer_normal
=
_get_buffer_normal
(
self
.
group
,
self
.
hidden_size
*
self
.
params_bytes
)
return
reorder_topk_ids
,
seg_indptr
,
gateup_input
self
.
async_finish
=
async_finish
self
.
src2dst
=
None
def
dispatch
(
self
,
...
...
@@ -183,51 +148,34 @@ class DeepEPDispatcher:
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
num_experts
:
int
,
num_max_dispatch_tokens_per_rank
:
int
=
128
,
forward_mode
:
ForwardMode
=
None
,
)
->
Tuple
:
num_max_dispatch_tokens_per_rank
:
int
,
):
topk_idx
=
topk_idx
.
to
(
torch
.
int64
)
reorder_topk_ids
=
torch
.
empty
(
(
0
,),
device
=
hidden_states
.
device
,
dtype
=
torch
.
int64
)
seg_indptr
=
torch
.
zeros
(
(
num_experts
+
1
,),
device
=
hidden_states
.
device
,
dtype
=
torch
.
int64
)
masked_m
=
torch
.
empty
(
(
self
.
num_local_experts
,),
device
=
hidden_states
.
device
,
dtype
=
torch
.
int64
)
expected_m
=
0
resolved_deepep_mode
=
self
.
deepep_mode
.
resolve
(
forward_mode
)
if
resolved_deepep_mode
==
DeepEPMode
.
normal
:
(
hidden_states
,
topk_idx
,
topk_weights
,
event
,
)
=
self
.
dispatch_normal
(
hidden_states
,
topk_idx
,
topk_weights
,
num_experts
)
event
.
current_stream_wait
()
if
self
.
async_finish
else
()
if
hidden_states
.
shape
[
0
]
>
0
:
reorder_topk_ids
,
seg_indptr
,
hidden_states
=
self
.
deepep_permute
(
hidden_states
,
topk_idx
,
fp8_dtype
=
hidden_states
.
dtype
)
elif
resolved_deepep_mode
==
DeepEPMode
.
low_latency
:
expected_m
=
(
hidden_states
.
shape
[
0
]
*
self
.
buffer_low_latency
.
group_size
*
topk_idx
.
shape
[
1
]
+
num_experts
)
//
num_experts
hidden_states
,
masked_m
,
event
,
hook
=
self
.
dispatch_low_latency
(
hidden_states
,
topk_idx
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
use_fp8
=
True
,
(
hidden_states
,
topk_idx
,
topk_weights
,
event
,
)
=
self
.
_dispatch_normal
(
hidden_states
,
topk_idx
,
topk_weights
,
num_experts
)
event
.
current_stream_wait
()
if
self
.
async_finish
else
()
if
hidden_states
.
shape
[
0
]
>
0
:
reorder_topk_ids
,
seg_indptr
,
hidden_states
=
self
.
_deepep_permute
(
hidden_states
,
topk_idx
,
fp8_dtype
=
hidden_states
.
dtype
)
hook
()
if
self
.
return_recv_hook
else
event
.
current_stream_wait
()
else
:
raise
ValueError
(
f
"Invalid deepep_mode:
{
self
.
deepep_mode
}
"
)
reorder_topk_ids
=
torch
.
empty
(
(
0
,),
device
=
hidden_states
.
device
,
dtype
=
torch
.
int64
)
seg_indptr
=
torch
.
zeros
(
(
num_experts
+
1
,),
device
=
hidden_states
.
device
,
dtype
=
torch
.
int64
)
# TODO
# masked_m = torch.empty(
# (self.num_local_experts,), device=hidden_states.device, dtype=torch.int64
# )
# expected_m = 0
masked_m
=
expected_m
=
None
return
(
hidden_states
,
...
...
@@ -239,7 +187,7 @@ class DeepEPDispatcher:
expected_m
,
)
def
dispatch_normal
(
def
_
dispatch_normal
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
...
...
@@ -292,7 +240,156 @@ class DeepEPDispatcher:
event
,
)
def
dispatch_low_latency
(
def
_deepep_permute
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
fp8_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
use_fp8_w8a8
:
bool
=
False
,
use_block_quant
:
bool
=
False
,
):
"""
Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py
"""
reorder_topk_ids
,
self
.
src2dst
,
seg_indptr
=
deepep_run_moe_deep_preprocess
(
topk_idx
,
self
.
num_experts
)
num_total_tokens
=
reorder_topk_ids
.
numel
()
gateup_input
=
torch
.
empty
(
(
int
(
num_total_tokens
),
hidden_states
.
shape
[
1
]),
device
=
hidden_states
.
device
,
dtype
=
(
fp8_dtype
if
(
use_fp8_w8a8
and
not
use_block_quant
)
else
hidden_states
.
dtype
),
)
# PreReorder
deepep_permute_triton_kernel
[(
hidden_states
.
shape
[
0
],)](
hidden_states
,
gateup_input
,
self
.
src2dst
,
topk_idx
,
None
,
self
.
router_topk
,
hidden_states
.
shape
[
1
],
BLOCK_SIZE
=
512
,
)
return
reorder_topk_ids
,
seg_indptr
,
gateup_input
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
if
hidden_states
.
shape
[
0
]
>
0
:
num_tokens
=
self
.
src2dst
.
shape
[
0
]
//
self
.
router_topk
output
=
torch
.
empty
(
(
num_tokens
,
hidden_states
.
shape
[
1
]),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
deepep_post_reorder_triton_kernel
[(
num_tokens
,)](
hidden_states
,
output
,
self
.
src2dst
,
topk_idx
,
topk_weights
,
self
.
router_topk
,
hidden_states
.
shape
[
1
],
BLOCK_SIZE
=
512
,
)
else
:
output
=
torch
.
zeros
(
(
0
,
hidden_states
.
shape
[
1
]),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
hidden_states
,
event
=
self
.
_combine_normal
(
output
,
)
event
.
current_stream_wait
()
if
self
.
async_finish
else
()
return
hidden_states
def
_combine_normal
(
self
,
x
:
torch
.
Tensor
):
previous_event
=
Buffer
.
capture
()
if
self
.
async_finish
else
None
combined_x
,
_
,
event
=
self
.
buffer_normal
.
combine
(
x
,
self
.
handle
,
async_finish
=
self
.
async_finish
,
previous_event
=
previous_event
,
allocate_on_comm_stream
=
previous_event
is
not
None
,
)
return
combined_x
,
event
class
_DeepEPDispatcherImplLowLatency
(
_DeepEPDispatcherImplBase
):
def
__init__
(
self
,
return_recv_hook
:
bool
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
"""
num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
"""
# TODO(ch-wan): allow users to set this value
self
.
num_max_dispatch_tokens_per_rank
=
128
self
.
buffer_low_latency
=
_get_buffer_low_latency
(
self
.
group
,
self
.
num_max_dispatch_tokens_per_rank
,
self
.
hidden_size
,
self
.
num_experts
,
)
self
.
return_recv_hook
=
return_recv_hook
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
num_experts
:
int
,
num_max_dispatch_tokens_per_rank
:
int
,
):
topk_idx
=
topk_idx
.
to
(
torch
.
int64
)
expected_m
=
(
hidden_states
.
shape
[
0
]
*
self
.
buffer_low_latency
.
group_size
*
topk_idx
.
shape
[
1
]
+
num_experts
)
//
num_experts
hidden_states
,
masked_m
,
event
,
hook
=
self
.
_dispatch_low_latency
(
hidden_states
,
topk_idx
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
use_fp8
=
True
,
)
hook
()
if
self
.
return_recv_hook
else
event
.
current_stream_wait
()
# TODO
# reorder_topk_ids = torch.empty(
# (0,), device=hidden_states.device, dtype=torch.int64
# )
# seg_indptr = torch.zeros(
# (num_experts + 1,), device=hidden_states.device, dtype=torch.int64
# )
reorder_topk_ids
=
seg_indptr
=
None
return
(
hidden_states
,
topk_idx
,
topk_weights
,
reorder_topk_ids
,
seg_indptr
,
masked_m
,
expected_m
,
)
def
_dispatch_low_latency
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
...
...
@@ -351,62 +448,17 @@ class DeepEPDispatcher:
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
forward_mode
:
ForwardMode
,
)
->
torch
.
Tensor
:
resolved_deepep_mode
=
self
.
deepep_mode
.
resolve
(
forward_mode
)
if
resolved_deepep_mode
==
DeepEPMode
.
normal
:
if
hidden_states
.
shape
[
0
]
>
0
:
num_tokens
=
self
.
src2dst
.
shape
[
0
]
//
self
.
router_topk
output
=
torch
.
empty
(
(
num_tokens
,
hidden_states
.
shape
[
1
]),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
deepep_post_reorder_triton_kernel
[(
num_tokens
,)](
hidden_states
,
output
,
self
.
src2dst
,
topk_idx
,
topk_weights
,
self
.
router_topk
,
hidden_states
.
shape
[
1
],
BLOCK_SIZE
=
512
,
)
else
:
output
=
torch
.
zeros
(
(
0
,
hidden_states
.
shape
[
1
]),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
hidden_states
,
event
=
self
.
combine_normal
(
output
,
)
event
.
current_stream_wait
()
if
self
.
async_finish
else
()
elif
resolved_deepep_mode
==
DeepEPMode
.
low_latency
:
hidden_states
,
event
,
hook
=
self
.
combine_low_latency
(
hidden_states
,
topk_idx
,
topk_weights
,
)
hook
()
if
self
.
return_recv_hook
else
event
.
current_stream_wait
()
else
:
raise
ValueError
(
f
"Invalid deepep_mode:
{
self
.
deepep_mode
}
"
)
hidden_states
,
event
,
hook
=
self
.
_combine_low_latency
(
hidden_states
,
topk_idx
,
topk_weights
,
)
hook
()
if
self
.
return_recv_hook
else
event
.
current_stream_wait
()
return
hidden_states
def
combine_normal
(
self
,
x
:
torch
.
Tensor
):
previous_event
=
Buffer
.
capture
()
if
self
.
async_finish
else
None
combined_x
,
_
,
event
=
self
.
buffer_normal
.
combine
(
x
,
self
.
handle
,
async_finish
=
self
.
async_finish
,
previous_event
=
previous_event
,
allocate_on_comm_stream
=
previous_event
is
not
None
,
)
return
combined_x
,
event
def
combine_low_latency
(
def
_combine_low_latency
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
...
...
@@ -423,3 +475,80 @@ class DeepEPDispatcher:
)
)
return
combined_hidden_states
,
event
,
hook
class
DeepEPDispatcher
:
def
__init__
(
self
,
group
:
torch
.
distributed
.
ProcessGroup
,
router_topk
:
int
,
permute_fusion
:
bool
=
False
,
num_experts
:
int
=
None
,
num_local_experts
:
int
=
None
,
hidden_size
:
int
=
None
,
params_dtype
:
torch
.
dtype
=
None
,
deepep_mode
:
DeepEPMode
=
DeepEPMode
.
auto
,
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
,
):
self
.
deepep_mode
=
deepep_mode
common_kwargs
=
dict
(
group
=
group
,
router_topk
=
router_topk
,
permute_fusion
=
permute_fusion
,
num_experts
=
num_experts
,
num_local_experts
=
num_local_experts
,
hidden_size
=
hidden_size
,
params_dtype
=
params_dtype
,
)
if
self
.
deepep_mode
.
enable_normal
():
self
.
_normal_dispatcher
=
_DeepEPDispatcherImplNormal
(
async_finish
=
async_finish
,
**
common_kwargs
,
)
if
self
.
deepep_mode
.
enable_low_latency
():
self
.
_low_latency_dispatcher
=
_DeepEPDispatcherImplLowLatency
(
return_recv_hook
=
return_recv_hook
,
**
common_kwargs
,
)
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
num_experts
:
int
,
num_max_dispatch_tokens_per_rank
:
int
=
128
,
forward_mode
:
ForwardMode
=
None
,
)
->
Tuple
:
return
self
.
_get_dispatcher
(
forward_mode
).
dispatch
(
hidden_states
=
hidden_states
,
topk_idx
=
topk_idx
,
topk_weights
=
topk_weights
,
num_experts
=
num_experts
,
num_max_dispatch_tokens_per_rank
=
num_max_dispatch_tokens_per_rank
,
)
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
forward_mode
:
ForwardMode
,
)
->
torch
.
Tensor
:
return
self
.
_get_dispatcher
(
forward_mode
).
combine
(
hidden_states
=
hidden_states
,
topk_idx
=
topk_idx
,
topk_weights
=
topk_weights
,
)
def
_get_dispatcher
(
self
,
forward_mode
:
ForwardMode
)
->
_DeepEPDispatcherImplBase
:
resolved_deepep_mode
=
self
.
deepep_mode
.
resolve
(
forward_mode
)
if
resolved_deepep_mode
==
DeepEPMode
.
normal
:
return
self
.
_normal_dispatcher
elif
resolved_deepep_mode
==
DeepEPMode
.
low_latency
:
return
self
.
_low_latency_dispatcher
else
:
raise
ValueError
(
f
"Invalid deepep_mode:
{
self
.
deepep_mode
}
"
)
python/sglang/srt/models/deepseek_v2.py
View file @
febe21ce
...
...
@@ -188,35 +188,24 @@ class DeepseekV2MoE(nn.Module):
if
global_server_args_dict
[
"enable_deepep_moe"
]
else
(
EPMoE
if
global_server_args_dict
[
"enable_ep_moe"
]
else
FusedMoE
)
)
if
not
global_server_args_dict
[
"enable_deepep_moe"
]:
self
.
experts
=
MoEImpl
(
num_experts
=
config
.
n_routed_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
renormalize
=
config
.
norm_topk_prob
,
quant_config
=
quant_config
,
use_grouped_topk
=
True
,
num_expert_group
=
config
.
n_group
,
topk_group
=
config
.
topk_group
,
correction_bias
=
self
.
gate
.
e_score_correction_bias
,
prefix
=
add_prefix
(
"experts"
,
prefix
),
)
else
:
self
.
experts
=
MoEImpl
(
num_experts
=
config
.
n_routed_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
renormalize
=
config
.
norm_topk_prob
,
quant_config
=
quant_config
,
use_grouped_topk
=
True
,
num_expert_group
=
config
.
n_group
,
topk_group
=
config
.
topk_group
,
correction_bias
=
self
.
gate
.
e_score_correction_bias
,
prefix
=
add_prefix
(
"experts"
,
prefix
),
deepep_mode
=
DeepEPMode
[
global_server_args_dict
[
"deepep_mode"
]],
)
self
.
experts
=
MoEImpl
(
num_experts
=
config
.
n_routed_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
renormalize
=
config
.
norm_topk_prob
,
quant_config
=
quant_config
,
use_grouped_topk
=
True
,
num_expert_group
=
config
.
n_group
,
topk_group
=
config
.
topk_group
,
correction_bias
=
self
.
gate
.
e_score_correction_bias
,
prefix
=
add_prefix
(
"experts"
,
prefix
),
**
(
dict
(
deepep_mode
=
DeepEPMode
[
global_server_args_dict
[
"deepep_mode"
]])
if
global_server_args_dict
[
"enable_deepep_moe"
]
else
{}
),
)
if
config
.
n_shared_experts
is
not
None
:
intermediate_size
=
config
.
moe_intermediate_size
*
config
.
n_shared_experts
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment