Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2ac453b0
Unverified
Commit
2ac453b0
authored
Oct 02, 2025
by
fzyzcjy
Committed by
GitHub
Oct 02, 2025
Browse files
Tiny detect slow ranks (#10508)
parent
f35def86
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
75 additions
and
0 deletions
+75
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+4
-0
python/sglang/srt/slow_rank_detector.py
python/sglang/srt/slow_rank_detector.py
+71
-0
No files found.
python/sglang/srt/model_executor/model_runner.py
View file @
2ac453b0
...
...
@@ -31,6 +31,7 @@ import requests
import
torch
import
torch.distributed
as
dist
from
sglang.srt
import
slow_rank_detector
from
sglang.srt.configs.device_config
import
DeviceConfig
from
sglang.srt.configs.load_config
import
LoadConfig
,
LoadFormat
from
sglang.srt.configs.model_config
import
AttentionArch
,
ModelConfig
...
...
@@ -283,6 +284,9 @@ class ModelRunner:
# CPU offload
set_offloader
(
create_offloader_from_server_args
(
server_args
,
dp_rank
=
dp_rank
))
if
get_bool_env_var
(
"SGLANG_DETECT_SLOW_RANK"
):
slow_rank_detector
.
execute
()
# Update deep gemm configure
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
:
deep_gemm_wrapper
.
update_deep_gemm_config
(
gpu_id
,
server_args
)
...
...
python/sglang/srt/slow_rank_detector.py
0 → 100644
View file @
2ac453b0
import
logging
from
typing
import
Any
,
Dict
,
List
import
torch
import
torch.distributed
as
dist
import
triton
logger
=
logging
.
getLogger
(
__name__
)
def
execute
():
if
dist
.
get_rank
()
==
0
:
logger
.
info
(
f
"[slow_rank_detector] Start benchmarking..."
)
local_metrics
=
{
bench_name
:
_compute_local_metric
(
bench_name
)
for
bench_name
in
_BENCH_NAMES
}
all_metrics
=
[
None
for
_
in
range
(
dist
.
get_world_size
())]
dist
.
gather_object
(
local_metrics
,
all_metrics
if
dist
.
get_rank
()
==
0
else
None
)
if
dist
.
get_rank
()
==
0
:
_analyze_metrics
(
all_metrics
)
class
_GemmExecutor
:
def
__init__
(
self
):
self
.
lhs
=
torch
.
randn
((
8192
,
8192
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
self
.
rhs
=
torch
.
randn
((
8192
,
8192
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
def
__call__
(
self
):
self
.
lhs
@
self
.
rhs
class
_ElementwiseExecutor
:
def
__init__
(
self
):
self
.
value
=
torch
.
randint
(
0
,
10000
,
(
128
*
1024
**
2
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
def
__call__
(
self
):
self
.
value
+=
1
_EXECUTOR_CLS_OF_BENCH
=
{
"gemm"
:
_GemmExecutor
,
"elementwise"
:
_ElementwiseExecutor
,
}
_BENCH_NAMES
=
list
(
_EXECUTOR_CLS_OF_BENCH
.
keys
())
def
_compute_local_metric
(
bench_name
):
executor
=
_EXECUTOR_CLS_OF_BENCH
[
bench_name
]()
ms
=
triton
.
testing
.
do_bench_cudagraph
(
executor
,
return_mode
=
"mean"
,
rep
=
20
)
return
ms
def
_analyze_metrics
(
all_metrics
:
List
[
Dict
[
str
,
Any
]]):
for
bench_name
in
_BENCH_NAMES
:
time_of_rank
=
torch
.
tensor
([
m
[
bench_name
]
for
m
in
all_metrics
])
speed_of_rank
=
1
/
time_of_rank
rel_speed_of_rank
=
speed_of_rank
/
speed_of_rank
.
max
()
slowest_rel_speed
=
rel_speed_of_rank
.
min
().
item
()
logger
.
info
(
f
"[slow_rank_detector]
{
bench_name
=
}
{
slowest_rel_speed
=
}
{
rel_speed_of_rank
=
}
{
time_of_rank
=
}
"
)
if
slowest_rel_speed
<
0.9
:
logger
.
warning
(
"[slow_rank_detector] Some ranks are too slow compared with others"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment