Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f622dbcf
Unverified
Commit
f622dbcf
authored
Mar 23, 2025
by
Luka Govedič
Committed by
GitHub
Mar 24, 2025
Browse files
[Fix] [torch.compile] Improve UUID system for custom passes (#15249)
Signed-off-by:
luka
<
luka@neuralmagic.com
>
parent
dccf535f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
125 additions
and
84 deletions
+125
-84
tests/compile/test_pass_manager.py
tests/compile/test_pass_manager.py
+46
-16
vllm/compilation/inductor_pass.py
vllm/compilation/inductor_pass.py
+25
-28
vllm/compilation/pass_manager.py
vllm/compilation/pass_manager.py
+9
-35
vllm/compilation/torch25_custom_graph_pass.py
vllm/compilation/torch25_custom_graph_pass.py
+41
-0
vllm/config.py
vllm/config.py
+4
-5
No files found.
tests/compile/test_pass_manager.py
View file @
f622dbcf
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
copy
import
pickle
import
pytest
import
pytest
import
torch
import
torch
...
@@ -10,32 +9,63 @@ from vllm.compilation.pass_manager import PostGradPassManager
...
@@ -10,32 +9,63 @@ from vllm.compilation.pass_manager import PostGradPassManager
from
vllm.config
import
CompilationConfig
from
vllm.config
import
CompilationConfig
# dummy custom pass that doesn't inherit
def
simple_callable
(
graph
:
torch
.
fx
.
Graph
):
def
simple_callable
(
graph
:
torch
.
fx
.
Graph
):
pass
pass
callable_uuid
=
CallableInductorPass
(
simple_callable
,
# Should fail to add directly to the pass manager
InductorPass
.
hash_source
(
__file__
))
def
test_bad_callable
():
config
=
CompilationConfig
().
pass_config
pass_manager
=
PostGradPassManager
()
pass_manager
.
configure
(
config
)
with
pytest
.
raises
(
AssertionError
):
pass_manager
.
add
(
simple_callable
)
# noqa, type wrong on purpose
# Pass that inherits from InductorPass
class
ProperPass
(
InductorPass
):
def
__call__
(
self
,
graph
:
torch
.
fx
.
graph
.
Graph
)
->
None
:
pass
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"
works,
callable"
,
"callable"
,
[
[
(
False
,
simple_callable
),
ProperPass
(),
(
True
,
callable_uuid
),
# Can also wrap callables in CallableInductorPass for compliance
(
True
,
CallableInductorPass
(
simple_callable
)),
CallableInductorPass
(
simple_callable
),
CallableInductorPass
(
simple_callable
,
InductorPass
.
hash_source
(
__file__
))
],
],
)
)
def
test_pass_manager
(
works
:
bool
,
callable
):
def
test_pass_manager
_uuid
(
callable
):
config
=
CompilationConfig
().
pass_config
config
=
CompilationConfig
().
pass_config
pass_manager
=
PostGradPassManager
()
pass_manager
=
PostGradPassManager
()
pass_manager
.
configure
(
config
)
pass_manager
.
configure
(
config
)
# Try to add the callable to the pass manager
# Check that UUID is different if the same pass is added 2x
if
works
:
pass_manager
.
add
(
callable
)
pass_manager
.
add
(
callable
)
uuid1
=
pass_manager
.
uuid
()
pickle
.
dumps
(
pass_manager
)
pass_manager
.
add
(
callable
)
else
:
uuid2
=
pass_manager
.
uuid
()
with
pytest
.
raises
(
AssertionError
):
assert
uuid1
!=
uuid2
pass_manager
.
add
(
callable
)
# UUID should be the same as the original one,
# as we constructed in the same way.
pass_manager2
=
PostGradPassManager
()
pass_manager2
.
configure
(
config
)
pass_manager2
.
add
(
callable
)
assert
uuid1
==
pass_manager2
.
uuid
()
# UUID should be different due to config change
config2
=
copy
.
deepcopy
(
config
)
config2
.
enable_fusion
=
not
config2
.
enable_fusion
pass_manager3
=
PostGradPassManager
()
pass_manager3
.
configure
(
config2
)
pass_manager3
.
add
(
callable
)
assert
uuid1
!=
pass_manager3
.
uuid
()
vllm/compilation/inductor_pass.py
View file @
f622dbcf
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
hashlib
import
hashlib
import
importlib.metadata
import
inspect
import
inspect
import
json
import
types
import
types
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
Union
import
torch
import
torch
from
packaging.version
import
Version
from
torch
import
fx
from
torch
import
fx
if
Version
(
importlib
.
metadata
.
version
(
'torch'
))
>=
Version
(
"2.6"
):
from
torch._inductor.custom_graph_pass
import
CustomGraphPass
else
:
# CustomGraphPass is not present in 2.5 or lower, import our version
from
.torch25_custom_graph_pass
import
(
# noqa: yapf
Torch25CustomGraphPass
as
CustomGraphPass
)
class
InductorPass
(
ABC
):
class
InductorPass
(
CustomGraphPass
):
"""
"""
General custom inductor pass interface.
A custom graph pass that uses a hash of its source as the UUID.
This is defined as a convenience and should work in most cases.
"""
"""
@
abstractmethod
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
):
"""
Execute the pass on the given graph.
"""
raise
NotImplementedError
def
uuid
(
self
)
->
Any
:
def
uuid
(
self
)
->
Any
:
"""
"""
Provide a unique identifier for the pass, used in Inductor code cache.
Provide a unique identifier for the pass, used in Inductor code cache.
...
@@ -48,7 +51,16 @@ class InductorPass(ABC):
...
@@ -48,7 +51,16 @@ class InductorPass(ABC):
else
:
else
:
src_str
=
inspect
.
getsource
(
src
.
__class__
)
src_str
=
inspect
.
getsource
(
src
.
__class__
)
hasher
.
update
(
src_str
.
encode
(
"utf-8"
))
hasher
.
update
(
src_str
.
encode
(
"utf-8"
))
return
hasher
.
digest
()
return
hasher
.
hexdigest
()
@
staticmethod
def
hash_dict
(
dict_
:
Dict
[
Any
,
Any
]):
"""
Utility method to hash a dictionary, can alternatively be used for uuid.
:return: A sha256 hash of the json rep of the dictionary.
"""
encoded
=
json
.
dumps
(
dict_
,
sort_keys
=
True
).
encode
(
"utf-8"
)
return
hashlib
.
sha256
(
encoded
).
hexdigest
()
class
CallableInductorPass
(
InductorPass
):
class
CallableInductorPass
(
InductorPass
):
...
@@ -61,25 +73,10 @@ class CallableInductorPass(InductorPass):
...
@@ -61,25 +73,10 @@ class CallableInductorPass(InductorPass):
callable
:
Callable
[[
fx
.
Graph
],
None
],
callable
:
Callable
[[
fx
.
Graph
],
None
],
uuid
:
Optional
[
Any
]
=
None
):
uuid
:
Optional
[
Any
]
=
None
):
self
.
callable
=
callable
self
.
callable
=
callable
if
uuid
is
None
:
self
.
_uuid
=
self
.
hash_source
(
callable
)
if
uuid
is
None
else
uuid
uuid
=
InductorPass
.
hash_source
(
callable
)
self
.
_uuid
=
uuid
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
):
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
):
self
.
callable
(
graph
)
self
.
callable
(
graph
)
def
uuid
(
self
)
->
Any
:
def
uuid
(
self
)
->
Any
:
return
self
.
_uuid
return
self
.
_uuid
def
__getstate__
(
self
):
"""
Pickling occurs in the Inductor code cache if a pass is not given to
the pass manager but is instead directly added to config as a pass.
See PostGradPassManager for more.
TODO(torch==2.6), use the `uuid` method in CustomGraphPass instead.
"""
return
self
.
_uuid
def
__setstate__
(
self
,
state
):
raise
ValueError
(
"Cannot unpickle CallableInductorPass"
)
vllm/compilation/pass_manager.py
View file @
f622dbcf
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Any
,
Dict
,
List
from
typing
import
List
import
torch
from
torch
import
fx
as
fx
from
torch
import
fx
as
fx
from
vllm.config
import
CompilationConfig
from
vllm.config
import
CompilationConfig
...
@@ -10,29 +9,18 @@ from vllm.logger import init_logger
...
@@ -10,29 +9,18 @@ from vllm.logger import init_logger
from
.fix_functionalization
import
FixFunctionalizationPass
from
.fix_functionalization
import
FixFunctionalizationPass
from
.fusion
import
FusionPass
from
.fusion
import
FusionPass
from
.inductor_pass
import
InductorPass
from
.inductor_pass
import
CustomGraphPass
,
InductorPass
from
.noop_elimination
import
NoOpEliminationPass
from
.noop_elimination
import
NoOpEliminationPass
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
class
PlaceHolder
:
class
PostGradPassManager
(
CustomGraphPass
):
pass
if
torch
.
__version__
<
"2.6"
:
Parent
=
PlaceHolder
# type: ignore
else
:
Parent
=
torch
.
_inductor
.
custom_graph_pass
.
CustomGraphPass
# type: ignore
class
PostGradPassManager
(
Parent
):
"""
"""
The pass manager for post-grad passes.
The pass manager for post-grad passes.
It handles configuration, adding custom passes, and running passes.
It handles configuration, adding custom passes, and running passes.
It also supports pickling, which is used by the Inductor code cache.
It supports uuid for the Inductor code cache. That includes torch<2.6
TODO(torch==2.6), use CustomGraphPass
support using pickling (in .inductor_pass.CustomGraphPass).
(torch._inductor.custom_graph_pass.CustomGraphPass)
The order of the post-grad post-passes is:
The order of the post-grad post-passes is:
1. passes (constructor parameter)
1. passes (constructor parameter)
...
@@ -67,27 +55,13 @@ class PostGradPassManager(Parent):
...
@@ -67,27 +55,13 @@ class PostGradPassManager(Parent):
self
.
passes
.
append
(
pass_
)
self
.
passes
.
append
(
pass_
)
def
uuid
(
self
):
def
uuid
(
self
):
return
self
.
__getstate__
()
def
__getstate__
(
self
)
->
Dict
[
str
,
List
[
Any
]]:
"""
"""
Custom pickling for the pass manager, as some passes cannot be pickled.
The PostGradPassManager is set as a custom pass in the Inductor and
Pickling occurs because the pass manager is set as the value of
affects compilation caching. Its uuid depends on the UUIDs of all
`config["post_grad_custom_post_pass"]` in the Inductor config.
dependent passes and the pass config. See InductorPass for more info.
The config is pickled to act as a key in the Inductor code cache.
Any other passes in the config are pickled as well.
TODO(torch==2.6), use the `uuid` method in CustomGraphPass instead.
"""
"""
state
=
{
"pass_config"
:
self
.
pass_config
.
uuid
(),
"passes"
:
[]}
state
=
{
"pass_config"
:
self
.
pass_config
.
uuid
(),
"passes"
:
[]}
for
pass_
in
self
.
passes
:
for
pass_
in
self
.
passes
:
state
[
"passes"
].
append
(
pass_
.
uuid
())
state
[
"passes"
].
append
(
pass_
.
uuid
())
state
[
"passes"
].
append
(
self
.
fix_functionalization
.
uuid
())
state
[
"passes"
].
append
(
self
.
fix_functionalization
.
uuid
())
return
state
return
InductorPass
.
hash_dict
(
state
)
def
__setstate__
(
self
,
state
):
"""
Do not allow unpickling of the pass manager.
If this is needed in the future, it should properly pickle the passes.
"""
raise
ValueError
(
"Cannot unpickle PostGradPassManager"
)
vllm/compilation/torch25_custom_graph_pass.py
0 → 100644
View file @
f622dbcf
# SPDX-License-Identifier: Apache-2.0
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Optional
import
torch
class
Torch25CustomGraphPass
(
ABC
):
# noqa (redefinition)
"""
This class replaces CustomGraphPass from torch==2.6 when using torch<2.6.
It conforms to the 2.6 interface but also supports pickling, as that's what
the inductor code cache uses to determine the cache key before 2.6.
(in 2.6 and above, uuid() is used.)
Subclasses can just "pretend" that uuid is used.
"""
@
abstractmethod
def
__call__
(
self
,
graph
:
torch
.
fx
.
graph
.
Graph
)
->
None
:
"""
Implementation of the custom pass.
"""
@
abstractmethod
def
uuid
(
self
)
->
Optional
[
Any
]:
"""
Return an ID to uniquely identify your custom pass implementation.
Return None to skip inductor code caching entirely.
"""
def
__getstate__
(
self
):
"""
Pickling is used instead of uuid() in torch<2.6. Just return uuid()
to enable subclasses to only have to implement uuid.
"""
return
self
.
uuid
()
def
__setstate__
(
self
,
state
):
raise
ValueError
(
"Cannot unpickle CustomGraphPass because pickling"
" is used for cache key uuid. Use torch>=2.6 with"
" native uuid support for custom passes."
)
vllm/config.py
View file @
f622dbcf
...
@@ -4,6 +4,7 @@ import ast
...
@@ -4,6 +4,7 @@ import ast
import
copy
import
copy
import
enum
import
enum
import
hashlib
import
hashlib
import
importlib.metadata
import
json
import
json
import
sys
import
sys
import
warnings
import
warnings
...
@@ -17,6 +18,7 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
...
@@ -17,6 +18,7 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
Optional
,
Protocol
,
Union
)
Optional
,
Protocol
,
Union
)
import
torch
import
torch
from
packaging.version
import
Version
from
pydantic
import
BaseModel
,
Field
,
PrivateAttr
from
pydantic
import
BaseModel
,
Field
,
PrivateAttr
from
torch.distributed
import
ProcessGroup
,
ReduceOp
from
torch.distributed
import
ProcessGroup
,
ReduceOp
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
...
@@ -52,8 +54,6 @@ if TYPE_CHECKING:
...
@@ -52,8 +54,6 @@ if TYPE_CHECKING:
else
:
else
:
QuantizationConfig
=
None
QuantizationConfig
=
None
from
packaging.version
import
Version
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
# This value is chosen to have a balance between ITL and TTFT. Note it is
# This value is chosen to have a balance between ITL and TTFT. Note it is
...
@@ -3088,8 +3088,7 @@ class CompilationConfig(BaseModel):
...
@@ -3088,8 +3088,7 @@ class CompilationConfig(BaseModel):
compilation.
compilation.
"""
"""
dict_
=
self
.
model_dump
(
include
=
{
"enable_fusion"
,
"enable_noop"
})
dict_
=
self
.
model_dump
(
include
=
{
"enable_fusion"
,
"enable_noop"
})
encoded
=
json
.
dumps
(
dict_
,
sort_keys
=
True
).
encode
(
"utf-8"
)
return
InductorPass
.
hash_dict
(
dict_
)
return
hashlib
.
sha256
(
encoded
).
digest
()
def
model_post_init
(
self
,
__context
:
Any
)
->
None
:
def
model_post_init
(
self
,
__context
:
Any
)
->
None
:
if
not
self
.
enable_noop
and
self
.
enable_fusion
:
if
not
self
.
enable_noop
and
self
.
enable_fusion
:
...
@@ -3178,7 +3177,7 @@ class CompilationConfig(BaseModel):
...
@@ -3178,7 +3177,7 @@ class CompilationConfig(BaseModel):
# and it is not yet a priority. RFC here:
# and it is not yet a priority. RFC here:
# https://github.com/vllm-project/vllm/issues/14703
# https://github.com/vllm-project/vllm/issues/14703
if
Version
(
torch
.
__version__
)
>=
Version
(
"2.6"
):
if
Version
(
importlib
.
metadata
.
version
(
'torch'
)
)
>=
Version
(
"2.6"
):
KEY
=
'enable_auto_functionalized_v2'
KEY
=
'enable_auto_functionalized_v2'
if
KEY
not
in
self
.
inductor_compile_config
:
if
KEY
not
in
self
.
inductor_compile_config
:
self
.
inductor_compile_config
[
KEY
]
=
False
self
.
inductor_compile_config
[
KEY
]
=
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment