__init__.py 6.36 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import multiprocessing
5
from concurrent.futures import ThreadPoolExecutor
6
7
8
9
from typing import TYPE_CHECKING, Optional

from vllm.config import VllmConfig
from vllm.logger import init_logger
10
from vllm.v1.structured_output.backend_guidance import GuidanceBackend
11
12
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
                                                     StructuredOutputGrammar)
13
14
15
16

if TYPE_CHECKING:
    import numpy as np
    import numpy.typing as npt
17
    import torch
18
19
20
21
22
23
24

    from vllm.v1.request import Request

logger = init_logger(__name__)


class StructuredOutputManager:
25
    """Engine-level manager for structured output requests."""
26

27
    def __init__(self, vllm_config: VllmConfig):
28
        self.backend: Optional[StructuredOutputBackend] = None
29
        self.vllm_config = vllm_config
30

31
        self._grammar_bitmask: Optional[torch.Tensor] = None
32
33
34
35
36
37
38
39

        # The default max_workers if not specified is the number of CPUs * 5,
        # which is way too high since these tasks are CPU-bound, not I/O bound.
        # We also know we would never dominate CPU usage with just grammar
        # compilation, so we set it to half the number of CPUs.
        max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
        self.executor = ThreadPoolExecutor(max_workers=max_workers)

40
    def grammar_init(self, request: Request) -> None:
41
42
43
        if request.structured_output_request is None:
            return

44
45
46
47
48
        # Initialize the backend the first time it is needed.
        #
        # NOTE: We only support a single backend. We do NOT support different
        # backends on a per-request basis in V1 (for now, anyway...).
        if self.backend is None:
49
50
            backend = request.sampling_params.guided_decoding.backend
            if backend == "xgrammar":
51
52
53
                from vllm.v1.structured_output.backend_xgrammar import (
                    XgrammarBackend)

54
                self.backend = XgrammarBackend(self.vllm_config)
55
            elif backend == "guidance":
56
                self.backend = GuidanceBackend(self.vllm_config)
57
58
            else:
                raise ValueError(
59
                    f"Unsupported structured output backend: {backend}")
60

61
        grammar = self.executor.submit(self._async_create_grammar, request)
62
        request.structured_output_request.grammar = grammar  # type: ignore[assignment]
63

64
    def _async_create_grammar(
65
66
67
        self,
        request: Request,
    ) -> StructuredOutputGrammar:
68
69
70
71
72
        key = request.structured_output_request.structured_output_key  # type: ignore[union-attr]

        # Note that the request was validated in the engine core client,
        # so at this point we know it is a supported type of request.
        #
73
74
        # TODO: we still need to handle xgrammar compilation failures,
        # though it should be unlikely as we test that up front as well.
75
76
        request_type, grammar_spec = key

77
78
        assert self.backend is not None
        return self.backend.compile_grammar(request_type, grammar_spec)
79
80
81
82
83

    def grammar_bitmask(
        self,
        requests: dict[str, Request],
        structured_output_request_ids: dict[str, int],
84
        scheduled_spec_decode_tokens: dict[str, list[int]],
85
86
87
88
89
    ) -> Optional[npt.NDArray[np.int32]]:
        # Prepare the structured output bitmask for this batch.
        if not structured_output_request_ids:
            return None

90
91
        if self._grammar_bitmask is None:
            assert self.backend is not None
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
            max_batch_size = self.vllm_config.scheduler_config.max_num_seqs
            if self.vllm_config.speculative_config is not None:
                max_num_spec_tokens = self.vllm_config.\
                    speculative_config.num_speculative_tokens
            else:
                max_num_spec_tokens = 0

            # Allocate a bitmask for each token needing to be checked:
            # one for each speculative position, and one more for the
            # bonus token / non-speculative token.
            self._grammar_bitmask = \
                self.backend.allocate_token_bitmask(
                    max_batch_size * (1 + max_num_spec_tokens))

        # Generate a batched bitmask for all structured output requests.
        # When speculative decoding is enabled, we need to include multiple
        # masks for each request, one for each possible bonus token position.
        # These are stored inline in the tensor and unpacked by the gpu runner.
        cumulative_index = 0
        ordered_seq = sorted(structured_output_request_ids.items(),
                             key=lambda x: x[1])
        # NOTE: This outer loop can likely be parallelized to improve
        # performance of bitmask generation for large batches.
        for req_id, _ in ordered_seq:
116
117
            request = requests[req_id].structured_output_request
            assert request is not None and request.grammar is not None
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
            state_advancements = 0
            req_tokens = scheduled_spec_decode_tokens.get(req_id, []) + [None]
            for i, token in enumerate(req_tokens):
                if not request.grammar.is_terminated():
                    request.grammar.fill_bitmask(self._grammar_bitmask,
                                                 cumulative_index)
                    if token is not None:
                        # In order to generate the correct bitmask for each
                        # position in the speculative sequence, we advance
                        # the FSM state for each speculative token and rollback
                        # to restore the previous state when we are finished.
                        assert request.grammar.accept_tokens(req_id, [token])
                        state_advancements += 1
                cumulative_index += 1
            if state_advancements > 0:
                request.grammar.rollback(state_advancements)

        bitmask_tensor = self._grammar_bitmask
        if cumulative_index < self._grammar_bitmask.shape[0]:
            bitmask_tensor = self._grammar_bitmask[:cumulative_index]
138
139
140
141
142

        # After finishing with the xgrammar operations, we convert to
        # np.ndarray, because that is much more efficient for serialization
        # and deserialization when sending this to the GPU workers.
        return bitmask_tensor.numpy()
143
144
145
146

    def clear_backend(self) -> None:
        if self.backend is not None:
            self.backend.destroy()