backend.py 1.65 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from copy import deepcopy
4
from typing import Callable, Union
5

6
7
8
from torch import fx

from vllm.compilation.inductor_pass import InductorPass
9
10
11
12
13
14
15


class TestBackend:
    """
    This class provides a simple Inductor backend that can be used for testing.
    It takes a list of custom passes and runs them after Inductor's passes.
    It also saves the graph before and after the custom passes for inspection.
16
17
18
19

    Inductor config can be modified directly by editing the inductor_config
    property. This can be helpful for adding passes like the
    'pre_grad_custom_pass' and the 'post_grad_custom_pre_pass'.
20
21
    """

22
23
24
    def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph],
                                                             None]]):
        self.custom_passes = list(passes)
25
        from torch._inductor import config
26
27
28
        self.inductor_config = config.shallow_copy_dict()
        self.inductor_config['force_disable_caches'] = True
        self.inductor_config['post_grad_custom_post_pass'] = self.post_pass
29

30
    def __call__(self, graph: fx.GraphModule, example_inputs):
31
        self.graph_pre_compile = deepcopy(graph)
32
33
34
        from torch._inductor.compile_fx import compile_fx
        return compile_fx(graph,
                          example_inputs,
35
                          config_patches=self.inductor_config)
36

37
    def post_pass(self, graph: fx.Graph):
38
39
40
41
42
43
44
        self.graph_pre_pass = deepcopy(graph)
        for pass_ in self.custom_passes:
            pass_(graph)

        self.graph_post_pass = deepcopy(graph)
        # assign by reference, will reflect the final state of the graph
        self.final_graph = graph