Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0de5e7d4
"examples/vscode:/vscode.git/clone" did not exist on "3f0ed0385616ac1aa4b4646d33eee300eeb6ae77"
Unverified
Commit
0de5e7d4
authored
Jun 05, 2025
by
fzyzcjy
Committed by
GitHub
Jun 05, 2025
Browse files
Support layerwise rebalancing experts (#6851)
parent
72a110f6
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
114 additions
and
37 deletions
+114
-37
python/sglang/srt/managers/eplb_manager.py
python/sglang/srt/managers/eplb_manager.py
+51
-11
python/sglang/srt/managers/expert_location.py
python/sglang/srt/managers/expert_location.py
+14
-5
python/sglang/srt/model_executor/expert_location_updater.py
python/sglang/srt/model_executor/expert_location_updater.py
+20
-16
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+5
-2
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+7
-0
test/srt/test_eplb.py
test/srt/test_eplb.py
+17
-3
No files found.
python/sglang/srt/managers/eplb_manager.py
View file @
0de5e7d4
import
logging
import
time
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
,
List
import
torch.cuda
...
...
@@ -20,6 +20,10 @@ class EPLBManager:
super
().
__init__
()
self
.
_model_runner
=
model_runner
self
.
_server_args
=
model_runner
.
server_args
self
.
_rebalance_layers_per_chunk
=
(
self
.
_server_args
.
eplb_rebalance_layers_per_chunk
)
self
.
_rebalance_num_iterations
=
self
.
_server_args
.
eplb_rebalance_num_iterations
# Otherwise, the circular buffer will contain stale data. If the case is needed, it can be implemented.
assert
(
...
...
@@ -31,15 +35,28 @@ class EPLBManager:
get_global_expert_distribution_recorder
().
start_record
()
logger
.
info
(
f
"[EPLBManager] system started, will rebalance per
{
self
.
_
server_args
.
eplb_
rebalance_num_iterations
}
iterations."
f
"[EPLBManager] system started, will rebalance per
{
self
.
_rebalance_num_iterations
}
iterations."
)
def
on_forward_pass_end
(
self
,
forward_pass_id
:
int
):
if
forward_pass_id
%
self
.
_server_args
.
eplb_rebalance_num_iterations
==
0
:
self
.
rebalance
()
self
.
_main_generator
=
self
.
_entrypoint
()
def
on_forward_pass_end
(
self
):
next
(
self
.
_main_generator
)
# can be more complex if needed
def
_entrypoint
(
self
):
while
True
:
for
_
in
range
(
self
.
_rebalance_num_iterations
):
yield
yield
from
self
.
rebalance
()
def
rebalance
(
self
):
logger
.
info
(
"[EPLBManager] rebalance start"
)
enable_timing
=
self
.
_rebalance_layers_per_chunk
is
None
if
enable_timing
:
torch
.
cuda
.
synchronize
()
time_start
=
time
.
time
()
...
...
@@ -49,8 +66,31 @@ class EPLBManager:
expert_location_metadata
=
ExpertLocationMetadata
.
init_by_eplb
(
self
.
_server_args
,
self
.
_model_runner
.
model_config
,
logical_count
)
self
.
_model_runner
.
update_expert_location
(
expert_location_metadata
)
update_layer_ids_chunks
=
self
.
_compute_update_layer_ids_chunks
()
for
chunk_index
,
update_layer_ids
in
enumerate
(
update_layer_ids_chunks
):
if
len
(
update_layer_ids_chunks
)
>
1
:
yield
self
.
_model_runner
.
update_expert_location
(
expert_location_metadata
,
update_layer_ids
=
update_layer_ids
,
)
msg
=
f
"[EPLBManager] rebalance end"
if
enable_timing
:
torch
.
cuda
.
synchronize
()
time_end
=
time
.
time
()
logger
.
info
(
f
"[EPLBManager] rebalance end time=
{
time_end
-
time_start
:.
3
f
}
s"
)
msg
+=
f
" time=
{
time_end
-
time_start
:.
3
f
}
s"
logger
.
info
(
msg
)
def
_compute_update_layer_ids_chunks
(
self
)
->
List
[
List
[
int
]]:
all_layer_ids
=
sorted
(
list
(
self
.
_model_runner
.
model
.
routed_experts_weights_of_layer
.
keys
())
)
chunk_size
=
self
.
_rebalance_layers_per_chunk
or
1000000
return
list
(
_chunk_list
(
all_layer_ids
,
chunk_size
=
chunk_size
))
def
_chunk_list
(
items
:
List
,
chunk_size
):
for
start_index
in
range
(
0
,
len
(
items
),
chunk_size
):
yield
items
[
start_index
:
start_index
+
chunk_size
]
python/sglang/srt/managers/expert_location.py
View file @
0de5e7d4
...
...
@@ -33,6 +33,7 @@ logger = logging.getLogger(__name__)
@
dataclass
class
ExpertLocationMetadata
:
physical_to_logical_map
:
torch
.
Tensor
# (layers, num_physical_experts)
physical_to_logical_map_cpu
:
torch
.
Tensor
logical_to_all_physical_map
:
torch
.
Tensor
# (layers, num_logical_experts, X)
logical_to_all_physical_map_num_valid
:
torch
.
Tensor
# (layers, num_logical_experts)
# (layers, num_logical_experts)
...
...
@@ -203,6 +204,7 @@ class ExpertLocationMetadata:
return
ExpertLocationMetadata
(
physical_to_logical_map
=
physical_to_logical_map
,
physical_to_logical_map_cpu
=
physical_to_logical_map
.
cpu
(),
logical_to_all_physical_map
=
logical_to_all_physical_map_padded
,
logical_to_all_physical_map_num_valid
=
logical_to_all_physical_map_num_valid
,
logical_to_rank_dispatch_physical_map
=
(
...
...
@@ -223,6 +225,7 @@ class ExpertLocationMetadata:
def
update
(
self
,
other
:
"ExpertLocationMetadata"
,
update_layer_ids
:
List
[
int
],
):
for
field
in
[
"ep_size"
,
...
...
@@ -231,15 +234,21 @@ class ExpertLocationMetadata:
for
field
in
[
"physical_to_logical_map"
,
"physical_to_logical_map_cpu"
,
"logical_to_all_physical_map"
,
"logical_to_all_physical_map_num_valid"
,
"logical_to_rank_dispatch_physical_map"
,
]:
src
=
getattr
(
other
,
field
)
dst
=
getattr
(
self
,
field
)
assert
(
src
is
not
None
)
==
(
dst
is
not
None
)
if
dst
is
not
None
:
dst
[...]
=
src
other_field
=
getattr
(
other
,
field
)
self_field
=
getattr
(
self
,
field
)
assert
(
other_field
is
not
None
)
==
(
self_field
is
not
None
)
if
self_field
is
not
None
:
mask_update
=
torch
.
tensor
(
[
i
in
update_layer_ids
for
i
in
range
(
self
.
num_layers
)]
)
mask_update
=
mask_update
.
view
(
*
([
-
1
]
+
[
1
]
*
(
self_field
.
dim
()
-
1
)))
mask_update
=
mask_update
.
to
(
self_field
.
device
,
non_blocking
=
True
)
self_field
[...]
=
torch
.
where
(
mask_update
,
other_field
,
self_field
)
# -------------------------------- usage ------------------------------------
...
...
python/sglang/srt/model_executor/expert_location_updater.py
View file @
0de5e7d4
...
...
@@ -24,6 +24,7 @@ from sglang.srt.managers.expert_location import (
ExpertLocationMetadata
,
get_global_expert_location_metadata
,
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.utils
import
get_bool_env_var
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -37,6 +38,7 @@ class ExpertLocationUpdater:
self
,
routed_experts_weights_of_layer
:
Dict
[
int
,
List
[
torch
.
Tensor
]],
new_expert_location_metadata
:
ExpertLocationMetadata
,
update_layer_ids
:
List
[
int
],
nnodes
:
int
,
rank
:
int
,
):
...
...
@@ -46,45 +48,47 @@ class ExpertLocationUpdater:
old_expert_location_metadata
=
get_global_expert_location_metadata
()
_update_expert_weights
(
routed_experts_weights_of_layer
,
old_expert_location_metadata
,
routed_experts_weights_of_layer
=
routed_experts_weights_of_layer
,
old_expert_location_metadata
=
old_expert_location_metadata
,
new_expert_location_metadata
=
new_expert_location_metadata
,
update_layer_ids
=
update_layer_ids
,
nnodes
=
nnodes
,
rank
=
rank
,
)
old_expert_location_metadata
.
update
(
new_expert_location_metadata
,
nnodes
,
rank
,
update_layer_ids
=
update_layer_ids
,
)
old_expert_location_metadata
.
update
(
new_expert_location_metadata
)
def
_update_expert_weights
(
routed_experts_weights_of_layer
:
Dict
[
int
,
List
[
torch
.
Tensor
]],
old_expert_location_metadata
:
ExpertLocationMetadata
,
new_expert_location_metadata
:
ExpertLocationMetadata
,
update_layer_ids
:
List
[
int
],
nnodes
:
int
,
rank
:
int
,
):
log_metrics
=
get_bool_env_var
(
"SGLANG_EXPERT_LOCATION_UPDATER_LOG_METRICS"
)
temp_buffers
=
create_temp_buffers
(
next
(
iter
(
routed_experts_weights_of_layer
.
values
()))
routed_experts_weights_of_layer
[
update_layer_ids
[
0
]]
)
world_size
=
torch
.
distributed
.
get_world_size
()
num_local_physical_experts
=
old_expert_location_metadata
.
num_local_physical_experts
num_gpu_per_node
=
world_size
//
nnodes
old_physical_to_logical_map
=
(
old_expert_location_metadata
.
physical_to_logical_map
.
tolist
()
)
new_physical_to_logical_map
=
(
new_expert_location_metadata
.
physical_to_logical_map
.
tolist
()
)
for
layer_id
in
sorted
(
routed_experts_weights_of_layer
.
keys
()):
for
layer_id
in
update_layer_ids
:
update_expert_weights_single_layer
(
routed_experts_weights
=
routed_experts_weights_of_layer
[
layer_id
],
temp_buffers
=
temp_buffers
,
old_physical_to_logical_map
=
old_physical_to_logical_map
[
layer_id
],
new_physical_to_logical_map
=
new_physical_to_logical_map
[
layer_id
],
old_physical_to_logical_map
=
old_expert_location_metadata
.
physical_to_logical_map_cpu
[
layer_id
].
tolist
(),
new_physical_to_logical_map
=
new_expert_location_metadata
.
physical_to_logical_map_cpu
[
layer_id
].
tolist
(),
num_local_physical_experts
=
num_local_physical_experts
,
num_gpu_per_node
=
num_gpu_per_node
,
rank
=
rank
,
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
0de5e7d4
...
...
@@ -611,11 +611,14 @@ class ModelRunner:
)
from
None
def
update_expert_location
(
self
,
new_expert_location_metadata
:
ExpertLocationMetadata
self
,
new_expert_location_metadata
:
ExpertLocationMetadata
,
update_layer_ids
:
List
[
int
],
):
self
.
expert_location_updater
.
update
(
self
.
model
.
routed_experts_weights_of_layer
,
new_expert_location_metadata
,
update_layer_ids
=
update_layer_ids
,
nnodes
=
self
.
server_args
.
nnodes
,
rank
=
self
.
tp_rank
,
)
...
...
@@ -1203,7 +1206,7 @@ class ModelRunner:
)
if
self
.
eplb_manager
is
not
None
:
self
.
eplb_manager
.
on_forward_pass_end
(
self
.
forward_pass_id
)
self
.
eplb_manager
.
on_forward_pass_end
()
return
output
...
...
python/sglang/srt/server_args.py
View file @
0de5e7d4
...
...
@@ -180,6 +180,7 @@ class ServerArgs:
enable_eplb
:
bool
=
False
eplb_algorithm
:
str
=
"auto"
eplb_rebalance_num_iterations
:
int
=
1000
eplb_rebalance_layers_per_chunk
:
Optional
[
int
]
=
None
expert_distribution_recorder_mode
:
Optional
[
Literal
[
"stat"
,
"per_pass"
,
"per_token"
]
]
=
None
...
...
@@ -1367,6 +1368,12 @@ class ServerArgs:
default
=
ServerArgs
.
eplb_rebalance_num_iterations
,
help
=
"Number of iterations to automatically trigger a EPLB re-balance."
,
)
parser
.
add_argument
(
"--eplb-rebalance-layers-per-chunk"
,
type
=
int
,
default
=
ServerArgs
.
eplb_rebalance_layers_per_chunk
,
help
=
"Number of layers to rebalance per forward pass."
,
)
parser
.
add_argument
(
"--expert-distribution-recorder-mode"
,
type
=
str
,
...
...
test/srt/test_eplb.py
View file @
0de5e7d4
...
...
@@ -5,7 +5,6 @@ from pathlib import Path
from
types
import
SimpleNamespace
import
sglang
as
sgl
from
sglang.srt.managers.expert_distribution_storage
import
ExpertDistributionStorage
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
...
...
@@ -17,7 +16,9 @@ from sglang.test.test_utils import (
)
class
TestDynamicEPLB
(
CustomTestCase
):
class
_BaseTestDynamicEPLB
(
CustomTestCase
):
extra_args
=
[]
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MLA_MODEL_NAME_FOR_TEST
...
...
@@ -51,8 +52,13 @@ class TestDynamicEPLB(CustomTestCase):
"stat"
,
"--ep-dispatch-algorithm"
,
"static"
,
*
cls
.
extra_args
,
],
env
=
{
"SGL_ENABLE_JIT_DEEPGEMM"
:
"0"
,
**
os
.
environ
},
env
=
{
"SGL_ENABLE_JIT_DEEPGEMM"
:
"0"
,
"SGLANG_EXPERT_LOCATION_UPDATER_CANARY"
:
"1"
,
**
os
.
environ
,
},
)
@
classmethod
...
...
@@ -72,6 +78,14 @@ class TestDynamicEPLB(CustomTestCase):
self
.
assertGreater
(
metrics
[
"score"
],
0.5
)
class
TestDynamicEPLBSimple
(
_BaseTestDynamicEPLB
):
pass
class
TestDynamicEPLBMultiChunk
(
_BaseTestDynamicEPLB
):
extra_args
=
[
"--eplb-rebalance-layers-per-chunk"
,
"1"
]
class
TestStaticEPLB
(
CustomTestCase
):
def
test_save_expert_distribution_and_init_expert_location
(
self
):
os
.
environ
[
"SGL_ENABLE_JIT_DEEPGEMM"
]
=
"0"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment