Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a40aecc5
Unverified
Commit
a40aecc5
authored
May 21, 2025
by
fzyzcjy
Committed by
GitHub
May 21, 2025
Browse files
Fix num_qps_per_rank computation when providing custom DeepEP configuration (#6468)
parent
d6e1d28c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
9 deletions
+21
-9
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
+21
-9
No files found.
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
View file @
a40aecc5
...
...
@@ -67,9 +67,9 @@ class DeepEPBuffer:
if
deepep_mode
.
enable_normal
():
hidden_bytes
=
hidden_size
*
param_bytes
for
config
in
(
_
DeepEPConfig
.
get_instance
().
normal_dispatch_config
DeepEPConfig
.
get_instance
().
normal_dispatch_config
or
Buffer
.
get_dispatch_config
(
group
.
size
()),
_
DeepEPConfig
.
get_instance
().
normal_combine_config
DeepEPConfig
.
get_instance
().
normal_combine_config
or
Buffer
.
get_combine_config
(
group
.
size
()),
):
num_nvl_bytes
=
max
(
...
...
@@ -97,7 +97,12 @@ class DeepEPBuffer:
num_nvl_bytes
,
num_rdma_bytes
,
low_latency_mode
=
deepep_mode
.
enable_low_latency
(),
num_qps_per_rank
=
(
max
(
num_experts
//
group
.
size
(),
Buffer
.
num_sms
//
2
)),
num_qps_per_rank
=
(
max
(
num_experts
//
group
.
size
(),
DeepEPConfig
.
get_instance
().
num_sms
//
2
,
)
),
)
return
cls
.
_buffer
...
...
@@ -122,7 +127,7 @@ class DeepEPBuffer:
cls
.
_dispatch_mode
=
DeepEPDispatchMode
.
LOW_LATENCY
class
_
DeepEPConfig
:
class
DeepEPConfig
:
_instance
=
None
def
__init__
(
self
):
...
...
@@ -131,16 +136,23 @@ class _DeepEPConfig:
config_parsed
=
load_json_config
(
config_str
)
if
torch
.
distributed
.
get_rank
()
==
0
:
logger
.
info
(
f
"Use DeepEP Config:
{
config_parsed
}
"
)
self
.
normal_dispatch_config
=
Config
(
**
config_parsed
[
"normal_dispatch"
])
self
.
normal_combine_config
=
Config
(
**
config_parsed
[
"normal_combine"
])
config_dispatch
=
config_parsed
[
"normal_dispatch"
]
config_combine
=
config_parsed
[
"normal_combine"
]
self
.
normal_dispatch_config
=
Config
(
**
config_dispatch
)
self
.
normal_combine_config
=
Config
(
**
config_combine
)
assert
config_dispatch
[
"num_sms"
]
==
config_combine
[
"num_sms"
]
self
.
num_sms
=
config_dispatch
[
"num_sms"
]
else
:
self
.
normal_dispatch_config
=
None
self
.
normal_combine_config
=
None
self
.
num_sms
=
Buffer
.
num_sms
@
classmethod
def
get_instance
(
cls
):
if
cls
.
_instance
is
None
:
cls
.
_instance
=
_
DeepEPConfig
()
cls
.
_instance
=
DeepEPConfig
()
return
cls
.
_instance
...
...
@@ -326,7 +338,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
async_finish
=
self
.
async_finish
,
allocate_on_comm_stream
=
(
previous_event
is
not
None
)
and
self
.
async_finish
,
expert_alignment
=
128
if
_ENABLE_JIT_DEEPGEMM
else
1
,
config
=
_
DeepEPConfig
.
get_instance
().
normal_dispatch_config
,
config
=
DeepEPConfig
.
get_instance
().
normal_dispatch_config
,
)
get_global_expert_distribution_recorder
().
on_deepep_dispatch_normal
(
...
...
@@ -433,7 +445,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
async_finish
=
self
.
async_finish
,
previous_event
=
previous_event
,
allocate_on_comm_stream
=
previous_event
is
not
None
,
config
=
_
DeepEPConfig
.
get_instance
().
normal_combine_config
,
config
=
DeepEPConfig
.
get_instance
().
normal_combine_config
,
)
return
combined_x
,
event
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment