test_pass_manager.py 1.01 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
import pickle

import pytest
import torch

Jovan Sardinha's avatar
Jovan Sardinha committed
8
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
9
from vllm.compilation.pass_manager import PostGradPassManager
Jovan Sardinha's avatar
Jovan Sardinha committed
10
from vllm.config import CompilationConfig
11
12
13
14
15
16


def simple_callable(graph: torch.fx.Graph):
    pass


Jovan Sardinha's avatar
Jovan Sardinha committed
17
18
callable_uuid = CallableInductorPass(simple_callable,
                                     InductorPass.hash_source(__file__))
19
20
21
22


@pytest.mark.parametrize(
    "works, callable",
Jovan Sardinha's avatar
Jovan Sardinha committed
23
24
25
26
27
28
    [
        (False, simple_callable),
        (True, callable_uuid),
        (True, CallableInductorPass(simple_callable)),
    ],
)
29
30
31
def test_pass_manager(works: bool, callable):
    config = CompilationConfig().pass_config

Jovan Sardinha's avatar
Jovan Sardinha committed
32
33
34
35
    pass_manager = PostGradPassManager()
    pass_manager.configure(config)

    # Try to add the callable to the pass manager
36
    if works:
Jovan Sardinha's avatar
Jovan Sardinha committed
37
        pass_manager.add(callable)
38
39
        pickle.dumps(pass_manager)
    else:
Jovan Sardinha's avatar
Jovan Sardinha committed
40
41
        with pytest.raises(AssertionError):
            pass_manager.add(callable)