backend.py 1.37 KB
Newer Older
1
from copy import deepcopy
2
from typing import Callable, Union
3

4
5
6
from torch import fx

from vllm.compilation.inductor_pass import InductorPass
7
8
9
10
11
12
13
14
15


class TestBackend:
    """
    This class provides a simple Inductor backend that can be used for testing.
    It takes a list of custom passes and runs them after Inductor's passes.
    It also saves the graph before and after the custom passes for inspection.
    """

16
17
18
    def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph],
                                                             None]]):
        self.custom_passes = list(passes)
19
20
        from torch._inductor import config
        self.current_config = config.shallow_copy_dict()
21
        self.current_config['force_disable_caches'] = True
22
23
        self.current_config['post_grad_custom_post_pass'] = self.post_pass

24
    def __call__(self, graph: fx.GraphModule, example_inputs):
25
26
27
28
29
        from torch._inductor.compile_fx import compile_fx
        return compile_fx(graph,
                          example_inputs,
                          config_patches=self.current_config)

30
    def post_pass(self, graph: fx.Graph):
31
32
33
34
35
36
37
        self.graph_pre_pass = deepcopy(graph)
        for pass_ in self.custom_passes:
            pass_(graph)

        self.graph_post_pass = deepcopy(graph)
        # assign by reference, will reflect the final state of the graph
        self.final_graph = graph