backend_xgrammar.py 12.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from __future__ import annotations

6
import json
7
from dataclasses import dataclass, field
8
from typing import TYPE_CHECKING, Any
9
10
11

import torch

12
import vllm.envs
13
from vllm.logger import init_logger
14
from vllm.sampling_params import SamplingParams
15
16
17
18
19
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.utils import LazyLoader
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
                                                     StructuredOutputGrammar,
                                                     StructuredOutputOptions)
20
21
22
from vllm.v1.structured_output.utils import (choice_as_grammar,
                                             convert_lark_to_ebnf,
                                             grammar_is_likely_lark)
23
24
25
26
27
28
29
30
31

if TYPE_CHECKING:
    import xgrammar as xgr
else:
    xgr = LazyLoader("xgr", globals(), "xgrammar")

logger = init_logger(__name__)


32
@dataclass
33
34
class XgrammarBackend(StructuredOutputBackend):

35
    def __post_init__(self):
36
        self.disable_any_whitespace = \
37
            self.vllm_config.structured_outputs_config.disable_any_whitespace
38

39
        if isinstance(self.tokenizer, MistralTokenizer):
40
41
42
            # NOTE: ideally, xgrammar should handle this accordingly.
            # refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
            try:
43
44
                if self.tokenizer.is_tekken:
                    encoded_vocab = self.tokenizer._vocab
45
46
47
                else:
                    encoded_vocab = [
                        token for token, _ in sorted(
48
                            self.tokenizer.get_vocab().items(),
49
50
51
                            key=lambda x: x[1],
                        )
                    ]
52
                stop_token_ids = None
53
54
                if (hasattr(
                        self.tokenizer,
55
                        "eos_token_id",
56
57
                ) and self.tokenizer.eos_token_id is not None):
                    stop_token_ids = [self.tokenizer.eos_token_id]
58
59
60
            except AttributeError as e:
                raise ValueError(
                    f"Cannot get the vocabulary of the tokenizer "
61
                    f"{type(self.tokenizer)}. The tokenizer should have a "
62
63
64
65
                    "get_vocab method.") from e
            tokenizer_info = xgr.TokenizerInfo(  # type: ignore
                encoded_vocab=encoded_vocab,
                # NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
66
                vocab_type=xgr.VocabType.RAW
67
                if self.tokenizer.is_tekken else xgr.VocabType.BYTE_FALLBACK,
68
69
70
71
72
73
                vocab_size=self.vocab_size,
                stop_token_ids=stop_token_ids,
                add_prefix_space=True,
            )
        else:
            tokenizer_info = xgr.TokenizerInfo.from_huggingface(
74
                self.tokenizer,
75
76
                vocab_size=self.vocab_size,
            )
77
78
79
80
81
82
        self.compiler = xgr.GrammarCompiler(
            tokenizer_info,
            max_threads=8,
            cache_enabled=True,
            cache_limit_bytes=vllm.envs.VLLM_XGRAMMAR_CACHE_MB * 1024 * 1024,
        )
83

84
85
86
87
88
        self.num_speculative_tokens = 0
        if self.vllm_config.speculative_config is not None:
            self.num_speculative_tokens = \
                self.vllm_config.speculative_config.num_speculative_tokens

89
90
91
    def compile_grammar(self, request_type: StructuredOutputOptions,
                        grammar_spec: str) -> StructuredOutputGrammar:
        if request_type == StructuredOutputOptions.JSON:
92
93
            ctx = self.compiler.compile_json_schema(
                grammar_spec, any_whitespace=not self.disable_any_whitespace)
94
        elif request_type == StructuredOutputOptions.JSON_OBJECT:
95
96
97
            ctx = self.compiler.compile_json_schema(
                '{"type": "object"}',
                any_whitespace=not self.disable_any_whitespace)
98
99
100
101
        elif request_type == StructuredOutputOptions.GRAMMAR:
            ctx = self.compiler.compile_grammar(grammar_spec)
        elif request_type == StructuredOutputOptions.REGEX:
            ctx = self.compiler.compile_regex(grammar_spec)
102
103
104
105
106
107
108
109
110
        elif request_type == StructuredOutputOptions.STRUCTURAL_TAG:
            s_tag = json.loads(grammar_spec)
            tags = [
                xgr.StructuralTagItem(
                    begin=s["begin"],
                    schema=json.dumps(s["schema"]),
                    end=s["end"],
                ) for s in s_tag["structures"]
            ]
111
112
113
            structural_tag = xgr.StructuralTag.from_legacy_structural_tag(
                tags, s_tag["triggers"])
            ctx = self.compiler.compile_structural_tag(structural_tag)
114
115
116
117
118
119
120
121
        else:
            logger.error(
                "Validation should have already occurred. Please file an issue."
            )
            raise ValueError(
                f"grammar is not of valid supported types. ({request_type!s})")

        return XgrammarGrammar(
122
123
124
125
            matcher=xgr.GrammarMatcher(
                ctx,
                max_rollback_tokens=self.num_speculative_tokens,
            ),
126
127
128
129
130
131
132
            vocab_size=self.vocab_size,
            ctx=ctx,
        )

    def allocate_token_bitmask(self, max_num_seqs: int):
        return xgr.allocate_token_bitmask(max_num_seqs, self.vocab_size)

133
134
135
    def destroy(self):
        del self.compiler

136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152

@dataclass
class XgrammarGrammar(StructuredOutputGrammar):
    # NOTE: This would be a generic-enough class for
    # supporting different backends, in the future.
    # For now, just xgrammar.
    #
    # https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string
    # for jump-forward decoding

    vocab_size: int
    matcher: xgr.GrammarMatcher = field(hash=False)
    ctx: xgr.CompiledGrammar = field(hash=False)
    num_processed_tokens: int = field(default_factory=lambda: 0,
                                      repr=False,
                                      hash=False,
                                      init=False)
153
    _is_terminated: bool = field(default=False, repr=False, hash=False)
154
155
156
157
158
159
160

    def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
        """Accepts a list of tokens and advances the FSM.

        Returns True if the FSM was advanced successfully.
        Returns False if the FSM failed to advance.
        """
161
162
        if self._is_terminated:
            return False
163
164
165
166
167
168
169
        for token in tokens:
            if not self.matcher.accept_token(token):
                logger.error(
                    "Failed to advance FSM for request %s "
                    "for tokens %s. Please file an issue.", request_id, token)
                return False
            self.num_processed_tokens += 1
170
        self._is_terminated = self.matcher.is_terminated()
171
172
        return True

173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
    def validate_tokens(self, tokens: list[int]) -> list[int]:
        """Checks if the list of tokens are accepted by the FSM in sequence.
        Will not advance the FSM.

        Returns the prefix list of tokens that are accepted by the FSM.
        """
        accepted_tokens = []
        for token in tokens:
            if self.matcher.accept_token(token):
                accepted_tokens.append(token)
            else:
                break
        if len(accepted_tokens) > 0:
            # Rollback the FSM to the initial state
            self.matcher.rollback(len(accepted_tokens))
        return accepted_tokens

    def rollback(self, num_tokens: int) -> None:
        self.matcher.rollback(num_tokens)
        self.num_processed_tokens -= num_tokens
193
        self._is_terminated = self.matcher.is_terminated()
194

195
196
197
198
    def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
        self.matcher.fill_next_token_bitmask(bitmask, idx)

    def is_terminated(self) -> bool:
199
        return self._is_terminated
200
201
202
203

    def reset(self):
        self.num_processed_tokens = 0
        self.matcher.reset()
204
205
206
207
208
209
210
211
212
213


def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool:
    """Check if JSON schema contains features unsupported by xgrammar."""

    def check_object(obj: dict[str, Any]) -> bool:
        if not isinstance(obj, dict):
            return False

        # Check for numeric ranges
214
        if obj.get("type") in ("integer", "number") and ("multipleOf" in obj):
215
216
217
218
            return True

        # Check for array unsupported keywords
        if obj.get("type") == "array" and any(
219
220
                key in obj for key in ("uniqueItems", "contains",
                                       "minContains", "maxContains")):
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
            return True

        # Unsupported keywords for strings
        if obj.get("type") == "string" and "format" in obj:
            return True

        # Unsupported keywords for objects
        if obj.get("type") == "object" and any(
                key in obj for key in ("minProperties", "maxProperties",
                                       "propertyNames", "patternProperties")):
            return True

        # Recursively check all nested objects and arrays
        for value in obj.values():
            if isinstance(value, dict):
                if check_object(value):
                    return True
            elif isinstance(value, list):
                for item in value:
                    if isinstance(item, dict) and check_object(item):
                        return True

        return False

    return check_object(schema)


def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None:
    """Validate that the request is supported by structured output.

    Raises ValueError if the request is not supported.
    """
253
    if sampling_params.structured_outputs is None:
254
255
        return

256
    so_params = sampling_params.structured_outputs
257

258
    if so_params.regex:
259
        try:
260
            xgr.Grammar.from_regex(so_params.regex)
261
262
263
264
        except Exception as err:
            raise ValueError("Failed to transform regex into a grammar: "
                             f"{err}") from err

265
266
    if so_params.choice:
        choice_grammar = choice_as_grammar(so_params.choice)
267
268
269
270
271
        try:
            xgr.Grammar.from_ebnf(choice_grammar)
        except Exception as err:
            raise ValueError("Failed to transform choices into a grammar: "
                             "{err}") from err
272
273
        so_params.choice = None
        so_params.grammar = choice_grammar
274
275
        return

276
277
    if so_params.json:
        if isinstance(so_params.json, str):
278
            try:
279
                schema = json.loads(so_params.json)
280
281
282
            except json.JSONDecodeError as e:
                raise ValueError("Invalid JSON grammar specification.") from e
        else:
283
            schema = so_params.json
284

285
286
287
288
289
290
        try:
            xgr.Grammar.from_json_schema(schema)
        except Exception as err:
            raise ValueError("Failed to transform json schema into a grammar: "
                             f"{err}") from err

291
292
293
294
295
        if has_xgrammar_unsupported_json_features(schema):
            raise ValueError("The provided JSON schema contains features not "
                             "supported by xgrammar.")
        return

296
297
    if so_params.grammar:
        if grammar_is_likely_lark(so_params.grammar):
298
299
            # xgrammar supports EBNF grammars only
            try:
300
                so_params.grammar = convert_lark_to_ebnf(so_params.grammar)
301
302
303
304
305
306
307
            except ValueError as e:
                raise ValueError(
                    "Failed to convert the grammar from Lark to EBNF. ") from e

        # Test parsing EBNF grammar, possibly already converted from Lark
        try:
            # parse the grammar, but we aren't compiling it.
308
            xgr.Grammar.from_ebnf(so_params.grammar)
309
310
        except Exception as e:
            raise ValueError("Invalid grammar specification.") from e
311
312
        return

313
    if so_params.structural_tag:
314
        try:
315
            s_tag = json.loads(so_params.structural_tag)
316
317
318
319
320
321
322
            tags = [
                xgr.StructuralTagItem(
                    begin=s["begin"],
                    schema=json.dumps(s["schema"]),
                    end=s["end"],
                ) for s in s_tag["structures"]
            ]
323
324
325
            structural_tag = xgr.StructuralTag.from_legacy_structural_tag(
                tags, s_tag["triggers"])
            xgr.Grammar.from_structural_tag(structural_tag)
326
327
        except Exception as e:
            raise ValueError("Invalid structural tag specification.") from e