Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c2a26e72
Unverified
Commit
c2a26e72
authored
Aug 30, 2025
by
hzh0425
Committed by
GitHub
Aug 30, 2025
Browse files
feature(eplb): add min-rebalancing-utilization-threshold for eplb (#8345)
Co-authored-by:
yizhang2077
<
1109276519@qq.com
>
parent
591e6c59
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
62 additions
and
4 deletions
+62
-4
python/sglang/srt/eplb/eplb_manager.py
python/sglang/srt/eplb/eplb_manager.py
+26
-2
python/sglang/srt/eplb/expert_distribution.py
python/sglang/srt/eplb/expert_distribution.py
+29
-2
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+7
-0
No files found.
python/sglang/srt/eplb/eplb_manager.py
View file @
c2a26e72
...
@@ -58,9 +58,18 @@ class EPLBManager:
...
@@ -58,9 +58,18 @@ class EPLBManager:
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
time_start
=
time
.
time
()
time_start
=
time
.
time
()
logical_coun
t
=
get_global_expert_distribution_recorder
().
dump_record
(
dump_record_outpu
t
=
get_global_expert_distribution_recorder
().
dump_record
(
output_mode
=
"object"
output_mode
=
"object"
)[
"logical_count"
]
)
logical_count
=
dump_record_output
[
"logical_count"
]
average_utilization_rate_over_window
=
dump_record_output
[
"average_utilization_rate_over_window"
]
# Check whether rebalancing is needed
if
not
self
.
_check_rebalance_needed
(
average_utilization_rate_over_window
):
return
expert_location_metadata
=
ExpertLocationMetadata
.
init_by_eplb
(
expert_location_metadata
=
ExpertLocationMetadata
.
init_by_eplb
(
self
.
_server_args
,
self
.
_model_runner
.
model_config
,
logical_count
self
.
_server_args
,
self
.
_model_runner
.
model_config
,
logical_count
)
)
...
@@ -81,6 +90,21 @@ class EPLBManager:
...
@@ -81,6 +90,21 @@ class EPLBManager:
msg
+=
f
" time=
{
time_end
-
time_start
:.
3
f
}
s"
msg
+=
f
" time=
{
time_end
-
time_start
:.
3
f
}
s"
logger
.
info
(
msg
)
logger
.
info
(
msg
)
def
_check_rebalance_needed
(
self
,
average_utilization_rate_over_window
):
if
average_utilization_rate_over_window
is
None
:
return
True
if
(
average_utilization_rate_over_window
>
self
.
_server_args
.
eplb_min_rebalancing_utilization_threshold
):
logger
.
info
(
f
"[EPLBManager] Skipped ep rebalancing: current GPU utilization
{
average_utilization_rate_over_window
:.
2
f
}
> minimum rebalance threshold
{
self
.
_server_args
.
eplb_min_rebalancing_utilization_threshold
:.
2
f
}
"
)
return
False
return
True
def
_compute_update_layer_ids_chunks
(
self
)
->
List
[
List
[
int
]]:
def
_compute_update_layer_ids_chunks
(
self
)
->
List
[
List
[
int
]]:
all_layer_ids
=
sorted
(
all_layer_ids
=
sorted
(
list
(
self
.
_model_runner
.
model
.
routed_experts_weights_of_layer
.
keys
())
list
(
self
.
_model_runner
.
model
.
routed_experts_weights_of_layer
.
keys
())
...
...
python/sglang/srt/eplb/expert_distribution.py
View file @
c2a26e72
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
import
logging
import
logging
import
math
import
os
import
os
import
time
import
time
from
abc
import
ABC
from
abc
import
ABC
...
@@ -614,8 +615,8 @@ class _UtilizationRateAccumulatorMixin(_Accumulator):
...
@@ -614,8 +615,8 @@ class _UtilizationRateAccumulatorMixin(_Accumulator):
self
.
_enable
=
self
.
_server_args
.
enable_expert_distribution_metrics
self
.
_enable
=
self
.
_server_args
.
enable_expert_distribution_metrics
if
self
.
_enable
:
if
self
.
_enable
:
window_sizes
=
[
10
,
100
,
1000
]
self
.
window_sizes
=
[
10
,
100
,
1000
]
self
.
_history
=
_DequeCollection
(
maxlens
=
window_sizes
)
self
.
_history
=
_DequeCollection
(
maxlens
=
self
.
window_sizes
)
self
.
_rank
=
torch
.
distributed
.
get_rank
()
self
.
_rank
=
torch
.
distributed
.
get_rank
()
def
append
(
def
append
(
...
@@ -787,6 +788,7 @@ class _StatAccumulator(_UtilizationRateAccumulatorMixin):
...
@@ -787,6 +788,7 @@ class _StatAccumulator(_UtilizationRateAccumulatorMixin):
output
=
dict
(
output
=
dict
(
rank
=
self
.
_rank
,
rank
=
self
.
_rank
,
logical_count
=
logical_count_of_buffered_step
,
logical_count
=
logical_count_of_buffered_step
,
average_utilization_rate_over_window
=
self
.
_get_global_average_utilization_rate
(),
)
)
if
output_mode
==
"file"
:
if
output_mode
==
"file"
:
...
@@ -797,6 +799,31 @@ class _StatAccumulator(_UtilizationRateAccumulatorMixin):
...
@@ -797,6 +799,31 @@ class _StatAccumulator(_UtilizationRateAccumulatorMixin):
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
def
_get_global_average_utilization_rate
(
self
):
if
not
self
.
_enable
or
math
.
isclose
(
self
.
_server_args
.
eplb_min_rebalancing_utilization_threshold
,
1.0
):
return
None
if
self
.
_rank
==
0
:
utilization_mean_rates
=
self
.
_history
.
mean
()
window_index
=
self
.
window_sizes
[
-
1
]
average_utilization_rate_over_window
=
(
utilization_mean_rates
[
window_index
]
if
window_index
in
utilization_mean_rates
else
0
)
avg_rate_tensor
=
torch
.
tensor
(
[
average_utilization_rate_over_window
],
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
else
:
avg_rate_tensor
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
torch
.
distributed
.
broadcast
(
avg_rate_tensor
,
src
=
0
)
return
avg_rate_tensor
.
item
()
def
_dump_to_file
(
name
,
data
):
def
_dump_to_file
(
name
,
data
):
save_dir
=
Path
(
os
.
environ
.
get
(
"SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR"
,
"/tmp"
))
save_dir
=
Path
(
os
.
environ
.
get
(
"SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR"
,
"/tmp"
))
...
...
python/sglang/srt/server_args.py
View file @
c2a26e72
...
@@ -274,6 +274,7 @@ class ServerArgs:
...
@@ -274,6 +274,7 @@ class ServerArgs:
eplb_algorithm
:
str
=
"auto"
eplb_algorithm
:
str
=
"auto"
eplb_rebalance_num_iterations
:
int
=
1000
eplb_rebalance_num_iterations
:
int
=
1000
eplb_rebalance_layers_per_chunk
:
Optional
[
int
]
=
None
eplb_rebalance_layers_per_chunk
:
Optional
[
int
]
=
None
eplb_min_rebalancing_utilization_threshold
:
float
=
1.0
expert_distribution_recorder_mode
:
Optional
[
expert_distribution_recorder_mode
:
Optional
[
Literal
[
"stat"
,
"stat_approx"
,
"per_pass"
,
"per_token"
]
Literal
[
"stat"
,
"stat_approx"
,
"per_pass"
,
"per_token"
]
]
=
None
]
=
None
...
@@ -1595,6 +1596,12 @@ class ServerArgs:
...
@@ -1595,6 +1596,12 @@ class ServerArgs:
default
=
ServerArgs
.
eplb_rebalance_layers_per_chunk
,
default
=
ServerArgs
.
eplb_rebalance_layers_per_chunk
,
help
=
"Number of layers to rebalance per forward pass."
,
help
=
"Number of layers to rebalance per forward pass."
,
)
)
parser
.
add_argument
(
"--eplb-min-rebalancing-utilization-threshold"
,
type
=
float
,
default
=
ServerArgs
.
eplb_min_rebalancing_utilization_threshold
,
help
=
"Minimum threshold for GPU average utilization to trigger EPLB rebalancing. Must be in the range [0.0, 1.0]."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--expert-distribution-recorder-mode"
,
"--expert-distribution-recorder-mode"
,
type
=
str
,
type
=
str
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment