Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
aaf4b70a
Unverified
Commit
aaf4b70a
authored
Jan 09, 2026
by
Lucas Kabela
Committed by
GitHub
Jan 09, 2026
Browse files
[Misc][BE] Type coverage for vllm/compilation [2/3] (#31744)
parent
3adffd5b
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
161 additions
and
91 deletions
+161
-91
vllm/compilation/backends.py
vllm/compilation/backends.py
+6
-6
vllm/compilation/caching.py
vllm/compilation/caching.py
+16
-7
vllm/compilation/cuda_graph.py
vllm/compilation/cuda_graph.py
+12
-10
vllm/compilation/decorators.py
vllm/compilation/decorators.py
+42
-20
vllm/compilation/fix_functionalization.py
vllm/compilation/fix_functionalization.py
+5
-5
vllm/compilation/inductor_pass.py
vllm/compilation/inductor_pass.py
+3
-0
vllm/compilation/noop_elimination.py
vllm/compilation/noop_elimination.py
+2
-2
vllm/compilation/pass_manager.py
vllm/compilation/pass_manager.py
+19
-12
vllm/compilation/piecewise_backend.py
vllm/compilation/piecewise_backend.py
+29
-15
vllm/compilation/wrapper.py
vllm/compilation/wrapper.py
+23
-13
vllm/config/compilation.py
vllm/config/compilation.py
+3
-0
vllm/distributed/device_communicators/pynccl_allocator.py
vllm/distributed/device_communicators/pynccl_allocator.py
+1
-1
No files found.
vllm/compilation/backends.py
View file @
aaf4b70a
...
@@ -179,7 +179,7 @@ class CompilerManager:
...
@@ -179,7 +179,7 @@ class CompilerManager:
example_inputs
:
list
[
Any
],
example_inputs
:
list
[
Any
],
graph_index
:
int
,
graph_index
:
int
,
compile_range
:
Range
,
compile_range
:
Range
,
)
->
Callable
|
None
:
)
->
Callable
[...,
Any
]
|
None
:
if
(
compile_range
,
graph_index
,
self
.
compiler
.
name
)
not
in
self
.
cache
:
if
(
compile_range
,
graph_index
,
self
.
compiler
.
name
)
not
in
self
.
cache
:
return
None
return
None
handle
=
self
.
cache
[(
compile_range
,
graph_index
,
self
.
compiler
.
name
)]
handle
=
self
.
cache
[(
compile_range
,
graph_index
,
self
.
compiler
.
name
)]
...
@@ -199,7 +199,7 @@ class CompilerManager:
...
@@ -199,7 +199,7 @@ class CompilerManager:
self
,
self
,
graph
:
fx
.
GraphModule
,
graph
:
fx
.
GraphModule
,
example_inputs
:
list
[
Any
],
example_inputs
:
list
[
Any
],
additional_inductor_config
,
additional_inductor_config
:
dict
[
str
,
Any
]
,
compilation_config
:
CompilationConfig
,
compilation_config
:
CompilationConfig
,
compile_range
:
Range
,
compile_range
:
Range
,
graph_index
:
int
=
0
,
graph_index
:
int
=
0
,
...
@@ -355,7 +355,7 @@ def split_graph(
...
@@ -355,7 +355,7 @@ def split_graph(
compilation_start_time
=
0.0
compilation_start_time
=
0.0
class
PiecewiseCompileInterpreter
(
torch
.
fx
.
Interpreter
):
class
PiecewiseCompileInterpreter
(
torch
.
fx
.
Interpreter
):
# type: ignore[misc]
"""Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
"""Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
It runs the given graph with fake inputs, and compile some
It runs the given graph with fake inputs, and compile some
submodules specified by `compile_submod_names` with the given
submodules specified by `compile_submod_names` with the given
...
@@ -506,9 +506,9 @@ class VllmBackend:
...
@@ -506,9 +506,9 @@ class VllmBackend:
# the stiching graph module for all the piecewise graphs
# the stiching graph module for all the piecewise graphs
split_gm
:
fx
.
GraphModule
split_gm
:
fx
.
GraphModule
piecewise_graphs
:
list
[
SplitItem
]
piecewise_graphs
:
list
[
SplitItem
]
returned_callable
:
Callable
returned_callable
:
Callable
[...,
Any
]
# Inductor passes to run on the graph pre-defunctionalization
# Inductor passes to run on the graph pre-defunctionalization
post_grad_passes
:
Sequence
[
Callable
]
post_grad_passes
:
Sequence
[
Callable
[...,
Any
]
]
sym_tensor_indices
:
list
[
int
]
sym_tensor_indices
:
list
[
int
]
input_buffers
:
list
[
torch
.
Tensor
]
input_buffers
:
list
[
torch
.
Tensor
]
compiler_manager
:
CompilerManager
compiler_manager
:
CompilerManager
...
@@ -821,7 +821,7 @@ class VllmBackend:
...
@@ -821,7 +821,7 @@ class VllmBackend:
]
]
# this is the callable we return to Dynamo to run
# this is the callable we return to Dynamo to run
def
copy_and_call
(
*
args
)
:
def
copy_and_call
(
*
args
:
Any
)
->
Any
:
list_args
=
list
(
args
)
list_args
=
list
(
args
)
for
i
,
index
in
enumerate
(
self
.
sym_tensor_indices
):
for
i
,
index
in
enumerate
(
self
.
sym_tensor_indices
):
runtime_tensor
=
list_args
[
index
]
runtime_tensor
=
list_args
[
index
]
...
...
vllm/compilation/caching.py
View file @
aaf4b70a
...
@@ -4,6 +4,8 @@
...
@@ -4,6 +4,8 @@
import
inspect
import
inspect
import
os
import
os
import
pickle
import
pickle
from
collections.abc
import
Callable
,
Sequence
from
typing
import
Any
,
Literal
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
torch
import
torch
...
@@ -25,7 +27,7 @@ assert isinstance(SerializableCallable, type)
...
@@ -25,7 +27,7 @@ assert isinstance(SerializableCallable, type)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
class
VllmSerializableFunction
(
SerializableCallable
):
class
VllmSerializableFunction
(
SerializableCallable
):
# type: ignore[misc]
"""
"""
A wrapper around a compiled function by vllm. It will forward the tensor
A wrapper around a compiled function by vllm. It will forward the tensor
inputs to the compiled function and return the result.
inputs to the compiled function and return the result.
...
@@ -38,8 +40,13 @@ class VllmSerializableFunction(SerializableCallable):
...
@@ -38,8 +40,13 @@ class VllmSerializableFunction(SerializableCallable):
"""
"""
def
__init__
(
def
__init__
(
self
,
graph_module
,
example_inputs
,
prefix
,
optimized_call
,
is_encoder
=
False
self
,
):
graph_module
:
torch
.
fx
.
GraphModule
,
example_inputs
:
Sequence
[
Any
],
prefix
:
str
,
optimized_call
:
Callable
[...,
Any
],
is_encoder
:
bool
=
False
,
)
->
None
:
assert
isinstance
(
graph_module
,
torch
.
fx
.
GraphModule
)
assert
isinstance
(
graph_module
,
torch
.
fx
.
GraphModule
)
self
.
graph_module
=
graph_module
self
.
graph_module
=
graph_module
self
.
example_inputs
=
example_inputs
self
.
example_inputs
=
example_inputs
...
@@ -53,7 +60,7 @@ class VllmSerializableFunction(SerializableCallable):
...
@@ -53,7 +60,7 @@ class VllmSerializableFunction(SerializableCallable):
if
sym_input
is
not
None
:
if
sym_input
is
not
None
:
self
.
shape_env
=
sym_input
.
node
.
shape_env
self
.
shape_env
=
sym_input
.
node
.
shape_env
def
__call__
(
self
,
*
args
,
**
kwargs
)
:
def
__call__
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
return
self
.
optimized_call
(
*
args
,
**
kwargs
)
return
self
.
optimized_call
(
*
args
,
**
kwargs
)
@
classmethod
@
classmethod
...
@@ -73,7 +80,9 @@ class VllmSerializableFunction(SerializableCallable):
...
@@ -73,7 +80,9 @@ class VllmSerializableFunction(SerializableCallable):
graph_reducer_override
=
GraphPickler
.
reducer_override
graph_reducer_override
=
GraphPickler
.
reducer_override
def
_graph_reducer_override
(
self
,
obj
):
def
_graph_reducer_override
(
self
:
GraphPickler
,
obj
:
Any
)
->
tuple
[
Callable
[...,
Any
],
tuple
[
Any
,
...]]
|
Any
:
if
(
if
(
inspect
.
isclass
(
obj
)
inspect
.
isclass
(
obj
)
and
issubclass
(
obj
,
sympy
.
Function
)
and
issubclass
(
obj
,
sympy
.
Function
)
...
@@ -114,7 +123,7 @@ class VllmSerializableFunction(SerializableCallable):
...
@@ -114,7 +123,7 @@ class VllmSerializableFunction(SerializableCallable):
get_current_vllm_config
(),
state
[
"prefix"
],
is_encoder
get_current_vllm_config
(),
state
[
"prefix"
],
is_encoder
)
)
def
optimized_call
(
*
example_inputs
)
:
def
optimized_call
(
*
example_inputs
:
Any
)
->
Any
:
"""
"""
On the first run of the optimized call, we rerun the compiler
On the first run of the optimized call, we rerun the compiler
backend which should result in a cache hit. After the backend
backend which should result in a cache hit. After the backend
...
@@ -136,7 +145,7 @@ class VllmSerializableFunction(SerializableCallable):
...
@@ -136,7 +145,7 @@ class VllmSerializableFunction(SerializableCallable):
return
fn
return
fn
@
property
@
property
def
co_name
(
self
):
def
co_name
(
self
)
->
Literal
[
"VllmSerializableFunction"
]
:
"""
"""
Used for depyf debugging.
Used for depyf debugging.
"""
"""
...
...
vllm/compilation/cuda_graph.py
View file @
aaf4b70a
...
@@ -42,7 +42,9 @@ class CUDAGraphLogging:
...
@@ -42,7 +42,9 @@ class CUDAGraphLogging:
"Count"
,
"Count"
,
]
]
def
__init__
(
self
,
cg_mode
:
CUDAGraphMode
,
cg_capture_sizes
:
list
[
int
]
|
None
):
def
__init__
(
self
,
cg_mode
:
CUDAGraphMode
,
cg_capture_sizes
:
list
[
int
]
|
None
)
->
None
:
self
.
reset
()
self
.
reset
()
self
.
cg_mode
=
str
(
cg_mode
)
self
.
cg_mode
=
str
(
cg_mode
)
self
.
cg_capture_sizes
=
str
(
cg_capture_sizes
or
[])
self
.
cg_capture_sizes
=
str
(
cg_capture_sizes
or
[])
...
@@ -54,10 +56,10 @@ class CUDAGraphLogging:
...
@@ -54,10 +56,10 @@ class CUDAGraphLogging:
"**CUDAGraph Stats:**
\n\n
"
"**CUDAGraph Stats:**
\n\n
"
)
)
def
reset
(
self
):
def
reset
(
self
)
->
None
:
self
.
stats
=
[]
self
.
stats
:
list
[
CUDAGraphStat
]
=
[]
def
observe
(
self
,
cudagraph_stat
:
CUDAGraphStat
):
def
observe
(
self
,
cudagraph_stat
:
CUDAGraphStat
)
->
None
:
self
.
stats
.
append
(
cudagraph_stat
)
self
.
stats
.
append
(
cudagraph_stat
)
def
generate_metric_table
(
self
)
->
str
:
def
generate_metric_table
(
self
)
->
str
:
...
@@ -109,7 +111,7 @@ class CUDAGraphLogging:
...
@@ -109,7 +111,7 @@ class CUDAGraphLogging:
+
"
\n
"
+
"
\n
"
)
)
def
log
(
self
,
log_fn
=
logger
.
info
):
def
log
(
self
,
log_fn
:
Callable
[...,
Any
]
=
logger
.
info
)
->
None
:
if
not
self
.
stats
:
if
not
self
.
stats
:
return
return
log_fn
(
self
.
generate_metric_table
())
log_fn
(
self
.
generate_metric_table
())
...
@@ -161,11 +163,11 @@ class CUDAGraphWrapper:
...
@@ -161,11 +163,11 @@ class CUDAGraphWrapper:
def
__init__
(
def
__init__
(
self
,
self
,
runnable
:
Callable
,
runnable
:
Callable
[...,
Any
]
,
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
runtime_mode
:
CUDAGraphMode
,
runtime_mode
:
CUDAGraphMode
,
cudagraph_options
:
CUDAGraphOptions
|
None
=
None
,
cudagraph_options
:
CUDAGraphOptions
|
None
=
None
,
):
)
->
None
:
self
.
runnable
=
runnable
self
.
runnable
=
runnable
self
.
vllm_config
=
vllm_config
self
.
vllm_config
=
vllm_config
self
.
runtime_mode
=
runtime_mode
self
.
runtime_mode
=
runtime_mode
...
@@ -189,7 +191,7 @@ class CUDAGraphWrapper:
...
@@ -189,7 +191,7 @@ class CUDAGraphWrapper:
# cudagraphs for.
# cudagraphs for.
self
.
concrete_cudagraph_entries
:
dict
[
BatchDescriptor
,
CUDAGraphEntry
]
=
{}
self
.
concrete_cudagraph_entries
:
dict
[
BatchDescriptor
,
CUDAGraphEntry
]
=
{}
def
__getattr__
(
self
,
key
:
str
):
def
__getattr__
(
self
,
key
:
str
)
->
Any
:
# allow accessing the attributes of the runnable.
# allow accessing the attributes of the runnable.
if
hasattr
(
self
.
runnable
,
key
):
if
hasattr
(
self
.
runnable
,
key
):
return
getattr
(
self
.
runnable
,
key
)
return
getattr
(
self
.
runnable
,
key
)
...
@@ -198,11 +200,11 @@ class CUDAGraphWrapper:
...
@@ -198,11 +200,11 @@ class CUDAGraphWrapper:
f
"cudagraph wrapper:
{
self
.
runnable
}
"
f
"cudagraph wrapper:
{
self
.
runnable
}
"
)
)
def
unwrap
(
self
)
->
Callable
:
def
unwrap
(
self
)
->
Callable
[...,
Any
]
:
# in case we need to access the original runnable.
# in case we need to access the original runnable.
return
self
.
runnable
return
self
.
runnable
def
__call__
(
self
,
*
args
,
**
kwargs
)
:
def
__call__
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
|
None
:
forward_context
=
get_forward_context
()
forward_context
=
get_forward_context
()
batch_descriptor
=
forward_context
.
batch_descriptor
batch_descriptor
=
forward_context
.
batch_descriptor
cudagraph_runtime_mode
=
forward_context
.
cudagraph_runtime_mode
cudagraph_runtime_mode
=
forward_context
.
cudagraph_runtime_mode
...
...
vllm/compilation/decorators.py
View file @
aaf4b70a
...
@@ -6,8 +6,8 @@ import hashlib
...
@@ -6,8 +6,8 @@ import hashlib
import
inspect
import
inspect
import
os
import
os
import
sys
import
sys
from
collections.abc
import
Callable
from
collections.abc
import
Callable
,
Generator
from
typing
import
TypeVar
,
overload
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
TypeVar
,
overload
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
torch
import
torch
...
@@ -32,6 +32,14 @@ from vllm.utils.torch_utils import is_torch_equal_or_newer, supports_dynamo
...
@@ -32,6 +32,14 @@ from vllm.utils.torch_utils import is_torch_equal_or_newer, supports_dynamo
from
.monitor
import
start_monitoring_torch_compile
from
.monitor
import
start_monitoring_torch_compile
if
TYPE_CHECKING
:
# Only added on nightly/2.10 so wrap
try
:
from
torch._dynamo.package
import
SourceInfo
except
ImportError
:
# Fallback for old versions not supporting
SourceInfo
=
Any
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
IGNORE_COMPILE_KEY
=
"_ignore_compile_vllm"
IGNORE_COMPILE_KEY
=
"_ignore_compile_vllm"
...
@@ -59,7 +67,7 @@ def ignore_torch_compile(cls: _T) -> _T:
...
@@ -59,7 +67,7 @@ def ignore_torch_compile(cls: _T) -> _T:
return
cls
return
cls
def
_should_ignore_torch_compile
(
cls
)
->
bool
:
def
_should_ignore_torch_compile
(
cls
:
_T
)
->
bool
:
"""
"""
Check if the class should be ignored for torch.compile.
Check if the class should be ignored for torch.compile.
"""
"""
...
@@ -224,7 +232,7 @@ def support_torch_compile(
...
@@ -224,7 +232,7 @@ def support_torch_compile(
return
cls_decorator_helper
return
cls_decorator_helper
def
_model_hash_key
(
fn
)
->
str
:
def
_model_hash_key
(
fn
:
Callable
[...,
Any
]
)
->
str
:
import
vllm
import
vllm
sha256_hash
=
hashlib
.
sha256
()
sha256_hash
=
hashlib
.
sha256
()
...
@@ -234,7 +242,9 @@ def _model_hash_key(fn) -> str:
...
@@ -234,7 +242,9 @@ def _model_hash_key(fn) -> str:
return
sha256_hash
.
hexdigest
()
return
sha256_hash
.
hexdigest
()
def
_verify_source_unchanged
(
source_info
,
vllm_config
)
->
None
:
def
_verify_source_unchanged
(
source_info
:
"SourceInfo"
,
vllm_config
:
VllmConfig
)
->
None
:
from
.caching
import
_compute_code_hash
,
_compute_code_hash_with_content
from
.caching
import
_compute_code_hash
,
_compute_code_hash_with_content
file_contents
=
{}
file_contents
=
{}
...
@@ -275,8 +285,12 @@ def _support_torch_compile(
...
@@ -275,8 +285,12 @@ def _support_torch_compile(
setattr
(
cls
,
IGNORE_COMPILE_KEY
,
False
)
setattr
(
cls
,
IGNORE_COMPILE_KEY
,
False
)
def
__init__
(
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
|
None
=
None
,
prefix
:
str
=
""
,
**
kwargs
self
:
_T
,
):
*
,
vllm_config
:
VllmConfig
|
None
=
None
,
prefix
:
str
=
""
,
**
kwargs
:
Any
,
)
->
None
:
if
vllm_config
is
None
:
if
vllm_config
is
None
:
vllm_config
=
get_current_vllm_config
()
vllm_config
=
get_current_vllm_config
()
...
@@ -309,13 +323,17 @@ def _support_torch_compile(
...
@@ -309,13 +323,17 @@ def _support_torch_compile(
compilation_counter
.
num_models_seen
+=
1
compilation_counter
.
num_models_seen
+=
1
self
.
compiled
=
False
self
.
compiled
=
False
TorchCompileWithNoGuardsWrapper
.
__init__
(
self
)
# Handled by monkeypatching `TorchCompileWithNoGuardsWrapper` into base class
TorchCompileWithNoGuardsWrapper
.
__init__
(
self
)
# type: ignore[arg-type]
cls
.
__init__
=
__init__
cls
.
__init__
=
__init__
def
_mark_dynamic_inputs
(
mod
,
type
,
*
args
,
**
kwargs
):
def
_mark_dynamic_inputs
(
def
mark_dynamic
(
arg
,
dims
):
mod
:
_T
,
ds_type
:
DynamicShapesType
,
*
args
:
Any
,
**
kwargs
:
Any
if
type
==
DynamicShapesType
.
UNBACKED
:
)
->
None
:
def
mark_dynamic
(
arg
:
torch
.
Tensor
,
dims
:
list
[
int
])
->
None
:
if
ds_type
==
DynamicShapesType
.
UNBACKED
:
if
is_torch_equal_or_newer
(
"2.10.0.dev"
):
if
is_torch_equal_or_newer
(
"2.10.0.dev"
):
for
dim
in
dims
:
for
dim
in
dims
:
torch
.
_dynamo
.
decorators
.
mark_unbacked
(
torch
.
_dynamo
.
decorators
.
mark_unbacked
(
...
@@ -326,7 +344,7 @@ def _support_torch_compile(
...
@@ -326,7 +344,7 @@ def _support_torch_compile(
else
:
else
:
torch
.
_dynamo
.
mark_dynamic
(
arg
,
dims
)
torch
.
_dynamo
.
mark_dynamic
(
arg
,
dims
)
sig
=
inspect
.
signature
(
mod
.
__class__
.
forward
)
sig
=
inspect
.
signature
(
mod
.
__class__
.
forward
)
# type: ignore[attr-defined]
bound_args
=
sig
.
bind
(
mod
,
*
args
,
**
kwargs
)
bound_args
=
sig
.
bind
(
mod
,
*
args
,
**
kwargs
)
bound_args
.
apply_defaults
()
bound_args
.
apply_defaults
()
for
k
,
dims
in
dynamic_arg_dims
.
items
():
for
k
,
dims
in
dynamic_arg_dims
.
items
():
...
@@ -364,7 +382,7 @@ def _support_torch_compile(
...
@@ -364,7 +382,7 @@ def _support_torch_compile(
else
:
else
:
torch
.
_dynamo
.
decorators
.
mark_unbacked
(
arg
,
dims
)
torch
.
_dynamo
.
decorators
.
mark_unbacked
(
arg
,
dims
)
def
__call__
(
self
,
*
args
,
**
kwargs
)
:
def
__call__
(
self
:
_T
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
# torch.compiler.is_compiling() means we are inside the compilation
# torch.compiler.is_compiling() means we are inside the compilation
# e.g. TPU has the compilation logic in model runner, so we don't
# e.g. TPU has the compilation logic in model runner, so we don't
# need to compile the model inside.
# need to compile the model inside.
...
@@ -444,7 +462,7 @@ def _support_torch_compile(
...
@@ -444,7 +462,7 @@ def _support_torch_compile(
not
envs
.
VLLM_USE_AOT_COMPILE
not
envs
.
VLLM_USE_AOT_COMPILE
or
self
.
vllm_config
.
compilation_config
.
backend
==
"eager"
or
self
.
vllm_config
.
compilation_config
.
backend
==
"eager"
)
)
return
TorchCompileWithNoGuardsWrapper
.
__call__
(
self
,
*
args
,
**
kwargs
)
return
TorchCompileWithNoGuardsWrapper
.
__call__
(
self
,
*
args
,
**
kwargs
)
# type: ignore[arg-type]
# This is the path for the first compilation.
# This is the path for the first compilation.
# the first compilation needs to have dynamic shapes marked
# the first compilation needs to have dynamic shapes marked
...
@@ -477,7 +495,7 @@ def _support_torch_compile(
...
@@ -477,7 +495,7 @@ def _support_torch_compile(
# during Dynamo tracing, and their corresponding files
# during Dynamo tracing, and their corresponding files
inline_call
=
InliningInstructionTranslator
.
inline_call_
inline_call
=
InliningInstructionTranslator
.
inline_call_
def
patched_inline_call
(
self_
)
:
def
patched_inline_call
(
self_
:
Any
)
->
Any
:
code
=
self_
.
f_code
code
=
self_
.
f_code
self
.
compilation_config
.
traced_files
.
add
(
code
.
co_filename
)
self
.
compilation_config
.
traced_files
.
add
(
code
.
co_filename
)
return
inline_call
(
self_
)
return
inline_call
(
self_
)
...
@@ -535,7 +553,7 @@ def _support_torch_compile(
...
@@ -535,7 +553,7 @@ def _support_torch_compile(
str
(
e
),
str
(
e
),
)
)
else
:
else
:
output
=
TorchCompileWithNoGuardsWrapper
.
__call__
(
self
,
*
args
,
**
kwargs
)
output
=
TorchCompileWithNoGuardsWrapper
.
__call__
(
self
,
*
args
,
**
kwargs
)
# type: ignore[arg-type]
self
.
compiled
=
True
self
.
compiled
=
True
return
output
return
output
...
@@ -545,7 +563,9 @@ def _support_torch_compile(
...
@@ -545,7 +563,9 @@ def _support_torch_compile(
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
maybe_use_cudagraph_partition_wrapper
(
vllm_config
:
VllmConfig
):
def
maybe_use_cudagraph_partition_wrapper
(
vllm_config
:
VllmConfig
,
)
->
Generator
[
None
,
None
,
None
]:
"""
"""
Context manager to set/unset customized cudagraph partition wrappers.
Context manager to set/unset customized cudagraph partition wrappers.
...
@@ -572,7 +592,9 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
...
@@ -572,7 +592,9 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
current_platform
.
get_static_graph_wrapper_cls
()
current_platform
.
get_static_graph_wrapper_cls
()
)
)
def
customized_cudagraph_wrapper
(
f
,
metadata
:
CUDAGraphWrapperMetadata
):
def
customized_cudagraph_wrapper
(
f
:
Callable
[...,
Any
],
metadata
:
CUDAGraphWrapperMetadata
)
->
Any
:
partition_id
=
metadata
.
partition_index
partition_id
=
metadata
.
partition_index
num_partitions
=
metadata
.
num_partitions
num_partitions
=
metadata
.
num_partitions
return
static_graph_wrapper_class
(
return
static_graph_wrapper_class
(
...
@@ -600,7 +622,7 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
...
@@ -600,7 +622,7 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
_torch27_patch_tensor_subclasses
():
def
_torch27_patch_tensor_subclasses
()
->
Generator
[
None
,
None
,
None
]
:
"""
"""
Add support for using tensor subclasses (ie `BasevLLMParameter`, ect) when
Add support for using tensor subclasses (ie `BasevLLMParameter`, ect) when
using torch 2.7.0. This enables using weight_loader_v2 and the use of
using torch 2.7.0. This enables using weight_loader_v2 and the use of
...
@@ -614,7 +636,7 @@ def _torch27_patch_tensor_subclasses():
...
@@ -614,7 +636,7 @@ def _torch27_patch_tensor_subclasses():
_ColumnvLLMParameter
,
_ColumnvLLMParameter
,
)
)
def
return_false
(
*
args
,
**
kwargs
)
:
def
return_false
(
*
args
:
Any
,
**
kwargs
:
Any
)
->
Literal
[
False
]
:
return
False
return
False
if
version
.
parse
(
"2.7"
)
<=
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
"2.8"
):
if
version
.
parse
(
"2.7"
)
<=
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
"2.8"
):
...
...
vllm/compilation/fix_functionalization.py
View file @
aaf4b70a
...
@@ -26,7 +26,7 @@ class FixFunctionalizationPass(VllmInductorPass):
...
@@ -26,7 +26,7 @@ class FixFunctionalizationPass(VllmInductorPass):
"""
"""
@
VllmInductorPass
.
time_and_log
@
VllmInductorPass
.
time_and_log
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
):
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
)
->
None
:
# XPU does not support auto-functionalization yet.
# XPU does not support auto-functionalization yet.
# Will enable this when switch to vllm-xpu-kernels.
# Will enable this when switch to vllm-xpu-kernels.
if
current_platform
.
is_xpu
():
if
current_platform
.
is_xpu
():
...
@@ -179,7 +179,7 @@ class FixFunctionalizationPass(VllmInductorPass):
...
@@ -179,7 +179,7 @@ class FixFunctionalizationPass(VllmInductorPass):
)
)
self
.
nodes_to_remove
.
clear
()
self
.
nodes_to_remove
.
clear
()
def
_remove
(
self
,
node_or_nodes
:
torch
.
fx
.
Node
|
Iterable
[
torch
.
fx
.
Node
]):
def
_remove
(
self
,
node_or_nodes
:
torch
.
fx
.
Node
|
Iterable
[
torch
.
fx
.
Node
])
->
None
:
"""
"""
Stage a node (or nodes) for removal at the end of the pass.
Stage a node (or nodes) for removal at the end of the pass.
"""
"""
...
@@ -194,7 +194,7 @@ class FixFunctionalizationPass(VllmInductorPass):
...
@@ -194,7 +194,7 @@ class FixFunctionalizationPass(VllmInductorPass):
node
:
torch
.
fx
.
Node
,
node
:
torch
.
fx
.
Node
,
mutated_args
:
dict
[
int
,
torch
.
fx
.
Node
|
str
],
mutated_args
:
dict
[
int
,
torch
.
fx
.
Node
|
str
],
args
:
tuple
[
torch
.
fx
.
Node
|
str
,
...]
|
None
=
None
,
args
:
tuple
[
torch
.
fx
.
Node
|
str
,
...]
|
None
=
None
,
):
)
->
None
:
"""
"""
De-functionalize a node by replacing it with a call to the original.
De-functionalize a node by replacing it with a call to the original.
It also replaces the getitem users with the mutated arguments.
It also replaces the getitem users with the mutated arguments.
...
@@ -206,7 +206,7 @@ class FixFunctionalizationPass(VllmInductorPass):
...
@@ -206,7 +206,7 @@ class FixFunctionalizationPass(VllmInductorPass):
def
replace_users_with_mutated_args
(
def
replace_users_with_mutated_args
(
self
,
node
:
torch
.
fx
.
Node
,
mutated_args
:
dict
[
int
,
torch
.
fx
.
Node
|
str
]
self
,
node
:
torch
.
fx
.
Node
,
mutated_args
:
dict
[
int
,
torch
.
fx
.
Node
|
str
]
):
)
->
None
:
"""
"""
Replace all getitem users of the auto-functionalized node with the
Replace all getitem users of the auto-functionalized node with the
mutated arguments.
mutated arguments.
...
@@ -237,7 +237,7 @@ class FixFunctionalizationPass(VllmInductorPass):
...
@@ -237,7 +237,7 @@ class FixFunctionalizationPass(VllmInductorPass):
graph
:
torch
.
fx
.
Graph
,
graph
:
torch
.
fx
.
Graph
,
node
:
torch
.
fx
.
Node
,
node
:
torch
.
fx
.
Node
,
args
:
tuple
[
torch
.
fx
.
Node
|
str
,
...]
|
None
=
None
,
args
:
tuple
[
torch
.
fx
.
Node
|
str
,
...]
|
None
=
None
,
):
)
->
None
:
"""
"""
Insert a new defunctionalized node into the graph before node.
Insert a new defunctionalized node into the graph before node.
If one of the kwargs is 'out', provide args directly,
If one of the kwargs is 'out', provide args directly,
...
...
vllm/compilation/inductor_pass.py
View file @
aaf4b70a
...
@@ -29,6 +29,9 @@ else:
...
@@ -29,6 +29,9 @@ else:
Torch25CustomGraphPass
as
CustomGraphPass
,
Torch25CustomGraphPass
as
CustomGraphPass
,
)
)
# Re-export CustomGraphPass for external usage
__all__
=
[
"CustomGraphPass"
]
_pass_context
=
None
_pass_context
=
None
P
=
ParamSpec
(
"P"
)
P
=
ParamSpec
(
"P"
)
R
=
TypeVar
(
"R"
)
R
=
TypeVar
(
"R"
)
...
...
vllm/compilation/noop_elimination.py
View file @
aaf4b70a
...
@@ -65,7 +65,7 @@ class NoOpEliminationPass(VllmInductorPass):
...
@@ -65,7 +65,7 @@ class NoOpEliminationPass(VllmInductorPass):
"""
"""
@
VllmInductorPass
.
time_and_log
@
VllmInductorPass
.
time_and_log
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
):
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
)
->
None
:
count
=
0
count
=
0
# Remove no-op reshapes/views:
# Remove no-op reshapes/views:
for
node
in
graph
.
nodes
:
for
node
in
graph
.
nodes
:
...
@@ -117,7 +117,7 @@ class NoOpEliminationPass(VllmInductorPass):
...
@@ -117,7 +117,7 @@ class NoOpEliminationPass(VllmInductorPass):
2. The dimensions both correspond to the same SymInt
2. The dimensions both correspond to the same SymInt
"""
"""
# Case 1
# Case 1
return
statically_known_true
(
dim
==
i_dim
)
return
statically_known_true
(
dim
==
i_dim
)
# type: ignore[no-any-return]
def
all_dims_equivalent
(
def
all_dims_equivalent
(
self
,
dims
:
Iterable
[
int
|
SymInt
],
i_dims
:
Iterable
[
int
|
SymInt
]
self
,
dims
:
Iterable
[
int
|
SymInt
],
i_dims
:
Iterable
[
int
|
SymInt
]
...
...
vllm/compilation/pass_manager.py
View file @
aaf4b70a
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
functools
import
functools
from
collections.abc
import
Callable
from
typing
import
Any
,
ParamSpec
,
TypeVar
from
torch
import
fx
as
fx
from
torch
import
fx
as
fx
...
@@ -40,8 +42,11 @@ from .noop_elimination import NoOpEliminationPass
...
@@ -40,8 +42,11 @@ from .noop_elimination import NoOpEliminationPass
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
P
=
ParamSpec
(
"P"
)
R
=
TypeVar
(
"R"
)
def
with_pattern_match_debug
(
fn
):
def
with_pattern_match_debug
(
fn
:
Callable
[
P
,
R
])
->
Callable
[
P
,
R
]:
"""
"""
Function decorator that turns on inductor pattern match debug
Function decorator that turns on inductor pattern match debug
for the duration of the call.
for the duration of the call.
...
@@ -49,7 +54,7 @@ def with_pattern_match_debug(fn):
...
@@ -49,7 +54,7 @@ def with_pattern_match_debug(fn):
"""
"""
@
functools
.
wraps
(
fn
)
@
functools
.
wraps
(
fn
)
def
wrapper
(
*
args
,
**
kwargs
)
:
def
wrapper
(
*
args
:
P
.
args
,
**
kwargs
:
P
.
kwargs
)
->
R
:
if
(
debug_val
:
=
envs
.
VLLM_PATTERN_MATCH_DEBUG
)
is
not
None
:
if
(
debug_val
:
=
envs
.
VLLM_PATTERN_MATCH_DEBUG
)
is
not
None
:
# optionally check rank here
# optionally check rank here
with
set_env_var
(
"TORCHINDUCTOR_PATTERN_MATCH_DEBUG"
,
debug_val
):
with
set_env_var
(
"TORCHINDUCTOR_PATTERN_MATCH_DEBUG"
,
debug_val
):
...
@@ -59,7 +64,7 @@ def with_pattern_match_debug(fn):
...
@@ -59,7 +64,7 @@ def with_pattern_match_debug(fn):
return
wrapper
return
wrapper
class
PostGradPassManager
(
CustomGraphPass
):
class
PostGradPassManager
(
CustomGraphPass
):
# type: ignore[misc]
"""
"""
The pass manager for post-grad passes.
The pass manager for post-grad passes.
It handles configuration, adding custom passes, and running passes.
It handles configuration, adding custom passes, and running passes.
...
@@ -74,11 +79,11 @@ class PostGradPassManager(CustomGraphPass):
...
@@ -74,11 +79,11 @@ class PostGradPassManager(CustomGraphPass):
This way, all passes operate on a functionalized graph.
This way, all passes operate on a functionalized graph.
"""
"""
def
__init__
(
self
):
def
__init__
(
self
)
->
None
:
self
.
passes
:
list
[
InductorPass
]
=
[]
self
.
passes
:
list
[
InductorPass
]
=
[]
@
with_pattern_match_debug
@
with_pattern_match_debug
def
__call__
(
self
,
graph
:
fx
.
Graph
):
def
__call__
(
self
,
graph
:
fx
.
Graph
)
->
None
:
VllmInductorPass
.
dump_prefix
=
0
# reset dump index
VllmInductorPass
.
dump_prefix
=
0
# reset dump index
compile_range
=
get_pass_context
().
compile_range
compile_range
=
get_pass_context
().
compile_range
...
@@ -98,7 +103,7 @@ class PostGradPassManager(CustomGraphPass):
...
@@ -98,7 +103,7 @@ class PostGradPassManager(CustomGraphPass):
self
.
fix_functionalization
(
graph
)
self
.
fix_functionalization
(
graph
)
VllmInductorPass
.
dump_prefix
=
None
# Cleanup index
VllmInductorPass
.
dump_prefix
=
None
# Cleanup index
def
configure
(
self
,
config
:
VllmConfig
):
def
configure
(
self
,
config
:
VllmConfig
)
->
None
:
self
.
pass_config
=
config
.
compilation_config
.
pass_config
self
.
pass_config
=
config
.
compilation_config
.
pass_config
# Set the current vllm config to allow tracing CustomOp instances
# Set the current vllm config to allow tracing CustomOp instances
...
@@ -135,23 +140,25 @@ class PostGradPassManager(CustomGraphPass):
...
@@ -135,23 +140,25 @@ class PostGradPassManager(CustomGraphPass):
self
.
post_cleanup
=
PostCleanupPass
(
config
)
self
.
post_cleanup
=
PostCleanupPass
(
config
)
self
.
fix_functionalization
=
FixFunctionalizationPass
(
config
)
self
.
fix_functionalization
=
FixFunctionalizationPass
(
config
)
def
add
(
self
,
pass_
:
InductorPass
):
def
add
(
self
,
pass_
:
InductorPass
)
->
None
:
assert
isinstance
(
pass_
,
InductorPass
)
assert
isinstance
(
pass_
,
InductorPass
)
self
.
passes
.
append
(
pass_
)
self
.
passes
.
append
(
pass_
)
def
uuid
(
self
):
def
uuid
(
self
)
->
str
:
"""
"""
The PostGradPassManager is set as a custom pass in the Inductor and
The PostGradPassManager is set as a custom pass in the Inductor and
affects compilation caching. Its uuid depends on the UUIDs of all
affects compilation caching. Its uuid depends on the UUIDs of all
dependent passes and the pass config. See InductorPass for more info.
dependent passes and the pass config. See InductorPass for more info.
"""
"""
state
=
{
"pass_config"
:
self
.
pass_config
.
compute_hash
(),
"passes"
:
[]}
passes
=
[]
state
:
dict
[
str
,
Any
]
=
{
"pass_config"
:
self
.
pass_config
.
compute_hash
()}
for
pass_
in
self
.
passes
:
for
pass_
in
self
.
passes
:
state
[
"
passes
"
]
.
append
(
pass_
.
uuid
())
passes
.
append
(
pass_
.
uuid
())
state
[
"
passes
"
]
.
append
(
self
.
fix_functionalization
.
uuid
())
passes
.
append
(
self
.
fix_functionalization
.
uuid
())
# Include the compile range in the uuid to ensure that inductor
# Include the compile range in the uuid to ensure that inductor
# recompiles the graph for the new dynamic compile range.
# recompiles the graph for the new dynamic compile range.
state
[
"compile_range"
]
=
str
(
get_pass_context
().
compile_range
)
state
[
"compile_range"
]
=
str
(
get_pass_context
().
compile_range
)
state
[
"passes"
]
=
passes
return
InductorPass
.
hash_dict
(
state
)
return
InductorPass
.
hash_dict
(
state
)
vllm/compilation/piecewise_backend.py
View file @
aaf4b70a
...
@@ -86,7 +86,16 @@ class PiecewiseBackend:
...
@@ -86,7 +86,16 @@ class PiecewiseBackend:
self
.
to_be_compiled_ranges
:
set
[
Range
]
=
set
(
self
.
compile_ranges
)
self
.
to_be_compiled_ranges
:
set
[
Range
]
=
set
(
self
.
compile_ranges
)
# We only keep compilation management inside this class directly.
# We only keep compilation management inside this class directly.
if
self
.
compile_sizes
is
not
None
:
for
size
in
self
.
compile_sizes
:
for
size
in
self
.
compile_sizes
:
if
isinstance
(
size
,
str
):
assert
size
==
"cudagraph_capture_sizes"
raise
NotImplementedError
(
"cudagraph_capture_sizes not supported in compile_sizes."
"This should be handled in `post_init_cudagraph_sizes`."
)
else
:
assert
isinstance
(
size
,
int
)
range
=
Range
(
start
=
size
,
end
=
size
)
range
=
Range
(
start
=
size
,
end
=
size
)
if
range
not
in
self
.
compile_ranges
:
if
range
not
in
self
.
compile_ranges
:
self
.
range_entries
[
range
]
=
RangeEntry
(
self
.
range_entries
[
range
]
=
RangeEntry
(
...
@@ -99,14 +108,14 @@ class PiecewiseBackend:
...
@@ -99,14 +108,14 @@ class PiecewiseBackend:
compile_range
=
range
,
compile_range
=
range
,
)
)
def
check_for_ending_compilation
(
self
):
def
check_for_ending_compilation
(
self
)
->
None
:
if
self
.
is_last_graph
and
not
self
.
to_be_compiled_ranges
:
if
self
.
is_last_graph
and
not
self
.
to_be_compiled_ranges
:
# no specific sizes to compile
# no specific sizes to compile
# save the hash of the inductor graph for the next run
# save the hash of the inductor graph for the next run
self
.
vllm_backend
.
compiler_manager
.
save_to_file
()
self
.
vllm_backend
.
compiler_manager
.
save_to_file
()
end_monitoring_torch_compile
(
self
.
vllm_config
)
end_monitoring_torch_compile
(
self
.
vllm_config
)
def
_fakify_args
(
self
,
args
:
list
[
Any
])
->
list
[
Any
]:
def
_fakify_args
(
self
,
args
:
tuple
[
Any
,
...
])
->
list
[
Any
]:
# We need to pass fake example_inputs, otherwise torch.compile
# We need to pass fake example_inputs, otherwise torch.compile
# will fakify the example_inputs potentially causing some non dynamic
# will fakify the example_inputs potentially causing some non dynamic
# dimension to be be duck shaped to other existing shapes that have hints
# dimension to be be duck shaped to other existing shapes that have hints
...
@@ -127,7 +136,9 @@ class PiecewiseBackend:
...
@@ -127,7 +136,9 @@ class PiecewiseBackend:
assert
len
(
fake_example_inputs
)
==
len
(
args
)
assert
len
(
fake_example_inputs
)
==
len
(
args
)
return
fake_example_inputs
return
fake_example_inputs
def
_maybe_compile_for_range_entry
(
self
,
range_entry
:
RangeEntry
,
args
)
->
Any
:
def
_maybe_compile_for_range_entry
(
self
,
range_entry
:
RangeEntry
,
args
:
tuple
[
Any
,
...]
)
->
Any
:
if
not
range_entry
.
compiled
:
if
not
range_entry
.
compiled
:
range_entry
.
compiled
=
True
range_entry
.
compiled
=
True
self
.
to_be_compiled_ranges
.
remove
(
range_entry
.
compile_range
)
self
.
to_be_compiled_ranges
.
remove
(
range_entry
.
compile_range
)
...
@@ -136,14 +147,14 @@ class PiecewiseBackend:
...
@@ -136,14 +147,14 @@ class PiecewiseBackend:
# fakify for range, real args for concrete size.
# fakify for range, real args for concrete size.
# For concrete size, we clear the shape env in
# For concrete size, we clear the shape env in
# compiler_manager.compile() so no need to fakify.
# compiler_manager.compile() so no need to fakify.
args
=
(
args
_list
=
(
self
.
_fakify_args
(
args
)
self
.
_fakify_args
(
args
)
if
not
range_entry
.
compile_range
.
is_single_size
()
if
not
range_entry
.
compile_range
.
is_single_size
()
else
args
else
list
(
args
)
)
)
range_entry
.
runnable
=
self
.
vllm_backend
.
compiler_manager
.
compile
(
range_entry
.
runnable
=
self
.
vllm_backend
.
compiler_manager
.
compile
(
self
.
graph
,
self
.
graph
,
args
,
args
_list
,
self
.
vllm_backend
.
inductor_config
,
self
.
vllm_backend
.
inductor_config
,
self
.
compilation_config
,
self
.
compilation_config
,
compile_range
=
range_entry
.
compile_range
,
compile_range
=
range_entry
.
compile_range
,
...
@@ -153,10 +164,13 @@ class PiecewiseBackend:
...
@@ -153,10 +164,13 @@ class PiecewiseBackend:
self
.
check_for_ending_compilation
()
self
.
check_for_ending_compilation
()
def
_find_range_for_shape
(
self
,
runtime_shape
:
int
)
->
Range
|
None
:
def
_find_range_for_shape
(
self
,
runtime_shape
:
int
)
->
Range
Entry
|
None
:
# First we try to find the range entry for the concrete compile size
# First we try to find the range entry for the concrete compile size
# If not found, we search for the range entry
# If not found, we search for the range entry
# that contains the runtime shape.
# that contains the runtime shape.
if
self
.
compile_sizes
is
None
:
return
None
if
runtime_shape
in
self
.
compile_sizes
:
if
runtime_shape
in
self
.
compile_sizes
:
return
self
.
range_entries
[
Range
(
start
=
runtime_shape
,
end
=
runtime_shape
)]
return
self
.
range_entries
[
Range
(
start
=
runtime_shape
,
end
=
runtime_shape
)]
else
:
else
:
...
@@ -165,7 +179,7 @@ class PiecewiseBackend:
...
@@ -165,7 +179,7 @@ class PiecewiseBackend:
return
self
.
range_entries
[
range
]
return
self
.
range_entries
[
range
]
return
None
return
None
def
__call__
(
self
,
*
args
)
->
Any
:
def
__call__
(
self
,
*
args
:
Any
)
->
Any
:
runtime_shape
=
args
[
self
.
sym_shape_indices
[
0
]]
runtime_shape
=
args
[
self
.
sym_shape_indices
[
0
]]
range_entry
=
self
.
_find_range_for_shape
(
runtime_shape
)
range_entry
=
self
.
_find_range_for_shape
(
runtime_shape
)
...
...
vllm/compilation/wrapper.py
View file @
aaf4b70a
...
@@ -4,9 +4,10 @@
...
@@ -4,9 +4,10 @@
import
os
import
os
import
sys
import
sys
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
collections.abc
import
Callable
,
Generator
from
contextlib
import
contextmanager
,
nullcontext
from
contextlib
import
contextmanager
,
nullcontext
from
types
import
CodeType
from
types
import
CodeType
from
typing
import
Any
from
typing
import
Any
,
ParamSpec
,
TypeVar
import
torch
import
torch
import
torch._C._dynamo.guards
import
torch._C._dynamo.guards
...
@@ -19,19 +20,26 @@ from vllm.utils.nvtx_pytorch_hooks import layerwise_nvtx_marker_context
...
@@ -19,19 +20,26 @@ from vllm.utils.nvtx_pytorch_hooks import layerwise_nvtx_marker_context
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
R
=
TypeVar
(
"R"
)
P
=
ParamSpec
(
"P"
)
def
_noop_add_global_state_guard
(
self
,
*
args
,
**
kwargs
):
def
_noop_add_global_state_guard
(
self
:
torch
.
_C
.
_dynamo
.
guards
.
GuardManager
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
None
:
"""No-op to skip the GLOBAL_STATE guard entirely"""
"""No-op to skip the GLOBAL_STATE guard entirely"""
pass
pass
def
_noop_add_torch_function_mode_stack_guard
(
self
,
*
args
,
**
kwargs
):
def
_noop_add_torch_function_mode_stack_guard
(
self
:
torch
.
_C
.
_dynamo
.
guards
.
GuardManager
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
None
:
"""No-op to skip the TORCH_FUNCTION_MODE_STACK guard entirely"""
"""No-op to skip the TORCH_FUNCTION_MODE_STACK guard entirely"""
pass
pass
@
contextmanager
@
contextmanager
def
_compilation_context
():
def
_compilation_context
()
->
Generator
[
None
,
None
,
None
]
:
"""Context manager for compilation settings and patches.
"""Context manager for compilation settings and patches.
This manager:
This manager:
...
@@ -88,13 +96,15 @@ class TorchCompileWithNoGuardsWrapper:
...
@@ -88,13 +96,15 @@ class TorchCompileWithNoGuardsWrapper:
since we drop all guards.
since we drop all guards.
"""
"""
def
check_invariants_and_forward
(
self
,
*
args
,
**
kwargs
)
:
def
check_invariants_and_forward
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
assert
hasattr
(
self
,
"_check_shape_invariants"
)
assert
hasattr
(
self
,
"_check_shape_invariants"
)
self
.
_check_shape_invariants
(
*
args
,
**
kwargs
)
self
.
_check_shape_invariants
(
*
args
,
**
kwargs
)
return
self
.
forward
(
*
args
,
**
kwargs
)
return
self
.
forward
(
*
args
,
**
kwargs
)
def
_call_with_optional_nvtx_range
(
self
,
callable_fn
,
*
args
,
**
kwargs
):
def
_call_with_optional_nvtx_range
(
self
,
callable_fn
:
Callable
[
P
,
R
],
*
args
:
P
.
args
,
**
kwargs
:
P
.
kwargs
)
->
Any
:
if
self
.
layerwise_nvtx_tracing_enabled
:
if
self
.
layerwise_nvtx_tracing_enabled
:
args_list
=
list
(
args
)
args_list
=
list
(
args
)
kwargs_dict
=
dict
(
kwargs
)
kwargs_dict
=
dict
(
kwargs
)
...
@@ -108,7 +118,7 @@ class TorchCompileWithNoGuardsWrapper:
...
@@ -108,7 +118,7 @@ class TorchCompileWithNoGuardsWrapper:
return
ctx
.
result
return
ctx
.
result
return
callable_fn
(
*
args
,
**
kwargs
)
return
callable_fn
(
*
args
,
**
kwargs
)
def
__init__
(
self
):
def
__init__
(
self
)
->
None
:
self
.
compiled
=
False
self
.
compiled
=
False
vllm_config
=
get_current_vllm_config
()
vllm_config
=
get_current_vllm_config
()
...
@@ -192,9 +202,9 @@ class TorchCompileWithNoGuardsWrapper:
...
@@ -192,9 +202,9 @@ class TorchCompileWithNoGuardsWrapper:
if
envs
.
VLLM_USE_BYTECODE_HOOK
and
mode
!=
CompilationMode
.
STOCK_TORCH_COMPILE
:
if
envs
.
VLLM_USE_BYTECODE_HOOK
and
mode
!=
CompilationMode
.
STOCK_TORCH_COMPILE
:
torch
.
_dynamo
.
convert_frame
.
register_bytecode_hook
(
self
.
bytecode_hook
)
torch
.
_dynamo
.
convert_frame
.
register_bytecode_hook
(
self
.
bytecode_hook
)
self
.
_compiled_bytecode
=
None
self
.
_compiled_bytecode
:
CodeType
|
None
=
None
def
aot_compile
(
self
,
*
args
,
**
kwargs
)
:
def
aot_compile
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
if
not
hasattr
(
self
.
_compiled_callable
,
"aot_compile"
):
if
not
hasattr
(
self
.
_compiled_callable
,
"aot_compile"
):
raise
RuntimeError
(
raise
RuntimeError
(
"aot_compile is not supported by the current configuration. "
"aot_compile is not supported by the current configuration. "
...
@@ -203,7 +213,7 @@ class TorchCompileWithNoGuardsWrapper:
...
@@ -203,7 +213,7 @@ class TorchCompileWithNoGuardsWrapper:
)
)
return
self
.
_compiled_callable
.
aot_compile
((
args
,
kwargs
))
return
self
.
_compiled_callable
.
aot_compile
((
args
,
kwargs
))
def
__call__
(
self
,
*
args
,
**
kwargs
)
:
def
__call__
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
if
envs
.
VLLM_USE_BYTECODE_HOOK
:
if
envs
.
VLLM_USE_BYTECODE_HOOK
:
if
(
if
(
self
.
vllm_config
.
compilation_config
.
mode
self
.
vllm_config
.
compilation_config
.
mode
...
@@ -236,13 +246,13 @@ class TorchCompileWithNoGuardsWrapper:
...
@@ -236,13 +246,13 @@ class TorchCompileWithNoGuardsWrapper:
)
)
@
abstractmethod
@
abstractmethod
def
forward
(
self
,
*
args
,
**
kwargs
)
:
...
def
forward
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
...
def
original_code_object
(
self
)
->
CodeType
:
def
original_code_object
(
self
)
->
CodeType
:
"""Return the original code object of the forward method."""
"""Return the original code object of the forward method."""
return
self
.
__class__
.
forward
.
__code__
return
self
.
__class__
.
forward
.
__code__
def
bytecode_hook
(
self
,
old_code
:
CodeType
,
new_code
:
CodeType
):
def
bytecode_hook
(
self
,
old_code
:
CodeType
,
new_code
:
CodeType
)
->
None
:
"""Hook to save the compiled bytecode for direct execution."""
"""Hook to save the compiled bytecode for direct execution."""
if
old_code
is
not
self
.
original_code_object
():
if
old_code
is
not
self
.
original_code_object
():
return
return
...
@@ -299,7 +309,7 @@ class TorchCompileWithNoGuardsWrapper:
...
@@ -299,7 +309,7 @@ class TorchCompileWithNoGuardsWrapper:
raise
RuntimeError
(
msg
)
raise
RuntimeError
(
msg
)
@
contextmanager
@
contextmanager
def
_dispatch_to_compiled_code
(
self
):
def
_dispatch_to_compiled_code
(
self
)
->
Generator
[
None
,
None
,
None
]
:
# noqa: E501
# noqa: E501
"""
"""
Context manager to dispatch to internally compiled code for torch<2.8.
Context manager to dispatch to internally compiled code for torch<2.8.
...
...
vllm/config/compilation.py
View file @
aaf4b70a
...
@@ -32,6 +32,9 @@ else:
...
@@ -32,6 +32,9 @@ else:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
# Explicitly exports Range
__all__
=
[
"Range"
]
class
CompilationMode
(
enum
.
IntEnum
):
class
CompilationMode
(
enum
.
IntEnum
):
"""The compilation approach used for torch.compile-based compilation of the
"""The compilation approach used for torch.compile-based compilation of the
...
...
vllm/distributed/device_communicators/pynccl_allocator.py
View file @
aaf4b70a
...
@@ -60,7 +60,7 @@ def is_symmetric_memory_tensor(tensor: torch.Tensor):
...
@@ -60,7 +60,7 @@ def is_symmetric_memory_tensor(tensor: torch.Tensor):
return
False
return
False
def
set_graph_pool_id
(
graph_pool_id
)
:
def
set_graph_pool_id
(
graph_pool_id
:
Any
)
->
None
:
global
_graph_pool_id
global
_graph_pool_id
_graph_pool_id
=
graph_pool_id
_graph_pool_id
=
graph_pool_id
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment