Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ed0fdbf3
Unverified
Commit
ed0fdbf3
authored
Jul 27, 2025
by
fzyzcjy
Committed by
GitHub
Jul 27, 2025
Browse files
Tool to dump and compare internal activation tensors (#7976)
parent
b602f423
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
239 additions
and
0 deletions
+239
-0
python/sglang/srt/debug_utils/__init__.py
python/sglang/srt/debug_utils/__init__.py
+0
-0
python/sglang/srt/debug_utils/dump_comparator.py
python/sglang/srt/debug_utils/dump_comparator.py
+131
-0
python/sglang/srt/debug_utils/dumper.py
python/sglang/srt/debug_utils/dumper.py
+108
-0
No files found.
python/sglang/srt/debug_utils/__init__.py
0 → 100644
View file @
ed0fdbf3
python/sglang/srt/debug_utils/dump_comparator.py
0 → 100644
View file @
ed0fdbf3
import
argparse
import
functools
import
re
from
pathlib
import
Path
import
polars
as
pl
import
torch
from
sglang.srt.debug_utils.dumper
import
get_truncated_value
def
main
(
args
):
df_target
=
read_meta
(
args
.
target_path
)
df_target
=
df_target
.
sort
(
"rank"
,
"dump_index"
)
df_target
=
df_target
.
filter
(
(
pl
.
col
(
"forward_pass_id"
)
>=
args
.
start_id
)
&
(
pl
.
col
(
"forward_pass_id"
)
<=
args
.
end_id
)
)
assert
all
(
c
in
df_target
.
columns
for
c
in
[
"rank"
,
"forward_pass_id"
,
"dump_index"
,
"name"
]
)
df_baseline
=
read_meta
(
args
.
baseline_path
)
print
(
"df_target"
,
df_target
)
print
(
"df_baseline"
,
df_baseline
)
for
row
in
df_target
.
iter_rows
(
named
=
True
):
rows_baseline
=
df_baseline
.
filter
(
(
pl
.
col
(
"forward_pass_id"
)
==
row
[
"forward_pass_id"
]
-
args
.
start_id
+
args
.
baseline_start_id
)
&
functools
.
reduce
(
lambda
a
,
b
:
a
&
b
,
[
pl
.
col
(
col
)
==
row
[
col
]
for
col
in
row
.
keys
()
if
col
not
in
[
"forward_pass_id"
,
"dump_index"
,
"filename"
]
],
)
)
assert
len
(
rows_baseline
)
==
1
,
f
"
{
rows_baseline
=
}
"
row_baseline
=
rows_baseline
.
to_dicts
()[
0
]
path_baseline
=
Path
(
args
.
baseline_path
)
/
row_baseline
[
"filename"
]
path_target
=
Path
(
args
.
target_path
)
/
row
[
"filename"
]
print
(
f
"Check: target=
{
str
(
path_target
)
}
baseline=
{
str
(
path_baseline
)
}
"
)
check_tensor_pair
(
path_baseline
=
path_baseline
,
path_target
=
path_target
)
print
()
def
read_meta
(
directory
):
directory
=
Path
(
directory
)
assert
directory
.
is_dir
(),
f
"
{
directory
=
}
should be a directory"
rows
=
[]
for
p
in
directory
.
glob
(
"*.pt"
):
full_kwargs
=
{}
for
kv
in
p
.
stem
.
split
(
"___"
):
k
,
v
=
kv
.
split
(
"="
)
full_kwargs
[
k
]
=
v
rows
.
append
(
{
"filename"
:
str
(
p
.
name
),
**
full_kwargs
,
}
)
df
=
pl
.
DataFrame
(
rows
)
df
=
df
.
with_columns
(
pl
.
col
(
"forward_pass_id"
).
cast
(
int
),
pl
.
col
(
"rank"
).
cast
(
int
),
)
return
df
def
check_tensor_pair
(
path_baseline
,
path_target
):
x_baseline
=
torch
.
load
(
path_baseline
,
weights_only
=
True
)
x_target
=
torch
.
load
(
path_target
,
weights_only
=
True
)
print
(
f
"[shape]
{
x_baseline
.
shape
}
vs
{
x_target
.
shape
}
\t
"
f
"[dtype]
{
x_baseline
.
dtype
}
vs
{
x_target
.
dtype
}
"
)
if
x_baseline
.
shape
!=
x_target
.
shape
:
print
(
f
"❌ Shape mismatch"
)
return
raw_abs_diff
=
(
x_target
-
x_baseline
).
abs
()
max_abs_diff
=
raw_abs_diff
.
max
().
item
()
mean_abs_diff
=
raw_abs_diff
.
mean
().
item
()
rel_diff
=
_calc_rel_diff
(
x_target
,
x_baseline
)
needs_print
=
max_abs_diff
>
1e-3
print
(
"
\t
"
.
join
(
f
"
{
'❌'
if
value
>
1e-3
else
'✅'
}
{
name
}
=
{
value
}
"
for
name
,
value
in
[
(
"rel_diff"
,
rel_diff
),
(
"max_abs_diff"
,
max_abs_diff
),
(
"mean_abs_diff"
,
mean_abs_diff
),
]
)
)
if
needs_print
:
print
(
f
"x_baseline(sample)=
{
get_truncated_value
(
x_baseline
)
}
"
)
print
(
f
"x_target(sample)=
{
get_truncated_value
(
x_target
)
}
"
)
# Copied from DeepGEMM
def
_calc_rel_diff
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
):
x
,
y
=
x
.
double
(),
y
.
double
()
denominator
=
(
x
*
x
+
y
*
y
).
sum
()
sim
=
2
*
(
x
*
y
).
sum
()
/
denominator
return
1
-
sim
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--baseline-path"
,
type
=
str
)
parser
.
add_argument
(
"--target-path"
,
type
=
str
)
parser
.
add_argument
(
"--start-id"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--end-id"
,
type
=
int
,
default
=
1000000
)
parser
.
add_argument
(
"--baseline-start-id"
,
type
=
int
,
default
=
0
)
args
=
parser
.
parse_args
()
main
(
args
)
python/sglang/srt/debug_utils.py
→
python/sglang/srt/debug_utils
/dumper
.py
View file @
ed0fdbf3
import
os
import
os
import
time
import
time
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Optional
import
torch
import
torch
import
torch.distributed
as
dist
from
sglang.srt.utils
import
get_bool_env_var
class
_Dumper
:
class
_Dumper
:
"""Utility to dump tensors, which can be useful when comparison checking models.
"""Utility to dump tensors, which can be useful when comparison checking models.
Example usage:
Example usage:
debug_utils.dumper.dump("layer_start_hidden_states", hidden_states, layer_id=self.layer_id)
dumper.on_forward_pass_start()
dumper.dump("layer_start__hidden_states", hidden_states, layer_id=self.layer_id)
Import from non-SGLang system:
```
import sys
sys.path.append("/YOUR_PATH/sglang/python/sglang/srt/debug_utils")
from dumper import dumper
```
Related: `sglang.srt.debug_utils.dump_comparator` for dump comparison
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
self
.
_enable
=
get_bool_env_var
(
"SGLANG_DUMPER_ENABLE"
,
"true"
)
# Do not import `sglang` to make this file standalone
self
.
_enable
=
bool
(
int
(
os
.
environ
.
get
(
"SGLANG_DUMPER_ENABLE"
,
"1"
)))
self
.
_base_dir
=
Path
(
os
.
environ
.
get
(
"SGLANG_DUMPER_DIR"
,
"/tmp"
))
self
.
_base_dir
=
Path
(
os
.
environ
.
get
(
"SGLANG_DUMPER_DIR"
,
"/tmp"
))
self
.
_enable_write_file
=
get_bool_env_var
(
"SGLANG_DUMPER_WRITE_FILE"
,
"1"
)
self
.
_enable_write_file
=
bool
(
self
.
_partial_name
=
str
(
time
.
time
())
int
(
os
.
environ
.
get
(
"SGLANG_DUMPER_WRITE_FILE"
,
"1"
))
self
.
forward_pass_id
=
None
)
self
.
_partial_name
:
Optional
[
str
]
=
None
self
.
_dump_index
=
0
self
.
_forward_pass_id
=
0
def
on_forward_pass_start
(
self
):
self
.
_forward_pass_id
+=
1
print
(
f
"[Dumper] [
{
time
.
time
()
}
] on_forward_pass_start id=
{
self
.
_forward_pass_id
}
"
)
def
dump
(
self
,
name
,
value
,
**
kwargs
):
def
dump
(
self
,
name
,
value
,
**
kwargs
):
if
not
self
.
_enable
:
if
not
self
.
_enable
:
return
return
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
assert
(
self
.
_forward_pass_id
>=
1
),
"Do you forget to call `dumper.on_forward_pass_start()`?"
self
.
_dump_index
+=
1
if
self
.
_partial_name
is
None
:
self
.
_partial_name
=
_get_partial_name
()
rank
=
get_tensor_model_parallel
_rank
()
rank
=
dist
.
get
_rank
()
full_kwargs
=
dict
(
full_kwargs
=
dict
(
forward_pass_id
=
self
.
forward_pass_id
,
forward_pass_id
=
self
.
_forward_pass_id
,
rank
=
rank
,
name
=
name
,
name
=
name
,
dump_index
=
self
.
_dump_index
,
**
kwargs
,
**
kwargs
,
)
)
full_filename
=
"___"
.
join
(
f
"
{
k
}
=
{
v
}
"
for
k
,
v
in
full_kwargs
.
items
())
+
".pt"
full_filename
=
"___"
.
join
(
f
"
{
k
}
=
{
v
}
"
for
k
,
v
in
full_kwargs
.
items
())
+
".pt"
path
=
(
path
=
self
.
_base_dir
/
f
"sglang_dump_
{
self
.
_partial_name
}
"
/
full_filename
self
.
_base_dir
/
f
"sglang_dump_
{
self
.
_partial_name
}
_
{
rank
}
"
/
full_filename
)
sample_value
=
self
.
_get_sample
_value
(
name
,
value
)
sample_value
=
get_truncated
_value
(
value
)
print
(
print
(
f
"[
{
rank
}
,
{
time
.
time
()
}
]
{
path
}
"
f
"
[Dumper]
[
{
rank
}
,
{
time
.
time
()
}
]
{
path
}
"
f
"type=
{
type
(
value
)
}
"
f
"type=
{
type
(
value
)
}
"
f
"shape=
{
value
.
shape
if
isinstance
(
value
,
torch
.
Tensor
)
else
None
}
"
f
"shape=
{
value
.
shape
if
isinstance
(
value
,
torch
.
Tensor
)
else
None
}
"
f
"dtype=
{
value
.
dtype
if
isinstance
(
value
,
torch
.
Tensor
)
else
None
}
"
f
"dtype=
{
value
.
dtype
if
isinstance
(
value
,
torch
.
Tensor
)
else
None
}
"
...
@@ -52,23 +78,31 @@ class _Dumper:
...
@@ -52,23 +78,31 @@ class _Dumper:
path
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
path
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
torch
.
save
(
value
,
str
(
path
))
torch
.
save
(
value
,
str
(
path
))
def
_get_sample_value
(
self
,
name
,
value
):
if
value
is
None
:
return
None
if
isinstance
(
value
,
tuple
):
def
_get_partial_name
():
return
[
self
.
_get_sample_value
(
name
,
x
)
for
x
in
value
]
rank
=
dist
.
get_rank
()
object_list
=
[
str
(
time
.
time
())
if
rank
==
0
else
None
]
dist
.
broadcast_object_list
(
object_list
,
device
=
"cuda"
)
return
object_list
[
0
]
def
get_truncated_value
(
value
):
if
value
is
None
:
return
None
if
isinstance
(
value
,
tuple
):
return
[
get_truncated_value
(
x
)
for
x
in
value
]
if
not
isinstance
(
value
,
torch
.
Tensor
):
if
not
isinstance
(
value
,
torch
.
Tensor
):
return
None
return
None
if
value
.
numel
()
<
200
:
if
value
.
numel
()
<
200
:
return
value
return
value
slices
=
[
slices
=
[
slice
(
0
,
5
)
if
dim_size
>
200
else
slice
(
None
)
for
dim_size
in
value
.
shape
slice
(
0
,
5
)
if
dim_size
>
200
else
slice
(
None
)
for
dim_size
in
value
.
shape
]
]
return
value
[
tuple
(
slices
)]
return
value
[
tuple
(
slices
)]
dumper
=
_Dumper
()
dumper
=
_Dumper
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment