backend.py 2.61 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from copy import deepcopy
4
from typing import Callable, Union
5

6
7
from torch import fx

8
9
from vllm.compilation.fx_utils import (find_specified_fn,
                                       find_specified_fn_maybe)
10
from vllm.compilation.inductor_pass import InductorPass
Michael Goin's avatar
Michael Goin committed
11
from vllm.config import get_current_vllm_config
12
13
14
15
16
17
18


class TestBackend:
    """
    This class provides a simple Inductor backend that can be used for testing.
    It takes a list of custom passes and runs them after Inductor's passes.
    It also saves the graph before and after the custom passes for inspection.
19
20
21
22

    Inductor config can be modified directly by editing the inductor_config
    property. This can be helpful for adding passes like the
    'pre_grad_custom_pass' and the 'post_grad_custom_pre_pass'.
Michael Goin's avatar
Michael Goin committed
23
    Inductor config is default-initialized from VllmConfig.CompilationConfig.
24
25
    """

26
27
28
    def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph],
                                                             None]]):
        self.custom_passes = list(passes)
Michael Goin's avatar
Michael Goin committed
29
30
        compile_config = get_current_vllm_config().compilation_config
        self.inductor_config = compile_config.inductor_compile_config
31
32
        self.inductor_config['force_disable_caches'] = True
        self.inductor_config['post_grad_custom_post_pass'] = self.post_pass
33

34
    def __call__(self, graph: fx.GraphModule, example_inputs):
35
        self.graph_pre_compile = deepcopy(graph)
36
37
38
        from torch._inductor.compile_fx import compile_fx
        return compile_fx(graph,
                          example_inputs,
39
                          config_patches=self.inductor_config)
40

41
    def post_pass(self, graph: fx.Graph):
42
43
44
45
46
47
48
        self.graph_pre_pass = deepcopy(graph)
        for pass_ in self.custom_passes:
            pass_(graph)

        self.graph_post_pass = deepcopy(graph)
        # assign by reference, will reflect the final state of the graph
        self.final_graph = graph
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64

    def check_before_ops(self, ops,
                         find_fn=find_specified_fn, \
                         find_fn_maybe=find_specified_fn_maybe, \
                        ops_fully_replaced=True):
        for op in ops:
            find_fn(self.graph_pre_pass.nodes, op)
            if ops_fully_replaced:
                assert find_fn_maybe(self.graph_post_pass.nodes, op) is None

    def check_after_ops(self, ops,
                        find_fn=find_specified_fn, \
                        find_fn_maybe=find_specified_fn_maybe):
        for op in ops:
            find_fn(self.graph_post_pass.nodes, op)
            assert find_fn_maybe(self.graph_pre_pass.nodes, op) is None