Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f1214117
Unverified
Commit
f1214117
authored
Jan 08, 2025
by
youkaichao
Committed by
GitHub
Jan 08, 2025
Browse files
[torch.compile] consider relevant code in compilation cache (#11614)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
cfd3219f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
99 additions
and
35 deletions
+99
-35
vllm/compilation/backends.py
vllm/compilation/backends.py
+62
-8
vllm/compilation/decorators.py
vllm/compilation/decorators.py
+27
-1
vllm/config.py
vllm/config.py
+3
-26
vllm/sequence.py
vllm/sequence.py
+7
-0
No files found.
vllm/compilation/backends.py
View file @
f1214117
...
...
@@ -145,6 +145,7 @@ def wrap_inductor(graph: fx.GraphModule,
example_inputs
,
additional_inductor_config
,
compilation_config
:
CompilationConfig
,
vllm_backend
:
"VllmBackend"
,
graph_index
:
int
=
0
,
num_graphs
:
int
=
1
,
runtime_shape
:
Optional
[
int
]
=
None
,
...
...
@@ -176,7 +177,7 @@ def wrap_inductor(graph: fx.GraphModule,
# see https://github.com/pytorch/pytorch/issues/138980
graph
=
copy
.
deepcopy
(
graph
)
cache_data
=
compilation_config
.
inductor_hash_cache
cache_data
=
vllm_backend
.
inductor_hash_cache
if
(
runtime_shape
,
graph_index
)
in
cache_data
:
# we compiled this graph before
# so we can directly lookup the compiled graph via hash
...
...
@@ -196,7 +197,7 @@ def wrap_inductor(graph: fx.GraphModule,
hash_str
,
example_inputs
,
True
,
False
)
assert
inductor_compiled_graph
is
not
None
,
(
"Inductor cache lookup failed. Please remove"
f
"the cache file
{
c
ompilation_config
.
inductor_hash_cache
.
cache_file_path
}
and try again."
# noqa
f
"the cache file
{
c
ache_data
.
cache_file_path
}
and try again."
# noqa
)
# Inductor calling convention (function signature):
...
...
@@ -354,7 +355,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
def
__init__
(
self
,
module
:
torch
.
fx
.
GraphModule
,
compile_submod_names
:
List
[
str
],
vllm_config
:
VllmConfig
,
graph_pool
):
graph_pool
,
vllm_backend
:
"VllmBackend"
):
super
().
__init__
(
module
)
from
torch._guards
import
detect_fake_mode
self
.
fake_mode
=
detect_fake_mode
()
...
...
@@ -362,6 +363,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
self
.
compilation_config
=
vllm_config
.
compilation_config
self
.
graph_pool
=
graph_pool
self
.
vllm_config
=
vllm_config
self
.
vllm_backend
=
vllm_backend
def
run
(
self
,
*
args
):
fake_args
=
[
...
...
@@ -389,6 +391,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
args
,
self
.
compilation_config
.
inductor_compile_config
,
self
.
compilation_config
,
self
.
vllm_backend
,
graph_index
=
index
,
num_graphs
=
len
(
self
.
compile_submod_names
),
runtime_shape
=
None
,
...
...
@@ -397,7 +400,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
self
.
module
.
__dict__
[
target
]
=
PiecewiseBackend
(
submod
,
self
.
vllm_config
,
self
.
graph_pool
,
index
,
len
(
self
.
compile_submod_names
),
sym_shape_indices
,
compiled_graph_for_general_shape
)
compiled_graph_for_general_shape
,
self
.
vllm_backend
)
compilation_counter
.
num_piecewise_capturable_graphs_seen
+=
1
...
...
@@ -430,6 +433,7 @@ class VllmBackend:
post_grad_passes
:
Sequence
[
Callable
]
sym_tensor_indices
:
List
[
int
]
input_buffers
:
List
[
torch
.
Tensor
]
inductor_hash_cache
:
InductorHashCache
def
__init__
(
self
,
...
...
@@ -472,6 +476,53 @@ class VllmBackend:
def
__call__
(
self
,
graph
:
fx
.
GraphModule
,
example_inputs
)
->
Callable
:
if
not
self
.
compilation_config
.
cache_dir
:
# no provided cache dir, generate one based on the known factors
# that affects the compilation. if none of the factors change,
# the cache dir will be the same so that we can reuse the compiled
# graph.
# 1. factors come from the vllm_config (it mainly summarizes how the
# model is created)
vllm_config
=
self
.
vllm_config
config_hash
=
vllm_config
.
compute_hash
()
# 2. factors come from the code files that are traced by Dynamo (
# it mainly summarizes how the model is used in forward pass)
forward_code_files
=
list
(
sorted
(
self
.
compilation_config
.
traced_files
))
self
.
compilation_config
.
traced_files
.
clear
()
logger
.
debug
(
"Traced files (to be considered for compilation cache):
\n
%s"
,
"
\n
"
.
join
(
forward_code_files
))
hash_content
=
[]
for
filepath
in
forward_code_files
:
hash_content
.
append
(
filepath
)
with
open
(
filepath
)
as
f
:
hash_content
.
append
(
f
.
read
())
import
hashlib
code_hash
=
hashlib
.
md5
(
"
\n
"
.
join
(
hash_content
).
encode
()).
hexdigest
()
# combine the two hashes to generate the cache dir
hash_key
=
hashlib
.
md5
(
f
"
{
config_hash
}
_
{
code_hash
}
"
.
encode
()).
hexdigest
()[:
10
]
cache_dir
=
os
.
path
.
join
(
envs
.
VLLM_CACHE_ROOT
,
"torch_compile_cache"
,
hash_key
,
f
"rank_
{
vllm_config
.
parallel_config
.
rank
}
"
)
else
:
cache_dir
=
self
.
compilation_config
.
cache_dir
os
.
makedirs
(
cache_dir
,
exist_ok
=
True
)
disabled
=
envs
.
VLLM_DISABLE_COMPILE_CACHE
self
.
inductor_hash_cache
:
InductorHashCache
=
InductorHashCache
(
cache_dir
,
disabled
=
disabled
)
if
disabled
:
logger
.
info
(
"vLLM's torch.compile cache is disabled."
)
else
:
logger
.
info
(
"Using cache directory: %s for vLLM's torch.compile"
,
cache_dir
)
# when dynamo calls the backend, it means the bytecode
# transform and analysis are done
compilation_counter
.
num_graphs_seen
+=
1
...
...
@@ -507,8 +558,8 @@ class VllmBackend:
# propagate the split graph to the piecewise backend,
# compile submodules with symbolic shapes
PiecewiseCompileInterpreter
(
self
.
split_gm
,
submod_names_to_compile
,
self
.
vllm_config
,
self
.
graph_pool
).
run
(
*
example_inputs
)
self
.
vllm_config
,
self
.
graph_pool
,
self
).
run
(
*
example_inputs
)
self
.
_called
=
True
...
...
@@ -577,7 +628,8 @@ class PiecewiseBackend:
def
__init__
(
self
,
graph
:
fx
.
GraphModule
,
vllm_config
:
VllmConfig
,
graph_pool
:
Any
,
piecewise_compile_index
:
int
,
total_piecewise_compiles
:
int
,
sym_shape_indices
:
List
[
int
],
compiled_graph_for_general_shape
:
Callable
):
compiled_graph_for_general_shape
:
Callable
,
vllm_backend
:
VllmBackend
):
"""
The backend for piecewise compilation.
It mainly handles the compilation and cudagraph capturing.
...
...
@@ -597,6 +649,7 @@ class PiecewiseBackend:
self
.
graph_pool
=
graph_pool
self
.
piecewise_compile_index
=
piecewise_compile_index
self
.
total_piecewise_compiles
=
total_piecewise_compiles
self
.
vllm_backend
=
vllm_backend
self
.
is_first_graph
=
piecewise_compile_index
==
0
self
.
is_last_graph
=
(
...
...
@@ -634,7 +687,7 @@ class PiecewiseBackend:
if
self
.
is_last_graph
and
not
self
.
to_be_compiled_sizes
:
# no specific sizes to compile
# save the hash of the inductor graph for the next run
self
.
compilation_config
.
inductor_hash_cache
.
save_to_file
()
self
.
vllm_backend
.
inductor_hash_cache
.
save_to_file
()
end_monitoring_torch_compile
(
self
.
vllm_config
)
def
__call__
(
self
,
*
args
)
->
Any
:
...
...
@@ -662,6 +715,7 @@ class PiecewiseBackend:
args
,
self
.
compilation_config
.
inductor_compile_config
,
self
.
compilation_config
,
self
.
vllm_backend
,
graph_index
=
self
.
piecewise_compile_index
,
num_graphs
=
self
.
total_piecewise_compiles
,
runtime_shape
=
runtime_shape
,
...
...
vllm/compilation/decorators.py
View file @
f1214117
import
inspect
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
TypeVar
,
Union
,
overload
from
unittest.mock
import
patch
import
torch
import
torch.nn
as
nn
from
torch._dynamo.symbolic_convert
import
InliningInstructionTranslator
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.wrapper
import
TorchCompileWrapperWithCustomDispatcher
...
...
@@ -196,7 +198,31 @@ def _support_torch_compile(
# we need to control all the compilation of the model.
torch
.
_dynamo
.
eval_frame
.
remove_from_cache
(
self
.
original_code_object
)
return
self
.
compiled_callable
(
*
args
,
**
kwargs
)
# collect all relevant files traced by Dynamo,
# so that the compilation cache can trigger re-compilation
# properly when any of these files change.
# 1. the file containing the top-level forward function
self
.
vllm_config
.
compilation_config
.
traced_files
.
add
(
self
.
original_code_object
.
co_filename
)
# 2. every time Dynamo sees a function call, it will inline
# the function by calling InliningInstructionTranslator.inline_call
# we hijack this function to know all the functions called
# during Dynamo tracing, and their corresponding files
inline_call
=
InliningInstructionTranslator
.
inline_call
def
patched_inline_call
(
parent
,
func
,
args
,
kwargs
):
code
=
func
.
get_code
()
self
.
vllm_config
.
compilation_config
.
traced_files
.
add
(
code
.
co_filename
)
return
inline_call
(
parent
,
func
,
args
,
kwargs
)
with
patch
.
object
(
InliningInstructionTranslator
,
'inline_call'
,
patched_inline_call
):
output
=
self
.
compiled_callable
(
*
args
,
**
kwargs
)
return
output
# usually, capturing the model once is enough, and then we can
# dispatch to the compiled code directly, without going through
...
...
vllm/config.py
View file @
f1214117
...
...
@@ -3,7 +3,6 @@ import copy
import
enum
import
hashlib
import
json
import
os
import
sys
import
warnings
from
contextlib
import
contextmanager
...
...
@@ -2778,9 +2777,8 @@ class CompilationConfig(BaseModel):
# keep track of enabled and disabled custom ops
enabled_custom_ops
:
Counter
[
str
]
=
PrivateAttr
disabled_custom_ops
:
Counter
[
str
]
=
PrivateAttr
traced_files
:
Set
[
str
]
=
PrivateAttr
compilation_time
:
float
=
PrivateAttr
# should be InductorHashCache, but Pydantic does not support it
inductor_hash_cache
:
Any
=
PrivateAttr
# Per-model forward context
# Mainly used to store attention cls
...
...
@@ -2818,6 +2816,7 @@ class CompilationConfig(BaseModel):
"compilation_time"
,
"bs_to_padded_graph_size"
,
"pass_config"
,
"traced_files"
,
}
return
self
.
model_dump_json
(
exclude
=
exclude
,
exclude_unset
=
True
)
...
...
@@ -2877,6 +2876,7 @@ class CompilationConfig(BaseModel):
self
.
enabled_custom_ops
=
Counter
()
self
.
disabled_custom_ops
=
Counter
()
self
.
traced_files
=
set
()
self
.
static_forward_context
=
{}
self
.
compilation_time
=
0.0
...
...
@@ -2899,29 +2899,6 @@ class CompilationConfig(BaseModel):
# merge with the config use_inductor
assert
self
.
level
==
CompilationLevel
.
PIECEWISE
if
not
self
.
cache_dir
:
# no provided cache dir, generate one based on the known factors
# that affects the compilation. if none of the factors change,
# the cache dir will be the same so that we can reuse the compiled
# graph.
hash_key
=
vllm_config
.
compute_hash
()
cache_dir
=
os
.
path
.
join
(
envs
.
VLLM_CACHE_ROOT
,
"torch_compile_cache"
,
hash_key
,
f
"rank_
{
vllm_config
.
parallel_config
.
rank
}
"
)
os
.
makedirs
(
cache_dir
,
exist_ok
=
True
)
self
.
cache_dir
=
cache_dir
disabled
=
envs
.
VLLM_DISABLE_COMPILE_CACHE
from
vllm.compilation.backends
import
InductorHashCache
self
.
inductor_hash_cache
:
InductorHashCache
=
InductorHashCache
(
self
.
cache_dir
,
disabled
=
disabled
)
if
disabled
:
logger
.
info
(
"vLLM's torch.compile cache is disabled."
)
else
:
logger
.
info
(
"Using cache directory: %s for vLLM's torch.compile"
,
self
.
cache_dir
)
from
vllm.compilation.backends
import
VllmBackend
return
VllmBackend
(
vllm_config
)
...
...
vllm/sequence.py
View file @
f1214117
...
...
@@ -1108,6 +1108,13 @@ class IntermediateTensors:
tensors
:
Dict
[
str
,
torch
.
Tensor
]
def
__init__
(
self
,
tensors
):
# manually define this function, so that
# Dynamo knows `IntermediateTensors()` comes from this file.
# Otherwise, dataclass will generate this function by evaluating
# a string, and we will lose the information about the source file.
self
.
tensors
=
tensors
def
__getitem__
(
self
,
key
:
Union
[
str
,
slice
]):
if
isinstance
(
key
,
str
):
return
self
.
tensors
[
key
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment