__init__.py 12.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
from __future__ import annotations

import multiprocessing
6
from concurrent.futures import Future, ThreadPoolExecutor
7
8
9
10
from typing import TYPE_CHECKING, Optional

from vllm.config import VllmConfig
from vllm.logger import init_logger
11
from vllm.reasoning import ReasoningParserManager
12
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
13
from vllm.utils import LazyLoader
14
from vllm.v1.structured_output.backend_guidance import GuidanceBackend
15
16
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
                                                     StructuredOutputGrammar)
17
from vllm.v1.structured_output.backend_xgrammar import XgrammarBackend
18
19
20
21

if TYPE_CHECKING:
    import numpy as np
    import numpy.typing as npt
22
    import torch
23

24
    from vllm.reasoning import ReasoningParser
25
    from vllm.v1.request import Request
26
27
else:
    torch = LazyLoader("torch", globals(), "torch")
28
29
30
31
32

logger = init_logger(__name__)


class StructuredOutputManager:
33
    """Engine-level manager for structured output requests."""
34

35
    def __init__(self, vllm_config: VllmConfig):
36
        self.backend: Optional[StructuredOutputBackend] = None
37
        self.reasoner: Optional[ReasoningParser] = None
38
        self.vllm_config = vllm_config
39

40
        self._grammar_bitmask: Optional[torch.Tensor] = None
41
        self._full_mask = torch.tensor(-1, dtype=torch.int32)
42

43
44
45
46
47
48
49
50
51
52
53
        max_batch_size = self.vllm_config.scheduler_config.max_num_seqs
        self.fill_bitmask_parallel_threshold = 128
        if self.fill_bitmask_parallel_threshold < max_batch_size:
            self.fill_bitmask_parallel_batch_size = 16
            # Use:
            # - at least 1 CPU
            # - at most half the number of CPUs or 8, whichever is less
            max_workers = max(1, min(multiprocessing.cpu_count() // 2, 8))
            self.executor_for_fillmask = ThreadPoolExecutor(
                max_workers=max_workers)

54
55
56
57
58
59
60
61
62
        if not self.vllm_config.model_config.skip_tokenizer_init:
            # The default max_workers if not specified is the number of
            # CPUs * 5, which is way too high since these tasks are CPU-bound,
            # not I/O bound. We also know we would never dominate CPU usage
            # with just grammar compilation, so we set it to half the number
            # of CPUs.
            max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
            self.executor = ThreadPoolExecutor(max_workers=max_workers)
            self.tokenizer = init_tokenizer_from_configs(
63
                model_config=self.vllm_config.model_config)
64
65
66
            reasoning_parser = \
                    self.vllm_config.structured_outputs_config.reasoning_parser
            if reasoning_parser:
67
                reasoner_cls = ReasoningParserManager.get_reasoning_parser(
68
                    reasoning_parser)
69
                self.reasoner = reasoner_cls(tokenizer=self.tokenizer)
70

71
    def grammar_init(self, request: Request) -> None:
72
73
74
        if request.structured_output_request is None:
            return

75
        if TYPE_CHECKING:
76
            assert request.sampling_params is not None and \
77
                request.sampling_params.structured_outputs is not None
78

79
80
81
82
        # Initialize the backend the first time it is needed.
        #
        # NOTE: We only support a single backend. We do NOT support different
        # backends on a per-request basis in V1 (for now, anyway...).
83
        # _backend is set in Processor._validate_structured_output
84
        if self.backend is None:
85
            assert request.sampling_params is not None
86
            backend = request.sampling_params.structured_outputs._backend
87
            vocab_size = self.vllm_config.model_config.get_vocab_size()
88
            if backend == "xgrammar":
89
90
91
92
93
                self.backend = XgrammarBackend(
                    self.vllm_config,
                    tokenizer=self.tokenizer,
                    vocab_size=vocab_size,
                )
94
            elif backend == "guidance":
95
96
97
98
99
                self.backend = GuidanceBackend(
                    self.vllm_config,
                    tokenizer=self.tokenizer,
                    vocab_size=vocab_size,
                )
100
101
102
103
104
105
106
107
108
            elif backend == "outlines":
                from vllm.v1.structured_output.backend_outlines import (
                    OutlinesBackend)

                self.backend = OutlinesBackend(
                    self.vllm_config,
                    tokenizer=self.tokenizer,
                    vocab_size=vocab_size,
                )
109
110
111
112
113
114
115
116
            elif backend == "lm-format-enforcer":
                from vllm.v1.structured_output.backend_lm_format_enforcer import (  # noqa: E501
                    LMFormatEnforcerBackend)
                self.backend = LMFormatEnforcerBackend(
                    self.vllm_config,
                    tokenizer=self.tokenizer,
                    vocab_size=vocab_size,
                )
117
118
            else:
                raise ValueError(
119
                    f"Unsupported structured output backend: {backend}")
120

121
        grammar = self.executor.submit(self._async_create_grammar, request)
122
        request.structured_output_request.grammar = grammar  # type: ignore[assignment]
123

124
    def _async_create_grammar(
125
126
127
        self,
        request: Request,
    ) -> StructuredOutputGrammar:
128
129
130
131
132
        key = request.structured_output_request.structured_output_key  # type: ignore[union-attr]

        # Note that the request was validated in the engine core client,
        # so at this point we know it is a supported type of request.
        #
133
134
        # TODO: we still need to handle xgrammar compilation failures,
        # though it should be unlikely as we test that up front as well.
135
136
        request_type, grammar_spec = key

137
138
        assert self.backend is not None
        return self.backend.compile_grammar(request_type, grammar_spec)
139

140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
    def _fill_bitmasks(
        self,
        batch: list[tuple[StructuredOutputGrammar, int, bool]],
    ) -> None:
        assert self._grammar_bitmask is not None
        for grammar, index, apply_bitmask in batch:
            if apply_bitmask and not grammar.is_terminated():
                grammar.fill_bitmask(self._grammar_bitmask, index)
            else:
                # Note that for thinking support, we will need to
                # reset the relevant part of the bitmask for consequent
                # requests here.
                self._grammar_bitmask[index].fill_(self._full_mask)

    def _async_submit_fill_bitmask(
        self,
        batch: list[tuple[StructuredOutputGrammar, int, bool]],
    ) -> Future:
        return self.executor_for_fillmask.submit(self._fill_bitmasks, batch)

160
161
162
163
    def grammar_bitmask(
        self,
        requests: dict[str, Request],
        structured_output_request_ids: dict[str, int],
164
        scheduled_spec_decode_tokens: dict[str, list[int]],
165
166
167
168
169
    ) -> Optional[npt.NDArray[np.int32]]:
        # Prepare the structured output bitmask for this batch.
        if not structured_output_request_ids:
            return None

170
171
172
173
174
        max_num_spec_tokens = 0
        if self.vllm_config.speculative_config is not None:
            max_num_spec_tokens = \
                self.vllm_config.speculative_config.num_speculative_tokens

175
176
        if self._grammar_bitmask is None:
            assert self.backend is not None
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
            max_batch_size = self.vllm_config.scheduler_config.max_num_seqs

            # Allocate a bitmask for each token needing to be checked:
            # one for each speculative position, and one more for the
            # bonus token / non-speculative token.
            self._grammar_bitmask = \
                self.backend.allocate_token_bitmask(
                    max_batch_size * (1 + max_num_spec_tokens))

        # Generate a batched bitmask for all structured output requests.
        # When speculative decoding is enabled, we need to include multiple
        # masks for each request, one for each possible bonus token position.
        # These are stored inline in the tensor and unpacked by the gpu runner.
        cumulative_index = 0
        ordered_seq = sorted(structured_output_request_ids.items(),
                             key=lambda x: x[1])
193

194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
        # Optimized parallel filling of bitmasks for
        # non-spec, large-batch-size cases
        if len(ordered_seq) > self.fill_bitmask_parallel_threshold and \
                max_num_spec_tokens == 0:
            promises = []
            batch = []
            for req_id, _ in ordered_seq:
                request = requests[req_id]
                structured_output_request = request.structured_output_request
                if TYPE_CHECKING:
                    assert structured_output_request is not None
                    assert structured_output_request.grammar is not None

                apply_bitmask = self.should_fill_bitmask(request)
                batch.append((structured_output_request.grammar,
                              cumulative_index, apply_bitmask))
                if len(batch) == self.fill_bitmask_parallel_batch_size:
                    promises.append(self._async_submit_fill_bitmask(batch))
                    batch = []

                cumulative_index += 1
            if batch:
                promises.append(self._async_submit_fill_bitmask(batch))

            # Wait for all bitmask filling tasks to complete.
            for promise in promises:
                promise.result()
        else:
            # Fallback to serial filling of bitmasks for small-batch-size cases
            for req_id, _ in ordered_seq:
                request = requests[req_id]
                structured_output_request = request.structured_output_request

                if TYPE_CHECKING:
                    assert structured_output_request is not None
                    assert structured_output_request.grammar is not None
                apply_bitmask = self.should_fill_bitmask(request)

                state_advancements = 0
                req_tokens = scheduled_spec_decode_tokens.get(req_id, [])
                for i, token in enumerate(req_tokens + [None]):
                    self._fill_bitmasks([(structured_output_request.grammar,
                                          cumulative_index, apply_bitmask)])

                    if apply_bitmask and token is not None and \
                        not structured_output_request.grammar.is_terminated():
240
241
                        assert structured_output_request.grammar.accept_tokens(
                            req_id, [token])
242
                        state_advancements += 1
243
244
245
246
                    cumulative_index += 1
                if state_advancements > 0:
                    structured_output_request.grammar.rollback(
                        state_advancements)
247

248
        bitmask_tensor = self._grammar_bitmask
249
250
        if cumulative_index < bitmask_tensor.shape[0]:
            bitmask_tensor = bitmask_tensor[:cumulative_index]
251
252
253
254
255

        # After finishing with the xgrammar operations, we convert to
        # np.ndarray, because that is much more efficient for serialization
        # and deserialization when sending this to the GPU workers.
        return bitmask_tensor.numpy()
256

257
258
259
260
261
262
263
264
265
    def should_fill_bitmask(self, request: Request) -> bool:
        if self.reasoner is not None:
            assert request.structured_output_request is not None
            if request.structured_output_request.reasoning_ended is None:
                request.structured_output_request.reasoning_ended = \
                    self.reasoner.is_reasoning_end(request.prompt_token_ids)
            return request.structured_output_request.reasoning_ended
        return True

266
267
268
269
270
271
272
273
274
275
    def should_advance(self, request: Request) -> bool:
        if not request.use_structured_output:
            return False

        # To determine whether we can advance the FSM.
        # Supports thinking usage where we skip the reasoning components.
        if TYPE_CHECKING:
            assert request.structured_output_request is not None
            assert request.structured_output_request.grammar is not None
        # by default, we should always advance
276
        # for cases that don't use thinking mode.
277
278
279
280
281
282
283
284
        if self.reasoner is not None:
            structured_req = request.structured_output_request

            if structured_req.reasoning_ended:
                return True

            # Check if reasoning ends in *this* step
            if self.reasoner.is_reasoning_end(request.all_token_ids):
285
                # Reasoning just ended, so we shouldn't advance til
286
287
288
289
290
291
292
                # next pass
                structured_req.reasoning_ended = True

            return False
        else:
            return True

293
294
295
    def clear_backend(self) -> None:
        if self.backend is not None:
            self.backend.destroy()