Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9484eba4
Unverified
Commit
9484eba4
authored
May 22, 2025
by
fzyzcjy
Committed by
GitHub
May 21, 2025
Browse files
Support logging expert balancedness metrics (#6482)
parent
e9feb488
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
113 additions
and
1 deletion
+113
-1
python/sglang/srt/managers/expert_distribution.py
python/sglang/srt/managers/expert_distribution.py
+107
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+6
-0
No files found.
python/sglang/srt/managers/expert_distribution.py
View file @
9484eba4
...
@@ -15,10 +15,12 @@ import logging
...
@@ -15,10 +15,12 @@ import logging
import
os
import
os
import
time
import
time
from
abc
import
ABC
from
abc
import
ABC
from
collections
import
deque
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Dict
,
List
,
Literal
,
Optional
,
Tuple
,
Type
from
typing
import
Dict
,
List
,
Literal
,
Optional
,
Tuple
,
Type
import
einops
import
torch
import
torch
import
torch.distributed
import
torch.distributed
...
@@ -472,7 +474,80 @@ class _Accumulator(ABC):
...
@@ -472,7 +474,80 @@ class _Accumulator(ABC):
pass
pass
class
_StatAccumulator
(
_Accumulator
):
class
_UtilizationRateAccumulatorMixin
(
_Accumulator
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
_enable
=
self
.
_server_args
.
enable_expert_distribution_metrics
if
self
.
_enable
:
window_sizes
=
[
10
,
100
,
1000
]
self
.
_history
=
_DequeCollection
(
maxlens
=
window_sizes
)
self
.
_rank
=
torch
.
distributed
.
get_rank
()
def
append
(
self
,
forward_pass_id
:
int
,
gatherer_key
:
str
,
single_pass_data
:
Dict
,
):
super
().
append
(
forward_pass_id
,
gatherer_key
,
single_pass_data
)
if
self
.
_enable
:
self
.
_append_utilization_rate
(
forward_pass_id
,
single_pass_data
[
"global_physical_count"
]
)
def
reset
(
self
):
super
().
reset
()
if
self
.
_enable
:
self
.
_history
.
clear
()
def
_append_utilization_rate
(
self
,
forward_pass_id
:
int
,
single_pass_global_physical_count
:
torch
.
Tensor
):
gpu_physical_count
=
compute_gpu_physical_count
(
single_pass_global_physical_count
,
num_gpu
=
self
.
_expert_location_metadata
.
ep_size
,
)
gpu_physical_count
=
gpu_physical_count
.
to
(
self
.
_server_args
.
device
)
torch
.
distributed
.
reduce
(
gpu_physical_count
,
dst
=
0
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
)
if
self
.
_rank
==
0
:
utilization_rate_tensor
=
compute_utilization_rate
(
gpu_physical_count
)
utilization_rate
=
torch
.
mean
(
utilization_rate_tensor
).
item
()
self
.
_history
.
append
(
utilization_rate
)
gpu_physical_count_sum
=
gpu_physical_count
.
sum
().
item
()
logger
.
info
(
f
"[Expert Balancedness] "
f
"forward_pass_id=
{
forward_pass_id
}
"
f
"current_pass_balancedness=
{
utilization_rate
:.
03
f
}
"
f
"
{
''
.
join
(
f
'last_
{
size
}
_average_balancedness
=
{
value
:.
03
f
}
' for size, value in self._history.mean().items())
}
"
f
"gpu_physical_count_sum=
{
gpu_physical_count_sum
}
"
# f"current_pass_per_layer={[round(x, 2) for x in utilization_rate_tensor.cpu().tolist()]}"
)
class
_DequeCollection
:
def
__init__
(
self
,
maxlens
:
List
[
int
]):
self
.
_dequeues
=
[
deque
(
maxlen
=
maxlen
)
for
maxlen
in
maxlens
]
def
append
(
self
,
value
):
for
d
in
self
.
_dequeues
:
d
.
append
(
value
)
def
clear
(
self
):
for
d
in
self
.
_dequeues
:
d
.
clear
()
def
mean
(
self
)
->
Dict
[
int
,
float
]:
return
{
d
.
maxlen
:
sum
(
d
)
/
len
(
d
)
for
d
in
self
.
_dequeues
}
class
_StatAccumulator
(
_UtilizationRateAccumulatorMixin
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
_global_physical_count_of_buffered_step
=
_Buffer
.
init_new
(
self
.
_global_physical_count_of_buffered_step
=
_Buffer
.
init_new
(
...
@@ -619,3 +694,34 @@ def _convert_global_physical_count_to_logical_count(
...
@@ -619,3 +694,34 @@ def _convert_global_physical_count_to_logical_count(
src
=
global_physical_count
,
src
=
global_physical_count
,
)
)
return
logical_count
return
logical_count
def
compute_gpu_physical_count
(
physical_count_of_whatever
:
torch
.
Tensor
,
# (..., num_layer, num_physical_expert)
num_gpu
:
int
,
):
"""output: gpu_physical_count_of_batch (..., num_layer, num_gpu)"""
return
einops
.
reduce
(
physical_count_of_whatever
,
"... num_layer (num_gpu num_expert_per_gpu) -> ... num_layer num_gpu"
,
"sum"
,
num_gpu
=
num_gpu
,
)
def
compute_utilization_rate
(
gpu_physical_count_of_batch
:
torch
.
Tensor
,
# (..., num_layer, num_gpu)
):
"""output: utilization_rate (..., num_layer)"""
gpu_physical_count_of_batch
=
gpu_physical_count_of_batch
.
float
()
max_gpu_physical_count
=
einops
.
reduce
(
gpu_physical_count_of_batch
,
"... num_layer num_gpu -> ... num_layer"
,
"max"
,
)
avg_gpu_physical_count
=
einops
.
reduce
(
gpu_physical_count_of_batch
,
"... num_layer num_gpu -> ... num_layer"
,
"mean"
,
)
return
(
avg_gpu_physical_count
+
1e-5
)
/
(
max_gpu_physical_count
+
1e-5
)
python/sglang/srt/server_args.py
View file @
9484eba4
...
@@ -177,6 +177,7 @@ class ServerArgs:
...
@@ -177,6 +177,7 @@ class ServerArgs:
Literal
[
"stat"
,
"per_pass"
,
"per_token"
]
Literal
[
"stat"
,
"per_pass"
,
"per_token"
]
]
=
None
]
=
None
expert_distribution_recorder_buffer_size
:
Optional
[
int
]
=
None
expert_distribution_recorder_buffer_size
:
Optional
[
int
]
=
None
enable_expert_distribution_metrics
:
bool
=
False
deepep_config
:
Optional
[
str
]
=
None
deepep_config
:
Optional
[
str
]
=
None
enable_torch_compile
:
bool
=
False
enable_torch_compile
:
bool
=
False
torch_compile_max_bs
:
int
=
32
torch_compile_max_bs
:
int
=
32
...
@@ -1304,6 +1305,11 @@ class ServerArgs:
...
@@ -1304,6 +1305,11 @@ class ServerArgs:
default
=
ServerArgs
.
expert_distribution_recorder_buffer_size
,
default
=
ServerArgs
.
expert_distribution_recorder_buffer_size
,
help
=
"Circular buffer size of expert distribution recorder. Set to -1 to denote infinite buffer."
,
help
=
"Circular buffer size of expert distribution recorder. Set to -1 to denote infinite buffer."
,
)
)
parser
.
add_argument
(
"--enable-expert-distribution-metrics"
,
action
=
"store_true"
,
help
=
"Enable logging metrics for expert balancedness"
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--deepep-config"
,
"--deepep-config"
,
type
=
str
,
type
=
str
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment