Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
aaf4b70a
Unverified
Commit
aaf4b70a
authored
Jan 09, 2026
by
Lucas Kabela
Committed by
GitHub
Jan 09, 2026
Browse files
[Misc][BE] Type coverage for vllm/compilation [2/3] (#31744)
parent
3adffd5b
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
161 additions
and
91 deletions
+161
-91
vllm/compilation/backends.py
vllm/compilation/backends.py
+6
-6
vllm/compilation/caching.py
vllm/compilation/caching.py
+16
-7
vllm/compilation/cuda_graph.py
vllm/compilation/cuda_graph.py
+12
-10
vllm/compilation/decorators.py
vllm/compilation/decorators.py
+42
-20
vllm/compilation/fix_functionalization.py
vllm/compilation/fix_functionalization.py
+5
-5
vllm/compilation/inductor_pass.py
vllm/compilation/inductor_pass.py
+3
-0
vllm/compilation/noop_elimination.py
vllm/compilation/noop_elimination.py
+2
-2
vllm/compilation/pass_manager.py
vllm/compilation/pass_manager.py
+19
-12
vllm/compilation/piecewise_backend.py
vllm/compilation/piecewise_backend.py
+29
-15
vllm/compilation/wrapper.py
vllm/compilation/wrapper.py
+23
-13
vllm/config/compilation.py
vllm/config/compilation.py
+3
-0
vllm/distributed/device_communicators/pynccl_allocator.py
vllm/distributed/device_communicators/pynccl_allocator.py
+1
-1
No files found.
vllm/compilation/backends.py
View file @
aaf4b70a
...
...
@@ -179,7 +179,7 @@ class CompilerManager:
example_inputs
:
list
[
Any
],
graph_index
:
int
,
compile_range
:
Range
,
)
->
Callable
|
None
:
)
->
Callable
[...,
Any
]
|
None
:
if
(
compile_range
,
graph_index
,
self
.
compiler
.
name
)
not
in
self
.
cache
:
return
None
handle
=
self
.
cache
[(
compile_range
,
graph_index
,
self
.
compiler
.
name
)]
...
...
@@ -199,7 +199,7 @@ class CompilerManager:
self
,
graph
:
fx
.
GraphModule
,
example_inputs
:
list
[
Any
],
additional_inductor_config
,
additional_inductor_config
:
dict
[
str
,
Any
]
,
compilation_config
:
CompilationConfig
,
compile_range
:
Range
,
graph_index
:
int
=
0
,
...
...
@@ -355,7 +355,7 @@ def split_graph(
compilation_start_time
=
0.0
class
PiecewiseCompileInterpreter
(
torch
.
fx
.
Interpreter
):
class
PiecewiseCompileInterpreter
(
torch
.
fx
.
Interpreter
):
# type: ignore[misc]
"""Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
It runs the given graph with fake inputs, and compile some
submodules specified by `compile_submod_names` with the given
...
...
@@ -506,9 +506,9 @@ class VllmBackend:
# the stiching graph module for all the piecewise graphs
split_gm
:
fx
.
GraphModule
piecewise_graphs
:
list
[
SplitItem
]
returned_callable
:
Callable
returned_callable
:
Callable
[...,
Any
]
# Inductor passes to run on the graph pre-defunctionalization
post_grad_passes
:
Sequence
[
Callable
]
post_grad_passes
:
Sequence
[
Callable
[...,
Any
]
]
sym_tensor_indices
:
list
[
int
]
input_buffers
:
list
[
torch
.
Tensor
]
compiler_manager
:
CompilerManager
...
...
@@ -821,7 +821,7 @@ class VllmBackend:
]
# this is the callable we return to Dynamo to run
def
copy_and_call
(
*
args
)
:
def
copy_and_call
(
*
args
:
Any
)
->
Any
:
list_args
=
list
(
args
)
for
i
,
index
in
enumerate
(
self
.
sym_tensor_indices
):
runtime_tensor
=
list_args
[
index
]
...
...
vllm/compilation/caching.py
View file @
aaf4b70a
...
...
@@ -4,6 +4,8 @@
import
inspect
import
os
import
pickle
from
collections.abc
import
Callable
,
Sequence
from
typing
import
Any
,
Literal
from
unittest.mock
import
patch
import
torch
...
...
@@ -25,7 +27,7 @@ assert isinstance(SerializableCallable, type)
logger
=
init_logger
(
__name__
)
class
VllmSerializableFunction
(
SerializableCallable
):
class
VllmSerializableFunction
(
SerializableCallable
):
# type: ignore[misc]
"""
A wrapper around a compiled function by vllm. It will forward the tensor
inputs to the compiled function and return the result.
...
...
@@ -38,8 +40,13 @@ class VllmSerializableFunction(SerializableCallable):
"""
def
__init__
(
self
,
graph_module
,
example_inputs
,
prefix
,
optimized_call
,
is_encoder
=
False
):
self
,
graph_module
:
torch
.
fx
.
GraphModule
,
example_inputs
:
Sequence
[
Any
],
prefix
:
str
,
optimized_call
:
Callable
[...,
Any
],
is_encoder
:
bool
=
False
,
)
->
None
:
assert
isinstance
(
graph_module
,
torch
.
fx
.
GraphModule
)
self
.
graph_module
=
graph_module
self
.
example_inputs
=
example_inputs
...
...
@@ -53,7 +60,7 @@ class VllmSerializableFunction(SerializableCallable):
if
sym_input
is
not
None
:
self
.
shape_env
=
sym_input
.
node
.
shape_env
def
__call__
(
self
,
*
args
,
**
kwargs
)
:
def
__call__
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
return
self
.
optimized_call
(
*
args
,
**
kwargs
)
@
classmethod
...
...
@@ -73,7 +80,9 @@ class VllmSerializableFunction(SerializableCallable):
graph_reducer_override
=
GraphPickler
.
reducer_override
def
_graph_reducer_override
(
self
,
obj
):
def
_graph_reducer_override
(
self
:
GraphPickler
,
obj
:
Any
)
->
tuple
[
Callable
[...,
Any
],
tuple
[
Any
,
...]]
|
Any
:
if
(
inspect
.
isclass
(
obj
)
and
issubclass
(
obj
,
sympy
.
Function
)
...
...
@@ -114,7 +123,7 @@ class VllmSerializableFunction(SerializableCallable):
get_current_vllm_config
(),
state
[
"prefix"
],
is_encoder
)
def
optimized_call
(
*
example_inputs
)
:
def
optimized_call
(
*
example_inputs
:
Any
)
->
Any
:
"""
On the first run of the optimized call, we rerun the compiler
backend which should result in a cache hit. After the backend
...
...
@@ -136,7 +145,7 @@ class VllmSerializableFunction(SerializableCallable):
return
fn
@
property
def
co_name
(
self
):
def
co_name
(
self
)
->
Literal
[
"VllmSerializableFunction"
]
:
"""
Used for depyf debugging.
"""
...
...
vllm/compilation/cuda_graph.py
View file @
aaf4b70a
...
...
@@ -42,7 +42,9 @@ class CUDAGraphLogging:
"Count"
,
]
def
__init__
(
self
,
cg_mode
:
CUDAGraphMode
,
cg_capture_sizes
:
list
[
int
]
|
None
):
def
__init__
(
self
,
cg_mode
:
CUDAGraphMode
,
cg_capture_sizes
:
list
[
int
]
|
None
)
->
None
:
self
.
reset
()
self
.
cg_mode
=
str
(
cg_mode
)
self
.
cg_capture_sizes
=
str
(
cg_capture_sizes
or
[])
...
...
@@ -54,10 +56,10 @@ class CUDAGraphLogging:
"**CUDAGraph Stats:**
\n\n
"
)
def
reset
(
self
):
self
.
stats
=
[]
def
reset
(
self
)
->
None
:
self
.
stats
:
list
[
CUDAGraphStat
]
=
[]
def
observe
(
self
,
cudagraph_stat
:
CUDAGraphStat
):
def
observe
(
self
,
cudagraph_stat
:
CUDAGraphStat
)
->
None
:
self
.
stats
.
append
(
cudagraph_stat
)
def
generate_metric_table
(
self
)
->
str
:
...
...
@@ -109,7 +111,7 @@ class CUDAGraphLogging:
+
"
\n
"
)
def
log
(
self
,
log_fn
=
logger
.
info
):
def
log
(
self
,
log_fn
:
Callable
[...,
Any
]
=
logger
.
info
)
->
None
:
if
not
self
.
stats
:
return
log_fn
(
self
.
generate_metric_table
())
...
...
@@ -161,11 +163,11 @@ class CUDAGraphWrapper:
def
__init__
(
self
,
runnable
:
Callable
,
runnable
:
Callable
[...,
Any
]
,
vllm_config
:
VllmConfig
,
runtime_mode
:
CUDAGraphMode
,
cudagraph_options
:
CUDAGraphOptions
|
None
=
None
,
):
)
->
None
:
self
.
runnable
=
runnable
self
.
vllm_config
=
vllm_config
self
.
runtime_mode
=
runtime_mode
...
...
@@ -189,7 +191,7 @@ class CUDAGraphWrapper:
# cudagraphs for.
self
.
concrete_cudagraph_entries
:
dict
[
BatchDescriptor
,
CUDAGraphEntry
]
=
{}
def
__getattr__
(
self
,
key
:
str
):
def
__getattr__
(
self
,
key
:
str
)
->
Any
:
# allow accessing the attributes of the runnable.
if
hasattr
(
self
.
runnable
,
key
):
return
getattr
(
self
.
runnable
,
key
)
...
...
@@ -198,11 +200,11 @@ class CUDAGraphWrapper:
f
"cudagraph wrapper:
{
self
.
runnable
}
"
)
def
unwrap
(
self
)
->
Callable
:
def
unwrap
(
self
)
->
Callable
[...,
Any
]
:
# in case we need to access the original runnable.
return
self
.
runnable
def
__call__
(
self
,
*
args
,
**
kwargs
)
:
def
__call__
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
|
None
:
forward_context
=
get_forward_context
()
batch_descriptor
=
forward_context
.
batch_descriptor
cudagraph_runtime_mode
=
forward_context
.
cudagraph_runtime_mode
...
...
vllm/compilation/decorators.py
View file @
aaf4b70a
...
...
@@ -6,8 +6,8 @@ import hashlib
import
inspect
import
os
import
sys
from
collections.abc
import
Callable
from
typing
import
TypeVar
,
overload
from
collections.abc
import
Callable
,
Generator
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
TypeVar
,
overload
from
unittest.mock
import
patch
import
torch
...
...
@@ -32,6 +32,14 @@ from vllm.utils.torch_utils import is_torch_equal_or_newer, supports_dynamo
from
.monitor
import
start_monitoring_torch_compile
if
TYPE_CHECKING
:
# Only added on nightly/2.10 so wrap
try
:
from
torch._dynamo.package
import
SourceInfo
except
ImportError
:
# Fallback for old versions not supporting
SourceInfo
=
Any
logger
=
init_logger
(
__name__
)
IGNORE_COMPILE_KEY
=
"_ignore_compile_vllm"
...
...
@@ -59,7 +67,7 @@ def ignore_torch_compile(cls: _T) -> _T:
return
cls
def
_should_ignore_torch_compile
(
cls
)
->
bool
:
def
_should_ignore_torch_compile
(
cls
:
_T
)
->
bool
:
"""
Check if the class should be ignored for torch.compile.
"""
...
...
@@ -224,7 +232,7 @@ def support_torch_compile(
return
cls_decorator_helper
def
_model_hash_key
(
fn
)
->
str
:
def
_model_hash_key
(
fn
:
Callable
[...,
Any
]
)
->
str
:
import
vllm
sha256_hash
=
hashlib
.
sha256
()
...
...
@@ -234,7 +242,9 @@ def _model_hash_key(fn) -> str:
return
sha256_hash
.
hexdigest
()
def
_verify_source_unchanged
(
source_info
,
vllm_config
)
->
None
:
def
_verify_source_unchanged
(
source_info
:
"SourceInfo"
,
vllm_config
:
VllmConfig
)
->
None
:
from
.caching
import
_compute_code_hash
,
_compute_code_hash_with_content
file_contents
=
{}
...
...
@@ -275,8 +285,12 @@ def _support_torch_compile(
setattr
(
cls
,
IGNORE_COMPILE_KEY
,
False
)
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
|
None
=
None
,
prefix
:
str
=
""
,
**
kwargs
):
self
:
_T
,
*
,
vllm_config
:
VllmConfig
|
None
=
None
,
prefix
:
str
=
""
,
**
kwargs
:
Any
,
)
->
None
:
if
vllm_config
is
None
:
vllm_config
=
get_current_vllm_config
()
...
...
@@ -309,13 +323,17 @@ def _support_torch_compile(
compilation_counter
.
num_models_seen
+=
1
self
.
compiled
=
False
TorchCompileWithNoGuardsWrapper
.
__init__
(
self
)
# Handled by monkeypatching `TorchCompileWithNoGuardsWrapper` into base class
TorchCompileWithNoGuardsWrapper
.
__init__
(
self
)
# type: ignore[arg-type]
cls
.
__init__
=
__init__
def
_mark_dynamic_inputs
(
mod
,
type
,
*
args
,
**
kwargs
):
def
mark_dynamic
(
arg
,
dims
):
if
type
==
DynamicShapesType
.
UNBACKED
:
def
_mark_dynamic_inputs
(
mod
:
_T
,
ds_type
:
DynamicShapesType
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
None
:
def
mark_dynamic
(
arg
:
torch
.
Tensor
,
dims
:
list
[
int
])
->
None
:
if
ds_type
==
DynamicShapesType
.
UNBACKED
:
if
is_torch_equal_or_newer
(
"2.10.0.dev"
):
for
dim
in
dims
:
torch
.
_dynamo
.
decorators
.
mark_unbacked
(
...
...
@@ -326,7 +344,7 @@ def _support_torch_compile(
else
:
torch
.
_dynamo
.
mark_dynamic
(
arg
,
dims
)
sig
=
inspect
.
signature
(
mod
.
__class__
.
forward
)
sig
=
inspect
.
signature
(
mod
.
__class__
.
forward
)
# type: ignore[attr-defined]
bound_args
=
sig
.
bind
(
mod
,
*
args
,
**
kwargs
)
bound_args
.
apply_defaults
()
for
k
,
dims
in
dynamic_arg_dims
.
items
():
...
...
@@ -364,7 +382,7 @@ def _support_torch_compile(
else
:
torch
.
_dynamo
.
decorators
.
mark_unbacked
(
arg
,
dims
)
def
__call__
(
self
,
*
args
,
**
kwargs
)
:
def
__call__
(
self
:
_T
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
# torch.compiler.is_compiling() means we are inside the compilation
# e.g. TPU has the compilation logic in model runner, so we don't
# need to compile the model inside.
...
...
@@ -444,7 +462,7 @@ def _support_torch_compile(
not
envs
.
VLLM_USE_AOT_COMPILE
or
self
.
vllm_config
.
compilation_config
.
backend
==
"eager"
)
return
TorchCompileWithNoGuardsWrapper
.
__call__
(
self
,
*
args
,
**
kwargs
)
return
TorchCompileWithNoGuardsWrapper
.
__call__
(
self
,
*
args
,
**
kwargs
)
# type: ignore[arg-type]
# This is the path for the first compilation.
# the first compilation needs to have dynamic shapes marked
...
...
@@ -477,7 +495,7 @@ def _support_torch_compile(
# during Dynamo tracing, and their corresponding files
inline_call
=
InliningInstructionTranslator
.
inline_call_
def
patched_inline_call
(
self_
)
:
def
patched_inline_call
(
self_
:
Any
)
->
Any
:
code
=
self_
.
f_code
self
.
compilation_config
.
traced_files
.
add
(
code
.
co_filename
)
return
inline_call
(
self_
)
...
...
@@ -535,7 +553,7 @@ def _support_torch_compile(
str
(
e
),
)
else
:
output
=
TorchCompileWithNoGuardsWrapper
.
__call__
(
self
,
*
args
,
**
kwargs
)
output
=
TorchCompileWithNoGuardsWrapper
.
__call__
(
self
,
*
args
,
**
kwargs
)
# type: ignore[arg-type]
self
.
compiled
=
True
return
output
...
...
@@ -545,7 +563,9 @@ def _support_torch_compile(
@
contextlib
.
contextmanager
def
maybe_use_cudagraph_partition_wrapper
(
vllm_config
:
VllmConfig
):
def
maybe_use_cudagraph_partition_wrapper
(
vllm_config
:
VllmConfig
,
)
->
Generator
[
None
,
None
,
None
]:
"""
Context manager to set/unset customized cudagraph partition wrappers.
...
...
@@ -572,7 +592,9 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
current_platform
.
get_static_graph_wrapper_cls
()
)
def
customized_cudagraph_wrapper
(
f
,
metadata
:
CUDAGraphWrapperMetadata
):
def
customized_cudagraph_wrapper
(
f
:
Callable
[...,
Any
],
metadata
:
CUDAGraphWrapperMetadata
)
->
Any
:
partition_id
=
metadata
.
partition_index
num_partitions
=
metadata
.
num_partitions
return
static_graph_wrapper_class
(
...
...
@@ -600,7 +622,7 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
@
contextlib
.
contextmanager
def
_torch27_patch_tensor_subclasses
():
def
_torch27_patch_tensor_subclasses
()
->
Generator
[
None
,
None
,
None
]
:
"""
Add support for using tensor subclasses (ie `BasevLLMParameter`, ect) when
using torch 2.7.0. This enables using weight_loader_v2 and the use of
...
...
@@ -614,7 +636,7 @@ def _torch27_patch_tensor_subclasses():
_ColumnvLLMParameter
,
)
def
return_false
(
*
args
,
**
kwargs
)
:
def
return_false
(
*
args
:
Any
,
**
kwargs
:
Any
)
->
Literal
[
False
]
:
return
False
if
version
.
parse
(
"2.7"
)
<=
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
"2.8"
):
...
...
vllm/compilation/fix_functionalization.py
View file @
aaf4b70a
...
...
@@ -26,7 +26,7 @@ class FixFunctionalizationPass(VllmInductorPass):
"""
@
VllmInductorPass
.
time_and_log
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
):
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
)
->
None
:
# XPU does not support auto-functionalization yet.
# Will enable this when switch to vllm-xpu-kernels.
if
current_platform
.
is_xpu
():
...
...
@@ -179,7 +179,7 @@ class FixFunctionalizationPass(VllmInductorPass):
)
self
.
nodes_to_remove
.
clear
()
def
_remove
(
self
,
node_or_nodes
:
torch
.
fx
.
Node
|
Iterable
[
torch
.
fx
.
Node
]):
def
_remove
(
self
,
node_or_nodes
:
torch
.
fx
.
Node
|
Iterable
[
torch
.
fx
.
Node
])
->
None
:
"""
Stage a node (or nodes) for removal at the end of the pass.
"""
...
...
@@ -194,7 +194,7 @@ class FixFunctionalizationPass(VllmInductorPass):
node
:
torch
.
fx
.
Node
,
mutated_args
:
dict
[
int
,
torch
.
fx
.
Node
|
str
],
args
:
tuple
[
torch
.
fx
.
Node
|
str
,
...]
|
None
=
None
,
):
)
->
None
:
"""
De-functionalize a node by replacing it with a call to the original.
It also replaces the getitem users with the mutated arguments.
...
...
@@ -206,7 +206,7 @@ class FixFunctionalizationPass(VllmInductorPass):
def
replace_users_with_mutated_args
(
self
,
node
:
torch
.
fx
.
Node
,
mutated_args
:
dict
[
int
,
torch
.
fx
.
Node
|
str
]
):
)
->
None
:
"""
Replace all getitem users of the auto-functionalized node with the
mutated arguments.
...
...
@@ -237,7 +237,7 @@ class FixFunctionalizationPass(VllmInductorPass):
graph
:
torch
.
fx
.
Graph
,
node
:
torch
.
fx
.
Node
,
args
:
tuple
[
torch
.
fx
.
Node
|
str
,
...]
|
None
=
None
,
):
)
->
None
:
"""
Insert a new defunctionalized node into the graph before node.
If one of the kwargs is 'out', provide args directly,
...
...
vllm/compilation/inductor_pass.py
View file @
aaf4b70a
...
...
@@ -29,6 +29,9 @@ else:
Torch25CustomGraphPass
as
CustomGraphPass
,
)
# Re-export CustomGraphPass for external usage
__all__
=
[
"CustomGraphPass"
]
_pass_context
=
None
P
=
ParamSpec
(
"P"
)
R
=
TypeVar
(
"R"
)
...
...
vllm/compilation/noop_elimination.py
View file @
aaf4b70a
...
...
@@ -65,7 +65,7 @@ class NoOpEliminationPass(VllmInductorPass):
"""
@
VllmInductorPass
.
time_and_log
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
):
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
)
->
None
:
count
=
0
# Remove no-op reshapes/views:
for
node
in
graph
.
nodes
:
...
...
@@ -117,7 +117,7 @@ class NoOpEliminationPass(VllmInductorPass):
2. The dimensions both correspond to the same SymInt
"""
# Case 1
return
statically_known_true
(
dim
==
i_dim
)
return
statically_known_true
(
dim
==
i_dim
)
# type: ignore[no-any-return]
def
all_dims_equivalent
(
self
,
dims
:
Iterable
[
int
|
SymInt
],
i_dims
:
Iterable
[
int
|
SymInt
]
...
...
vllm/compilation/pass_manager.py
View file @
aaf4b70a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
functools
from
collections.abc
import
Callable
from
typing
import
Any
,
ParamSpec
,
TypeVar
from
torch
import
fx
as
fx
...
...
@@ -40,8 +42,11 @@ from .noop_elimination import NoOpEliminationPass
logger
=
init_logger
(
__name__
)
P
=
ParamSpec
(
"P"
)
R
=
TypeVar
(
"R"
)
def
with_pattern_match_debug
(
fn
):
def
with_pattern_match_debug
(
fn
:
Callable
[
P
,
R
])
->
Callable
[
P
,
R
]:
"""
Function decorator that turns on inductor pattern match debug
for the duration of the call.
...
...
@@ -49,7 +54,7 @@ def with_pattern_match_debug(fn):
"""
@
functools
.
wraps
(
fn
)
def
wrapper
(
*
args
,
**
kwargs
)
:
def
wrapper
(
*
args
:
P
.
args
,
**
kwargs
:
P
.
kwargs
)
->
R
:
if
(
debug_val
:
=
envs
.
VLLM_PATTERN_MATCH_DEBUG
)
is
not
None
:
# optionally check rank here
with
set_env_var
(
"TORCHINDUCTOR_PATTERN_MATCH_DEBUG"
,
debug_val
):
...
...
@@ -59,7 +64,7 @@ def with_pattern_match_debug(fn):
return
wrapper
class
PostGradPassManager
(
CustomGraphPass
):
class
PostGradPassManager
(
CustomGraphPass
):
# type: ignore[misc]
"""
The pass manager for post-grad passes.
It handles configuration, adding custom passes, and running passes.
...
...
@@ -74,11 +79,11 @@ class PostGradPassManager(CustomGraphPass):
This way, all passes operate on a functionalized graph.
"""
def
__init__
(
self
):
def
__init__
(
self
)
->
None
:
self
.
passes
:
list
[
InductorPass
]
=
[]
@
with_pattern_match_debug
def
__call__
(
self
,
graph
:
fx
.
Graph
):
def
__call__
(
self
,
graph
:
fx
.
Graph
)
->
None
:
VllmInductorPass
.
dump_prefix
=
0
# reset dump index
compile_range
=
get_pass_context
().
compile_range
...
...
@@ -98,7 +103,7 @@ class PostGradPassManager(CustomGraphPass):
self
.
fix_functionalization
(
graph
)
VllmInductorPass
.
dump_prefix
=
None
# Cleanup index
def
configure
(
self
,
config
:
VllmConfig
):
def
configure
(
self
,
config
:
VllmConfig
)
->
None
:
self
.
pass_config
=
config
.
compilation_config
.
pass_config
# Set the current vllm config to allow tracing CustomOp instances
...
...
@@ -135,23 +140,25 @@ class PostGradPassManager(CustomGraphPass):
self
.
post_cleanup
=
PostCleanupPass
(
config
)
self
.
fix_functionalization
=
FixFunctionalizationPass
(
config
)
def
add
(
self
,
pass_
:
InductorPass
):
def
add
(
self
,
pass_
:
InductorPass
)
->
None
:
assert
isinstance
(
pass_
,
InductorPass
)
self
.
passes
.
append
(
pass_
)
def
uuid
(
self
):
def
uuid
(
self
)
->
str
:
"""
The PostGradPassManager is set as a custom pass in the Inductor and
affects compilation caching. Its uuid depends on the UUIDs of all
dependent passes and the pass config. See InductorPass for more info.
"""
state
=
{
"pass_config"
:
self
.
pass_config
.
compute_hash
(),
"passes"
:
[]}
passes
=
[]
state
:
dict
[
str
,
Any
]
=
{
"pass_config"
:
self
.
pass_config
.
compute_hash
()}
for
pass_
in
self
.
passes
:
state
[
"
passes
"
]
.
append
(
pass_
.
uuid
())
state
[
"
passes
"
]
.
append
(
self
.
fix_functionalization
.
uuid
())
passes
.
append
(
pass_
.
uuid
())
passes
.
append
(
self
.
fix_functionalization
.
uuid
())
# Include the compile range in the uuid to ensure that inductor
# recompiles the graph for the new dynamic compile range.
state
[
"compile_range"
]
=
str
(
get_pass_context
().
compile_range
)
state
[
"passes"
]
=
passes
return
InductorPass
.
hash_dict
(
state
)
vllm/compilation/piecewise_backend.py
View file @
aaf4b70a
...
...
@@ -86,7 +86,16 @@ class PiecewiseBackend:
self
.
to_be_compiled_ranges
:
set
[
Range
]
=
set
(
self
.
compile_ranges
)
# We only keep compilation management inside this class directly.
if
self
.
compile_sizes
is
not
None
:
for
size
in
self
.
compile_sizes
:
if
isinstance
(
size
,
str
):
assert
size
==
"cudagraph_capture_sizes"
raise
NotImplementedError
(
"cudagraph_capture_sizes not supported in compile_sizes."
"This should be handled in `post_init_cudagraph_sizes`."
)
else
:
assert
isinstance
(
size
,
int
)
range
=
Range
(
start
=
size
,
end
=
size
)
if
range
not
in
self
.
compile_ranges
:
self
.
range_entries
[
range
]
=
RangeEntry
(
...
...
@@ -99,14 +108,14 @@ class PiecewiseBackend:
compile_range
=
range
,
)
def
check_for_ending_compilation
(
self
):
def
check_for_ending_compilation
(
self
)
->
None
:
if
self
.
is_last_graph
and
not
self
.
to_be_compiled_ranges
:
# no specific sizes to compile
# save the hash of the inductor graph for the next run
self
.
vllm_backend
.
compiler_manager
.
save_to_file
()
end_monitoring_torch_compile
(
self
.
vllm_config
)
def
_fakify_args
(
self
,
args
:
list
[
Any
])
->
list
[
Any
]:
def
_fakify_args
(
self
,
args
:
tuple
[
Any
,
...
])
->
list
[
Any
]:
# We need to pass fake example_inputs, otherwise torch.compile
# will fakify the example_inputs potentially causing some non dynamic
# dimension to be be duck shaped to other existing shapes that have hints
...
...
@@ -127,7 +136,9 @@ class PiecewiseBackend:
assert
len
(
fake_example_inputs
)
==
len
(
args
)
return
fake_example_inputs
def
_maybe_compile_for_range_entry
(
self
,
range_entry
:
RangeEntry
,
args
)
->
Any
:
def
_maybe_compile_for_range_entry
(
self
,
range_entry
:
RangeEntry
,
args
:
tuple
[
Any
,
...]
)
->
Any
:
if
not
range_entry
.
compiled
:
range_entry
.
compiled
=
True
self
.
to_be_compiled_ranges
.
remove
(
range_entry
.
compile_range
)
...
...
@@ -136,14 +147,14 @@ class PiecewiseBackend:
# fakify for range, real args for concrete size.
# For concrete size, we clear the shape env in
# compiler_manager.compile() so no need to fakify.
args
=
(
args
_list
=
(
self
.
_fakify_args
(
args
)
if
not
range_entry
.
compile_range
.
is_single_size
()
else
args
else
list
(
args
)
)
range_entry
.
runnable
=
self
.
vllm_backend
.
compiler_manager
.
compile
(
self
.
graph
,
args
,
args
_list
,
self
.
vllm_backend
.
inductor_config
,
self
.
compilation_config
,
compile_range
=
range_entry
.
compile_range
,
...
...
@@ -153,10 +164,13 @@ class PiecewiseBackend:
self
.
check_for_ending_compilation
()
def
_find_range_for_shape
(
self
,
runtime_shape
:
int
)
->
Range
|
None
:
def
_find_range_for_shape
(
self
,
runtime_shape
:
int
)
->
Range
Entry
|
None
:
# First we try to find the range entry for the concrete compile size
# If not found, we search for the range entry
# that contains the runtime shape.
if
self
.
compile_sizes
is
None
:
return
None
if
runtime_shape
in
self
.
compile_sizes
:
return
self
.
range_entries
[
Range
(
start
=
runtime_shape
,
end
=
runtime_shape
)]
else
:
...
...
@@ -165,7 +179,7 @@ class PiecewiseBackend:
return
self
.
range_entries
[
range
]
return
None
def
__call__
(
self
,
*
args
)
->
Any
:
def
__call__
(
self
,
*
args
:
Any
)
->
Any
:
runtime_shape
=
args
[
self
.
sym_shape_indices
[
0
]]
range_entry
=
self
.
_find_range_for_shape
(
runtime_shape
)
...
...
vllm/compilation/wrapper.py
View file @
aaf4b70a
...
...
@@ -4,9 +4,10 @@
import
os
import
sys
from
abc
import
abstractmethod
from
collections.abc
import
Callable
,
Generator
from
contextlib
import
contextmanager
,
nullcontext
from
types
import
CodeType
from
typing
import
Any
from
typing
import
Any
,
ParamSpec
,
TypeVar
import
torch
import
torch._C._dynamo.guards
...
...
@@ -19,19 +20,26 @@ from vllm.utils.nvtx_pytorch_hooks import layerwise_nvtx_marker_context
logger
=
init_logger
(
__name__
)
R
=
TypeVar
(
"R"
)
P
=
ParamSpec
(
"P"
)
def
_noop_add_global_state_guard
(
self
,
*
args
,
**
kwargs
):
def
_noop_add_global_state_guard
(
self
:
torch
.
_C
.
_dynamo
.
guards
.
GuardManager
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
None
:
"""No-op to skip the GLOBAL_STATE guard entirely"""
pass
def
_noop_add_torch_function_mode_stack_guard
(
self
,
*
args
,
**
kwargs
):
def
_noop_add_torch_function_mode_stack_guard
(
self
:
torch
.
_C
.
_dynamo
.
guards
.
GuardManager
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
None
:
"""No-op to skip the TORCH_FUNCTION_MODE_STACK guard entirely"""
pass
@
contextmanager
def
_compilation_context
():
def
_compilation_context
()
->
Generator
[
None
,
None
,
None
]
:
"""Context manager for compilation settings and patches.
This manager:
...
...
@@ -88,13 +96,15 @@ class TorchCompileWithNoGuardsWrapper:
since we drop all guards.
"""
def
check_invariants_and_forward
(
self
,
*
args
,
**
kwargs
)
:
def
check_invariants_and_forward
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
assert
hasattr
(
self
,
"_check_shape_invariants"
)
self
.
_check_shape_invariants
(
*
args
,
**
kwargs
)
return
self
.
forward
(
*
args
,
**
kwargs
)
def
_call_with_optional_nvtx_range
(
self
,
callable_fn
,
*
args
,
**
kwargs
):
def
_call_with_optional_nvtx_range
(
self
,
callable_fn
:
Callable
[
P
,
R
],
*
args
:
P
.
args
,
**
kwargs
:
P
.
kwargs
)
->
Any
:
if
self
.
layerwise_nvtx_tracing_enabled
:
args_list
=
list
(
args
)
kwargs_dict
=
dict
(
kwargs
)
...
...
@@ -108,7 +118,7 @@ class TorchCompileWithNoGuardsWrapper:
return
ctx
.
result
return
callable_fn
(
*
args
,
**
kwargs
)
def
__init__
(
self
):
def
__init__
(
self
)
->
None
:
self
.
compiled
=
False
vllm_config
=
get_current_vllm_config
()
...
...
@@ -192,9 +202,9 @@ class TorchCompileWithNoGuardsWrapper:
if
envs
.
VLLM_USE_BYTECODE_HOOK
and
mode
!=
CompilationMode
.
STOCK_TORCH_COMPILE
:
torch
.
_dynamo
.
convert_frame
.
register_bytecode_hook
(
self
.
bytecode_hook
)
self
.
_compiled_bytecode
=
None
self
.
_compiled_bytecode
:
CodeType
|
None
=
None
def
aot_compile
(
self
,
*
args
,
**
kwargs
)
:
def
aot_compile
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
if
not
hasattr
(
self
.
_compiled_callable
,
"aot_compile"
):
raise
RuntimeError
(
"aot_compile is not supported by the current configuration. "
...
...
@@ -203,7 +213,7 @@ class TorchCompileWithNoGuardsWrapper:
)
return
self
.
_compiled_callable
.
aot_compile
((
args
,
kwargs
))
def
__call__
(
self
,
*
args
,
**
kwargs
)
:
def
__call__
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
if
envs
.
VLLM_USE_BYTECODE_HOOK
:
if
(
self
.
vllm_config
.
compilation_config
.
mode
...
...
@@ -236,13 +246,13 @@ class TorchCompileWithNoGuardsWrapper:
)
@
abstractmethod
def
forward
(
self
,
*
args
,
**
kwargs
)
:
...
def
forward
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
...
def
original_code_object
(
self
)
->
CodeType
:
"""Return the original code object of the forward method."""
return
self
.
__class__
.
forward
.
__code__
def
bytecode_hook
(
self
,
old_code
:
CodeType
,
new_code
:
CodeType
):
def
bytecode_hook
(
self
,
old_code
:
CodeType
,
new_code
:
CodeType
)
->
None
:
"""Hook to save the compiled bytecode for direct execution."""
if
old_code
is
not
self
.
original_code_object
():
return
...
...
@@ -299,7 +309,7 @@ class TorchCompileWithNoGuardsWrapper:
raise
RuntimeError
(
msg
)
@
contextmanager
def
_dispatch_to_compiled_code
(
self
):
def
_dispatch_to_compiled_code
(
self
)
->
Generator
[
None
,
None
,
None
]
:
# noqa: E501
"""
Context manager to dispatch to internally compiled code for torch<2.8.
...
...
vllm/config/compilation.py
View file @
aaf4b70a
...
...
@@ -32,6 +32,9 @@ else:
logger
=
init_logger
(
__name__
)
# Explicitly exports Range
__all__
=
[
"Range"
]
class
CompilationMode
(
enum
.
IntEnum
):
"""The compilation approach used for torch.compile-based compilation of the
...
...
vllm/distributed/device_communicators/pynccl_allocator.py
View file @
aaf4b70a
...
...
@@ -60,7 +60,7 @@ def is_symmetric_memory_tensor(tensor: torch.Tensor):
return
False
def
set_graph_pool_id
(
graph_pool_id
)
:
def
set_graph_pool_id
(
graph_pool_id
:
Any
)
->
None
:
global
_graph_pool_id
_graph_pool_id
=
graph_pool_id
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment