__init__.py 4.51 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import multiprocessing
5
from concurrent.futures import ThreadPoolExecutor
6
7
8
9
from typing import TYPE_CHECKING, Optional

from vllm.config import VllmConfig
from vllm.logger import init_logger
10
from vllm.v1.structured_output.backend_guidance import GuidanceBackend
11
12
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
                                                     StructuredOutputGrammar)
13
14
15
16

if TYPE_CHECKING:
    import numpy as np
    import numpy.typing as npt
17
    import torch
18
19
20
21
22
23
24

    from vllm.v1.request import Request

logger = init_logger(__name__)


class StructuredOutputManager:
25
    """Engine-level manager for structured output requests."""
26

27
    def __init__(self, vllm_config: VllmConfig):
28
        self.backend: Optional[StructuredOutputBackend] = None
29
        self.vllm_config = vllm_config
30
        self._grammar_bitmask: Optional[torch.Tensor] = None
31
32
33
34
35
36
37
38

        # The default max_workers if not specified is the number of CPUs * 5,
        # which is way too high since these tasks are CPU-bound, not I/O bound.
        # We also know we would never dominate CPU usage with just grammar
        # compilation, so we set it to half the number of CPUs.
        max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
        self.executor = ThreadPoolExecutor(max_workers=max_workers)

39
    def grammar_init(self, request: Request) -> None:
40
41
42
        if request.structured_output_request is None:
            return

43
44
45
46
47
48
49
        # Initialize the backend the first time it is needed.
        #
        # NOTE: We only support a single backend. We do NOT support different
        # backends on a per-request basis in V1 (for now, anyway...).
        if self.backend is None:
            backend_name = request.sampling_params.guided_decoding.backend_name
            if backend_name == "xgrammar":
50
51
52
                from vllm.v1.structured_output.backend_xgrammar import (
                    XgrammarBackend)

53
                self.backend = XgrammarBackend(self.vllm_config)
54
55
            elif backend_name == "guidance":
                self.backend = GuidanceBackend(self.vllm_config)
56
57
58
            else:
                raise ValueError(
                    f"Unsupported structured output backend: {backend_name}")
59

60
        grammar = self.executor.submit(self._async_create_grammar, request)
61
        request.structured_output_request.grammar = grammar  # type: ignore[assignment]
62

63
    def _async_create_grammar(
64
65
66
        self,
        request: Request,
    ) -> StructuredOutputGrammar:
67
68
69
70
71
        key = request.structured_output_request.structured_output_key  # type: ignore[union-attr]

        # Note that the request was validated in the engine core client,
        # so at this point we know it is a supported type of request.
        #
72
73
        # TODO: we still need to handle xgrammar compilation failures,
        # though it should be unlikely as we test that up front as well.
74
75
        request_type, grammar_spec = key

76
77
        assert self.backend is not None
        return self.backend.compile_grammar(request_type, grammar_spec)
78
79
80
81
82
83
84
85
86
87
88

    def grammar_bitmask(
        self,
        requests: dict[str, Request],
        structured_output_request_ids: dict[str, int],
        batch_len: int,
    ) -> Optional[npt.NDArray[np.int32]]:
        # Prepare the structured output bitmask for this batch.
        if not structured_output_request_ids:
            return None

89
90
91
92
93
        if self._grammar_bitmask is None:
            assert self.backend is not None
            self._grammar_bitmask = self.backend.allocate_token_bitmask(
                self.vllm_config.scheduler_config.max_num_seqs)

94
95
96
97
98
99
100
        # Fill the bitmask using the index of each request equal to its
        # position in the batch. Resize the bitmask down to the size of
        # the batch.
        bitmask_tensor = self._grammar_bitmask
        for req_id, batch_index in structured_output_request_ids.items():
            request = requests[req_id].structured_output_request
            assert request is not None and request.grammar is not None
101
            if not request.grammar.is_terminated():
102
103
104
105
106
107
108
109
                request.grammar.fill_bitmask(bitmask_tensor, batch_index)
        if batch_len < self._grammar_bitmask.shape[0]:
            bitmask_tensor = self._grammar_bitmask[:batch_len]

        # After finishing with the xgrammar operations, we convert to
        # np.ndarray, because that is much more efficient for serialization
        # and deserialization when sending this to the GPU workers.
        return bitmask_tensor.numpy()