Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
608b5565
Unverified
Commit
608b5565
authored
Jan 31, 2026
by
Angela Yi
Committed by
GitHub
Jan 31, 2026
Browse files
[ez] Add structured torch.compile logs (#33213)
Signed-off-by:
angelayi
<
yiangela7@gmail.com
>
parent
f0a1c845
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
227 additions
and
0 deletions
+227
-0
tests/compile/test_structured_logging.py
tests/compile/test_structured_logging.py
+121
-0
vllm/compilation/backends.py
vllm/compilation/backends.py
+58
-0
vllm/compilation/piecewise_backend.py
vllm/compilation/piecewise_backend.py
+48
-0
No files found.
tests/compile/test_structured_logging.py
0 → 100644
View file @
608b5565
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
unittest.mock
import
patch
import
pytest
import
regex
as
re
import
torch
from
torch
import
nn
import
tests.compile.silly_attention
# noqa
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.config.compilation
import
(
CompilationConfig
,
CompilationMode
,
CUDAGraphMode
,
)
from
vllm.config.scheduler
import
SchedulerConfig
from
vllm.forward_context
import
set_forward_context
MLP_SIZE
=
64
@
support_torch_compile
class
SimpleModel
(
nn
.
Module
):
"""A simple model with a splitting op for piecewise compilation."""
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
**
kwargs
):
super
().
__init__
()
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
x
+
x
attn_output
=
torch
.
empty_like
(
x
)
torch
.
ops
.
silly
.
attention
(
x
,
x
,
x
,
attn_output
)
x
=
attn_output
*
2
return
x
class
TraceStructuredCapture
:
"""Captures trace_structured calls for testing."""
def
__init__
(
self
):
self
.
calls
:
list
[
dict
]
=
[]
def
__call__
(
self
,
event_type
:
str
,
metadata_fn
=
None
,
payload_fn
=
None
,
**
kwargs
):
"""Capture a trace_structured call."""
metadata
=
metadata_fn
()
if
metadata_fn
else
{}
self
.
calls
.
append
(
{
"event_type"
:
event_type
,
"metadata"
:
metadata
,
}
)
def
get
(
self
,
event_type
:
str
,
name_pattern
:
str
)
->
list
[
dict
]:
"""Get all calls with the given event type and name matching pattern.
Args:
event_type: The event type to filter by (e.g., "artifact", "graph_dump")
name_pattern: Regex pattern to match against the artifact name
"""
regex
=
re
.
compile
(
name_pattern
)
return
[
c
for
c
in
self
.
calls
if
c
[
"event_type"
]
==
event_type
and
regex
.
fullmatch
(
c
.
get
(
"metadata"
,
{}).
get
(
"name"
,
""
))
]
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"CUDA required"
)
def
test_vllm_structured_logging_artifacts
(
use_fresh_inductor_cache
):
"""Test that all expected vLLM artifacts are logged during compilation."""
torch
.
set_default_device
(
"cuda"
)
capture
=
TraceStructuredCapture
()
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
mode
=
CompilationMode
.
VLLM_COMPILE
,
cudagraph_mode
=
CUDAGraphMode
.
PIECEWISE
,
compile_sizes
=
[
8
],
splitting_ops
=
[
"silly::attention"
],
),
scheduler_config
=
SchedulerConfig
(
max_num_seqs
=
8
,
max_model_len
=
8192
,
is_encoder_decoder
=
False
,
),
)
# Patch trace_structured to capture calls
with
(
patch
(
"vllm.compilation.backends.trace_structured"
,
capture
),
patch
(
"vllm.compilation.piecewise_backend.trace_structured"
,
capture
),
set_current_vllm_config
(
vllm_config
),
):
model
=
SimpleModel
(
vllm_config
=
vllm_config
,
prefix
=
"test"
)
with
set_forward_context
({},
vllm_config
=
vllm_config
):
model
(
torch
.
randn
(
8
,
MLP_SIZE
))
config_artifacts
=
capture
.
get
(
"artifact"
,
"vllm_compilation_config"
)
assert
len
(
config_artifacts
)
==
1
,
(
f
"Expected 1 vllm_compilation_config, got
{
len
(
config_artifacts
)
}
"
)
vllm_piecewise_split_graph
=
capture
.
get
(
"graph_dump"
,
"vllm_piecewise_split_graph"
)
assert
len
(
vllm_piecewise_split_graph
)
==
1
,
(
"Expected 1 toplevel piecewise split graph, "
f
"got
{
len
(
vllm_piecewise_split_graph
)
}
"
)
compile_start_artifacts
=
capture
.
get
(
"artifact"
,
"vllm_piecewise_compile_start"
)
assert
len
(
compile_start_artifacts
)
==
2
,
(
"Expected 2 vllm_piecewise_compile_start "
"(one for dynamic ranges, one for compile size), "
f
"got
{
len
(
compile_start_artifacts
)
}
"
)
submod_dumps
=
capture
.
get
(
"graph_dump"
,
r
"vllm_submod_.*"
)
assert
len
(
submod_dumps
)
==
2
,
(
"Expected 2 submods (one before attention, one after attention), "
f
"got
{
len
(
submod_dumps
)
}
"
)
vllm/compilation/backends.py
View file @
608b5565
...
@@ -19,6 +19,7 @@ from typing import Any
...
@@ -19,6 +19,7 @@ from typing import Any
import
torch
import
torch
import
torch.fx
as
fx
import
torch.fx
as
fx
from
torch._dispatch.python
import
enable_python_dispatcher
from
torch._dispatch.python
import
enable_python_dispatcher
from
torch._logging._internal
import
trace_structured
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.compilation.inductor_pass
import
pass_context
from
vllm.compilation.inductor_pass
import
pass_context
...
@@ -529,6 +530,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
...
@@ -529,6 +530,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
sym_shape_indices
,
sym_shape_indices
,
self
.
vllm_backend
,
self
.
vllm_backend
,
graph_returns_tuple
(
submod
),
graph_returns_tuple
(
submod
),
submod_name
=
target
,
)
)
self
.
module
.
__dict__
[
target
]
=
wrap_with_cudagraph_if_needed
(
self
.
module
.
__dict__
[
target
]
=
wrap_with_cudagraph_if_needed
(
...
@@ -735,12 +737,61 @@ class VllmBackend:
...
@@ -735,12 +737,61 @@ class VllmBackend:
)
)
self
.
inductor_config
[
self
.
pass_key
]
=
self
.
pass_manager
self
.
inductor_config
[
self
.
pass_key
]
=
self
.
pass_manager
def
_log_compilation_config
(
self
):
"""Log vLLM compilation config for TORCH_TRACE/tlparse."""
cc
=
self
.
compilation_config
pass_cfg
=
cc
.
pass_config
# Helper to convert lists to comma-separated strings for tlparse display
def
list_to_str
(
lst
:
list
|
None
)
->
str
:
if
lst
is
None
:
return
""
return
", "
.
join
(
str
(
x
)
for
x
in
lst
)
# Get enabled passes by introspecting dataclass fields
enabled_passes
=
[
f
.
name
for
f
in
dataclasses
.
fields
(
pass_cfg
)
if
isinstance
(
getattr
(
pass_cfg
,
f
.
name
),
bool
)
and
getattr
(
pass_cfg
,
f
.
name
)
]
trace_structured
(
"artifact"
,
metadata_fn
=
lambda
:
{
"name"
:
"vllm_compilation_config"
,
"encoding"
:
"json"
,
},
payload_fn
=
lambda
:
json
.
dumps
(
{
"model"
:
self
.
vllm_config
.
model_config
.
model
,
"prefix"
:
self
.
prefix
,
"mode"
:
str
(
cc
.
mode
),
"backend"
:
cc
.
backend
,
"custom_ops"
:
list_to_str
(
cc
.
custom_ops
),
"splitting_ops"
:
list_to_str
(
cc
.
splitting_ops
),
"cudagraph_mode"
:
str
(
cc
.
cudagraph_mode
),
"compile_sizes"
:
list_to_str
(
cc
.
compile_sizes
),
"compile_ranges_split_points"
:
list_to_str
(
cc
.
compile_ranges_split_points
),
"use_inductor_graph_partition"
:
cc
.
use_inductor_graph_partition
,
"inductor_passes"
:
list_to_str
(
list
(
cc
.
inductor_passes
.
keys
())),
"enabled_passes"
:
list_to_str
(
enabled_passes
),
"dynamic_shapes_type"
:
str
(
cc
.
dynamic_shapes_config
.
type
),
"dynamic_shapes_evaluate_guards"
:
cc
.
dynamic_shapes_config
.
evaluate_guards
,
# noqa: E501
}
),
)
def
__call__
(
self
,
graph
:
fx
.
GraphModule
,
example_inputs
:
Sequence
[
Any
])
->
Any
:
def
__call__
(
self
,
graph
:
fx
.
GraphModule
,
example_inputs
:
Sequence
[
Any
])
->
Any
:
from
.caching
import
(
from
.caching
import
(
VllmSerializableFunction
,
VllmSerializableFunction
,
)
)
vllm_config
=
self
.
vllm_config
vllm_config
=
self
.
vllm_config
self
.
_log_compilation_config
()
# Minimal hashing here with existing utilities, reused below.
# Minimal hashing here with existing utilities, reused below.
env_factors
=
envs
.
compile_factors
()
env_factors
=
envs
.
compile_factors
()
...
@@ -892,6 +943,13 @@ class VllmBackend:
...
@@ -892,6 +943,13 @@ class VllmBackend:
lazy_format_graph_code
(
"before split"
,
self
.
graph
)
lazy_format_graph_code
(
"before split"
,
self
.
graph
)
lazy_format_graph_code
(
"after split"
,
self
.
split_gm
)
lazy_format_graph_code
(
"after split"
,
self
.
split_gm
)
# Log the piecewise split graph for TORCH_TRACE/tlparse
trace_structured
(
"graph_dump"
,
metadata_fn
=
lambda
:
{
"name"
:
"vllm_piecewise_split_graph"
},
payload_fn
=
lambda
:
self
.
split_gm
.
print_readable
(
print_output
=
False
),
)
compilation_counter
.
num_piecewise_graphs_seen
+=
len
(
self
.
piecewise_graphs
)
compilation_counter
.
num_piecewise_graphs_seen
+=
len
(
self
.
piecewise_graphs
)
submod_names_to_compile
=
[
submod_names_to_compile
=
[
item
.
submod_name
item
.
submod_name
...
...
vllm/compilation/piecewise_backend.py
View file @
608b5565
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
import
dataclasses
import
dataclasses
import
io
import
io
import
json
import
pickle
import
pickle
from
collections.abc
import
Callable
from
collections.abc
import
Callable
from
pickle
import
Pickler
from
pickle
import
Pickler
...
@@ -11,6 +12,7 @@ from typing import Any
...
@@ -11,6 +12,7 @@ from typing import Any
import
torch._functorch.config
import
torch._functorch.config
import
torch.fx
as
fx
import
torch.fx
as
fx
from
torch._inductor.runtime.triton_heuristics
import
CachingAutotuner
from
torch._inductor.runtime.triton_heuristics
import
CachingAutotuner
from
torch._logging._internal
import
trace_structured
from
vllm.compilation.backends
import
VllmBackend
from
vllm.compilation.backends
import
VllmBackend
from
vllm.compilation.monitor
import
end_monitoring_torch_compile
from
vllm.compilation.monitor
import
end_monitoring_torch_compile
...
@@ -39,6 +41,7 @@ class PiecewiseBackend:
...
@@ -39,6 +41,7 @@ class PiecewiseBackend:
vllm_backend
:
VllmBackend
,
vllm_backend
:
VllmBackend
,
returns_tuple
:
bool
,
returns_tuple
:
bool
,
compiled_runnables
:
dict
[
str
,
Callable
[...,
Any
]]
|
None
=
None
,
compiled_runnables
:
dict
[
str
,
Callable
[...,
Any
]]
|
None
=
None
,
submod_name
:
str
=
""
,
):
):
"""
"""
The backend for piecewise compilation.
The backend for piecewise compilation.
...
@@ -70,6 +73,7 @@ class PiecewiseBackend:
...
@@ -70,6 +73,7 @@ class PiecewiseBackend:
self
.
total_piecewise_compiles
=
total_piecewise_compiles
self
.
total_piecewise_compiles
=
total_piecewise_compiles
self
.
vllm_backend
=
vllm_backend
self
.
vllm_backend
=
vllm_backend
self
.
compiled_runnables
=
compiled_runnables
self
.
compiled_runnables
=
compiled_runnables
self
.
submod_name
=
submod_name
self
.
is_first_graph
=
piecewise_compile_index
==
0
self
.
is_first_graph
=
piecewise_compile_index
==
0
self
.
is_last_graph
=
piecewise_compile_index
==
total_piecewise_compiles
-
1
self
.
is_last_graph
=
piecewise_compile_index
==
total_piecewise_compiles
-
1
...
@@ -131,6 +135,9 @@ class PiecewiseBackend:
...
@@ -131,6 +135,9 @@ class PiecewiseBackend:
compile_range
=
range
,
compile_range
=
range
,
)
)
# Track whether we've logged the graph for this subgraph (only log once)
self
.
_graph_logged
=
False
# get the on_compilation_complete callback from context...
# get the on_compilation_complete callback from context...
# PiecewiseBackend is created during the first call,
# PiecewiseBackend is created during the first call,
# which is when the context is set (see compilation/decorators.py)
# which is when the context is set (see compilation/decorators.py)
...
@@ -221,6 +228,45 @@ class PiecewiseBackend:
...
@@ -221,6 +228,45 @@ class PiecewiseBackend:
assert
len
(
fake_example_inputs
)
==
len
(
args
)
assert
len
(
fake_example_inputs
)
==
len
(
args
)
return
fake_example_inputs
return
fake_example_inputs
def
_log_compile_start
(
self
,
compile_range
:
Range
):
"""Log compilation event for TORCH_TRACE/tlparse."""
is_cudagraph_size
=
(
self
.
compile_sizes
is
not
None
and
compile_range
.
start
in
self
.
compile_sizes
)
subgraph_index
=
self
.
piecewise_compile_index
submod_name
=
self
.
submod_name
trace_structured
(
"artifact"
,
metadata_fn
=
lambda
:
{
"name"
:
"vllm_piecewise_compile_start"
,
"encoding"
:
"json"
,
},
payload_fn
=
lambda
:
json
.
dumps
(
{
"piecewise_index"
:
subgraph_index
,
"submod_name"
:
submod_name
,
"total_piecewise_compiles"
:
self
.
total_piecewise_compiles
,
"compile_range_start"
:
compile_range
.
start
,
"compile_range_end"
:
compile_range
.
end
,
"is_single_size"
:
compile_range
.
is_single_size
(),
"is_cudagraph_capture_size"
:
is_cudagraph_size
,
}
),
)
# Log the subgraph graph dump only once per subgraph (not per size)
# to reduce log file size. The graph code is the same for all sizes.
if
not
self
.
_graph_logged
:
self
.
_graph_logged
=
True
assert
self
.
graph
is
not
None
trace_structured
(
"graph_dump"
,
metadata_fn
=
lambda
:
{
"name"
:
f
"vllm_
{
submod_name
}
"
,
},
payload_fn
=
lambda
:
self
.
graph
.
print_readable
(
print_output
=
False
),
)
def
_maybe_compile_for_range_entry
(
def
_maybe_compile_for_range_entry
(
self
,
range_entry
:
RangeEntry
,
args
:
tuple
[
Any
,
...]
self
,
range_entry
:
RangeEntry
,
args
:
tuple
[
Any
,
...]
)
->
Any
:
)
->
Any
:
...
@@ -230,6 +276,8 @@ class PiecewiseBackend:
...
@@ -230,6 +276,8 @@ class PiecewiseBackend:
self
.
compiled_runnables
[
str
(
range_entry
.
compile_range
)]
self
.
compiled_runnables
[
str
(
range_entry
.
compile_range
)]
)
)
else
:
else
:
self
.
_log_compile_start
(
range_entry
.
compile_range
)
# args are real arguments
# args are real arguments
# fakify for range, real args for concrete size.
# fakify for range, real args for concrete size.
# For concrete size, we clear the shape env in
# For concrete size, we clear the shape env in
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment