__init__.py 12.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
from __future__ import annotations

import multiprocessing
6
from concurrent.futures import Future, ThreadPoolExecutor
7
8
9
10
from typing import TYPE_CHECKING, Optional

from vllm.config import VllmConfig
from vllm.logger import init_logger
11
12
13
from vllm.reasoning import ReasoningParserManager
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import LazyLoader
14
from vllm.v1.structured_output.backend_guidance import GuidanceBackend
15
16
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
                                                     StructuredOutputGrammar)
17
from vllm.v1.structured_output.backend_xgrammar import XgrammarBackend
18
19
20
21

if TYPE_CHECKING:
    import numpy as np
    import numpy.typing as npt
22
    import torch
23

24
    from vllm.reasoning import ReasoningParser
25
    from vllm.v1.request import Request
26
27
else:
    torch = LazyLoader("torch", globals(), "torch")
28
29
30
31
32

logger = init_logger(__name__)


class StructuredOutputManager:
33
    """Engine-level manager for structured output requests."""
34

35
    def __init__(self, vllm_config: VllmConfig):
36
        self.backend: Optional[StructuredOutputBackend] = None
37
        self.reasoner: Optional[ReasoningParser] = None
38
        self.vllm_config = vllm_config
39

40
        self._grammar_bitmask: Optional[torch.Tensor] = None
41
        self._full_mask = torch.tensor(-1, dtype=torch.int32)
42

43
44
45
46
47
48
49
50
51
52
53
        max_batch_size = self.vllm_config.scheduler_config.max_num_seqs
        self.fill_bitmask_parallel_threshold = 128
        if self.fill_bitmask_parallel_threshold < max_batch_size:
            self.fill_bitmask_parallel_batch_size = 16
            # Use:
            # - at least 1 CPU
            # - at most half the number of CPUs or 8, whichever is less
            max_workers = max(1, min(multiprocessing.cpu_count() // 2, 8))
            self.executor_for_fillmask = ThreadPoolExecutor(
                max_workers=max_workers)

54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
        if not self.vllm_config.model_config.skip_tokenizer_init:
            # The default max_workers if not specified is the number of
            # CPUs * 5, which is way too high since these tasks are CPU-bound,
            # not I/O bound. We also know we would never dominate CPU usage
            # with just grammar compilation, so we set it to half the number
            # of CPUs.
            max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
            self.executor = ThreadPoolExecutor(max_workers=max_workers)
            self.tokenizer = init_tokenizer_from_configs(
                model_config=self.vllm_config.model_config,
                scheduler_config=self.vllm_config.scheduler_config,
                lora_config=self.vllm_config.lora_config,
            ).get_lora_tokenizer(None)
            reasoning_backend = \
                    self.vllm_config.decoding_config.reasoning_backend
            if reasoning_backend:
                reasoner_cls = ReasoningParserManager.get_reasoning_parser(
                    reasoning_backend)
                self.reasoner = reasoner_cls(tokenizer=self.tokenizer)
73

74
    def grammar_init(self, request: Request) -> None:
75
76
77
        if request.structured_output_request is None:
            return

78
        if TYPE_CHECKING:
79
80
            assert request.sampling_params is not None and \
                request.sampling_params.guided_decoding is not None
81

82
83
84
85
86
        # Initialize the backend the first time it is needed.
        #
        # NOTE: We only support a single backend. We do NOT support different
        # backends on a per-request basis in V1 (for now, anyway...).
        if self.backend is None:
87
            assert request.sampling_params is not None
88
            backend = request.sampling_params.guided_decoding.backend
89
            vocab_size = self.vllm_config.model_config.get_vocab_size()
90
            if backend == "xgrammar":
91
92
93
94
95
                self.backend = XgrammarBackend(
                    self.vllm_config,
                    tokenizer=self.tokenizer,
                    vocab_size=vocab_size,
                )
96
            elif backend == "guidance":
97
98
99
100
101
                self.backend = GuidanceBackend(
                    self.vllm_config,
                    tokenizer=self.tokenizer,
                    vocab_size=vocab_size,
                )
102
103
104
105
106
107
108
109
110
            elif backend == "outlines":
                from vllm.v1.structured_output.backend_outlines import (
                    OutlinesBackend)

                self.backend = OutlinesBackend(
                    self.vllm_config,
                    tokenizer=self.tokenizer,
                    vocab_size=vocab_size,
                )
111
112
            else:
                raise ValueError(
113
                    f"Unsupported structured output backend: {backend}")
114

115
        grammar = self.executor.submit(self._async_create_grammar, request)
116
        request.structured_output_request.grammar = grammar  # type: ignore[assignment]
117

118
    def _async_create_grammar(
119
120
121
        self,
        request: Request,
    ) -> StructuredOutputGrammar:
122
123
124
125
126
        key = request.structured_output_request.structured_output_key  # type: ignore[union-attr]

        # Note that the request was validated in the engine core client,
        # so at this point we know it is a supported type of request.
        #
127
128
        # TODO: we still need to handle xgrammar compilation failures,
        # though it should be unlikely as we test that up front as well.
129
130
        request_type, grammar_spec = key

131
132
        assert self.backend is not None
        return self.backend.compile_grammar(request_type, grammar_spec)
133

134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
    def _fill_bitmasks(
        self,
        batch: list[tuple[StructuredOutputGrammar, int, bool]],
    ) -> None:
        assert self._grammar_bitmask is not None
        for grammar, index, apply_bitmask in batch:
            if apply_bitmask and not grammar.is_terminated():
                grammar.fill_bitmask(self._grammar_bitmask, index)
            else:
                # Note that for thinking support, we will need to
                # reset the relevant part of the bitmask for consequent
                # requests here.
                self._grammar_bitmask[index].fill_(self._full_mask)

    def _async_submit_fill_bitmask(
        self,
        batch: list[tuple[StructuredOutputGrammar, int, bool]],
    ) -> Future:
        return self.executor_for_fillmask.submit(self._fill_bitmasks, batch)

154
155
156
157
    def grammar_bitmask(
        self,
        requests: dict[str, Request],
        structured_output_request_ids: dict[str, int],
158
        scheduled_spec_decode_tokens: dict[str, list[int]],
159
160
161
162
163
    ) -> Optional[npt.NDArray[np.int32]]:
        # Prepare the structured output bitmask for this batch.
        if not structured_output_request_ids:
            return None

164
165
166
167
168
        max_num_spec_tokens = 0
        if self.vllm_config.speculative_config is not None:
            max_num_spec_tokens = \
                self.vllm_config.speculative_config.num_speculative_tokens

169
170
        if self._grammar_bitmask is None:
            assert self.backend is not None
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
            max_batch_size = self.vllm_config.scheduler_config.max_num_seqs

            # Allocate a bitmask for each token needing to be checked:
            # one for each speculative position, and one more for the
            # bonus token / non-speculative token.
            self._grammar_bitmask = \
                self.backend.allocate_token_bitmask(
                    max_batch_size * (1 + max_num_spec_tokens))

        # Generate a batched bitmask for all structured output requests.
        # When speculative decoding is enabled, we need to include multiple
        # masks for each request, one for each possible bonus token position.
        # These are stored inline in the tensor and unpacked by the gpu runner.
        cumulative_index = 0
        ordered_seq = sorted(structured_output_request_ids.items(),
                             key=lambda x: x[1])
187

188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
        # Optimized parallel filling of bitmasks for
        # non-spec, large-batch-size cases
        if len(ordered_seq) > self.fill_bitmask_parallel_threshold and \
                max_num_spec_tokens == 0:
            promises = []
            batch = []
            for req_id, _ in ordered_seq:
                request = requests[req_id]
                structured_output_request = request.structured_output_request
                if TYPE_CHECKING:
                    assert structured_output_request is not None
                    assert structured_output_request.grammar is not None

                apply_bitmask = self.should_fill_bitmask(request)
                batch.append((structured_output_request.grammar,
                              cumulative_index, apply_bitmask))
                if len(batch) == self.fill_bitmask_parallel_batch_size:
                    promises.append(self._async_submit_fill_bitmask(batch))
                    batch = []

                cumulative_index += 1
            if batch:
                promises.append(self._async_submit_fill_bitmask(batch))

            # Wait for all bitmask filling tasks to complete.
            for promise in promises:
                promise.result()
        else:
            # Fallback to serial filling of bitmasks for small-batch-size cases
            for req_id, _ in ordered_seq:
                request = requests[req_id]
                structured_output_request = request.structured_output_request

                if TYPE_CHECKING:
                    assert structured_output_request is not None
                    assert structured_output_request.grammar is not None
                apply_bitmask = self.should_fill_bitmask(request)

                state_advancements = 0
                req_tokens = scheduled_spec_decode_tokens.get(req_id, [])
                for i, token in enumerate(req_tokens + [None]):
                    self._fill_bitmasks([(structured_output_request.grammar,
                                          cumulative_index, apply_bitmask)])

                    if apply_bitmask and token is not None and \
                        not structured_output_request.grammar.is_terminated():
234
235
                        assert structured_output_request.grammar.accept_tokens(
                            req_id, [token])
236
                        state_advancements += 1
237
238
239
240
                    cumulative_index += 1
                if state_advancements > 0:
                    structured_output_request.grammar.rollback(
                        state_advancements)
241

242
        bitmask_tensor = self._grammar_bitmask
243
244
        if cumulative_index < bitmask_tensor.shape[0]:
            bitmask_tensor = bitmask_tensor[:cumulative_index]
245
246
247
248
249

        # After finishing with the xgrammar operations, we convert to
        # np.ndarray, because that is much more efficient for serialization
        # and deserialization when sending this to the GPU workers.
        return bitmask_tensor.numpy()
250

251
252
253
254
255
256
257
258
259
    def should_fill_bitmask(self, request: Request) -> bool:
        if self.reasoner is not None:
            assert request.structured_output_request is not None
            if request.structured_output_request.reasoning_ended is None:
                request.structured_output_request.reasoning_ended = \
                    self.reasoner.is_reasoning_end(request.prompt_token_ids)
            return request.structured_output_request.reasoning_ended
        return True

260
261
262
263
264
265
266
267
268
269
    def should_advance(self, request: Request) -> bool:
        if not request.use_structured_output:
            return False

        # To determine whether we can advance the FSM.
        # Supports thinking usage where we skip the reasoning components.
        if TYPE_CHECKING:
            assert request.structured_output_request is not None
            assert request.structured_output_request.grammar is not None
        # by default, we should always advance
270
        # for cases that don't use thinking mode.
271
272
273
274
275
276
277
278
        if self.reasoner is not None:
            structured_req = request.structured_output_request

            if structured_req.reasoning_ended:
                return True

            # Check if reasoning ends in *this* step
            if self.reasoner.is_reasoning_end(request.all_token_ids):
279
                # Reasoning just ended, so we shouldn't advance til
280
281
282
283
284
285
286
                # next pass
                structured_req.reasoning_ended = True

            return False
        else:
            return True

287
288
289
    def clear_backend(self) -> None:
        if self.backend is not None:
            self.backend.destroy()