__init__.py 5.54 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import multiprocessing
5
from concurrent.futures import Future, ThreadPoolExecutor
6
7
8
9
10
11
from typing import TYPE_CHECKING, Optional

from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import LazyLoader
12
from vllm.v1.structured_output.grammar import Grammar, StructuredOutputOptions
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27

if TYPE_CHECKING:
    import numpy as np
    import numpy.typing as npt
    import xgrammar as xgr

    from vllm.v1.request import Request
else:
    xgr = LazyLoader("xgr", globals(), "xgrammar")

logger = init_logger(__name__)


class StructuredOutputManager:

28
    def __init__(self, vllm_config: VllmConfig):
29
30
        self.vocab_size = vllm_config.model_config.get_vocab_size()
        self.vllm_config = vllm_config
31
32
33
34
35
36
37
38
39
40
        self.init_complete = False

    def _delayed_init(self):
        """Initialization delayed until we know it is needed."""
        tokenizer_group = init_tokenizer_from_configs(
            model_config=self.vllm_config.model_config,
            scheduler_config=self.vllm_config.scheduler_config,
            parallel_config=self.vllm_config.parallel_config,
            lora_config=self.vllm_config.lora_config)  # type: ignore[arg-type]
        tokenizer_group.ping()
41
42
43
44
45
46
47
48
49
50
51
52

        tokenizer = tokenizer_group.get_lora_tokenizer(None)
        tokenizer_info = xgr.TokenizerInfo.from_huggingface(
            tokenizer, vocab_size=self.vocab_size)
        self.compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8)

        # The default max_workers if not specified is the number of CPUs * 5,
        # which is way too high since these tasks are CPU-bound, not I/O bound.
        # We also know we would never dominate CPU usage with just grammar
        # compilation, so we set it to half the number of CPUs.
        max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
        self.executor = ThreadPoolExecutor(max_workers=max_workers)
53
54
55
56
        self._grammar_bitmask = xgr.allocate_token_bitmask(
            self.vllm_config.scheduler_config.max_num_seqs, self.vocab_size)

        self.init_complete = True
57

58
    def grammar_init(self, request: Request) -> None:
59
60
61
        if request.structured_output_request is None:
            return

62
63
64
65
66
67
        # The first time this is called, we need to finish initialization
        # of xgrammar. We defer it to avoid the import of xgrammar and
        # initialization cost if it is not going to be used.
        if not self.init_complete:
            self._delayed_init()

68
69
70
        grammar: Future[Grammar] = self.executor.submit(
            self._async_create_grammar, request)
        request.structured_output_request.grammar = grammar  # type: ignore[assignment]
71

72
    def _async_create_grammar(self, request: Request) -> Grammar:
73
74
75
76
77
        key = request.structured_output_request.structured_output_key  # type: ignore[union-attr]

        # Note that the request was validated in the engine core client,
        # so at this point we know it is a supported type of request.
        #
78
79
        # TODO: we still need to handle xgrammar compilation failures,
        # though it should be unlikely as we test that up front as well.
80
81
82
83
84
85
86
87
88
89
90
        request_type, grammar_spec = key

        if request_type == StructuredOutputOptions.JSON:
            # TODO -- allow any_whitespace to be configurable
            # pending merge of https://github.com/vllm-project/vllm/pull/12744
            ctx = self.compiler.compile_json_schema(grammar_spec,
                                                    any_whitespace=False)
        elif request_type == StructuredOutputOptions.JSON_OBJECT:
            ctx = self.compiler.compile_builtin_json_grammar()
        elif request_type == StructuredOutputOptions.GRAMMAR:
            ctx = self.compiler.compile_grammar(grammar_spec)
91
92
        elif request_type == StructuredOutputOptions.REGEX:
            ctx = self.compiler.compile_regex(grammar_spec)
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
        else:
            logger.error("Validation should have already occurred. "
                         "Please file an issue.")
            raise ValueError(
                f"grammar is not of valid supported types. ({request_type!s})")

        return Grammar(
            matcher=xgr.GrammarMatcher(ctx),
            vocab_size=self.vocab_size,
            ctx=ctx,
        )

    def grammar_bitmask(
        self,
        requests: dict[str, Request],
        structured_output_request_ids: dict[str, int],
        batch_len: int,
    ) -> Optional[npt.NDArray[np.int32]]:
        # Prepare the structured output bitmask for this batch.
        if not structured_output_request_ids:
            return None

        # Fill the bitmask using the index of each request equal to its
        # position in the batch. Resize the bitmask down to the size of
        # the batch.
        bitmask_tensor = self._grammar_bitmask
        for req_id, batch_index in structured_output_request_ids.items():
            request = requests[req_id].structured_output_request
            assert request is not None and request.grammar is not None
            if not request.grammar.matcher.is_terminated():
                request.grammar.fill_bitmask(bitmask_tensor, batch_index)
        if batch_len < self._grammar_bitmask.shape[0]:
            bitmask_tensor = self._grammar_bitmask[:batch_len]

        # After finishing with the xgrammar operations, we convert to
        # np.ndarray, because that is much more efficient for serialization
        # and deserialization when sending this to the GPU workers.
        return bitmask_tensor.numpy()