test_pass_manager.py 2.27 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import copy
4
5
6
7

import pytest
import torch

Jovan Sardinha's avatar
Jovan Sardinha committed
8
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
9
from vllm.compilation.pass_manager import PostGradPassManager
10
from vllm.config import ModelConfig, VllmConfig
11
12


13
# dummy custom pass that doesn't inherit
14
15
16
17
def simple_callable(graph: torch.fx.Graph):
    pass


18
19
# Should fail to add directly to the pass manager
def test_bad_callable():
20
    config = VllmConfig()
21
22
23
24
25

    pass_manager = PostGradPassManager()
    pass_manager.configure(config)

    with pytest.raises(AssertionError):
26
        pass_manager.add(simple_callable)
27
28
29
30
31
32


# Pass that inherits from InductorPass
class ProperPass(InductorPass):
    def __call__(self, graph: torch.fx.graph.Graph) -> None:
        pass
33
34
35


@pytest.mark.parametrize(
36
    "callable",
Jovan Sardinha's avatar
Jovan Sardinha committed
37
    [
38
39
40
        ProperPass(),
        # Can also wrap callables in CallableInductorPass for compliance
        CallableInductorPass(simple_callable),
41
        CallableInductorPass(simple_callable, InductorPass.hash_source(__file__)),
Jovan Sardinha's avatar
Jovan Sardinha committed
42
43
    ],
)
44
def test_pass_manager_uuid(callable):
45
46
    # Some passes need dtype to be set
    config = VllmConfig(model_config=ModelConfig(dtype=torch.bfloat16))
47

Jovan Sardinha's avatar
Jovan Sardinha committed
48
49
50
    pass_manager = PostGradPassManager()
    pass_manager.configure(config)

51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    # Check that UUID is different if the same pass is added 2x
    pass_manager.add(callable)
    uuid1 = pass_manager.uuid()
    pass_manager.add(callable)
    uuid2 = pass_manager.uuid()
    assert uuid1 != uuid2

    # UUID should be the same as the original one,
    # as we constructed in the same way.
    pass_manager2 = PostGradPassManager()
    pass_manager2.configure(config)
    pass_manager2.add(callable)
    assert uuid1 == pass_manager2.uuid()

    # UUID should be different due to config change
    config2 = copy.deepcopy(config)
67
68
69
70
71
    config2.compilation_config.pass_config.fuse_norm_quant = (
        not config2.compilation_config.pass_config.fuse_norm_quant
    )
    config2.compilation_config.pass_config.fuse_act_quant = (
        not config2.compilation_config.pass_config.fuse_act_quant
72
    )
73
74
75
76
    pass_manager3 = PostGradPassManager()
    pass_manager3.configure(config2)
    pass_manager3.add(callable)
    assert uuid1 != pass_manager3.uuid()