__init__.py 12.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
from __future__ import annotations

import multiprocessing
6
from concurrent.futures import Future, ThreadPoolExecutor
7
from typing import TYPE_CHECKING
8
9
10

from vllm.config import VllmConfig
from vllm.logger import init_logger
11
from vllm.reasoning import ReasoningParserManager
12
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
13
from vllm.utils import LazyLoader
14
from vllm.v1.structured_output.backend_guidance import GuidanceBackend
15
16
17
18
from vllm.v1.structured_output.backend_types import (
    StructuredOutputBackend,
    StructuredOutputGrammar,
)
19
from vllm.v1.structured_output.backend_xgrammar import XgrammarBackend
20
21
22
23

if TYPE_CHECKING:
    import numpy as np
    import numpy.typing as npt
24
    import torch
25

26
    from vllm.reasoning import ReasoningParser
27
    from vllm.v1.request import Request
28
29
else:
    torch = LazyLoader("torch", globals(), "torch")
30
31
32
33
34

logger = init_logger(__name__)


class StructuredOutputManager:
35
    """Engine-level manager for structured output requests."""
36

37
    def __init__(self, vllm_config: VllmConfig):
38
39
        self.backend: StructuredOutputBackend | None = None
        self.reasoner: ReasoningParser | None = None
40
        self.vllm_config = vllm_config
41

42
        self._grammar_bitmask: torch.Tensor | None = None
43
        self._full_mask = torch.tensor(-1, dtype=torch.int32)
44

45
46
47
48
49
50
51
52
        max_batch_size = self.vllm_config.scheduler_config.max_num_seqs
        self.fill_bitmask_parallel_threshold = 128
        if self.fill_bitmask_parallel_threshold < max_batch_size:
            self.fill_bitmask_parallel_batch_size = 16
            # Use:
            # - at least 1 CPU
            # - at most half the number of CPUs or 8, whichever is less
            max_workers = max(1, min(multiprocessing.cpu_count() // 2, 8))
53
            self.executor_for_fillmask = ThreadPoolExecutor(max_workers=max_workers)
54

55
56
57
58
59
60
61
62
63
        if not self.vllm_config.model_config.skip_tokenizer_init:
            # The default max_workers if not specified is the number of
            # CPUs * 5, which is way too high since these tasks are CPU-bound,
            # not I/O bound. We also know we would never dominate CPU usage
            # with just grammar compilation, so we set it to half the number
            # of CPUs.
            max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
            self.executor = ThreadPoolExecutor(max_workers=max_workers)
            self.tokenizer = init_tokenizer_from_configs(
64
65
66
67
68
                model_config=self.vllm_config.model_config
            )
            reasoning_parser = (
                self.vllm_config.structured_outputs_config.reasoning_parser
            )
69
            if reasoning_parser:
70
                reasoner_cls = ReasoningParserManager.get_reasoning_parser(
71
72
                    reasoning_parser
                )
73
                self.reasoner = reasoner_cls(tokenizer=self.tokenizer)
74

75
    def grammar_init(self, request: Request) -> None:
76
77
78
        if request.structured_output_request is None:
            return

79
        if TYPE_CHECKING:
80
81
82
83
            assert (
                request.sampling_params is not None
                and request.sampling_params.structured_outputs is not None
            )
84

85
86
87
88
        # Initialize the backend the first time it is needed.
        #
        # NOTE: We only support a single backend. We do NOT support different
        # backends on a per-request basis in V1 (for now, anyway...).
89
        # _backend is set in Processor._validate_structured_output
90
        if self.backend is None:
91
            assert request.sampling_params is not None
92
            backend = request.sampling_params.structured_outputs._backend
93
            vocab_size = self.vllm_config.model_config.get_vocab_size()
94
            if backend == "xgrammar":
95
96
97
98
99
                self.backend = XgrammarBackend(
                    self.vllm_config,
                    tokenizer=self.tokenizer,
                    vocab_size=vocab_size,
                )
100
            elif backend == "guidance":
101
102
103
104
105
                self.backend = GuidanceBackend(
                    self.vllm_config,
                    tokenizer=self.tokenizer,
                    vocab_size=vocab_size,
                )
106
            elif backend == "outlines":
107
                from vllm.v1.structured_output.backend_outlines import OutlinesBackend
108
109
110
111
112
113

                self.backend = OutlinesBackend(
                    self.vllm_config,
                    tokenizer=self.tokenizer,
                    vocab_size=vocab_size,
                )
114
115
            elif backend == "lm-format-enforcer":
                from vllm.v1.structured_output.backend_lm_format_enforcer import (  # noqa: E501
116
117
118
                    LMFormatEnforcerBackend,
                )

119
120
121
122
123
                self.backend = LMFormatEnforcerBackend(
                    self.vllm_config,
                    tokenizer=self.tokenizer,
                    vocab_size=vocab_size,
                )
124
            else:
125
                raise ValueError(f"Unsupported structured output backend: {backend}")
126

127
        grammar = self.executor.submit(self._async_create_grammar, request)
128
        request.structured_output_request.grammar = grammar  # type: ignore[assignment]
129

130
    def _async_create_grammar(
131
132
133
        self,
        request: Request,
    ) -> StructuredOutputGrammar:
134
135
136
137
138
        key = request.structured_output_request.structured_output_key  # type: ignore[union-attr]

        # Note that the request was validated in the engine core client,
        # so at this point we know it is a supported type of request.
        #
139
140
        # TODO: we still need to handle xgrammar compilation failures,
        # though it should be unlikely as we test that up front as well.
141
142
        request_type, grammar_spec = key

143
144
        assert self.backend is not None
        return self.backend.compile_grammar(request_type, grammar_spec)
145

146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
    def _fill_bitmasks(
        self,
        batch: list[tuple[StructuredOutputGrammar, int, bool]],
    ) -> None:
        assert self._grammar_bitmask is not None
        for grammar, index, apply_bitmask in batch:
            if apply_bitmask and not grammar.is_terminated():
                grammar.fill_bitmask(self._grammar_bitmask, index)
            else:
                # Note that for thinking support, we will need to
                # reset the relevant part of the bitmask for consequent
                # requests here.
                self._grammar_bitmask[index].fill_(self._full_mask)

    def _async_submit_fill_bitmask(
        self,
        batch: list[tuple[StructuredOutputGrammar, int, bool]],
    ) -> Future:
        return self.executor_for_fillmask.submit(self._fill_bitmasks, batch)

166
167
168
169
    def grammar_bitmask(
        self,
        requests: dict[str, Request],
        structured_output_request_ids: dict[str, int],
170
        scheduled_spec_decode_tokens: dict[str, list[int]],
171
    ) -> npt.NDArray[np.int32] | None:
172
173
174
175
        # Prepare the structured output bitmask for this batch.
        if not structured_output_request_ids:
            return None

176
177
        max_num_spec_tokens = 0
        if self.vllm_config.speculative_config is not None:
178
            max_num_spec_tokens = (
179
                self.vllm_config.speculative_config.num_speculative_tokens
180
            )
181

182
183
        if self._grammar_bitmask is None:
            assert self.backend is not None
184
185
186
187
188
            max_batch_size = self.vllm_config.scheduler_config.max_num_seqs

            # Allocate a bitmask for each token needing to be checked:
            # one for each speculative position, and one more for the
            # bonus token / non-speculative token.
189
190
191
            self._grammar_bitmask = self.backend.allocate_token_bitmask(
                max_batch_size * (1 + max_num_spec_tokens)
            )
192
193
194
195
196
197

        # Generate a batched bitmask for all structured output requests.
        # When speculative decoding is enabled, we need to include multiple
        # masks for each request, one for each possible bonus token position.
        # These are stored inline in the tensor and unpacked by the gpu runner.
        cumulative_index = 0
198
        ordered_seq = sorted(structured_output_request_ids.items(), key=lambda x: x[1])
199

200
201
        # Optimized parallel filling of bitmasks for
        # non-spec, large-batch-size cases
202
203
204
205
        if (
            len(ordered_seq) > self.fill_bitmask_parallel_threshold
            and max_num_spec_tokens == 0
        ):
206
207
208
209
210
211
212
213
214
215
            promises = []
            batch = []
            for req_id, _ in ordered_seq:
                request = requests[req_id]
                structured_output_request = request.structured_output_request
                if TYPE_CHECKING:
                    assert structured_output_request is not None
                    assert structured_output_request.grammar is not None

                apply_bitmask = self.should_fill_bitmask(request)
216
217
218
                batch.append(
                    (structured_output_request.grammar, cumulative_index, apply_bitmask)
                )
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
                if len(batch) == self.fill_bitmask_parallel_batch_size:
                    promises.append(self._async_submit_fill_bitmask(batch))
                    batch = []

                cumulative_index += 1
            if batch:
                promises.append(self._async_submit_fill_bitmask(batch))

            # Wait for all bitmask filling tasks to complete.
            for promise in promises:
                promise.result()
        else:
            # Fallback to serial filling of bitmasks for small-batch-size cases
            for req_id, _ in ordered_seq:
                request = requests[req_id]
                structured_output_request = request.structured_output_request

                if TYPE_CHECKING:
                    assert structured_output_request is not None
                    assert structured_output_request.grammar is not None
                apply_bitmask = self.should_fill_bitmask(request)

                state_advancements = 0
                req_tokens = scheduled_spec_decode_tokens.get(req_id, [])
                for i, token in enumerate(req_tokens + [None]):
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
                    self._fill_bitmasks(
                        [
                            (
                                structured_output_request.grammar,
                                cumulative_index,
                                apply_bitmask,
                            )
                        ]
                    )

                    if (
                        apply_bitmask
                        and token is not None
                        and not structured_output_request.grammar.is_terminated()
                    ):
259
                        assert structured_output_request.grammar.accept_tokens(
260
261
                            req_id, [token]
                        )
262
                        state_advancements += 1
263
264
                    cumulative_index += 1
                if state_advancements > 0:
265
                    structured_output_request.grammar.rollback(state_advancements)
266

267
        bitmask_tensor = self._grammar_bitmask
268
269
        if cumulative_index < bitmask_tensor.shape[0]:
            bitmask_tensor = bitmask_tensor[:cumulative_index]
270
271
272
273
274

        # After finishing with the xgrammar operations, we convert to
        # np.ndarray, because that is much more efficient for serialization
        # and deserialization when sending this to the GPU workers.
        return bitmask_tensor.numpy()
275

276
277
278
279
    def should_fill_bitmask(self, request: Request) -> bool:
        if self.reasoner is not None:
            assert request.structured_output_request is not None
            if request.structured_output_request.reasoning_ended is None:
280
                request.structured_output_request.reasoning_ended = (
281
                    self.reasoner.is_reasoning_end(request.prompt_token_ids)
282
                )
283
284
285
            return request.structured_output_request.reasoning_ended
        return True

286
287
288
289
290
291
292
293
294
295
    def should_advance(self, request: Request) -> bool:
        if not request.use_structured_output:
            return False

        # To determine whether we can advance the FSM.
        # Supports thinking usage where we skip the reasoning components.
        if TYPE_CHECKING:
            assert request.structured_output_request is not None
            assert request.structured_output_request.grammar is not None
        # by default, we should always advance
296
        # for cases that don't use thinking mode.
297
298
299
300
301
302
303
304
        if self.reasoner is not None:
            structured_req = request.structured_output_request

            if structured_req.reasoning_ended:
                return True

            # Check if reasoning ends in *this* step
            if self.reasoner.is_reasoning_end(request.all_token_ids):
305
                # Reasoning just ended, so we shouldn't advance til
306
307
308
309
310
311
312
                # next pass
                structured_req.reasoning_ended = True

            return False
        else:
            return True

313
314
315
    def clear_backend(self) -> None:
        if self.backend is not None:
            self.backend.destroy()