xgrammar_backend.py 4.63 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

"""Constrained decoding with xgrammar backend."""

18
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
19
20
21
from typing import List, Tuple

import torch
22
23

try:
24
25
26
27
28
29
    from xgrammar import (
        CachedGrammarCompiler,
        CompiledGrammar,
        GrammarMatcher,
        TokenizerInfo,
    )
30
31
32
33
34
35
36

    import_error = None
except ImportError as e:
    CachedGrammarCompiler = CompiledGrammar = GrammarMatcher = TokenizerInfo = (
        ImportError
    )
    import_error = e
Lianmin Zheng's avatar
Lianmin Zheng committed
37

38
39
40
41
from sglang.srt.constrained.base_grammar_backend import (
    BaseGrammarBackend,
    BaseGrammarObject,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
42

43
44
45
logger = logging.getLogger(__name__)


Lianmin Zheng's avatar
Lianmin Zheng committed
46
47
48
MAX_ROLLBACK_TOKENS = 10


49
class XGrammarGrammar(BaseGrammarObject):
Lianmin Zheng's avatar
Lianmin Zheng committed
50

51
52
53
    def __init__(
        self, matcher: GrammarMatcher, vocab_size: int, ctx: CompiledGrammar
    ) -> None:
Lianmin Zheng's avatar
Lianmin Zheng committed
54
55
        self.matcher = matcher
        self.vocab_size = vocab_size
56
        self.ctx = ctx
Lianmin Zheng's avatar
Lianmin Zheng committed
57
58
59
60
61

    def accept_token(self, token: int):
        assert self.matcher.accept_token(token)

    def try_jump_forward(self, tokenizer) -> Tuple[List[int], str]:
62
63
64
65
        s = self.matcher.find_jump_forward_string()
        if s:
            return [], s
        return None
Lianmin Zheng's avatar
Lianmin Zheng committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

    def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
        _, data = helper
        return data, -1

    def jump_and_retokenize(
        self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
    ):
        k = 0
        for i, old_id in enumerate(old_output_ids):
            if old_id == new_output_ids[i]:
                k = i + 1
            else:
                break

        # rollback to the last token that is the same
        if k < len(old_output_ids):
            self.matcher.rollback(len(old_output_ids) - k)

        for i in range(k, len(new_output_ids)):
            assert self.matcher.accept_token(new_output_ids[i])

88
89
90
91
92
93
94
95
96
97
98
    def allocate_vocab_mask(
        self, vocab_size: int, batch_size: int, device
    ) -> torch.Tensor:
        return self.matcher.allocate_token_bitmask(vocab_size, batch_size)

    def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
        self.matcher.fill_next_token_bitmask(vocab_mask, idx)

    @staticmethod
    def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
        GrammarMatcher.apply_token_bitmask_inplace(logits, vocab_mask)
Lianmin Zheng's avatar
Lianmin Zheng committed
99

100
101
102
103
    def copy(self):
        matcher = GrammarMatcher(
            self.ctx,
            max_rollback_tokens=MAX_ROLLBACK_TOKENS,
104
            vocab_size=self.vocab_size,
105
106
107
        )
        return XGrammarGrammar(matcher, self.vocab_size, self.ctx)

Lianmin Zheng's avatar
Lianmin Zheng committed
108

109
class XGrammarGrammarBackend(BaseGrammarBackend):
Lianmin Zheng's avatar
Lianmin Zheng committed
110
111
112
113
114
    def __init__(
        self,
        tokenizer,
        vocab_size: int,
    ):
115
        super().__init__()
116
117
118
119
120
121
122
123

        if import_error:
            logger.warning(
                f"Ignore import error for the grammar backend: {import_error}"
            )
            self.grammar_cache = None
            return

124
125
        tokenizer_info = TokenizerInfo.from_huggingface(tokenizer)
        self.grammar_cache = CachedGrammarCompiler(tokenizer_info=tokenizer_info)
Lianmin Zheng's avatar
Lianmin Zheng committed
126
127
        self.vocab_size = vocab_size

128
    def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
129
130
131
        if import_error:
            raise import_error

Lianmin Zheng's avatar
Lianmin Zheng committed
132
133
        key_type, key_string = key
        if key_type == "json":
134
            try:
135
                ctx = self.grammar_cache.compile_json_schema_grammar(schema=key_string)
136
137
138
139
140
            except RuntimeError as e:
                logging.warning(
                    f"Skip invalid json_schema: json_schema={key_string}, {e=}"
                )
                return None
Lianmin Zheng's avatar
Lianmin Zheng committed
141
        elif key_type == "regex":
142
143
144
145
            logger.warning(
                "regex hasn't been supported by xgrammar yet. This is skipped."
            )
            return None
Lianmin Zheng's avatar
Lianmin Zheng committed
146
147
148
        else:
            raise ValueError(f"Invalid key_type: {key_type}")

149
        matcher = GrammarMatcher(
Lianmin Zheng's avatar
Lianmin Zheng committed
150
151
            ctx,
            max_rollback_tokens=MAX_ROLLBACK_TOKENS,
152
            vocab_size=self.vocab_size,
Lianmin Zheng's avatar
Lianmin Zheng committed
153
        )
154
        return XGrammarGrammar(matcher, self.vocab_size, ctx)
Lianmin Zheng's avatar
Lianmin Zheng committed
155
156

    def reset(self):
157
158
        if self.grammar_cache:
            self.grammar_cache.clear()