Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
df7f61ee
Unverified
Commit
df7f61ee
authored
Jun 03, 2025
by
fzyzcjy
Committed by
GitHub
Jun 02, 2025
Browse files
Speed up rebalancing when using non-static dispatch algorithms (#6812)
parent
ef21729c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
17 deletions
+30
-17
python/sglang/srt/managers/expert_location.py
python/sglang/srt/managers/expert_location.py
+21
-13
python/sglang/srt/managers/expert_location_dispatch.py
python/sglang/srt/managers/expert_location_dispatch.py
+9
-4
No files found.
python/sglang/srt/managers/expert_location.py
View file @
df7f61ee
...
@@ -35,7 +35,8 @@ class ExpertLocationMetadata:
...
@@ -35,7 +35,8 @@ class ExpertLocationMetadata:
physical_to_logical_map
:
torch
.
Tensor
# (layers, num_physical_experts)
physical_to_logical_map
:
torch
.
Tensor
# (layers, num_physical_experts)
logical_to_all_physical_map
:
torch
.
Tensor
# (layers, num_logical_experts, X)
logical_to_all_physical_map
:
torch
.
Tensor
# (layers, num_logical_experts, X)
logical_to_all_physical_map_num_valid
:
torch
.
Tensor
# (layers, num_logical_experts)
logical_to_all_physical_map_num_valid
:
torch
.
Tensor
# (layers, num_logical_experts)
logical_to_rank_dispatch_physical_map
:
torch
.
Tensor
# (layers, num_logical_experts)
# (layers, num_logical_experts)
logical_to_rank_dispatch_physical_map
:
Optional
[
torch
.
Tensor
]
# -------------------------------- properties ------------------------------------
# -------------------------------- properties ------------------------------------
...
@@ -70,11 +71,8 @@ class ExpertLocationMetadata:
...
@@ -70,11 +71,8 @@ class ExpertLocationMetadata:
num_layers_2
,
num_logical_experts_1
=
(
num_layers_2
,
num_logical_experts_1
=
(
self
.
logical_to_all_physical_map_num_valid
.
shape
self
.
logical_to_all_physical_map_num_valid
.
shape
)
)
num_layers_3
,
num_logical_experts_2
=
(
assert
num_layers_0
==
num_layers_1
==
num_layers_2
self
.
logical_to_rank_dispatch_physical_map
.
shape
assert
num_logical_experts_0
==
num_logical_experts_1
)
assert
num_layers_0
==
num_layers_1
==
num_layers_2
==
num_layers_3
assert
num_logical_experts_0
==
num_logical_experts_1
==
num_logical_experts_2
assert
num_physical_experts_0
==
num_physical_experts_1
assert
num_physical_experts_0
==
num_physical_experts_1
# -------------------------------- construction ------------------------------------
# -------------------------------- construction ------------------------------------
...
@@ -117,6 +115,7 @@ class ExpertLocationMetadata:
...
@@ -117,6 +115,7 @@ class ExpertLocationMetadata:
)
)
return
ExpertLocationMetadata
.
_init_raw
(
return
ExpertLocationMetadata
.
_init_raw
(
server_args
=
server_args
,
ep_size
=
common
[
"ep_size"
],
ep_size
=
common
[
"ep_size"
],
physical_to_logical_map
=
physical_to_logical_map
,
physical_to_logical_map
=
physical_to_logical_map
,
logical_to_all_physical_map
=
logical_to_all_physical_map
,
logical_to_all_physical_map
=
logical_to_all_physical_map
,
...
@@ -154,6 +153,7 @@ class ExpertLocationMetadata:
...
@@ -154,6 +153,7 @@ class ExpertLocationMetadata:
)
)
return
ExpertLocationMetadata
.
_init_raw
(
return
ExpertLocationMetadata
.
_init_raw
(
server_args
=
server_args
,
ep_size
=
common
[
"ep_size"
],
ep_size
=
common
[
"ep_size"
],
physical_to_logical_map
=
physical_to_logical_map
.
to
(
server_args
.
device
),
physical_to_logical_map
=
physical_to_logical_map
.
to
(
server_args
.
device
),
logical_to_all_physical_map
=
logical_to_all_physical_map
.
to
(
logical_to_all_physical_map
=
logical_to_all_physical_map
.
to
(
...
@@ -184,6 +184,7 @@ class ExpertLocationMetadata:
...
@@ -184,6 +184,7 @@ class ExpertLocationMetadata:
@
staticmethod
@
staticmethod
def
_init_raw
(
def
_init_raw
(
server_args
:
ServerArgs
,
ep_size
:
int
,
ep_size
:
int
,
physical_to_logical_map
:
torch
.
Tensor
,
physical_to_logical_map
:
torch
.
Tensor
,
logical_to_all_physical_map
:
torch
.
Tensor
,
logical_to_all_physical_map
:
torch
.
Tensor
,
...
@@ -204,12 +205,16 @@ class ExpertLocationMetadata:
...
@@ -204,12 +205,16 @@ class ExpertLocationMetadata:
physical_to_logical_map
=
physical_to_logical_map
,
physical_to_logical_map
=
physical_to_logical_map
,
logical_to_all_physical_map
=
logical_to_all_physical_map_padded
,
logical_to_all_physical_map
=
logical_to_all_physical_map_padded
,
logical_to_all_physical_map_num_valid
=
logical_to_all_physical_map_num_valid
,
logical_to_all_physical_map_num_valid
=
logical_to_all_physical_map_num_valid
,
logical_to_rank_dispatch_physical_map
=
compute_logical_to_rank_dispatch_physical_map
(
logical_to_rank_dispatch_physical_map
=
(
logical_to_all_physical_map
=
logical_to_all_physical_map
,
compute_logical_to_rank_dispatch_physical_map
(
num_gpus
=
ep_size
,
logical_to_all_physical_map
=
logical_to_all_physical_map
,
num_physical_experts
=
num_physical_experts
,
num_gpus
=
ep_size
,
# TODO improve when we have real EP rank
num_physical_experts
=
num_physical_experts
,
ep_rank
=
torch
.
distributed
.
get_rank
()
%
ep_size
,
# TODO improve when we have real EP rank
ep_rank
=
torch
.
distributed
.
get_rank
()
%
ep_size
,
)
if
server_args
.
ep_dispatch_algorithm
==
"static"
else
None
),
),
)
)
...
@@ -230,8 +235,11 @@ class ExpertLocationMetadata:
...
@@ -230,8 +235,11 @@ class ExpertLocationMetadata:
"logical_to_all_physical_map_num_valid"
,
"logical_to_all_physical_map_num_valid"
,
"logical_to_rank_dispatch_physical_map"
,
"logical_to_rank_dispatch_physical_map"
,
]:
]:
src
=
getattr
(
other
,
field
)
dst
=
getattr
(
self
,
field
)
dst
=
getattr
(
self
,
field
)
dst
[...]
=
getattr
(
other
,
field
)
assert
(
src
is
not
None
)
==
(
dst
is
not
None
)
if
dst
is
not
None
:
dst
[...]
=
src
# -------------------------------- usage ------------------------------------
# -------------------------------- usage ------------------------------------
...
...
python/sglang/srt/managers/expert_location_dispatch.py
View file @
df7f61ee
...
@@ -25,7 +25,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
...
@@ -25,7 +25,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
class
ExpertLocationDispatchInfo
:
class
ExpertLocationDispatchInfo
:
ep_dispatch_algorithm
:
Literal
[
"static"
,
"random"
]
ep_dispatch_algorithm
:
Literal
[
"static"
,
"random"
]
# (num_logical_experts,)
# (num_logical_experts,)
partial_logical_to_rank_dispatch_physical_map
:
torch
.
Tensor
partial_logical_to_rank_dispatch_physical_map
:
Optional
[
torch
.
Tensor
]
# (num_logical_experts, X)
# (num_logical_experts, X)
partial_logical_to_all_physical_map
:
torch
.
Tensor
partial_logical_to_all_physical_map
:
torch
.
Tensor
# (num_logical_experts,)
# (num_logical_experts,)
...
@@ -42,9 +42,14 @@ class ExpertLocationDispatchInfo:
...
@@ -42,9 +42,14 @@ class ExpertLocationDispatchInfo:
return
cls
(
return
cls
(
ep_dispatch_algorithm
=
ep_dispatch_algorithm
,
ep_dispatch_algorithm
=
ep_dispatch_algorithm
,
partial_logical_to_rank_dispatch_physical_map
=
expert_location_metadata
.
logical_to_rank_dispatch_physical_map
[
partial_logical_to_rank_dispatch_physical_map
=
(
layer_id
,
:
expert_location_metadata
.
logical_to_rank_dispatch_physical_map
[
],
layer_id
,
:
]
if
expert_location_metadata
.
logical_to_rank_dispatch_physical_map
is
not
None
else
None
),
partial_logical_to_all_physical_map
=
expert_location_metadata
.
logical_to_all_physical_map
[
partial_logical_to_all_physical_map
=
expert_location_metadata
.
logical_to_all_physical_map
[
layer_id
,
:
layer_id
,
:
],
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment