vllm_inductor_pass.py 5.73 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import functools
import operator
5
import time
6
from typing import ClassVar
7

8
import regex as re
9
import torch
10
from torch._dynamo.utils import lazy_format_graph_code
11
from torch._inductor.pattern_matcher import PatternMatcherPass, PatternPrettyPrinter
12

13
from vllm.config import VllmConfig
14
15
16
17
18
19
20
21
22
23
24
25
from vllm.logger import init_logger

from .inductor_pass import InductorPass

logger = init_logger(__name__)


class VllmInductorPass(InductorPass):
    """
    An inductor pass with access to vLLM PassConfig.
    It provides timing, logging, and dumping utilities.
    """
26

27
    dump_prefix: ClassVar[int | None] = None
28
    """Keep track of pass index for debug dump ordering."""
29

30
31
    def __init__(self, config: VllmConfig):
        self.pass_config = config.compilation_config.pass_config
32
33
        self.model_dtype = config.model_config.dtype if config.model_config else None
        self.device = config.device_config.device if config.device_config else None
34
35
        self.pass_name = self.__class__.__name__

36
37
38
39
40
41
42
43
44
45
46
47
    @staticmethod
    def time_and_log(call_fn):
        @functools.wraps(call_fn)
        def wrapped(self: VllmInductorPass, graph: torch.fx.Graph):
            self.begin()
            self.dump_graph(graph, "before")
            call_fn(self, graph)
            self.dump_graph(graph, "after")
            self.end_and_log()

        return wrapped

48
    def dump_graph(self, graph: torch.fx.Graph, stage: str):
49
50
        i = VllmInductorPass.dump_prefix
        i_str = "" if i is None else f".{i}"
51
52
53
        lazy_format_graph_code(
            f"post_grad{i_str}.{self.pass_name}.{stage}", graph.owning_module
        )
54

55
56
57
58
59
60
61
    def begin(self):
        self._start_time = time.perf_counter_ns()

    def end_and_log(self):
        self._end_time = time.perf_counter_ns()
        duration_ms = float(self._end_time - self._start_time) / 1.0e6
        logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms)
62
63


64
65
66
67
68
69
70
71
class VllmPatternMatcherPass(VllmInductorPass):
    """
    A VllmInductorPass that uses the Inductor pattern matcher.
    Its main use is providing the dump_patterns utility that dumps the
    Inductor pattern matcher patterns into a file, which greatly aids debugging.

    TODO(luka) move more utilities to this pass.
    """
72

73
74
75
76
    matched_count: int = 0
    """The number of matched patterns in the pass."""

    _OP_OVERLOAD_PATTERN: ClassVar[re.Pattern] = re.compile(
77
78
        r"<OpOverload\(op='([^']*)', overload='([^']*)'\)>"
    )
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97

    def _replace_op_overloads(self, string: str) -> str:
        """Replace <OpOverload(..., ...)> with nicer formulations"""
        return self._OP_OVERLOAD_PATTERN.sub(
            lambda m: f"torch.ops.{m.group(1)}.{m.group(2)}",
            string,
        )

    def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass):
        """
        If debug dumping is enabled, dump the Inductor pattern-matcher patterns
        into the debug_dump_path folder next to the dumped fx graphs.

        This method does its best to print something that looks like Python code
        for easier debugging and potentially navigation. If any errors appear in
        the output, please add to this method.

        TODO(luka): use pattern object to manually produce pattern graph
        """
98
        debug_dump_path = config.compile_debug_dump_path()
99
100
101
102
103
104
        if not debug_dump_path:
            return

        debug_dump_path.mkdir(parents=True, exist_ok=True)

        from vllm.utils import unique_filepath
105

106
        file_path = unique_filepath(
107
108
            lambda i: debug_dump_path / f"patterns.{self.pass_name}.{i}.py"
        )
109
110
111

        with file_path.open("w") as f:
            print(
112
113
114
115
116
117
118
119
120
                f"# This file was produced by VllmPatternMatcherPass."
                f"dump_patterns for {self.pass_name}.\n"
                f"# It does its best to produce valid-Python-looking code but"
                f" please add to dump_patterns if there are any errors.\n\n"
                f"from torch._higher_order_ops.auto_functionalize import "
                f"auto_functionalized as auto_functionalized\n"
                f"from torch._inductor.pattern_matcher import *",
                file=f,
            )
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138

            for node, patterns in pm_pass.patterns.items():
                # fix the operator.getitem repr
                if node[1] == operator.getitem:
                    node_repr = f"({repr(node[0])}, operator.getitem)"
                else:
                    node_repr = repr(node)

                node_repr = self._replace_op_overloads(node_repr)

                print(f"\n\n# Patterns for op: {node_repr}", file=f)
                for i, pattern in enumerate(patterns):
                    # reserve auto_functionalized ahead of time
                    pp = PatternPrettyPrinter()
                    pp.namespace.create_name("auto_functionalized", None)

                    # Assemble pattern
                    out_node = pp.pretty_print(pattern.pattern)
139
140
141
142
143
144
145
146
147
                    pattern_repr = "\n".join(
                        [f"def pattern_{i}():"]
                        + [
                            f"{pp.memoized_objs_names[key]} = "
                            f"{pp.memoized_objs_pp[key]}"
                            for key in pp.memoized_objs_names
                        ]
                        + [f"return {out_node}"]
                    ).replace("\n", "\n    ")
148
149
150
151
152

                    pattern_repr = self._replace_op_overloads(pattern_repr)
                    print(f"{pattern_repr}\n", file=f)


153
class PrinterInductorPass(VllmInductorPass):
154
    def __init__(self, name: str, config: VllmConfig):
155
156
157
158
        super().__init__(config)
        self.name = name

    def __call__(self, graph: torch.fx.Graph):
159
        self.dump_graph(graph, self.name)