backend_xgrammar.py 12.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from __future__ import annotations

6
import json
7
from dataclasses import dataclass, field
8
from typing import TYPE_CHECKING, Any
9
10
11

import torch

12
import vllm.envs
13
from vllm.logger import init_logger
14
from vllm.sampling_params import SamplingParams
15
16
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.utils import LazyLoader
17
18
19
20
21
22
23
24
25
26
from vllm.v1.structured_output.backend_types import (
    StructuredOutputBackend,
    StructuredOutputGrammar,
    StructuredOutputOptions,
)
from vllm.v1.structured_output.utils import (
    choice_as_grammar,
    convert_lark_to_ebnf,
    grammar_is_likely_lark,
)
27
28
29
30
31
32
33
34
35

if TYPE_CHECKING:
    import xgrammar as xgr
else:
    xgr = LazyLoader("xgr", globals(), "xgrammar")

logger = init_logger(__name__)


36
@dataclass
37
class XgrammarBackend(StructuredOutputBackend):
38
    def __post_init__(self):
39
        self.disable_any_whitespace = (
40
            self.vllm_config.structured_outputs_config.disable_any_whitespace
41
        )
42

43
        if isinstance(self.tokenizer, MistralTokenizer):
44
45
46
            # NOTE: ideally, xgrammar should handle this accordingly.
            # refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
            try:
47
48
                if self.tokenizer.is_tekken:
                    encoded_vocab = self.tokenizer._vocab
49
50
                else:
                    encoded_vocab = [
51
52
                        token
                        for token, _ in sorted(
53
                            self.tokenizer.get_vocab().items(),
54
55
56
                            key=lambda x: x[1],
                        )
                    ]
57
                stop_token_ids = None
58
59
                if (
                    hasattr(
60
                        self.tokenizer,
61
                        "eos_token_id",
62
63
64
                    )
                    and self.tokenizer.eos_token_id is not None
                ):
65
                    stop_token_ids = [self.tokenizer.eos_token_id]
66
67
68
            except AttributeError as e:
                raise ValueError(
                    f"Cannot get the vocabulary of the tokenizer "
69
                    f"{type(self.tokenizer)}. The tokenizer should have a "
70
71
                    "get_vocab method."
                ) from e
72
73
74
            tokenizer_info = xgr.TokenizerInfo(  # type: ignore
                encoded_vocab=encoded_vocab,
                # NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
75
                vocab_type=xgr.VocabType.RAW
76
77
                if self.tokenizer.is_tekken
                else xgr.VocabType.BYTE_FALLBACK,
78
79
80
81
82
83
                vocab_size=self.vocab_size,
                stop_token_ids=stop_token_ids,
                add_prefix_space=True,
            )
        else:
            tokenizer_info = xgr.TokenizerInfo.from_huggingface(
84
                self.tokenizer,
85
86
                vocab_size=self.vocab_size,
            )
87
88
89
90
91
92
        self.compiler = xgr.GrammarCompiler(
            tokenizer_info,
            max_threads=8,
            cache_enabled=True,
            cache_limit_bytes=vllm.envs.VLLM_XGRAMMAR_CACHE_MB * 1024 * 1024,
        )
93

94
95
        self.num_speculative_tokens = 0
        if self.vllm_config.speculative_config is not None:
96
            self.num_speculative_tokens = (
97
                self.vllm_config.speculative_config.num_speculative_tokens
98
            )
99

100
101
102
    def compile_grammar(
        self, request_type: StructuredOutputOptions, grammar_spec: str
    ) -> StructuredOutputGrammar:
103
        if request_type == StructuredOutputOptions.JSON:
104
            ctx = self.compiler.compile_json_schema(
105
106
                grammar_spec, any_whitespace=not self.disable_any_whitespace
            )
107
        elif request_type == StructuredOutputOptions.JSON_OBJECT:
108
            ctx = self.compiler.compile_json_schema(
109
110
                '{"type": "object"}', any_whitespace=not self.disable_any_whitespace
            )
111
112
113
114
        elif request_type == StructuredOutputOptions.GRAMMAR:
            ctx = self.compiler.compile_grammar(grammar_spec)
        elif request_type == StructuredOutputOptions.REGEX:
            ctx = self.compiler.compile_regex(grammar_spec)
115
116
117
118
119
120
121
        elif request_type == StructuredOutputOptions.STRUCTURAL_TAG:
            s_tag = json.loads(grammar_spec)
            tags = [
                xgr.StructuralTagItem(
                    begin=s["begin"],
                    schema=json.dumps(s["schema"]),
                    end=s["end"],
122
123
                )
                for s in s_tag["structures"]
124
            ]
125
            structural_tag = xgr.StructuralTag.from_legacy_structural_tag(
126
127
                tags, s_tag["triggers"]
            )
128
            ctx = self.compiler.compile_structural_tag(structural_tag)
129
130
131
132
133
        else:
            logger.error(
                "Validation should have already occurred. Please file an issue."
            )
            raise ValueError(
134
135
                f"grammar is not of valid supported types. ({request_type!s})"
            )
136
137

        return XgrammarGrammar(
138
139
140
141
            matcher=xgr.GrammarMatcher(
                ctx,
                max_rollback_tokens=self.num_speculative_tokens,
            ),
142
143
144
145
146
147
148
            vocab_size=self.vocab_size,
            ctx=ctx,
        )

    def allocate_token_bitmask(self, max_num_seqs: int):
        return xgr.allocate_token_bitmask(max_num_seqs, self.vocab_size)

149
150
151
    def destroy(self):
        del self.compiler

152
153
154
155
156
157
158
159
160
161
162
163
164

@dataclass
class XgrammarGrammar(StructuredOutputGrammar):
    # NOTE: This would be a generic-enough class for
    # supporting different backends, in the future.
    # For now, just xgrammar.
    #
    # https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string
    # for jump-forward decoding

    vocab_size: int
    matcher: xgr.GrammarMatcher = field(hash=False)
    ctx: xgr.CompiledGrammar = field(hash=False)
165
166
167
    num_processed_tokens: int = field(
        default_factory=lambda: 0, repr=False, hash=False, init=False
    )
168
    _is_terminated: bool = field(default=False, repr=False, hash=False)
169
170
171
172
173
174
175

    def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
        """Accepts a list of tokens and advances the FSM.

        Returns True if the FSM was advanced successfully.
        Returns False if the FSM failed to advance.
        """
176
177
        if self._is_terminated:
            return False
178
179
180
181
        for token in tokens:
            if not self.matcher.accept_token(token):
                logger.error(
                    "Failed to advance FSM for request %s "
182
183
184
185
                    "for tokens %s. Please file an issue.",
                    request_id,
                    token,
                )
186
187
                return False
            self.num_processed_tokens += 1
188
        self._is_terminated = self.matcher.is_terminated()
189
190
        return True

191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
    def validate_tokens(self, tokens: list[int]) -> list[int]:
        """Checks if the list of tokens are accepted by the FSM in sequence.
        Will not advance the FSM.

        Returns the prefix list of tokens that are accepted by the FSM.
        """
        accepted_tokens = []
        for token in tokens:
            if self.matcher.accept_token(token):
                accepted_tokens.append(token)
            else:
                break
        if len(accepted_tokens) > 0:
            # Rollback the FSM to the initial state
            self.matcher.rollback(len(accepted_tokens))
        return accepted_tokens

    def rollback(self, num_tokens: int) -> None:
        self.matcher.rollback(num_tokens)
        self.num_processed_tokens -= num_tokens
211
        self._is_terminated = self.matcher.is_terminated()
212

213
214
215
216
    def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
        self.matcher.fill_next_token_bitmask(bitmask, idx)

    def is_terminated(self) -> bool:
217
        return self._is_terminated
218
219
220
221

    def reset(self):
        self.num_processed_tokens = 0
        self.matcher.reset()
222
223
224
225
226
227
228
229
230
231


def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool:
    """Check if JSON schema contains features unsupported by xgrammar."""

    def check_object(obj: dict[str, Any]) -> bool:
        if not isinstance(obj, dict):
            return False

        # Check for numeric ranges
232
        if obj.get("type") in ("integer", "number") and ("multipleOf" in obj):
233
234
235
236
            return True

        # Check for array unsupported keywords
        if obj.get("type") == "array" and any(
237
238
239
            key in obj
            for key in ("uniqueItems", "contains", "minContains", "maxContains")
        ):
240
241
242
243
244
245
246
247
            return True

        # Unsupported keywords for strings
        if obj.get("type") == "string" and "format" in obj:
            return True

        # Unsupported keywords for objects
        if obj.get("type") == "object" and any(
248
249
250
251
252
253
254
255
            key in obj
            for key in (
                "minProperties",
                "maxProperties",
                "propertyNames",
                "patternProperties",
            )
        ):
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
            return True

        # Recursively check all nested objects and arrays
        for value in obj.values():
            if isinstance(value, dict):
                if check_object(value):
                    return True
            elif isinstance(value, list):
                for item in value:
                    if isinstance(item, dict) and check_object(item):
                        return True

        return False

    return check_object(schema)


def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None:
    """Validate that the request is supported by structured output.

    Raises ValueError if the request is not supported.
    """
278
    if sampling_params.structured_outputs is None:
279
280
        return

281
    so_params = sampling_params.structured_outputs
282

283
    if so_params.regex:
284
        try:
285
            xgr.Grammar.from_regex(so_params.regex)
286
        except Exception as err:
287
288
289
            raise ValueError(
                f"Failed to transform regex into a grammar: {err}"
            ) from err
290

291
292
    if so_params.choice:
        choice_grammar = choice_as_grammar(so_params.choice)
293
294
295
        try:
            xgr.Grammar.from_ebnf(choice_grammar)
        except Exception as err:
296
297
298
            raise ValueError(
                "Failed to transform choices into a grammar: {err}"
            ) from err
299
300
        so_params.choice = None
        so_params.grammar = choice_grammar
301
302
        return

303
304
    if so_params.json:
        if isinstance(so_params.json, str):
305
            try:
306
                schema = json.loads(so_params.json)
307
308
309
            except json.JSONDecodeError as e:
                raise ValueError("Invalid JSON grammar specification.") from e
        else:
310
            schema = so_params.json
311

312
313
314
        try:
            xgr.Grammar.from_json_schema(schema)
        except Exception as err:
315
316
317
            raise ValueError(
                f"Failed to transform json schema into a grammar: {err}"
            ) from err
318

319
        if has_xgrammar_unsupported_json_features(schema):
320
321
322
            raise ValueError(
                "The provided JSON schema contains features not supported by xgrammar."
            )
323
324
        return

325
326
    if so_params.grammar:
        if grammar_is_likely_lark(so_params.grammar):
327
328
            # xgrammar supports EBNF grammars only
            try:
329
                so_params.grammar = convert_lark_to_ebnf(so_params.grammar)
330
331
            except ValueError as e:
                raise ValueError(
332
333
                    "Failed to convert the grammar from Lark to EBNF. "
                ) from e
334
335
336
337

        # Test parsing EBNF grammar, possibly already converted from Lark
        try:
            # parse the grammar, but we aren't compiling it.
338
            xgr.Grammar.from_ebnf(so_params.grammar)
339
340
        except Exception as e:
            raise ValueError("Invalid grammar specification.") from e
341
342
        return

343
    if so_params.structural_tag:
344
        try:
345
            s_tag = json.loads(so_params.structural_tag)
346
347
348
349
350
            tags = [
                xgr.StructuralTagItem(
                    begin=s["begin"],
                    schema=json.dumps(s["schema"]),
                    end=s["end"],
351
352
                )
                for s in s_tag["structures"]
353
            ]
354
            structural_tag = xgr.StructuralTag.from_legacy_structural_tag(
355
356
                tags, s_tag["triggers"]
            )
357
            xgr.Grammar.from_structural_tag(structural_tag)
358
359
        except Exception as e:
            raise ValueError("Invalid structural tag specification.") from e