piecewise_backend.py 7.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

import dataclasses
5
6
from collections.abc import Callable
from typing import Any
7
8
9
10
11
12

import torch.fx as fx

from vllm.compilation.backends import VllmBackend
from vllm.compilation.monitor import end_monitoring_torch_compile
from vllm.config import VllmConfig
13
from vllm.config.compilation import Range
14
15
16
17
18
19
from vllm.logger import init_logger

logger = init_logger(__name__)


@dataclasses.dataclass
20
21
class RangeEntry:
    compile_range: Range
22
23
24
25
    compiled: bool = False
    runnable: Callable = None  # type: ignore


26
class PiecewiseBackend:
27
28
29
30
31
32
33
34
35
    def __init__(
        self,
        graph: fx.GraphModule,
        vllm_config: VllmConfig,
        piecewise_compile_index: int,
        total_piecewise_compiles: int,
        sym_shape_indices: list[int],
        vllm_backend: VllmBackend,
    ):
36
37
        """
        The backend for piecewise compilation.
38
        It mainly handles the compilation of static shapes and
39
        dispatching based on runtime shape.
40
41
42
43
44
45
46
47
48
49
50
51
52

        We will compile `self.graph` once for the general shape,
        and then compile for different shapes specified in
        `compilation_config.compile_sizes`.
        """
        self.graph = graph
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
        self.piecewise_compile_index = piecewise_compile_index
        self.total_piecewise_compiles = total_piecewise_compiles
        self.vllm_backend = vllm_backend

        self.is_first_graph = piecewise_compile_index == 0
53
        self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
54

55
        self.is_full_graph = total_piecewise_compiles == 1
56
        self.is_encoder_compilation = vllm_backend.is_encoder
57

58
        self.compile_ranges = self.compilation_config.get_compile_ranges()
59
60
61
62
63
64
65
66
67
68
69
70
71
        if self.is_encoder_compilation:
            # For encoder compilation we use the max int32 value
            # to set the upper bound of the compile ranges
            max_int32 = 2**31 - 1
            last_compile_range = self.compile_ranges[-1]
            assert (
                last_compile_range.end
                == vllm_config.scheduler_config.max_num_batched_tokens
            )
            self.compile_ranges[-1] = Range(
                start=last_compile_range.start, end=max_int32
            )

72
73
        log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}"
        logger.debug_once(log_string)
74

75
76
77
        self.compile_sizes = self.compilation_config.compile_sizes
        log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}"
        logger.debug_once(log_string)
78
79
80

        self.sym_shape_indices = sym_shape_indices

81
82
        # the entries for ranges that we need to either
        self.range_entries: dict[Range, RangeEntry] = {}
83

84
        # to_be_compiled_ranges tracks the remaining ranges to compile,
85
        # and updates during the compilation process, so we need to copy it
86
        self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges)
87
88

        # We only keep compilation management inside this class directly.
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
        if self.compile_sizes is not None:
            for size in self.compile_sizes:
                if isinstance(size, str):
                    assert size == "cudagraph_capture_sizes"
                    raise NotImplementedError(
                        "cudagraph_capture_sizes not supported in compile_sizes."
                        "This should be handled in `post_init_cudagraph_sizes`."
                    )
                else:
                    assert isinstance(size, int)
                    range = Range(start=size, end=size)
                    if range not in self.compile_ranges:
                        self.range_entries[range] = RangeEntry(
                            compile_range=range,
                        )
                        self.to_be_compiled_ranges.add(range)
105
106
107
108

        for range in self.compile_ranges:
            self.range_entries[range] = RangeEntry(
                compile_range=range,
109
110
            )

111
    def check_for_ending_compilation(self) -> None:
112
        if self.is_last_graph and not self.to_be_compiled_ranges:
113
114
115
116
117
            # no specific sizes to compile
            # save the hash of the inductor graph for the next run
            self.vllm_backend.compiler_manager.save_to_file()
            end_monitoring_torch_compile(self.vllm_config)

118
    def _fakify_args(self, args: tuple[Any, ...]) -> list[Any]:
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
        # We need to pass fake example_inputs, otherwise torch.compile
        # will fakify the example_inputs potentially causing some non dynamic
        # dimension to be be duck shaped to other existing shapes that have hints
        # matching their values.
        # This is problem because it can lead to unintended specializations!
        # if the new wrongly dynamic dim is specialized
        # it will force specializing the whole shape
        # torch.compile probably should not accept
        # non fake tensors as example inputs!
        # See issue https://github.com/vllm-project/vllm/issues/27899
        fake_example_inputs = []
        for node in self.graph.graph.nodes:
            # All place holders come first
            if node.op == "placeholder":
                fake_example_inputs.append(node.meta["example_value"])
            else:
                break
        assert len(fake_example_inputs) == len(args)
        return fake_example_inputs

139
140
141
    def _maybe_compile_for_range_entry(
        self, range_entry: RangeEntry, args: tuple[Any, ...]
    ) -> Any:
142
143
144
        if not range_entry.compiled:
            range_entry.compiled = True
            self.to_be_compiled_ranges.remove(range_entry.compile_range)
145
146

            # args are real arguments
147
148
149
            # fakify for range, real args for concrete size.
            # For concrete size, we clear the shape env in
            # compiler_manager.compile() so no need to fakify.
150
            args_list = (
151
152
                self._fakify_args(args)
                if not range_entry.compile_range.is_single_size()
153
                else list(args)
154
155
            )
            range_entry.runnable = self.vllm_backend.compiler_manager.compile(
156
                self.graph,
157
                args_list,
158
                self.vllm_backend.inductor_config,
159
                self.compilation_config,
160
                compile_range=range_entry.compile_range,
161
162
                graph_index=self.piecewise_compile_index,
                num_graphs=self.total_piecewise_compiles,
163
            )
164

165
166
            self.check_for_ending_compilation()

167
    def _find_range_for_shape(self, runtime_shape: int) -> RangeEntry | None:
168
169
170
        # First we try to find the range entry for the concrete compile size
        # If not found, we search for the range entry
        # that contains the runtime shape.
171
172
173
        if self.compile_sizes is None:
            return None

174
175
176
177
178
179
180
181
        if runtime_shape in self.compile_sizes:
            return self.range_entries[Range(start=runtime_shape, end=runtime_shape)]
        else:
            for range in self.compile_ranges:
                if runtime_shape in range:
                    return self.range_entries[range]
        return None

182
    def __call__(self, *args: Any) -> Any:
183
184
185
186
        runtime_shape = args[self.sym_shape_indices[0]]
        range_entry = self._find_range_for_shape(runtime_shape)

        assert range_entry is not None, (
187
            f"Shape: {runtime_shape} out of considered ranges: {self.compile_ranges}"
188
        )
189

190
191
        self._maybe_compile_for_range_entry(range_entry, args)
        return range_entry.runnable(*args)