test_pass_manager.py 2.06 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import copy
4
5
6
7

import pytest
import torch

Jovan Sardinha's avatar
Jovan Sardinha committed
8
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
9
from vllm.compilation.pass_manager import PostGradPassManager
10
from vllm.config import VllmConfig
11
12


13
# dummy custom pass that doesn't inherit
14
15
16
17
def simple_callable(graph: torch.fx.Graph):
    pass


18
19
# Should fail to add directly to the pass manager
def test_bad_callable():
20
    config = VllmConfig()
21
22
23
24
25

    pass_manager = PostGradPassManager()
    pass_manager.configure(config)

    with pytest.raises(AssertionError):
26
        pass_manager.add(simple_callable)
27
28
29
30
31
32
33


# Pass that inherits from InductorPass
class ProperPass(InductorPass):

    def __call__(self, graph: torch.fx.graph.Graph) -> None:
        pass
34
35
36


@pytest.mark.parametrize(
37
    "callable",
Jovan Sardinha's avatar
Jovan Sardinha committed
38
    [
39
40
41
42
43
        ProperPass(),
        # Can also wrap callables in CallableInductorPass for compliance
        CallableInductorPass(simple_callable),
        CallableInductorPass(simple_callable,
                             InductorPass.hash_source(__file__))
Jovan Sardinha's avatar
Jovan Sardinha committed
44
45
    ],
)
46
def test_pass_manager_uuid(callable):
47
    config = VllmConfig()
48

Jovan Sardinha's avatar
Jovan Sardinha committed
49
50
51
    pass_manager = PostGradPassManager()
    pass_manager.configure(config)

52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    # Check that UUID is different if the same pass is added 2x
    pass_manager.add(callable)
    uuid1 = pass_manager.uuid()
    pass_manager.add(callable)
    uuid2 = pass_manager.uuid()
    assert uuid1 != uuid2

    # UUID should be the same as the original one,
    # as we constructed in the same way.
    pass_manager2 = PostGradPassManager()
    pass_manager2.configure(config)
    pass_manager2.add(callable)
    assert uuid1 == pass_manager2.uuid()

    # UUID should be different due to config change
    config2 = copy.deepcopy(config)
68
69
    config2.compilation_config.pass_config.enable_fusion = not \
        config2.compilation_config.pass_config.enable_fusion
70
71
72
73
    pass_manager3 = PostGradPassManager()
    pass_manager3.configure(config2)
    pass_manager3.add(callable)
    assert uuid1 != pass_manager3.uuid()