test_pass_manager.py 2.02 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
import copy
3
4
5
6

import pytest
import torch

Jovan Sardinha's avatar
Jovan Sardinha committed
7
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
8
from vllm.compilation.pass_manager import PostGradPassManager
9
from vllm.config import VllmConfig
10
11


12
# dummy custom pass that doesn't inherit
13
14
15
16
def simple_callable(graph: torch.fx.Graph):
    pass


17
18
# Should fail to add directly to the pass manager
def test_bad_callable():
19
    config = VllmConfig()
20
21
22
23
24
25
26
27
28
29
30
31
32

    pass_manager = PostGradPassManager()
    pass_manager.configure(config)

    with pytest.raises(AssertionError):
        pass_manager.add(simple_callable)  # noqa, type wrong on purpose


# Pass that inherits from InductorPass
class ProperPass(InductorPass):

    def __call__(self, graph: torch.fx.graph.Graph) -> None:
        pass
33
34
35


@pytest.mark.parametrize(
36
    "callable",
Jovan Sardinha's avatar
Jovan Sardinha committed
37
    [
38
39
40
41
42
        ProperPass(),
        # Can also wrap callables in CallableInductorPass for compliance
        CallableInductorPass(simple_callable),
        CallableInductorPass(simple_callable,
                             InductorPass.hash_source(__file__))
Jovan Sardinha's avatar
Jovan Sardinha committed
43
44
    ],
)
45
def test_pass_manager_uuid(callable):
46
    config = VllmConfig()
47

Jovan Sardinha's avatar
Jovan Sardinha committed
48
49
50
    pass_manager = PostGradPassManager()
    pass_manager.configure(config)

51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    # Check that UUID is different if the same pass is added 2x
    pass_manager.add(callable)
    uuid1 = pass_manager.uuid()
    pass_manager.add(callable)
    uuid2 = pass_manager.uuid()
    assert uuid1 != uuid2

    # UUID should be the same as the original one,
    # as we constructed in the same way.
    pass_manager2 = PostGradPassManager()
    pass_manager2.configure(config)
    pass_manager2.add(callable)
    assert uuid1 == pass_manager2.uuid()

    # UUID should be different due to config change
    config2 = copy.deepcopy(config)
67
68
    config2.compilation_config.pass_config.enable_fusion = not \
        config2.compilation_config.pass_config.enable_fusion
69
70
71
72
    pass_manager3 = PostGradPassManager()
    pass_manager3.configure(config2)
    pass_manager3.add(callable)
    assert uuid1 != pass_manager3.uuid()