piecewise_backend.py 7.09 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

import dataclasses
5
6
from collections.abc import Callable
from typing import Any
7
8
9
10
11
12

import torch.fx as fx

from vllm.compilation.backends import VllmBackend
from vllm.compilation.monitor import end_monitoring_torch_compile
from vllm.config import VllmConfig
13
from vllm.config.compilation import Range
14
15
16
17
18
19
from vllm.logger import init_logger

logger = init_logger(__name__)


@dataclasses.dataclass
20
21
class RangeEntry:
    compile_range: Range
22
23
24
25
    compiled: bool = False
    runnable: Callable = None  # type: ignore


26
class PiecewiseBackend:
27
28
29
30
31
32
33
34
35
    def __init__(
        self,
        graph: fx.GraphModule,
        vllm_config: VllmConfig,
        piecewise_compile_index: int,
        total_piecewise_compiles: int,
        sym_shape_indices: list[int],
        vllm_backend: VllmBackend,
    ):
36
37
        """
        The backend for piecewise compilation.
38
        It mainly handles the compilation of static shapes and
39
        dispatching based on runtime shape.
40
41
42
43
44
45
46
47
48
49
50
51
52

        We will compile `self.graph` once for the general shape,
        and then compile for different shapes specified in
        `compilation_config.compile_sizes`.
        """
        self.graph = graph
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
        self.piecewise_compile_index = piecewise_compile_index
        self.total_piecewise_compiles = total_piecewise_compiles
        self.vllm_backend = vllm_backend

        self.is_first_graph = piecewise_compile_index == 0
53
        self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
54

55
        self.is_full_graph = total_piecewise_compiles == 1
56
57
58
59
60
61
        # TODO: we need to generalize encoder compilation to other models
        self.is_encoder_compilation = vllm_backend.prefix in [
            "Qwen2_5_VisionPatchEmbed",
            "Qwen2_5_VisionPatchMerger",
            "Qwen2_5_VisionBlock",
        ]
62

63
        self.compile_ranges = self.compilation_config.get_compile_ranges()
64
65
66
67
68
69
70
71
72
73
74
75
76
        if self.is_encoder_compilation:
            # For encoder compilation we use the max int32 value
            # to set the upper bound of the compile ranges
            max_int32 = 2**31 - 1
            last_compile_range = self.compile_ranges[-1]
            assert (
                last_compile_range.end
                == vllm_config.scheduler_config.max_num_batched_tokens
            )
            self.compile_ranges[-1] = Range(
                start=last_compile_range.start, end=max_int32
            )

77
78
        log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}"
        logger.debug_once(log_string)
79

80
81
82
        self.compile_sizes = self.compilation_config.compile_sizes
        log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}"
        logger.debug_once(log_string)
83
84
85

        self.sym_shape_indices = sym_shape_indices

86
87
        # the entries for ranges that we need to either
        self.range_entries: dict[Range, RangeEntry] = {}
88

89
        # to_be_compiled_ranges tracks the remaining ranges to compile,
90
        # and updates during the compilation process, so we need to copy it
91
        self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges)
92
93

        # We only keep compilation management inside this class directly.
94
95
96
97
98
99
100
101
102
103
104
        for size in self.compile_sizes:
            range = Range(start=size, end=size)
            if range not in self.compile_ranges:
                self.range_entries[range] = RangeEntry(
                    compile_range=range,
                )
                self.to_be_compiled_ranges.add(range)

        for range in self.compile_ranges:
            self.range_entries[range] = RangeEntry(
                compile_range=range,
105
106
107
            )

    def check_for_ending_compilation(self):
108
        if self.is_last_graph and not self.to_be_compiled_ranges:
109
110
111
112
113
            # no specific sizes to compile
            # save the hash of the inductor graph for the next run
            self.vllm_backend.compiler_manager.save_to_file()
            end_monitoring_torch_compile(self.vllm_config)

114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
    def _fakify_args(self, args: list[Any]) -> list[Any]:
        # We need to pass fake example_inputs, otherwise torch.compile
        # will fakify the example_inputs potentially causing some non dynamic
        # dimension to be be duck shaped to other existing shapes that have hints
        # matching their values.
        # This is problem because it can lead to unintended specializations!
        # if the new wrongly dynamic dim is specialized
        # it will force specializing the whole shape
        # torch.compile probably should not accept
        # non fake tensors as example inputs!
        # See issue https://github.com/vllm-project/vllm/issues/27899
        fake_example_inputs = []
        for node in self.graph.graph.nodes:
            # All place holders come first
            if node.op == "placeholder":
                fake_example_inputs.append(node.meta["example_value"])
            else:
                break
        assert len(fake_example_inputs) == len(args)
        return fake_example_inputs

    def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any:
        if not range_entry.compiled:
            range_entry.compiled = True
            self.to_be_compiled_ranges.remove(range_entry.compile_range)
139
140

            # args are real arguments
141
142
143
144
145
146
147
148
149
            # fakify for range, real args for concrete size.
            # For concrete size, we clear the shape env in
            # compiler_manager.compile() so no need to fakify.
            args = (
                self._fakify_args(args)
                if not range_entry.compile_range.is_single_size()
                else args
            )
            range_entry.runnable = self.vllm_backend.compiler_manager.compile(
150
151
                self.graph,
                args,
152
                self.vllm_backend.inductor_config,
153
                self.compilation_config,
154
                compile_range=range_entry.compile_range,
155
156
                graph_index=self.piecewise_compile_index,
                num_graphs=self.total_piecewise_compiles,
157
            )
158

159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
            self.check_for_ending_compilation()

    def _find_range_for_shape(self, runtime_shape: int) -> Range | None:
        # First we try to find the range entry for the concrete compile size
        # If not found, we search for the range entry
        # that contains the runtime shape.
        if runtime_shape in self.compile_sizes:
            return self.range_entries[Range(start=runtime_shape, end=runtime_shape)]
        else:
            for range in self.compile_ranges:
                if runtime_shape in range:
                    return self.range_entries[range]
        return None

    def __call__(self, *args) -> Any:
        runtime_shape = args[self.sym_shape_indices[0]]
        range_entry = self._find_range_for_shape(runtime_shape)

        assert range_entry is not None, (
            f"Shape out of considered range: {runtime_shape} "
            "[1, max_num_batched_tokens]"
        )
181

182
183
        self._maybe_compile_for_range_entry(range_entry, args)
        return range_entry.runnable(*args)