__init__.py 9.46 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import multiprocessing
5
from concurrent.futures import ThreadPoolExecutor
6
7
8
9
from typing import TYPE_CHECKING, Optional

from vllm.config import VllmConfig
from vllm.logger import init_logger
10
11
12
from vllm.reasoning import ReasoningParserManager
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import LazyLoader
13
from vllm.v1.structured_output.backend_guidance import GuidanceBackend
14
15
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
                                                     StructuredOutputGrammar)
16
from vllm.v1.structured_output.backend_xgrammar import XgrammarBackend
17
18
19
20

if TYPE_CHECKING:
    import numpy as np
    import numpy.typing as npt
21
    import torch
22

23
    from vllm.reasoning import ReasoningParser
24
    from vllm.v1.request import Request
25
26
else:
    torch = LazyLoader("torch", globals(), "torch")
27
28
29
30
31

logger = init_logger(__name__)


class StructuredOutputManager:
32
    """Engine-level manager for structured output requests."""
33

34
    def __init__(self, vllm_config: VllmConfig):
35
        self.backend: Optional[StructuredOutputBackend] = None
36
        self.reasoner: Optional[ReasoningParser] = None
37
        self.vllm_config = vllm_config
38

39
        self._grammar_bitmask: Optional[torch.Tensor] = None
40
        self._full_mask = torch.tensor(-1, dtype=torch.int32)
41
42
43
44
45
46
47

        # The default max_workers if not specified is the number of CPUs * 5,
        # which is way too high since these tasks are CPU-bound, not I/O bound.
        # We also know we would never dominate CPU usage with just grammar
        # compilation, so we set it to half the number of CPUs.
        max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
        self.executor = ThreadPoolExecutor(max_workers=max_workers)
48
49
50
51
52
53
54
55
56
57
        self.tokenizer = init_tokenizer_from_configs(
            model_config=self.vllm_config.model_config,
            scheduler_config=self.vllm_config.scheduler_config,
            lora_config=self.vllm_config.lora_config,
        ).get_lora_tokenizer(None)
        reasoning_backend = vllm_config.decoding_config.reasoning_backend
        if reasoning_backend:
            reasoner_cls = ReasoningParserManager.get_reasoning_parser(
                reasoning_backend)
            self.reasoner = reasoner_cls(tokenizer=self.tokenizer)
58

59
    def grammar_init(self, request: Request) -> None:
60
61
62
        if request.structured_output_request is None:
            return

63
64
65
        if TYPE_CHECKING:
            assert request.sampling_params.guided_decoding is not None

66
67
68
69
70
        # Initialize the backend the first time it is needed.
        #
        # NOTE: We only support a single backend. We do NOT support different
        # backends on a per-request basis in V1 (for now, anyway...).
        if self.backend is None:
71
            backend = request.sampling_params.guided_decoding.backend
72
            vocab_size = self.vllm_config.model_config.get_vocab_size()
73
            if backend == "xgrammar":
74
75
76
77
78
                self.backend = XgrammarBackend(
                    self.vllm_config,
                    tokenizer=self.tokenizer,
                    vocab_size=vocab_size,
                )
79
            elif backend == "guidance":
80
81
82
83
84
                self.backend = GuidanceBackend(
                    self.vllm_config,
                    tokenizer=self.tokenizer,
                    vocab_size=vocab_size,
                )
85
86
            else:
                raise ValueError(
87
                    f"Unsupported structured output backend: {backend}")
88

89
        grammar = self.executor.submit(self._async_create_grammar, request)
90
        request.structured_output_request.grammar = grammar  # type: ignore[assignment]
91

92
    def _async_create_grammar(
93
94
95
        self,
        request: Request,
    ) -> StructuredOutputGrammar:
96
97
98
99
100
        key = request.structured_output_request.structured_output_key  # type: ignore[union-attr]

        # Note that the request was validated in the engine core client,
        # so at this point we know it is a supported type of request.
        #
101
102
        # TODO: we still need to handle xgrammar compilation failures,
        # though it should be unlikely as we test that up front as well.
103
104
        request_type, grammar_spec = key

105
106
        assert self.backend is not None
        return self.backend.compile_grammar(request_type, grammar_spec)
107
108
109
110
111

    def grammar_bitmask(
        self,
        requests: dict[str, Request],
        structured_output_request_ids: dict[str, int],
112
        scheduled_spec_decode_tokens: dict[str, list[int]],
113
114
115
116
117
    ) -> Optional[npt.NDArray[np.int32]]:
        # Prepare the structured output bitmask for this batch.
        if not structured_output_request_ids:
            return None

118
119
120
121
122
        max_num_spec_tokens = 0
        if self.vllm_config.speculative_config is not None:
            max_num_spec_tokens = \
                self.vllm_config.speculative_config.num_speculative_tokens

123
124
        if self._grammar_bitmask is None:
            assert self.backend is not None
125
126
127
128
129
130
131
132
133
            max_batch_size = self.vllm_config.scheduler_config.max_num_seqs

            # Allocate a bitmask for each token needing to be checked:
            # one for each speculative position, and one more for the
            # bonus token / non-speculative token.
            self._grammar_bitmask = \
                self.backend.allocate_token_bitmask(
                    max_batch_size * (1 + max_num_spec_tokens))

134
        bitmask_tensor = self._grammar_bitmask
135
136
137
138
139
140
141
        # Generate a batched bitmask for all structured output requests.
        # When speculative decoding is enabled, we need to include multiple
        # masks for each request, one for each possible bonus token position.
        # These are stored inline in the tensor and unpacked by the gpu runner.
        cumulative_index = 0
        ordered_seq = sorted(structured_output_request_ids.items(),
                             key=lambda x: x[1])
142
143
144
145
146
147
148

        # Note that for thinking support, we will need to
        # reset the relevant part of the bitmask for consequent
        # request here.
        bitmask_tensor[:(len(ordered_seq) * (1 + max_num_spec_tokens))].fill_(
            self._full_mask)

149
150
151
        # NOTE: This outer loop can likely be parallelized to improve
        # performance of bitmask generation for large batches.
        for req_id, _ in ordered_seq:
152
153
            request = requests[req_id]
            structured_output_request = request.structured_output_request
154

155
156
157
158
159
160
161
162
163
            if TYPE_CHECKING:
                assert structured_output_request is not None
                assert structured_output_request.grammar is not None
            apply_bitmask: bool = True
            if self.reasoner is not None:
                if structured_output_request.reasoning_ended is None:
                    structured_output_request.reasoning_ended = \
                        self.reasoner.is_reasoning_end(request.prompt_token_ids)
                apply_bitmask = structured_output_request.reasoning_ended
164

165
166
167
            state_advancements = 0
            req_tokens = scheduled_spec_decode_tokens.get(req_id, []) + [None]
            for i, token in enumerate(req_tokens):
168
169
170
171
                if apply_bitmask and not \
                    structured_output_request.grammar.is_terminated():
                    structured_output_request.grammar.fill_bitmask(
                        bitmask_tensor, cumulative_index)
172
173
174
175
176
                    if token is not None:
                        # In order to generate the correct bitmask for each
                        # position in the speculative sequence, we advance
                        # the FSM state for each speculative token and rollback
                        # to restore the previous state when we are finished.
177
178
                        assert structured_output_request.grammar.accept_tokens(
                            req_id, [token])
179
180
181
                        state_advancements += 1
                cumulative_index += 1
            if state_advancements > 0:
182
                structured_output_request.grammar.rollback(state_advancements)
183

184
185
        if cumulative_index < bitmask_tensor.shape[0]:
            bitmask_tensor = bitmask_tensor[:cumulative_index]
186
187
188
189
190

        # After finishing with the xgrammar operations, we convert to
        # np.ndarray, because that is much more efficient for serialization
        # and deserialization when sending this to the GPU workers.
        return bitmask_tensor.numpy()
191

192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
    def should_advance(self, request: Request) -> bool:
        if not request.use_structured_output:
            return False

        # To determine whether we can advance the FSM.
        # Supports thinking usage where we skip the reasoning components.
        if TYPE_CHECKING:
            assert request.structured_output_request is not None
            assert request.structured_output_request.grammar is not None
        # by default, we should always advance
        # for cases that doesn't uses thinking mode.
        if self.reasoner is not None:
            structured_req = request.structured_output_request

            if structured_req.reasoning_ended:
                return True

            # Check if reasoning ends in *this* step
            if self.reasoner.is_reasoning_end(request.all_token_ids):
                # Reasoning just ended, so we shouldn't advanced til
                # next pass
                structured_req.reasoning_ended = True

            return False
        else:
            return True

219
220
221
    def clear_backend(self) -> None:
        if self.backend is not None:
            self.backend.destroy()