"vllm/model_executor/models/glm4_1v.py" did not exist on "a35ca765a52fff242edf0e9fd3203ea2534aed58"
backend.py 1.41 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from copy import deepcopy
4
from typing import Callable, Union
5

6
7
8
from torch import fx

from vllm.compilation.inductor_pass import InductorPass
9
10
11
12
13
14
15
16
17


class TestBackend:
    """
    This class provides a simple Inductor backend that can be used for testing.
    It takes a list of custom passes and runs them after Inductor's passes.
    It also saves the graph before and after the custom passes for inspection.
    """

18
19
20
    def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph],
                                                             None]]):
        self.custom_passes = list(passes)
21
22
        from torch._inductor import config
        self.current_config = config.shallow_copy_dict()
23
        self.current_config['force_disable_caches'] = True
24
25
        self.current_config['post_grad_custom_post_pass'] = self.post_pass

26
    def __call__(self, graph: fx.GraphModule, example_inputs):
27
28
29
30
31
        from torch._inductor.compile_fx import compile_fx
        return compile_fx(graph,
                          example_inputs,
                          config_patches=self.current_config)

32
    def post_pass(self, graph: fx.Graph):
33
34
35
36
37
38
39
        self.graph_pre_pass = deepcopy(graph)
        for pass_ in self.custom_passes:
            pass_(graph)

        self.graph_post_pass = deepcopy(graph)
        # assign by reference, will reflect the final state of the graph
        self.final_graph = graph