Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
15e302df
Unverified
Commit
15e302df
authored
Jan 22, 2026
by
Lucas Kabela
Committed by
GitHub
Jan 22, 2026
Browse files
[Misc][BE] Turn on strict type coverage for vllm/compilation (#31756)
Signed-off-by:
Lucas Kabela
<
lucaskabela@meta.com
>
parent
d117a4d1
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
121 additions
and
68 deletions
+121
-68
pyproject.toml
pyproject.toml
+7
-0
tests/compile/test_pass_manager.py
tests/compile/test_pass_manager.py
+1
-1
tools/pre_commit/mypy.py
tools/pre_commit/mypy.py
+37
-3
vllm/compilation/backends.py
vllm/compilation/backends.py
+1
-1
vllm/compilation/caching.py
vllm/compilation/caching.py
+12
-11
vllm/compilation/collective_fusion.py
vllm/compilation/collective_fusion.py
+15
-13
vllm/compilation/compiler_interface.py
vllm/compilation/compiler_interface.py
+6
-6
vllm/compilation/decorators.py
vllm/compilation/decorators.py
+17
-17
vllm/compilation/matcher_utils.py
vllm/compilation/matcher_utils.py
+13
-9
vllm/compilation/piecewise_backend.py
vllm/compilation/piecewise_backend.py
+8
-6
vllm/compilation/sequence_parallelism.py
vllm/compilation/sequence_parallelism.py
+4
-1
No files found.
pyproject.toml
View file @
15e302df
...
@@ -100,6 +100,13 @@ ignore_missing_imports = true
...
@@ -100,6 +100,13 @@ ignore_missing_imports = true
check_untyped_defs
=
true
check_untyped_defs
=
true
follow_imports
=
"silent"
follow_imports
=
"silent"
[[tool.mypy.overrides]]
module
=
"vllm.compilation.*"
disallow_untyped_defs
=
true
disallow_incomplete_defs
=
true
warn_return_any
=
true
follow_imports
=
"silent"
[tool.pytest.ini_options]
[tool.pytest.ini_options]
markers
=
[
markers
=
[
"slow_test"
,
"slow_test"
,
...
...
tests/compile/test_pass_manager.py
View file @
15e302df
...
@@ -28,7 +28,7 @@ def test_bad_callable():
...
@@ -28,7 +28,7 @@ def test_bad_callable():
pass_manager
.
configure
(
config
)
pass_manager
.
configure
(
config
)
with
pytest
.
raises
(
AssertionError
):
with
pytest
.
raises
(
AssertionError
):
pass_manager
.
add
(
simple_callable
)
pass_manager
.
add
(
simple_callable
)
# type: ignore[arg-type]
# Pass that inherits from InductorPass
# Pass that inherits from InductorPass
...
...
tools/pre_commit/mypy.py
View file @
15e302df
...
@@ -77,6 +77,11 @@ EXCLUDE = [
...
@@ -77,6 +77,11 @@ EXCLUDE = [
"vllm/v1/attention/ops"
,
"vllm/v1/attention/ops"
,
]
]
# Directories that should be checked with --strict
STRICT_DIRS
=
[
"vllm/compilation"
,
]
def
group_files
(
changed_files
:
list
[
str
])
->
dict
[
str
,
list
[
str
]]:
def
group_files
(
changed_files
:
list
[
str
])
->
dict
[
str
,
list
[
str
]]:
"""
"""
...
@@ -108,11 +113,17 @@ def group_files(changed_files: list[str]) -> dict[str, list[str]]:
...
@@ -108,11 +113,17 @@ def group_files(changed_files: list[str]) -> dict[str, list[str]]:
return
file_groups
return
file_groups
def
is_strict_file
(
filepath
:
str
)
->
bool
:
"""Check if a file should be checked with strict mode."""
return
any
(
filepath
.
startswith
(
strict_dir
)
for
strict_dir
in
STRICT_DIRS
)
def
mypy
(
def
mypy
(
targets
:
list
[
str
],
targets
:
list
[
str
],
python_version
:
str
|
None
,
python_version
:
str
|
None
,
follow_imports
:
str
|
None
,
follow_imports
:
str
|
None
,
file_group
:
str
,
file_group
:
str
,
strict
:
bool
=
False
,
)
->
int
:
)
->
int
:
"""
"""
Run mypy on the given targets.
Run mypy on the given targets.
...
@@ -124,6 +135,7 @@ def mypy(
...
@@ -124,6 +135,7 @@ def mypy(
follow_imports: Value for the --follow-imports option or None to use
follow_imports: Value for the --follow-imports option or None to use
the default mypy behavior.
the default mypy behavior.
file_group: The file group name for logging purposes.
file_group: The file group name for logging purposes.
strict: If True, run mypy with --strict flag.
Returns:
Returns:
The return code from mypy.
The return code from mypy.
...
@@ -133,6 +145,8 @@ def mypy(
...
@@ -133,6 +145,8 @@ def mypy(
args
+=
[
"--python-version"
,
python_version
]
args
+=
[
"--python-version"
,
python_version
]
if
follow_imports
is
not
None
:
if
follow_imports
is
not
None
:
args
+=
[
"--follow-imports"
,
follow_imports
]
args
+=
[
"--follow-imports"
,
follow_imports
]
if
strict
:
args
+=
[
"--strict"
]
print
(
f
"$
{
' '
.
join
(
args
)
}
{
file_group
}
"
)
print
(
f
"$
{
' '
.
join
(
args
)
}
{
file_group
}
"
)
return
subprocess
.
run
(
args
+
targets
,
check
=
False
).
returncode
return
subprocess
.
run
(
args
+
targets
,
check
=
False
).
returncode
...
@@ -149,9 +163,29 @@ def main():
...
@@ -149,9 +163,29 @@ def main():
for
file_group
,
changed_files
in
file_groups
.
items
():
for
file_group
,
changed_files
in
file_groups
.
items
():
follow_imports
=
None
if
ci
and
file_group
==
""
else
"skip"
follow_imports
=
None
if
ci
and
file_group
==
""
else
"skip"
if
changed_files
:
if
changed_files
:
returncode
|=
mypy
(
# Separate files into strict and non-strict groups
changed_files
,
python_version
,
follow_imports
,
file_group
strict_files
=
[
f
for
f
in
changed_files
if
is_strict_file
(
f
)]
)
non_strict_files
=
[
f
for
f
in
changed_files
if
not
is_strict_file
(
f
)]
# Run mypy on non-strict files
if
non_strict_files
:
returncode
|=
mypy
(
non_strict_files
,
python_version
,
follow_imports
,
file_group
,
strict
=
False
,
)
# Run mypy on strict files with --strict flag
if
strict_files
:
returncode
|=
mypy
(
strict_files
,
python_version
,
follow_imports
,
f
"
{
file_group
}
(strict)"
,
strict
=
True
,
)
return
returncode
return
returncode
...
...
vllm/compilation/backends.py
View file @
15e302df
...
@@ -68,7 +68,7 @@ def make_copy_and_call(
...
@@ -68,7 +68,7 @@ def make_copy_and_call(
A wrapper function that copies inputs and calls the compiled function
A wrapper function that copies inputs and calls the compiled function
"""
"""
def
copy_and_call
(
*
args
)
:
def
copy_and_call
(
*
args
:
Any
)
->
Any
:
list_args
=
list
(
args
)
list_args
=
list
(
args
)
for
i
,
index
in
enumerate
(
sym_tensor_indices
):
for
i
,
index
in
enumerate
(
sym_tensor_indices
):
runtime_tensor
=
list_args
[
index
]
runtime_tensor
=
list_args
[
index
]
...
...
vllm/compilation/caching.py
View file @
15e302df
...
@@ -43,15 +43,15 @@ class StandaloneCompiledArtifacts:
...
@@ -43,15 +43,15 @@ class StandaloneCompiledArtifacts:
split on attn)
split on attn)
"""
"""
def
__init__
(
self
):
def
__init__
(
self
)
->
None
:
# dict from submodule name to byte hash
# dict from submodule name to byte hash
self
.
submodule_bytes
=
{}
self
.
submodule_bytes
:
dict
[
str
,
str
]
=
{}
# dict from byte hash to bytes
# dict from byte hash to bytes
self
.
submodule_bytes_store
=
{}
self
.
submodule_bytes_store
:
dict
[
str
,
bytes
]
=
{}
# dict from byte hash to loaded module
# dict from byte hash to loaded module
self
.
loaded_submodule_store
=
{}
self
.
loaded_submodule_store
:
dict
[
str
,
Any
]
=
{}
def
insert
(
self
,
submod_name
:
str
,
shape
:
str
,
entry
:
bytes
):
def
insert
(
self
,
submod_name
:
str
,
shape
:
str
,
entry
:
bytes
)
->
None
:
hasher
=
hashlib
.
sha256
()
hasher
=
hashlib
.
sha256
()
hasher
.
update
(
entry
)
hasher
.
update
(
entry
)
hex_digest
=
hasher
.
hexdigest
()
hex_digest
=
hasher
.
hexdigest
()
...
@@ -86,7 +86,7 @@ class StandaloneCompiledArtifacts:
...
@@ -86,7 +86,7 @@ class StandaloneCompiledArtifacts:
self
.
submodule_bytes
[
f
"
{
submod_name
}
_
{
shape
}
"
]
self
.
submodule_bytes
[
f
"
{
submod_name
}
_
{
shape
}
"
]
]
]
def
get_loaded
(
self
,
submod_name
:
str
,
shape
:
str
):
def
get_loaded
(
self
,
submod_name
:
str
,
shape
:
str
)
->
Any
:
logger
.
debug
(
logger
.
debug
(
"getting artifact for submod %s with shape %s"
,
"getting artifact for submod %s with shape %s"
,
submod_name
,
submod_name
,
...
@@ -119,7 +119,7 @@ class StandaloneCompiledArtifacts:
...
@@ -119,7 +119,7 @@ class StandaloneCompiledArtifacts:
from
torch._inductor.standalone_compile
import
AOTCompiledArtifact
from
torch._inductor.standalone_compile
import
AOTCompiledArtifact
def
_load_entry
(
entry_bytes
)
->
AOTCompiledArtifact
:
def
_load_entry
(
entry_
bytes
:
bytes
)
->
AOTCompiledArtifact
:
entry
=
pickle
.
loads
(
entry_bytes
)
entry
=
pickle
.
loads
(
entry_bytes
)
return
AOTCompiledArtifact
.
deserialize
(
entry
)
return
AOTCompiledArtifact
.
deserialize
(
entry
)
...
@@ -132,13 +132,13 @@ class StandaloneCompiledArtifacts:
...
@@ -132,13 +132,13 @@ class StandaloneCompiledArtifacts:
logger
.
debug
(
"loaded all %s submodules"
,
self
.
num_artifacts
())
logger
.
debug
(
"loaded all %s submodules"
,
self
.
num_artifacts
())
def
__getstate__
(
self
):
def
__getstate__
(
self
)
->
dict
[
str
,
dict
[
str
,
str
]
|
dict
[
str
,
bytes
]]
:
return
{
return
{
"submodule_bytes"
:
self
.
submodule_bytes
,
"submodule_bytes"
:
self
.
submodule_bytes
,
"submodule_bytes_store"
:
self
.
submodule_bytes_store
,
"submodule_bytes_store"
:
self
.
submodule_bytes_store
,
}
}
def
__setstate__
(
self
,
state
)
:
def
__setstate__
(
self
,
state
:
dict
[
str
,
dict
[
str
,
Any
]])
->
None
:
self
.
submodule_bytes
=
state
[
"submodule_bytes"
]
self
.
submodule_bytes
=
state
[
"submodule_bytes"
]
self
.
submodule_bytes_store
=
state
[
"submodule_bytes_store"
]
self
.
submodule_bytes_store
=
state
[
"submodule_bytes_store"
]
self
.
loaded_submodule_store
=
{}
self
.
loaded_submodule_store
=
{}
...
@@ -387,7 +387,7 @@ def reconstruct_serializable_fn_from_mega_artifact(
...
@@ -387,7 +387,7 @@ def reconstruct_serializable_fn_from_mega_artifact(
standalone_compile_artifacts
.
load_all
()
standalone_compile_artifacts
.
load_all
()
submod_names
=
standalone_compile_artifacts
.
submodule_names
()
submod_names
=
standalone_compile_artifacts
.
submodule_names
()
compiled_callables
:
dict
[
str
,
dict
[
str
,
Callable
]]
=
{}
compiled_callables
:
dict
[
str
,
dict
[
str
,
Callable
[...,
Any
]
]]
=
{}
for
cache_key
in
standalone_compile_artifacts
.
submodule_bytes
:
for
cache_key
in
standalone_compile_artifacts
.
submodule_bytes
:
submod_name
,
shape_str
=
cache_key
.
rsplit
(
"_"
,
1
)
submod_name
,
shape_str
=
cache_key
.
rsplit
(
"_"
,
1
)
...
@@ -495,9 +495,10 @@ def _compute_code_hash_with_content(file_contents: dict[str, str]) -> str:
...
@@ -495,9 +495,10 @@ def _compute_code_hash_with_content(file_contents: dict[str, str]) -> str:
# e.g. exec(). We can't actually check these.
# e.g. exec(). We can't actually check these.
continue
continue
hash_content
.
append
(
content
)
hash_content
.
append
(
content
)
re
turn
safe_hash
(
re
sult
:
str
=
safe_hash
(
"
\n
"
.
join
(
hash_content
).
encode
(),
usedforsecurity
=
False
"
\n
"
.
join
(
hash_content
).
encode
(),
usedforsecurity
=
False
).
hexdigest
()
).
hexdigest
()
return
result
def
_compute_code_hash
(
files
:
set
[
str
])
->
str
:
def
_compute_code_hash
(
files
:
set
[
str
])
->
str
:
...
...
vllm/compilation/collective_fusion.py
View file @
15e302df
...
@@ -30,19 +30,15 @@ from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
...
@@ -30,19 +30,15 @@ from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
flashinfer_comm
:
ModuleType
|
None
=
None
if
find_spec
(
"flashinfer"
):
if
find_spec
(
"flashinfer"
):
try
:
try
:
import
flashinfer.comm
as
flashinfer_comm
import
flashinfer.comm
as
_
flashinfer_comm
flashinfer_comm
:
ModuleType
|
None
=
(
# type: ignore[no-redef]
if
hasattr
(
_flashinfer_comm
,
"trtllm_allreduce_fusion"
):
flashinfer_comm
flashinfer_comm
=
_flashinfer_comm
if
hasattr
(
flashinfer_comm
,
"trtllm_allreduce_fusion"
)
else
None
)
except
ImportError
:
except
ImportError
:
flashinfer_comm
=
None
# type: ignore[assignment]
pass
else
:
flashinfer_comm
=
None
# type: ignore[assignment]
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -441,7 +437,7 @@ class AsyncTPPass(VllmPatternMatcherPass):
...
@@ -441,7 +437,7 @@ class AsyncTPPass(VllmPatternMatcherPass):
):
):
return
True
return
True
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
return
compile_range
.
is_single_size
()
and
compile_range
.
end
%
tp_size
==
0
return
bool
(
compile_range
.
is_single_size
()
and
compile_range
.
end
%
tp_size
==
0
)
@
VllmInductorPass
.
time_and_log
@
VllmInductorPass
.
time_and_log
def
__call__
(
self
,
graph
:
fx
.
Graph
)
->
None
:
def
__call__
(
self
,
graph
:
fx
.
Graph
)
->
None
:
...
@@ -516,7 +512,7 @@ if flashinfer_comm is not None:
...
@@ -516,7 +512,7 @@ if flashinfer_comm is not None:
# Get one shot input size limit for the current world size
# Get one shot input size limit for the current world size
# for the current device capability
# for the current device capability
max_one_shot_size
=
_FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB
.
get
(
max_one_shot_size
=
_FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB
.
get
(
device_capability
,
# type: ignore[arg-type]
device_capability
,
# type: ignore[arg-type
, unused-ignore
]
{},
{},
).
get
(
world_size
,
None
)
).
get
(
world_size
,
None
)
# Use one shot if no max size is specified
# Use one shot if no max size is specified
...
@@ -666,6 +662,7 @@ class AllReduceRMSNormPattern(BasePattern):
...
@@ -666,6 +662,7 @@ class AllReduceRMSNormPattern(BasePattern):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
residual
=
torch
.
zeros_like
(
input
)
residual
=
torch
.
zeros_like
(
input
)
rms_result
=
torch
.
empty_like
(
input
)
rms_result
=
torch
.
empty_like
(
input
)
assert
flashinfer_comm
is
not
None
,
"FlashInfer must be enabled"
allreduce
=
auto_functionalized
(
allreduce
=
auto_functionalized
(
flashinfer_trtllm_fused_allreduce_norm
,
flashinfer_trtllm_fused_allreduce_norm
,
allreduce_in
=
input
,
allreduce_in
=
input
,
...
@@ -722,6 +719,7 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
...
@@ -722,6 +719,7 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
def
replacement
(
def
replacement
(
residual
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
residual
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
flashinfer_comm
is
not
None
,
"FlashInfer must be enabled"
allreduce
=
auto_functionalized
(
allreduce
=
auto_functionalized
(
flashinfer_trtllm_fused_allreduce_norm
,
flashinfer_trtllm_fused_allreduce_norm
,
allreduce_in
=
input
,
allreduce_in
=
input
,
...
@@ -800,6 +798,7 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
...
@@ -800,6 +798,7 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
residual
=
torch
.
zeros_like
(
input
)
residual
=
torch
.
zeros_like
(
input
)
result_rms
=
torch
.
empty_like
(
input
)
result_rms
=
torch
.
empty_like
(
input
)
result_quant
=
torch
.
empty_like
(
input
,
dtype
=
self
.
quant_dtype
)
result_quant
=
torch
.
empty_like
(
input
,
dtype
=
self
.
quant_dtype
)
assert
flashinfer_comm
is
not
None
,
"FlashInfer must be enabled"
allreduce
=
auto_functionalized
(
allreduce
=
auto_functionalized
(
flashinfer_trtllm_fused_allreduce_norm
,
flashinfer_trtllm_fused_allreduce_norm
,
allreduce_in
=
input
,
allreduce_in
=
input
,
...
@@ -875,6 +874,7 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
...
@@ -875,6 +874,7 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
scale
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
result_quant
=
torch
.
empty_like
(
input
,
dtype
=
self
.
quant_dtype
)
result_quant
=
torch
.
empty_like
(
input
,
dtype
=
self
.
quant_dtype
)
assert
flashinfer_comm
is
not
None
,
"FlashInfer must be enabled"
allreduce
=
auto_functionalized
(
allreduce
=
auto_functionalized
(
flashinfer_trtllm_fused_allreduce_norm
,
flashinfer_trtllm_fused_allreduce_norm
,
allreduce_in
=
input
,
allreduce_in
=
input
,
...
@@ -960,6 +960,7 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
...
@@ -960,6 +960,7 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
residual
=
torch
.
zeros_like
(
input
)
residual
=
torch
.
zeros_like
(
input
)
result_rms
=
torch
.
empty_like
(
input
)
result_rms
=
torch
.
empty_like
(
input
)
assert
flashinfer_comm
is
not
None
,
"FlashInfer must be enabled"
allreduce
=
auto_functionalized
(
allreduce
=
auto_functionalized
(
flashinfer_trtllm_fused_allreduce_norm
,
flashinfer_trtllm_fused_allreduce_norm
,
allreduce_in
=
input
,
allreduce_in
=
input
,
...
@@ -1055,6 +1056,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
...
@@ -1055,6 +1056,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
input_global_scale
:
torch
.
Tensor
,
input_global_scale
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
assert
flashinfer_comm
is
not
None
,
"FlashInfer must be enabled"
allreduce
=
auto_functionalized
(
allreduce
=
auto_functionalized
(
flashinfer_trtllm_fused_allreduce_norm
,
flashinfer_trtllm_fused_allreduce_norm
,
allreduce_in
=
input
,
allreduce_in
=
input
,
...
@@ -1131,7 +1133,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
...
@@ -1131,7 +1133,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
)
)
self
.
ipc_handles
,
workspace_tensor
=
(
self
.
ipc_handles
,
workspace_tensor
=
(
flashinfer_comm
.
trtllm_create_ipc_workspace_for_all_reduce_fusion
(
# type: ignore[misc]
flashinfer_comm
.
trtllm_create_ipc_workspace_for_all_reduce_fusion
(
tp_rank
=
rank
,
tp_rank
=
rank
,
tp_size
=
self
.
tp_size
,
tp_size
=
self
.
tp_size
,
max_token_num
=
self
.
max_token_num
,
max_token_num
=
self
.
max_token_num
,
...
@@ -1204,7 +1206,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
...
@@ -1204,7 +1206,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
if
self
.
disabled
:
if
self
.
disabled
:
logger
.
warning_once
(
"AllReduce fusion pass is disabled."
)
logger
.
warning_once
(
"AllReduce fusion pass is disabled."
)
return
False
return
False
return
compile_range
.
end
<=
self
.
max_token_num
return
bool
(
compile_range
.
end
<=
self
.
max_token_num
)
@
VllmInductorPass
.
time_and_log
@
VllmInductorPass
.
time_and_log
def
__call__
(
self
,
graph
:
fx
.
Graph
)
->
None
:
def
__call__
(
self
,
graph
:
fx
.
Graph
)
->
None
:
...
...
vllm/compilation/compiler_interface.py
View file @
15e302df
...
@@ -201,9 +201,9 @@ class InductorStandaloneAdaptor(CompilerInterface):
...
@@ -201,9 +201,9 @@ class InductorStandaloneAdaptor(CompilerInterface):
def
compute_hash
(
self
,
vllm_config
:
VllmConfig
)
->
str
:
def
compute_hash
(
self
,
vllm_config
:
VllmConfig
)
->
str
:
factors
=
get_inductor_factors
()
factors
=
get_inductor_factors
()
hash_str
=
safe_hash
(
str
(
factors
).
encode
(),
usedforsecurity
=
False
).
hexdigest
()[
hash_str
:
str
=
safe_hash
(
:
10
str
(
factors
).
encode
(),
usedforsecurity
=
False
]
).
hexdigest
()[:
10
]
return
hash_str
return
hash_str
def
initialize_cache
(
def
initialize_cache
(
...
@@ -319,9 +319,9 @@ class InductorAdaptor(CompilerInterface):
...
@@ -319,9 +319,9 @@ class InductorAdaptor(CompilerInterface):
def
compute_hash
(
self
,
vllm_config
:
VllmConfig
)
->
str
:
def
compute_hash
(
self
,
vllm_config
:
VllmConfig
)
->
str
:
factors
=
get_inductor_factors
()
factors
=
get_inductor_factors
()
hash_str
=
safe_hash
(
str
(
factors
).
encode
(),
usedforsecurity
=
False
).
hexdigest
()[
hash_str
:
str
=
safe_hash
(
:
10
str
(
factors
).
encode
(),
usedforsecurity
=
False
]
).
hexdigest
()[:
10
]
return
hash_str
return
hash_str
def
initialize_cache
(
def
initialize_cache
(
...
...
vllm/compilation/decorators.py
View file @
15e302df
...
@@ -45,10 +45,10 @@ logger = init_logger(__name__)
...
@@ -45,10 +45,10 @@ logger = init_logger(__name__)
IGNORE_COMPILE_KEY
=
"_ignore_compile_vllm"
IGNORE_COMPILE_KEY
=
"_ignore_compile_vllm"
_T
=
TypeVar
(
"_T"
,
bound
=
type
[
nn
.
Module
]
)
_T
=
TypeVar
(
"_T"
,
bound
=
nn
.
Module
)
def
ignore_torch_compile
(
cls
:
_T
)
->
_T
:
def
ignore_torch_compile
(
cls
:
type
[
_T
]
)
->
type
[
_T
]
:
"""
"""
A decorator to ignore support_torch_compile decorator
A decorator to ignore support_torch_compile decorator
on the class. This is useful when a parent class has
on the class. This is useful when a parent class has
...
@@ -68,7 +68,7 @@ def ignore_torch_compile(cls: _T) -> _T:
...
@@ -68,7 +68,7 @@ def ignore_torch_compile(cls: _T) -> _T:
return
cls
return
cls
def
_should_ignore_torch_compile
(
cls
:
_T
)
->
bool
:
def
_should_ignore_torch_compile
(
cls
:
type
[
_T
]
)
->
bool
:
"""
"""
Check if the class should be ignored for torch.compile.
Check if the class should be ignored for torch.compile.
"""
"""
...
@@ -79,21 +79,21 @@ def _should_ignore_torch_compile(cls: _T) -> bool:
...
@@ -79,21 +79,21 @@ def _should_ignore_torch_compile(cls: _T) -> bool:
def
support_torch_compile
(
def
support_torch_compile
(
*
,
*
,
enable_if
:
Callable
[[
VllmConfig
],
bool
]
|
None
=
None
,
enable_if
:
Callable
[[
VllmConfig
],
bool
]
|
None
=
None
,
)
->
Callable
[[
_T
],
_T
]:
...
)
->
Callable
[[
type
[
_T
]
]
,
type
[
_T
]
]
:
...
@
overload
@
overload
def
support_torch_compile
(
def
support_torch_compile
(
*
,
*
,
dynamic_arg_dims
:
dict
[
str
,
int
|
list
[
int
]]
|
None
,
dynamic_arg_dims
:
dict
[
str
,
int
|
list
[
int
]]
|
None
,
)
->
Callable
[[
_T
],
_T
]:
...
)
->
Callable
[[
type
[
_T
]
]
,
type
[
_T
]
]
:
...
@
overload
@
overload
def
support_torch_compile
(
def
support_torch_compile
(
*
,
*
,
mark_unbacked_dims
:
dict
[
str
,
int
|
list
[
int
]]
|
None
,
mark_unbacked_dims
:
dict
[
str
,
int
|
list
[
int
]]
|
None
,
)
->
Callable
[[
_T
],
_T
]:
...
)
->
Callable
[[
type
[
_T
]
]
,
type
[
_T
]
]
:
...
@
overload
@
overload
...
@@ -101,21 +101,21 @@ def support_torch_compile(
...
@@ -101,21 +101,21 @@ def support_torch_compile(
*
,
*
,
dynamic_arg_dims
:
dict
[
str
,
int
|
list
[
int
]]
|
None
,
dynamic_arg_dims
:
dict
[
str
,
int
|
list
[
int
]]
|
None
,
mark_unbacked_dims
:
dict
[
str
,
int
|
list
[
int
]]
|
None
,
mark_unbacked_dims
:
dict
[
str
,
int
|
list
[
int
]]
|
None
,
)
->
Callable
[[
_T
],
_T
]:
...
)
->
Callable
[[
type
[
_T
]
]
,
type
[
_T
]
]
:
...
@
overload
@
overload
def
support_torch_compile
(
cls
:
_T
)
->
_T
:
...
def
support_torch_compile
(
cls
:
type
[
_T
]
)
->
type
[
_T
]
:
...
def
support_torch_compile
(
def
support_torch_compile
(
cls
:
_T
|
None
=
None
,
cls
:
type
[
_T
]
|
None
=
None
,
*
,
*
,
dynamic_arg_dims
:
dict
[
str
,
int
|
list
[
int
]]
|
None
=
None
,
dynamic_arg_dims
:
dict
[
str
,
int
|
list
[
int
]]
|
None
=
None
,
mark_unbacked_dims
:
dict
[
str
,
int
|
list
[
int
]]
|
None
=
None
,
mark_unbacked_dims
:
dict
[
str
,
int
|
list
[
int
]]
|
None
=
None
,
enable_if
:
Callable
[[
VllmConfig
],
bool
]
|
None
=
None
,
enable_if
:
Callable
[[
VllmConfig
],
bool
]
|
None
=
None
,
shape_invariants
:
Callable
[...,
None
]
=
lambda
*
args
,
**
kwargs
:
None
,
shape_invariants
:
Callable
[...,
None
]
=
lambda
*
args
,
**
kwargs
:
None
,
)
->
Callable
[[
_T
],
_T
]
|
_T
:
)
->
Callable
[[
type
[
_T
]
]
,
type
[
_T
]
]
|
type
[
_T
]
:
"""
"""
A decorator to add support for compiling the forward method of a class.
A decorator to add support for compiling the forward method of a class.
...
@@ -182,7 +182,7 @@ def support_torch_compile(
...
@@ -182,7 +182,7 @@ def support_torch_compile(
errors.
errors.
"""
"""
def
cls_decorator_helper
(
cls
:
_T
)
->
_T
:
def
cls_decorator_helper
(
cls
:
type
[
_T
]
)
->
type
[
_T
]
:
# helper to pass `dynamic_arg_dims` to `_support_torch_compile`
# helper to pass `dynamic_arg_dims` to `_support_torch_compile`
# to avoid too much indentation for `_support_torch_compile`
# to avoid too much indentation for `_support_torch_compile`
if
not
hasattr
(
cls
,
"forward"
):
if
not
hasattr
(
cls
,
"forward"
):
...
@@ -263,12 +263,12 @@ def _verify_source_unchanged(
...
@@ -263,12 +263,12 @@ def _verify_source_unchanged(
def
_support_torch_compile
(
def
_support_torch_compile
(
cls
:
_T
,
cls
:
type
[
_T
]
,
dynamic_arg_dims
:
dict
[
str
,
int
|
list
[
int
]],
dynamic_arg_dims
:
dict
[
str
,
int
|
list
[
int
]],
mark_unbacked_dims
:
dict
[
str
,
int
|
list
[
int
]]
|
None
=
None
,
mark_unbacked_dims
:
dict
[
str
,
int
|
list
[
int
]]
|
None
=
None
,
enable_if
:
Callable
[[
VllmConfig
],
bool
]
|
None
=
None
,
enable_if
:
Callable
[[
VllmConfig
],
bool
]
|
None
=
None
,
shape_invariants
:
Callable
[...,
None
]
=
lambda
*
args
,
**
kwargs
:
None
,
shape_invariants
:
Callable
[...,
None
]
=
lambda
*
args
,
**
kwargs
:
None
,
)
->
_T
:
)
->
type
[
_T
]
:
"""
"""
A decorator to add support for compiling the forward method of a class.
A decorator to add support for compiling the forward method of a class.
"""
"""
...
@@ -325,12 +325,12 @@ def _support_torch_compile(
...
@@ -325,12 +325,12 @@ def _support_torch_compile(
self
.
compiled
=
False
self
.
compiled
=
False
# Handled by monkeypatching `TorchCompileWithNoGuardsWrapper` into base class
# Handled by monkeypatching `TorchCompileWithNoGuardsWrapper` into base class
TorchCompileWithNoGuardsWrapper
.
__init__
(
self
)
# type: ignore[arg-type]
TorchCompileWithNoGuardsWrapper
.
__init__
(
self
)
cls
.
__init__
=
__init__
cls
.
__init__
=
__init__
def
_mark_dynamic_inputs
(
def
_mark_dynamic_inputs
(
mod
:
_T
,
ds_type
:
DynamicShapesType
,
*
args
:
Any
,
**
kwargs
:
Any
mod
:
type
[
_T
]
,
ds_type
:
DynamicShapesType
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
None
:
)
->
None
:
def
mark_dynamic
(
arg
:
torch
.
Tensor
,
dims
:
list
[
int
])
->
None
:
def
mark_dynamic
(
arg
:
torch
.
Tensor
,
dims
:
list
[
int
])
->
None
:
if
ds_type
==
DynamicShapesType
.
UNBACKED
:
if
ds_type
==
DynamicShapesType
.
UNBACKED
:
...
@@ -382,7 +382,7 @@ def _support_torch_compile(
...
@@ -382,7 +382,7 @@ def _support_torch_compile(
else
:
else
:
torch
.
_dynamo
.
decorators
.
mark_unbacked
(
arg
,
dims
)
torch
.
_dynamo
.
decorators
.
mark_unbacked
(
arg
,
dims
)
def
__call__
(
self
:
_T
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
def
__call__
(
self
:
type
[
_T
]
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
# torch.compiler.is_compiling() means we are inside the compilation
# torch.compiler.is_compiling() means we are inside the compilation
# e.g. TPU has the compilation logic in model runner, so we don't
# e.g. TPU has the compilation logic in model runner, so we don't
# need to compile the model inside.
# need to compile the model inside.
...
@@ -564,7 +564,7 @@ def _support_torch_compile(
...
@@ -564,7 +564,7 @@ def _support_torch_compile(
return
output
return
output
# triggers VllmSerializableFunction.serialize()
# triggers VllmSerializableFunction.serialize()
def
save_aot_compiled_function
(
self
)
:
def
save_aot_compiled_function
(
self
:
type
[
_T
])
->
None
:
if
self
.
was_aot_compile_fn_loaded_from_disk
:
if
self
.
was_aot_compile_fn_loaded_from_disk
:
logger
.
debug
(
"AOT compiled function was loaded from cache, skipping save"
)
logger
.
debug
(
"AOT compiled function was loaded from cache, skipping save"
)
return
return
...
...
vllm/compilation/matcher_utils.py
View file @
15e302df
...
@@ -141,15 +141,18 @@ class MatcherRotaryEmbedding(MatcherCustomOp):
...
@@ -141,15 +141,18 @@ class MatcherRotaryEmbedding(MatcherCustomOp):
key
:
torch
.
Tensor
|
None
,
key
:
torch
.
Tensor
|
None
,
cos_sin_cache
:
torch
.
Tensor
,
cos_sin_cache
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
return
RotaryEmbedding
.
forward_static
(
result
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]
=
(
positions
,
RotaryEmbedding
.
forward_static
(
query
,
positions
,
key
,
query
,
self
.
head_size
,
key
,
self
.
rotary_dim
,
self
.
head_size
,
cos_sin_cache
,
self
.
rotary_dim
,
self
.
is_neox
,
cos_sin_cache
,
self
.
is_neox
,
)
)
)
return
result
class
MatcherRMSNorm
(
MatcherCustomOp
):
class
MatcherRMSNorm
(
MatcherCustomOp
):
...
@@ -275,9 +278,10 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp):
...
@@ -275,9 +278,10 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp):
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
re
turn
RMSNorm
.
forward_static
(
re
sult
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
=
RMSNorm
.
forward_static
(
input
,
self
.
epsilon
,
input
.
size
(
-
1
),
self
.
model_dtype
,
weight
,
residual
input
,
self
.
epsilon
,
input
.
size
(
-
1
),
self
.
model_dtype
,
weight
,
residual
)
)
return
result
class
MatcherQuantFP8
(
MatcherCustomOp
):
class
MatcherQuantFP8
(
MatcherCustomOp
):
...
...
vllm/compilation/piecewise_backend.py
View file @
15e302df
...
@@ -25,7 +25,7 @@ logger = init_logger(__name__)
...
@@ -25,7 +25,7 @@ logger = init_logger(__name__)
class
RangeEntry
:
class
RangeEntry
:
compile_range
:
Range
compile_range
:
Range
compiled
:
bool
=
False
compiled
:
bool
=
False
runnable
:
Callable
=
None
# type: ignore
runnable
:
Callable
[...,
Any
]
=
None
# type: ignore
class
PiecewiseBackend
:
class
PiecewiseBackend
:
...
@@ -38,7 +38,7 @@ class PiecewiseBackend:
...
@@ -38,7 +38,7 @@ class PiecewiseBackend:
sym_shape_indices
:
list
[
int
],
sym_shape_indices
:
list
[
int
],
vllm_backend
:
VllmBackend
,
vllm_backend
:
VllmBackend
,
returns_tuple
:
bool
,
returns_tuple
:
bool
,
compiled_runnables
:
dict
[
str
,
Callable
]
|
None
=
None
,
compiled_runnables
:
dict
[
str
,
Callable
[...,
Any
]
]
|
None
=
None
,
):
):
"""
"""
The backend for piecewise compilation.
The backend for piecewise compilation.
...
@@ -138,8 +138,10 @@ class PiecewiseBackend:
...
@@ -138,8 +138,10 @@ class PiecewiseBackend:
self
.
on_compilation_complete
=
_on_compilation_complete_callback
.
get
()
self
.
on_compilation_complete
=
_on_compilation_complete_callback
.
get
()
def
get_compiled_graph_wrapper
(
self
,
compiled_graph
):
def
get_compiled_graph_wrapper
(
def
compiled_graph_wrapper
(
*
args
):
self
,
compiled_graph
:
Callable
[...,
Any
]
)
->
Callable
[...,
Any
]:
def
compiled_graph_wrapper
(
*
args
:
Any
)
->
Any
:
graph_output
=
compiled_graph
(
*
args
)
graph_output
=
compiled_graph
(
*
args
)
# unpack the tuple if needed
# unpack the tuple if needed
# TODO(rzou): the implication is that we're not
# TODO(rzou): the implication is that we're not
...
@@ -163,7 +165,7 @@ class PiecewiseBackend:
...
@@ -163,7 +165,7 @@ class PiecewiseBackend:
def
to_bytes
(
self
)
->
dict
[
str
,
bytes
]:
def
to_bytes
(
self
)
->
dict
[
str
,
bytes
]:
class
StandaloneCompiledArtifactsPickler
(
Pickler
):
class
StandaloneCompiledArtifactsPickler
(
Pickler
):
def
reducer_override
(
self
,
obj
)
:
def
reducer_override
(
self
,
obj
:
object
)
->
Any
:
if
isinstance
(
obj
,
CachingAutotuner
):
if
isinstance
(
obj
,
CachingAutotuner
):
obj
.
prepare_for_pickle
()
obj
.
prepare_for_pickle
()
return
pickle
.
loads
,
(
return
pickle
.
loads
,
(
...
@@ -173,7 +175,7 @@ class PiecewiseBackend:
...
@@ -173,7 +175,7 @@ class PiecewiseBackend:
)
)
return
NotImplemented
return
NotImplemented
def
serialize
(
fn
)
->
bytes
:
def
serialize
(
fn
:
Callable
[...,
Any
]
)
->
bytes
:
assert
hasattr
(
fn
,
"serialize"
),
"fn must have serialize method"
assert
hasattr
(
fn
,
"serialize"
),
"fn must have serialize method"
with
torch
.
_functorch
.
config
.
patch
(
"bundled_autograd_cache"
,
True
):
with
torch
.
_functorch
.
config
.
patch
(
"bundled_autograd_cache"
,
True
):
entry
=
fn
.
serialize
()
entry
=
fn
.
serialize
()
...
...
vllm/compilation/sequence_parallelism.py
View file @
15e302df
...
@@ -358,7 +358,10 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
...
@@ -358,7 +358,10 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
):
):
return
True
return
True
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
return
(
compile_range
.
is_single_size
())
and
(
compile_range
.
end
%
tp_size
==
0
)
result
:
bool
=
(
compile_range
.
is_single_size
())
and
(
compile_range
.
end
%
tp_size
==
0
)
return
result
@
VllmInductorPass
.
time_and_log
@
VllmInductorPass
.
time_and_log
def
__call__
(
self
,
graph
:
fx
.
Graph
)
->
None
:
def
__call__
(
self
,
graph
:
fx
.
Graph
)
->
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment