__init__.py 12.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
from __future__ import annotations

import multiprocessing
6
from concurrent.futures import Future, ThreadPoolExecutor
7
8
9
10
from typing import TYPE_CHECKING, Optional

from vllm.config import VllmConfig
from vllm.logger import init_logger
11
from vllm.reasoning import ReasoningParserManager
12
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
13
from vllm.utils import LazyLoader
14
from vllm.v1.structured_output.backend_guidance import GuidanceBackend
15
16
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
                                                     StructuredOutputGrammar)
17
from vllm.v1.structured_output.backend_xgrammar import XgrammarBackend
18
19
20
21

if TYPE_CHECKING:
    import numpy as np
    import numpy.typing as npt
22
    import torch
23

24
    from vllm.reasoning import ReasoningParser
25
    from vllm.v1.request import Request
26
27
else:
    torch = LazyLoader("torch", globals(), "torch")
28
29
30
31
32

logger = init_logger(__name__)


class StructuredOutputManager:
33
    """Engine-level manager for structured output requests."""
34

35
    def __init__(self, vllm_config: VllmConfig):
36
        self.backend: Optional[StructuredOutputBackend] = None
37
        self.reasoner: Optional[ReasoningParser] = None
38
        self.vllm_config = vllm_config
39

40
        self._grammar_bitmask: Optional[torch.Tensor] = None
41
        self._full_mask = torch.tensor(-1, dtype=torch.int32)
42

43
44
45
46
47
48
49
50
51
52
53
        max_batch_size = self.vllm_config.scheduler_config.max_num_seqs
        self.fill_bitmask_parallel_threshold = 128
        if self.fill_bitmask_parallel_threshold < max_batch_size:
            self.fill_bitmask_parallel_batch_size = 16
            # Use:
            # - at least 1 CPU
            # - at most half the number of CPUs or 8, whichever is less
            max_workers = max(1, min(multiprocessing.cpu_count() // 2, 8))
            self.executor_for_fillmask = ThreadPoolExecutor(
                max_workers=max_workers)

54
55
56
57
58
59
60
61
62
        if not self.vllm_config.model_config.skip_tokenizer_init:
            # The default max_workers if not specified is the number of
            # CPUs * 5, which is way too high since these tasks are CPU-bound,
            # not I/O bound. We also know we would never dominate CPU usage
            # with just grammar compilation, so we set it to half the number
            # of CPUs.
            max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
            self.executor = ThreadPoolExecutor(max_workers=max_workers)
            self.tokenizer = init_tokenizer_from_configs(
63
                model_config=self.vllm_config.model_config)
64
65
66
67
68
69
            reasoning_backend = \
                    self.vllm_config.decoding_config.reasoning_backend
            if reasoning_backend:
                reasoner_cls = ReasoningParserManager.get_reasoning_parser(
                    reasoning_backend)
                self.reasoner = reasoner_cls(tokenizer=self.tokenizer)
70

71
    def grammar_init(self, request: Request) -> None:
72
73
74
        if request.structured_output_request is None:
            return

75
        if TYPE_CHECKING:
76
77
            assert request.sampling_params is not None and \
                request.sampling_params.guided_decoding is not None
78

79
80
81
82
83
        # Initialize the backend the first time it is needed.
        #
        # NOTE: We only support a single backend. We do NOT support different
        # backends on a per-request basis in V1 (for now, anyway...).
        if self.backend is None:
84
            assert request.sampling_params is not None
85
            backend = request.sampling_params.guided_decoding.backend
86
            vocab_size = self.vllm_config.model_config.get_vocab_size()
87
            if backend == "xgrammar":
88
89
90
91
92
                self.backend = XgrammarBackend(
                    self.vllm_config,
                    tokenizer=self.tokenizer,
                    vocab_size=vocab_size,
                )
93
            elif backend == "guidance":
94
95
96
97
98
                self.backend = GuidanceBackend(
                    self.vllm_config,
                    tokenizer=self.tokenizer,
                    vocab_size=vocab_size,
                )
99
100
101
102
103
104
105
106
107
            elif backend == "outlines":
                from vllm.v1.structured_output.backend_outlines import (
                    OutlinesBackend)

                self.backend = OutlinesBackend(
                    self.vllm_config,
                    tokenizer=self.tokenizer,
                    vocab_size=vocab_size,
                )
108
109
110
111
112
113
114
115
            elif backend == "lm-format-enforcer":
                from vllm.v1.structured_output.backend_lm_format_enforcer import (  # noqa: E501
                    LMFormatEnforcerBackend)
                self.backend = LMFormatEnforcerBackend(
                    self.vllm_config,
                    tokenizer=self.tokenizer,
                    vocab_size=vocab_size,
                )
116
117
            else:
                raise ValueError(
118
                    f"Unsupported structured output backend: {backend}")
119

120
        grammar = self.executor.submit(self._async_create_grammar, request)
121
        request.structured_output_request.grammar = grammar  # type: ignore[assignment]
122

123
    def _async_create_grammar(
124
125
126
        self,
        request: Request,
    ) -> StructuredOutputGrammar:
127
128
129
130
131
        key = request.structured_output_request.structured_output_key  # type: ignore[union-attr]

        # Note that the request was validated in the engine core client,
        # so at this point we know it is a supported type of request.
        #
132
133
        # TODO: we still need to handle xgrammar compilation failures,
        # though it should be unlikely as we test that up front as well.
134
135
        request_type, grammar_spec = key

136
137
        assert self.backend is not None
        return self.backend.compile_grammar(request_type, grammar_spec)
138

139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
    def _fill_bitmasks(
        self,
        batch: list[tuple[StructuredOutputGrammar, int, bool]],
    ) -> None:
        assert self._grammar_bitmask is not None
        for grammar, index, apply_bitmask in batch:
            if apply_bitmask and not grammar.is_terminated():
                grammar.fill_bitmask(self._grammar_bitmask, index)
            else:
                # Note that for thinking support, we will need to
                # reset the relevant part of the bitmask for consequent
                # requests here.
                self._grammar_bitmask[index].fill_(self._full_mask)

    def _async_submit_fill_bitmask(
        self,
        batch: list[tuple[StructuredOutputGrammar, int, bool]],
    ) -> Future:
        return self.executor_for_fillmask.submit(self._fill_bitmasks, batch)

159
160
161
162
    def grammar_bitmask(
        self,
        requests: dict[str, Request],
        structured_output_request_ids: dict[str, int],
163
        scheduled_spec_decode_tokens: dict[str, list[int]],
164
165
166
167
168
    ) -> Optional[npt.NDArray[np.int32]]:
        # Prepare the structured output bitmask for this batch.
        if not structured_output_request_ids:
            return None

169
170
171
172
173
        max_num_spec_tokens = 0
        if self.vllm_config.speculative_config is not None:
            max_num_spec_tokens = \
                self.vllm_config.speculative_config.num_speculative_tokens

174
175
        if self._grammar_bitmask is None:
            assert self.backend is not None
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
            max_batch_size = self.vllm_config.scheduler_config.max_num_seqs

            # Allocate a bitmask for each token needing to be checked:
            # one for each speculative position, and one more for the
            # bonus token / non-speculative token.
            self._grammar_bitmask = \
                self.backend.allocate_token_bitmask(
                    max_batch_size * (1 + max_num_spec_tokens))

        # Generate a batched bitmask for all structured output requests.
        # When speculative decoding is enabled, we need to include multiple
        # masks for each request, one for each possible bonus token position.
        # These are stored inline in the tensor and unpacked by the gpu runner.
        cumulative_index = 0
        ordered_seq = sorted(structured_output_request_ids.items(),
                             key=lambda x: x[1])
192

193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
        # Optimized parallel filling of bitmasks for
        # non-spec, large-batch-size cases
        if len(ordered_seq) > self.fill_bitmask_parallel_threshold and \
                max_num_spec_tokens == 0:
            promises = []
            batch = []
            for req_id, _ in ordered_seq:
                request = requests[req_id]
                structured_output_request = request.structured_output_request
                if TYPE_CHECKING:
                    assert structured_output_request is not None
                    assert structured_output_request.grammar is not None

                apply_bitmask = self.should_fill_bitmask(request)
                batch.append((structured_output_request.grammar,
                              cumulative_index, apply_bitmask))
                if len(batch) == self.fill_bitmask_parallel_batch_size:
                    promises.append(self._async_submit_fill_bitmask(batch))
                    batch = []

                cumulative_index += 1
            if batch:
                promises.append(self._async_submit_fill_bitmask(batch))

            # Wait for all bitmask filling tasks to complete.
            for promise in promises:
                promise.result()
        else:
            # Fallback to serial filling of bitmasks for small-batch-size cases
            for req_id, _ in ordered_seq:
                request = requests[req_id]
                structured_output_request = request.structured_output_request

                if TYPE_CHECKING:
                    assert structured_output_request is not None
                    assert structured_output_request.grammar is not None
                apply_bitmask = self.should_fill_bitmask(request)

                state_advancements = 0
                req_tokens = scheduled_spec_decode_tokens.get(req_id, [])
                for i, token in enumerate(req_tokens + [None]):
                    self._fill_bitmasks([(structured_output_request.grammar,
                                          cumulative_index, apply_bitmask)])

                    if apply_bitmask and token is not None and \
                        not structured_output_request.grammar.is_terminated():
239
240
                        assert structured_output_request.grammar.accept_tokens(
                            req_id, [token])
241
                        state_advancements += 1
242
243
244
245
                    cumulative_index += 1
                if state_advancements > 0:
                    structured_output_request.grammar.rollback(
                        state_advancements)
246

247
        bitmask_tensor = self._grammar_bitmask
248
249
        if cumulative_index < bitmask_tensor.shape[0]:
            bitmask_tensor = bitmask_tensor[:cumulative_index]
250
251
252
253
254

        # After finishing with the xgrammar operations, we convert to
        # np.ndarray, because that is much more efficient for serialization
        # and deserialization when sending this to the GPU workers.
        return bitmask_tensor.numpy()
255

256
257
258
259
260
261
262
263
264
    def should_fill_bitmask(self, request: Request) -> bool:
        if self.reasoner is not None:
            assert request.structured_output_request is not None
            if request.structured_output_request.reasoning_ended is None:
                request.structured_output_request.reasoning_ended = \
                    self.reasoner.is_reasoning_end(request.prompt_token_ids)
            return request.structured_output_request.reasoning_ended
        return True

265
266
267
268
269
270
271
272
273
274
    def should_advance(self, request: Request) -> bool:
        if not request.use_structured_output:
            return False

        # To determine whether we can advance the FSM.
        # Supports thinking usage where we skip the reasoning components.
        if TYPE_CHECKING:
            assert request.structured_output_request is not None
            assert request.structured_output_request.grammar is not None
        # by default, we should always advance
275
        # for cases that don't use thinking mode.
276
277
278
279
280
281
282
283
        if self.reasoner is not None:
            structured_req = request.structured_output_request

            if structured_req.reasoning_ended:
                return True

            # Check if reasoning ends in *this* step
            if self.reasoner.is_reasoning_end(request.all_token_ids):
284
                # Reasoning just ended, so we shouldn't advance til
285
286
287
288
289
290
291
                # next pass
                structured_req.reasoning_ended = True

            return False
        else:
            return True

292
293
294
    def clear_backend(self) -> None:
        if self.backend is not None:
            self.backend.destroy()