backend_xgrammar.py 12.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import json
5
from dataclasses import dataclass, field
6
from typing import TYPE_CHECKING, Any
7
8
9

import torch

10
import vllm.envs
11
from vllm.logger import init_logger
12
from vllm.sampling_params import SamplingParams
13
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
14
from vllm.utils.import_utils import LazyLoader
15
16
17
18
19
20
21
22
23
24
from vllm.v1.structured_output.backend_types import (
    StructuredOutputBackend,
    StructuredOutputGrammar,
    StructuredOutputOptions,
)
from vllm.v1.structured_output.utils import (
    choice_as_grammar,
    convert_lark_to_ebnf,
    grammar_is_likely_lark,
)
25
26
27
28
29
30
31
32
33

if TYPE_CHECKING:
    import xgrammar as xgr
else:
    xgr = LazyLoader("xgr", globals(), "xgrammar")

logger = init_logger(__name__)


34
@dataclass
35
class XgrammarBackend(StructuredOutputBackend):
36
    def __post_init__(self):
37
        self.disable_any_whitespace = (
38
            self.vllm_config.structured_outputs_config.disable_any_whitespace
39
        )
40

41
        if isinstance(self.tokenizer, MistralTokenizer):
42
43
            # NOTE: ideally, xgrammar should handle this accordingly.
            # refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
44
45
46
47
48
            stop_token_ids = [self.tokenizer.eos_token_id]

            # not self.tokenizer.vocab_size as self.tokenizer.vocab
            # collapses all decoded errors into a single token.
            self.vocab_size = len(self.tokenizer.vocab)
49
            tokenizer_info = xgr.TokenizerInfo(  # type: ignore
50
                encoded_vocab=self.tokenizer.vocab,
51
                # NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
52
                vocab_type=xgr.VocabType.RAW
53
54
                if self.tokenizer.is_tekken
                else xgr.VocabType.BYTE_FALLBACK,
55
56
57
58
59
60
                vocab_size=self.vocab_size,
                stop_token_ids=stop_token_ids,
                add_prefix_space=True,
            )
        else:
            tokenizer_info = xgr.TokenizerInfo.from_huggingface(
61
                self.tokenizer,
62
63
                vocab_size=self.vocab_size,
            )
64
65
66
67
68
69
        self.compiler = xgr.GrammarCompiler(
            tokenizer_info,
            max_threads=8,
            cache_enabled=True,
            cache_limit_bytes=vllm.envs.VLLM_XGRAMMAR_CACHE_MB * 1024 * 1024,
        )
70

71
72
        self.num_speculative_tokens = 0
        if self.vllm_config.speculative_config is not None:
73
            self.num_speculative_tokens = (
74
                self.vllm_config.speculative_config.num_speculative_tokens
75
            )
76

77
78
79
    def compile_grammar(
        self, request_type: StructuredOutputOptions, grammar_spec: str
    ) -> StructuredOutputGrammar:
80
        if request_type == StructuredOutputOptions.JSON:
81
            ctx = self.compiler.compile_json_schema(
82
83
                grammar_spec, any_whitespace=not self.disable_any_whitespace
            )
84
        elif request_type == StructuredOutputOptions.JSON_OBJECT:
85
            ctx = self.compiler.compile_json_schema(
86
87
                '{"type": "object"}', any_whitespace=not self.disable_any_whitespace
            )
88
89
90
91
        elif request_type == StructuredOutputOptions.GRAMMAR:
            ctx = self.compiler.compile_grammar(grammar_spec)
        elif request_type == StructuredOutputOptions.REGEX:
            ctx = self.compiler.compile_regex(grammar_spec)
92
93
        elif request_type == StructuredOutputOptions.STRUCTURAL_TAG:
            s_tag = json.loads(grammar_spec)
94
95
96
97
98
99
100
101
102
103
104
105
106
            if "structures" in s_tag:
                # Falling back to deprecated method of compiling structural tag
                tags = [
                    xgr.StructuralTagItem(
                        begin=s["begin"],
                        schema=json.dumps(s["schema"]),
                        end=s["end"],
                    )
                    for s in s_tag["structures"]
                ]
                ctx = self.compiler.compile_structural_tag(tags, s_tag["triggers"])
            else:
                ctx = self.compiler.compile_structural_tag(grammar_spec)
107
108
109
110
111
        else:
            logger.error(
                "Validation should have already occurred. Please file an issue."
            )
            raise ValueError(
112
113
                f"grammar is not of valid supported types. ({request_type!s})"
            )
114
115

        return XgrammarGrammar(
116
117
118
119
            matcher=xgr.GrammarMatcher(
                ctx,
                max_rollback_tokens=self.num_speculative_tokens,
            ),
120
121
122
123
124
125
126
            vocab_size=self.vocab_size,
            ctx=ctx,
        )

    def allocate_token_bitmask(self, max_num_seqs: int):
        return xgr.allocate_token_bitmask(max_num_seqs, self.vocab_size)

127
128
129
    def destroy(self):
        del self.compiler

130
131
132
133
134
135
136
137
138
139
140
141
142

@dataclass
class XgrammarGrammar(StructuredOutputGrammar):
    # NOTE: This would be a generic-enough class for
    # supporting different backends, in the future.
    # For now, just xgrammar.
    #
    # https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string
    # for jump-forward decoding

    vocab_size: int
    matcher: xgr.GrammarMatcher = field(hash=False)
    ctx: xgr.CompiledGrammar = field(hash=False)
143
144
145
    num_processed_tokens: int = field(
        default_factory=lambda: 0, repr=False, hash=False, init=False
    )
146
    _is_terminated: bool = field(default=False, repr=False, hash=False)
147
148
149
150
151
152
153

    def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
        """Accepts a list of tokens and advances the FSM.

        Returns True if the FSM was advanced successfully.
        Returns False if the FSM failed to advance.
        """
154
155
        if self._is_terminated:
            return False
156
157
158
159
        for token in tokens:
            if not self.matcher.accept_token(token):
                logger.error(
                    "Failed to advance FSM for request %s "
160
161
162
163
                    "for tokens %s. Please file an issue.",
                    request_id,
                    token,
                )
164
165
                return False
            self.num_processed_tokens += 1
166
        self._is_terminated = self.matcher.is_terminated()
167
168
        return True

169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
    def validate_tokens(self, tokens: list[int]) -> list[int]:
        """Checks if the list of tokens are accepted by the FSM in sequence.
        Will not advance the FSM.

        Returns the prefix list of tokens that are accepted by the FSM.
        """
        accepted_tokens = []
        for token in tokens:
            if self.matcher.accept_token(token):
                accepted_tokens.append(token)
            else:
                break
        if len(accepted_tokens) > 0:
            # Rollback the FSM to the initial state
            self.matcher.rollback(len(accepted_tokens))
        return accepted_tokens

    def rollback(self, num_tokens: int) -> None:
        self.matcher.rollback(num_tokens)
        self.num_processed_tokens -= num_tokens
189
        self._is_terminated = self.matcher.is_terminated()
190

191
192
193
194
    def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
        self.matcher.fill_next_token_bitmask(bitmask, idx)

    def is_terminated(self) -> bool:
195
        return self._is_terminated
196
197
198
199

    def reset(self):
        self.num_processed_tokens = 0
        self.matcher.reset()
200
201


202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
# cf https://github.com/mlc-ai/xgrammar/blob/a32ac892676d2eedc0327416105b9b06edfb94b2/cpp/json_schema_converter.cc
STRING_SUPPORTED_FORMATS = {
    "email",
    "date",
    "time",
    "date-time",
    "duration",
    "ipv4",
    "ipv6",
    "hostname",
    "uuid",
    "uri",
    "uri-reference",
    "uri-template",
    "json-pointer",
    "relative-json-pointer",
}


221
222
223
224
225
226
227
228
def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool:
    """Check if JSON schema contains features unsupported by xgrammar."""

    def check_object(obj: dict[str, Any]) -> bool:
        if not isinstance(obj, dict):
            return False

        # Check for numeric ranges
229
        if obj.get("type") in ("integer", "number") and ("multipleOf" in obj):
230
231
232
233
            return True

        # Check for array unsupported keywords
        if obj.get("type") == "array" and any(
234
235
236
            key in obj
            for key in ("uniqueItems", "contains", "minContains", "maxContains")
        ):
237
238
239
            return True

        # Unsupported keywords for strings
240
241
242
243
244
        if (
            obj.get("type") == "string"
            and "format" in obj
            and obj["format"] not in STRING_SUPPORTED_FORMATS
        ):
245
246
247
248
            return True

        # Unsupported keywords for objects
        if obj.get("type") == "object" and any(
249
250
251
252
253
254
255
256
            key in obj
            for key in (
                "minProperties",
                "maxProperties",
                "propertyNames",
                "patternProperties",
            )
        ):
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
            return True

        # Recursively check all nested objects and arrays
        for value in obj.values():
            if isinstance(value, dict):
                if check_object(value):
                    return True
            elif isinstance(value, list):
                for item in value:
                    if isinstance(item, dict) and check_object(item):
                        return True

        return False

    return check_object(schema)


def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None:
    """Validate that the request is supported by structured output.

    Raises ValueError if the request is not supported.
    """
279
    if sampling_params.structured_outputs is None:
280
281
        return

282
    so_params = sampling_params.structured_outputs
283

284
    if so_params.regex:
285
        try:
286
            xgr.Grammar.from_regex(so_params.regex)
287
        except Exception as err:
288
289
290
            raise ValueError(
                f"Failed to transform regex into a grammar: {err}"
            ) from err
291

292
293
    if so_params.choice:
        choice_grammar = choice_as_grammar(so_params.choice)
294
295
296
        try:
            xgr.Grammar.from_ebnf(choice_grammar)
        except Exception as err:
297
298
299
            raise ValueError(
                "Failed to transform choices into a grammar: {err}"
            ) from err
300
301
        so_params.choice = None
        so_params.grammar = choice_grammar
302
303
        return

304
305
    if so_params.json:
        if isinstance(so_params.json, str):
306
            try:
307
                schema = json.loads(so_params.json)
308
309
310
            except json.JSONDecodeError as e:
                raise ValueError("Invalid JSON grammar specification.") from e
        else:
311
            schema = so_params.json
312

313
314
315
        try:
            xgr.Grammar.from_json_schema(schema)
        except Exception as err:
316
317
318
            raise ValueError(
                f"Failed to transform json schema into a grammar: {err}"
            ) from err
319

320
        if has_xgrammar_unsupported_json_features(schema):
321
322
323
            raise ValueError(
                "The provided JSON schema contains features not supported by xgrammar."
            )
324
325
        return

326
327
    if so_params.grammar:
        if grammar_is_likely_lark(so_params.grammar):
328
329
            # xgrammar supports EBNF grammars only
            try:
330
                so_params.grammar = convert_lark_to_ebnf(so_params.grammar)
331
332
            except ValueError as e:
                raise ValueError(
333
334
                    "Failed to convert the grammar from Lark to EBNF. "
                ) from e
335
336
337
338

        # Test parsing EBNF grammar, possibly already converted from Lark
        try:
            # parse the grammar, but we aren't compiling it.
339
            xgr.Grammar.from_ebnf(so_params.grammar)
340
341
        except Exception as e:
            raise ValueError("Invalid grammar specification.") from e
342
343
        return

344
    if so_params.structural_tag:
345
        try:
346
            s_tag = json.loads(so_params.structural_tag)
347
348
349
350
351
352
353
354
355
356
357
358
359
360

            # Using the deprecated method of compiling structural tag
            if "structures" in s_tag:
                tags = [
                    xgr.StructuralTagItem(
                        begin=s["begin"],
                        schema=json.dumps(s["schema"]),
                        end=s["end"],
                    )
                    for s in s_tag["structures"]
                ]
                xgr.Grammar.from_structural_tag(tags, s_tag["triggers"])
            else:
                xgr.Grammar.from_structural_tag(so_params.structural_tag)
361
362
        except Exception as e:
            raise ValueError("Invalid structural tag specification.") from e