Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f1214117
Unverified
Commit
f1214117
authored
Jan 08, 2025
by
youkaichao
Committed by
GitHub
Jan 08, 2025
Browse files
[torch.compile] consider relevant code in compilation cache (#11614)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
cfd3219f
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
99 additions
and
35 deletions
+99
-35
vllm/compilation/backends.py
vllm/compilation/backends.py
+62
-8
vllm/compilation/decorators.py
vllm/compilation/decorators.py
+27
-1
vllm/config.py
vllm/config.py
+3
-26
vllm/sequence.py
vllm/sequence.py
+7
-0
No files found.
vllm/compilation/backends.py
View file @
f1214117
...
@@ -145,6 +145,7 @@ def wrap_inductor(graph: fx.GraphModule,
...
@@ -145,6 +145,7 @@ def wrap_inductor(graph: fx.GraphModule,
example_inputs
,
example_inputs
,
additional_inductor_config
,
additional_inductor_config
,
compilation_config
:
CompilationConfig
,
compilation_config
:
CompilationConfig
,
vllm_backend
:
"VllmBackend"
,
graph_index
:
int
=
0
,
graph_index
:
int
=
0
,
num_graphs
:
int
=
1
,
num_graphs
:
int
=
1
,
runtime_shape
:
Optional
[
int
]
=
None
,
runtime_shape
:
Optional
[
int
]
=
None
,
...
@@ -176,7 +177,7 @@ def wrap_inductor(graph: fx.GraphModule,
...
@@ -176,7 +177,7 @@ def wrap_inductor(graph: fx.GraphModule,
# see https://github.com/pytorch/pytorch/issues/138980
# see https://github.com/pytorch/pytorch/issues/138980
graph
=
copy
.
deepcopy
(
graph
)
graph
=
copy
.
deepcopy
(
graph
)
cache_data
=
compilation_config
.
inductor_hash_cache
cache_data
=
vllm_backend
.
inductor_hash_cache
if
(
runtime_shape
,
graph_index
)
in
cache_data
:
if
(
runtime_shape
,
graph_index
)
in
cache_data
:
# we compiled this graph before
# we compiled this graph before
# so we can directly lookup the compiled graph via hash
# so we can directly lookup the compiled graph via hash
...
@@ -196,7 +197,7 @@ def wrap_inductor(graph: fx.GraphModule,
...
@@ -196,7 +197,7 @@ def wrap_inductor(graph: fx.GraphModule,
hash_str
,
example_inputs
,
True
,
False
)
hash_str
,
example_inputs
,
True
,
False
)
assert
inductor_compiled_graph
is
not
None
,
(
assert
inductor_compiled_graph
is
not
None
,
(
"Inductor cache lookup failed. Please remove"
"Inductor cache lookup failed. Please remove"
f
"the cache file
{
c
ompilation_config
.
inductor_hash_cache
.
cache_file_path
}
and try again."
# noqa
f
"the cache file
{
c
ache_data
.
cache_file_path
}
and try again."
# noqa
)
)
# Inductor calling convention (function signature):
# Inductor calling convention (function signature):
...
@@ -354,7 +355,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
...
@@ -354,7 +355,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
def
__init__
(
self
,
module
:
torch
.
fx
.
GraphModule
,
def
__init__
(
self
,
module
:
torch
.
fx
.
GraphModule
,
compile_submod_names
:
List
[
str
],
vllm_config
:
VllmConfig
,
compile_submod_names
:
List
[
str
],
vllm_config
:
VllmConfig
,
graph_pool
):
graph_pool
,
vllm_backend
:
"VllmBackend"
):
super
().
__init__
(
module
)
super
().
__init__
(
module
)
from
torch._guards
import
detect_fake_mode
from
torch._guards
import
detect_fake_mode
self
.
fake_mode
=
detect_fake_mode
()
self
.
fake_mode
=
detect_fake_mode
()
...
@@ -362,6 +363,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
...
@@ -362,6 +363,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
self
.
compilation_config
=
vllm_config
.
compilation_config
self
.
compilation_config
=
vllm_config
.
compilation_config
self
.
graph_pool
=
graph_pool
self
.
graph_pool
=
graph_pool
self
.
vllm_config
=
vllm_config
self
.
vllm_config
=
vllm_config
self
.
vllm_backend
=
vllm_backend
def
run
(
self
,
*
args
):
def
run
(
self
,
*
args
):
fake_args
=
[
fake_args
=
[
...
@@ -389,6 +391,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
...
@@ -389,6 +391,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
args
,
args
,
self
.
compilation_config
.
inductor_compile_config
,
self
.
compilation_config
.
inductor_compile_config
,
self
.
compilation_config
,
self
.
compilation_config
,
self
.
vllm_backend
,
graph_index
=
index
,
graph_index
=
index
,
num_graphs
=
len
(
self
.
compile_submod_names
),
num_graphs
=
len
(
self
.
compile_submod_names
),
runtime_shape
=
None
,
runtime_shape
=
None
,
...
@@ -397,7 +400,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
...
@@ -397,7 +400,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
self
.
module
.
__dict__
[
target
]
=
PiecewiseBackend
(
self
.
module
.
__dict__
[
target
]
=
PiecewiseBackend
(
submod
,
self
.
vllm_config
,
self
.
graph_pool
,
index
,
submod
,
self
.
vllm_config
,
self
.
graph_pool
,
index
,
len
(
self
.
compile_submod_names
),
sym_shape_indices
,
len
(
self
.
compile_submod_names
),
sym_shape_indices
,
compiled_graph_for_general_shape
)
compiled_graph_for_general_shape
,
self
.
vllm_backend
)
compilation_counter
.
num_piecewise_capturable_graphs_seen
+=
1
compilation_counter
.
num_piecewise_capturable_graphs_seen
+=
1
...
@@ -430,6 +433,7 @@ class VllmBackend:
...
@@ -430,6 +433,7 @@ class VllmBackend:
post_grad_passes
:
Sequence
[
Callable
]
post_grad_passes
:
Sequence
[
Callable
]
sym_tensor_indices
:
List
[
int
]
sym_tensor_indices
:
List
[
int
]
input_buffers
:
List
[
torch
.
Tensor
]
input_buffers
:
List
[
torch
.
Tensor
]
inductor_hash_cache
:
InductorHashCache
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -472,6 +476,53 @@ class VllmBackend:
...
@@ -472,6 +476,53 @@ class VllmBackend:
def
__call__
(
self
,
graph
:
fx
.
GraphModule
,
example_inputs
)
->
Callable
:
def
__call__
(
self
,
graph
:
fx
.
GraphModule
,
example_inputs
)
->
Callable
:
if
not
self
.
compilation_config
.
cache_dir
:
# no provided cache dir, generate one based on the known factors
# that affects the compilation. if none of the factors change,
# the cache dir will be the same so that we can reuse the compiled
# graph.
# 1. factors come from the vllm_config (it mainly summarizes how the
# model is created)
vllm_config
=
self
.
vllm_config
config_hash
=
vllm_config
.
compute_hash
()
# 2. factors come from the code files that are traced by Dynamo (
# it mainly summarizes how the model is used in forward pass)
forward_code_files
=
list
(
sorted
(
self
.
compilation_config
.
traced_files
))
self
.
compilation_config
.
traced_files
.
clear
()
logger
.
debug
(
"Traced files (to be considered for compilation cache):
\n
%s"
,
"
\n
"
.
join
(
forward_code_files
))
hash_content
=
[]
for
filepath
in
forward_code_files
:
hash_content
.
append
(
filepath
)
with
open
(
filepath
)
as
f
:
hash_content
.
append
(
f
.
read
())
import
hashlib
code_hash
=
hashlib
.
md5
(
"
\n
"
.
join
(
hash_content
).
encode
()).
hexdigest
()
# combine the two hashes to generate the cache dir
hash_key
=
hashlib
.
md5
(
f
"
{
config_hash
}
_
{
code_hash
}
"
.
encode
()).
hexdigest
()[:
10
]
cache_dir
=
os
.
path
.
join
(
envs
.
VLLM_CACHE_ROOT
,
"torch_compile_cache"
,
hash_key
,
f
"rank_
{
vllm_config
.
parallel_config
.
rank
}
"
)
else
:
cache_dir
=
self
.
compilation_config
.
cache_dir
os
.
makedirs
(
cache_dir
,
exist_ok
=
True
)
disabled
=
envs
.
VLLM_DISABLE_COMPILE_CACHE
self
.
inductor_hash_cache
:
InductorHashCache
=
InductorHashCache
(
cache_dir
,
disabled
=
disabled
)
if
disabled
:
logger
.
info
(
"vLLM's torch.compile cache is disabled."
)
else
:
logger
.
info
(
"Using cache directory: %s for vLLM's torch.compile"
,
cache_dir
)
# when dynamo calls the backend, it means the bytecode
# when dynamo calls the backend, it means the bytecode
# transform and analysis are done
# transform and analysis are done
compilation_counter
.
num_graphs_seen
+=
1
compilation_counter
.
num_graphs_seen
+=
1
...
@@ -507,8 +558,8 @@ class VllmBackend:
...
@@ -507,8 +558,8 @@ class VllmBackend:
# propagate the split graph to the piecewise backend,
# propagate the split graph to the piecewise backend,
# compile submodules with symbolic shapes
# compile submodules with symbolic shapes
PiecewiseCompileInterpreter
(
self
.
split_gm
,
submod_names_to_compile
,
PiecewiseCompileInterpreter
(
self
.
split_gm
,
submod_names_to_compile
,
self
.
vllm_config
,
self
.
vllm_config
,
self
.
graph_pool
,
self
.
graph_pool
).
run
(
*
example_inputs
)
self
).
run
(
*
example_inputs
)
self
.
_called
=
True
self
.
_called
=
True
...
@@ -577,7 +628,8 @@ class PiecewiseBackend:
...
@@ -577,7 +628,8 @@ class PiecewiseBackend:
def
__init__
(
self
,
graph
:
fx
.
GraphModule
,
vllm_config
:
VllmConfig
,
def
__init__
(
self
,
graph
:
fx
.
GraphModule
,
vllm_config
:
VllmConfig
,
graph_pool
:
Any
,
piecewise_compile_index
:
int
,
graph_pool
:
Any
,
piecewise_compile_index
:
int
,
total_piecewise_compiles
:
int
,
sym_shape_indices
:
List
[
int
],
total_piecewise_compiles
:
int
,
sym_shape_indices
:
List
[
int
],
compiled_graph_for_general_shape
:
Callable
):
compiled_graph_for_general_shape
:
Callable
,
vllm_backend
:
VllmBackend
):
"""
"""
The backend for piecewise compilation.
The backend for piecewise compilation.
It mainly handles the compilation and cudagraph capturing.
It mainly handles the compilation and cudagraph capturing.
...
@@ -597,6 +649,7 @@ class PiecewiseBackend:
...
@@ -597,6 +649,7 @@ class PiecewiseBackend:
self
.
graph_pool
=
graph_pool
self
.
graph_pool
=
graph_pool
self
.
piecewise_compile_index
=
piecewise_compile_index
self
.
piecewise_compile_index
=
piecewise_compile_index
self
.
total_piecewise_compiles
=
total_piecewise_compiles
self
.
total_piecewise_compiles
=
total_piecewise_compiles
self
.
vllm_backend
=
vllm_backend
self
.
is_first_graph
=
piecewise_compile_index
==
0
self
.
is_first_graph
=
piecewise_compile_index
==
0
self
.
is_last_graph
=
(
self
.
is_last_graph
=
(
...
@@ -634,7 +687,7 @@ class PiecewiseBackend:
...
@@ -634,7 +687,7 @@ class PiecewiseBackend:
if
self
.
is_last_graph
and
not
self
.
to_be_compiled_sizes
:
if
self
.
is_last_graph
and
not
self
.
to_be_compiled_sizes
:
# no specific sizes to compile
# no specific sizes to compile
# save the hash of the inductor graph for the next run
# save the hash of the inductor graph for the next run
self
.
compilation_config
.
inductor_hash_cache
.
save_to_file
()
self
.
vllm_backend
.
inductor_hash_cache
.
save_to_file
()
end_monitoring_torch_compile
(
self
.
vllm_config
)
end_monitoring_torch_compile
(
self
.
vllm_config
)
def
__call__
(
self
,
*
args
)
->
Any
:
def
__call__
(
self
,
*
args
)
->
Any
:
...
@@ -662,6 +715,7 @@ class PiecewiseBackend:
...
@@ -662,6 +715,7 @@ class PiecewiseBackend:
args
,
args
,
self
.
compilation_config
.
inductor_compile_config
,
self
.
compilation_config
.
inductor_compile_config
,
self
.
compilation_config
,
self
.
compilation_config
,
self
.
vllm_backend
,
graph_index
=
self
.
piecewise_compile_index
,
graph_index
=
self
.
piecewise_compile_index
,
num_graphs
=
self
.
total_piecewise_compiles
,
num_graphs
=
self
.
total_piecewise_compiles
,
runtime_shape
=
runtime_shape
,
runtime_shape
=
runtime_shape
,
...
...
vllm/compilation/decorators.py
View file @
f1214117
import
inspect
import
inspect
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
TypeVar
,
Union
,
overload
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
TypeVar
,
Union
,
overload
from
unittest.mock
import
patch
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch._dynamo.symbolic_convert
import
InliningInstructionTranslator
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.wrapper
import
TorchCompileWrapperWithCustomDispatcher
from
vllm.compilation.wrapper
import
TorchCompileWrapperWithCustomDispatcher
...
@@ -196,7 +198,31 @@ def _support_torch_compile(
...
@@ -196,7 +198,31 @@ def _support_torch_compile(
# we need to control all the compilation of the model.
# we need to control all the compilation of the model.
torch
.
_dynamo
.
eval_frame
.
remove_from_cache
(
torch
.
_dynamo
.
eval_frame
.
remove_from_cache
(
self
.
original_code_object
)
self
.
original_code_object
)
return
self
.
compiled_callable
(
*
args
,
**
kwargs
)
# collect all relevant files traced by Dynamo,
# so that the compilation cache can trigger re-compilation
# properly when any of these files change.
# 1. the file containing the top-level forward function
self
.
vllm_config
.
compilation_config
.
traced_files
.
add
(
self
.
original_code_object
.
co_filename
)
# 2. every time Dynamo sees a function call, it will inline
# the function by calling InliningInstructionTranslator.inline_call
# we hijack this function to know all the functions called
# during Dynamo tracing, and their corresponding files
inline_call
=
InliningInstructionTranslator
.
inline_call
def
patched_inline_call
(
parent
,
func
,
args
,
kwargs
):
code
=
func
.
get_code
()
self
.
vllm_config
.
compilation_config
.
traced_files
.
add
(
code
.
co_filename
)
return
inline_call
(
parent
,
func
,
args
,
kwargs
)
with
patch
.
object
(
InliningInstructionTranslator
,
'inline_call'
,
patched_inline_call
):
output
=
self
.
compiled_callable
(
*
args
,
**
kwargs
)
return
output
# usually, capturing the model once is enough, and then we can
# usually, capturing the model once is enough, and then we can
# dispatch to the compiled code directly, without going through
# dispatch to the compiled code directly, without going through
...
...
vllm/config.py
View file @
f1214117
...
@@ -3,7 +3,6 @@ import copy
...
@@ -3,7 +3,6 @@ import copy
import
enum
import
enum
import
hashlib
import
hashlib
import
json
import
json
import
os
import
sys
import
sys
import
warnings
import
warnings
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
...
@@ -2778,9 +2777,8 @@ class CompilationConfig(BaseModel):
...
@@ -2778,9 +2777,8 @@ class CompilationConfig(BaseModel):
# keep track of enabled and disabled custom ops
# keep track of enabled and disabled custom ops
enabled_custom_ops
:
Counter
[
str
]
=
PrivateAttr
enabled_custom_ops
:
Counter
[
str
]
=
PrivateAttr
disabled_custom_ops
:
Counter
[
str
]
=
PrivateAttr
disabled_custom_ops
:
Counter
[
str
]
=
PrivateAttr
traced_files
:
Set
[
str
]
=
PrivateAttr
compilation_time
:
float
=
PrivateAttr
compilation_time
:
float
=
PrivateAttr
# should be InductorHashCache, but Pydantic does not support it
inductor_hash_cache
:
Any
=
PrivateAttr
# Per-model forward context
# Per-model forward context
# Mainly used to store attention cls
# Mainly used to store attention cls
...
@@ -2818,6 +2816,7 @@ class CompilationConfig(BaseModel):
...
@@ -2818,6 +2816,7 @@ class CompilationConfig(BaseModel):
"compilation_time"
,
"compilation_time"
,
"bs_to_padded_graph_size"
,
"bs_to_padded_graph_size"
,
"pass_config"
,
"pass_config"
,
"traced_files"
,
}
}
return
self
.
model_dump_json
(
exclude
=
exclude
,
exclude_unset
=
True
)
return
self
.
model_dump_json
(
exclude
=
exclude
,
exclude_unset
=
True
)
...
@@ -2877,6 +2876,7 @@ class CompilationConfig(BaseModel):
...
@@ -2877,6 +2876,7 @@ class CompilationConfig(BaseModel):
self
.
enabled_custom_ops
=
Counter
()
self
.
enabled_custom_ops
=
Counter
()
self
.
disabled_custom_ops
=
Counter
()
self
.
disabled_custom_ops
=
Counter
()
self
.
traced_files
=
set
()
self
.
static_forward_context
=
{}
self
.
static_forward_context
=
{}
self
.
compilation_time
=
0.0
self
.
compilation_time
=
0.0
...
@@ -2899,29 +2899,6 @@ class CompilationConfig(BaseModel):
...
@@ -2899,29 +2899,6 @@ class CompilationConfig(BaseModel):
# merge with the config use_inductor
# merge with the config use_inductor
assert
self
.
level
==
CompilationLevel
.
PIECEWISE
assert
self
.
level
==
CompilationLevel
.
PIECEWISE
if
not
self
.
cache_dir
:
# no provided cache dir, generate one based on the known factors
# that affects the compilation. if none of the factors change,
# the cache dir will be the same so that we can reuse the compiled
# graph.
hash_key
=
vllm_config
.
compute_hash
()
cache_dir
=
os
.
path
.
join
(
envs
.
VLLM_CACHE_ROOT
,
"torch_compile_cache"
,
hash_key
,
f
"rank_
{
vllm_config
.
parallel_config
.
rank
}
"
)
os
.
makedirs
(
cache_dir
,
exist_ok
=
True
)
self
.
cache_dir
=
cache_dir
disabled
=
envs
.
VLLM_DISABLE_COMPILE_CACHE
from
vllm.compilation.backends
import
InductorHashCache
self
.
inductor_hash_cache
:
InductorHashCache
=
InductorHashCache
(
self
.
cache_dir
,
disabled
=
disabled
)
if
disabled
:
logger
.
info
(
"vLLM's torch.compile cache is disabled."
)
else
:
logger
.
info
(
"Using cache directory: %s for vLLM's torch.compile"
,
self
.
cache_dir
)
from
vllm.compilation.backends
import
VllmBackend
from
vllm.compilation.backends
import
VllmBackend
return
VllmBackend
(
vllm_config
)
return
VllmBackend
(
vllm_config
)
...
...
vllm/sequence.py
View file @
f1214117
...
@@ -1108,6 +1108,13 @@ class IntermediateTensors:
...
@@ -1108,6 +1108,13 @@ class IntermediateTensors:
tensors
:
Dict
[
str
,
torch
.
Tensor
]
tensors
:
Dict
[
str
,
torch
.
Tensor
]
def
__init__
(
self
,
tensors
):
# manually define this function, so that
# Dynamo knows `IntermediateTensors()` comes from this file.
# Otherwise, dataclass will generate this function by evaluating
# a string, and we will lose the information about the source file.
self
.
tensors
=
tensors
def
__getitem__
(
self
,
key
:
Union
[
str
,
slice
]):
def
__getitem__
(
self
,
key
:
Union
[
str
,
slice
]):
if
isinstance
(
key
,
str
):
if
isinstance
(
key
,
str
):
return
self
.
tensors
[
key
]
return
self
.
tensors
[
key
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment