Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
873480d1
Unverified
Commit
873480d1
authored
Jan 06, 2026
by
Lucas Kabela
Committed by
GitHub
Jan 06, 2026
Browse files
[Misc][BE] Type coverage for vllm/compilation [1/3] (#31554)
Signed-off-by:
Lucas Kabela
<
lucaskabela@meta.com
>
parent
6f351548
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
103 additions
and
85 deletions
+103
-85
vllm/compilation/backends.py
vllm/compilation/backends.py
+13
-13
vllm/compilation/collective_fusion.py
vllm/compilation/collective_fusion.py
+7
-7
vllm/compilation/compiler_interface.py
vllm/compilation/compiler_interface.py
+24
-24
vllm/compilation/counter.py
vllm/compilation/counter.py
+3
-1
vllm/compilation/cuda_graph.py
vllm/compilation/cuda_graph.py
+1
-0
vllm/compilation/fx_utils.py
vllm/compilation/fx_utils.py
+3
-2
vllm/compilation/inductor_pass.py
vllm/compilation/inductor_pass.py
+15
-11
vllm/compilation/monitor.py
vllm/compilation/monitor.py
+4
-4
vllm/compilation/partition_rules.py
vllm/compilation/partition_rules.py
+4
-1
vllm/compilation/sequence_parallelism.py
vllm/compilation/sequence_parallelism.py
+7
-7
vllm/compilation/torch25_custom_graph_pass.py
vllm/compilation/torch25_custom_graph_pass.py
+3
-3
vllm/compilation/vllm_inductor_pass.py
vllm/compilation/vllm_inductor_pass.py
+19
-12
No files found.
vllm/compilation/backends.py
View file @
873480d1
...
...
@@ -9,7 +9,7 @@ import operator
import
os
import
pprint
import
time
from
collections.abc
import
Callable
,
Sequence
from
collections.abc
import
Callable
,
Generator
,
Sequence
from
contextlib
import
contextmanager
from
copy
import
deepcopy
from
functools
import
partial
...
...
@@ -90,7 +90,7 @@ class CompilerManager:
support int as key.
"""
def
__init__
(
self
,
compilation_config
:
CompilationConfig
):
def
__init__
(
self
,
compilation_config
:
CompilationConfig
)
->
None
:
self
.
cache
:
dict
[
tuple
[
Range
,
int
,
str
],
Any
]
=
dict
()
self
.
is_cache_updated
=
False
self
.
compilation_config
=
compilation_config
...
...
@@ -100,7 +100,7 @@ class CompilerManager:
return
self
.
compiler
.
compute_hash
(
vllm_config
)
@
contextmanager
def
compile_context
(
self
,
compile_range
:
Range
):
def
compile_context
(
self
,
compile_range
:
Range
)
->
Generator
[
None
,
None
,
None
]
:
"""Provide compilation context for the duration of compilation to set
any torch global properties we want to scope to a single Inductor
compilation (e.g. partition rules, pass context)."""
...
...
@@ -115,7 +115,7 @@ class CompilerManager:
def
initialize_cache
(
self
,
cache_dir
:
str
,
disable_cache
:
bool
=
False
,
prefix
:
str
=
""
):
)
->
None
:
"""
Initialize the cache directory for the compiler.
...
...
@@ -143,7 +143,7 @@ class CompilerManager:
# do not use eval(), it is unsafe.
cache
=
ast
.
literal_eval
(
f
.
read
())
def
check_type
(
value
,
ty
)
:
def
check_type
(
value
:
Any
,
ty
:
type
)
->
None
:
if
not
isinstance
(
value
,
ty
):
raise
TypeError
(
f
"Expected
{
ty
}
but got
{
type
(
value
)
}
for
{
value
}
"
)
...
...
@@ -165,7 +165,7 @@ class CompilerManager:
cache_dir
=
cache_dir
,
disable_cache
=
disable_cache
,
prefix
=
prefix
)
def
save_to_file
(
self
):
def
save_to_file
(
self
)
->
None
:
if
self
.
disable_cache
or
not
self
.
is_cache_updated
:
return
printer
=
pprint
.
PrettyPrinter
(
indent
=
4
)
...
...
@@ -198,7 +198,7 @@ class CompilerManager:
def
compile
(
self
,
graph
:
fx
.
GraphModule
,
example_inputs
,
example_inputs
:
list
[
Any
]
,
additional_inductor_config
,
compilation_config
:
CompilationConfig
,
compile_range
:
Range
,
...
...
@@ -373,7 +373,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
compile_submod_names
:
list
[
str
],
vllm_config
:
VllmConfig
,
vllm_backend
:
"VllmBackend"
,
):
)
->
None
:
super
().
__init__
(
module
)
from
torch._guards
import
detect_fake_mode
...
...
@@ -385,7 +385,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
# When True, it annoyingly dumps the torch.fx.Graph on errors.
self
.
extra_traceback
=
False
def
run
(
self
,
*
args
)
:
def
run
(
self
,
*
args
:
Any
)
->
Any
:
# maybe instead just assert inputs are fake?
fake_args
=
[
self
.
fake_mode
.
from_tensor
(
t
)
if
isinstance
(
t
,
torch
.
Tensor
)
else
t
...
...
@@ -467,7 +467,7 @@ model_is_encoder: bool = False
@
contextmanager
def
set_model_tag
(
tag
:
str
,
is_encoder
:
bool
=
False
):
def
set_model_tag
(
tag
:
str
,
is_encoder
:
bool
=
False
)
->
Generator
[
None
,
None
,
None
]
:
"""Context manager to set the model tag."""
global
model_tag
global
model_is_encoder
...
...
@@ -521,7 +521,7 @@ class VllmBackend:
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
is_encoder
:
bool
=
False
,
):
)
->
None
:
# if the model is initialized with a non-empty prefix,
# then usually it's enough to use that prefix,
# e.g. language_model, vision_model, etc.
...
...
@@ -558,7 +558,7 @@ class VllmBackend:
# `torch.compile` is JIT compiled, so we don't need to
# do anything here
def
configure_post_pass
(
self
):
def
configure_post_pass
(
self
)
->
None
:
self
.
pass_manager
.
configure
(
self
.
vllm_config
)
# Post-grad custom passes are run using the post_grad_custom_post_pass
...
...
@@ -580,7 +580,7 @@ class VllmBackend:
self
.
inductor_config
[
self
.
pass_key
]
=
self
.
pass_manager
def
__call__
(
self
,
graph
:
fx
.
GraphModule
,
example_inputs
self
,
graph
:
fx
.
GraphModule
,
example_inputs
:
Sequence
[
Any
]
)
->
VllmSerializableFunction
:
vllm_config
=
self
.
vllm_config
# Minimal hashing here with existing utilities, reused below.
...
...
vllm/compilation/collective_fusion.py
View file @
873480d1
...
...
@@ -50,7 +50,7 @@ if hasattr(torch.ops._C, "scaled_fp4_quant"):
class
BasePattern
:
def
__init__
(
self
,
dtype
:
torch
.
dtype
,
device
:
str
)
:
def
__init__
(
self
,
dtype
:
torch
.
dtype
,
device
:
str
|
None
)
->
None
:
self
.
dtype
=
dtype
self
.
device
=
device
self
.
tp
=
get_tp_group
()
...
...
@@ -637,7 +637,7 @@ class AllReduceRMSNormPattern(BasePattern):
self
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
device
:
str
,
device
:
str
|
None
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
):
super
().
__init__
(
dtype
,
device
)
...
...
@@ -692,7 +692,7 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
self
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
device
:
str
,
device
:
str
|
None
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
):
super
().
__init__
(
dtype
,
device
)
...
...
@@ -759,7 +759,7 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
self
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
device
:
str
,
device
:
str
|
None
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
):
super
().
__init__
(
dtype
,
device
)
...
...
@@ -828,7 +828,7 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
self
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
device
:
str
,
device
:
str
|
None
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
):
super
().
__init__
(
dtype
,
device
)
...
...
@@ -902,7 +902,7 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
self
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
device
:
str
,
device
:
str
|
None
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
):
super
().
__init__
(
dtype
,
device
)
...
...
@@ -988,7 +988,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
self
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
device
:
str
,
device
:
str
|
None
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
):
super
().
__init__
(
dtype
,
device
)
...
...
vllm/compilation/compiler_interface.py
View file @
873480d1
...
...
@@ -31,7 +31,7 @@ class CompilerInterface:
def
initialize_cache
(
self
,
cache_dir
:
str
,
disable_cache
:
bool
=
False
,
prefix
:
str
=
""
):
)
->
None
:
"""
when the vLLM process uses `cache_dir` as the cache directory,
the compiler should initialize itself with the cache directory,
...
...
@@ -66,7 +66,7 @@ class CompilerInterface:
compiler_config
:
dict
[
str
,
Any
],
compile_range
:
Range
,
key
:
str
|
None
=
None
,
)
->
tuple
[
Callable
|
None
,
Any
|
None
]:
)
->
tuple
[
Callable
[...,
Any
]
|
None
,
Any
|
None
]:
"""
Compile the graph with the given example inputs and compiler config,
with a range. The `compile_range` specifies the range of the inputs,
...
...
@@ -100,7 +100,7 @@ class CompilerInterface:
example_inputs
:
list
[
Any
],
graph_index
:
int
,
compile_range
:
Range
,
)
->
Callable
:
)
->
Callable
[...,
Any
]
:
"""
Load the compiled function from the handle.
Raises an error if the handle is invalid.
...
...
@@ -138,13 +138,13 @@ class AlwaysHitShapeEnv:
def
__init__
(
self
)
->
None
:
self
.
guards
:
list
[
Any
]
=
[]
def
evaluate_guards_expression
(
self
,
*
args
,
**
kwargs
)
:
def
evaluate_guards_expression
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Literal
[
True
]
:
return
True
def
get_pruned_guards
(
self
,
*
args
,
**
kwargs
)
:
def
get_pruned_guards
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
list
[
Any
]
:
return
[]
def
produce_guards_expression
(
self
,
*
args
,
**
kwargs
)
:
def
produce_guards_expression
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Literal
[
""
]
:
return
""
...
...
@@ -193,7 +193,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
name
=
"inductor_standalone"
def
__init__
(
self
,
save_format
:
Literal
[
"binary"
,
"unpacked"
]):
def
__init__
(
self
,
save_format
:
Literal
[
"binary"
,
"unpacked"
])
->
None
:
self
.
save_format
=
save_format
def
compute_hash
(
self
,
vllm_config
:
VllmConfig
)
->
str
:
...
...
@@ -205,7 +205,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
def
initialize_cache
(
self
,
cache_dir
:
str
,
disable_cache
:
bool
=
False
,
prefix
:
str
=
""
):
)
->
None
:
self
.
cache_dir
=
cache_dir
def
compile
(
...
...
@@ -215,7 +215,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
compiler_config
:
dict
[
str
,
Any
],
compile_range
:
Range
,
key
:
str
|
None
=
None
,
)
->
tuple
[
Callable
|
None
,
Any
|
None
]:
)
->
tuple
[
Callable
[...,
Any
]
|
None
,
Any
|
None
]:
compilation_counter
.
num_inductor_compiles
+=
1
current_config
=
{}
if
compiler_config
is
not
None
:
...
...
@@ -252,7 +252,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
example_inputs
:
list
[
Any
],
graph_index
:
int
,
compile_range
:
Range
,
)
->
Callable
:
)
->
Callable
[...,
Any
]
:
assert
isinstance
(
handle
,
tuple
)
assert
isinstance
(
handle
[
0
],
str
)
assert
isinstance
(
handle
[
1
],
str
)
...
...
@@ -264,7 +264,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
returns_tuple
=
graph_returns_tuple
(
graph
)
def
compiled_graph_wrapper
(
*
args
)
:
def
compiled_graph_wrapper
(
*
args
:
Any
)
->
tuple
[
Any
,
...]
|
Any
:
graph_output
=
inductor_compiled_graph
(
*
args
)
# unpack the tuple if needed
# TODO(rzou): the implication is that we're not
...
...
@@ -293,7 +293,7 @@ class InductorAdaptor(CompilerInterface):
def
initialize_cache
(
self
,
cache_dir
:
str
,
disable_cache
:
bool
=
False
,
prefix
:
str
=
""
):
)
->
None
:
self
.
cache_dir
=
cache_dir
self
.
prefix
=
prefix
self
.
base_cache_dir
=
cache_dir
[:
-
len
(
prefix
)]
if
prefix
else
cache_dir
...
...
@@ -317,7 +317,7 @@ class InductorAdaptor(CompilerInterface):
compiler_config
:
dict
[
str
,
Any
],
compile_range
:
Range
,
key
:
str
|
None
=
None
,
)
->
tuple
[
Callable
|
None
,
Any
|
None
]:
)
->
tuple
[
Callable
[...,
Any
]
|
None
,
Any
|
None
]:
compilation_counter
.
num_inductor_compiles
+=
1
from
torch._inductor.compile_fx
import
compile_fx
...
...
@@ -348,7 +348,7 @@ class InductorAdaptor(CompilerInterface):
original_load
=
FxGraphCache
.
load
original_load_name
=
"torch._inductor.codecache.FxGraphCache.load"
def
hijack_load
(
*
args
,
**
kwargs
)
:
def
hijack_load
(
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
inductor_compiled_graph
=
original_load
(
*
args
,
**
kwargs
)
nonlocal
file_path
compiled_fn
=
inductor_compiled_graph
.
current_callable
...
...
@@ -375,7 +375,7 @@ class InductorAdaptor(CompilerInterface):
# function renamed in 2.6
original_load_name
=
None
def
hijacked_compile_fx_inner
(
*
args
,
**
kwargs
)
:
def
hijacked_compile_fx_inner
(
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
output
=
torch
.
_inductor
.
compile_fx
.
compile_fx_inner
(
*
args
,
**
kwargs
)
nonlocal
hash_str
inductor_compiled_graph
=
output
...
...
@@ -401,13 +401,13 @@ class InductorAdaptor(CompilerInterface):
hash_str
=
inductor_compiled_graph
.
_fx_graph_cache_key
return
output
def
hijack_compiled_fx_graph_hash
(
*
args
,
**
kwargs
)
:
def
hijack_compiled_fx_graph_hash
(
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
out
=
compiled_fx_graph_hash
(
*
args
,
**
kwargs
)
nonlocal
hash_str
hash_str
=
out
[
0
]
return
out
def
_check_can_cache
(
*
args
,
**
kwargs
)
:
def
_check_can_cache
(
*
args
:
Any
,
**
kwargs
:
Any
)
->
None
:
# no error means it can be cached.
# Inductor refuses to cache the graph outside of Dynamo
# tracing context, and also disables caching for graphs
...
...
@@ -513,7 +513,7 @@ class InductorAdaptor(CompilerInterface):
example_inputs
:
list
[
Any
],
graph_index
:
int
,
compile_range
:
Range
,
)
->
Callable
:
)
->
Callable
[...,
Any
]
:
assert
isinstance
(
handle
,
tuple
)
assert
isinstance
(
handle
[
0
],
str
)
assert
isinstance
(
handle
[
1
],
str
)
...
...
@@ -572,7 +572,7 @@ class InductorAdaptor(CompilerInterface):
returns_tuple
=
graph_returns_tuple
(
graph
)
# this is the callable we return to Dynamo to run
def
compiled_graph
(
*
args
)
:
def
compiled_graph
(
*
args
:
Any
)
->
tuple
[
Any
,
...]
|
Any
:
# convert args to list
list_args
=
list
(
args
)
graph_output
=
inductor_compiled_graph
(
list_args
)
...
...
@@ -584,7 +584,7 @@ class InductorAdaptor(CompilerInterface):
return
compiled_graph
def
metrics_context
(
self
)
->
contextlib
.
AbstractContextManager
:
def
metrics_context
(
self
)
->
contextlib
.
AbstractContextManager
[
Any
]
:
"""
This method returns the Dynamo metrics context (if it exists,
otherwise a null context). It is used by various compile components.
...
...
@@ -603,12 +603,12 @@ class InductorAdaptor(CompilerInterface):
if
is_torch_equal_or_newer
(
"2.6"
):
import
torch._dynamo.utils
return
torch
.
_dynamo
.
utils
.
get_metrics_context
()
return
torch
.
_dynamo
.
utils
.
get_metrics_context
()
# type: ignore[no-any-return]
else
:
return
contextlib
.
nullcontext
()
def
set_inductor_config
(
config
,
compile_range
:
Range
):
def
set_inductor_config
(
config
:
dict
[
str
,
Any
]
,
compile_range
:
Range
)
->
None
:
if
compile_range
.
is_single_size
():
# for a specific batch size, tuning triton kernel parameters
# can be beneficial
...
...
@@ -618,7 +618,7 @@ def set_inductor_config(config, compile_range: Range):
)
def
set_functorch_config
():
def
set_functorch_config
()
->
None
:
torch
.
_functorch
.
config
.
bundled_autograd_cache
=
False
...
...
@@ -632,7 +632,7 @@ class EagerAdaptor(CompilerInterface):
compiler_config
:
dict
[
str
,
Any
],
compile_range
:
Range
,
key
:
str
|
None
=
None
,
)
->
tuple
[
Callable
|
None
,
Any
|
None
]:
)
->
tuple
[
Callable
[...,
Any
]
|
None
,
Any
|
None
]:
compilation_counter
.
num_eager_compiles
+=
1
# we don't need to compile the graph, just return the graph itself.
# It does not support caching, return None for the handle.
...
...
vllm/compilation/counter.py
View file @
873480d1
...
...
@@ -3,7 +3,9 @@
import
copy
import
dataclasses
from
collections.abc
import
Generator
from
contextlib
import
contextmanager
from
typing
import
Any
@
dataclasses
.
dataclass
...
...
@@ -34,7 +36,7 @@ class CompilationCounter:
return
copy
.
deepcopy
(
self
)
@
contextmanager
def
expect
(
self
,
**
kwargs
)
:
def
expect
(
self
,
**
kwargs
:
Any
)
->
Generator
[
None
,
None
,
None
]
:
old
=
self
.
clone
()
yield
for
k
,
v
in
kwargs
.
items
():
...
...
vllm/compilation/cuda_graph.py
View file @
873480d1
...
...
@@ -219,6 +219,7 @@ class CUDAGraphWrapper:
# runtime modes.
return
self
.
runnable
(
*
args
,
**
kwargs
)
assert
batch_descriptor
is
not
None
if
batch_descriptor
not
in
self
.
concrete_cudagraph_entries
:
# create a new entry for this batch descriptor
self
.
concrete_cudagraph_entries
[
batch_descriptor
]
=
CUDAGraphEntry
(
...
...
vllm/compilation/fx_utils.py
View file @
873480d1
...
...
@@ -7,10 +7,11 @@ from collections.abc import Iterable, Iterator
from
torch
import
fx
from
torch._higher_order_ops.auto_functionalize
import
auto_functionalized
from
torch._ops
import
OpOverload
,
OpOverloadPacket
from
torch.fx.node
import
Target
def
is_func
(
node
:
fx
.
Node
,
target
)
->
bool
:
return
node
.
op
==
"call_function"
and
node
.
target
==
target
def
is_func
(
node
:
fx
.
Node
,
target
:
Target
)
->
bool
:
return
bool
(
node
.
op
==
"call_function"
and
node
.
target
==
target
)
def
is_auto_func
(
node
:
fx
.
Node
,
op
:
OpOverload
)
->
bool
:
...
...
vllm/compilation/inductor_pass.py
View file @
873480d1
...
...
@@ -8,9 +8,9 @@ import hashlib
import
inspect
import
json
import
types
from
collections.abc
import
Callable
from
collections.abc
import
Callable
,
Generator
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Any
from
typing
import
TYPE_CHECKING
,
Any
,
ParamSpec
,
TypeVar
import
torch
from
torch
import
fx
...
...
@@ -30,6 +30,8 @@ else:
)
_pass_context
=
None
P
=
ParamSpec
(
"P"
)
R
=
TypeVar
(
"R"
)
class
PassContext
:
...
...
@@ -44,7 +46,7 @@ def get_pass_context() -> PassContext:
@
contextmanager
def
pass_context
(
compile_range
:
Range
):
def
pass_context
(
compile_range
:
Range
)
->
Generator
[
None
,
None
,
None
]
:
"""A context manager that stores the current pass context,
usually it is a list of sizes to specialize.
"""
...
...
@@ -57,7 +59,7 @@ def pass_context(compile_range: Range):
_pass_context
=
prev_context
class
InductorPass
(
CustomGraphPass
):
class
InductorPass
(
CustomGraphPass
):
# type: ignore[misc]
"""
A custom graph pass that uses a hash of its source as the UUID.
This is defined as a convenience and should work in most cases.
...
...
@@ -73,7 +75,7 @@ class InductorPass(CustomGraphPass):
return
InductorPass
.
hash_source
(
self
)
@
staticmethod
def
hash_source
(
*
srcs
:
str
|
Any
):
def
hash_source
(
*
srcs
:
str
|
Any
)
->
str
:
"""
Utility method to hash the sources of functions or objects.
:param srcs: strings or objects to add to the hash.
...
...
@@ -93,7 +95,7 @@ class InductorPass(CustomGraphPass):
return
hasher
.
hexdigest
()
@
staticmethod
def
hash_dict
(
dict_
:
dict
[
Any
,
Any
]):
def
hash_dict
(
dict_
:
dict
[
Any
,
Any
])
->
str
:
"""
Utility method to hash a dictionary, can alternatively be used for uuid.
:return: A sha256 hash of the json rep of the dictionary.
...
...
@@ -101,7 +103,7 @@ class InductorPass(CustomGraphPass):
encoded
=
json
.
dumps
(
dict_
,
sort_keys
=
True
).
encode
(
"utf-8"
)
return
hashlib
.
sha256
(
encoded
).
hexdigest
()
def
is_applicable_for_range
(
self
,
compile_range
:
Range
):
def
is_applicable_for_range
(
self
,
compile_range
:
Range
)
->
bool
:
return
True
...
...
@@ -111,25 +113,27 @@ class CallableInductorPass(InductorPass):
implementation of the UUID.
"""
def
__init__
(
self
,
callable
:
Callable
[[
fx
.
Graph
],
None
],
uuid
:
Any
|
None
=
None
):
def
__init__
(
self
,
callable
:
Callable
[[
fx
.
Graph
],
None
],
uuid
:
Any
|
None
=
None
)
->
None
:
self
.
callable
=
callable
self
.
_uuid
=
self
.
hash_source
(
callable
)
if
uuid
is
None
else
uuid
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
):
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
)
->
None
:
self
.
callable
(
graph
)
def
uuid
(
self
)
->
Any
:
return
self
.
_uuid
def
enable_fake_mode
(
fn
:
Callable
[
...,
Any
])
->
Callable
[
...,
Any
]:
def
enable_fake_mode
(
fn
:
Callable
[
P
,
R
])
->
Callable
[
P
,
R
]:
"""
Applies a FakeTensorMode context. This is useful when you don't want to
create or run things with real tensors.
"""
@
functools
.
wraps
(
fn
)
def
fn_new
(
*
args
,
**
kwargs
)
->
Any
:
def
fn_new
(
*
args
:
P
.
args
,
**
kwargs
:
P
.
kwargs
)
->
R
:
with
torch
.
_guards
.
tracing
(
None
),
unset_fake_temporarily
(),
FakeTensorMode
():
result
=
fn
(
*
args
,
**
kwargs
)
...
...
vllm/compilation/monitor.py
View file @
873480d1
...
...
@@ -12,7 +12,7 @@ context_manager = None
torch_compile_start_time
:
float
=
0.0
def
start_monitoring_torch_compile
(
vllm_config
:
VllmConfig
):
def
start_monitoring_torch_compile
(
vllm_config
:
VllmConfig
)
->
None
:
global
torch_compile_start_time
torch_compile_start_time
=
time
.
time
()
...
...
@@ -28,7 +28,7 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig):
context_manager
.
__enter__
()
def
end_monitoring_torch_compile
(
vllm_config
:
VllmConfig
):
def
end_monitoring_torch_compile
(
vllm_config
:
VllmConfig
)
->
None
:
compilation_config
:
CompilationConfig
=
vllm_config
.
compilation_config
if
compilation_config
.
mode
==
CompilationMode
.
VLLM_COMPILE
:
logger
.
info_once
(
...
...
@@ -45,7 +45,7 @@ def end_monitoring_torch_compile(vllm_config: VllmConfig):
cudagraph_capturing_enabled
:
bool
=
True
def
validate_cudagraph_capturing_enabled
():
def
validate_cudagraph_capturing_enabled
()
->
None
:
# used to monitor whether a cudagraph capturing is legal at runtime.
# should be called before any cudagraph capturing.
# if an illegal cudagraph capturing happens, raise an error.
...
...
@@ -57,6 +57,6 @@ def validate_cudagraph_capturing_enabled():
)
def
set_cudagraph_capturing_enabled
(
enabled
:
bool
):
def
set_cudagraph_capturing_enabled
(
enabled
:
bool
)
->
None
:
global
cudagraph_capturing_enabled
cudagraph_capturing_enabled
=
enabled
vllm/compilation/partition_rules.py
View file @
873480d1
...
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
contextlib
from
collections.abc
import
Generator
import
torch
...
...
@@ -38,7 +39,9 @@ def should_split(node: torch.fx.Node, splitting_ops: list[str]) -> bool:
@
contextlib
.
contextmanager
def
inductor_partition_rule_context
(
splitting_ops
:
list
[
str
]):
def
inductor_partition_rule_context
(
splitting_ops
:
list
[
str
]
|
None
,
)
->
Generator
[
None
,
None
,
None
]:
"""Context manager to temporarily register Inductor partition rules.
Registers custom partition rules for specified operators, forcing the
...
...
vllm/compilation/sequence_parallelism.py
View file @
873480d1
...
...
@@ -41,8 +41,8 @@ class _SequenceParallelPatternHelper:
self
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
device
:
str
,
):
device
:
str
|
None
,
)
->
None
:
self
.
epsilon
=
epsilon
self
.
dtype
=
dtype
self
.
device
=
device
...
...
@@ -64,7 +64,7 @@ class _SequenceParallelPatternHelper:
class
FirstAllReduceRMSNormPattern
(
_SequenceParallelPatternHelper
):
def
__init__
(
self
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
device
:
str
)
:
def
__init__
(
self
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
device
:
str
|
None
)
->
None
:
super
().
__init__
(
epsilon
,
dtype
,
device
)
self
.
rmsnorm_matcher
=
MatcherRMSNorm
(
epsilon
)
...
...
@@ -74,7 +74,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
return
[
input
,
arg3_1
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
input
:
torch
.
Tensor
,
arg3_1
:
torch
.
Tensor
,
...
...
@@ -100,7 +100,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
class
MiddleAllReduceRMSNormPattern
(
_SequenceParallelPatternHelper
):
def
__init__
(
self
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
device
:
str
):
def
__init__
(
self
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
device
:
str
|
None
):
super
().
__init__
(
epsilon
,
dtype
,
device
)
self
.
rmsnorm_matcher
=
MatcherFusedAddRMSNorm
(
epsilon
)
...
...
@@ -162,7 +162,7 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
self
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
device
:
str
,
device
:
str
|
None
,
):
super
().
__init__
(
epsilon
,
dtype
,
device
)
self
.
rmsnorm_matcher
=
MatcherRMSNorm
(
epsilon
)
...
...
@@ -203,7 +203,7 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
class
MiddleAllReduceRMSNormStaticFP8Pattern
(
_SequenceParallelPatternHelper
):
def
__init__
(
self
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
device
:
str
):
def
__init__
(
self
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
device
:
str
|
None
):
super
().
__init__
(
epsilon
,
dtype
,
device
)
self
.
rmsnorm_matcher
=
MatcherFusedAddRMSNorm
(
epsilon
)
self
.
quant_matcher
=
MatcherQuantFP8
(
kFp8StaticTensorSym
)
...
...
vllm/compilation/torch25_custom_graph_pass.py
View file @
873480d1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
from
typing
import
Any
,
NoReturn
import
torch
...
...
@@ -29,14 +29,14 @@ class Torch25CustomGraphPass(ABC): # noqa (redefinition)
Return None to skip inductor code caching entirely.
"""
def
__getstate__
(
self
):
def
__getstate__
(
self
)
->
Any
|
None
:
"""
Pickling is used instead of uuid() in torch<2.6. Just return uuid()
to enable subclasses to only have to implement uuid.
"""
return
self
.
uuid
()
def
__setstate__
(
self
,
state
)
:
def
__setstate__
(
self
,
state
:
Any
)
->
NoReturn
:
raise
ValueError
(
"Cannot unpickle CustomGraphPass because pickling"
" is used for cache key uuid. Use torch>=2.6 with"
...
...
vllm/compilation/vllm_inductor_pass.py
View file @
873480d1
...
...
@@ -3,6 +3,7 @@
import
functools
import
operator
import
time
from
collections.abc
import
Callable
from
dataclasses
import
dataclass
from
typing
import
ClassVar
...
...
@@ -43,13 +44,17 @@ class VllmInductorPass(InductorPass):
)
self
.
pass_config
=
config
.
compilation_config
.
pass_config
self
.
model_dtype
=
config
.
model_config
.
dtype
if
config
.
model_config
else
None
self
.
device
=
config
.
device_config
.
device
if
config
.
device_config
else
None
self
.
device
:
str
|
None
=
(
config
.
device_config
.
device
if
config
.
device_config
else
None
)
self
.
pass_name
=
self
.
__class__
.
__name__
@
staticmethod
def
time_and_log
(
call_fn
):
def
time_and_log
(
call_fn
:
Callable
[[
"VllmInductorPass"
,
torch
.
fx
.
Graph
],
None
],
)
->
Callable
[[
"VllmInductorPass"
,
torch
.
fx
.
Graph
],
None
]:
@
functools
.
wraps
(
call_fn
)
def
wrapped
(
self
:
VllmInductorPass
,
graph
:
torch
.
fx
.
Graph
):
def
wrapped
(
self
:
VllmInductorPass
,
graph
:
torch
.
fx
.
Graph
)
->
None
:
self
.
begin
()
self
.
dump_graph
(
graph
,
"before"
)
call_fn
(
self
,
graph
)
...
...
@@ -58,17 +63,17 @@ class VllmInductorPass(InductorPass):
return
wrapped
def
dump_graph
(
self
,
graph
:
torch
.
fx
.
Graph
,
stage
:
str
):
def
dump_graph
(
self
,
graph
:
torch
.
fx
.
Graph
,
stage
:
str
)
->
None
:
i
=
VllmInductorPass
.
dump_prefix
i_str
=
""
if
i
is
None
else
f
".
{
i
}
"
lazy_format_graph_code
(
f
"post_grad
{
i_str
}
.
{
self
.
pass_name
}
.
{
stage
}
"
,
graph
.
owning_module
)
def
begin
(
self
):
def
begin
(
self
)
->
None
:
self
.
_start_time
=
time
.
perf_counter_ns
()
def
end_and_log
(
self
):
def
end_and_log
(
self
)
->
None
:
self
.
_end_time
=
time
.
perf_counter_ns
()
duration_ms
=
float
(
self
.
_end_time
-
self
.
_start_time
)
/
1.0e6
logger
.
debug
(
"%s completed in %.1f ms"
,
self
.
pass_name
,
duration_ms
)
...
...
@@ -92,12 +97,14 @@ class VllmPatternMatcherPass(VllmInductorPass):
def
_replace_op_overloads
(
self
,
string
:
str
)
->
str
:
"""Replace <OpOverload(..., ...)> with nicer formulations"""
return
self
.
_OP_OVERLOAD_PATTERN
.
sub
(
return
str
(
self
.
_OP_OVERLOAD_PATTERN
.
sub
(
lambda
m
:
f
"torch.ops.
{
m
.
group
(
1
)
}
.
{
m
.
group
(
2
)
}
"
,
string
,
)
)
def
dump_patterns
(
self
,
config
:
VllmConfig
,
pm_pass
:
PatternMatcherPass
):
def
dump_patterns
(
self
,
config
:
VllmConfig
,
pm_pass
:
PatternMatcherPass
)
->
None
:
"""
If debug dumping is enabled, dump the Inductor pattern-matcher patterns
into the debug_dump_path folder next to the dumped fx graphs.
...
...
@@ -165,9 +172,9 @@ class VllmPatternMatcherPass(VllmInductorPass):
class
PrinterInductorPass
(
VllmInductorPass
):
def
__init__
(
self
,
name
:
str
,
config
:
VllmConfig
):
def
__init__
(
self
,
name
:
str
,
config
:
VllmConfig
)
->
None
:
super
().
__init__
(
config
)
self
.
name
=
name
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
):
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
)
->
None
:
self
.
dump_graph
(
graph
,
self
.
name
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment