Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ea2236bf
Unverified
Commit
ea2236bf
authored
May 09, 2025
by
Richard Zou
Committed by
GitHub
May 09, 2025
Browse files
Add option to use torch._inductor.standalone_compile (#17057)
Signed-off-by:
rzou
<
zou3519@gmail.com
>
parent
7d4aedae
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
150 additions
and
29 deletions
+150
-29
vllm/compilation/backends.py
vllm/compilation/backends.py
+27
-6
vllm/compilation/compiler_interface.py
vllm/compilation/compiler_interface.py
+118
-23
vllm/envs.py
vllm/envs.py
+5
-0
No files found.
vllm/compilation/backends.py
View file @
ea2236bf
...
@@ -17,7 +17,8 @@ from vllm.config import CompilationConfig, VllmConfig
...
@@ -17,7 +17,8 @@ from vllm.config import CompilationConfig, VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
weak_ref_tensors
from
vllm.utils
import
weak_ref_tensors
from
.compiler_interface
import
EagerAdaptor
,
InductorAdaptor
from
.compiler_interface
import
(
CompilerInterface
,
EagerAdaptor
,
InductorAdaptor
,
InductorStandaloneAdaptor
)
from
.counter
import
compilation_counter
from
.counter
import
compilation_counter
from
.inductor_pass
import
InductorPass
from
.inductor_pass
import
InductorPass
from
.monitor
import
end_monitoring_torch_compile
from
.monitor
import
end_monitoring_torch_compile
...
@@ -26,6 +27,19 @@ from .pass_manager import PostGradPassManager
...
@@ -26,6 +27,19 @@ from .pass_manager import PostGradPassManager
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
def
make_compiler
(
compilation_config
:
CompilationConfig
)
->
CompilerInterface
:
if
compilation_config
.
use_inductor
:
if
envs
.
VLLM_TEST_STANDALONE_COMPILE
:
logger
.
info
(
"Using InductorStandaloneAdaptor"
)
return
InductorStandaloneAdaptor
()
else
:
logger
.
info
(
"Using InductorAdaptor"
)
return
InductorAdaptor
()
else
:
logger
.
info
(
"Using EagerAdaptor"
)
return
EagerAdaptor
()
class
CompilerManager
:
class
CompilerManager
:
"""
"""
A manager to manage the compilation process, including
A manager to manage the compilation process, including
...
@@ -41,11 +55,11 @@ class CompilerManager:
...
@@ -41,11 +55,11 @@ class CompilerManager:
support int as key.
support int as key.
"""
"""
def
__init__
(
self
,
use_inductor
:
bool
):
def
__init__
(
self
,
compilation_config
:
CompilationConfig
):
self
.
cache
:
Dict
[
Tuple
[
Optional
[
int
],
int
,
str
],
Any
]
=
dict
()
self
.
cache
:
Dict
[
Tuple
[
Optional
[
int
],
int
,
str
],
Any
]
=
dict
()
cls
=
InductorAdaptor
if
use_inductor
else
EagerAdaptor
self
.
compiler
=
cls
()
self
.
is_cache_updated
=
False
self
.
is_cache_updated
=
False
self
.
compilation_config
=
compilation_config
self
.
compiler
=
make_compiler
(
compilation_config
)
def
compute_hash
(
self
,
vllm_config
:
VllmConfig
)
->
str
:
def
compute_hash
(
self
,
vllm_config
:
VllmConfig
)
->
str
:
return
self
.
compiler
.
compute_hash
(
vllm_config
)
return
self
.
compiler
.
compute_hash
(
vllm_config
)
...
@@ -123,8 +137,15 @@ class CompilerManager:
...
@@ -123,8 +137,15 @@ class CompilerManager:
# no compiler cached the graph, or the cache is disabled,
# no compiler cached the graph, or the cache is disabled,
# we need to compile it
# we need to compile it
if
isinstance
(
self
.
compiler
,
InductorAdaptor
):
# Let compile_fx generate a key for us
maybe_key
=
None
else
:
maybe_key
=
\
f
"artifact_shape_
{
runtime_shape
}
_subgraph_
{
graph_index
}
"
compiled_graph
,
handle
=
self
.
compiler
.
compile
(
compiled_graph
,
handle
=
self
.
compiler
.
compile
(
graph
,
example_inputs
,
additional_inductor_config
,
runtime_shape
)
graph
,
example_inputs
,
additional_inductor_config
,
runtime_shape
,
maybe_key
)
assert
compiled_graph
is
not
None
,
"Failed to compile the graph"
assert
compiled_graph
is
not
None
,
"Failed to compile the graph"
...
@@ -336,7 +357,7 @@ class VllmBackend:
...
@@ -336,7 +357,7 @@ class VllmBackend:
self
.
compilation_config
=
vllm_config
.
compilation_config
self
.
compilation_config
=
vllm_config
.
compilation_config
self
.
compiler_manager
:
CompilerManager
=
CompilerManager
(
self
.
compiler_manager
:
CompilerManager
=
CompilerManager
(
self
.
compilation_config
.
use_inductor
)
self
.
compilation_config
)
# `torch.compile` is JIT compiled, so we don't need to
# `torch.compile` is JIT compiled, so we don't need to
# do anything here
# do anything here
...
...
vllm/compilation/compiler_interface.py
View file @
ea2236bf
...
@@ -50,7 +50,8 @@ class CompilerInterface:
...
@@ -50,7 +50,8 @@ class CompilerInterface:
graph
:
fx
.
GraphModule
,
graph
:
fx
.
GraphModule
,
example_inputs
:
List
[
Any
],
example_inputs
:
List
[
Any
],
compiler_config
:
Dict
[
str
,
Any
],
compiler_config
:
Dict
[
str
,
Any
],
runtime_shape
:
Optional
[
int
]
=
None
runtime_shape
:
Optional
[
int
]
=
None
,
key
:
Optional
[
str
]
=
None
,
)
->
Tuple
[
Optional
[
Callable
],
Optional
[
Any
]]:
)
->
Tuple
[
Optional
[
Callable
],
Optional
[
Any
]]:
"""
"""
Compile the graph with the given example inputs and compiler config,
Compile the graph with the given example inputs and compiler config,
...
@@ -71,6 +72,10 @@ class CompilerInterface:
...
@@ -71,6 +72,10 @@ class CompilerInterface:
If the compiler doesn't support caching, it should return None for the
If the compiler doesn't support caching, it should return None for the
handle. If the compiler fails to compile the graph, it should return
handle. If the compiler fails to compile the graph, it should return
None for the compiled function as well.
None for the compiled function as well.
`key` is required for StandaloneInductorAdapter, it specifies where to
save the compiled artifact. The compiled artifact gets saved to
`cache_dir/key`.
"""
"""
return
None
,
None
return
None
,
None
...
@@ -127,13 +132,7 @@ class AlwaysHitShapeEnv:
...
@@ -127,13 +132,7 @@ class AlwaysHitShapeEnv:
return
""
return
""
class
InductorAdaptor
(
CompilerInterface
):
def
get_inductor_factors
()
->
List
[
Any
]:
"""
The adaptor for the Inductor compiler, version 2.5 and 2.6.
"""
name
=
"inductor"
def
compute_hash
(
self
,
vllm_config
:
VllmConfig
)
->
str
:
factors
:
List
[
Any
]
=
[]
factors
:
List
[
Any
]
=
[]
# summarize system state
# summarize system state
from
torch._inductor.codecache
import
CacheBase
from
torch._inductor.codecache
import
CacheBase
...
@@ -144,6 +143,97 @@ class InductorAdaptor(CompilerInterface):
...
@@ -144,6 +143,97 @@ class InductorAdaptor(CompilerInterface):
from
torch._inductor.codecache
import
torch_key
from
torch._inductor.codecache
import
torch_key
torch_factors
=
torch_key
()
torch_factors
=
torch_key
()
factors
.
append
(
torch_factors
)
factors
.
append
(
torch_factors
)
return
factors
class
InductorStandaloneAdaptor
(
CompilerInterface
):
"""
The adaptor for the Inductor compiler.
Requires PyTorch 2.8+.
This is not on by default yet, but we plan to turn it on by default for
PyTorch 2.8.
Use VLLM_TEST_STANDALONE_COMPILE to toggle this on or off.
"""
name
=
"inductor_standalone"
def
compute_hash
(
self
,
vllm_config
:
VllmConfig
)
->
str
:
factors
=
get_inductor_factors
()
hash_str
=
hashlib
.
md5
(
str
(
factors
).
encode
(),
usedforsecurity
=
False
).
hexdigest
()[:
10
]
return
hash_str
def
initialize_cache
(
self
,
cache_dir
:
str
,
disable_cache
:
bool
=
False
):
self
.
cache_dir
=
cache_dir
def
compile
(
self
,
graph
:
fx
.
GraphModule
,
example_inputs
:
List
[
Any
],
compiler_config
:
Dict
[
str
,
Any
],
runtime_shape
:
Optional
[
int
]
=
None
,
key
:
Optional
[
str
]
=
None
,
)
->
Tuple
[
Optional
[
Callable
],
Optional
[
Any
]]:
current_config
=
{}
if
compiler_config
is
not
None
:
current_config
.
update
(
compiler_config
)
set_inductor_config
(
current_config
,
runtime_shape
)
if
isinstance
(
runtime_shape
,
int
):
dynamic_shapes
=
"from_example_inputs"
else
:
dynamic_shapes
=
"from_tracing_context"
from
torch._inductor
import
standalone_compile
with
pass_context
(
runtime_shape
):
compiled_graph
=
standalone_compile
(
graph
,
example_inputs
,
dynamic_shapes
=
dynamic_shapes
,
options
=
{
"config_patches"
:
current_config
})
# Save the compiled artifact to disk in the specified path
assert
key
is
not
None
path
=
os
.
path
.
join
(
self
.
cache_dir
,
key
)
compiled_graph
.
save
(
path
=
path
,
format
=
"unpacked"
)
return
compiled_graph
,
(
key
,
path
)
def
load
(
self
,
handle
:
Any
,
graph
:
fx
.
GraphModule
,
example_inputs
:
List
[
Any
],
graph_index
:
int
,
runtime_shape
:
Optional
[
int
]
=
None
)
->
Callable
:
assert
isinstance
(
handle
,
tuple
)
assert
isinstance
(
handle
[
0
],
str
)
assert
isinstance
(
handle
[
1
],
str
)
path
=
handle
[
1
]
inductor_compiled_graph
=
torch
.
_inductor
.
CompiledArtifact
.
load
(
path
=
path
,
format
=
"unpacked"
)
from
torch._inductor.compile_fx
import
graph_returns_tuple
returns_tuple
=
graph_returns_tuple
(
graph
)
def
compiled_graph_wrapper
(
*
args
):
graph_output
=
inductor_compiled_graph
(
*
args
)
# unpack the tuple if needed
# TODO(rzou): the implication is that we're not
# reading the python bytecode correctly in vLLM?
if
returns_tuple
:
return
graph_output
else
:
return
graph_output
[
0
]
return
compiled_graph_wrapper
class
InductorAdaptor
(
CompilerInterface
):
"""
The adaptor for the Inductor compiler, version 2.5, 2.6, 2.7.
"""
name
=
"inductor"
def
compute_hash
(
self
,
vllm_config
:
VllmConfig
)
->
str
:
factors
=
get_inductor_factors
()
hash_str
=
hashlib
.
md5
(
str
(
factors
).
encode
(),
hash_str
=
hashlib
.
md5
(
str
(
factors
).
encode
(),
usedforsecurity
=
False
).
hexdigest
()[:
10
]
usedforsecurity
=
False
).
hexdigest
()[:
10
]
return
hash_str
return
hash_str
...
@@ -168,23 +258,19 @@ class InductorAdaptor(CompilerInterface):
...
@@ -168,23 +258,19 @@ class InductorAdaptor(CompilerInterface):
graph
:
fx
.
GraphModule
,
graph
:
fx
.
GraphModule
,
example_inputs
:
List
[
Any
],
example_inputs
:
List
[
Any
],
compiler_config
:
Dict
[
str
,
Any
],
compiler_config
:
Dict
[
str
,
Any
],
runtime_shape
:
Optional
[
int
]
=
None
runtime_shape
:
Optional
[
int
]
=
None
,
key
:
Optional
[
str
]
=
None
,
)
->
Tuple
[
Optional
[
Callable
],
Optional
[
Any
]]:
)
->
Tuple
[
Optional
[
Callable
],
Optional
[
Any
]]:
current_config
=
{}
from
torch._inductor.compile_fx
import
compile_fx
from
torch._inductor.compile_fx
import
compile_fx
current_config
=
{}
if
compiler_config
is
not
None
:
current_config
.
update
(
compiler_config
)
# disable remote cache
# disable remote cache
current_config
[
"fx_graph_cache"
]
=
True
current_config
[
"fx_graph_cache"
]
=
True
current_config
[
"fx_graph_remote_cache"
]
=
False
current_config
[
"fx_graph_remote_cache"
]
=
False
if
compiler_config
is
not
None
:
set_inductor_config
(
current_config
,
runtime_shape
)
current_config
.
update
(
compiler_config
)
if
isinstance
(
runtime_shape
,
int
):
# for a specific batchsize, tuning triton kernel parameters
# can be beneficial
current_config
[
"max_autotune"
]
=
True
current_config
[
"coordinate_descent_tuning"
]
=
True
# inductor can inplace modify the graph, so we need to copy it
# inductor can inplace modify the graph, so we need to copy it
# see https://github.com/pytorch/pytorch/issues/138980
# see https://github.com/pytorch/pytorch/issues/138980
...
@@ -422,6 +508,14 @@ class InductorAdaptor(CompilerInterface):
...
@@ -422,6 +508,14 @@ class InductorAdaptor(CompilerInterface):
return
contextlib
.
nullcontext
()
return
contextlib
.
nullcontext
()
def
set_inductor_config
(
config
,
runtime_shape
):
if
isinstance
(
runtime_shape
,
int
):
# for a specific batchsize, tuning triton kernel parameters
# can be beneficial
config
[
"max_autotune"
]
=
True
config
[
"coordinate_descent_tuning"
]
=
True
class
EagerAdaptor
(
CompilerInterface
):
class
EagerAdaptor
(
CompilerInterface
):
name
=
"eager"
name
=
"eager"
...
@@ -430,7 +524,8 @@ class EagerAdaptor(CompilerInterface):
...
@@ -430,7 +524,8 @@ class EagerAdaptor(CompilerInterface):
graph
:
fx
.
GraphModule
,
graph
:
fx
.
GraphModule
,
example_inputs
:
List
[
Any
],
example_inputs
:
List
[
Any
],
compiler_config
:
Dict
[
str
,
Any
],
compiler_config
:
Dict
[
str
,
Any
],
runtime_shape
:
Optional
[
int
]
=
None
runtime_shape
:
Optional
[
int
]
=
None
,
key
:
Optional
[
str
]
=
None
,
)
->
Tuple
[
Optional
[
Callable
],
Optional
[
Any
]]:
)
->
Tuple
[
Optional
[
Callable
],
Optional
[
Any
]]:
# we don't need to compile the graph, just return the graph itself.
# we don't need to compile the graph, just return the graph itself.
# It does not support caching, return None for the handle.
# It does not support caching, return None for the handle.
...
...
vllm/envs.py
View file @
ea2236bf
...
@@ -263,6 +263,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -263,6 +263,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda
:
bool
(
lambda
:
bool
(
os
.
environ
.
get
(
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"
,
"1"
)
!=
"0"
),
os
.
environ
.
get
(
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"
,
"1"
)
!=
"0"
),
# Internal flag to enable/disable Inductor standalone compile
"VLLM_TEST_STANDALONE_COMPILE"
:
lambda
:
os
.
environ
.
get
(
"VLLM_TEST_STANDALONE_COMPILE"
,
"0"
)
!=
"0"
,
# local rank of the process in the distributed setting, used to determine
# local rank of the process in the distributed setting, used to determine
# the GPU device id
# the GPU device id
"LOCAL_RANK"
:
"LOCAL_RANK"
:
...
@@ -805,6 +809,7 @@ def compute_hash() -> str:
...
@@ -805,6 +809,7 @@ def compute_hash() -> str:
"VLLM_USE_TRITON_AWQ"
,
"VLLM_USE_TRITON_AWQ"
,
"VLLM_DP_RANK"
,
"VLLM_DP_RANK"
,
"VLLM_DP_SIZE"
,
"VLLM_DP_SIZE"
,
"VLLM_TEST_STANDALONE_COMPILE"
,
]
]
for
key
in
environment_variables_to_hash
:
for
key
in
environment_variables_to_hash
:
if
key
in
environment_variables
:
if
key
in
environment_variables
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment