__init__.py 14.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import itertools
4
import multiprocessing
5
from collections.abc import Iterable
6
from concurrent.futures import Future, ThreadPoolExecutor
7
from typing import TYPE_CHECKING
8
9
10

from vllm.config import VllmConfig
from vllm.logger import init_logger
11
from vllm.reasoning import ReasoningParserManager
12
from vllm.tokenizers import cached_tokenizer_from_config
13
from vllm.utils.import_utils import LazyLoader
14
from vllm.v1.structured_output.backend_guidance import GuidanceBackend
15
16
17
18
from vllm.v1.structured_output.backend_types import (
    StructuredOutputBackend,
    StructuredOutputGrammar,
)
19
from vllm.v1.structured_output.backend_xgrammar import XgrammarBackend
20
21
22
23

if TYPE_CHECKING:
    import numpy as np
    import numpy.typing as npt
24
    import torch
25

26
    from vllm.reasoning import ReasoningParser
27
    from vllm.v1.request import Request
28
29
else:
    torch = LazyLoader("torch", globals(), "torch")
30

31

32
33
34
35
logger = init_logger(__name__)


class StructuredOutputManager:
36
    """Engine-level manager for structured output requests."""
37

38
    def __init__(self, vllm_config: VllmConfig):
39
40
        self.backend: StructuredOutputBackend | None = None
        self.reasoner: ReasoningParser | None = None
41
        self.vllm_config = vllm_config
42

43
44
45
46
47
48
49
50
51
52
        # When in external_launcher mode, async grammar compilation causes deadlocks
        # due to external_launcher mode having a scheduler for each TP rank.
        # Async grammar compilation causes the WAITING_FOR_FSM → WAITING transition to
        # happen at different times on different TP ranks,
        # breaking the determinism assumption that external_launcher relies on.
        self._use_async_grammar_compilation = (
            vllm_config.parallel_config.distributed_executor_backend
            != "external_launcher"
        )

53
        self._grammar_bitmask: torch.Tensor | None = None
54
        self._full_mask = torch.tensor(-1, dtype=torch.int32)
55

56
57
58
59
60
61
62
63
        max_batch_size = self.vllm_config.scheduler_config.max_num_seqs
        self.fill_bitmask_parallel_threshold = 128
        if self.fill_bitmask_parallel_threshold < max_batch_size:
            self.fill_bitmask_parallel_batch_size = 16
            # Use:
            # - at least 1 CPU
            # - at most half the number of CPUs or 8, whichever is less
            max_workers = max(1, min(multiprocessing.cpu_count() // 2, 8))
64
            self.executor_for_fillmask = ThreadPoolExecutor(max_workers=max_workers)
65

66
        if not self.vllm_config.model_config.skip_tokenizer_init:
67
68
69
70
71
72
73
            # The default max_workers if not specified is the number of
            # CPUs * 5, which is way too high since these tasks are CPU-bound,
            # not I/O bound. We also know we would never dominate CPU usage
            # with just grammar compilation, so we set it to half the number
            # of CPUs.
            max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
            self.executor = ThreadPoolExecutor(max_workers=max_workers)
74
            self.tokenizer = cached_tokenizer_from_config(
75
76
                model_config=self.vllm_config.model_config
            )
77
            reasoning_parser_plugin = (
78
                self.vllm_config.structured_outputs_config.reasoning_parser_plugin
79
80
81
82
            )
            if reasoning_parser_plugin and len(reasoning_parser_plugin) > 3:
                ReasoningParserManager.import_reasoning_parser(reasoning_parser_plugin)

83
84
85
            reasoning_parser = (
                self.vllm_config.structured_outputs_config.reasoning_parser
            )
86
            if reasoning_parser:
87
                reasoner_cls = ReasoningParserManager.get_reasoning_parser(
88
89
                    reasoning_parser
                )
90
                self.reasoner = reasoner_cls(tokenizer=self.tokenizer)
91

92
        self.enable_in_reasoning = (
93
            self.vllm_config.structured_outputs_config.enable_in_reasoning
94
95
        )

96
    def grammar_init(self, request: "Request") -> None:
97
98
99
        if request.structured_output_request is None:
            return

100
        if TYPE_CHECKING:
101
102
103
104
            assert (
                request.sampling_params is not None
                and request.sampling_params.structured_outputs is not None
            )
105

106
107
108
109
        # Initialize the backend the first time it is needed.
        #
        # NOTE: We only support a single backend. We do NOT support different
        # backends on a per-request basis in V1 (for now, anyway...).
110
        # _backend is set in Processor._validate_structured_output
111
        if self.backend is None:
112
            assert request.sampling_params is not None
113
            backend = request.sampling_params.structured_outputs._backend
114
            vocab_size = self.vllm_config.model_config.get_vocab_size()
115
            if backend == "xgrammar":
116
117
118
119
120
                self.backend = XgrammarBackend(
                    self.vllm_config,
                    tokenizer=self.tokenizer,
                    vocab_size=vocab_size,
                )
121
            elif backend == "guidance":
122
123
124
125
126
                self.backend = GuidanceBackend(
                    self.vllm_config,
                    tokenizer=self.tokenizer,
                    vocab_size=vocab_size,
                )
127
            elif backend == "outlines":
128
                from vllm.v1.structured_output.backend_outlines import OutlinesBackend
129
130
131
132
133
134

                self.backend = OutlinesBackend(
                    self.vllm_config,
                    tokenizer=self.tokenizer,
                    vocab_size=vocab_size,
                )
135
136
            elif backend == "lm-format-enforcer":
                from vllm.v1.structured_output.backend_lm_format_enforcer import (  # noqa: E501
137
138
139
                    LMFormatEnforcerBackend,
                )

140
141
142
143
144
                self.backend = LMFormatEnforcerBackend(
                    self.vllm_config,
                    tokenizer=self.tokenizer,
                    vocab_size=vocab_size,
                )
145
            else:
146
                raise ValueError(f"Unsupported structured output backend: {backend}")
147

148
149
150
151
        if self._use_async_grammar_compilation:
            grammar = self.executor.submit(self._create_grammar, request)
        else:
            grammar = self._create_grammar(request)  # type: ignore[assignment]
152
        request.structured_output_request.grammar = grammar  # type: ignore[assignment]
153

154
    def _create_grammar(self, request: "Request") -> StructuredOutputGrammar:
155
156
157
158
159
        key = request.structured_output_request.structured_output_key  # type: ignore[union-attr]

        # Note that the request was validated in the engine core client,
        # so at this point we know it is a supported type of request.
        #
160
161
        # TODO: we still need to handle xgrammar compilation failures,
        # though it should be unlikely as we test that up front as well.
162
163
        request_type, grammar_spec = key

164
165
        assert self.backend is not None
        return self.backend.compile_grammar(request_type, grammar_spec)
166

167
    def _fill_bitmasks(
168
        self, batch: Iterable[tuple[StructuredOutputGrammar, int, bool]]
169
170
171
172
173
174
175
176
177
178
179
180
    ) -> None:
        assert self._grammar_bitmask is not None
        for grammar, index, apply_bitmask in batch:
            if apply_bitmask and not grammar.is_terminated():
                grammar.fill_bitmask(self._grammar_bitmask, index)
            else:
                # Note that for thinking support, we will need to
                # reset the relevant part of the bitmask for consequent
                # requests here.
                self._grammar_bitmask[index].fill_(self._full_mask)

    def _async_submit_fill_bitmask(
181
        self, batch: list[tuple[StructuredOutputGrammar, int, bool]]
182
183
184
    ) -> Future:
        return self.executor_for_fillmask.submit(self._fill_bitmasks, batch)

185
186
    def grammar_bitmask(
        self,
187
        requests: dict[str, "Request"],
188
        structured_output_request_ids: list[str],
189
        scheduled_spec_decode_tokens: dict[str, list[int]],
190
    ) -> "npt.NDArray[np.int32] | None":
191
192
193
194
        # Prepare the structured output bitmask for this batch.
        if not structured_output_request_ids:
            return None

195
196
        max_num_spec_tokens = 0
        if self.vllm_config.speculative_config is not None:
197
            max_num_spec_tokens = (
198
                self.vllm_config.speculative_config.num_speculative_tokens
199
            )
200

201
202
        if self._grammar_bitmask is None:
            assert self.backend is not None
203
204
205
206
207
            max_batch_size = self.vllm_config.scheduler_config.max_num_seqs

            # Allocate a bitmask for each token needing to be checked:
            # one for each speculative position, and one more for the
            # bonus token / non-speculative token.
208
209
210
            self._grammar_bitmask = self.backend.allocate_token_bitmask(
                max_batch_size * (1 + max_num_spec_tokens)
            )
211
212
213
214
215
216

        # Generate a batched bitmask for all structured output requests.
        # When speculative decoding is enabled, we need to include multiple
        # masks for each request, one for each possible bonus token position.
        # These are stored inline in the tensor and unpacked by the gpu runner.
        cumulative_index = 0
217

218
219
        # Optimized parallel filling of bitmasks for
        # non-spec, large-batch-size cases
220
        if (
221
            len(structured_output_request_ids) > self.fill_bitmask_parallel_threshold
222
223
            and max_num_spec_tokens == 0
        ):
224
225
            promises = []
            batch = []
226
            for req_id in structured_output_request_ids:
227
228
229
230
231
                request = requests[req_id]
                structured_output_request = request.structured_output_request
                if TYPE_CHECKING:
                    assert structured_output_request is not None
                    assert structured_output_request.grammar is not None
232
                grammar = structured_output_request.grammar
233
234

                apply_bitmask = self.should_fill_bitmask(request)
235
                batch.append((grammar, cumulative_index, apply_bitmask))
236
237
238
239
240
241
242
243
244
245
246
247
248
                if len(batch) == self.fill_bitmask_parallel_batch_size:
                    promises.append(self._async_submit_fill_bitmask(batch))
                    batch = []

                cumulative_index += 1
            if batch:
                promises.append(self._async_submit_fill_bitmask(batch))

            # Wait for all bitmask filling tasks to complete.
            for promise in promises:
                promise.result()
        else:
            # Fallback to serial filling of bitmasks for small-batch-size cases
249
            for req_id in structured_output_request_ids:
250
251
252
253
254
255
                request = requests[req_id]
                structured_output_request = request.structured_output_request

                if TYPE_CHECKING:
                    assert structured_output_request is not None
                    assert structured_output_request.grammar is not None
256
                grammar = structured_output_request.grammar
257
258
259
                apply_bitmask = self.should_fill_bitmask(request)

                state_advancements = 0
260
                req_tokens = scheduled_spec_decode_tokens.get(req_id, ())
261
262
263
264
265
266
267
                for token in itertools.chain(req_tokens, (-1,)):
                    self._fill_bitmasks(((grammar, cumulative_index, apply_bitmask),))
                    if token == -1:
                        # Stop advancing the grammar once we hit a padding token.
                        apply_bitmask = False
                    if apply_bitmask and not grammar.is_terminated():
                        accepted = grammar.accept_tokens(req_id, [token])
268
                        assert accepted, (token, req_id, scheduled_spec_decode_tokens)
269
                        state_advancements += 1
270
271
                    cumulative_index += 1
                if state_advancements > 0:
272
                    grammar.rollback(state_advancements)
273

274
        bitmask_tensor = self._grammar_bitmask
275
276
        if cumulative_index < bitmask_tensor.shape[0]:
            bitmask_tensor = bitmask_tensor[:cumulative_index]
277
278
279
280
281

        # After finishing with the xgrammar operations, we convert to
        # np.ndarray, because that is much more efficient for serialization
        # and deserialization when sending this to the GPU workers.
        return bitmask_tensor.numpy()
282

283
    def should_fill_bitmask(self, request: "Request") -> bool:
284
285
286
        # NOTE (Hanchen) if enable_in_reasoning is True, it means that
        # the model needs to be constrained in reasoning. So we should always
        # enable the bitmask filling.
287
        if self.reasoner is not None:
288
289
            if self.enable_in_reasoning:
                return True
290
291
            assert request.structured_output_request is not None
            if request.structured_output_request.reasoning_ended is None:
292
293
294
295
                # This should be removed here, but since `openai_gptoss`
                # is an independent code path, it is kept for now.
                # After unifying the `openai_gptoss` and non-`openai_gptoss` styles,
                # it can be removed.
296
                request.structured_output_request.reasoning_ended = (
297
                    self.reasoner.is_reasoning_end(request.prompt_token_ids or [])
298
                )
299
300
301
            return request.structured_output_request.reasoning_ended
        return True

302
    def should_advance(self, request: "Request") -> bool:
303
304
305
306
307
308
309
310
311
        if not request.use_structured_output:
            return False

        # To determine whether we can advance the FSM.
        # Supports thinking usage where we skip the reasoning components.
        if TYPE_CHECKING:
            assert request.structured_output_request is not None
            assert request.structured_output_request.grammar is not None
        # by default, we should always advance
312
        # for cases that don't use thinking mode.
313
314
        if self.reasoner is None:
            return True
315

316
317
318
319
        # if the model needs structured in reasoning, we should advance
        if self.enable_in_reasoning:
            return True

320
321
322
        structured_req = request.structured_output_request
        if structured_req.reasoning_ended:
            return True
323

324
        # Check if reasoning ends in *this* step
325
        delta_from = request.num_computed_tokens - request.num_output_placeholders
326
        all_token_ids = request.all_token_ids
327
328
329
        start = (
            delta_from if delta_from >= 0 else max(len(all_token_ids) + delta_from, 0)
        )
330
        if self.reasoner.is_reasoning_end_streaming(
331
            all_token_ids, itertools.islice(all_token_ids, start, None)
332
        ):
333
334
335
            # Reasoning just ended, so we shouldn't advance til
            # next pass
            structured_req.reasoning_ended = True
336

337
        return False
338

339
340
    def clear_backend(self) -> None:
        if self.backend is not None:
zhuwenwen's avatar
zhuwenwen committed
341
            self.backend.destroy()