Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e66faf48
Unverified
Commit
e66faf48
authored
Jan 19, 2025
by
youkaichao
Committed by
GitHub
Jan 19, 2025
Browse files
[torch.compile] store inductor compiled Python file (#12182)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
630eb5b5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
60 additions
and
33 deletions
+60
-33
vllm/compilation/backends.py
vllm/compilation/backends.py
+58
-22
vllm/config.py
vllm/config.py
+2
-11
No files found.
vllm/compilation/backends.py
View file @
e66faf48
...
@@ -25,23 +25,30 @@ from .pass_manager import PostGradPassManager
...
@@ -25,23 +25,30 @@ from .pass_manager import PostGradPassManager
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
@
dataclasses
.
dataclass
class
InductorArtifact
:
hash_str
:
str
=
""
file_path
:
str
=
""
class
InductorHashCache
:
class
InductorHashCache
:
"""
"""
Disk format: a Python list of tuples, each tuple is
Disk format: a Python list of tuples, each tuple is
(runtime_shape, graph_index, hash_str)
(runtime_shape, graph_index, hash_str
, file_path
)
We use list of tuple for readability.
We use list of tuple for readability.
In-memory format: a defaultdict of dict, where the key is
In-memory format: a defaultdict of dict, where the key is
runtime_shape, and the value is a dict of graph_index to hash_str.
runtime_shape, and the value is a dict of graph_index to hash_str.
The data is essentially `Dict[Optional[int], Dict[int,
str
]]`,
The data is essentially `Dict[Optional[int], Dict[int,
InductorArtifact
]]`,
we don't use json here because json doesn't support int as key.
we don't use json here because json doesn't support int as key.
TODO: better off-the-shelf solution to serialize the data?
TODO: better off-the-shelf solution to serialize the data?
"""
"""
def
__init__
(
self
,
cache_dir
:
str
,
disabled
:
bool
=
False
):
def
__init__
(
self
,
cache_dir
:
str
,
disabled
:
bool
=
False
):
self
.
cache
:
defaultdict
=
defaultdict
(
dict
)
self
.
cache
:
Dict
[
Optional
[
int
],
Dict
[
int
,
InductorArtifact
]]
=
defaultdict
(
dict
)
self
.
disabled
=
disabled
self
.
disabled
=
disabled
self
.
cache_dir
=
cache_dir
self
.
cache_dir
=
cache_dir
self
.
cache_file_path
=
os
.
path
.
join
(
cache_dir
,
self
.
cache_file_path
=
os
.
path
.
join
(
cache_dir
,
...
@@ -66,14 +73,25 @@ class InductorHashCache:
...
@@ -66,14 +73,25 @@ class InductorHashCache:
# because it is a safe way to parse Python literals.
# because it is a safe way to parse Python literals.
# do not use eval(), it is unsafe.
# do not use eval(), it is unsafe.
list_data
=
ast
.
literal_eval
(
data
)
list_data
=
ast
.
literal_eval
(
data
)
for
runtime_shape
,
graph_index
,
hash_str
in
list_data
:
for
item
in
list_data
:
self
.
cache
[
runtime_shape
][
graph_index
]
=
hash_str
runtime_shape
=
item
[
0
]
graph_index
=
item
[
1
]
hash_str
=
item
[
2
]
# for compatibility of old version,
# where we don't have file_path.
# NOTE: after running the new code, the file_path
# will be updated.
file_path
=
""
if
len
(
item
)
==
3
else
item
[
3
]
self
.
cache
[
runtime_shape
][
graph_index
]
=
InductorArtifact
(
hash_str
=
hash_str
,
file_path
=
file_path
)
def
serialize
(
self
)
->
str
:
def
serialize
(
self
)
->
str
:
data
=
[]
data
=
[]
for
runtime_shape
,
graph_index_to_hash_str
in
self
.
cache
.
items
():
for
runtime_shape
,
value
in
self
.
cache
.
items
():
for
graph_index
,
hash_str
in
graph_index_to_hash_str
.
items
():
for
graph_index
,
inductor_artifact
in
value
.
items
():
data
.
append
((
runtime_shape
,
graph_index
,
hash_str
))
data
.
append
(
(
runtime_shape
,
graph_index
,
inductor_artifact
.
hash_str
,
inductor_artifact
.
file_path
))
printer
=
pprint
.
PrettyPrinter
(
indent
=
4
)
printer
=
pprint
.
PrettyPrinter
(
indent
=
4
)
return
printer
.
pformat
(
data
)
return
printer
.
pformat
(
data
)
...
@@ -90,13 +108,14 @@ class InductorHashCache:
...
@@ -90,13 +108,14 @@ class InductorHashCache:
return
runtime_shape
in
self
.
cache
and
graph_index
in
self
.
cache
[
return
runtime_shape
in
self
.
cache
and
graph_index
in
self
.
cache
[
runtime_shape
]
runtime_shape
]
def
__getitem__
(
self
,
key
:
Tuple
[
Optional
[
int
],
int
])
->
str
:
def
__getitem__
(
self
,
key
:
Tuple
[
Optional
[
int
],
int
])
->
InductorArtifact
:
if
self
.
disabled
:
if
self
.
disabled
:
raise
KeyError
(
"cannot read from disabled cache"
)
raise
KeyError
(
"cannot read from disabled cache"
)
runtime_shape
,
graph_index
=
key
runtime_shape
,
graph_index
=
key
return
self
.
cache
[
runtime_shape
][
graph_index
]
return
self
.
cache
[
runtime_shape
][
graph_index
]
def
__setitem__
(
self
,
key
:
Tuple
[
Optional
[
int
],
int
],
value
:
str
):
def
__setitem__
(
self
,
key
:
Tuple
[
Optional
[
int
],
int
],
value
:
InductorArtifact
):
# setitem for disabled cache is fine, because we
# setitem for disabled cache is fine, because we
# don't actually write to the disk
# don't actually write to the disk
runtime_shape
,
graph_index
=
key
runtime_shape
,
graph_index
=
key
...
@@ -181,7 +200,8 @@ def wrap_inductor(graph: fx.GraphModule,
...
@@ -181,7 +200,8 @@ def wrap_inductor(graph: fx.GraphModule,
if
(
runtime_shape
,
graph_index
)
in
cache_data
:
if
(
runtime_shape
,
graph_index
)
in
cache_data
:
# we compiled this graph before
# we compiled this graph before
# so we can directly lookup the compiled graph via hash
# so we can directly lookup the compiled graph via hash
hash_str
=
cache_data
[(
runtime_shape
,
graph_index
)]
inductor_artifact
=
cache_data
[(
runtime_shape
,
graph_index
)]
hash_str
=
inductor_artifact
.
hash_str
if
graph_index
==
0
:
if
graph_index
==
0
:
# adds some info logging for the first graph
# adds some info logging for the first graph
logger
.
info
(
logger
.
info
(
...
@@ -199,6 +219,7 @@ def wrap_inductor(graph: fx.GraphModule,
...
@@ -199,6 +219,7 @@ def wrap_inductor(graph: fx.GraphModule,
"Inductor cache lookup failed. Please remove"
"Inductor cache lookup failed. Please remove"
f
"the cache file
{
cache_data
.
cache_file_path
}
and try again."
# noqa
f
"the cache file
{
cache_data
.
cache_file_path
}
and try again."
# noqa
)
)
inductor_artifact
.
file_path
=
inductor_compiled_graph
.
current_callable
.
__code__
.
co_filename
# noqa
# Inductor calling convention (function signature):
# Inductor calling convention (function signature):
# f(list) -> tuple
# f(list) -> tuple
...
@@ -224,19 +245,20 @@ def wrap_inductor(graph: fx.GraphModule,
...
@@ -224,19 +245,20 @@ def wrap_inductor(graph: fx.GraphModule,
# the assumption is that we don't have nested Inductor compilation.
# the assumption is that we don't have nested Inductor compilation.
# compiled_fx_graph_hash will only be called once, and we can hook
# compiled_fx_graph_hash will only be called once, and we can hook
# it to get the hash of the compiled graph directly.
# it to get the hash of the compiled graph directly.
from
torch._inductor.codecache
import
compiled_fx_graph_hash
inductor_artifact
=
InductorArtifact
()
from
torch._inductor.codecache
import
(
FxGraphCache
,
compiled_fx_graph_hash
)
original_load
=
FxGraphCache
.
load
def
hijack_load
(
*
args
,
**
kwargs
):
inductor_compiled_graph
=
original_load
(
*
args
,
**
kwargs
)
inductor_artifact
.
file_path
=
inductor_compiled_graph
.
current_callable
.
__code__
.
co_filename
# noqa
return
inductor_compiled_graph
def
hijack_compiled_fx_graph_hash
(
*
args
,
**
kwargs
):
def
hijack_compiled_fx_graph_hash
(
*
args
,
**
kwargs
):
out
=
compiled_fx_graph_hash
(
*
args
,
**
kwargs
)
out
=
compiled_fx_graph_hash
(
*
args
,
**
kwargs
)
# store the hash in the cache
inductor_artifact
.
hash_str
=
out
[
0
]
nonlocal
cache_data
cache_data
[(
runtime_shape
,
graph_index
)]
=
out
[
0
]
if
graph_index
==
0
:
# adds some info logging for the first graph
logger
.
info
(
"Cache the graph of shape %s for later use"
,
str
(
runtime_shape
))
logger
.
debug
(
"store the %s-th graph for shape %s via hash %s"
,
graph_index
,
str
(
runtime_shape
),
out
[
0
])
return
out
return
out
def
_check_can_cache
(
*
args
,
**
kwargs
):
def
_check_can_cache
(
*
args
,
**
kwargs
):
...
@@ -255,6 +277,11 @@ def wrap_inductor(graph: fx.GraphModule,
...
@@ -255,6 +277,11 @@ def wrap_inductor(graph: fx.GraphModule,
if
not
cache_data
.
disabled
:
if
not
cache_data
.
disabled
:
# compilation cache is enabled, patch several functions
# compilation cache is enabled, patch several functions
# hijack to get the compiled graph itself
stack
.
enter_context
(
patch
(
"torch._inductor.codecache.FxGraphCache.load"
,
hijack_load
))
# for hijacking the hash of the compiled graph
# for hijacking the hash of the compiled graph
stack
.
enter_context
(
stack
.
enter_context
(
patch
(
"torch._inductor.codecache.compiled_fx_graph_hash"
,
patch
(
"torch._inductor.codecache.compiled_fx_graph_hash"
,
...
@@ -275,7 +302,16 @@ def wrap_inductor(graph: fx.GraphModule,
...
@@ -275,7 +302,16 @@ def wrap_inductor(graph: fx.GraphModule,
compiled_graph
=
compile_fx
(
graph
,
compiled_graph
=
compile_fx
(
graph
,
example_inputs
,
example_inputs
,
config_patches
=
current_config
)
config_patches
=
current_config
)
# store the inductor_artifact in the cache
cache_data
[(
runtime_shape
,
graph_index
)]
=
inductor_artifact
if
graph_index
==
0
:
# adds some info logging for the first graph
logger
.
info
(
"Cache the graph of shape %s for later use"
,
str
(
runtime_shape
))
logger
.
debug
(
"store the %s-th graph for shape %s via hash %s from file %s"
,
graph_index
,
str
(
runtime_shape
),
inductor_artifact
.
hash_str
,
inductor_artifact
.
file_path
)
# after compiling the last graph, record the end time
# after compiling the last graph, record the end time
if
graph_index
==
num_graphs
-
1
:
if
graph_index
==
num_graphs
-
1
:
now
=
time
.
time
()
now
=
time
.
time
()
...
...
vllm/config.py
View file @
e66faf48
...
@@ -2862,17 +2862,8 @@ class CompilationConfig(BaseModel):
...
@@ -2862,17 +2862,8 @@ class CompilationConfig(BaseModel):
"vllm.unified_attention_with_output"
,
"vllm.unified_attention_with_output"
,
]
]
else
:
else
:
# v0 can use full graph compilation without splitting,
# v0 uses full graph compilation
# splitting is optional.
self
.
splitting_ops
=
[]
# right now we still need it. kv cache shape
# will be included in the graph if we don't split
# the graph.
# TODO: hide kv cache in static forward context
# so that inductor does not see it.
self
.
splitting_ops
=
[
"vllm.unified_attention"
,
"vllm.unified_attention_with_output"
,
]
for
k
,
v
in
self
.
inductor_passes
.
items
():
for
k
,
v
in
self
.
inductor_passes
.
items
():
if
not
isinstance
(
v
,
str
):
if
not
isinstance
(
v
,
str
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment