piecewise_backend.py 6.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

import dataclasses
5
6
from collections.abc import Callable
from typing import Any
7
8
9
10
11
12

import torch.fx as fx

from vllm.compilation.backends import VllmBackend
from vllm.compilation.monitor import end_monitoring_torch_compile
from vllm.config import VllmConfig
13
from vllm.config.compilation import Range
14
15
16
17
18
19
from vllm.logger import init_logger

logger = init_logger(__name__)


@dataclasses.dataclass
20
21
class RangeEntry:
    compile_range: Range
22
23
24
25
    compiled: bool = False
    runnable: Callable = None  # type: ignore


26
class PiecewiseBackend:
27
28
29
30
31
32
33
34
35
    def __init__(
        self,
        graph: fx.GraphModule,
        vllm_config: VllmConfig,
        piecewise_compile_index: int,
        total_piecewise_compiles: int,
        sym_shape_indices: list[int],
        vllm_backend: VllmBackend,
    ):
36
37
        """
        The backend for piecewise compilation.
38
        It mainly handles the compilation of static shapes and
39
        dispatching based on runtime shape.
40
41
42
43
44
45
46
47
48
49
50
51
52

        We will compile `self.graph` once for the general shape,
        and then compile for different shapes specified in
        `compilation_config.compile_sizes`.
        """
        self.graph = graph
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
        self.piecewise_compile_index = piecewise_compile_index
        self.total_piecewise_compiles = total_piecewise_compiles
        self.vllm_backend = vllm_backend

        self.is_first_graph = piecewise_compile_index == 0
53
        self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
54

55
        self.is_full_graph = total_piecewise_compiles == 1
56
        self.is_encoder_compilation = vllm_backend.is_encoder
57

58
        self.compile_ranges = self.compilation_config.get_compile_ranges()
59
60
61
62
63
64
65
66
67
68
69
70
71
        if self.is_encoder_compilation:
            # For encoder compilation we use the max int32 value
            # to set the upper bound of the compile ranges
            max_int32 = 2**31 - 1
            last_compile_range = self.compile_ranges[-1]
            assert (
                last_compile_range.end
                == vllm_config.scheduler_config.max_num_batched_tokens
            )
            self.compile_ranges[-1] = Range(
                start=last_compile_range.start, end=max_int32
            )

72
73
        log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}"
        logger.debug_once(log_string)
74

75
76
77
        self.compile_sizes = self.compilation_config.compile_sizes
        log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}"
        logger.debug_once(log_string)
78
79
80

        self.sym_shape_indices = sym_shape_indices

81
82
        # the entries for ranges that we need to either
        self.range_entries: dict[Range, RangeEntry] = {}
83

84
        # to_be_compiled_ranges tracks the remaining ranges to compile,
85
        # and updates during the compilation process, so we need to copy it
86
        self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges)
87
88

        # We only keep compilation management inside this class directly.
89
90
91
92
93
94
95
96
97
98
99
        for size in self.compile_sizes:
            range = Range(start=size, end=size)
            if range not in self.compile_ranges:
                self.range_entries[range] = RangeEntry(
                    compile_range=range,
                )
                self.to_be_compiled_ranges.add(range)

        for range in self.compile_ranges:
            self.range_entries[range] = RangeEntry(
                compile_range=range,
100
101
102
            )

    def check_for_ending_compilation(self):
103
        if self.is_last_graph and not self.to_be_compiled_ranges:
104
105
106
107
108
            # no specific sizes to compile
            # save the hash of the inductor graph for the next run
            self.vllm_backend.compiler_manager.save_to_file()
            end_monitoring_torch_compile(self.vllm_config)

109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
    def _fakify_args(self, args: list[Any]) -> list[Any]:
        # We need to pass fake example_inputs, otherwise torch.compile
        # will fakify the example_inputs potentially causing some non dynamic
        # dimension to be be duck shaped to other existing shapes that have hints
        # matching their values.
        # This is problem because it can lead to unintended specializations!
        # if the new wrongly dynamic dim is specialized
        # it will force specializing the whole shape
        # torch.compile probably should not accept
        # non fake tensors as example inputs!
        # See issue https://github.com/vllm-project/vllm/issues/27899
        fake_example_inputs = []
        for node in self.graph.graph.nodes:
            # All place holders come first
            if node.op == "placeholder":
                fake_example_inputs.append(node.meta["example_value"])
            else:
                break
        assert len(fake_example_inputs) == len(args)
        return fake_example_inputs

    def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any:
        if not range_entry.compiled:
            range_entry.compiled = True
            self.to_be_compiled_ranges.remove(range_entry.compile_range)
134
135

            # args are real arguments
136
137
138
139
140
141
142
143
144
            # fakify for range, real args for concrete size.
            # For concrete size, we clear the shape env in
            # compiler_manager.compile() so no need to fakify.
            args = (
                self._fakify_args(args)
                if not range_entry.compile_range.is_single_size()
                else args
            )
            range_entry.runnable = self.vllm_backend.compiler_manager.compile(
145
146
                self.graph,
                args,
147
                self.vllm_backend.inductor_config,
148
                self.compilation_config,
149
                compile_range=range_entry.compile_range,
150
151
                graph_index=self.piecewise_compile_index,
                num_graphs=self.total_piecewise_compiles,
152
            )
153

154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
            self.check_for_ending_compilation()

    def _find_range_for_shape(self, runtime_shape: int) -> Range | None:
        # First we try to find the range entry for the concrete compile size
        # If not found, we search for the range entry
        # that contains the runtime shape.
        if runtime_shape in self.compile_sizes:
            return self.range_entries[Range(start=runtime_shape, end=runtime_shape)]
        else:
            for range in self.compile_ranges:
                if runtime_shape in range:
                    return self.range_entries[range]
        return None

    def __call__(self, *args) -> Any:
        runtime_shape = args[self.sym_shape_indices[0]]
        range_entry = self._find_range_for_shape(runtime_shape)

        assert range_entry is not None, (
            f"Shape out of considered range: {runtime_shape} "
            "[1, max_num_batched_tokens]"
        )
176

177
178
        self._maybe_compile_for_range_entry(range_entry, args)
        return range_entry.runnable(*args)