Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9f14c922
Unverified
Commit
9f14c922
authored
Feb 04, 2026
by
Richard Zou
Committed by
GitHub
Feb 04, 2026
Browse files
Revert "[torch.compile] Significantly speed up cold start times" (#33820)
Signed-off-by:
Richard Zou
<
zou3519@gmail.com
>
parent
535de06c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
21 additions
and
41 deletions
+21
-41
tests/compile/test_cold_start.py
tests/compile/test_cold_start.py
+4
-5
vllm/compilation/backends.py
vllm/compilation/backends.py
+14
-36
vllm/compilation/compiler_interface.py
vllm/compilation/compiler_interface.py
+3
-0
No files found.
tests/compile/test_cold_start.py
View file @
9f14c922
...
...
@@ -37,13 +37,12 @@ def test_moe_compilation_cold_start(monkeypatch, use_fresh_inductor_cache):
# The forward pass consists of 32 transformer layers.
# Then, we split on the attention operation. This results in
# 33 subgraphs (not including the attention operation).
#
We then standalone_compile the unique subgraphs
.
#
The 33 subgraphs then get standalone_compile'd
.
#
# There are actually only 3 unique subgraphs for this model
# (all of its transformer layers are the same modulo weights);
# this is true for most vLLM models.
# So we test that during cold start, only 3 subgraphs are compiled
# These 3 subgraphs should cache miss, and then there should be
# no other compilation (so no cache hits).
# So we test that during cold start, the aot_autograd cache
# misses for 3 subgraphs and hits for the rest.
assert
counters
[
"aot_autograd"
][
"autograd_cache_miss"
]
==
3
assert
counters
[
"aot_autograd"
][
"autograd_cache_hit"
]
==
0
assert
counters
[
"aot_autograd"
][
"autograd_cache_hit"
]
==
3
0
vllm/compilation/backends.py
View file @
9f14c922
...
...
@@ -121,7 +121,7 @@ class CompilerManager:
and compiling the graph.
The cache is a dict mapping
`(runtime_shape, graph_
hash
, backend_name)`
`(runtime_shape, graph_
index
, backend_name)`
to `any_data` returned from the compiler.
When serializing the cache, we save it to a Python file
...
...
@@ -130,7 +130,7 @@ class CompilerManager:
"""
def
__init__
(
self
,
compilation_config
:
CompilationConfig
)
->
None
:
self
.
cache
:
dict
[
tuple
[
Range
,
str
,
str
],
Any
]
=
dict
()
self
.
cache
:
dict
[
tuple
[
Range
,
int
,
str
],
Any
]
=
dict
()
self
.
is_cache_updated
=
False
self
.
compilation_config
=
compilation_config
self
.
compiler
=
make_compiler
(
compilation_config
)
...
...
@@ -173,7 +173,6 @@ class CompilerManager:
self
.
disable_cache
=
disable_cache
self
.
cache_dir
=
cache_dir
self
.
cache_file_path
=
os
.
path
.
join
(
cache_dir
,
"vllm_compile_cache.py"
)
self
.
loaded_cache_entries
:
dict
[
tuple
[
Range
,
str
,
str
],
Any
]
=
{}
if
not
disable_cache
and
os
.
path
.
exists
(
self
.
cache_file_path
):
# load the cache from the file
...
...
@@ -187,9 +186,9 @@ class CompilerManager:
if
not
isinstance
(
value
,
ty
):
raise
TypeError
(
f
"Expected
{
ty
}
but got
{
type
(
value
)
}
for
{
value
}
"
)
def
parse_key
(
key
:
Any
)
->
tuple
[
Range
,
str
,
str
]:
range_tuple
,
graph_
hash
,
compiler_name
=
key
check_type
(
graph_
hash
,
str
)
def
parse_key
(
key
:
Any
)
->
tuple
[
Range
,
int
,
str
]:
range_tuple
,
graph_
index
,
compiler_name
=
key
check_type
(
graph_
index
,
int
)
check_type
(
compiler_name
,
str
)
if
isinstance
(
range_tuple
,
tuple
):
start
,
end
=
range_tuple
...
...
@@ -197,7 +196,7 @@ class CompilerManager:
check_type
(
end
,
int
)
range_tuple
=
Range
(
start
=
start
,
end
=
end
)
check_type
(
range_tuple
,
Range
)
return
range_tuple
,
graph_
hash
,
compiler_name
return
range_tuple
,
graph_
index
,
compiler_name
self
.
cache
=
{
parse_key
(
key
):
value
for
key
,
value
in
cache
.
items
()}
...
...
@@ -217,25 +216,18 @@ class CompilerManager:
self
,
graph
:
fx
.
GraphModule
,
example_inputs
:
list
[
Any
],
graph_
hash
:
str
,
graph_
index
:
int
,
compile_range
:
Range
,
)
->
Callable
[...,
Any
]
|
None
:
key
=
(
compile_range
,
graph_hash
,
self
.
compiler
.
name
)
# See if we've already loaded this cache entry
if
key
in
self
.
loaded_cache_entries
:
return
self
.
loaded_cache_entries
[
key
]
# Otherwise, go load it from disk
if
key
not
in
self
.
cache
:
if
(
compile_range
,
graph_index
,
self
.
compiler
.
name
)
not
in
self
.
cache
:
return
None
handle
=
self
.
cache
[
key
]
handle
=
self
.
cache
[
(
compile_range
,
graph_index
,
self
.
compiler
.
name
)
]
compiled_graph
=
self
.
compiler
.
load
(
handle
,
graph
,
example_inputs
,
compile_range
handle
,
graph
,
example_inputs
,
graph_index
,
compile_range
)
self
.
loaded_cache_entries
[
key
]
=
compiled_graph
logger
.
debug
(
"Directly load the graph (hash %s) for compile range "
"%sfrom %s via handle %s"
,
graph_hash
,
"Directly load the %s-th graph for compile range %sfrom %s via handle %s"
,
graph_index
,
str
(
compile_range
),
self
.
compiler
.
name
,
handle
,
...
...
@@ -257,22 +249,12 @@ class CompilerManager:
global
compilation_start_time
compilation_start_time
=
time
.
time
()
from
torch._functorch._aot_autograd.autograd_cache
import
(
AOTAutogradCachePickler
,
sanitize_gm_for_cache
,
)
with
sanitize_gm_for_cache
(
graph
):
pickler
=
AOTAutogradCachePickler
(
graph
)
dumped_graph
=
pickler
.
dumps
(
graph
)
graph_hash
=
hashlib
.
sha256
(
dumped_graph
).
hexdigest
()
compilation_counter
.
num_backend_compilations
+=
1
compiled_graph
=
None
# try to load from the cache
compiled_graph
=
self
.
load
(
graph
,
example_inputs
,
graph_
hash
,
compile_range
)
compiled_graph
=
self
.
load
(
graph
,
example_inputs
,
graph_
index
,
compile_range
)
if
compiled_graph
is
not
None
:
if
graph_index
==
num_graphs
-
1
:
# after loading the last graph for this shape, record the time.
...
...
@@ -308,13 +290,9 @@ class CompilerManager:
assert
compiled_graph
is
not
None
,
"Failed to compile the graph"
self
.
loaded_cache_entries
[(
compile_range
,
graph_hash
,
self
.
compiler
.
name
)]
=
(
compiled_graph
)
# store the artifact in the cache
if
is_compile_cache_enabled
(
additional_inductor_config
)
and
handle
is
not
None
:
self
.
cache
[(
compile_range
,
graph_
hash
,
self
.
compiler
.
name
)]
=
handle
self
.
cache
[(
compile_range
,
graph_
index
,
self
.
compiler
.
name
)]
=
handle
compilation_counter
.
num_cache_entries_updated
+=
1
self
.
is_cache_updated
=
True
if
graph_index
==
0
:
...
...
vllm/compilation/compiler_interface.py
View file @
9f14c922
...
...
@@ -101,6 +101,7 @@ class CompilerInterface:
handle
:
Any
,
graph
:
fx
.
GraphModule
,
example_inputs
:
list
[
Any
],
graph_index
:
int
,
compile_range
:
Range
,
)
->
Callable
[...,
Any
]:
"""
...
...
@@ -301,6 +302,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
handle
:
Any
,
graph
:
fx
.
GraphModule
,
example_inputs
:
list
[
Any
],
graph_index
:
int
,
compile_range
:
Range
,
)
->
Callable
[...,
Any
]:
assert
isinstance
(
handle
,
tuple
)
...
...
@@ -525,6 +527,7 @@ class InductorAdaptor(CompilerInterface):
handle
:
Any
,
graph
:
fx
.
GraphModule
,
example_inputs
:
list
[
Any
],
graph_index
:
int
,
compile_range
:
Range
,
)
->
Callable
[...,
Any
]:
assert
isinstance
(
handle
,
tuple
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment