piecewise_backend.py 6.33 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

import dataclasses
5
6
from collections.abc import Callable
from typing import Any
7
8
9
10
11
12

import torch.fx as fx

from vllm.compilation.backends import VllmBackend
from vllm.compilation.monitor import end_monitoring_torch_compile
from vllm.config import VllmConfig
13
from vllm.config.compilation import Range
14
15
16
17
18
19
from vllm.logger import init_logger

logger = init_logger(__name__)


@dataclasses.dataclass
20
21
class RangeEntry:
    compile_range: Range
22
23
24
25
    compiled: bool = False
    runnable: Callable = None  # type: ignore


26
class PiecewiseBackend:
27
28
29
30
31
32
33
34
35
    def __init__(
        self,
        graph: fx.GraphModule,
        vllm_config: VllmConfig,
        piecewise_compile_index: int,
        total_piecewise_compiles: int,
        sym_shape_indices: list[int],
        vllm_backend: VllmBackend,
    ):
36
37
        """
        The backend for piecewise compilation.
38
        It mainly handles the compilation of static shapes and
39
        dispatching based on runtime shape.
40
41
42
43
44
45
46
47
48
49
50
51
52

        We will compile `self.graph` once for the general shape,
        and then compile for different shapes specified in
        `compilation_config.compile_sizes`.
        """
        self.graph = graph
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
        self.piecewise_compile_index = piecewise_compile_index
        self.total_piecewise_compiles = total_piecewise_compiles
        self.vllm_backend = vllm_backend

        self.is_first_graph = piecewise_compile_index == 0
53
        self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
54

55
56
        self.is_full_graph = total_piecewise_compiles == 1

57
58
59
        self.compile_ranges = self.compilation_config.get_compile_ranges()
        log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}"
        logger.debug_once(log_string)
60

61
62
63
        self.compile_sizes = self.compilation_config.compile_sizes
        log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}"
        logger.debug_once(log_string)
64
65
66

        self.sym_shape_indices = sym_shape_indices

67
68
        # the entries for ranges that we need to either
        self.range_entries: dict[Range, RangeEntry] = {}
69

70
        # to_be_compiled_ranges tracks the remaining ranges to compile,
71
        # and updates during the compilation process, so we need to copy it
72
        self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges)
73
74

        # We only keep compilation management inside this class directly.
75
76
77
78
79
80
81
82
83
84
85
        for size in self.compile_sizes:
            range = Range(start=size, end=size)
            if range not in self.compile_ranges:
                self.range_entries[range] = RangeEntry(
                    compile_range=range,
                )
                self.to_be_compiled_ranges.add(range)

        for range in self.compile_ranges:
            self.range_entries[range] = RangeEntry(
                compile_range=range,
86
87
88
            )

    def check_for_ending_compilation(self):
89
        if self.is_last_graph and not self.to_be_compiled_ranges:
90
91
92
93
94
            # no specific sizes to compile
            # save the hash of the inductor graph for the next run
            self.vllm_backend.compiler_manager.save_to_file()
            end_monitoring_torch_compile(self.vllm_config)

95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
    def _fakify_args(self, args: list[Any]) -> list[Any]:
        # We need to pass fake example_inputs, otherwise torch.compile
        # will fakify the example_inputs potentially causing some non dynamic
        # dimension to be be duck shaped to other existing shapes that have hints
        # matching their values.
        # This is problem because it can lead to unintended specializations!
        # if the new wrongly dynamic dim is specialized
        # it will force specializing the whole shape
        # torch.compile probably should not accept
        # non fake tensors as example inputs!
        # See issue https://github.com/vllm-project/vllm/issues/27899
        fake_example_inputs = []
        for node in self.graph.graph.nodes:
            # All place holders come first
            if node.op == "placeholder":
                fake_example_inputs.append(node.meta["example_value"])
            else:
                break
        assert len(fake_example_inputs) == len(args)
        return fake_example_inputs

    def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any:
        if not range_entry.compiled:
            range_entry.compiled = True
            self.to_be_compiled_ranges.remove(range_entry.compile_range)
120
121

            # args are real arguments
122
123
124
125
126
127
128
129
130
            # fakify for range, real args for concrete size.
            # For concrete size, we clear the shape env in
            # compiler_manager.compile() so no need to fakify.
            args = (
                self._fakify_args(args)
                if not range_entry.compile_range.is_single_size()
                else args
            )
            range_entry.runnable = self.vllm_backend.compiler_manager.compile(
131
132
                self.graph,
                args,
133
                self.vllm_backend.inductor_config,
134
                self.compilation_config,
135
                compile_range=range_entry.compile_range,
136
137
                graph_index=self.piecewise_compile_index,
                num_graphs=self.total_piecewise_compiles,
138
            )
139

140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
            self.check_for_ending_compilation()

    def _find_range_for_shape(self, runtime_shape: int) -> Range | None:
        # First we try to find the range entry for the concrete compile size
        # If not found, we search for the range entry
        # that contains the runtime shape.
        if runtime_shape in self.compile_sizes:
            return self.range_entries[Range(start=runtime_shape, end=runtime_shape)]
        else:
            for range in self.compile_ranges:
                if runtime_shape in range:
                    return self.range_entries[range]
        return None

    def __call__(self, *args) -> Any:
        runtime_shape = args[self.sym_shape_indices[0]]
        range_entry = self._find_range_for_shape(runtime_shape)

        assert range_entry is not None, (
            f"Shape out of considered range: {runtime_shape} "
            "[1, max_num_batched_tokens]"
        )
162

163
164
        self._maybe_compile_for_range_entry(range_entry, args)
        return range_entry.runnable(*args)