Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
09b95e36
Unverified
Commit
09b95e36
authored
Feb 07, 2025
by
youkaichao
Committed by
GitHub
Feb 07, 2025
Browse files
[torch.compile] PyTorch 2.6 and nightly compatibility (#12393)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
85ac82d2
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
493 additions
and
320 deletions
+493
-320
tests/compile/piecewise/test_simple.py
tests/compile/piecewise/test_simple.py
+1
-1
tests/compile/piecewise/test_toy_llama.py
tests/compile/piecewise/test_toy_llama.py
+3
-3
vllm/compilation/backends.py
vllm/compilation/backends.py
+133
-304
vllm/compilation/compiler_interface.py
vllm/compilation/compiler_interface.py
+340
-0
vllm/compilation/counter.py
vllm/compilation/counter.py
+1
-1
vllm/compilation/inductor_pass.py
vllm/compilation/inductor_pass.py
+0
-1
vllm/compilation/pass_manager.py
vllm/compilation/pass_manager.py
+15
-1
vllm/config.py
vllm/config.py
+0
-9
No files found.
tests/compile/piecewise/test_simple.py
View file @
09b95e36
...
@@ -92,7 +92,7 @@ def test_simple_piecewise_compile():
...
@@ -92,7 +92,7 @@ def test_simple_piecewise_compile():
num_graphs_seen
=
1
,
# one graph for the model
num_graphs_seen
=
1
,
# one graph for the model
num_piecewise_graphs_seen
=
5
,
# 2 * num_layers + 1
num_piecewise_graphs_seen
=
5
,
# 2 * num_layers + 1
num_piecewise_capturable_graphs_seen
=
3
,
# 1 + num_layers
num_piecewise_capturable_graphs_seen
=
3
,
# 1 + num_layers
num_
inductor
_compilations
=
3
,
# num_piecewise_capturable_graphs_seen
num_
backend
_compilations
=
3
,
# num_piecewise_capturable_graphs_seen
num_cudagraph_caputured
=
num_cudagraph_caputured
=
6
,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
6
,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
):
...
...
tests/compile/piecewise/test_toy_llama.py
View file @
09b95e36
...
@@ -322,7 +322,7 @@ def test_toy_llama():
...
@@ -322,7 +322,7 @@ def test_toy_llama():
num_graphs_seen
=
0
,
num_graphs_seen
=
0
,
num_piecewise_graphs_seen
=
0
,
num_piecewise_graphs_seen
=
0
,
num_piecewise_capturable_graphs_seen
=
0
,
num_piecewise_capturable_graphs_seen
=
0
,
num_
inductor
_compilations
=
0
,
num_
backend
_compilations
=
0
,
num_cudagraph_caputured
=
0
,
num_cudagraph_caputured
=
0
,
):
):
outputs
.
append
(
run_model
(
llama_config
,
use_compile
=
False
))
outputs
.
append
(
run_model
(
llama_config
,
use_compile
=
False
))
...
@@ -332,7 +332,7 @@ def test_toy_llama():
...
@@ -332,7 +332,7 @@ def test_toy_llama():
num_graphs_seen
=
1
,
# one graph for the model
num_graphs_seen
=
1
,
# one graph for the model
num_piecewise_graphs_seen
=
1
,
num_piecewise_graphs_seen
=
1
,
num_piecewise_capturable_graphs_seen
=
1
,
num_piecewise_capturable_graphs_seen
=
1
,
num_
inductor
_compilations
=
1
,
# num_piecewise_capturable_graphs_seen
num_
backend
_compilations
=
1
,
# num_piecewise_capturable_graphs_seen
num_cudagraph_caputured
=
num_cudagraph_caputured
=
2
,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
2
,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
):
...
@@ -345,7 +345,7 @@ def test_toy_llama():
...
@@ -345,7 +345,7 @@ def test_toy_llama():
1
,
# 2 * num_layers + 1
1
,
# 2 * num_layers + 1
num_piecewise_capturable_graphs_seen
=
1
+
num_piecewise_capturable_graphs_seen
=
1
+
llama_config
.
num_layers
,
# 1 + num_layers
llama_config
.
num_layers
,
# 1 + num_layers
num_
inductor
_compilations
=
1
+
num_
backend
_compilations
=
1
+
llama_config
.
num_layers
,
# num_piecewise_capturable_graphs_seen
llama_config
.
num_layers
,
# num_piecewise_capturable_graphs_seen
num_cudagraph_caputured
=
2
*
num_cudagraph_caputured
=
2
*
(
1
+
llama_config
.
num_layers
(
1
+
llama_config
.
num_layers
...
...
vllm/compilation/backends.py
View file @
09b95e36
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
ast
import
ast
import
copy
import
dataclasses
import
dataclasses
import
os
import
os
import
pprint
import
pprint
import
time
import
time
from
collections
import
defaultdict
from
contextlib
import
ExitStack
from
contextlib
import
ExitStack
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Sequence
,
Set
,
Tuple
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Sequence
,
Set
,
Tuple
from
unittest.mock
import
patch
from
unittest.mock
import
patch
...
@@ -19,6 +17,7 @@ from vllm.config import CompilationConfig, VllmConfig
...
@@ -19,6 +17,7 @@ from vllm.config import CompilationConfig, VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
weak_ref_tensors
from
vllm.utils
import
weak_ref_tensors
from
.compiler_interface
import
EagerAdaptor
,
InductorAdaptor
from
.counter
import
compilation_counter
from
.counter
import
compilation_counter
from
.inductor_pass
import
InductorPass
from
.inductor_pass
import
InductorPass
from
.monitor
import
end_monitoring_torch_compile
from
.monitor
import
end_monitoring_torch_compile
...
@@ -27,293 +26,115 @@ from .pass_manager import PostGradPassManager
...
@@ -27,293 +26,115 @@ from .pass_manager import PostGradPassManager
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
@
dataclasses
.
dataclass
class
CompilerManager
:
class
InductorArtifact
:
"""
hash_str
:
str
=
""
A manager to manage the compilation process, including
file_path
:
str
=
""
caching the compiled graph, loading the compiled graph,
and compiling the graph.
The cache is a dict mapping
`(runtime_shape, graph_index, backend_name)`
to `any_data` returned from the compiler.
class
InductorHashCache
:
When serializing the cache, we save it to a Python file
for readability. We don't use json here because json doesn't
support int as key.
"""
"""
Disk format: a Python list of tuples, each tuple is
(runtime_shape, graph_index, hash_str, file_path)
We use list of tuple for readability.
In-memory format: a defaultdict of dict, where the key is
def
__init__
(
self
,
use_inductor
:
bool
):
runtime_shape, and the value is a dict of graph_index to hash_str.
self
.
cache
:
Dict
[
Tuple
[
Optional
[
int
],
int
,
str
],
Any
]
=
dict
()
cls
=
InductorAdaptor
if
use_inductor
else
EagerAdaptor
self
.
compiler
=
cls
()
The data is essentially `Dict[Optional[int], Dict[int, InductorArtifact]]`,
def
compute_hash
(
self
,
vllm_config
:
VllmConfig
)
->
str
:
we don't use json here because json doesn't support int as key.
return
self
.
compiler
.
compute_hash
(
vllm_config
)
TODO: better off-the-shelf solution to serialize the data?
"""
def
__init__
(
self
,
cache_dir
:
str
,
disabled
:
bool
=
False
):
def
initialize_cache
(
self
,
cache_dir
:
str
,
disable_cache
:
bool
=
False
):
self
.
cache
:
Dict
[
Optional
[
int
],
self
.
disable_cache
=
disable_cache
Dict
[
int
,
InductorArtifact
]]
=
defaultdict
(
dict
)
self
.
disabled
=
disabled
self
.
cache_dir
=
cache_dir
self
.
cache_dir
=
cache_dir
self
.
cache_file_path
=
os
.
path
.
join
(
cache_dir
,
self
.
cache_file_path
=
os
.
path
.
join
(
cache_dir
,
"vllm_compile_cache.py"
)
"inductor_hash_cache.py"
)
if
disabled
:
return
# set flags so that Inductor and Triton store their cache
# in the cache_dir, then users only need to copy the cache_dir
# to another machine to reuse the cache.
inductor_cache
=
os
.
path
.
join
(
cache_dir
,
"inductor_cache"
)
os
.
makedirs
(
inductor_cache
,
exist_ok
=
True
)
os
.
environ
[
"TORCHINDUCTOR_CACHE_DIR"
]
=
inductor_cache
triton_cache
=
os
.
path
.
join
(
cache_dir
,
"triton_cache"
)
os
.
makedirs
(
triton_cache
,
exist_ok
=
True
)
os
.
environ
[
"TRITON_CACHE_DIR"
]
=
triton_cache
if
os
.
path
.
exists
(
self
.
cache_file_path
):
with
open
(
self
.
cache_file_path
)
as
f
:
self
.
deserialize
(
f
.
read
())
def
deserialize
(
self
,
data
:
str
):
if
not
disable_cache
and
os
.
path
.
exists
(
self
.
cache_file_path
):
# load the cache from the file
with
open
(
self
.
cache_file_path
)
as
f
:
# we use ast.literal_eval to parse the data
# we use ast.literal_eval to parse the data
# because it is a safe way to parse Python literals.
# because it is a safe way to parse Python literals.
# do not use eval(), it is unsafe.
# do not use eval(), it is unsafe.
list_data
=
ast
.
literal_eval
(
data
)
self
.
cache
=
ast
.
literal_eval
(
f
.
read
())
for
item
in
list_data
:
runtime_shape
=
item
[
0
]
self
.
compiler
.
initialize_cache
(
cache_dir
=
cache_dir
,
graph_index
=
item
[
1
]
disable_cache
=
disable_cache
)
hash_str
=
item
[
2
]
# for compatibility of old version,
# where we don't have file_path.
# NOTE: after running the new code, the file_path
# will be updated.
file_path
=
""
if
len
(
item
)
==
3
else
item
[
3
]
self
.
cache
[
runtime_shape
][
graph_index
]
=
InductorArtifact
(
hash_str
=
hash_str
,
file_path
=
file_path
)
def
serialize
(
self
)
->
str
:
data
=
[]
for
runtime_shape
,
value
in
self
.
cache
.
items
():
for
graph_index
,
inductor_artifact
in
value
.
items
():
data
.
append
(
(
runtime_shape
,
graph_index
,
inductor_artifact
.
hash_str
,
inductor_artifact
.
file_path
))
printer
=
pprint
.
PrettyPrinter
(
indent
=
4
)
return
printer
.
pformat
(
data
)
def
save_to_file
(
self
):
def
save_to_file
(
self
):
if
self
.
disable
d
:
if
self
.
disable
_cache
:
return
return
with
open
(
self
.
cache_file_path
,
"w"
)
as
f
:
with
open
(
self
.
cache_file_path
,
"w"
)
as
f
:
f
.
write
(
self
.
serialize
())
printer
=
pprint
.
PrettyPrinter
(
indent
=
4
)
data
=
printer
.
pformat
(
self
.
cache
)
def
__contains__
(
self
,
key
:
Tuple
[
Optional
[
int
],
int
])
->
bool
:
f
.
write
(
data
)
if
self
.
disabled
:
return
False
def
load
(
self
,
runtime_shape
,
graph_index
=
key
graph
:
fx
.
GraphModule
,
return
runtime_shape
in
self
.
cache
and
graph_index
in
self
.
cache
[
example_inputs
:
List
[
Any
],
runtime_shape
]
graph_index
:
int
,
runtime_shape
:
Optional
[
int
]
=
None
)
->
Optional
[
Callable
]:
def
__getitem__
(
self
,
key
:
Tuple
[
Optional
[
int
],
int
])
->
InductorArtifact
:
if
(
runtime_shape
,
graph_index
,
self
.
compiler
.
name
)
not
in
self
.
cache
:
if
self
.
disabled
:
return
None
raise
KeyError
(
"cannot read from disabled cache"
)
handle
=
self
.
cache
[(
runtime_shape
,
graph_index
,
self
.
compiler
.
name
)]
runtime_shape
,
graph_index
=
key
compiled_graph
=
self
.
compiler
.
load
(
handle
,
graph
,
example_inputs
,
return
self
.
cache
[
runtime_shape
][
graph_index
]
graph_index
,
runtime_shape
)
logger
.
debug
(
def
__setitem__
(
self
,
key
:
Tuple
[
Optional
[
int
],
int
],
"Directly load the %s-th graph for shape %s from %s via "
value
:
InductorArtifact
):
"handle %s"
,
graph_index
,
str
(
runtime_shape
),
self
.
compiler
.
name
,
# setitem for disabled cache is fine, because we
handle
)
# don't actually write to the disk
return
compiled_graph
runtime_shape
,
graph_index
=
key
self
.
cache
[
runtime_shape
][
graph_index
]
=
value
class
AlwaysHitShapeEnv
:
"""
Why do we need this class:
For normal `torch.compile` usage, every compilation will have
one Dynamo bytecode compilation and one Inductor compilation.
The Inductor compilation happens under the context of the
Dynamo bytecode compilation, and that context is used to
determine the dynamic shape information, etc.
For our use case, we only run Dynamo bytecode compilation once,
and run Inductor compilation multiple times with different shapes
plus a general shape. The compilation for specific shapes happens
outside of the context of the Dynamo bytecode compilation. At that
time, we don't have shape environment to provide to Inductor, and
it will fail the Inductor code cache lookup.
By providing a dummy shape environment that always hits, we can
make the Inductor code cache lookup always hit, and we can
compile the graph for different shapes as needed.
The following dummy methods are obtained by trial-and-error
until it works.
"""
def
__init__
(
self
)
->
None
:
self
.
guards
:
List
[
Any
]
=
[]
def
evaluate_guards_expression
(
self
,
*
args
,
**
kwargs
):
return
True
def
get_pruned_guards
(
self
,
*
args
,
**
kwargs
):
return
[]
def
produce_guards_expression
(
self
,
*
args
,
**
kwargs
):
return
""
def
wrap_inductor
(
graph
:
fx
.
GraphModule
,
def
compile
(
self
,
graph
:
fx
.
GraphModule
,
example_inputs
,
example_inputs
,
additional_inductor_config
,
additional_inductor_config
,
compilation_config
:
CompilationConfig
,
compilation_config
:
CompilationConfig
,
vllm_backend
:
"VllmBackend"
,
graph_index
:
int
=
0
,
graph_index
:
int
=
0
,
num_graphs
:
int
=
1
,
num_graphs
:
int
=
1
,
runtime_shape
:
Optional
[
int
]
=
None
,
runtime_shape
:
Optional
[
int
]
=
None
)
->
Any
:
use_inductor
:
bool
=
True
)
->
Any
:
if
graph_index
==
0
:
if
graph_index
==
0
:
# before compiling the first graph, record the start time
# before compiling the first graph, record the start time
global
compilation_start_time
global
compilation_start_time
compilation_start_time
=
time
.
time
()
compilation_start_time
=
time
.
time
()
if
not
use_inductor
:
compilation_counter
.
num_backend_compilations
+=
1
return
graph
compilation_counter
.
num_inductor_compilations
+=
1
from
torch._inductor
import
config
compiled_graph
=
None
current_config
=
config
.
get_config_copy
()
from
torch._inductor.compile_fx
import
compile_fx
if
additional_inductor_config
is
not
None
:
# try to load from the cache
current_config
.
update
(
additional_inductor_config
)
compiled_graph
=
self
.
load
(
graph
,
example_inputs
,
graph_index
,
runtime_shape
)
if
isinstance
(
runtime_shape
,
int
):
if
compiled_graph
is
not
None
:
# for a specific batchsize, tuning triton kernel parameters
# can be beneficial
current_config
[
"max_autotune"
]
=
True
current_config
[
"coordinate_descent_tuning"
]
=
True
# inductor can inplace modify the graph, so we need to copy it
# see https://github.com/pytorch/pytorch/issues/138980
graph
=
copy
.
deepcopy
(
graph
)
cache_data
=
vllm_backend
.
inductor_hash_cache
if
(
runtime_shape
,
graph_index
)
in
cache_data
:
# we compiled this graph before
# so we can directly lookup the compiled graph via hash
inductor_artifact
=
cache_data
[(
runtime_shape
,
graph_index
)]
hash_str
=
inductor_artifact
.
hash_str
if
graph_index
==
0
:
if
graph_index
==
0
:
# adds some info logging for the first graph
# adds some info logging for the first graph
logger
.
info
(
logger
.
info
(
"Directly load the compiled graph for shape %s "
"Directly lookup the graph for shape %s from the cache"
,
"from the cache"
,
str
(
runtime_shape
))
# noqa
str
(
runtime_shape
))
# noqa
return
compiled_graph
logger
.
debug
(
"directly lookup the %s-th graph for shape %s via hash %s"
,
graph_index
,
str
(
runtime_shape
),
hash_str
)
from
torch._inductor.codecache
import
FxGraphCache
with
patch
(
"torch._inductor.codecache.FxGraphCache._get_shape_env"
,
lambda
*
args
,
**
kwargs
:
AlwaysHitShapeEnv
()):
inductor_compiled_graph
=
FxGraphCache
.
_lookup_graph
(
hash_str
,
example_inputs
,
True
,
False
)
assert
inductor_compiled_graph
is
not
None
,
(
"Inductor cache lookup failed. Please remove"
f
"the cache file
{
cache_data
.
cache_file_path
}
and try again."
# noqa
)
inductor_artifact
.
file_path
=
inductor_compiled_graph
.
current_callable
.
__code__
.
co_filename
# noqa
# Inductor calling convention (function signature):
# f(list) -> tuple
# Dynamo calling convention (function signature):
# f(*args) -> Any
# need to know if the graph returns a tuple
from
torch._inductor.compile_fx
import
graph_returns_tuple
returns_tuple
=
graph_returns_tuple
(
graph
)
# this is the callable we return to Dynamo to run
def
compiled_graph
(
*
args
):
# convert args to list
list_args
=
list
(
args
)
graph_output
=
inductor_compiled_graph
(
list_args
)
# unpack the tuple if needed
if
returns_tuple
:
return
graph_output
else
:
return
graph_output
[
0
]
else
:
# it's the first time we compile this graph
# the assumption is that we don't have nested Inductor compilation.
# compiled_fx_graph_hash will only be called once, and we can hook
# it to get the hash of the compiled graph directly.
inductor_artifact
=
InductorArtifact
()
from
torch._inductor.codecache
import
(
FxGraphCache
,
compiled_fx_graph_hash
)
original_load
=
FxGraphCache
.
load
def
hijack_load
(
*
args
,
**
kwargs
):
inductor_compiled_graph
=
original_load
(
*
args
,
**
kwargs
)
inductor_artifact
.
file_path
=
inductor_compiled_graph
.
current_callable
.
__code__
.
co_filename
# noqa
return
inductor_compiled_graph
def
hijack_compiled_fx_graph_hash
(
*
args
,
**
kwargs
):
out
=
compiled_fx_graph_hash
(
*
args
,
**
kwargs
)
inductor_artifact
.
hash_str
=
out
[
0
]
return
out
def
_check_can_cache
(
*
args
,
**
kwargs
):
# no error means it can be cached.
# Inductor refuses to cache the graph outside of Dynamo
# tracing context, and also disables caching for graphs
# with high-order ops.
# For vLLM, in either case, we want to cache the graph.
# see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa
return
def
_get_shape_env
()
->
AlwaysHitShapeEnv
:
return
AlwaysHitShapeEnv
()
with
ExitStack
()
as
stack
:
if
not
cache_data
.
disabled
:
# compilation cache is enabled, patch several functions
# hijack to get the compiled graph itself
stack
.
enter_context
(
patch
(
"torch._inductor.codecache.FxGraphCache.load"
,
hijack_load
))
# for hijacking the hash of the compiled graph
stack
.
enter_context
(
patch
(
"torch._inductor.codecache.compiled_fx_graph_hash"
,
hijack_compiled_fx_graph_hash
))
# for providing a dummy shape environment
# no compiler cached the graph, or the cache is disabled,
stack
.
enter_context
(
# we need to compile it
patch
(
compiled_graph
,
handle
=
self
.
compiler
.
compile
(
"torch._inductor.codecache.FxGraphCache._get_shape_env"
,
graph
,
example_inputs
,
additional_inductor_config
,
runtime_shape
)
_get_shape_env
))
# for forcing the graph to be cached
assert
compiled_graph
is
not
None
,
"Failed to compile the graph"
stack
.
enter_context
(
patch
(
"torch._inductor.codecache.FxGraphCache._check_can_cache"
,
_check_can_cache
))
compiled_graph
=
compile_fx
(
graph
,
# store the artifact in the cache
example_inputs
,
if
handle
is
not
None
:
config_patches
=
current_config
)
self
.
cache
[(
runtime_shape
,
graph_index
,
# store the inductor_artifact in the cache
self
.
compiler
.
name
)]
=
handle
cache_data
[(
runtime_shape
,
graph_index
)]
=
inductor_artifact
if
graph_index
==
0
:
if
graph_index
==
0
:
# adds some info logging for the first graph
# adds some info logging for the first graph
logger
.
info
(
"Cache the graph of shape %s for later use"
,
logger
.
info
(
"Cache the graph of shape %s for later use"
,
str
(
runtime_shape
))
str
(
runtime_shape
))
logger
.
debug
(
logger
.
debug
(
"store the %s-th graph for shape %s
via hash %s from fi
le %s"
,
"store the %s-th graph for shape %s
from %s via hand
le %s"
,
graph_index
,
str
(
runtime_shape
),
inductor_artifact
.
hash_str
,
graph_index
,
str
(
runtime_shape
),
self
.
compiler
.
name
,
handle
)
inductor_artifact
.
file_path
)
# after compiling the last graph, record the end time
# after compiling the last graph, record the end time
if
graph_index
==
num_graphs
-
1
:
if
graph_index
==
num_graphs
-
1
:
now
=
time
.
time
()
now
=
time
.
time
()
...
@@ -436,16 +257,15 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
...
@@ -436,16 +257,15 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
i
for
i
,
x
in
enumerate
(
args
)
if
isinstance
(
x
,
torch
.
SymInt
)
i
for
i
,
x
in
enumerate
(
args
)
if
isinstance
(
x
,
torch
.
SymInt
)
]
]
global
compilation_start_time
global
compilation_start_time
compiled_graph_for_general_shape
=
wrap_inductor
(
compiled_graph_for_general_shape
=
self
.
vllm_backend
.
\
compiler_manager
.
compile
(
submod
,
submod
,
args
,
args
,
self
.
compilation_config
.
inductor_compile_config
,
self
.
compilation_config
.
inductor_compile_config
,
self
.
compilation_config
,
self
.
compilation_config
,
self
.
vllm_backend
,
graph_index
=
index
,
graph_index
=
index
,
num_graphs
=
len
(
self
.
compile_submod_names
),
num_graphs
=
len
(
self
.
compile_submod_names
),
runtime_shape
=
None
,
runtime_shape
=
None
)
use_inductor
=
self
.
compilation_config
.
use_inductor
)
self
.
module
.
__dict__
[
target
]
=
PiecewiseBackend
(
self
.
module
.
__dict__
[
target
]
=
PiecewiseBackend
(
submod
,
self
.
vllm_config
,
self
.
graph_pool
,
index
,
submod
,
self
.
vllm_config
,
self
.
graph_pool
,
index
,
...
@@ -483,7 +303,7 @@ class VllmBackend:
...
@@ -483,7 +303,7 @@ class VllmBackend:
post_grad_passes
:
Sequence
[
Callable
]
post_grad_passes
:
Sequence
[
Callable
]
sym_tensor_indices
:
List
[
int
]
sym_tensor_indices
:
List
[
int
]
input_buffers
:
List
[
torch
.
Tensor
]
input_buffers
:
List
[
torch
.
Tensor
]
inductor_hash_cache
:
InductorHashCache
compiler_manager
:
CompilerManager
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -507,6 +327,9 @@ class VllmBackend:
...
@@ -507,6 +327,9 @@ class VllmBackend:
self
.
vllm_config
=
vllm_config
self
.
vllm_config
=
vllm_config
self
.
compilation_config
=
vllm_config
.
compilation_config
self
.
compilation_config
=
vllm_config
.
compilation_config
self
.
compiler_manager
:
CompilerManager
=
CompilerManager
(
self
.
compilation_config
.
use_inductor
)
# `torch.compile` is JIT compiled, so we don't need to
# `torch.compile` is JIT compiled, so we don't need to
# do anything here
# do anything here
...
@@ -533,9 +356,11 @@ class VllmBackend:
...
@@ -533,9 +356,11 @@ class VllmBackend:
# the cache dir will be the same so that we can reuse the compiled
# the cache dir will be the same so that we can reuse the compiled
# graph.
# graph.
factors
=
[]
# 1. factors come from the vllm_config (it mainly summarizes how the
# 1. factors come from the vllm_config (it mainly summarizes how the
# model is created)
# model is created)
config_hash
=
vllm_config
.
compute_hash
()
config_hash
=
vllm_config
.
compute_hash
()
factors
.
append
(
config_hash
)
# 2. factors come from the code files that are traced by Dynamo (
# 2. factors come from the code files that are traced by Dynamo (
# it mainly summarizes how the model is used in forward pass)
# it mainly summarizes how the model is used in forward pass)
...
@@ -553,10 +378,15 @@ class VllmBackend:
...
@@ -553,10 +378,15 @@ class VllmBackend:
import
hashlib
import
hashlib
code_hash
=
hashlib
.
md5
(
code_hash
=
hashlib
.
md5
(
"
\n
"
.
join
(
hash_content
).
encode
()).
hexdigest
()
"
\n
"
.
join
(
hash_content
).
encode
()).
hexdigest
()
factors
.
append
(
code_hash
)
# 3. compiler hash
compiler_hash
=
self
.
compiler_manager
.
compute_hash
(
vllm_config
)
factors
.
append
(
compiler_hash
)
# combine all factors to generate the cache dir
hash_key
=
hashlib
.
md5
(
str
(
factors
).
encode
()).
hexdigest
()[:
10
]
# combine the two hashes to generate the cache dir
hash_key
=
hashlib
.
md5
(
f
"
{
config_hash
}
_
{
code_hash
}
"
.
encode
()).
hexdigest
()[:
10
]
cache_dir
=
os
.
path
.
join
(
cache_dir
=
os
.
path
.
join
(
envs
.
VLLM_CACHE_ROOT
,
envs
.
VLLM_CACHE_ROOT
,
"torch_compile_cache"
,
"torch_compile_cache"
,
...
@@ -570,15 +400,16 @@ class VllmBackend:
...
@@ -570,15 +400,16 @@ class VllmBackend:
cache_dir
,
f
"rank_
{
vllm_config
.
parallel_config
.
rank
}
"
)
cache_dir
,
f
"rank_
{
vllm_config
.
parallel_config
.
rank
}
"
)
self
.
compilation_config
.
local_cache_dir
=
local_cache_dir
self
.
compilation_config
.
local_cache_dir
=
local_cache_dir
disabled
=
envs
.
VLLM_DISABLE_COMPILE_CACHE
disable_cache
=
envs
.
VLLM_DISABLE_COMPILE_CACHE
self
.
inductor_hash_cache
:
InductorHashCache
=
InductorHashCache
(
local_cache_dir
,
disabled
=
disabled
)
if
disable_cache
:
if
disabled
:
logger
.
info
(
"vLLM's torch.compile cache is disabled."
)
logger
.
info
(
"vLLM's torch.compile cache is disabled."
)
else
:
else
:
logger
.
info
(
"Using cache directory: %s for vLLM's torch.compile"
,
logger
.
info
(
"Using cache directory: %s for vLLM's torch.compile"
,
local_cache_dir
)
local_cache_dir
)
self
.
compiler_manager
.
initialize_cache
(
local_cache_dir
,
disable_cache
)
# when dynamo calls the backend, it means the bytecode
# when dynamo calls the backend, it means the bytecode
# transform and analysis are done
# transform and analysis are done
compilation_counter
.
num_graphs_seen
+=
1
compilation_counter
.
num_graphs_seen
+=
1
...
@@ -759,7 +590,7 @@ class PiecewiseBackend:
...
@@ -759,7 +590,7 @@ class PiecewiseBackend:
if
self
.
is_last_graph
and
not
self
.
to_be_compiled_sizes
:
if
self
.
is_last_graph
and
not
self
.
to_be_compiled_sizes
:
# no specific sizes to compile
# no specific sizes to compile
# save the hash of the inductor graph for the next run
# save the hash of the inductor graph for the next run
self
.
vllm_backend
.
inductor_hash_cache
.
save_to_file
()
self
.
vllm_backend
.
compiler_manager
.
save_to_file
()
end_monitoring_torch_compile
(
self
.
vllm_config
)
end_monitoring_torch_compile
(
self
.
vllm_config
)
def
__call__
(
self
,
*
args
)
->
Any
:
def
__call__
(
self
,
*
args
)
->
Any
:
...
@@ -782,16 +613,14 @@ class PiecewiseBackend:
...
@@ -782,16 +613,14 @@ class PiecewiseBackend:
entry
.
compiled
=
True
entry
.
compiled
=
True
self
.
to_be_compiled_sizes
.
remove
(
runtime_shape
)
self
.
to_be_compiled_sizes
.
remove
(
runtime_shape
)
# args are real arguments
# args are real arguments
entry
.
runnable
=
wrap_inductor
(
entry
.
runnable
=
self
.
vllm_backend
.
compiler_manager
.
compile
(
self
.
graph
,
self
.
graph
,
args
,
args
,
self
.
compilation_config
.
inductor_compile_config
,
self
.
compilation_config
.
inductor_compile_config
,
self
.
compilation_config
,
self
.
compilation_config
,
self
.
vllm_backend
,
graph_index
=
self
.
piecewise_compile_index
,
graph_index
=
self
.
piecewise_compile_index
,
num_graphs
=
self
.
total_piecewise_compiles
,
num_graphs
=
self
.
total_piecewise_compiles
,
runtime_shape
=
runtime_shape
,
runtime_shape
=
runtime_shape
)
use_inductor
=
self
.
compilation_config
.
use_inductor
)
# finished compilations for all required shapes
# finished compilations for all required shapes
if
self
.
is_last_graph
and
not
self
.
to_be_compiled_sizes
:
if
self
.
is_last_graph
and
not
self
.
to_be_compiled_sizes
:
...
...
vllm/compilation/compiler_interface.py
0 → 100644
View file @
09b95e36
# SPDX-License-Identifier: Apache-2.0
import
copy
import
hashlib
import
os
from
contextlib
import
ExitStack
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
from
unittest.mock
import
patch
import
torch
import
torch._inductor.compile_fx
import
torch.fx
as
fx
from
vllm.config
import
VllmConfig
class
CompilerInterface
:
"""
The interface for a compiler that can be used by vLLM.
"""
# The name of the compiler, e.g. inductor.
# This is a class-level attribute.
name
:
str
def
initialize_cache
(
self
,
cache_dir
:
str
,
disable_cache
:
bool
=
False
):
"""
when the vLLM process uses `cache_dir` as the cache directory,
the compiler should initialize itself with the cache directory,
e.g. by re-directing its own cache directory to a sub-directory.
"""
pass
def
compute_hash
(
self
,
vllm_config
:
VllmConfig
)
->
str
:
"""
Gather all the relevant information from the VLLM config,
to compute a hash so that we can cache the compiled model.
See :meth:`VllmConfig.compute_hash` to check what information
is already considered by default. This function should only
consider the information that is specific to the compiler.
"""
return
""
def
compile
(
self
,
graph
:
fx
.
GraphModule
,
example_inputs
:
List
[
Any
],
compiler_config
:
Dict
[
str
,
Any
],
runtime_shape
:
Optional
[
int
]
=
None
)
->
Tuple
[
Optional
[
Callable
],
Optional
[
Any
]]:
"""
Compile the graph with the given example inputs and compiler config,
with a runtime shape. If the `runtime_shape` is None, it means
the `example_inputs` have a dynamic shape. Otherwise, the
`runtime_shape` specifies the shape of the inputs. Right now we only
support one variable shape for all inputs, which is the batchsize
(number of tokens) during inference.
Dynamo will make sure `graph(*example_inputs)` is valid.
The function should return a compiled callable function, as well as
a handle that can be used to directly load the compiled function.
The handle should be a plain Python object, preferably a string or a
file path for readability.
If the compiler doesn't support caching, it should return None for the
handle. If the compiler fails to compile the graph, it should return
None for the compiled function as well.
"""
return
None
,
None
def
load
(
self
,
handle
:
Any
,
graph
:
fx
.
GraphModule
,
example_inputs
:
List
[
Any
],
graph_index
:
int
,
runtime_shape
:
Optional
[
int
]
=
None
)
->
Callable
:
"""
Load the compiled function from the handle.
Raises an error if the handle is invalid.
The handle is the second return value of the `compile` function.
"""
raise
NotImplementedError
(
"caching is not supported"
)
class
AlwaysHitShapeEnv
:
"""
Why do we need this class:
For normal `torch.compile` usage, every compilation will have
one Dynamo bytecode compilation and one Inductor compilation.
The Inductor compilation happens under the context of the
Dynamo bytecode compilation, and that context is used to
determine the dynamic shape information, etc.
For our use case, we only run Dynamo bytecode compilation once,
and run Inductor compilation multiple times with different shapes
plus a general shape. The compilation for specific shapes happens
outside of the context of the Dynamo bytecode compilation. At that
time, we don't have shape environment to provide to Inductor, and
it will fail the Inductor code cache lookup.
By providing a dummy shape environment that always hits, we can
make the Inductor code cache lookup always hit, and we can
compile the graph for different shapes as needed.
The following dummy methods are obtained by trial-and-error
until it works.
"""
def
__init__
(
self
)
->
None
:
self
.
guards
:
List
[
Any
]
=
[]
def
evaluate_guards_expression
(
self
,
*
args
,
**
kwargs
):
return
True
def
get_pruned_guards
(
self
,
*
args
,
**
kwargs
):
return
[]
def
produce_guards_expression
(
self
,
*
args
,
**
kwargs
):
return
""
class
InductorAdaptor
(
CompilerInterface
):
"""
The adaptor for the Inductor compiler, version 2.5 and 2.6.
"""
name
=
"inductor"
def
compute_hash
(
self
,
vllm_config
:
VllmConfig
)
->
str
:
factors
:
List
[
Any
]
=
[]
# summarize system state
from
torch._inductor.codecache
import
CacheBase
system_factors
=
CacheBase
.
get_system
()
factors
.
append
(
system_factors
)
# summarize pytorch state
from
torch._inductor.codecache
import
torch_key
torch_factors
=
torch_key
()
factors
.
append
(
torch_factors
)
hash_str
=
hashlib
.
md5
(
str
(
factors
).
encode
()).
hexdigest
()[:
10
]
return
hash_str
def
initialize_cache
(
self
,
cache_dir
:
str
,
disable_cache
:
bool
=
False
):
if
disable_cache
:
return
# redirect the cache directory to a sub-directory
# set flags so that Inductor and Triton store their cache
# in the cache_dir, then users only need to copy the cache_dir
# to another machine to reuse the cache.
inductor_cache
=
os
.
path
.
join
(
cache_dir
,
"inductor_cache"
)
os
.
makedirs
(
inductor_cache
,
exist_ok
=
True
)
os
.
environ
[
"TORCHINDUCTOR_CACHE_DIR"
]
=
inductor_cache
triton_cache
=
os
.
path
.
join
(
cache_dir
,
"triton_cache"
)
os
.
makedirs
(
triton_cache
,
exist_ok
=
True
)
os
.
environ
[
"TRITON_CACHE_DIR"
]
=
triton_cache
def
compile
(
self
,
graph
:
fx
.
GraphModule
,
example_inputs
:
List
[
Any
],
compiler_config
:
Dict
[
str
,
Any
],
runtime_shape
:
Optional
[
int
]
=
None
)
->
Tuple
[
Optional
[
Callable
],
Optional
[
Any
]]:
from
torch._inductor
import
config
current_config
=
config
.
get_config_copy
()
from
torch._inductor.compile_fx
import
compile_fx
# disable remote cache
current_config
[
"fx_graph_cache"
]
=
True
current_config
[
"fx_graph_remote_cache"
]
=
False
if
compiler_config
is
not
None
:
current_config
.
update
(
compiler_config
)
if
isinstance
(
runtime_shape
,
int
):
# for a specific batchsize, tuning triton kernel parameters
# can be beneficial
current_config
[
"max_autotune"
]
=
True
current_config
[
"coordinate_descent_tuning"
]
=
True
# inductor can inplace modify the graph, so we need to copy it
# see https://github.com/pytorch/pytorch/issues/138980
graph
=
copy
.
deepcopy
(
graph
)
# it's the first time we compile this graph
# the assumption is that we don't have nested Inductor compilation.
# compiled_fx_graph_hash will only be called once, and we can hook
# it to get the hash of the compiled graph directly.
hash_str
,
file_path
=
None
,
None
from
torch._inductor.codecache
import
(
FxGraphCache
,
compiled_fx_graph_hash
)
if
torch
.
__version__
.
startswith
(
"2.5"
):
original_load
=
FxGraphCache
.
load
original_load_name
=
"torch._inductor.codecache.FxGraphCache.load"
def
hijack_load
(
*
args
,
**
kwargs
):
inductor_compiled_graph
=
original_load
(
*
args
,
**
kwargs
)
nonlocal
file_path
file_path
=
inductor_compiled_graph
.
current_callable
.
__code__
.
co_filename
# noqa
return
inductor_compiled_graph
hijacked_compile_fx_inner
=
torch
.
_inductor
.
compile_fx
.
compile_fx_inner
# noqa
elif
torch
.
__version__
>=
"2.6"
:
# function renamed in 2.6
original_load_name
=
None
def
hijacked_compile_fx_inner
(
*
args
,
**
kwargs
):
output
=
torch
.
_inductor
.
compile_fx
.
compile_fx_inner
(
*
args
,
**
kwargs
)
nonlocal
hash_str
inductor_compiled_graph
=
output
if
inductor_compiled_graph
is
not
None
:
nonlocal
file_path
file_path
=
inductor_compiled_graph
.
current_callable
.
__code__
.
co_filename
# noqa
hash_str
=
inductor_compiled_graph
.
_fx_graph_cache_key
return
output
def
hijack_compiled_fx_graph_hash
(
*
args
,
**
kwargs
):
out
=
compiled_fx_graph_hash
(
*
args
,
**
kwargs
)
nonlocal
hash_str
hash_str
=
out
[
0
]
return
out
def
_check_can_cache
(
*
args
,
**
kwargs
):
# no error means it can be cached.
# Inductor refuses to cache the graph outside of Dynamo
# tracing context, and also disables caching for graphs
# with high-order ops.
# For vLLM, in either case, we want to cache the graph.
# see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa
return
def
_get_shape_env
()
->
AlwaysHitShapeEnv
:
return
AlwaysHitShapeEnv
()
with
ExitStack
()
as
stack
:
# hijack to get the compiled graph itself
if
original_load_name
is
not
None
:
stack
.
enter_context
(
patch
(
original_load_name
,
hijack_load
))
# for hijacking the hash of the compiled graph
stack
.
enter_context
(
patch
(
"torch._inductor.codecache.compiled_fx_graph_hash"
,
hijack_compiled_fx_graph_hash
))
# for providing a dummy shape environment
stack
.
enter_context
(
patch
(
"torch._inductor.codecache.FxGraphCache._get_shape_env"
,
_get_shape_env
))
# for forcing the graph to be cached
stack
.
enter_context
(
patch
(
"torch._inductor.codecache.FxGraphCache._check_can_cache"
,
_check_can_cache
))
compiled_graph
=
compile_fx
(
graph
,
example_inputs
,
inner_compile
=
hijacked_compile_fx_inner
,
config_patches
=
current_config
)
assert
hash_str
is
not
None
,
(
"failed to get the hash of the compiled graph"
)
assert
file_path
is
not
None
,
(
"failed to get the file path of the compiled graph"
)
return
compiled_graph
,
(
hash_str
,
file_path
)
def
load
(
self
,
handle
:
Any
,
graph
:
fx
.
GraphModule
,
example_inputs
:
List
[
Any
],
graph_index
:
int
,
runtime_shape
:
Optional
[
int
]
=
None
)
->
Callable
:
assert
isinstance
(
handle
,
tuple
)
assert
isinstance
(
handle
[
0
],
str
)
assert
isinstance
(
handle
[
1
],
str
)
hash_str
=
handle
[
0
]
from
torch._inductor.codecache
import
FxGraphCache
with
patch
(
"torch._inductor.codecache.FxGraphCache._get_shape_env"
,
lambda
*
args
,
**
kwargs
:
AlwaysHitShapeEnv
()):
if
torch
.
__version__
.
startswith
(
"2.5"
):
inductor_compiled_graph
=
FxGraphCache
.
_lookup_graph
(
hash_str
,
example_inputs
,
True
,
False
)
assert
inductor_compiled_graph
is
not
None
,
(
"Inductor cache lookup failed. Please remove"
f
"the cache directory and try again."
# noqa
)
elif
torch
.
__version__
>=
"2.6"
:
from
torch._inductor.output_code
import
(
CompiledFxGraphConstantsWithGm
)
constants
=
CompiledFxGraphConstantsWithGm
(
graph
)
inductor_compiled_graph
,
_
=
FxGraphCache
.
_lookup_graph
(
hash_str
,
example_inputs
,
True
,
None
,
constants
)
assert
inductor_compiled_graph
is
not
None
,
(
"Inductor cache lookup failed. Please remove"
f
"the cache directory and try again."
# noqa
)
# Inductor calling convention (function signature):
# f(list) -> tuple
# Dynamo calling convention (function signature):
# f(*args) -> Any
# need to know if the graph returns a tuple
from
torch._inductor.compile_fx
import
graph_returns_tuple
returns_tuple
=
graph_returns_tuple
(
graph
)
# this is the callable we return to Dynamo to run
def
compiled_graph
(
*
args
):
# convert args to list
list_args
=
list
(
args
)
graph_output
=
inductor_compiled_graph
(
list_args
)
# unpack the tuple if needed
if
returns_tuple
:
return
graph_output
else
:
return
graph_output
[
0
]
return
compiled_graph
class
EagerAdaptor
(
CompilerInterface
):
name
=
"eager"
def
compile
(
self
,
graph
:
fx
.
GraphModule
,
example_inputs
:
List
[
Any
],
compiler_config
:
Dict
[
str
,
Any
],
runtime_shape
:
Optional
[
int
]
=
None
)
->
Tuple
[
Optional
[
Callable
],
Optional
[
Any
]]:
# we don't need to compile the graph, just return the graph itself.
# It does not support caching, return None for the handle.
return
graph
,
None
vllm/compilation/counter.py
View file @
09b95e36
...
@@ -13,7 +13,7 @@ class CompilationCounter:
...
@@ -13,7 +13,7 @@ class CompilationCounter:
num_piecewise_graphs_seen
:
int
=
0
num_piecewise_graphs_seen
:
int
=
0
# not including the splitting ops
# not including the splitting ops
num_piecewise_capturable_graphs_seen
:
int
=
0
num_piecewise_capturable_graphs_seen
:
int
=
0
num_
inductor
_compilations
:
int
=
0
num_
backend
_compilations
:
int
=
0
num_cudagraph_caputured
:
int
=
0
num_cudagraph_caputured
:
int
=
0
def
clone
(
self
)
->
"CompilationCounter"
:
def
clone
(
self
)
->
"CompilationCounter"
:
...
...
vllm/compilation/inductor_pass.py
View file @
09b95e36
...
@@ -13,7 +13,6 @@ from torch import fx
...
@@ -13,7 +13,6 @@ from torch import fx
class
InductorPass
(
ABC
):
class
InductorPass
(
ABC
):
"""
"""
General custom inductor pass interface.
General custom inductor pass interface.
TODO(torch==2.6) use torch._inductor.custom_graph_pass.CustomGraphPass
"""
"""
@
abstractmethod
@
abstractmethod
...
...
vllm/compilation/pass_manager.py
View file @
09b95e36
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
from
typing
import
Any
,
Dict
,
List
from
typing
import
Any
,
Dict
,
List
import
torch
from
torch
import
fx
as
fx
from
torch
import
fx
as
fx
from
vllm.config
import
CompilationConfig
from
vllm.config
import
CompilationConfig
...
@@ -15,7 +16,17 @@ from .reshapes import RedundantReshapesPass
...
@@ -15,7 +16,17 @@ from .reshapes import RedundantReshapesPass
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
class
PostGradPassManager
:
class
PlaceHolder
:
pass
if
torch
.
__version__
<
"2.6"
:
Parent
=
PlaceHolder
# type: ignore
else
:
Parent
=
torch
.
_inductor
.
custom_graph_pass
.
CustomGraphPass
# type: ignore
class
PostGradPassManager
(
Parent
):
"""
"""
The pass manager for post-grad passes.
The pass manager for post-grad passes.
It handles configuration, adding custom passes, and running passes.
It handles configuration, adding custom passes, and running passes.
...
@@ -55,6 +66,9 @@ class PostGradPassManager:
...
@@ -55,6 +66,9 @@ class PostGradPassManager:
assert
isinstance
(
pass_
,
InductorPass
)
assert
isinstance
(
pass_
,
InductorPass
)
self
.
passes
.
append
(
pass_
)
self
.
passes
.
append
(
pass_
)
def
uuid
(
self
):
return
self
.
__getstate__
()
def
__getstate__
(
self
)
->
Dict
[
str
,
List
[
Any
]]:
def
__getstate__
(
self
)
->
Dict
[
str
,
List
[
Any
]]:
"""
"""
Custom pickling for the pass manager, as some passes cannot be pickled.
Custom pickling for the pass manager, as some passes cannot be pickled.
...
...
vllm/config.py
View file @
09b95e36
...
@@ -3072,15 +3072,6 @@ class VllmConfig:
...
@@ -3072,15 +3072,6 @@ class VllmConfig:
the final hidden states.
the final hidden states.
"""
"""
factors
:
List
[
Any
]
=
[]
factors
:
List
[
Any
]
=
[]
# summarize system state
from
torch._inductor.codecache
import
CacheBase
system_factors
=
CacheBase
.
get_system
()
factors
.
append
(
system_factors
)
# summarize pytorch state
from
torch._inductor.codecache
import
torch_key
torch_factors
=
torch_key
()
factors
.
append
(
torch_factors
)
# summarize vllm config
# summarize vllm config
vllm_factors
:
List
[
Any
]
=
[]
vllm_factors
:
List
[
Any
]
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment