"docs/vscode:/vscode.git/clone" did not exist on "75beba29b5c7316a9ebde9b0886f609dd5bf05bd"
backend.py 3.96 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import weakref
5
from collections.abc import Sequence
6
from copy import deepcopy
7
from typing import Callable, Union
8

9
from torch import fx
10
from torch._ops import OpOverload
11

12
from vllm.compilation.fx_utils import find_op_nodes
13
from vllm.compilation.inductor_pass import InductorPass
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from vllm.compilation.pass_manager import with_pattern_match_debug
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig, get_current_vllm_config


class LazyInitPass(InductorPass):
    """
    If there's a pass that we want to initialize lazily in a test,
    we can wrap it in LazyInitPass, which will initialize the pass when invoked
    and then immediately invoke it.
    """

    def __init__(self, pass_cls: type[VllmInductorPass],
                 vllm_config: VllmConfig):
        self.pass_cls = pass_cls
        self.vllm_config = weakref.proxy(vllm_config)  # avoid cycle

    def __call__(self, graph: fx.Graph) -> None:
        self.pass_ = self.pass_cls(self.vllm_config)
        self.pass_(graph)
34
35
36
37
38
39
40


class TestBackend:
    """
    This class provides a simple Inductor backend that can be used for testing.
    It takes a list of custom passes and runs them after Inductor's passes.
    It also saves the graph before and after the custom passes for inspection.
41
42
43
44

    Inductor config can be modified directly by editing the inductor_config
    property. This can be helpful for adding passes like the
    'pre_grad_custom_pass' and the 'post_grad_custom_pre_pass'.
Michael Goin's avatar
Michael Goin committed
45
    Inductor config is default-initialized from VllmConfig.CompilationConfig.
46
47
    """

48
49
50
    def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph],
                                                             None]]):
        self.custom_passes = list(passes)
Michael Goin's avatar
Michael Goin committed
51
52
        compile_config = get_current_vllm_config().compilation_config
        self.inductor_config = compile_config.inductor_compile_config
53
54
        self.inductor_config['force_disable_caches'] = True
        self.inductor_config['post_grad_custom_post_pass'] = self.post_pass
55

56
    def __call__(self, graph: fx.GraphModule, example_inputs):
57
        self.graph_pre_compile = deepcopy(graph)
58
59
60
        from torch._inductor.compile_fx import compile_fx
        return compile_fx(graph,
                          example_inputs,
61
                          config_patches=self.inductor_config)
62

63
    @with_pattern_match_debug
64
    def post_pass(self, graph: fx.Graph):
65
        self.graph_pre_pass = deepcopy(graph)
66
67

        VllmInductorPass.dump_prefix = 0
68
69
        for pass_ in self.custom_passes:
            pass_(graph)
70
71
72
            VllmInductorPass.dump_prefix += 1

        VllmInductorPass.dump_prefix = None
73
74
75
76

        self.graph_post_pass = deepcopy(graph)
        # assign by reference, will reflect the final state of the graph
        self.final_graph = graph
77

78
    def check_before_ops(self, ops: Sequence[OpOverload], fully_replaced=True):
79
        for op in ops:
80
81
82
83
84
85
86
            num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
            num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
            assert num_pre > 0, f"Op {op.name()} not found in pre-pass graph"
            assert num_pre > num_post, f"All nodes remain for op {op.name()}"
            if fully_replaced:
                assert num_post == 0, \
                    f"Unexpected op {op.name()} in post-pass graph"
87

88
    def check_after_ops(self, ops: Sequence[OpOverload]):
89
        for op in ops:
90
91
92
            num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
            num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
            assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph"
93
94
95
96
97
            assert num_post > 0, f"Op {op.name()} not found in post-pass graph"

    def op_count(self, op: OpOverload, before=False) -> int:
        graph = self.graph_pre_pass if before else self.graph_post_pass
        return len(list(find_op_nodes(op, graph)))