Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ab95d35f
Unverified
Commit
ab95d35f
authored
Oct 31, 2025
by
Yuhong Guo
Committed by
GitHub
Oct 31, 2025
Browse files
feat: Add Non-intrusive Tensor Dumping for Model Inference (#10566)
parent
34c286b8
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
272 additions
and
15 deletions
+272
-15
python/sglang/srt/debug_utils/tensor_dump_forward_hook.py
python/sglang/srt/debug_utils/tensor_dump_forward_hook.py
+149
-0
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+1
-14
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+12
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+16
-1
test/srt/debug_utils/test_tensor_dump_forward_hook.py
test/srt/debug_utils/test_tensor_dump_forward_hook.py
+93
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
No files found.
python/sglang/srt/debug_utils/tensor_dump_forward_hook.py
0 → 100644
View file @
ab95d35f
"""
This file provides a function `register_forward_hook_for_model` that registers a forward hook on every operator of the model.
After registration, during model inference, all tensors generated throughout the forward pass will be recorded.
Usage:
Specify the output directory for dumping tensors using the argument `--debug-tensor-dump-output-folder`.
A separate directory will be created for each GPU rank, named in the format `f"TP{tp_rank}_PP{pp_rank}_Rank{rank}_pid{pid}"`.
Each complete forward pass of the model generates a `.pt` file named `f"Pass{pass_num}.pt"`, which can be loaded using `torch.load`.
The file contains a series of key-value pairs, where the keys correspond to operator names in the model
(similar to those in model.safetensors.index.json), and the values are the outputs produced by the respective operators.
"""
import
logging
import
os
from
pathlib
import
Path
import
torch
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
logger
=
logging
.
getLogger
(
__name__
)
class
TensorDumper
:
def
__init__
(
self
,
dump_dir
:
str
,
dump_layers
:
int
,
tp_size
:
int
,
tp_rank
:
int
,
pp_rank
:
int
):
self
.
_dump_layers
=
dump_layers
self
.
_forward_pass_id
=
0
self
.
_pid
=
os
.
getpid
()
self
.
_current_tensors
=
{}
self
.
_base_dir
=
Path
(
dump_dir
)
rank
=
tp_size
*
pp_rank
+
tp_rank
self
.
_process_dir
=
(
self
.
_base_dir
/
f
"TP
{
tp_rank
}
_PP
{
pp_rank
}
_Rank
{
rank
}
_pid
{
self
.
_pid
}
"
)
self
.
_process_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
def
get_dump_dir
(
self
):
return
str
(
self
.
_process_dir
)
def
add_tensor
(
self
,
name
,
tensor_item
):
if
isinstance
(
tensor_item
,
(
tuple
,
list
)):
tensors
=
[
t
.
cpu
()
for
t
in
tensor_item
if
t
is
not
None
]
if
len
(
tensors
)
==
1
:
self
.
_current_tensors
[
name
]
=
tensors
[
0
]
else
:
self
.
_current_tensors
[
name
]
=
tensors
elif
isinstance
(
tensor_item
,
torch
.
Tensor
):
self
.
_current_tensors
[
name
]
=
tensor_item
.
cpu
()
elif
isinstance
(
tensor_item
,
LogitsProcessorOutput
):
self
.
_current_tensors
[
name
]
=
tensor_item
.
next_token_logits
.
cpu
()
elif
isinstance
(
tensor_item
,
ForwardBatch
):
self
.
_current_tensors
[
name
+
".forward_batch_info.input_ids"
]
=
(
tensor_item
.
input_ids
.
cpu
()
)
self
.
_current_tensors
[
name
+
".forward_batch_info.seq_lens"
]
=
(
tensor_item
.
seq_lens
.
cpu
()
)
self
.
_current_tensors
[
name
+
".forward_batch_info.positions"
]
=
(
tensor_item
.
positions
.
cpu
()
)
elif
isinstance
(
tensor_item
,
PPProxyTensors
):
for
tensor_name
in
tensor_item
.
tensors
.
keys
():
self
.
_current_tensors
[
name
+
".pp_proxy_tensors."
+
tensor_name
]
=
(
tensor_item
.
tensors
[
tensor_name
].
cpu
()
)
else
:
logger
.
warning
(
f
"Unsupported type:
{
type
(
tensor_item
)
}
:
{
tensor_item
}
"
)
def
dump_current_tensors
(
self
):
if
len
(
self
.
_current_tensors
)
==
0
:
return
tensor_file_for_pass
=
self
.
_process_dir
/
f
"Pass
{
self
.
_forward_pass_id
:
05
d
}
.pt"
logger
.
info
(
f
"Dump
{
self
.
_forward_pass_id
:
05
d
}
th pass to
{
tensor_file_for_pass
}
"
)
torch
.
save
(
self
.
_current_tensors
,
str
(
tensor_file_for_pass
))
self
.
_current_tensors
=
{}
self
.
_forward_pass_id
+=
1
def
_add_hook_recursive
(
self
,
model
,
prefix
,
top_level_module_name
,
layers_module_name
):
model_top_level_module_matched
=
False
layers_prefix
=
top_level_module_name
+
"."
+
layers_module_name
for
name
,
module
in
model
.
_modules
.
items
():
top_level_model
=
False
if
len
(
prefix
)
==
0
:
cur_name
=
name
if
cur_name
==
top_level_module_name
:
model_top_level_module_matched
=
True
top_level_model
=
True
else
:
cur_name
=
prefix
+
"."
+
name
if
self
.
_dump_layers
>
0
and
name
.
isdigit
()
and
prefix
==
layers_prefix
:
# If we only need n layers, skip the reset layers.
# Most models' layout is like model.layers.0.
cur_layer
=
int
(
name
)
if
cur_layer
>=
self
.
_dump_layers
:
continue
if
module
is
not
None
:
_
,
sub_count
=
self
.
_add_hook_recursive
(
module
,
cur_name
,
top_level_module_name
,
layers_module_name
)
if
sub_count
==
0
or
top_level_model
:
# Avoid duplicated output hooks, e.g. self_attn may contain:
# self_attn.qkv_proj, self_attn.attn & self_attn.o_proj.
# Therefore, we do not need to add output hooks for self_attn,
# since the output of self_attn should be the same to self_attn.o_proj.
module
.
register_forward_hook
(
self
.
_dump_hook
(
cur_name
,
top_level_model
)
)
return
model_top_level_module_matched
,
len
(
model
.
_modules
.
items
())
def
_dump_hook
(
self
,
tensor_name
,
do_dump
):
def
inner_dump_hook
(
module
,
input
,
output
):
if
do_dump
:
# This is the top-level model, so we will record the input for it.
for
item
in
input
:
if
isinstance
(
item
,
ForwardBatch
):
self
.
add_tensor
(
tensor_name
,
item
)
self
.
dump_current_tensors
()
if
output
is
not
None
:
self
.
add_tensor
(
tensor_name
,
output
)
return
inner_dump_hook
def
register_forward_hook_for_model
(
model
,
dump_dir
:
str
,
dump_layers
:
int
,
tp_size
:
int
,
tp_rank
:
int
,
pp_rank
:
int
):
tensor_dumper
=
TensorDumper
(
dump_dir
,
dump_layers
,
tp_size
,
tp_rank
,
pp_rank
)
# Most models have the layerout like:
# XxxxForCausalLM
# (model): XxxxModel
# (layers): ModuleList
# If the model is not constructed with this layout,
# environment variable can be used to specify the module names.
top_level_module_name
=
os
.
getenv
(
"TENSOR_DUMP_TOP_LEVEL_MODULE_NAME"
,
"model"
)
layers_module_name
=
os
.
getenv
(
"TENSOR_DUMP_LAYERS_MODULE_NAME"
,
"layers"
)
model_top_level_module_matched
,
_
=
tensor_dumper
.
_add_hook_recursive
(
model
,
""
,
top_level_module_name
,
layers_module_name
)
assert
(
model_top_level_module_matched
),
f
"model should have a module named
{
top_level_module_name
}
"
return
tensor_dumper
python/sglang/srt/layers/logits_processor.py
View file @
ab95d35f
...
@@ -38,7 +38,6 @@ from sglang.srt.layers.dp_attention import (
...
@@ -38,7 +38,6 @@ from sglang.srt.layers.dp_attention import (
get_dp_device
,
get_dp_device
,
get_dp_dtype
,
get_dp_dtype
,
get_dp_hidden_size
,
get_dp_hidden_size
,
get_local_attention_dp_size
,
)
)
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.model_executor.forward_batch_info
import
(
from
sglang.srt.model_executor.forward_batch_info
import
(
...
@@ -47,7 +46,7 @@ from sglang.srt.model_executor.forward_batch_info import (
...
@@ -47,7 +46,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardMode
,
ForwardMode
,
)
)
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
dump_to_file
,
is_npu
,
use_intel_amx_backend
from
sglang.srt.utils
import
is_npu
,
use_intel_amx_backend
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -252,10 +251,6 @@ class LogitsProcessor(nn.Module):
...
@@ -252,10 +251,6 @@ class LogitsProcessor(nn.Module):
):
):
self
.
final_logit_softcapping
=
None
self
.
final_logit_softcapping
=
None
self
.
debug_tensor_dump_output_folder
=
(
get_global_server_args
().
debug_tensor_dump_output_folder
)
def
compute_logprobs_for_multi_item_scoring
(
def
compute_logprobs_for_multi_item_scoring
(
self
,
self
,
input_ids
,
input_ids
,
...
@@ -463,14 +458,6 @@ class LogitsProcessor(nn.Module):
...
@@ -463,14 +458,6 @@ class LogitsProcessor(nn.Module):
logits
[
sample_indices
]
if
sample_indices
is
not
None
else
logits
logits
[
sample_indices
]
if
sample_indices
is
not
None
else
logits
)
)
if
self
.
debug_tensor_dump_output_folder
:
assert
(
not
self
.
do_tensor_parallel_all_gather
or
get_local_attention_dp_size
()
==
1
),
"dp attention + sharded lm_head doesn't support full logits"
full_logits
=
self
.
_get_logits
(
hidden_states
,
lm_head
,
logits_metadata
)
dump_to_file
(
self
.
debug_tensor_dump_output_folder
,
"logits"
,
full_logits
)
hidden_states_to_store
:
Optional
[
torch
.
Tensor
]
=
None
hidden_states_to_store
:
Optional
[
torch
.
Tensor
]
=
None
if
logits_metadata
.
capture_hidden_mode
.
need_capture
():
if
logits_metadata
.
capture_hidden_mode
.
need_capture
():
if
logits_metadata
.
capture_hidden_mode
.
is_full
():
if
logits_metadata
.
capture_hidden_mode
.
is_full
():
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
ab95d35f
...
@@ -40,6 +40,9 @@ from sglang.srt.configs.model_config import (
...
@@ -40,6 +40,9 @@ from sglang.srt.configs.model_config import (
)
)
from
sglang.srt.configs.update_config
import
adjust_config_with_unaligned_cpu_tp
from
sglang.srt.configs.update_config
import
adjust_config_with_unaligned_cpu_tp
from
sglang.srt.constants
import
GPU_MEMORY_TYPE_WEIGHTS
from
sglang.srt.constants
import
GPU_MEMORY_TYPE_WEIGHTS
from
sglang.srt.debug_utils.tensor_dump_forward_hook
import
(
register_forward_hook_for_model
,
)
from
sglang.srt.distributed
import
(
from
sglang.srt.distributed
import
(
get_pp_group
,
get_pp_group
,
get_tp_group
,
get_tp_group
,
...
@@ -791,6 +794,15 @@ class ModelRunner:
...
@@ -791,6 +794,15 @@ class ModelRunner:
f
"avail mem=
{
after_avail_memory
:.
2
f
}
GB, "
f
"avail mem=
{
after_avail_memory
:.
2
f
}
GB, "
f
"mem usage=
{
self
.
weight_load_mem_usage
:.
2
f
}
GB."
f
"mem usage=
{
self
.
weight_load_mem_usage
:.
2
f
}
GB."
)
)
if
self
.
server_args
.
debug_tensor_dump_output_folder
is
not
None
:
register_forward_hook_for_model
(
self
.
model
,
self
.
server_args
.
debug_tensor_dump_output_folder
,
self
.
server_args
.
debug_tensor_dump_layers
,
self
.
tp_size
,
self
.
tp_rank
,
self
.
pp_rank
,
)
if
self
.
server_args
.
elastic_ep_backend
==
"mooncake"
:
if
self
.
server_args
.
elastic_ep_backend
==
"mooncake"
:
# Mooncake does not support `monitored_barrier`
# Mooncake does not support `monitored_barrier`
...
...
python/sglang/srt/server_args.py
View file @
ab95d35f
...
@@ -511,6 +511,9 @@ class ServerArgs:
...
@@ -511,6 +511,9 @@ class ServerArgs:
# Debug tensor dumps
# Debug tensor dumps
debug_tensor_dump_output_folder
:
Optional
[
str
]
=
None
debug_tensor_dump_output_folder
:
Optional
[
str
]
=
None
# -1 mean dump all layers.
debug_tensor_dump_layers
:
int
=
-
1
# TODO(guoyuhong): clean the old dumper code.
debug_tensor_dump_input_file
:
Optional
[
str
]
=
None
debug_tensor_dump_input_file
:
Optional
[
str
]
=
None
debug_tensor_dump_inject
:
bool
=
False
debug_tensor_dump_inject
:
bool
=
False
...
@@ -1784,7 +1787,13 @@ class ServerArgs:
...
@@ -1784,7 +1787,13 @@ class ServerArgs:
)
)
def
_handle_other_validations
(
self
):
def
_handle_other_validations
(
self
):
pass
# Handle model inference tensor dump.
if
self
.
debug_tensor_dump_output_folder
is
not
None
:
logger
.
warning
(
"Cuda graph and server warmup are disabled because of using tensor dump mode"
)
self
.
disable_cuda_graph
=
True
self
.
skip_server_warmup
=
True
@
staticmethod
@
staticmethod
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
...
@@ -3375,6 +3384,12 @@ class ServerArgs:
...
@@ -3375,6 +3384,12 @@ class ServerArgs:
default
=
ServerArgs
.
debug_tensor_dump_output_folder
,
default
=
ServerArgs
.
debug_tensor_dump_output_folder
,
help
=
"The output folder for dumping tensors."
,
help
=
"The output folder for dumping tensors."
,
)
)
parser
.
add_argument
(
"--debug-tensor-dump-layers"
,
type
=
int
,
default
=-
1
,
help
=
"The layer number for dumping tensors."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--debug-tensor-dump-input-file"
,
"--debug-tensor-dump-input-file"
,
type
=
str
,
type
=
str
,
...
...
test/srt/debug_utils/test_tensor_dump_forward_hook.py
0 → 100644
View file @
ab95d35f
import
unittest
import
torch
from
torch
import
nn
from
sglang.srt.debug_utils.tensor_dump_forward_hook
import
(
register_forward_hook_for_model
,
)
from
sglang.srt.distributed.parallel_state
import
(
init_distributed_environment
,
initialize_model_parallel
,
)
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
LinearBase
from
sglang.srt.models.qwen2
import
Qwen2MLP
from
sglang.srt.utils
import
add_prefix
TEST_HIDDEN_SIZE
=
32
class
SimpleModel
(
nn
.
Module
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
TEST_HIDDEN_SIZE
self
.
rms_norm_eps
=
1e-5
self
.
mlp
=
Qwen2MLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
self
.
hidden_size
,
hidden_act
=
"silu"
,
quant_config
=
None
,
prefix
=
add_prefix
(
"mlp"
,
""
),
)
self
.
layernorm
=
RMSNorm
(
self
.
hidden_size
,
eps
=
self
.
rms_norm_eps
)
@
torch
.
no_grad
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
layernorm
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
class
MockCausalLM
(
nn
.
Module
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
model
=
SimpleModel
()
@
torch
.
no_grad
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
(
hidden_states
)
def
init_weights
(
module
):
if
isinstance
(
module
,
LinearBase
):
torch
.
nn
.
init
.
uniform_
(
module
.
weight
)
if
module
.
bias
is
not
None
:
torch
.
nn
.
init
.
zeros_
(
module
.
bias
)
elif
isinstance
(
module
,
RMSNorm
):
torch
.
nn
.
init
.
ones_
(
module
.
weight
)
def
test_model_forward_dump
(
tmp_path
):
init_distributed_environment
(
backend
=
"nccl"
,
world_size
=
1
,
rank
=
0
,
local_rank
=
0
,
distributed_init_method
=
"tcp://127.0.0.1:2646"
,
)
initialize_model_parallel
()
model
=
MockCausalLM
()
model
.
apply
(
init_weights
)
model
=
model
.
cuda
().
bfloat16
()
dumper
=
register_forward_hook_for_model
(
model
,
tmp_path
/
"sglang_dump"
,
-
1
,
0
,
0
,
0
)
dir_path
=
dumper
.
get_dump_dir
()
inp
=
torch
.
randn
(
4
,
TEST_HIDDEN_SIZE
,
dtype
=
torch
.
bfloat16
)
*
0.01
result
=
model
(
inp
.
cuda
())
data
=
torch
.
load
(
f
"
{
dir_path
}
/Pass00000.pt"
)
assert
"model.layernorm"
in
data
assert
"model.mlp.down_proj"
in
data
assert
torch
.
allclose
(
data
[
"model.mlp.down_proj"
],
result
.
cpu
(),
rtol
=
1e-5
,
atol
=
1e-5
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/run_suite.py
View file @
ab95d35f
...
@@ -14,6 +14,7 @@ class TestFile:
...
@@ -14,6 +14,7 @@ class TestFile:
# NOTE: please sort the test cases alphabetically by the test file name
# NOTE: please sort the test cases alphabetically by the test file name
suites
=
{
suites
=
{
"per-commit-1-gpu"
:
[
"per-commit-1-gpu"
:
[
TestFile
(
"debug_utils/test_tensor_dump_forward_hook.py"
,
15
),
TestFile
(
"function_call/test_json_schema_constraint.py"
,
30
),
TestFile
(
"function_call/test_json_schema_constraint.py"
,
30
),
TestFile
(
"hicache/test_hicache.py"
,
116
),
TestFile
(
"hicache/test_hicache.py"
,
116
),
TestFile
(
"hicache/test_hicache_eagle.py"
,
150
),
TestFile
(
"hicache/test_hicache_eagle.py"
,
150
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment