Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5569f521
Unverified
Commit
5569f521
authored
Mar 04, 2026
by
Richard Zou
Committed by
GitHub
Mar 04, 2026
Browse files
[torch.compile] Stop lazily compiling (#35472)
Signed-off-by:
Richard Zou
<
zou3519@gmail.com
>
parent
138d891d
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
177 additions
and
150 deletions
+177
-150
tests/compile/test_compile_ranges.py
tests/compile/test_compile_ranges.py
+4
-3
tests/compile/test_structured_logging.py
tests/compile/test_structured_logging.py
+3
-3
vllm/compilation/backends.py
vllm/compilation/backends.py
+20
-31
vllm/compilation/caching.py
vllm/compilation/caching.py
+11
-15
vllm/compilation/decorators.py
vllm/compilation/decorators.py
+12
-7
vllm/compilation/monitor.py
vllm/compilation/monitor.py
+1
-1
vllm/compilation/piecewise_backend.py
vllm/compilation/piecewise_backend.py
+126
-90
No files found.
tests/compile/test_compile_ranges.py
View file @
5569f521
...
@@ -73,6 +73,7 @@ def test_compile_ranges(use_fresh_inductor_cache):
...
@@ -73,6 +73,7 @@ def test_compile_ranges(use_fresh_inductor_cache):
Range
(
start
=
16
,
end
=
16
),
Range
(
start
=
16
,
end
=
16
),
Range
(
start
=
9
,
end
=
32
),
Range
(
start
=
9
,
end
=
32
),
Range
(
start
=
64
,
end
=
64
),
Range
(
start
=
64
,
end
=
64
),
Range
(
start
=
128
,
end
=
128
),
Range
(
start
=
33
,
end
=
8192
),
Range
(
start
=
33
,
end
=
8192
),
]
]
)
)
...
@@ -95,16 +96,16 @@ def test_compile_ranges(use_fresh_inductor_cache):
...
@@ -95,16 +96,16 @@ def test_compile_ranges(use_fresh_inductor_cache):
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
model
=
TestModel
(
vllm_config
=
vllm_config
,
prefix
=
""
).
eval
()
model
=
TestModel
(
vllm_config
=
vllm_config
,
prefix
=
""
).
eval
()
# Number of compilations: 3
for each
compile range +
2
compile sizes
# Number of compilations: 3 compile range
s
+
3
compile sizes
batch_sizes
=
[
1
,
4
,
16
,
24
,
48
,
64
,
8192
]
batch_sizes
=
[
1
,
4
,
16
,
24
,
48
,
64
,
8192
]
with
compilation_counter
.
expect
(
with
compilation_counter
.
expect
(
num_graphs_seen
=
1
,
num_graphs_seen
=
1
,
num_piecewise_graphs_seen
=
1
,
num_piecewise_graphs_seen
=
1
,
num_backend_compilations
=
5
,
num_backend_compilations
=
6
,
):
):
run_model
(
vllm_config
,
model
,
batch_sizes
)
run_model
(
vllm_config
,
model
,
batch_sizes
)
assert
post_grad_range_checker
.
num_calls
==
5
assert
post_grad_range_checker
.
num_calls
==
6
def
test_compile_config_get_compile_ranges
():
def
test_compile_config_get_compile_ranges
():
...
...
tests/compile/test_structured_logging.py
View file @
5569f521
...
@@ -109,9 +109,9 @@ def test_vllm_structured_logging_artifacts(use_fresh_inductor_cache):
...
@@ -109,9 +109,9 @@ def test_vllm_structured_logging_artifacts(use_fresh_inductor_cache):
f
"got
{
len
(
vllm_piecewise_split_graph
)
}
"
f
"got
{
len
(
vllm_piecewise_split_graph
)
}
"
)
)
compile_start_artifacts
=
capture
.
get
(
"artifact"
,
"vllm_piecewise_compile_start"
)
compile_start_artifacts
=
capture
.
get
(
"artifact"
,
"vllm_piecewise_compile_start"
)
assert
len
(
compile_start_artifacts
)
==
2
,
(
assert
len
(
compile_start_artifacts
)
==
4
,
(
"Expected
2
vllm_piecewise_compile_start "
"Expected
4
vllm_piecewise_compile_start "
"(
one for dynamic ranges, one for
compile size), "
"(
2 subgraphs x 2 ranges each: dynamic +
compile size), "
f
"got
{
len
(
compile_start_artifacts
)
}
"
f
"got
{
len
(
compile_start_artifacts
)
}
"
)
)
submod_dumps
=
capture
.
get
(
"graph_dump"
,
r
"vllm_submod_.*"
)
submod_dumps
=
capture
.
get
(
"graph_dump"
,
r
"vllm_submod_.*"
)
...
...
vllm/compilation/backends.py
View file @
5569f521
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
ast
import
ast
import
contextvars
import
dataclasses
import
dataclasses
import
hashlib
import
hashlib
import
json
import
json
...
@@ -18,7 +17,7 @@ from typing import Any
...
@@ -18,7 +17,7 @@ from typing import Any
import
torch
import
torch
import
torch.fx
as
fx
import
torch.fx
as
fx
from
torch._d
ispatch.python
import
enable_python_dispatcher
from
torch._d
ynamo.utils
import
dynamo_timed
from
torch._logging._internal
import
trace_structured
from
torch._logging._internal
import
trace_structured
import
vllm.envs
as
envs
import
vllm.envs
as
envs
...
@@ -510,9 +509,9 @@ def wrap_with_cudagraph_if_needed(
...
@@ -510,9 +509,9 @@ def wrap_with_cudagraph_if_needed(
class
PiecewiseCompileInterpreter
(
torch
.
fx
.
Interpreter
):
# type: ignore[misc]
class
PiecewiseCompileInterpreter
(
torch
.
fx
.
Interpreter
):
# type: ignore[misc]
"""Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
"""Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
It runs the given
graph with fake inputs, and compile some
It runs the given
split graph interpreter, and for each submodule in
submodules specified by `compile_submod_names` with the given
`compile_submod_names`, creates a PiecewiseBackend and compiles all
compilation configs
.
ranges up front
.
NOTE: the order in `compile_submod_names` matters, because
NOTE: the order in `compile_submod_names` matters, because
it will be used to determine the order of the compiled piecewise
it will be used to determine the order of the compiled piecewise
...
@@ -540,9 +539,6 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
...
@@ -540,9 +539,6 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
vllm_backend
:
"VllmBackend"
,
vllm_backend
:
"VllmBackend"
,
)
->
None
:
)
->
None
:
super
().
__init__
(
module
)
super
().
__init__
(
module
)
from
torch._guards
import
detect_fake_mode
self
.
fake_mode
=
detect_fake_mode
()
self
.
compile_submod_names
=
compile_submod_names
self
.
compile_submod_names
=
compile_submod_names
self
.
compilation_config
=
vllm_config
.
compilation_config
self
.
compilation_config
=
vllm_config
.
compilation_config
self
.
vllm_config
=
vllm_config
self
.
vllm_config
=
vllm_config
...
@@ -552,13 +548,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
...
@@ -552,13 +548,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
@
instrument
(
span_name
=
"Inductor compilation"
)
@
instrument
(
span_name
=
"Inductor compilation"
)
def
run
(
self
,
*
args
:
Any
)
->
Any
:
def
run
(
self
,
*
args
:
Any
)
->
Any
:
# maybe instead just assert inputs are fake?
return
super
().
run
(
*
args
)
fake_args
=
[
self
.
fake_mode
.
from_tensor
(
t
)
if
isinstance
(
t
,
torch
.
Tensor
)
else
t
for
t
in
args
]
with
self
.
fake_mode
,
enable_python_dispatcher
():
return
super
().
run
(
*
fake_args
)
def
call_module
(
def
call_module
(
self
,
self
,
...
@@ -614,21 +604,6 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
...
@@ -614,21 +604,6 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
model_tag
:
str
=
"backbone"
model_tag
:
str
=
"backbone"
model_is_encoder
:
bool
=
False
model_is_encoder
:
bool
=
False
_on_compilation_complete_callback
:
contextvars
.
ContextVar
[
Callable
[[],
None
]
|
None
]
=
(
contextvars
.
ContextVar
(
"on_compilation_complete_callback"
,
default
=
None
)
)
@
contextmanager
def
set_on_compilation_complete
(
callback
:
Callable
[[],
None
],
)
->
Generator
[
None
,
None
,
None
]:
token
=
_on_compilation_complete_callback
.
set
(
callback
)
try
:
yield
finally
:
_on_compilation_complete_callback
.
reset
(
token
)
@
contextmanager
@
contextmanager
def
set_model_tag
(
tag
:
str
,
is_encoder
:
bool
=
False
)
->
Generator
[
None
,
None
,
None
]:
def
set_model_tag
(
tag
:
str
,
is_encoder
:
bool
=
False
)
->
Generator
[
None
,
None
,
None
]:
...
@@ -846,6 +821,7 @@ class VllmBackend:
...
@@ -846,6 +821,7 @@ class VllmBackend:
),
),
)
)
@
dynamo_timed
(
"vllm_backend"
)
def
__call__
(
self
,
graph
:
fx
.
GraphModule
,
example_inputs
:
Sequence
[
Any
])
->
Any
:
def
__call__
(
self
,
graph
:
fx
.
GraphModule
,
example_inputs
:
Sequence
[
Any
])
->
Any
:
from
.caching
import
(
from
.caching
import
(
VllmSerializableFunction
,
VllmSerializableFunction
,
...
@@ -1036,11 +1012,24 @@ class VllmBackend:
...
@@ -1036,11 +1012,24 @@ class VllmBackend:
]
]
# propagate the split graph to the piecewise backend,
# propagate the split graph to the piecewise backend,
# compile submodules with symbolic shapes
# compile submodules with symbolic shapes, and compile all ranges
# up front so that compilation is complete before the callable
# is returned.
PiecewiseCompileInterpreter
(
PiecewiseCompileInterpreter
(
self
.
split_gm
,
submod_names_to_compile
,
self
.
vllm_config
,
self
self
.
split_gm
,
submod_names_to_compile
,
self
.
vllm_config
,
self
).
run
(
*
fake_args
)
).
run
(
*
fake_args
)
# All compilation is done. Save the cache.
time_before_saving
=
time
.
perf_counter
()
self
.
compiler_manager
.
save_to_file
()
elapsed
=
time
.
perf_counter
()
-
time_before_saving
if
elapsed
>
1
:
logger
.
info_once
(
"Saved compiler manager cache in %.2f seconds."
,
elapsed
,
scope
=
"local"
,
)
from
torch._guards
import
detect_fake_mode
from
torch._guards
import
detect_fake_mode
fake_mode
=
detect_fake_mode
()
fake_mode
=
detect_fake_mode
()
...
...
vllm/compilation/caching.py
View file @
5569f521
...
@@ -313,30 +313,26 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
...
@@ -313,30 +313,26 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
return
fn
return
fn
# Fall back to standard VllmBackend
# Fall back to standard VllmBackend.
# Use a lazy closure: the backend needs traced_files for cache
# dir computation, but those are only populated after
# _verify_source_unchanged runs in decorators.py (which happens
# after deserialization completes).
from
vllm.compilation.backends
import
VllmBackend
from
vllm.compilation.backends
import
VllmBackend
is_encoder
=
state
.
get
(
"is_encoder"
,
False
)
is_encoder
=
state
.
get
(
"is_encoder"
,
False
)
vllm_backend
:
VllmBackend
=
VllmBackend
(
vllm_config
=
get_current_vllm_config
()
get_current_vllm_config
(),
state
[
"prefix"
],
is_encoder
compile_inputs
=
list
(
state
[
"example_inputs"
])
)
def
optimized_call
(
*
example_inputs
:
Any
)
->
Any
:
def
optimized_call
(
*
example_inputs
:
Any
)
->
Any
:
"""
vllm_backend
:
VllmBackend
=
VllmBackend
(
On the first run of the optimized call, we rerun the compiler
vllm_config
,
state
[
"prefix"
],
is_encoder
backend which should result in a cache hit. After the backend
)
call returns, we just do a one-time replacement of the optimized
call with the compiled function, so that subsequent calls are on
the AOT compiled path.
"""
compile_inputs
=
[
inp
if
inp
is
not
None
else
example_inputs
[
i
]
for
i
,
inp
in
enumerate
(
fn
.
example_inputs
)
]
with
tracing
(
TracingContext
(
fake_mode
)):
with
tracing
(
TracingContext
(
fake_mode
)):
fn
.
optimized_call
=
vllm_backend
(
fn
.
optimized_call
=
vllm_backend
(
state
[
"graph_module"
],
compile_inputs
state
[
"graph_module"
],
compile_inputs
).
optimized_call
).
optimized_call
fn
.
vllm_backend
=
vllm_backend
return
fn
.
optimized_call
(
*
example_inputs
)
return
fn
.
optimized_call
(
*
example_inputs
)
fn
=
cls
(
**
state
,
optimized_call
=
optimized_call
)
fn
=
cls
(
**
state
,
optimized_call
=
optimized_call
)
...
...
vllm/compilation/decorators.py
View file @
5569f521
...
@@ -466,8 +466,12 @@ def _support_torch_compile(
...
@@ -466,8 +466,12 @@ def _support_torch_compile(
"Directly load AOT compilation from path %s"
,
aot_compilation_path
"Directly load AOT compilation from path %s"
,
aot_compilation_path
)
)
# Apply partition wrapper context for proper CUDA graph capture
# Apply partition wrapper context for proper CUDA graph capture
from
.monitor
import
end_monitoring_torch_compile
with
maybe_use_cudagraph_partition_wrapper
(
self
.
vllm_config
):
with
maybe_use_cudagraph_partition_wrapper
(
self
.
vllm_config
):
return
self
.
aot_compiled_fn
(
self
,
*
args
,
**
kwargs
)
output
=
self
.
aot_compiled_fn
(
self
,
*
args
,
**
kwargs
)
end_monitoring_torch_compile
(
self
.
vllm_config
)
return
output
if
self
.
compiled
:
if
self
.
compiled
:
assert
(
assert
(
...
@@ -552,18 +556,19 @@ def _support_torch_compile(
...
@@ -552,18 +556,19 @@ def _support_torch_compile(
logger
.
warning
(
"Detected eager backend, disabling AOT compile."
)
logger
.
warning
(
"Detected eager backend, disabling AOT compile."
)
use_aot_compile
=
False
use_aot_compile
=
False
if
use_aot_compile
:
if
use_aot_compile
:
from
vllm.compilation.backends
import
set_on_compilation_complete
# store the path for saving after warmup
# store the path for saving after warmup
self
.
_aot_compilation_path
=
aot_compilation_path
self
.
_aot_compilation_path
=
aot_compilation_path
self
.
_aot_cache_dir
=
cache_dir
self
.
_aot_cache_dir
=
cache_dir
# set callback in context so it's available when compilation completes
with
set_on_compilation_complete
(
self
.
save_aot_compiled_function
):
self
.
aot_compiled_fn
=
self
.
aot_compile
(
*
args
,
**
kwargs
)
self
.
aot_compiled_fn
=
self
.
aot_compile
(
*
args
,
**
kwargs
)
# All compilation is done at this point, save the AOT artifact.
self
.
save_aot_compiled_function
()
output
=
self
.
aot_compiled_fn
(
self
,
*
args
,
**
kwargs
)
output
=
self
.
aot_compiled_fn
(
self
,
*
args
,
**
kwargs
)
else
:
else
:
output
=
TorchCompileWithNoGuardsWrapper
.
__call__
(
self
,
*
args
,
**
kwargs
)
# type: ignore[arg-type]
output
=
TorchCompileWithNoGuardsWrapper
.
__call__
(
self
,
*
args
,
**
kwargs
)
# type: ignore[arg-type]
from
.monitor
import
end_monitoring_torch_compile
end_monitoring_torch_compile
(
self
.
vllm_config
)
self
.
compiled
=
True
self
.
compiled
=
True
return
output
return
output
...
...
vllm/compilation/monitor.py
View file @
5569f521
...
@@ -33,7 +33,7 @@ def end_monitoring_torch_compile(vllm_config: VllmConfig) -> None:
...
@@ -33,7 +33,7 @@ def end_monitoring_torch_compile(vllm_config: VllmConfig) -> None:
total_compile_time
:
float
=
time
.
perf_counter
()
-
torch_compile_start_time
total_compile_time
:
float
=
time
.
perf_counter
()
-
torch_compile_start_time
if
compilation_config
.
mode
==
CompilationMode
.
VLLM_COMPILE
:
if
compilation_config
.
mode
==
CompilationMode
.
VLLM_COMPILE
:
logger
.
info_once
(
logger
.
info_once
(
"torch.compile
takes
%.2f s in total"
,
"torch.compile
and initial profiling run took
%.2f s in total"
,
total_compile_time
,
total_compile_time
,
scope
=
"local"
,
scope
=
"local"
,
)
)
...
...
vllm/compilation/piecewise_backend.py
View file @
5569f521
...
@@ -5,7 +5,6 @@ import dataclasses
...
@@ -5,7 +5,6 @@ import dataclasses
import
io
import
io
import
json
import
json
import
pickle
import
pickle
import
time
from
collections.abc
import
Callable
from
collections.abc
import
Callable
from
pickle
import
Pickler
from
pickle
import
Pickler
from
typing
import
Any
from
typing
import
Any
...
@@ -16,7 +15,6 @@ from torch._inductor.runtime.triton_heuristics import CachingAutotuner
...
@@ -16,7 +15,6 @@ from torch._inductor.runtime.triton_heuristics import CachingAutotuner
from
torch._logging._internal
import
trace_structured
from
torch._logging._internal
import
trace_structured
from
vllm.compilation.backends
import
VllmBackend
from
vllm.compilation.backends
import
VllmBackend
from
vllm.compilation.monitor
import
end_monitoring_torch_compile
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.config.utils
import
Range
from
vllm.config.utils
import
Range
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -24,6 +22,55 @@ from vllm.logger import init_logger
...
@@ -24,6 +22,55 @@ from vllm.logger import init_logger
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
def
get_fake_args_from_graph
(
graph
:
fx
.
GraphModule
)
->
list
[
Any
]:
"""Get fake args directly from graph placeholder nodes."""
fake_args
=
[]
for
node
in
graph
.
graph
.
nodes
:
if
node
.
op
==
"placeholder"
:
fake_args
.
append
(
node
.
meta
[
"example_value"
])
else
:
break
return
fake_args
def
create_concrete_args
(
graph
:
fx
.
GraphModule
,
size
:
int
)
->
list
[
Any
]:
"""Create example inputs with symbolic dims replaced by a concrete size.
Used for single-size eager compilation where we need concrete-shaped
inputs but don't have real runtime tensors yet.
"""
from
torch._prims_common
import
compute_required_storage_length
from
torch.fx.experimental.symbolic_shapes
import
is_symbolic
def
concretize
(
sym_val
:
Any
)
->
int
:
"""Replace all symbolic variables in a SymInt expression with size."""
if
not
is_symbolic
(
sym_val
):
return
int
(
sym_val
)
expr
=
sym_val
.
node
.
expr
return
int
(
expr
.
subs
({
s
:
size
for
s
in
expr
.
free_symbols
}))
args
:
list
[
Any
]
=
[]
for
node
in
graph
.
graph
.
nodes
:
if
node
.
op
!=
"placeholder"
:
break
val
=
node
.
meta
[
"example_value"
]
if
isinstance
(
val
,
torch
.
SymInt
):
args
.
append
(
concretize
(
val
))
elif
isinstance
(
val
,
torch
.
Tensor
):
new_shape
=
tuple
(
concretize
(
d
)
for
d
in
val
.
shape
)
new_strides
=
tuple
(
concretize
(
s
)
for
s
in
val
.
stride
())
new_storage_offset
=
concretize
(
val
.
storage_offset
())
needed_size
=
compute_required_storage_length
(
new_shape
,
new_strides
,
new_storage_offset
)
t
=
torch
.
empty
(
needed_size
,
dtype
=
val
.
dtype
,
device
=
val
.
device
)
t
=
t
.
as_strided
(
new_shape
,
new_strides
,
new_storage_offset
)
args
.
append
(
t
)
else
:
args
.
append
(
val
)
return
args
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
RangeEntry
:
class
RangeEntry
:
compile_range
:
Range
compile_range
:
Range
...
@@ -109,10 +156,6 @@ class PiecewiseBackend:
...
@@ -109,10 +156,6 @@ class PiecewiseBackend:
# the entries for ranges that we need to either
# the entries for ranges that we need to either
self
.
range_entries
:
dict
[
Range
,
RangeEntry
]
=
{}
self
.
range_entries
:
dict
[
Range
,
RangeEntry
]
=
{}
# to_be_compiled_ranges tracks the remaining ranges to compile,
# and updates during the compilation process, so we need to copy it
self
.
to_be_compiled_ranges
:
set
[
Range
]
=
set
(
self
.
compile_ranges
)
# We only keep compilation management inside this class directly.
# We only keep compilation management inside this class directly.
if
self
.
compile_sizes
is
not
None
:
if
self
.
compile_sizes
is
not
None
:
for
size
in
self
.
compile_sizes
:
for
size
in
self
.
compile_sizes
:
...
@@ -129,7 +172,6 @@ class PiecewiseBackend:
...
@@ -129,7 +172,6 @@ class PiecewiseBackend:
self
.
range_entries
[
range
]
=
RangeEntry
(
self
.
range_entries
[
range
]
=
RangeEntry
(
compile_range
=
range
,
compile_range
=
range
,
)
)
self
.
to_be_compiled_ranges
.
add
(
range
)
for
range
in
self
.
compile_ranges
:
for
range
in
self
.
compile_ranges
:
self
.
range_entries
[
range
]
=
RangeEntry
(
self
.
range_entries
[
range
]
=
RangeEntry
(
...
@@ -139,12 +181,10 @@ class PiecewiseBackend:
...
@@ -139,12 +181,10 @@ class PiecewiseBackend:
# Track whether we've logged the graph for this subgraph (only log once)
# Track whether we've logged the graph for this subgraph (only log once)
self
.
_graph_logged
=
False
self
.
_graph_logged
=
False
# get the on_compilation_complete callback from context...
if
self
.
graph
is
not
None
:
# PiecewiseBackend is created during the first call,
self
.
compile_all_ranges
()
# which is when the context is set (see compilation/decorators.py)
else
:
from
vllm.compilation.backends
import
_on_compilation_complete_callback
self
.
load_all_ranges
()
self
.
on_compilation_complete
=
_on_compilation_complete_callback
.
get
()
def
get_compiled_graph_wrapper
(
def
get_compiled_graph_wrapper
(
self
,
compiled_graph
:
Callable
[...,
Any
]
self
,
compiled_graph
:
Callable
[...,
Any
]
...
@@ -161,25 +201,6 @@ class PiecewiseBackend:
...
@@ -161,25 +201,6 @@ class PiecewiseBackend:
return
compiled_graph_wrapper
return
compiled_graph_wrapper
def
check_for_ending_compilation
(
self
)
->
None
:
if
self
.
is_last_graph
and
not
self
.
to_be_compiled_ranges
:
# no specific sizes to compile
# save the hash of the inductor graph for the next run
time_before_saving
=
time
.
perf_counter
()
self
.
vllm_backend
.
compiler_manager
.
save_to_file
()
elapsed
=
time
.
perf_counter
()
-
time_before_saving
if
elapsed
>
1
:
logger
.
info_once
(
"Saved compiler manager cache in %.2f seconds."
,
elapsed
,
scope
=
"local"
,
)
end_monitoring_torch_compile
(
self
.
vllm_config
)
# Call the completion callback (e.g., to save AOT compiled function)
if
self
.
on_compilation_complete
is
not
None
:
self
.
on_compilation_complete
()
def
to_bytes
(
self
)
->
dict
[
str
,
bytes
]:
def
to_bytes
(
self
)
->
dict
[
str
,
bytes
]:
class
StandaloneCompiledArtifactsPickler
(
Pickler
):
class
StandaloneCompiledArtifactsPickler
(
Pickler
):
def
reducer_override
(
self
,
obj
:
object
)
->
Any
:
def
reducer_override
(
self
,
obj
:
object
)
->
Any
:
...
@@ -216,27 +237,54 @@ class PiecewiseBackend:
...
@@ -216,27 +237,54 @@ class PiecewiseBackend:
return
out
return
out
def
_fakify_args
(
self
,
args
:
tuple
[
Any
,
...])
->
list
[
Any
]:
def
compile_all_ranges
(
self
)
->
None
:
# We need to pass fake example_inputs, otherwise torch.compile
"""Compile all range entries for this piecewise subgraph up front."""
# will fakify the example_inputs potentially causing some non dynamic
assert
self
.
graph
is
not
None
,
(
# dimension to be be duck shaped to other existing shapes that have hints
"Cannot compile without a graph. "
# matching their values.
"When loading from cache/AOT artifacts, "
# This is problem because it can lead to unintended specializations!
"compile_all_ranges should not be called."
# if the new wrongly dynamic dim is specialized
)
# it will force specializing the whole shape
# torch.compile probably should not accept
for
range_entry
in
self
.
range_entries
.
values
():
# non fake tensors as example inputs!
if
range_entry
.
compiled
:
# See issue https://github.com/vllm-project/vllm/issues/27899
continue
fake_example_inputs
=
[]
assert
self
.
graph
is
not
None
self
.
_log_compile_start
(
range_entry
.
compile_range
)
for
node
in
self
.
graph
.
graph
.
nodes
:
# All place holders come first
if
range_entry
.
compile_range
.
is_single_size
():
if
node
.
op
==
"placeholder"
:
args_list
=
create_concrete_args
(
fake_example_inputs
.
append
(
node
.
meta
[
"example_value"
])
self
.
graph
,
range_entry
.
compile_range
.
start
)
else
:
else
:
break
args_list
=
get_fake_args_from_graph
(
self
.
graph
)
assert
len
(
fake_example_inputs
)
==
len
(
args
)
return
fake_example_inputs
# TODO(https://github.com/vllm-project/vllm/issues/35766)
# Can we remove strict_autograd_cache and
# force_non_lazy_backward_lowering overrides?
# I added them explicitly because this is what they are
# set to before the refactor
# (https://github.com/vllm-project/vllm/pull/35472).
# They affect the aotautograd cache key computation
# but they shouldn't have any effect on the actual
# compilation.
config_patches
=
dict
(
bundled_autograd_cache
=
True
,
strict_autograd_cache
=
False
,
)
if
hasattr
(
torch
.
_functorch
.
config
,
"force_non_lazy_backward_lowering"
):
config_patches
[
"force_non_lazy_backward_lowering"
]
=
False
with
torch
.
_functorch
.
config
.
patch
(
**
config_patches
):
range_entry
.
runnable
=
self
.
vllm_backend
.
compiler_manager
.
compile
(
self
.
graph
,
args_list
,
self
.
vllm_backend
.
inductor_config
,
self
.
compilation_config
,
compile_range
=
range_entry
.
compile_range
,
graph_index
=
self
.
piecewise_compile_index
,
num_graphs
=
self
.
total_piecewise_compiles
,
)
range_entry
.
compiled
=
True
def
_log_compile_start
(
self
,
compile_range
:
Range
):
def
_log_compile_start
(
self
,
compile_range
:
Range
):
"""Log compilation event for TORCH_TRACE/tlparse."""
"""Log compilation event for TORCH_TRACE/tlparse."""
...
@@ -277,44 +325,29 @@ class PiecewiseBackend:
...
@@ -277,44 +325,29 @@ class PiecewiseBackend:
payload_fn
=
lambda
:
self
.
graph
.
print_readable
(
print_output
=
False
),
payload_fn
=
lambda
:
self
.
graph
.
print_readable
(
print_output
=
False
),
)
)
def
_maybe_compile_for_range_entry
(
def
load_all_ranges
(
self
)
->
None
:
self
,
range_entry
:
RangeEntry
,
args
:
tuple
[
Any
,
...]
"""Load all pre-compiled runnables for this piecewise subgraph.
)
->
Any
:
if
not
range_entry
.
compiled
:
if
self
.
compiled_runnables
is
not
None
:
range_entry
.
runnable
=
self
.
get_compiled_graph_wrapper
(
self
.
compiled_runnables
[
str
(
range_entry
.
compile_range
)]
)
else
:
self
.
_log_compile_start
(
range_entry
.
compile_range
)
# args are real arguments
Called during warm start to wrap all cached compiled_runnables
# fakify for range, real args for concrete size.
into range_entry.runnable up front, analogous to compile_all_ranges()
# For concrete size, we clear the shape env in
for the cold start path.
# compiler_manager.compile() so no need to fakify.
"""
args_list
=
(
assert
self
.
compiled_runnables
is
not
None
,
(
self
.
_fakify_args
(
args
)
"load_all_ranges should only be called when compiled_runnables "
if
not
range_entry
.
compile_range
.
is_single_size
()
"is set (warm start / cache loading path)."
else
list
(
args
)
)
)
for
range_entry
in
self
.
range_entries
.
values
():
with
(
if
range_entry
.
compiled
:
torch
.
_functorch
.
config
.
patch
(
"bundled_autograd_cache"
,
True
),
continue
):
key
=
str
(
range_entry
.
compile_range
)
range_entry
.
runnable
=
self
.
vllm_backend
.
compiler_manager
.
compile
(
assert
key
in
self
.
compiled_runnables
,
(
self
.
graph
,
f
"Missing compiled runnable for range
{
range_entry
.
compile_range
}
. "
args_list
,
f
"Available keys:
{
list
(
self
.
compiled_runnables
.
keys
())
}
"
self
.
vllm_backend
.
inductor_config
,
)
self
.
compilation_config
,
range_entry
.
runnable
=
self
.
get_compiled_graph_wrapper
(
compile_range
=
range_entry
.
compile_range
,
self
.
compiled_runnables
[
key
]
graph_index
=
self
.
piecewise_compile_index
,
num_graphs
=
self
.
total_piecewise_compiles
,
)
)
range_entry
.
compiled
=
True
range_entry
.
compiled
=
True
self
.
to_be_compiled_ranges
.
remove
(
range_entry
.
compile_range
)
self
.
check_for_ending_compilation
()
def
_find_range_for_shape
(
self
,
runtime_shape
:
int
)
->
RangeEntry
|
None
:
def
_find_range_for_shape
(
self
,
runtime_shape
:
int
)
->
RangeEntry
|
None
:
# First we try to find the range entry for the concrete compile size
# First we try to find the range entry for the concrete compile size
...
@@ -338,6 +371,9 @@ class PiecewiseBackend:
...
@@ -338,6 +371,9 @@ class PiecewiseBackend:
assert
range_entry
is
not
None
,
(
assert
range_entry
is
not
None
,
(
f
"Shape:
{
runtime_shape
}
out of considered ranges:
{
self
.
compile_ranges
}
"
f
"Shape:
{
runtime_shape
}
out of considered ranges:
{
self
.
compile_ranges
}
"
)
)
assert
range_entry
.
compiled
,
(
self
.
_maybe_compile_for_range_entry
(
range_entry
,
args
)
"All ranges should be compiled or loaded up front in "
"PiecewiseBackend.__init__. "
f
"range_entry=
{
range_entry
.
compile_range
}
"
)
return
range_entry
.
runnable
(
*
args
)
return
range_entry
.
runnable
(
*
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment