Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7a80f565
Unverified
Commit
7a80f565
authored
May 22, 2025
by
fzyzcjy
Committed by
GitHub
May 21, 2025
Browse files
Support dynamically rebalancing experts using EPLB (#6469)
parent
9484eba4
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
226 additions
and
3 deletions
+226
-3
python/sglang/srt/managers/eplb_manager.py
python/sglang/srt/managers/eplb_manager.py
+55
-0
python/sglang/srt/model_executor/expert_location_updater.py
python/sglang/srt/model_executor/expert_location_updater.py
+2
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+13
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+13
-0
test/srt/test_eplb.py
test/srt/test_eplb.py
+141
-0
test/srt/test_expert_location_updater.py
test/srt/test_expert_location_updater.py
+2
-2
No files found.
python/sglang/srt/managers/eplb_manager.py
0 → 100644
View file @
7a80f565
import
logging
import
time
from
typing
import
TYPE_CHECKING
import
torch.cuda
from
sglang.srt.managers.expert_distribution
import
(
get_global_expert_distribution_recorder
,
)
from
sglang.srt.managers.expert_location
import
ExpertLocationMetadata
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
logger
=
logging
.
getLogger
(
__name__
)
class
EPLBManager
:
def
__init__
(
self
,
model_runner
:
"ModelRunner"
):
super
().
__init__
()
self
.
_model_runner
=
model_runner
self
.
_server_args
=
model_runner
.
server_args
# Otherwise, the circular buffer will contain stale data. If the case is needed, it can be implemented.
assert
(
self
.
_server_args
.
eplb_rebalance_num_iterations
<=
self
.
_server_args
.
expert_distribution_recorder_buffer_size
),
"eplb_rebalance_num_iterations must be less than expert_distribution_recorder_buffer_size"
get_global_expert_distribution_recorder
().
start_record
()
logger
.
info
(
f
"[EPLBManager] system started, will rebalance per
{
self
.
_server_args
.
eplb_rebalance_num_iterations
}
iterations."
)
def
on_forward_pass_end
(
self
,
forward_pass_id
:
int
):
if
forward_pass_id
%
self
.
_server_args
.
eplb_rebalance_num_iterations
==
0
:
self
.
rebalance
()
def
rebalance
(
self
):
logger
.
info
(
"[EPLBManager] rebalance start"
)
torch
.
cuda
.
synchronize
()
time_start
=
time
.
time
()
logical_count
=
get_global_expert_distribution_recorder
().
dump_record
(
output_mode
=
"object"
)[
"logical_count"
]
expert_location_metadata
=
ExpertLocationMetadata
.
init_by_eplb
(
self
.
_server_args
,
self
.
_model_runner
.
model_config
,
logical_count
)
self
.
_model_runner
.
update_expert_location
(
expert_location_metadata
)
torch
.
cuda
.
synchronize
()
time_end
=
time
.
time
()
logger
.
info
(
f
"[EPLBManager] rebalance end time=
{
time_end
-
time_start
:.
3
f
}
s"
)
python/sglang/srt/model_executor/expert_location_updater.py
View file @
7a80f565
...
...
@@ -95,6 +95,8 @@ def update_expert_weights_single_layer(
tensor
.
shape
[
0
]
==
num_local_physical_experts
for
tensor
in
routed_experts_weights
),
f
"
{
num_local_physical_experts
=
}
{
[
x
.
shape
for
x
in
routed_experts_weights
]
=
}
"
assert
isinstance
(
old_physical_to_logical_map
,
list
)
assert
isinstance
(
new_physical_to_logical_map
,
list
)
output_logs
=
[]
if
debug
else
None
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
7a80f565
...
...
@@ -51,6 +51,7 @@ from sglang.srt.layers.quantization.deep_gemm import (
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_to_model
from
sglang.srt.lora.lora_manager
import
LoRAManager
from
sglang.srt.managers.eplb_manager
import
EPLBManager
from
sglang.srt.managers.expert_distribution
import
(
ExpertDistributionRecorder
,
get_global_expert_distribution_recorder
,
...
...
@@ -255,6 +256,12 @@ class ModelRunner:
)
)
self
.
eplb_manager
=
(
EPLBManager
(
self
)
if
self
.
server_args
.
enable_eplb
and
(
not
self
.
is_draft_worker
)
else
None
)
# Load the model
self
.
sampler
=
Sampler
()
self
.
load_model
()
...
...
@@ -1152,10 +1159,15 @@ class ModelRunner:
self
.
forward_pass_id
,
forward_batch
,
):
return
self
.
_forward_raw
(
output
=
self
.
_forward_raw
(
forward_batch
,
skip_attn_backend_init
,
pp_proxy_tensors
)
if
self
.
eplb_manager
is
not
None
:
self
.
eplb_manager
.
on_forward_pass_end
(
self
.
forward_pass_id
)
return
output
def
_forward_raw
(
self
,
forward_batch
:
ForwardBatch
,
...
...
python/sglang/srt/server_args.py
View file @
7a80f565
...
...
@@ -173,6 +173,8 @@ class ServerArgs:
ep_num_redundant_experts
:
int
=
0
ep_dispatch_algorithm
:
Optional
[
Literal
[
"static"
,
"dynamic"
]]
=
None
init_expert_location
:
str
=
"trivial"
enable_eplb
:
bool
=
False
eplb_rebalance_num_iterations
:
int
=
1000
expert_distribution_recorder_mode
:
Optional
[
Literal
[
"stat"
,
"per_pass"
,
"per_token"
]
]
=
None
...
...
@@ -1293,6 +1295,17 @@ class ServerArgs:
default
=
ServerArgs
.
init_expert_location
,
help
=
"Initial location of EP experts."
,
)
parser
.
add_argument
(
"--enable-eplb"
,
action
=
"store_true"
,
help
=
"Enable EPLB algorithm"
,
)
parser
.
add_argument
(
"--eplb-rebalance-num-iterations"
,
type
=
int
,
default
=
ServerArgs
.
eplb_rebalance_num_iterations
,
help
=
"Number of iterations to automatically trigger a EPLB re-balance."
,
)
parser
.
add_argument
(
"--expert-distribution-recorder-mode"
,
type
=
str
,
...
...
test/srt/test_eplb.py
0 → 100755
View file @
7a80f565
import
os
import
tempfile
import
unittest
from
pathlib
import
Path
from
types
import
SimpleNamespace
import
sglang
as
sgl
from
sglang.srt.managers.expert_distribution_storage
import
ExpertDistributionStorage
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_MLA_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
popen_launch_server
,
)
class
TestDynamicEPLB
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MLA_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--trust-remote-code"
,
"--tp"
,
"2"
,
"--dp"
,
"2"
,
"--enable-dp-attention"
,
"--enable-deepep-moe"
,
"--deepep-mode"
,
"normal"
,
"--disable-cuda-graph"
,
"--enable-eplb"
,
"--ep-num-redundant-experts"
,
"4"
,
"--eplb-rebalance-num-iterations"
,
"50"
,
"--expert-distribution-recorder-buffer-size"
,
"50"
,
# TODO pr-chain: enable later
# "--enable-expert-distribution-metrics",
# TODO auto determine these flags
"--expert-distribution-recorder-mode"
,
"stat"
,
"--ep-dispatch-algorithm"
,
"static"
,
],
env
=
{
"SGL_ENABLE_JIT_DEEPGEMM"
:
"0"
,
**
os
.
environ
},
)
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
test_mmlu
(
self
):
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
self
.
model
,
eval_name
=
"mmlu"
,
num_examples
=
64
,
num_threads
=
32
,
)
metrics
=
run_eval
(
args
)
self
.
assertGreater
(
metrics
[
"score"
],
0.5
)
class
TestStaticEPLB
(
CustomTestCase
):
def
test_save_expert_distribution_and_init_expert_location
(
self
):
os
.
environ
[
"SGL_ENABLE_JIT_DEEPGEMM"
]
=
"0"
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
engine_kwargs
=
dict
(
model_path
=
DEFAULT_MLA_MODEL_NAME_FOR_TEST
,
trust_remote_code
=
True
,
ep_num_redundant_experts
=
4
,
enable_dp_attention
=
True
,
enable_deepep_moe
=
True
,
deepep_mode
=
"normal"
,
disable_cuda_graph
=
True
,
expert_distribution_recorder_mode
=
"stat"
,
tp_size
=
2
,
dp_size
=
2
,
log_level
=
"info"
,
# TODO pr-chain: enable later
# enable_expert_distribution_metrics=True,
)
print
(
f
"Action: start engine"
)
os
.
environ
[
"SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR"
]
=
tmp_dir
engine
=
sgl
.
Engine
(
**
engine_kwargs
,
disable_overlap_schedule
=
True
,
)
engine
.
start_expert_distribution_record
()
self
.
_assert_engine_generate_correct
(
engine
)
print
(
f
"Action: dump_expert_distribution_record"
)
engine
.
dump_expert_distribution_record
()
snapshot_path
=
list
(
Path
(
tmp_dir
).
glob
(
"*.pt"
))[
0
]
assert
snapshot_path
is
not
None
print
(
f
"
{
snapshot_path
=
}
"
)
print
(
f
"Action: shutdown engine"
)
engine
.
shutdown
()
del
engine
print
(
f
"Action: start engine with init_expert_location"
)
engine
=
sgl
.
Engine
(
**
engine_kwargs
,
init_expert_location
=
str
(
snapshot_path
),
port
=
21000
,
# TODO auto determine these flags
ep_dispatch_algorithm
=
"static"
,
)
self
.
_assert_engine_generate_correct
(
engine
)
print
(
f
"Action: shutdown engine"
)
engine
.
shutdown
()
del
engine
def
_assert_engine_generate_correct
(
self
,
engine
:
sgl
.
Engine
):
output
=
engine
.
generate
(
prompt
=
[
"1+1=2, 2+2=4"
,
"One plus one is two, two plus two is four"
],
sampling_params
=
dict
(
max_new_tokens
=
8
,
temperature
=
0.0
),
)
print
(
f
"engine.generate
{
output
=
}
"
)
self
.
assertEqual
(
[
x
[
"text"
]
for
x
in
output
],
[
", 4+4=8,"
,
", four plus four is eight, eight"
],
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/test_expert_location_updater.py
View file @
7a80f565
...
...
@@ -210,8 +210,8 @@ def _execute_test(info: _TestInfo, rank: int, num_gpus: int, device: str):
temp_buffers
=
expert_location_updater
.
create_temp_buffers
(
routed_experts_weights
),
old_physical_to_logical_map
=
physical_to_logical_map
,
new_physical_to_logical_map
=
new_physical_to_logical_map
,
old_physical_to_logical_map
=
physical_to_logical_map
.
tolist
()
,
new_physical_to_logical_map
=
new_physical_to_logical_map
.
tolist
()
,
num_local_physical_experts
=
num_local_physical_experts
,
num_gpu_per_node
=
num_gpu_per_node
,
rank
=
rank
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment