Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9f00ec44
Unverified
Commit
9f00ec44
authored
Sep 05, 2025
by
fzyzcjy
Committed by
GitHub
Sep 05, 2025
Browse files
Fix and enhance dumper (#8725)
parent
8e85ee88
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
189 additions
and
47 deletions
+189
-47
python/sglang/srt/debug_utils/dump_comparator.py
python/sglang/srt/debug_utils/dump_comparator.py
+81
-44
python/sglang/srt/debug_utils/dump_loader.py
python/sglang/srt/debug_utils/dump_loader.py
+97
-0
python/sglang/srt/debug_utils/dumper.py
python/sglang/srt/debug_utils/dumper.py
+11
-3
No files found.
python/sglang/srt/debug_utils/dump_comparator.py
View file @
9f00ec44
import
argparse
import
functools
import
re
from
pathlib
import
Path
import
polars
as
pl
import
torch
from
sglang.srt.debug_utils.dump_loader
import
find_row
,
read_meta
from
sglang.srt.debug_utils.dumper
import
get_truncated_value
...
...
@@ -26,66 +26,77 @@ def main(args):
print
(
"df_baseline"
,
df_baseline
)
for
row
in
df_target
.
iter_rows
(
named
=
True
):
rows_baseline
=
df_baseline
.
filter
(
(
pl
.
col
(
"forward_pass_id"
)
==
row
[
"forward_pass_id"
]
-
args
.
start_id
+
args
.
baseline_start_id
)
&
functools
.
reduce
(
lambda
a
,
b
:
a
&
b
,
[
pl
.
col
(
col
)
==
row
[
col
]
for
col
in
row
.
keys
()
if
col
not
in
[
"forward_pass_id"
,
"dump_index"
,
"filename"
]
],
)
path_target
=
Path
(
args
.
target_path
)
/
row
[
"filename"
]
row_baseline
=
find_row
(
df_baseline
,
conditions
=
dict
(
forward_pass_id
=
row
[
"forward_pass_id"
]
-
args
.
start_id
+
args
.
baseline_start_id
,
**
{
k
:
v
for
k
,
v
in
row
.
items
()
if
k
not
in
[
"forward_pass_id"
,
"dump_index"
,
"filename"
]
},
),
)
assert
len
(
rows_baseline
)
==
1
,
f
"
{
rows_baseline
=
}
"
row_baseline
=
rows_baseline
.
to_dicts
()[
0
]
if
row_baseline
is
None
:
print
(
f
"Skip: target=
{
str
(
path_target
)
}
since no baseline"
)
x_target
=
_load_object
(
path_target
)
if
x_target
is
not
None
:
print
(
f
"x_target(sample)=
{
get_truncated_value
(
x_target
)
}
"
)
continue
path_baseline
=
Path
(
args
.
baseline_path
)
/
row_baseline
[
"filename"
]
path_target
=
Path
(
args
.
target_path
)
/
row
[
"filename"
]
print
(
f
"Check: target=
{
str
(
path_target
)
}
baseline=
{
str
(
path_baseline
)
}
"
)
check_tensor_pair
(
path_baseline
=
path_baseline
,
path_target
=
path_target
)
check_tensor_pair
(
path_baseline
=
path_baseline
,
path_target
=
path_target
,
name
=
row
[
"name"
]
)
print
()
def
read_meta
(
directory
):
directory
=
Path
(
directory
)
assert
directory
.
is_dir
(),
f
"
{
directory
=
}
should be a directory"
rows
=
[]
for
p
in
directory
.
glob
(
"*.pt"
):
full_kwargs
=
{}
for
kv
in
p
.
stem
.
split
(
"___"
):
k
,
v
=
kv
.
split
(
"="
)
full_kwargs
[
k
]
=
v
rows
.
append
(
{
"filename"
:
str
(
p
.
name
),
**
full_kwargs
,
}
)
def
check_tensor_pair
(
path_baseline
,
path_target
,
name
=
""
):
x_baseline
=
_load_object
(
path_baseline
)
x_target
=
_load_object
(
path_target
)
df
=
pl
.
DataFrame
(
rows
)
df
=
df
.
with_columns
(
pl
.
col
(
"forward_pass_id"
).
cast
(
int
),
pl
.
col
(
"rank"
).
cast
(
int
),
print
(
f
"Raw "
f
"[shape]
{
x_baseline
.
shape
}
vs
{
x_target
.
shape
}
\t
"
f
"[dtype]
{
x_baseline
.
dtype
}
vs
{
x_target
.
dtype
}
"
)
return
df
def
check_tensor_pair
(
path_baseline
,
path_target
):
x_baseline
=
torch
.
load
(
path_baseline
,
weights_only
=
True
)
x_target
=
torch
.
load
(
path_target
,
weights_only
=
True
)
x_baseline
,
x_target
=
_comparison_preprocessor
(
x_baseline
,
x_target
,
name
=
name
)
x_baseline
=
_try_unify_shape
(
x_baseline
,
target_shape
=
x_target
.
shape
)
print
(
f
"After preprocessor "
f
"[shape]
{
x_baseline
.
shape
}
vs
{
x_target
.
shape
}
\t
"
f
"[dtype]
{
x_baseline
.
dtype
}
vs
{
x_target
.
dtype
}
"
)
x_target
=
x_target
.
float
()
x_baseline
=
x_baseline
.
float
()
for
name
,
fn
in
(
(
"mean"
,
torch
.
mean
),
(
"std"
,
torch
.
std
),
(
"min"
,
torch
.
min
),
(
"max"
,
torch
.
max
),
(
"p1"
,
functools
.
partial
(
torch
.
quantile
,
q
=
0.01
)),
(
"p5"
,
functools
.
partial
(
torch
.
quantile
,
q
=
0.05
)),
(
"p95"
,
functools
.
partial
(
torch
.
quantile
,
q
=
0.95
)),
(
"p99"
,
functools
.
partial
(
torch
.
quantile
,
q
=
0.99
)),
):
value_baseline
=
fn
(
x_baseline
).
item
()
value_target
=
fn
(
x_target
).
item
()
print
(
f
"[
{
name
}
]
{
value_baseline
:.
4
f
}
vs
{
value_target
:.
4
f
}
(diff:
{
value_target
-
value_baseline
:.
4
f
}
)"
)
if
x_baseline
.
shape
!=
x_target
.
shape
:
print
(
f
"
❌
Shape mismatch"
)
print
(
f
"
⚠️
Shape mismatch"
)
return
raw_abs_diff
=
(
x_target
-
x_baseline
).
abs
()
...
...
@@ -112,6 +123,19 @@ def check_tensor_pair(path_baseline, path_target):
print
(
f
"x_target(sample)=
{
get_truncated_value
(
x_target
)
}
"
)
def
_try_unify_shape
(
x
:
torch
.
Tensor
,
target_shape
):
x_shape
=
x
.
shape
num_dim_to_remove
=
len
(
x_shape
)
-
len
(
target_shape
)
if
(
x_shape
[
num_dim_to_remove
:]
==
target_shape
)
and
all
(
val
==
1
for
val
in
x_shape
[:
num_dim_to_remove
]
):
out
=
functools
.
reduce
(
lambda
a
,
_
:
a
.
squeeze
(
0
),
range
(
num_dim_to_remove
),
x
)
print
(
f
"Unify shape:
{
x_shape
}
->
{
out
.
shape
}
(to match
{
target_shape
}
)"
)
return
out
return
x
# Copied from DeepGEMM
def
_calc_rel_diff
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
):
x
,
y
=
x
.
double
(),
y
.
double
()
...
...
@@ -120,6 +144,19 @@ def _calc_rel_diff(x: torch.Tensor, y: torch.Tensor):
return
1
-
sim
def
_comparison_preprocessor
(
x_baseline
,
x_target
,
name
):
# can insert arbitrary adhoc postprocessing logic here
return
x_baseline
,
x_target
def
_load_object
(
path
):
x
=
torch
.
load
(
path
,
weights_only
=
False
)
if
not
isinstance
(
x
,
torch
.
Tensor
):
print
(
f
"Skip load
{
path
}
since
{
type
(
x
)
=
}
is not a Tensor"
)
return
None
return
x
.
cuda
()
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--baseline-path"
,
type
=
str
)
...
...
python/sglang/srt/debug_utils/dump_loader.py
0 → 100644
View file @
9f00ec44
import
functools
import
os
from
pathlib
import
Path
from
typing
import
Any
,
Dict
import
polars
as
pl
import
torch
class
DumpLoader
:
def
__init__
(
self
):
directory
=
os
.
environ
.
get
(
"SGLANG_DUMP_LOADER_DIR"
)
self
.
_enable
=
directory
is
not
None
if
self
.
_enable
:
self
.
_directory
=
Path
(
directory
)
self
.
_df
=
read_meta
(
directory
)
@
property
def
enable
(
self
):
return
self
.
_enable
def
load
(
self
,
name
,
**
kwargs
):
assert
self
.
_enable
,
"Please call DumpLoader.load only when it is enabled"
from
sglang.srt.debug_utils.dumper
import
dumper
forward_pass_id
=
dumper
.
_forward_pass_id
conditions
=
dict
(
name
=
name
,
forward_pass_id
=
forward_pass_id
,
**
kwargs
)
row
=
find_row
(
self
.
_df
,
conditions
=
conditions
)
assert
(
row
is
not
None
),
f
"DumpLoader cannot find row given query
{
name
=
}
{
kwargs
=
}
{
self
.
_directory
=
}
"
path
=
self
.
_directory
/
row
[
"filename"
]
output
=
torch
.
load
(
path
,
weights_only
=
False
)
print
(
f
"[DumpLoader] load from
{
path
=
}
(query:
{
name
=
}
{
kwargs
=
}
, output:
{
type
(
output
)
}
)"
)
return
output
def
read_meta
(
directory
):
directory
=
Path
(
directory
)
assert
directory
.
is_dir
(),
f
"
{
directory
=
}
should be a directory"
rows
=
[]
for
p
in
directory
.
glob
(
"*.pt"
):
full_kwargs
=
{}
for
kv
in
p
.
stem
.
split
(
"___"
):
k
,
v
=
kv
.
split
(
"="
)
full_kwargs
[
k
]
=
v
rows
.
append
(
{
"filename"
:
str
(
p
.
name
),
**
full_kwargs
,
}
)
df
=
pl
.
DataFrame
(
rows
)
df
=
df
.
with_columns
(
pl
.
col
(
"forward_pass_id"
).
cast
(
int
),
pl
.
col
(
"rank"
).
cast
(
int
),
pl
.
col
(
"dump_index"
).
cast
(
int
),
)
return
df
def
find_row
(
df
,
conditions
:
Dict
[
str
,
Any
]):
df_sub
=
df
.
filter
(
functools
.
reduce
(
lambda
a
,
b
:
a
&
b
,
[
pl
.
col
(
col
)
==
_cast_to_polars_dtype
(
conditions
[
col
],
df
.
schema
[
col
])
for
col
in
conditions
.
keys
()
],
)
)
assert
len
(
df_sub
)
<=
1
return
df_sub
.
to_dicts
()[
0
]
if
len
(
df_sub
)
>
0
else
None
def
_cast_to_polars_dtype
(
value
,
target_dtype
):
if
target_dtype
in
(
pl
.
Int64
,
pl
.
Int32
,
pl
.
UInt64
,
pl
.
UInt32
):
return
int
(
value
)
elif
target_dtype
in
(
pl
.
Float64
,
pl
.
Float32
):
return
float
(
value
)
elif
target_dtype
==
pl
.
Boolean
:
return
bool
(
value
)
elif
target_dtype
==
pl
.
String
:
return
str
(
value
)
else
:
return
value
dump_loader
=
DumpLoader
()
python/sglang/srt/debug_utils/dumper.py
View file @
9f00ec44
...
...
@@ -53,7 +53,7 @@ class _Dumper:
if
self
.
_partial_name
is
None
:
self
.
_partial_name
=
_get_partial_name
()
rank
=
dist
.
get_rank
()
rank
=
_
get_rank
()
full_kwargs
=
dict
(
forward_pass_id
=
self
.
_forward_pass_id
,
rank
=
rank
,
...
...
@@ -80,12 +80,20 @@ class _Dumper:
def
_get_partial_name
():
rank
=
dist
.
get_rank
()
rank
=
_
get_rank
()
object_list
=
[
str
(
time
.
time
())
if
rank
==
0
else
None
]
dist
.
broadcast_object_list
(
object_list
,
device
=
"cuda"
)
if
dist
.
is_initialized
():
dist
.
broadcast_object_list
(
object_list
,
device
=
"cuda"
)
return
object_list
[
0
]
def
_get_rank
():
if
dist
.
is_initialized
():
return
dist
.
get_rank
()
else
:
return
0
def
get_truncated_value
(
value
):
if
value
is
None
:
return
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment