backend.py 4.97 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import weakref
5
from collections.abc import Callable, Sequence
6
from contextlib import nullcontext
7
8
from copy import deepcopy

9
import depyf
10
from torch import fx
11
from torch._ops import OpOverload, OpOverloadPacket
12
from torch.fx._utils import lazy_format_graph_code
13

14
15
16
17
from vllm.compilation.passes.fx_utils import find_op_nodes
from vllm.compilation.passes.inductor_pass import InductorPass
from vllm.compilation.passes.pass_manager import with_pattern_match_debug
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass
18
from vllm.config import VllmConfig, get_current_vllm_config
19
20
21
from vllm.logger import init_logger

logger = init_logger("vllm.tests.compile.backend")
22
23
24
25
26
27
28
29
30


class LazyInitPass(InductorPass):
    """
    If there's a pass that we want to initialize lazily in a test,
    we can wrap it in LazyInitPass, which will initialize the pass when invoked
    and then immediately invoke it.
    """

31
    def __init__(self, pass_cls: type[VllmInductorPass], vllm_config: VllmConfig):
32
33
34
35
36
37
        self.pass_cls = pass_cls
        self.vllm_config = weakref.proxy(vllm_config)  # avoid cycle

    def __call__(self, graph: fx.Graph) -> None:
        self.pass_ = self.pass_cls(self.vllm_config)
        self.pass_(graph)
38
39
40
41
42
43
44


class TestBackend:
    """
    This class provides a simple Inductor backend that can be used for testing.
    It takes a list of custom passes and runs them after Inductor's passes.
    It also saves the graph before and after the custom passes for inspection.
45
46
47
48

    Inductor config can be modified directly by editing the inductor_config
    property. This can be helpful for adding passes like the
    'pre_grad_custom_pass' and the 'post_grad_custom_pre_pass'.
Michael Goin's avatar
Michael Goin committed
49
    Inductor config is default-initialized from VllmConfig.CompilationConfig.
50
51
    """

52
    def __init__(self, *passes: InductorPass | Callable[[fx.Graph], None]):
53
        self.custom_passes = list(passes)
54
55
56
57
        vllm_config = get_current_vllm_config()
        compile_config = vllm_config.compilation_config
        # Deepcopy to allow multiple TestBackend instances to use the same VllmConfig
        self.inductor_config = deepcopy(compile_config.inductor_compile_config)
58
59
        self.inductor_config["force_disable_caches"] = True
        self.inductor_config["post_grad_custom_post_pass"] = self.post_pass
60

61
62
63
64
65
66
        if debug_dump_path := vllm_config.compile_debug_dump_path():
            logger.debug("Dumping depyf output to %s", debug_dump_path)
            self.debug_ctx = depyf.prepare_debug(debug_dump_path.as_posix())
        else:
            self.debug_ctx = nullcontext()

67
    def __call__(self, graph: fx.GraphModule, example_inputs):
68
        self.graph_pre_compile = deepcopy(graph)
69
        from torch._inductor.compile_fx import compile_fx
70

71
72
73
74
        with self.debug_ctx:
            return compile_fx(
                graph, example_inputs, config_patches=self.inductor_config
            )
75

76
    @with_pattern_match_debug
77
    def post_pass(self, graph: fx.Graph):
78
        self.graph_pre_pass = deepcopy(graph)
79
        lazy_format_graph_code("graph_pre_pass", graph.owning_module)
80
81

        VllmInductorPass.dump_prefix = 0
82
83
        for pass_ in self.custom_passes:
            pass_(graph)
84
85
86
            VllmInductorPass.dump_prefix += 1

        VllmInductorPass.dump_prefix = None
87
88

        self.graph_post_pass = deepcopy(graph)
89
        lazy_format_graph_code("graph_post_pass", graph.owning_module)
90
91
        # assign by reference, will reflect the final state of the graph
        self.final_graph = graph
92

93
94
95
    def check_before_ops(
        self, ops: Sequence[OpOverload | OpOverloadPacket], fully_replaced=True
    ):
96
        for op in ops:
97
98
99
100
101
            num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
            num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
            assert num_pre > 0, f"Op {op.name()} not found in pre-pass graph"
            assert num_pre > num_post, f"All nodes remain for op {op.name()}"
            if fully_replaced:
102
                assert num_post == 0, f"Unexpected op {op.name()} in post-pass graph"
103

104
    def check_after_ops(self, ops: Sequence[OpOverload | OpOverloadPacket]):
105
        for op in ops:
106
107
108
            num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
            num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
            assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph"
109
110
            assert num_post > 0, f"Op {op.name()} not found in post-pass graph"

111
    def op_count(self, op: OpOverload | OpOverloadPacket, before=False) -> int:
112
113
        graph = self.graph_pre_pass if before else self.graph_post_pass
        return len(list(find_op_nodes(op, graph)))
114
115
116
117
118
119

    def print_graphs(self):
        print("=== Graph before custom passes ===")
        print(self.graph_pre_pass.python_code(root_module="self", verbose=True).src)
        print("=== Graph after custom passes ===")
        print(self.graph_post_pass.python_code(root_module="self", verbose=True).src)