Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0de5e7d4
Unverified
Commit
0de5e7d4
authored
Jun 05, 2025
by
fzyzcjy
Committed by
GitHub
Jun 05, 2025
Browse files
Support layerwise rebalancing experts (#6851)
parent
72a110f6
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
114 additions
and
37 deletions
+114
-37
python/sglang/srt/managers/eplb_manager.py
python/sglang/srt/managers/eplb_manager.py
+51
-11
python/sglang/srt/managers/expert_location.py
python/sglang/srt/managers/expert_location.py
+14
-5
python/sglang/srt/model_executor/expert_location_updater.py
python/sglang/srt/model_executor/expert_location_updater.py
+20
-16
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+5
-2
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+7
-0
test/srt/test_eplb.py
test/srt/test_eplb.py
+17
-3
No files found.
python/sglang/srt/managers/eplb_manager.py
View file @
0de5e7d4
import
logging
import
logging
import
time
import
time
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
,
List
import
torch.cuda
import
torch.cuda
...
@@ -20,6 +20,10 @@ class EPLBManager:
...
@@ -20,6 +20,10 @@ class EPLBManager:
super
().
__init__
()
super
().
__init__
()
self
.
_model_runner
=
model_runner
self
.
_model_runner
=
model_runner
self
.
_server_args
=
model_runner
.
server_args
self
.
_server_args
=
model_runner
.
server_args
self
.
_rebalance_layers_per_chunk
=
(
self
.
_server_args
.
eplb_rebalance_layers_per_chunk
)
self
.
_rebalance_num_iterations
=
self
.
_server_args
.
eplb_rebalance_num_iterations
# Otherwise, the circular buffer will contain stale data. If the case is needed, it can be implemented.
# Otherwise, the circular buffer will contain stale data. If the case is needed, it can be implemented.
assert
(
assert
(
...
@@ -31,17 +35,30 @@ class EPLBManager:
...
@@ -31,17 +35,30 @@ class EPLBManager:
get_global_expert_distribution_recorder
().
start_record
()
get_global_expert_distribution_recorder
().
start_record
()
logger
.
info
(
logger
.
info
(
f
"[EPLBManager] system started, will rebalance per
{
self
.
_
server_args
.
eplb_
rebalance_num_iterations
}
iterations."
f
"[EPLBManager] system started, will rebalance per
{
self
.
_rebalance_num_iterations
}
iterations."
)
)
def
on_forward_pass_end
(
self
,
forward_pass_id
:
int
):
self
.
_main_generator
=
self
.
_entrypoint
()
if
forward_pass_id
%
self
.
_server_args
.
eplb_rebalance_num_iterations
==
0
:
self
.
rebalance
()
def
on_forward_pass_end
(
self
):
next
(
self
.
_main_generator
)
# can be more complex if needed
def
_entrypoint
(
self
):
while
True
:
for
_
in
range
(
self
.
_rebalance_num_iterations
):
yield
yield
from
self
.
rebalance
()
def
rebalance
(
self
):
def
rebalance
(
self
):
logger
.
info
(
"[EPLBManager] rebalance start"
)
logger
.
info
(
"[EPLBManager] rebalance start"
)
torch
.
cuda
.
synchronize
()
time_start
=
time
.
time
()
enable_timing
=
self
.
_rebalance_layers_per_chunk
is
None
if
enable_timing
:
torch
.
cuda
.
synchronize
()
time_start
=
time
.
time
()
logical_count
=
get_global_expert_distribution_recorder
().
dump_record
(
logical_count
=
get_global_expert_distribution_recorder
().
dump_record
(
output_mode
=
"object"
output_mode
=
"object"
...
@@ -49,8 +66,31 @@ class EPLBManager:
...
@@ -49,8 +66,31 @@ class EPLBManager:
expert_location_metadata
=
ExpertLocationMetadata
.
init_by_eplb
(
expert_location_metadata
=
ExpertLocationMetadata
.
init_by_eplb
(
self
.
_server_args
,
self
.
_model_runner
.
model_config
,
logical_count
self
.
_server_args
,
self
.
_model_runner
.
model_config
,
logical_count
)
)
self
.
_model_runner
.
update_expert_location
(
expert_location_metadata
)
torch
.
cuda
.
synchronize
()
update_layer_ids_chunks
=
self
.
_compute_update_layer_ids_chunks
()
time_end
=
time
.
time
()
for
chunk_index
,
update_layer_ids
in
enumerate
(
update_layer_ids_chunks
):
logger
.
info
(
f
"[EPLBManager] rebalance end time=
{
time_end
-
time_start
:.
3
f
}
s"
)
if
len
(
update_layer_ids_chunks
)
>
1
:
yield
self
.
_model_runner
.
update_expert_location
(
expert_location_metadata
,
update_layer_ids
=
update_layer_ids
,
)
msg
=
f
"[EPLBManager] rebalance end"
if
enable_timing
:
torch
.
cuda
.
synchronize
()
time_end
=
time
.
time
()
msg
+=
f
" time=
{
time_end
-
time_start
:.
3
f
}
s"
logger
.
info
(
msg
)
def
_compute_update_layer_ids_chunks
(
self
)
->
List
[
List
[
int
]]:
all_layer_ids
=
sorted
(
list
(
self
.
_model_runner
.
model
.
routed_experts_weights_of_layer
.
keys
())
)
chunk_size
=
self
.
_rebalance_layers_per_chunk
or
1000000
return
list
(
_chunk_list
(
all_layer_ids
,
chunk_size
=
chunk_size
))
def
_chunk_list
(
items
:
List
,
chunk_size
):
for
start_index
in
range
(
0
,
len
(
items
),
chunk_size
):
yield
items
[
start_index
:
start_index
+
chunk_size
]
python/sglang/srt/managers/expert_location.py
View file @
0de5e7d4
...
@@ -33,6 +33,7 @@ logger = logging.getLogger(__name__)
...
@@ -33,6 +33,7 @@ logger = logging.getLogger(__name__)
@
dataclass
@
dataclass
class
ExpertLocationMetadata
:
class
ExpertLocationMetadata
:
physical_to_logical_map
:
torch
.
Tensor
# (layers, num_physical_experts)
physical_to_logical_map
:
torch
.
Tensor
# (layers, num_physical_experts)
physical_to_logical_map_cpu
:
torch
.
Tensor
logical_to_all_physical_map
:
torch
.
Tensor
# (layers, num_logical_experts, X)
logical_to_all_physical_map
:
torch
.
Tensor
# (layers, num_logical_experts, X)
logical_to_all_physical_map_num_valid
:
torch
.
Tensor
# (layers, num_logical_experts)
logical_to_all_physical_map_num_valid
:
torch
.
Tensor
# (layers, num_logical_experts)
# (layers, num_logical_experts)
# (layers, num_logical_experts)
...
@@ -203,6 +204,7 @@ class ExpertLocationMetadata:
...
@@ -203,6 +204,7 @@ class ExpertLocationMetadata:
return
ExpertLocationMetadata
(
return
ExpertLocationMetadata
(
physical_to_logical_map
=
physical_to_logical_map
,
physical_to_logical_map
=
physical_to_logical_map
,
physical_to_logical_map_cpu
=
physical_to_logical_map
.
cpu
(),
logical_to_all_physical_map
=
logical_to_all_physical_map_padded
,
logical_to_all_physical_map
=
logical_to_all_physical_map_padded
,
logical_to_all_physical_map_num_valid
=
logical_to_all_physical_map_num_valid
,
logical_to_all_physical_map_num_valid
=
logical_to_all_physical_map_num_valid
,
logical_to_rank_dispatch_physical_map
=
(
logical_to_rank_dispatch_physical_map
=
(
...
@@ -223,6 +225,7 @@ class ExpertLocationMetadata:
...
@@ -223,6 +225,7 @@ class ExpertLocationMetadata:
def
update
(
def
update
(
self
,
self
,
other
:
"ExpertLocationMetadata"
,
other
:
"ExpertLocationMetadata"
,
update_layer_ids
:
List
[
int
],
):
):
for
field
in
[
for
field
in
[
"ep_size"
,
"ep_size"
,
...
@@ -231,15 +234,21 @@ class ExpertLocationMetadata:
...
@@ -231,15 +234,21 @@ class ExpertLocationMetadata:
for
field
in
[
for
field
in
[
"physical_to_logical_map"
,
"physical_to_logical_map"
,
"physical_to_logical_map_cpu"
,
"logical_to_all_physical_map"
,
"logical_to_all_physical_map"
,
"logical_to_all_physical_map_num_valid"
,
"logical_to_all_physical_map_num_valid"
,
"logical_to_rank_dispatch_physical_map"
,
"logical_to_rank_dispatch_physical_map"
,
]:
]:
src
=
getattr
(
other
,
field
)
other_field
=
getattr
(
other
,
field
)
dst
=
getattr
(
self
,
field
)
self_field
=
getattr
(
self
,
field
)
assert
(
src
is
not
None
)
==
(
dst
is
not
None
)
assert
(
other_field
is
not
None
)
==
(
self_field
is
not
None
)
if
dst
is
not
None
:
if
self_field
is
not
None
:
dst
[...]
=
src
mask_update
=
torch
.
tensor
(
[
i
in
update_layer_ids
for
i
in
range
(
self
.
num_layers
)]
)
mask_update
=
mask_update
.
view
(
*
([
-
1
]
+
[
1
]
*
(
self_field
.
dim
()
-
1
)))
mask_update
=
mask_update
.
to
(
self_field
.
device
,
non_blocking
=
True
)
self_field
[...]
=
torch
.
where
(
mask_update
,
other_field
,
self_field
)
# -------------------------------- usage ------------------------------------
# -------------------------------- usage ------------------------------------
...
...
python/sglang/srt/model_executor/expert_location_updater.py
View file @
0de5e7d4
...
@@ -24,6 +24,7 @@ from sglang.srt.managers.expert_location import (
...
@@ -24,6 +24,7 @@ from sglang.srt.managers.expert_location import (
ExpertLocationMetadata
,
ExpertLocationMetadata
,
get_global_expert_location_metadata
,
get_global_expert_location_metadata
,
)
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.utils
import
get_bool_env_var
from
sglang.srt.utils
import
get_bool_env_var
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -37,6 +38,7 @@ class ExpertLocationUpdater:
...
@@ -37,6 +38,7 @@ class ExpertLocationUpdater:
self
,
self
,
routed_experts_weights_of_layer
:
Dict
[
int
,
List
[
torch
.
Tensor
]],
routed_experts_weights_of_layer
:
Dict
[
int
,
List
[
torch
.
Tensor
]],
new_expert_location_metadata
:
ExpertLocationMetadata
,
new_expert_location_metadata
:
ExpertLocationMetadata
,
update_layer_ids
:
List
[
int
],
nnodes
:
int
,
nnodes
:
int
,
rank
:
int
,
rank
:
int
,
):
):
...
@@ -46,45 +48,47 @@ class ExpertLocationUpdater:
...
@@ -46,45 +48,47 @@ class ExpertLocationUpdater:
old_expert_location_metadata
=
get_global_expert_location_metadata
()
old_expert_location_metadata
=
get_global_expert_location_metadata
()
_update_expert_weights
(
_update_expert_weights
(
routed_experts_weights_of_layer
,
routed_experts_weights_of_layer
=
routed_experts_weights_of_layer
,
old_expert_location_metadata
,
old_expert_location_metadata
=
old_expert_location_metadata
,
new_expert_location_metadata
=
new_expert_location_metadata
,
update_layer_ids
=
update_layer_ids
,
nnodes
=
nnodes
,
rank
=
rank
,
)
old_expert_location_metadata
.
update
(
new_expert_location_metadata
,
new_expert_location_metadata
,
nnodes
,
update_layer_ids
=
update_layer_ids
,
rank
,
)
)
old_expert_location_metadata
.
update
(
new_expert_location_metadata
)
def
_update_expert_weights
(
def
_update_expert_weights
(
routed_experts_weights_of_layer
:
Dict
[
int
,
List
[
torch
.
Tensor
]],
routed_experts_weights_of_layer
:
Dict
[
int
,
List
[
torch
.
Tensor
]],
old_expert_location_metadata
:
ExpertLocationMetadata
,
old_expert_location_metadata
:
ExpertLocationMetadata
,
new_expert_location_metadata
:
ExpertLocationMetadata
,
new_expert_location_metadata
:
ExpertLocationMetadata
,
update_layer_ids
:
List
[
int
],
nnodes
:
int
,
nnodes
:
int
,
rank
:
int
,
rank
:
int
,
):
):
log_metrics
=
get_bool_env_var
(
"SGLANG_EXPERT_LOCATION_UPDATER_LOG_METRICS"
)
log_metrics
=
get_bool_env_var
(
"SGLANG_EXPERT_LOCATION_UPDATER_LOG_METRICS"
)
temp_buffers
=
create_temp_buffers
(
temp_buffers
=
create_temp_buffers
(
next
(
iter
(
routed_experts_weights_of_layer
.
values
()))
routed_experts_weights_of_layer
[
update_layer_ids
[
0
]]
)
)
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
torch
.
distributed
.
get_world_size
()
num_local_physical_experts
=
old_expert_location_metadata
.
num_local_physical_experts
num_local_physical_experts
=
old_expert_location_metadata
.
num_local_physical_experts
num_gpu_per_node
=
world_size
//
nnodes
num_gpu_per_node
=
world_size
//
nnodes
old_physical_to_logical_map
=
(
for
layer_id
in
update_layer_ids
:
old_expert_location_metadata
.
physical_to_logical_map
.
tolist
()
)
new_physical_to_logical_map
=
(
new_expert_location_metadata
.
physical_to_logical_map
.
tolist
()
)
for
layer_id
in
sorted
(
routed_experts_weights_of_layer
.
keys
()):
update_expert_weights_single_layer
(
update_expert_weights_single_layer
(
routed_experts_weights
=
routed_experts_weights_of_layer
[
layer_id
],
routed_experts_weights
=
routed_experts_weights_of_layer
[
layer_id
],
temp_buffers
=
temp_buffers
,
temp_buffers
=
temp_buffers
,
old_physical_to_logical_map
=
old_physical_to_logical_map
[
layer_id
],
old_physical_to_logical_map
=
old_expert_location_metadata
.
physical_to_logical_map_cpu
[
new_physical_to_logical_map
=
new_physical_to_logical_map
[
layer_id
],
layer_id
].
tolist
(),
new_physical_to_logical_map
=
new_expert_location_metadata
.
physical_to_logical_map_cpu
[
layer_id
].
tolist
(),
num_local_physical_experts
=
num_local_physical_experts
,
num_local_physical_experts
=
num_local_physical_experts
,
num_gpu_per_node
=
num_gpu_per_node
,
num_gpu_per_node
=
num_gpu_per_node
,
rank
=
rank
,
rank
=
rank
,
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
0de5e7d4
...
@@ -611,11 +611,14 @@ class ModelRunner:
...
@@ -611,11 +611,14 @@ class ModelRunner:
)
from
None
)
from
None
def
update_expert_location
(
def
update_expert_location
(
self
,
new_expert_location_metadata
:
ExpertLocationMetadata
self
,
new_expert_location_metadata
:
ExpertLocationMetadata
,
update_layer_ids
:
List
[
int
],
):
):
self
.
expert_location_updater
.
update
(
self
.
expert_location_updater
.
update
(
self
.
model
.
routed_experts_weights_of_layer
,
self
.
model
.
routed_experts_weights_of_layer
,
new_expert_location_metadata
,
new_expert_location_metadata
,
update_layer_ids
=
update_layer_ids
,
nnodes
=
self
.
server_args
.
nnodes
,
nnodes
=
self
.
server_args
.
nnodes
,
rank
=
self
.
tp_rank
,
rank
=
self
.
tp_rank
,
)
)
...
@@ -1203,7 +1206,7 @@ class ModelRunner:
...
@@ -1203,7 +1206,7 @@ class ModelRunner:
)
)
if
self
.
eplb_manager
is
not
None
:
if
self
.
eplb_manager
is
not
None
:
self
.
eplb_manager
.
on_forward_pass_end
(
self
.
forward_pass_id
)
self
.
eplb_manager
.
on_forward_pass_end
()
return
output
return
output
...
...
python/sglang/srt/server_args.py
View file @
0de5e7d4
...
@@ -180,6 +180,7 @@ class ServerArgs:
...
@@ -180,6 +180,7 @@ class ServerArgs:
enable_eplb
:
bool
=
False
enable_eplb
:
bool
=
False
eplb_algorithm
:
str
=
"auto"
eplb_algorithm
:
str
=
"auto"
eplb_rebalance_num_iterations
:
int
=
1000
eplb_rebalance_num_iterations
:
int
=
1000
eplb_rebalance_layers_per_chunk
:
Optional
[
int
]
=
None
expert_distribution_recorder_mode
:
Optional
[
expert_distribution_recorder_mode
:
Optional
[
Literal
[
"stat"
,
"per_pass"
,
"per_token"
]
Literal
[
"stat"
,
"per_pass"
,
"per_token"
]
]
=
None
]
=
None
...
@@ -1367,6 +1368,12 @@ class ServerArgs:
...
@@ -1367,6 +1368,12 @@ class ServerArgs:
default
=
ServerArgs
.
eplb_rebalance_num_iterations
,
default
=
ServerArgs
.
eplb_rebalance_num_iterations
,
help
=
"Number of iterations to automatically trigger a EPLB re-balance."
,
help
=
"Number of iterations to automatically trigger a EPLB re-balance."
,
)
)
parser
.
add_argument
(
"--eplb-rebalance-layers-per-chunk"
,
type
=
int
,
default
=
ServerArgs
.
eplb_rebalance_layers_per_chunk
,
help
=
"Number of layers to rebalance per forward pass."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--expert-distribution-recorder-mode"
,
"--expert-distribution-recorder-mode"
,
type
=
str
,
type
=
str
,
...
...
test/srt/test_eplb.py
View file @
0de5e7d4
...
@@ -5,7 +5,6 @@ from pathlib import Path
...
@@ -5,7 +5,6 @@ from pathlib import Path
from
types
import
SimpleNamespace
from
types
import
SimpleNamespace
import
sglang
as
sgl
import
sglang
as
sgl
from
sglang.srt.managers.expert_distribution_storage
import
ExpertDistributionStorage
from
sglang.srt.utils
import
kill_process_tree
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.run_eval
import
run_eval
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
from
sglang.test.test_utils
import
(
...
@@ -17,7 +16,9 @@ from sglang.test.test_utils import (
...
@@ -17,7 +16,9 @@ from sglang.test.test_utils import (
)
)
class
TestDynamicEPLB
(
CustomTestCase
):
class
_BaseTestDynamicEPLB
(
CustomTestCase
):
extra_args
=
[]
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MLA_MODEL_NAME_FOR_TEST
cls
.
model
=
DEFAULT_MLA_MODEL_NAME_FOR_TEST
...
@@ -51,8 +52,13 @@ class TestDynamicEPLB(CustomTestCase):
...
@@ -51,8 +52,13 @@ class TestDynamicEPLB(CustomTestCase):
"stat"
,
"stat"
,
"--ep-dispatch-algorithm"
,
"--ep-dispatch-algorithm"
,
"static"
,
"static"
,
*
cls
.
extra_args
,
],
],
env
=
{
"SGL_ENABLE_JIT_DEEPGEMM"
:
"0"
,
**
os
.
environ
},
env
=
{
"SGL_ENABLE_JIT_DEEPGEMM"
:
"0"
,
"SGLANG_EXPERT_LOCATION_UPDATER_CANARY"
:
"1"
,
**
os
.
environ
,
},
)
)
@
classmethod
@
classmethod
...
@@ -72,6 +78,14 @@ class TestDynamicEPLB(CustomTestCase):
...
@@ -72,6 +78,14 @@ class TestDynamicEPLB(CustomTestCase):
self
.
assertGreater
(
metrics
[
"score"
],
0.5
)
self
.
assertGreater
(
metrics
[
"score"
],
0.5
)
class
TestDynamicEPLBSimple
(
_BaseTestDynamicEPLB
):
pass
class
TestDynamicEPLBMultiChunk
(
_BaseTestDynamicEPLB
):
extra_args
=
[
"--eplb-rebalance-layers-per-chunk"
,
"1"
]
class
TestStaticEPLB
(
CustomTestCase
):
class
TestStaticEPLB
(
CustomTestCase
):
def
test_save_expert_distribution_and_init_expert_location
(
self
):
def
test_save_expert_distribution_and_init_expert_location
(
self
):
os
.
environ
[
"SGL_ENABLE_JIT_DEEPGEMM"
]
=
"0"
os
.
environ
[
"SGL_ENABLE_JIT_DEEPGEMM"
]
=
"0"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment