backend_xgrammar.py 11.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import json
5
from dataclasses import dataclass, field
6
from typing import TYPE_CHECKING, Any
7
8
9

import torch

10
import vllm.envs
11
from vllm.logger import init_logger
12
from vllm.sampling_params import SamplingParams
13
14
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.utils import LazyLoader
15
16
17
18
19
20
21
22
23
24
from vllm.v1.structured_output.backend_types import (
    StructuredOutputBackend,
    StructuredOutputGrammar,
    StructuredOutputOptions,
)
from vllm.v1.structured_output.utils import (
    choice_as_grammar,
    convert_lark_to_ebnf,
    grammar_is_likely_lark,
)
25
26
27
28
29
30
31
32
33

if TYPE_CHECKING:
    import xgrammar as xgr
else:
    xgr = LazyLoader("xgr", globals(), "xgrammar")

logger = init_logger(__name__)


34
@dataclass
35
class XgrammarBackend(StructuredOutputBackend):
36
    def __post_init__(self):
37
        self.disable_any_whitespace = (
38
            self.vllm_config.structured_outputs_config.disable_any_whitespace
39
        )
40

41
        if isinstance(self.tokenizer, MistralTokenizer):
42
43
            # NOTE: ideally, xgrammar should handle this accordingly.
            # refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
44
45
46
47
48
            stop_token_ids = [self.tokenizer.eos_token_id]

            # not self.tokenizer.vocab_size as self.tokenizer.vocab
            # collapses all decoded errors into a single token.
            self.vocab_size = len(self.tokenizer.vocab)
49
            tokenizer_info = xgr.TokenizerInfo(  # type: ignore
50
                encoded_vocab=self.tokenizer.vocab,
51
                # NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
52
                vocab_type=xgr.VocabType.RAW
53
54
                if self.tokenizer.is_tekken
                else xgr.VocabType.BYTE_FALLBACK,
55
56
57
58
59
60
                vocab_size=self.vocab_size,
                stop_token_ids=stop_token_ids,
                add_prefix_space=True,
            )
        else:
            tokenizer_info = xgr.TokenizerInfo.from_huggingface(
61
                self.tokenizer,
62
63
                vocab_size=self.vocab_size,
            )
64
65
66
67
68
69
        self.compiler = xgr.GrammarCompiler(
            tokenizer_info,
            max_threads=8,
            cache_enabled=True,
            cache_limit_bytes=vllm.envs.VLLM_XGRAMMAR_CACHE_MB * 1024 * 1024,
        )
70

71
72
        self.num_speculative_tokens = 0
        if self.vllm_config.speculative_config is not None:
73
            self.num_speculative_tokens = (
74
                self.vllm_config.speculative_config.num_speculative_tokens
75
            )
76

77
78
79
    def compile_grammar(
        self, request_type: StructuredOutputOptions, grammar_spec: str
    ) -> StructuredOutputGrammar:
80
        if request_type == StructuredOutputOptions.JSON:
81
            ctx = self.compiler.compile_json_schema(
82
83
                grammar_spec, any_whitespace=not self.disable_any_whitespace
            )
84
        elif request_type == StructuredOutputOptions.JSON_OBJECT:
85
            ctx = self.compiler.compile_json_schema(
86
87
                '{"type": "object"}', any_whitespace=not self.disable_any_whitespace
            )
88
89
90
91
        elif request_type == StructuredOutputOptions.GRAMMAR:
            ctx = self.compiler.compile_grammar(grammar_spec)
        elif request_type == StructuredOutputOptions.REGEX:
            ctx = self.compiler.compile_regex(grammar_spec)
92
93
94
95
96
97
98
        elif request_type == StructuredOutputOptions.STRUCTURAL_TAG:
            s_tag = json.loads(grammar_spec)
            tags = [
                xgr.StructuralTagItem(
                    begin=s["begin"],
                    schema=json.dumps(s["schema"]),
                    end=s["end"],
99
100
                )
                for s in s_tag["structures"]
101
            ]
102
            structural_tag = xgr.StructuralTag.from_legacy_structural_tag(
103
104
                tags, s_tag["triggers"]
            )
105
            ctx = self.compiler.compile_structural_tag(structural_tag)
106
107
108
109
110
        else:
            logger.error(
                "Validation should have already occurred. Please file an issue."
            )
            raise ValueError(
111
112
                f"grammar is not of valid supported types. ({request_type!s})"
            )
113
114

        return XgrammarGrammar(
115
116
117
118
            matcher=xgr.GrammarMatcher(
                ctx,
                max_rollback_tokens=self.num_speculative_tokens,
            ),
119
120
121
122
123
124
125
            vocab_size=self.vocab_size,
            ctx=ctx,
        )

    def allocate_token_bitmask(self, max_num_seqs: int):
        return xgr.allocate_token_bitmask(max_num_seqs, self.vocab_size)

126
127
128
    def destroy(self):
        del self.compiler

129
130
131
132
133
134
135
136
137
138
139
140
141

@dataclass
class XgrammarGrammar(StructuredOutputGrammar):
    # NOTE: This would be a generic-enough class for
    # supporting different backends, in the future.
    # For now, just xgrammar.
    #
    # https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string
    # for jump-forward decoding

    vocab_size: int
    matcher: xgr.GrammarMatcher = field(hash=False)
    ctx: xgr.CompiledGrammar = field(hash=False)
142
143
144
    num_processed_tokens: int = field(
        default_factory=lambda: 0, repr=False, hash=False, init=False
    )
145
    _is_terminated: bool = field(default=False, repr=False, hash=False)
146
147
148
149
150
151
152

    def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
        """Accepts a list of tokens and advances the FSM.

        Returns True if the FSM was advanced successfully.
        Returns False if the FSM failed to advance.
        """
153
154
        if self._is_terminated:
            return False
155
156
157
158
        for token in tokens:
            if not self.matcher.accept_token(token):
                logger.error(
                    "Failed to advance FSM for request %s "
159
160
161
162
                    "for tokens %s. Please file an issue.",
                    request_id,
                    token,
                )
163
164
                return False
            self.num_processed_tokens += 1
165
        self._is_terminated = self.matcher.is_terminated()
166
167
        return True

168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
    def validate_tokens(self, tokens: list[int]) -> list[int]:
        """Checks if the list of tokens are accepted by the FSM in sequence.
        Will not advance the FSM.

        Returns the prefix list of tokens that are accepted by the FSM.
        """
        accepted_tokens = []
        for token in tokens:
            if self.matcher.accept_token(token):
                accepted_tokens.append(token)
            else:
                break
        if len(accepted_tokens) > 0:
            # Rollback the FSM to the initial state
            self.matcher.rollback(len(accepted_tokens))
        return accepted_tokens

    def rollback(self, num_tokens: int) -> None:
        self.matcher.rollback(num_tokens)
        self.num_processed_tokens -= num_tokens
188
        self._is_terminated = self.matcher.is_terminated()
189

190
191
192
193
    def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
        self.matcher.fill_next_token_bitmask(bitmask, idx)

    def is_terminated(self) -> bool:
194
        return self._is_terminated
195
196
197
198

    def reset(self):
        self.num_processed_tokens = 0
        self.matcher.reset()
199
200
201
202
203
204
205
206
207
208


def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool:
    """Check if JSON schema contains features unsupported by xgrammar."""

    def check_object(obj: dict[str, Any]) -> bool:
        if not isinstance(obj, dict):
            return False

        # Check for numeric ranges
209
        if obj.get("type") in ("integer", "number") and ("multipleOf" in obj):
210
211
212
213
            return True

        # Check for array unsupported keywords
        if obj.get("type") == "array" and any(
214
215
216
            key in obj
            for key in ("uniqueItems", "contains", "minContains", "maxContains")
        ):
217
218
219
220
221
222
223
224
            return True

        # Unsupported keywords for strings
        if obj.get("type") == "string" and "format" in obj:
            return True

        # Unsupported keywords for objects
        if obj.get("type") == "object" and any(
225
226
227
228
229
230
231
232
            key in obj
            for key in (
                "minProperties",
                "maxProperties",
                "propertyNames",
                "patternProperties",
            )
        ):
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
            return True

        # Recursively check all nested objects and arrays
        for value in obj.values():
            if isinstance(value, dict):
                if check_object(value):
                    return True
            elif isinstance(value, list):
                for item in value:
                    if isinstance(item, dict) and check_object(item):
                        return True

        return False

    return check_object(schema)


def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None:
    """Validate that the request is supported by structured output.

    Raises ValueError if the request is not supported.
    """
255
    if sampling_params.structured_outputs is None:
256
257
        return

258
    so_params = sampling_params.structured_outputs
259

260
    if so_params.regex:
261
        try:
262
            xgr.Grammar.from_regex(so_params.regex)
263
        except Exception as err:
264
265
266
            raise ValueError(
                f"Failed to transform regex into a grammar: {err}"
            ) from err
267

268
269
    if so_params.choice:
        choice_grammar = choice_as_grammar(so_params.choice)
270
271
272
        try:
            xgr.Grammar.from_ebnf(choice_grammar)
        except Exception as err:
273
274
275
            raise ValueError(
                "Failed to transform choices into a grammar: {err}"
            ) from err
276
277
        so_params.choice = None
        so_params.grammar = choice_grammar
278
279
        return

280
281
    if so_params.json:
        if isinstance(so_params.json, str):
282
            try:
283
                schema = json.loads(so_params.json)
284
285
286
            except json.JSONDecodeError as e:
                raise ValueError("Invalid JSON grammar specification.") from e
        else:
287
            schema = so_params.json
288

289
290
291
        try:
            xgr.Grammar.from_json_schema(schema)
        except Exception as err:
292
293
294
            raise ValueError(
                f"Failed to transform json schema into a grammar: {err}"
            ) from err
295

296
        if has_xgrammar_unsupported_json_features(schema):
297
298
299
            raise ValueError(
                "The provided JSON schema contains features not supported by xgrammar."
            )
300
301
        return

302
303
    if so_params.grammar:
        if grammar_is_likely_lark(so_params.grammar):
304
305
            # xgrammar supports EBNF grammars only
            try:
306
                so_params.grammar = convert_lark_to_ebnf(so_params.grammar)
307
308
            except ValueError as e:
                raise ValueError(
309
310
                    "Failed to convert the grammar from Lark to EBNF. "
                ) from e
311
312
313
314

        # Test parsing EBNF grammar, possibly already converted from Lark
        try:
            # parse the grammar, but we aren't compiling it.
315
            xgr.Grammar.from_ebnf(so_params.grammar)
316
317
        except Exception as e:
            raise ValueError("Invalid grammar specification.") from e
318
319
        return

320
    if so_params.structural_tag:
321
        try:
322
            s_tag = json.loads(so_params.structural_tag)
323
324
325
326
327
            tags = [
                xgr.StructuralTagItem(
                    begin=s["begin"],
                    schema=json.dumps(s["schema"]),
                    end=s["end"],
328
329
                )
                for s in s_tag["structures"]
330
            ]
331
            structural_tag = xgr.StructuralTag.from_legacy_structural_tag(
332
333
                tags, s_tag["triggers"]
            )
334
            xgr.Grammar.from_structural_tag(structural_tag)
335
336
        except Exception as e:
            raise ValueError("Invalid structural tag specification.") from e