backend_xgrammar.py 13.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import json
5
from dataclasses import dataclass, field
6
from typing import TYPE_CHECKING, Any
7
8
9

import torch

10
import vllm.envs
11
from vllm.logger import init_logger
12
from vllm.sampling_params import SamplingParams
13
from vllm.tokenizers import DeepseekV32Tokenizer, MistralTokenizer
14
from vllm.utils.import_utils import LazyLoader
15
16
17
18
19
20
21
22
23
24
from vllm.v1.structured_output.backend_types import (
    StructuredOutputBackend,
    StructuredOutputGrammar,
    StructuredOutputOptions,
)
from vllm.v1.structured_output.utils import (
    choice_as_grammar,
    convert_lark_to_ebnf,
    grammar_is_likely_lark,
)
25
26
27
28
29
30
31
32
33

if TYPE_CHECKING:
    import xgrammar as xgr
else:
    xgr = LazyLoader("xgr", globals(), "xgrammar")

logger = init_logger(__name__)


34
@dataclass
35
class XgrammarBackend(StructuredOutputBackend):
36
    def __post_init__(self):
37
        self.disable_any_whitespace = (
38
            self.vllm_config.structured_outputs_config.disable_any_whitespace
39
        )
40

41
        if isinstance(self.tokenizer, MistralTokenizer):
42
43
            # NOTE: ideally, xgrammar should handle this accordingly.
            # refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
44
45
46
47
48
            stop_token_ids = [self.tokenizer.eos_token_id]

            # not self.tokenizer.vocab_size as self.tokenizer.vocab
            # collapses all decoded errors into a single token.
            self.vocab_size = len(self.tokenizer.vocab)
49
            tokenizer_info = xgr.TokenizerInfo(  # type: ignore
50
                encoded_vocab=self.tokenizer.vocab,
51
                # NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
52
                vocab_type=xgr.VocabType.RAW
53
54
                if self.tokenizer.is_tekken
                else xgr.VocabType.BYTE_FALLBACK,
55
56
57
58
                vocab_size=self.vocab_size,
                stop_token_ids=stop_token_ids,
                add_prefix_space=True,
            )
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
        elif isinstance(self.tokenizer, DeepseekV32Tokenizer):
            # copy from xgr.TokenizerInfo.from_huggingface()
            # because we are using a custom tokenizer wrapper here.
            vocab_dict = self.tokenizer.get_vocab()
            tokenizer_vocab_size = max(len(vocab_dict), self.tokenizer.max_token_id + 1)
            vocab_size = self.vocab_size or tokenizer_vocab_size
            # maintain tokenizer's indexing
            encoded_vocab = [""] * vocab_size
            for token, idx in vocab_dict.items():
                if idx < vocab_size:
                    encoded_vocab[idx] = token
            stop_token_ids = [self.tokenizer.eos_token_id]
            backend_str = self.tokenizer.tokenizer.backend_tokenizer.to_str()
            metadata = xgr.TokenizerInfo._detect_metadata_from_hf(backend_str)
            tokenizer_info = xgr.TokenizerInfo(
                encoded_vocab=encoded_vocab,
                vocab_type=metadata["vocab_type"],
                vocab_size=vocab_size,
                stop_token_ids=stop_token_ids,
                add_prefix_space=metadata["add_prefix_space"],
            )
80
81
        else:
            tokenizer_info = xgr.TokenizerInfo.from_huggingface(
82
                self.tokenizer,
83
84
                vocab_size=self.vocab_size,
            )
85
86
87
88
89
90
        self.compiler = xgr.GrammarCompiler(
            tokenizer_info,
            max_threads=8,
            cache_enabled=True,
            cache_limit_bytes=vllm.envs.VLLM_XGRAMMAR_CACHE_MB * 1024 * 1024,
        )
91

92
93
        self.num_speculative_tokens = 0
        if self.vllm_config.speculative_config is not None:
94
            self.num_speculative_tokens = (
95
                self.vllm_config.speculative_config.num_speculative_tokens
96
            )
97

98
99
100
    def compile_grammar(
        self, request_type: StructuredOutputOptions, grammar_spec: str
    ) -> StructuredOutputGrammar:
101
        if request_type == StructuredOutputOptions.JSON:
102
            ctx = self.compiler.compile_json_schema(
103
104
                grammar_spec, any_whitespace=not self.disable_any_whitespace
            )
105
        elif request_type == StructuredOutputOptions.JSON_OBJECT:
106
            ctx = self.compiler.compile_json_schema(
107
108
                '{"type": "object"}', any_whitespace=not self.disable_any_whitespace
            )
109
110
111
112
        elif request_type == StructuredOutputOptions.GRAMMAR:
            ctx = self.compiler.compile_grammar(grammar_spec)
        elif request_type == StructuredOutputOptions.REGEX:
            ctx = self.compiler.compile_regex(grammar_spec)
113
114
        elif request_type == StructuredOutputOptions.STRUCTURAL_TAG:
            s_tag = json.loads(grammar_spec)
115
116
117
118
119
120
121
122
123
124
125
126
127
            if "structures" in s_tag:
                # Falling back to deprecated method of compiling structural tag
                tags = [
                    xgr.StructuralTagItem(
                        begin=s["begin"],
                        schema=json.dumps(s["schema"]),
                        end=s["end"],
                    )
                    for s in s_tag["structures"]
                ]
                ctx = self.compiler.compile_structural_tag(tags, s_tag["triggers"])
            else:
                ctx = self.compiler.compile_structural_tag(grammar_spec)
128
129
130
131
132
        else:
            logger.error(
                "Validation should have already occurred. Please file an issue."
            )
            raise ValueError(
133
134
                f"grammar is not of valid supported types. ({request_type!s})"
            )
135
136

        return XgrammarGrammar(
137
138
139
140
            matcher=xgr.GrammarMatcher(
                ctx,
                max_rollback_tokens=self.num_speculative_tokens,
            ),
141
142
143
144
145
146
147
            vocab_size=self.vocab_size,
            ctx=ctx,
        )

    def allocate_token_bitmask(self, max_num_seqs: int):
        return xgr.allocate_token_bitmask(max_num_seqs, self.vocab_size)

148
149
150
    def destroy(self):
        del self.compiler

151
152
153
154
155
156
157
158
159
160
161
162
163

@dataclass
class XgrammarGrammar(StructuredOutputGrammar):
    # NOTE: This would be a generic-enough class for
    # supporting different backends, in the future.
    # For now, just xgrammar.
    #
    # https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string
    # for jump-forward decoding

    vocab_size: int
    matcher: xgr.GrammarMatcher = field(hash=False)
    ctx: xgr.CompiledGrammar = field(hash=False)
164
165
166
    num_processed_tokens: int = field(
        default_factory=lambda: 0, repr=False, hash=False, init=False
    )
167
    _is_terminated: bool = field(default=False, repr=False, hash=False)
168
169
170
171
172
173
174

    def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
        """Accepts a list of tokens and advances the FSM.

        Returns True if the FSM was advanced successfully.
        Returns False if the FSM failed to advance.
        """
175
176
        if self._is_terminated:
            return False
177
178
179
180
        for token in tokens:
            if not self.matcher.accept_token(token):
                logger.error(
                    "Failed to advance FSM for request %s "
181
182
183
184
                    "for tokens %s. Please file an issue.",
                    request_id,
                    token,
                )
185
186
                return False
            self.num_processed_tokens += 1
187
        self._is_terminated = self.matcher.is_terminated()
188
189
        return True

190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
    def validate_tokens(self, tokens: list[int]) -> list[int]:
        """Checks if the list of tokens are accepted by the FSM in sequence.
        Will not advance the FSM.

        Returns the prefix list of tokens that are accepted by the FSM.
        """
        accepted_tokens = []
        for token in tokens:
            if self.matcher.accept_token(token):
                accepted_tokens.append(token)
            else:
                break
        if len(accepted_tokens) > 0:
            # Rollback the FSM to the initial state
            self.matcher.rollback(len(accepted_tokens))
        return accepted_tokens

    def rollback(self, num_tokens: int) -> None:
        self.matcher.rollback(num_tokens)
        self.num_processed_tokens -= num_tokens
210
        self._is_terminated = self.matcher.is_terminated()
211

212
213
214
215
    def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
        self.matcher.fill_next_token_bitmask(bitmask, idx)

    def is_terminated(self) -> bool:
216
        return self._is_terminated
217
218
219
220

    def reset(self):
        self.num_processed_tokens = 0
        self.matcher.reset()
221
222


223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
# cf https://github.com/mlc-ai/xgrammar/blob/a32ac892676d2eedc0327416105b9b06edfb94b2/cpp/json_schema_converter.cc
STRING_SUPPORTED_FORMATS = {
    "email",
    "date",
    "time",
    "date-time",
    "duration",
    "ipv4",
    "ipv6",
    "hostname",
    "uuid",
    "uri",
    "uri-reference",
    "uri-template",
    "json-pointer",
    "relative-json-pointer",
}


242
243
244
245
246
247
248
249
def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool:
    """Check if JSON schema contains features unsupported by xgrammar."""

    def check_object(obj: dict[str, Any]) -> bool:
        if not isinstance(obj, dict):
            return False

        # Check for numeric ranges
250
        if obj.get("type") in ("integer", "number") and ("multipleOf" in obj):
251
252
253
254
            return True

        # Check for array unsupported keywords
        if obj.get("type") == "array" and any(
255
256
257
            key in obj
            for key in ("uniqueItems", "contains", "minContains", "maxContains")
        ):
258
259
260
            return True

        # Unsupported keywords for strings
261
262
263
264
265
        if (
            obj.get("type") == "string"
            and "format" in obj
            and obj["format"] not in STRING_SUPPORTED_FORMATS
        ):
266
267
268
269
            return True

        # Unsupported keywords for objects
        if obj.get("type") == "object" and any(
270
271
272
273
274
275
276
277
            key in obj
            for key in (
                "minProperties",
                "maxProperties",
                "propertyNames",
                "patternProperties",
            )
        ):
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
            return True

        # Recursively check all nested objects and arrays
        for value in obj.values():
            if isinstance(value, dict):
                if check_object(value):
                    return True
            elif isinstance(value, list):
                for item in value:
                    if isinstance(item, dict) and check_object(item):
                        return True

        return False

    return check_object(schema)


def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None:
    """Validate that the request is supported by structured output.

    Raises ValueError if the request is not supported.
    """
300
    if sampling_params.structured_outputs is None:
301
302
        return

303
    so_params = sampling_params.structured_outputs
304

305
    if so_params.regex:
306
        try:
307
            xgr.Grammar.from_regex(so_params.regex)
308
        except Exception as err:
309
310
311
            raise ValueError(
                f"Failed to transform regex into a grammar: {err}"
            ) from err
312

313
314
    if so_params.choice:
        choice_grammar = choice_as_grammar(so_params.choice)
315
316
317
        try:
            xgr.Grammar.from_ebnf(choice_grammar)
        except Exception as err:
318
319
320
            raise ValueError(
                "Failed to transform choices into a grammar: {err}"
            ) from err
321
322
        so_params.choice = None
        so_params.grammar = choice_grammar
323
324
        return

325
326
    if so_params.json:
        if isinstance(so_params.json, str):
327
            try:
328
                schema = json.loads(so_params.json)
329
330
331
            except json.JSONDecodeError as e:
                raise ValueError("Invalid JSON grammar specification.") from e
        else:
332
            schema = so_params.json
333

334
335
336
        try:
            xgr.Grammar.from_json_schema(schema)
        except Exception as err:
337
338
339
            raise ValueError(
                f"Failed to transform json schema into a grammar: {err}"
            ) from err
340

341
        if has_xgrammar_unsupported_json_features(schema):
342
343
344
            raise ValueError(
                "The provided JSON schema contains features not supported by xgrammar."
            )
345
346
        return

347
348
    if so_params.grammar:
        if grammar_is_likely_lark(so_params.grammar):
349
350
            # xgrammar supports EBNF grammars only
            try:
351
                so_params.grammar = convert_lark_to_ebnf(so_params.grammar)
352
353
            except ValueError as e:
                raise ValueError(
354
355
                    "Failed to convert the grammar from Lark to EBNF. "
                ) from e
356
357
358
359

        # Test parsing EBNF grammar, possibly already converted from Lark
        try:
            # parse the grammar, but we aren't compiling it.
360
            xgr.Grammar.from_ebnf(so_params.grammar)
361
362
        except Exception as e:
            raise ValueError("Invalid grammar specification.") from e
363
364
        return

365
    if so_params.structural_tag:
366
        try:
367
            s_tag = json.loads(so_params.structural_tag)
368
369
370
371
372
373
374
375
376
377
378
379
380
381

            # Using the deprecated method of compiling structural tag
            if "structures" in s_tag:
                tags = [
                    xgr.StructuralTagItem(
                        begin=s["begin"],
                        schema=json.dumps(s["schema"]),
                        end=s["end"],
                    )
                    for s in s_tag["structures"]
                ]
                xgr.Grammar.from_structural_tag(tags, s_tag["triggers"])
            else:
                xgr.Grammar.from_structural_tag(so_params.structural_tag)
382
383
        except Exception as e:
            raise ValueError("Invalid structural tag specification.") from e