Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
873480d1
Unverified
Commit
873480d1
authored
Jan 06, 2026
by
Lucas Kabela
Committed by
GitHub
Jan 06, 2026
Browse files
[Misc][BE] Type coverage for vllm/compilation [1/3] (#31554)
Signed-off-by:
Lucas Kabela
<
lucaskabela@meta.com
>
parent
6f351548
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
103 additions
and
85 deletions
+103
-85
vllm/compilation/backends.py
vllm/compilation/backends.py
+13
-13
vllm/compilation/collective_fusion.py
vllm/compilation/collective_fusion.py
+7
-7
vllm/compilation/compiler_interface.py
vllm/compilation/compiler_interface.py
+24
-24
vllm/compilation/counter.py
vllm/compilation/counter.py
+3
-1
vllm/compilation/cuda_graph.py
vllm/compilation/cuda_graph.py
+1
-0
vllm/compilation/fx_utils.py
vllm/compilation/fx_utils.py
+3
-2
vllm/compilation/inductor_pass.py
vllm/compilation/inductor_pass.py
+15
-11
vllm/compilation/monitor.py
vllm/compilation/monitor.py
+4
-4
vllm/compilation/partition_rules.py
vllm/compilation/partition_rules.py
+4
-1
vllm/compilation/sequence_parallelism.py
vllm/compilation/sequence_parallelism.py
+7
-7
vllm/compilation/torch25_custom_graph_pass.py
vllm/compilation/torch25_custom_graph_pass.py
+3
-3
vllm/compilation/vllm_inductor_pass.py
vllm/compilation/vllm_inductor_pass.py
+19
-12
No files found.
vllm/compilation/backends.py
View file @
873480d1
...
@@ -9,7 +9,7 @@ import operator
...
@@ -9,7 +9,7 @@ import operator
import
os
import
os
import
pprint
import
pprint
import
time
import
time
from
collections.abc
import
Callable
,
Sequence
from
collections.abc
import
Callable
,
Generator
,
Sequence
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
copy
import
deepcopy
from
copy
import
deepcopy
from
functools
import
partial
from
functools
import
partial
...
@@ -90,7 +90,7 @@ class CompilerManager:
...
@@ -90,7 +90,7 @@ class CompilerManager:
support int as key.
support int as key.
"""
"""
def
__init__
(
self
,
compilation_config
:
CompilationConfig
):
def
__init__
(
self
,
compilation_config
:
CompilationConfig
)
->
None
:
self
.
cache
:
dict
[
tuple
[
Range
,
int
,
str
],
Any
]
=
dict
()
self
.
cache
:
dict
[
tuple
[
Range
,
int
,
str
],
Any
]
=
dict
()
self
.
is_cache_updated
=
False
self
.
is_cache_updated
=
False
self
.
compilation_config
=
compilation_config
self
.
compilation_config
=
compilation_config
...
@@ -100,7 +100,7 @@ class CompilerManager:
...
@@ -100,7 +100,7 @@ class CompilerManager:
return
self
.
compiler
.
compute_hash
(
vllm_config
)
return
self
.
compiler
.
compute_hash
(
vllm_config
)
@
contextmanager
@
contextmanager
def
compile_context
(
self
,
compile_range
:
Range
):
def
compile_context
(
self
,
compile_range
:
Range
)
->
Generator
[
None
,
None
,
None
]
:
"""Provide compilation context for the duration of compilation to set
"""Provide compilation context for the duration of compilation to set
any torch global properties we want to scope to a single Inductor
any torch global properties we want to scope to a single Inductor
compilation (e.g. partition rules, pass context)."""
compilation (e.g. partition rules, pass context)."""
...
@@ -115,7 +115,7 @@ class CompilerManager:
...
@@ -115,7 +115,7 @@ class CompilerManager:
def
initialize_cache
(
def
initialize_cache
(
self
,
cache_dir
:
str
,
disable_cache
:
bool
=
False
,
prefix
:
str
=
""
self
,
cache_dir
:
str
,
disable_cache
:
bool
=
False
,
prefix
:
str
=
""
):
)
->
None
:
"""
"""
Initialize the cache directory for the compiler.
Initialize the cache directory for the compiler.
...
@@ -143,7 +143,7 @@ class CompilerManager:
...
@@ -143,7 +143,7 @@ class CompilerManager:
# do not use eval(), it is unsafe.
# do not use eval(), it is unsafe.
cache
=
ast
.
literal_eval
(
f
.
read
())
cache
=
ast
.
literal_eval
(
f
.
read
())
def
check_type
(
value
,
ty
)
:
def
check_type
(
value
:
Any
,
ty
:
type
)
->
None
:
if
not
isinstance
(
value
,
ty
):
if
not
isinstance
(
value
,
ty
):
raise
TypeError
(
f
"Expected
{
ty
}
but got
{
type
(
value
)
}
for
{
value
}
"
)
raise
TypeError
(
f
"Expected
{
ty
}
but got
{
type
(
value
)
}
for
{
value
}
"
)
...
@@ -165,7 +165,7 @@ class CompilerManager:
...
@@ -165,7 +165,7 @@ class CompilerManager:
cache_dir
=
cache_dir
,
disable_cache
=
disable_cache
,
prefix
=
prefix
cache_dir
=
cache_dir
,
disable_cache
=
disable_cache
,
prefix
=
prefix
)
)
def
save_to_file
(
self
):
def
save_to_file
(
self
)
->
None
:
if
self
.
disable_cache
or
not
self
.
is_cache_updated
:
if
self
.
disable_cache
or
not
self
.
is_cache_updated
:
return
return
printer
=
pprint
.
PrettyPrinter
(
indent
=
4
)
printer
=
pprint
.
PrettyPrinter
(
indent
=
4
)
...
@@ -198,7 +198,7 @@ class CompilerManager:
...
@@ -198,7 +198,7 @@ class CompilerManager:
def
compile
(
def
compile
(
self
,
self
,
graph
:
fx
.
GraphModule
,
graph
:
fx
.
GraphModule
,
example_inputs
,
example_inputs
:
list
[
Any
]
,
additional_inductor_config
,
additional_inductor_config
,
compilation_config
:
CompilationConfig
,
compilation_config
:
CompilationConfig
,
compile_range
:
Range
,
compile_range
:
Range
,
...
@@ -373,7 +373,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
...
@@ -373,7 +373,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
compile_submod_names
:
list
[
str
],
compile_submod_names
:
list
[
str
],
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
vllm_backend
:
"VllmBackend"
,
vllm_backend
:
"VllmBackend"
,
):
)
->
None
:
super
().
__init__
(
module
)
super
().
__init__
(
module
)
from
torch._guards
import
detect_fake_mode
from
torch._guards
import
detect_fake_mode
...
@@ -385,7 +385,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
...
@@ -385,7 +385,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
# When True, it annoyingly dumps the torch.fx.Graph on errors.
# When True, it annoyingly dumps the torch.fx.Graph on errors.
self
.
extra_traceback
=
False
self
.
extra_traceback
=
False
def
run
(
self
,
*
args
)
:
def
run
(
self
,
*
args
:
Any
)
->
Any
:
# maybe instead just assert inputs are fake?
# maybe instead just assert inputs are fake?
fake_args
=
[
fake_args
=
[
self
.
fake_mode
.
from_tensor
(
t
)
if
isinstance
(
t
,
torch
.
Tensor
)
else
t
self
.
fake_mode
.
from_tensor
(
t
)
if
isinstance
(
t
,
torch
.
Tensor
)
else
t
...
@@ -467,7 +467,7 @@ model_is_encoder: bool = False
...
@@ -467,7 +467,7 @@ model_is_encoder: bool = False
@
contextmanager
@
contextmanager
def
set_model_tag
(
tag
:
str
,
is_encoder
:
bool
=
False
):
def
set_model_tag
(
tag
:
str
,
is_encoder
:
bool
=
False
)
->
Generator
[
None
,
None
,
None
]
:
"""Context manager to set the model tag."""
"""Context manager to set the model tag."""
global
model_tag
global
model_tag
global
model_is_encoder
global
model_is_encoder
...
@@ -521,7 +521,7 @@ class VllmBackend:
...
@@ -521,7 +521,7 @@ class VllmBackend:
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
is_encoder
:
bool
=
False
,
is_encoder
:
bool
=
False
,
):
)
->
None
:
# if the model is initialized with a non-empty prefix,
# if the model is initialized with a non-empty prefix,
# then usually it's enough to use that prefix,
# then usually it's enough to use that prefix,
# e.g. language_model, vision_model, etc.
# e.g. language_model, vision_model, etc.
...
@@ -558,7 +558,7 @@ class VllmBackend:
...
@@ -558,7 +558,7 @@ class VllmBackend:
# `torch.compile` is JIT compiled, so we don't need to
# `torch.compile` is JIT compiled, so we don't need to
# do anything here
# do anything here
def
configure_post_pass
(
self
):
def
configure_post_pass
(
self
)
->
None
:
self
.
pass_manager
.
configure
(
self
.
vllm_config
)
self
.
pass_manager
.
configure
(
self
.
vllm_config
)
# Post-grad custom passes are run using the post_grad_custom_post_pass
# Post-grad custom passes are run using the post_grad_custom_post_pass
...
@@ -580,7 +580,7 @@ class VllmBackend:
...
@@ -580,7 +580,7 @@ class VllmBackend:
self
.
inductor_config
[
self
.
pass_key
]
=
self
.
pass_manager
self
.
inductor_config
[
self
.
pass_key
]
=
self
.
pass_manager
def
__call__
(
def
__call__
(
self
,
graph
:
fx
.
GraphModule
,
example_inputs
self
,
graph
:
fx
.
GraphModule
,
example_inputs
:
Sequence
[
Any
]
)
->
VllmSerializableFunction
:
)
->
VllmSerializableFunction
:
vllm_config
=
self
.
vllm_config
vllm_config
=
self
.
vllm_config
# Minimal hashing here with existing utilities, reused below.
# Minimal hashing here with existing utilities, reused below.
...
...
vllm/compilation/collective_fusion.py
View file @
873480d1
...
@@ -50,7 +50,7 @@ if hasattr(torch.ops._C, "scaled_fp4_quant"):
...
@@ -50,7 +50,7 @@ if hasattr(torch.ops._C, "scaled_fp4_quant"):
class
BasePattern
:
class
BasePattern
:
def
__init__
(
self
,
dtype
:
torch
.
dtype
,
device
:
str
)
:
def
__init__
(
self
,
dtype
:
torch
.
dtype
,
device
:
str
|
None
)
->
None
:
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
device
=
device
self
.
device
=
device
self
.
tp
=
get_tp_group
()
self
.
tp
=
get_tp_group
()
...
@@ -637,7 +637,7 @@ class AllReduceRMSNormPattern(BasePattern):
...
@@ -637,7 +637,7 @@ class AllReduceRMSNormPattern(BasePattern):
self
,
self
,
epsilon
:
float
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
device
:
str
,
device
:
str
|
None
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
):
):
super
().
__init__
(
dtype
,
device
)
super
().
__init__
(
dtype
,
device
)
...
@@ -692,7 +692,7 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
...
@@ -692,7 +692,7 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
self
,
self
,
epsilon
:
float
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
device
:
str
,
device
:
str
|
None
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
):
):
super
().
__init__
(
dtype
,
device
)
super
().
__init__
(
dtype
,
device
)
...
@@ -759,7 +759,7 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
...
@@ -759,7 +759,7 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
self
,
self
,
epsilon
:
float
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
device
:
str
,
device
:
str
|
None
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
):
):
super
().
__init__
(
dtype
,
device
)
super
().
__init__
(
dtype
,
device
)
...
@@ -828,7 +828,7 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
...
@@ -828,7 +828,7 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
self
,
self
,
epsilon
:
float
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
device
:
str
,
device
:
str
|
None
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
):
):
super
().
__init__
(
dtype
,
device
)
super
().
__init__
(
dtype
,
device
)
...
@@ -902,7 +902,7 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
...
@@ -902,7 +902,7 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
self
,
self
,
epsilon
:
float
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
device
:
str
,
device
:
str
|
None
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
):
):
super
().
__init__
(
dtype
,
device
)
super
().
__init__
(
dtype
,
device
)
...
@@ -988,7 +988,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
...
@@ -988,7 +988,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
self
,
self
,
epsilon
:
float
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
device
:
str
,
device
:
str
|
None
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
):
):
super
().
__init__
(
dtype
,
device
)
super
().
__init__
(
dtype
,
device
)
...
...
vllm/compilation/compiler_interface.py
View file @
873480d1
...
@@ -31,7 +31,7 @@ class CompilerInterface:
...
@@ -31,7 +31,7 @@ class CompilerInterface:
def
initialize_cache
(
def
initialize_cache
(
self
,
cache_dir
:
str
,
disable_cache
:
bool
=
False
,
prefix
:
str
=
""
self
,
cache_dir
:
str
,
disable_cache
:
bool
=
False
,
prefix
:
str
=
""
):
)
->
None
:
"""
"""
when the vLLM process uses `cache_dir` as the cache directory,
when the vLLM process uses `cache_dir` as the cache directory,
the compiler should initialize itself with the cache directory,
the compiler should initialize itself with the cache directory,
...
@@ -66,7 +66,7 @@ class CompilerInterface:
...
@@ -66,7 +66,7 @@ class CompilerInterface:
compiler_config
:
dict
[
str
,
Any
],
compiler_config
:
dict
[
str
,
Any
],
compile_range
:
Range
,
compile_range
:
Range
,
key
:
str
|
None
=
None
,
key
:
str
|
None
=
None
,
)
->
tuple
[
Callable
|
None
,
Any
|
None
]:
)
->
tuple
[
Callable
[...,
Any
]
|
None
,
Any
|
None
]:
"""
"""
Compile the graph with the given example inputs and compiler config,
Compile the graph with the given example inputs and compiler config,
with a range. The `compile_range` specifies the range of the inputs,
with a range. The `compile_range` specifies the range of the inputs,
...
@@ -100,7 +100,7 @@ class CompilerInterface:
...
@@ -100,7 +100,7 @@ class CompilerInterface:
example_inputs
:
list
[
Any
],
example_inputs
:
list
[
Any
],
graph_index
:
int
,
graph_index
:
int
,
compile_range
:
Range
,
compile_range
:
Range
,
)
->
Callable
:
)
->
Callable
[...,
Any
]
:
"""
"""
Load the compiled function from the handle.
Load the compiled function from the handle.
Raises an error if the handle is invalid.
Raises an error if the handle is invalid.
...
@@ -138,13 +138,13 @@ class AlwaysHitShapeEnv:
...
@@ -138,13 +138,13 @@ class AlwaysHitShapeEnv:
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
self
.
guards
:
list
[
Any
]
=
[]
self
.
guards
:
list
[
Any
]
=
[]
def
evaluate_guards_expression
(
self
,
*
args
,
**
kwargs
)
:
def
evaluate_guards_expression
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Literal
[
True
]
:
return
True
return
True
def
get_pruned_guards
(
self
,
*
args
,
**
kwargs
)
:
def
get_pruned_guards
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
list
[
Any
]
:
return
[]
return
[]
def
produce_guards_expression
(
self
,
*
args
,
**
kwargs
)
:
def
produce_guards_expression
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Literal
[
""
]
:
return
""
return
""
...
@@ -193,7 +193,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
...
@@ -193,7 +193,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
name
=
"inductor_standalone"
name
=
"inductor_standalone"
def
__init__
(
self
,
save_format
:
Literal
[
"binary"
,
"unpacked"
]):
def
__init__
(
self
,
save_format
:
Literal
[
"binary"
,
"unpacked"
])
->
None
:
self
.
save_format
=
save_format
self
.
save_format
=
save_format
def
compute_hash
(
self
,
vllm_config
:
VllmConfig
)
->
str
:
def
compute_hash
(
self
,
vllm_config
:
VllmConfig
)
->
str
:
...
@@ -205,7 +205,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
...
@@ -205,7 +205,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
def
initialize_cache
(
def
initialize_cache
(
self
,
cache_dir
:
str
,
disable_cache
:
bool
=
False
,
prefix
:
str
=
""
self
,
cache_dir
:
str
,
disable_cache
:
bool
=
False
,
prefix
:
str
=
""
):
)
->
None
:
self
.
cache_dir
=
cache_dir
self
.
cache_dir
=
cache_dir
def
compile
(
def
compile
(
...
@@ -215,7 +215,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
...
@@ -215,7 +215,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
compiler_config
:
dict
[
str
,
Any
],
compiler_config
:
dict
[
str
,
Any
],
compile_range
:
Range
,
compile_range
:
Range
,
key
:
str
|
None
=
None
,
key
:
str
|
None
=
None
,
)
->
tuple
[
Callable
|
None
,
Any
|
None
]:
)
->
tuple
[
Callable
[...,
Any
]
|
None
,
Any
|
None
]:
compilation_counter
.
num_inductor_compiles
+=
1
compilation_counter
.
num_inductor_compiles
+=
1
current_config
=
{}
current_config
=
{}
if
compiler_config
is
not
None
:
if
compiler_config
is
not
None
:
...
@@ -252,7 +252,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
...
@@ -252,7 +252,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
example_inputs
:
list
[
Any
],
example_inputs
:
list
[
Any
],
graph_index
:
int
,
graph_index
:
int
,
compile_range
:
Range
,
compile_range
:
Range
,
)
->
Callable
:
)
->
Callable
[...,
Any
]
:
assert
isinstance
(
handle
,
tuple
)
assert
isinstance
(
handle
,
tuple
)
assert
isinstance
(
handle
[
0
],
str
)
assert
isinstance
(
handle
[
0
],
str
)
assert
isinstance
(
handle
[
1
],
str
)
assert
isinstance
(
handle
[
1
],
str
)
...
@@ -264,7 +264,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
...
@@ -264,7 +264,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
returns_tuple
=
graph_returns_tuple
(
graph
)
returns_tuple
=
graph_returns_tuple
(
graph
)
def
compiled_graph_wrapper
(
*
args
)
:
def
compiled_graph_wrapper
(
*
args
:
Any
)
->
tuple
[
Any
,
...]
|
Any
:
graph_output
=
inductor_compiled_graph
(
*
args
)
graph_output
=
inductor_compiled_graph
(
*
args
)
# unpack the tuple if needed
# unpack the tuple if needed
# TODO(rzou): the implication is that we're not
# TODO(rzou): the implication is that we're not
...
@@ -293,7 +293,7 @@ class InductorAdaptor(CompilerInterface):
...
@@ -293,7 +293,7 @@ class InductorAdaptor(CompilerInterface):
def
initialize_cache
(
def
initialize_cache
(
self
,
cache_dir
:
str
,
disable_cache
:
bool
=
False
,
prefix
:
str
=
""
self
,
cache_dir
:
str
,
disable_cache
:
bool
=
False
,
prefix
:
str
=
""
):
)
->
None
:
self
.
cache_dir
=
cache_dir
self
.
cache_dir
=
cache_dir
self
.
prefix
=
prefix
self
.
prefix
=
prefix
self
.
base_cache_dir
=
cache_dir
[:
-
len
(
prefix
)]
if
prefix
else
cache_dir
self
.
base_cache_dir
=
cache_dir
[:
-
len
(
prefix
)]
if
prefix
else
cache_dir
...
@@ -317,7 +317,7 @@ class InductorAdaptor(CompilerInterface):
...
@@ -317,7 +317,7 @@ class InductorAdaptor(CompilerInterface):
compiler_config
:
dict
[
str
,
Any
],
compiler_config
:
dict
[
str
,
Any
],
compile_range
:
Range
,
compile_range
:
Range
,
key
:
str
|
None
=
None
,
key
:
str
|
None
=
None
,
)
->
tuple
[
Callable
|
None
,
Any
|
None
]:
)
->
tuple
[
Callable
[...,
Any
]
|
None
,
Any
|
None
]:
compilation_counter
.
num_inductor_compiles
+=
1
compilation_counter
.
num_inductor_compiles
+=
1
from
torch._inductor.compile_fx
import
compile_fx
from
torch._inductor.compile_fx
import
compile_fx
...
@@ -348,7 +348,7 @@ class InductorAdaptor(CompilerInterface):
...
@@ -348,7 +348,7 @@ class InductorAdaptor(CompilerInterface):
original_load
=
FxGraphCache
.
load
original_load
=
FxGraphCache
.
load
original_load_name
=
"torch._inductor.codecache.FxGraphCache.load"
original_load_name
=
"torch._inductor.codecache.FxGraphCache.load"
def
hijack_load
(
*
args
,
**
kwargs
)
:
def
hijack_load
(
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
inductor_compiled_graph
=
original_load
(
*
args
,
**
kwargs
)
inductor_compiled_graph
=
original_load
(
*
args
,
**
kwargs
)
nonlocal
file_path
nonlocal
file_path
compiled_fn
=
inductor_compiled_graph
.
current_callable
compiled_fn
=
inductor_compiled_graph
.
current_callable
...
@@ -375,7 +375,7 @@ class InductorAdaptor(CompilerInterface):
...
@@ -375,7 +375,7 @@ class InductorAdaptor(CompilerInterface):
# function renamed in 2.6
# function renamed in 2.6
original_load_name
=
None
original_load_name
=
None
def
hijacked_compile_fx_inner
(
*
args
,
**
kwargs
)
:
def
hijacked_compile_fx_inner
(
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
output
=
torch
.
_inductor
.
compile_fx
.
compile_fx_inner
(
*
args
,
**
kwargs
)
output
=
torch
.
_inductor
.
compile_fx
.
compile_fx_inner
(
*
args
,
**
kwargs
)
nonlocal
hash_str
nonlocal
hash_str
inductor_compiled_graph
=
output
inductor_compiled_graph
=
output
...
@@ -401,13 +401,13 @@ class InductorAdaptor(CompilerInterface):
...
@@ -401,13 +401,13 @@ class InductorAdaptor(CompilerInterface):
hash_str
=
inductor_compiled_graph
.
_fx_graph_cache_key
hash_str
=
inductor_compiled_graph
.
_fx_graph_cache_key
return
output
return
output
def
hijack_compiled_fx_graph_hash
(
*
args
,
**
kwargs
)
:
def
hijack_compiled_fx_graph_hash
(
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
out
=
compiled_fx_graph_hash
(
*
args
,
**
kwargs
)
out
=
compiled_fx_graph_hash
(
*
args
,
**
kwargs
)
nonlocal
hash_str
nonlocal
hash_str
hash_str
=
out
[
0
]
hash_str
=
out
[
0
]
return
out
return
out
def
_check_can_cache
(
*
args
,
**
kwargs
)
:
def
_check_can_cache
(
*
args
:
Any
,
**
kwargs
:
Any
)
->
None
:
# no error means it can be cached.
# no error means it can be cached.
# Inductor refuses to cache the graph outside of Dynamo
# Inductor refuses to cache the graph outside of Dynamo
# tracing context, and also disables caching for graphs
# tracing context, and also disables caching for graphs
...
@@ -513,7 +513,7 @@ class InductorAdaptor(CompilerInterface):
...
@@ -513,7 +513,7 @@ class InductorAdaptor(CompilerInterface):
example_inputs
:
list
[
Any
],
example_inputs
:
list
[
Any
],
graph_index
:
int
,
graph_index
:
int
,
compile_range
:
Range
,
compile_range
:
Range
,
)
->
Callable
:
)
->
Callable
[...,
Any
]
:
assert
isinstance
(
handle
,
tuple
)
assert
isinstance
(
handle
,
tuple
)
assert
isinstance
(
handle
[
0
],
str
)
assert
isinstance
(
handle
[
0
],
str
)
assert
isinstance
(
handle
[
1
],
str
)
assert
isinstance
(
handle
[
1
],
str
)
...
@@ -572,7 +572,7 @@ class InductorAdaptor(CompilerInterface):
...
@@ -572,7 +572,7 @@ class InductorAdaptor(CompilerInterface):
returns_tuple
=
graph_returns_tuple
(
graph
)
returns_tuple
=
graph_returns_tuple
(
graph
)
# this is the callable we return to Dynamo to run
# this is the callable we return to Dynamo to run
def
compiled_graph
(
*
args
)
:
def
compiled_graph
(
*
args
:
Any
)
->
tuple
[
Any
,
...]
|
Any
:
# convert args to list
# convert args to list
list_args
=
list
(
args
)
list_args
=
list
(
args
)
graph_output
=
inductor_compiled_graph
(
list_args
)
graph_output
=
inductor_compiled_graph
(
list_args
)
...
@@ -584,7 +584,7 @@ class InductorAdaptor(CompilerInterface):
...
@@ -584,7 +584,7 @@ class InductorAdaptor(CompilerInterface):
return
compiled_graph
return
compiled_graph
def
metrics_context
(
self
)
->
contextlib
.
AbstractContextManager
:
def
metrics_context
(
self
)
->
contextlib
.
AbstractContextManager
[
Any
]
:
"""
"""
This method returns the Dynamo metrics context (if it exists,
This method returns the Dynamo metrics context (if it exists,
otherwise a null context). It is used by various compile components.
otherwise a null context). It is used by various compile components.
...
@@ -603,12 +603,12 @@ class InductorAdaptor(CompilerInterface):
...
@@ -603,12 +603,12 @@ class InductorAdaptor(CompilerInterface):
if
is_torch_equal_or_newer
(
"2.6"
):
if
is_torch_equal_or_newer
(
"2.6"
):
import
torch._dynamo.utils
import
torch._dynamo.utils
return
torch
.
_dynamo
.
utils
.
get_metrics_context
()
return
torch
.
_dynamo
.
utils
.
get_metrics_context
()
# type: ignore[no-any-return]
else
:
else
:
return
contextlib
.
nullcontext
()
return
contextlib
.
nullcontext
()
def
set_inductor_config
(
config
,
compile_range
:
Range
):
def
set_inductor_config
(
config
:
dict
[
str
,
Any
]
,
compile_range
:
Range
)
->
None
:
if
compile_range
.
is_single_size
():
if
compile_range
.
is_single_size
():
# for a specific batch size, tuning triton kernel parameters
# for a specific batch size, tuning triton kernel parameters
# can be beneficial
# can be beneficial
...
@@ -618,7 +618,7 @@ def set_inductor_config(config, compile_range: Range):
...
@@ -618,7 +618,7 @@ def set_inductor_config(config, compile_range: Range):
)
)
def
set_functorch_config
():
def
set_functorch_config
()
->
None
:
torch
.
_functorch
.
config
.
bundled_autograd_cache
=
False
torch
.
_functorch
.
config
.
bundled_autograd_cache
=
False
...
@@ -632,7 +632,7 @@ class EagerAdaptor(CompilerInterface):
...
@@ -632,7 +632,7 @@ class EagerAdaptor(CompilerInterface):
compiler_config
:
dict
[
str
,
Any
],
compiler_config
:
dict
[
str
,
Any
],
compile_range
:
Range
,
compile_range
:
Range
,
key
:
str
|
None
=
None
,
key
:
str
|
None
=
None
,
)
->
tuple
[
Callable
|
None
,
Any
|
None
]:
)
->
tuple
[
Callable
[...,
Any
]
|
None
,
Any
|
None
]:
compilation_counter
.
num_eager_compiles
+=
1
compilation_counter
.
num_eager_compiles
+=
1
# we don't need to compile the graph, just return the graph itself.
# we don't need to compile the graph, just return the graph itself.
# It does not support caching, return None for the handle.
# It does not support caching, return None for the handle.
...
...
vllm/compilation/counter.py
View file @
873480d1
...
@@ -3,7 +3,9 @@
...
@@ -3,7 +3,9 @@
import
copy
import
copy
import
dataclasses
import
dataclasses
from
collections.abc
import
Generator
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
Any
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -34,7 +36,7 @@ class CompilationCounter:
...
@@ -34,7 +36,7 @@ class CompilationCounter:
return
copy
.
deepcopy
(
self
)
return
copy
.
deepcopy
(
self
)
@
contextmanager
@
contextmanager
def
expect
(
self
,
**
kwargs
)
:
def
expect
(
self
,
**
kwargs
:
Any
)
->
Generator
[
None
,
None
,
None
]
:
old
=
self
.
clone
()
old
=
self
.
clone
()
yield
yield
for
k
,
v
in
kwargs
.
items
():
for
k
,
v
in
kwargs
.
items
():
...
...
vllm/compilation/cuda_graph.py
View file @
873480d1
...
@@ -219,6 +219,7 @@ class CUDAGraphWrapper:
...
@@ -219,6 +219,7 @@ class CUDAGraphWrapper:
# runtime modes.
# runtime modes.
return
self
.
runnable
(
*
args
,
**
kwargs
)
return
self
.
runnable
(
*
args
,
**
kwargs
)
assert
batch_descriptor
is
not
None
if
batch_descriptor
not
in
self
.
concrete_cudagraph_entries
:
if
batch_descriptor
not
in
self
.
concrete_cudagraph_entries
:
# create a new entry for this batch descriptor
# create a new entry for this batch descriptor
self
.
concrete_cudagraph_entries
[
batch_descriptor
]
=
CUDAGraphEntry
(
self
.
concrete_cudagraph_entries
[
batch_descriptor
]
=
CUDAGraphEntry
(
...
...
vllm/compilation/fx_utils.py
View file @
873480d1
...
@@ -7,10 +7,11 @@ from collections.abc import Iterable, Iterator
...
@@ -7,10 +7,11 @@ from collections.abc import Iterable, Iterator
from
torch
import
fx
from
torch
import
fx
from
torch._higher_order_ops.auto_functionalize
import
auto_functionalized
from
torch._higher_order_ops.auto_functionalize
import
auto_functionalized
from
torch._ops
import
OpOverload
,
OpOverloadPacket
from
torch._ops
import
OpOverload
,
OpOverloadPacket
from
torch.fx.node
import
Target
def
is_func
(
node
:
fx
.
Node
,
target
)
->
bool
:
def
is_func
(
node
:
fx
.
Node
,
target
:
Target
)
->
bool
:
return
node
.
op
==
"call_function"
and
node
.
target
==
target
return
bool
(
node
.
op
==
"call_function"
and
node
.
target
==
target
)
def
is_auto_func
(
node
:
fx
.
Node
,
op
:
OpOverload
)
->
bool
:
def
is_auto_func
(
node
:
fx
.
Node
,
op
:
OpOverload
)
->
bool
:
...
...
vllm/compilation/inductor_pass.py
View file @
873480d1
...
@@ -8,9 +8,9 @@ import hashlib
...
@@ -8,9 +8,9 @@ import hashlib
import
inspect
import
inspect
import
json
import
json
import
types
import
types
from
collections.abc
import
Callable
from
collections.abc
import
Callable
,
Generator
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Any
from
typing
import
TYPE_CHECKING
,
Any
,
ParamSpec
,
TypeVar
import
torch
import
torch
from
torch
import
fx
from
torch
import
fx
...
@@ -30,6 +30,8 @@ else:
...
@@ -30,6 +30,8 @@ else:
)
)
_pass_context
=
None
_pass_context
=
None
P
=
ParamSpec
(
"P"
)
R
=
TypeVar
(
"R"
)
class
PassContext
:
class
PassContext
:
...
@@ -44,7 +46,7 @@ def get_pass_context() -> PassContext:
...
@@ -44,7 +46,7 @@ def get_pass_context() -> PassContext:
@
contextmanager
@
contextmanager
def
pass_context
(
compile_range
:
Range
):
def
pass_context
(
compile_range
:
Range
)
->
Generator
[
None
,
None
,
None
]
:
"""A context manager that stores the current pass context,
"""A context manager that stores the current pass context,
usually it is a list of sizes to specialize.
usually it is a list of sizes to specialize.
"""
"""
...
@@ -57,7 +59,7 @@ def pass_context(compile_range: Range):
...
@@ -57,7 +59,7 @@ def pass_context(compile_range: Range):
_pass_context
=
prev_context
_pass_context
=
prev_context
class
InductorPass
(
CustomGraphPass
):
class
InductorPass
(
CustomGraphPass
):
# type: ignore[misc]
"""
"""
A custom graph pass that uses a hash of its source as the UUID.
A custom graph pass that uses a hash of its source as the UUID.
This is defined as a convenience and should work in most cases.
This is defined as a convenience and should work in most cases.
...
@@ -73,7 +75,7 @@ class InductorPass(CustomGraphPass):
...
@@ -73,7 +75,7 @@ class InductorPass(CustomGraphPass):
return
InductorPass
.
hash_source
(
self
)
return
InductorPass
.
hash_source
(
self
)
@
staticmethod
@
staticmethod
def
hash_source
(
*
srcs
:
str
|
Any
):
def
hash_source
(
*
srcs
:
str
|
Any
)
->
str
:
"""
"""
Utility method to hash the sources of functions or objects.
Utility method to hash the sources of functions or objects.
:param srcs: strings or objects to add to the hash.
:param srcs: strings or objects to add to the hash.
...
@@ -93,7 +95,7 @@ class InductorPass(CustomGraphPass):
...
@@ -93,7 +95,7 @@ class InductorPass(CustomGraphPass):
return
hasher
.
hexdigest
()
return
hasher
.
hexdigest
()
@
staticmethod
@
staticmethod
def
hash_dict
(
dict_
:
dict
[
Any
,
Any
]):
def
hash_dict
(
dict_
:
dict
[
Any
,
Any
])
->
str
:
"""
"""
Utility method to hash a dictionary, can alternatively be used for uuid.
Utility method to hash a dictionary, can alternatively be used for uuid.
:return: A sha256 hash of the json rep of the dictionary.
:return: A sha256 hash of the json rep of the dictionary.
...
@@ -101,7 +103,7 @@ class InductorPass(CustomGraphPass):
...
@@ -101,7 +103,7 @@ class InductorPass(CustomGraphPass):
encoded
=
json
.
dumps
(
dict_
,
sort_keys
=
True
).
encode
(
"utf-8"
)
encoded
=
json
.
dumps
(
dict_
,
sort_keys
=
True
).
encode
(
"utf-8"
)
return
hashlib
.
sha256
(
encoded
).
hexdigest
()
return
hashlib
.
sha256
(
encoded
).
hexdigest
()
def
is_applicable_for_range
(
self
,
compile_range
:
Range
):
def
is_applicable_for_range
(
self
,
compile_range
:
Range
)
->
bool
:
return
True
return
True
...
@@ -111,25 +113,27 @@ class CallableInductorPass(InductorPass):
...
@@ -111,25 +113,27 @@ class CallableInductorPass(InductorPass):
implementation of the UUID.
implementation of the UUID.
"""
"""
def
__init__
(
self
,
callable
:
Callable
[[
fx
.
Graph
],
None
],
uuid
:
Any
|
None
=
None
):
def
__init__
(
self
,
callable
:
Callable
[[
fx
.
Graph
],
None
],
uuid
:
Any
|
None
=
None
)
->
None
:
self
.
callable
=
callable
self
.
callable
=
callable
self
.
_uuid
=
self
.
hash_source
(
callable
)
if
uuid
is
None
else
uuid
self
.
_uuid
=
self
.
hash_source
(
callable
)
if
uuid
is
None
else
uuid
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
):
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
)
->
None
:
self
.
callable
(
graph
)
self
.
callable
(
graph
)
def
uuid
(
self
)
->
Any
:
def
uuid
(
self
)
->
Any
:
return
self
.
_uuid
return
self
.
_uuid
def
enable_fake_mode
(
fn
:
Callable
[
...,
Any
])
->
Callable
[
...,
Any
]:
def
enable_fake_mode
(
fn
:
Callable
[
P
,
R
])
->
Callable
[
P
,
R
]:
"""
"""
Applies a FakeTensorMode context. This is useful when you don't want to
Applies a FakeTensorMode context. This is useful when you don't want to
create or run things with real tensors.
create or run things with real tensors.
"""
"""
@
functools
.
wraps
(
fn
)
@
functools
.
wraps
(
fn
)
def
fn_new
(
*
args
,
**
kwargs
)
->
Any
:
def
fn_new
(
*
args
:
P
.
args
,
**
kwargs
:
P
.
kwargs
)
->
R
:
with
torch
.
_guards
.
tracing
(
None
),
unset_fake_temporarily
(),
FakeTensorMode
():
with
torch
.
_guards
.
tracing
(
None
),
unset_fake_temporarily
(),
FakeTensorMode
():
result
=
fn
(
*
args
,
**
kwargs
)
result
=
fn
(
*
args
,
**
kwargs
)
...
...
vllm/compilation/monitor.py
View file @
873480d1
...
@@ -12,7 +12,7 @@ context_manager = None
...
@@ -12,7 +12,7 @@ context_manager = None
torch_compile_start_time
:
float
=
0.0
torch_compile_start_time
:
float
=
0.0
def
start_monitoring_torch_compile
(
vllm_config
:
VllmConfig
):
def
start_monitoring_torch_compile
(
vllm_config
:
VllmConfig
)
->
None
:
global
torch_compile_start_time
global
torch_compile_start_time
torch_compile_start_time
=
time
.
time
()
torch_compile_start_time
=
time
.
time
()
...
@@ -28,7 +28,7 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig):
...
@@ -28,7 +28,7 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig):
context_manager
.
__enter__
()
context_manager
.
__enter__
()
def
end_monitoring_torch_compile
(
vllm_config
:
VllmConfig
):
def
end_monitoring_torch_compile
(
vllm_config
:
VllmConfig
)
->
None
:
compilation_config
:
CompilationConfig
=
vllm_config
.
compilation_config
compilation_config
:
CompilationConfig
=
vllm_config
.
compilation_config
if
compilation_config
.
mode
==
CompilationMode
.
VLLM_COMPILE
:
if
compilation_config
.
mode
==
CompilationMode
.
VLLM_COMPILE
:
logger
.
info_once
(
logger
.
info_once
(
...
@@ -45,7 +45,7 @@ def end_monitoring_torch_compile(vllm_config: VllmConfig):
...
@@ -45,7 +45,7 @@ def end_monitoring_torch_compile(vllm_config: VllmConfig):
cudagraph_capturing_enabled
:
bool
=
True
cudagraph_capturing_enabled
:
bool
=
True
def
validate_cudagraph_capturing_enabled
():
def
validate_cudagraph_capturing_enabled
()
->
None
:
# used to monitor whether a cudagraph capturing is legal at runtime.
# used to monitor whether a cudagraph capturing is legal at runtime.
# should be called before any cudagraph capturing.
# should be called before any cudagraph capturing.
# if an illegal cudagraph capturing happens, raise an error.
# if an illegal cudagraph capturing happens, raise an error.
...
@@ -57,6 +57,6 @@ def validate_cudagraph_capturing_enabled():
...
@@ -57,6 +57,6 @@ def validate_cudagraph_capturing_enabled():
)
)
def
set_cudagraph_capturing_enabled
(
enabled
:
bool
):
def
set_cudagraph_capturing_enabled
(
enabled
:
bool
)
->
None
:
global
cudagraph_capturing_enabled
global
cudagraph_capturing_enabled
cudagraph_capturing_enabled
=
enabled
cudagraph_capturing_enabled
=
enabled
vllm/compilation/partition_rules.py
View file @
873480d1
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
contextlib
import
contextlib
from
collections.abc
import
Generator
import
torch
import
torch
...
@@ -38,7 +39,9 @@ def should_split(node: torch.fx.Node, splitting_ops: list[str]) -> bool:
...
@@ -38,7 +39,9 @@ def should_split(node: torch.fx.Node, splitting_ops: list[str]) -> bool:
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
inductor_partition_rule_context
(
splitting_ops
:
list
[
str
]):
def
inductor_partition_rule_context
(
splitting_ops
:
list
[
str
]
|
None
,
)
->
Generator
[
None
,
None
,
None
]:
"""Context manager to temporarily register Inductor partition rules.
"""Context manager to temporarily register Inductor partition rules.
Registers custom partition rules for specified operators, forcing the
Registers custom partition rules for specified operators, forcing the
...
...
vllm/compilation/sequence_parallelism.py
View file @
873480d1
...
@@ -41,8 +41,8 @@ class _SequenceParallelPatternHelper:
...
@@ -41,8 +41,8 @@ class _SequenceParallelPatternHelper:
self
,
self
,
epsilon
:
float
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
device
:
str
,
device
:
str
|
None
,
):
)
->
None
:
self
.
epsilon
=
epsilon
self
.
epsilon
=
epsilon
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
device
=
device
self
.
device
=
device
...
@@ -64,7 +64,7 @@ class _SequenceParallelPatternHelper:
...
@@ -64,7 +64,7 @@ class _SequenceParallelPatternHelper:
class
FirstAllReduceRMSNormPattern
(
_SequenceParallelPatternHelper
):
class
FirstAllReduceRMSNormPattern
(
_SequenceParallelPatternHelper
):
def
__init__
(
self
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
device
:
str
)
:
def
__init__
(
self
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
device
:
str
|
None
)
->
None
:
super
().
__init__
(
epsilon
,
dtype
,
device
)
super
().
__init__
(
epsilon
,
dtype
,
device
)
self
.
rmsnorm_matcher
=
MatcherRMSNorm
(
epsilon
)
self
.
rmsnorm_matcher
=
MatcherRMSNorm
(
epsilon
)
...
@@ -74,7 +74,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
...
@@ -74,7 +74,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
return
[
input
,
arg3_1
]
return
[
input
,
arg3_1
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
def
pattern
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
arg3_1
:
torch
.
Tensor
,
arg3_1
:
torch
.
Tensor
,
...
@@ -100,7 +100,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
...
@@ -100,7 +100,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
class
MiddleAllReduceRMSNormPattern
(
_SequenceParallelPatternHelper
):
class
MiddleAllReduceRMSNormPattern
(
_SequenceParallelPatternHelper
):
def
__init__
(
self
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
device
:
str
):
def
__init__
(
self
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
device
:
str
|
None
):
super
().
__init__
(
epsilon
,
dtype
,
device
)
super
().
__init__
(
epsilon
,
dtype
,
device
)
self
.
rmsnorm_matcher
=
MatcherFusedAddRMSNorm
(
epsilon
)
self
.
rmsnorm_matcher
=
MatcherFusedAddRMSNorm
(
epsilon
)
...
@@ -162,7 +162,7 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
...
@@ -162,7 +162,7 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
self
,
self
,
epsilon
:
float
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
device
:
str
,
device
:
str
|
None
,
):
):
super
().
__init__
(
epsilon
,
dtype
,
device
)
super
().
__init__
(
epsilon
,
dtype
,
device
)
self
.
rmsnorm_matcher
=
MatcherRMSNorm
(
epsilon
)
self
.
rmsnorm_matcher
=
MatcherRMSNorm
(
epsilon
)
...
@@ -203,7 +203,7 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
...
@@ -203,7 +203,7 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
class
MiddleAllReduceRMSNormStaticFP8Pattern
(
_SequenceParallelPatternHelper
):
class
MiddleAllReduceRMSNormStaticFP8Pattern
(
_SequenceParallelPatternHelper
):
def
__init__
(
self
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
device
:
str
):
def
__init__
(
self
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
device
:
str
|
None
):
super
().
__init__
(
epsilon
,
dtype
,
device
)
super
().
__init__
(
epsilon
,
dtype
,
device
)
self
.
rmsnorm_matcher
=
MatcherFusedAddRMSNorm
(
epsilon
)
self
.
rmsnorm_matcher
=
MatcherFusedAddRMSNorm
(
epsilon
)
self
.
quant_matcher
=
MatcherQuantFP8
(
kFp8StaticTensorSym
)
self
.
quant_matcher
=
MatcherQuantFP8
(
kFp8StaticTensorSym
)
...
...
vllm/compilation/torch25_custom_graph_pass.py
View file @
873480d1
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
from
typing
import
Any
,
NoReturn
import
torch
import
torch
...
@@ -29,14 +29,14 @@ class Torch25CustomGraphPass(ABC): # noqa (redefinition)
...
@@ -29,14 +29,14 @@ class Torch25CustomGraphPass(ABC): # noqa (redefinition)
Return None to skip inductor code caching entirely.
Return None to skip inductor code caching entirely.
"""
"""
def
__getstate__
(
self
):
def
__getstate__
(
self
)
->
Any
|
None
:
"""
"""
Pickling is used instead of uuid() in torch<2.6. Just return uuid()
Pickling is used instead of uuid() in torch<2.6. Just return uuid()
to enable subclasses to only have to implement uuid.
to enable subclasses to only have to implement uuid.
"""
"""
return
self
.
uuid
()
return
self
.
uuid
()
def
__setstate__
(
self
,
state
)
:
def
__setstate__
(
self
,
state
:
Any
)
->
NoReturn
:
raise
ValueError
(
raise
ValueError
(
"Cannot unpickle CustomGraphPass because pickling"
"Cannot unpickle CustomGraphPass because pickling"
" is used for cache key uuid. Use torch>=2.6 with"
" is used for cache key uuid. Use torch>=2.6 with"
...
...
vllm/compilation/vllm_inductor_pass.py
View file @
873480d1
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
import
functools
import
functools
import
operator
import
operator
import
time
import
time
from
collections.abc
import
Callable
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
ClassVar
from
typing
import
ClassVar
...
@@ -43,13 +44,17 @@ class VllmInductorPass(InductorPass):
...
@@ -43,13 +44,17 @@ class VllmInductorPass(InductorPass):
)
)
self
.
pass_config
=
config
.
compilation_config
.
pass_config
self
.
pass_config
=
config
.
compilation_config
.
pass_config
self
.
model_dtype
=
config
.
model_config
.
dtype
if
config
.
model_config
else
None
self
.
model_dtype
=
config
.
model_config
.
dtype
if
config
.
model_config
else
None
self
.
device
=
config
.
device_config
.
device
if
config
.
device_config
else
None
self
.
device
:
str
|
None
=
(
config
.
device_config
.
device
if
config
.
device_config
else
None
)
self
.
pass_name
=
self
.
__class__
.
__name__
self
.
pass_name
=
self
.
__class__
.
__name__
@
staticmethod
@
staticmethod
def
time_and_log
(
call_fn
):
def
time_and_log
(
call_fn
:
Callable
[[
"VllmInductorPass"
,
torch
.
fx
.
Graph
],
None
],
)
->
Callable
[[
"VllmInductorPass"
,
torch
.
fx
.
Graph
],
None
]:
@
functools
.
wraps
(
call_fn
)
@
functools
.
wraps
(
call_fn
)
def
wrapped
(
self
:
VllmInductorPass
,
graph
:
torch
.
fx
.
Graph
):
def
wrapped
(
self
:
VllmInductorPass
,
graph
:
torch
.
fx
.
Graph
)
->
None
:
self
.
begin
()
self
.
begin
()
self
.
dump_graph
(
graph
,
"before"
)
self
.
dump_graph
(
graph
,
"before"
)
call_fn
(
self
,
graph
)
call_fn
(
self
,
graph
)
...
@@ -58,17 +63,17 @@ class VllmInductorPass(InductorPass):
...
@@ -58,17 +63,17 @@ class VllmInductorPass(InductorPass):
return
wrapped
return
wrapped
def
dump_graph
(
self
,
graph
:
torch
.
fx
.
Graph
,
stage
:
str
):
def
dump_graph
(
self
,
graph
:
torch
.
fx
.
Graph
,
stage
:
str
)
->
None
:
i
=
VllmInductorPass
.
dump_prefix
i
=
VllmInductorPass
.
dump_prefix
i_str
=
""
if
i
is
None
else
f
".
{
i
}
"
i_str
=
""
if
i
is
None
else
f
".
{
i
}
"
lazy_format_graph_code
(
lazy_format_graph_code
(
f
"post_grad
{
i_str
}
.
{
self
.
pass_name
}
.
{
stage
}
"
,
graph
.
owning_module
f
"post_grad
{
i_str
}
.
{
self
.
pass_name
}
.
{
stage
}
"
,
graph
.
owning_module
)
)
def
begin
(
self
):
def
begin
(
self
)
->
None
:
self
.
_start_time
=
time
.
perf_counter_ns
()
self
.
_start_time
=
time
.
perf_counter_ns
()
def
end_and_log
(
self
):
def
end_and_log
(
self
)
->
None
:
self
.
_end_time
=
time
.
perf_counter_ns
()
self
.
_end_time
=
time
.
perf_counter_ns
()
duration_ms
=
float
(
self
.
_end_time
-
self
.
_start_time
)
/
1.0e6
duration_ms
=
float
(
self
.
_end_time
-
self
.
_start_time
)
/
1.0e6
logger
.
debug
(
"%s completed in %.1f ms"
,
self
.
pass_name
,
duration_ms
)
logger
.
debug
(
"%s completed in %.1f ms"
,
self
.
pass_name
,
duration_ms
)
...
@@ -92,12 +97,14 @@ class VllmPatternMatcherPass(VllmInductorPass):
...
@@ -92,12 +97,14 @@ class VllmPatternMatcherPass(VllmInductorPass):
def
_replace_op_overloads
(
self
,
string
:
str
)
->
str
:
def
_replace_op_overloads
(
self
,
string
:
str
)
->
str
:
"""Replace <OpOverload(..., ...)> with nicer formulations"""
"""Replace <OpOverload(..., ...)> with nicer formulations"""
return
self
.
_OP_OVERLOAD_PATTERN
.
sub
(
return
str
(
lambda
m
:
f
"torch.ops.
{
m
.
group
(
1
)
}
.
{
m
.
group
(
2
)
}
"
,
self
.
_OP_OVERLOAD_PATTERN
.
sub
(
string
,
lambda
m
:
f
"torch.ops.
{
m
.
group
(
1
)
}
.
{
m
.
group
(
2
)
}
"
,
string
,
)
)
)
def
dump_patterns
(
self
,
config
:
VllmConfig
,
pm_pass
:
PatternMatcherPass
):
def
dump_patterns
(
self
,
config
:
VllmConfig
,
pm_pass
:
PatternMatcherPass
)
->
None
:
"""
"""
If debug dumping is enabled, dump the Inductor pattern-matcher patterns
If debug dumping is enabled, dump the Inductor pattern-matcher patterns
into the debug_dump_path folder next to the dumped fx graphs.
into the debug_dump_path folder next to the dumped fx graphs.
...
@@ -165,9 +172,9 @@ class VllmPatternMatcherPass(VllmInductorPass):
...
@@ -165,9 +172,9 @@ class VllmPatternMatcherPass(VllmInductorPass):
class
PrinterInductorPass
(
VllmInductorPass
):
class
PrinterInductorPass
(
VllmInductorPass
):
def
__init__
(
self
,
name
:
str
,
config
:
VllmConfig
):
def
__init__
(
self
,
name
:
str
,
config
:
VllmConfig
)
->
None
:
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
name
=
name
self
.
name
=
name
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
):
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
)
->
None
:
self
.
dump_graph
(
graph
,
self
.
name
)
self
.
dump_graph
(
graph
,
self
.
name
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment