Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
df7f61ee
"vscode:/vscode.git/clone" did not exist on "407b85081e156e575fde594071bbddee660c40af"
Unverified
Commit
df7f61ee
authored
Jun 03, 2025
by
fzyzcjy
Committed by
GitHub
Jun 02, 2025
Browse files
Speed up rebalancing when using non-static dispatch algorithms (#6812)
parent
ef21729c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
17 deletions
+30
-17
python/sglang/srt/managers/expert_location.py
python/sglang/srt/managers/expert_location.py
+21
-13
python/sglang/srt/managers/expert_location_dispatch.py
python/sglang/srt/managers/expert_location_dispatch.py
+9
-4
No files found.
python/sglang/srt/managers/expert_location.py
View file @
df7f61ee
...
...
@@ -35,7 +35,8 @@ class ExpertLocationMetadata:
physical_to_logical_map
:
torch
.
Tensor
# (layers, num_physical_experts)
logical_to_all_physical_map
:
torch
.
Tensor
# (layers, num_logical_experts, X)
logical_to_all_physical_map_num_valid
:
torch
.
Tensor
# (layers, num_logical_experts)
logical_to_rank_dispatch_physical_map
:
torch
.
Tensor
# (layers, num_logical_experts)
# (layers, num_logical_experts)
logical_to_rank_dispatch_physical_map
:
Optional
[
torch
.
Tensor
]
# -------------------------------- properties ------------------------------------
...
...
@@ -70,11 +71,8 @@ class ExpertLocationMetadata:
num_layers_2
,
num_logical_experts_1
=
(
self
.
logical_to_all_physical_map_num_valid
.
shape
)
num_layers_3
,
num_logical_experts_2
=
(
self
.
logical_to_rank_dispatch_physical_map
.
shape
)
assert
num_layers_0
==
num_layers_1
==
num_layers_2
==
num_layers_3
assert
num_logical_experts_0
==
num_logical_experts_1
==
num_logical_experts_2
assert
num_layers_0
==
num_layers_1
==
num_layers_2
assert
num_logical_experts_0
==
num_logical_experts_1
assert
num_physical_experts_0
==
num_physical_experts_1
# -------------------------------- construction ------------------------------------
...
...
@@ -117,6 +115,7 @@ class ExpertLocationMetadata:
)
return
ExpertLocationMetadata
.
_init_raw
(
server_args
=
server_args
,
ep_size
=
common
[
"ep_size"
],
physical_to_logical_map
=
physical_to_logical_map
,
logical_to_all_physical_map
=
logical_to_all_physical_map
,
...
...
@@ -154,6 +153,7 @@ class ExpertLocationMetadata:
)
return
ExpertLocationMetadata
.
_init_raw
(
server_args
=
server_args
,
ep_size
=
common
[
"ep_size"
],
physical_to_logical_map
=
physical_to_logical_map
.
to
(
server_args
.
device
),
logical_to_all_physical_map
=
logical_to_all_physical_map
.
to
(
...
...
@@ -184,6 +184,7 @@ class ExpertLocationMetadata:
@
staticmethod
def
_init_raw
(
server_args
:
ServerArgs
,
ep_size
:
int
,
physical_to_logical_map
:
torch
.
Tensor
,
logical_to_all_physical_map
:
torch
.
Tensor
,
...
...
@@ -204,12 +205,16 @@ class ExpertLocationMetadata:
physical_to_logical_map
=
physical_to_logical_map
,
logical_to_all_physical_map
=
logical_to_all_physical_map_padded
,
logical_to_all_physical_map_num_valid
=
logical_to_all_physical_map_num_valid
,
logical_to_rank_dispatch_physical_map
=
compute_logical_to_rank_dispatch_physical_map
(
logical_to_all_physical_map
=
logical_to_all_physical_map
,
num_gpus
=
ep_size
,
num_physical_experts
=
num_physical_experts
,
# TODO improve when we have real EP rank
ep_rank
=
torch
.
distributed
.
get_rank
()
%
ep_size
,
logical_to_rank_dispatch_physical_map
=
(
compute_logical_to_rank_dispatch_physical_map
(
logical_to_all_physical_map
=
logical_to_all_physical_map
,
num_gpus
=
ep_size
,
num_physical_experts
=
num_physical_experts
,
# TODO improve when we have real EP rank
ep_rank
=
torch
.
distributed
.
get_rank
()
%
ep_size
,
)
if
server_args
.
ep_dispatch_algorithm
==
"static"
else
None
),
)
...
...
@@ -230,8 +235,11 @@ class ExpertLocationMetadata:
"logical_to_all_physical_map_num_valid"
,
"logical_to_rank_dispatch_physical_map"
,
]:
src
=
getattr
(
other
,
field
)
dst
=
getattr
(
self
,
field
)
dst
[...]
=
getattr
(
other
,
field
)
assert
(
src
is
not
None
)
==
(
dst
is
not
None
)
if
dst
is
not
None
:
dst
[...]
=
src
# -------------------------------- usage ------------------------------------
...
...
python/sglang/srt/managers/expert_location_dispatch.py
View file @
df7f61ee
...
...
@@ -25,7 +25,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
class
ExpertLocationDispatchInfo
:
ep_dispatch_algorithm
:
Literal
[
"static"
,
"random"
]
# (num_logical_experts,)
partial_logical_to_rank_dispatch_physical_map
:
torch
.
Tensor
partial_logical_to_rank_dispatch_physical_map
:
Optional
[
torch
.
Tensor
]
# (num_logical_experts, X)
partial_logical_to_all_physical_map
:
torch
.
Tensor
# (num_logical_experts,)
...
...
@@ -42,9 +42,14 @@ class ExpertLocationDispatchInfo:
return
cls
(
ep_dispatch_algorithm
=
ep_dispatch_algorithm
,
partial_logical_to_rank_dispatch_physical_map
=
expert_location_metadata
.
logical_to_rank_dispatch_physical_map
[
layer_id
,
:
],
partial_logical_to_rank_dispatch_physical_map
=
(
expert_location_metadata
.
logical_to_rank_dispatch_physical_map
[
layer_id
,
:
]
if
expert_location_metadata
.
logical_to_rank_dispatch_physical_map
is
not
None
else
None
),
partial_logical_to_all_physical_map
=
expert_location_metadata
.
logical_to_all_physical_map
[
layer_id
,
:
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment