xgrammar_backend.py 5.61 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
15
"""Constrained decoding with xgrammar backend."""

16
import json
17
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
18
19
20
from typing import List, Tuple

import torch
21
22
23
24
from xgrammar import (
    CompiledGrammar,
    GrammarCompiler,
    GrammarMatcher,
25
    StructuralTagItem,
26
27
28
29
    TokenizerInfo,
    allocate_token_bitmask,
    apply_token_bitmask_inplace,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
30

31
32
33
34
from sglang.srt.constrained.base_grammar_backend import (
    BaseGrammarBackend,
    BaseGrammarObject,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
35

36
37
38
logger = logging.getLogger(__name__)


39
MAX_ROLLBACK_TOKENS = 200
Lianmin Zheng's avatar
Lianmin Zheng committed
40
41


42
class XGrammarGrammar(BaseGrammarObject):
Lianmin Zheng's avatar
Lianmin Zheng committed
43

44
45
46
    def __init__(
        self, matcher: GrammarMatcher, vocab_size: int, ctx: CompiledGrammar
    ) -> None:
Lianmin Zheng's avatar
Lianmin Zheng committed
47
48
        self.matcher = matcher
        self.vocab_size = vocab_size
49
        self.ctx = ctx
50
        self.finished = False
Lianmin Zheng's avatar
Lianmin Zheng committed
51
52
53
54
55

    def accept_token(self, token: int):
        assert self.matcher.accept_token(token)

    def try_jump_forward(self, tokenizer) -> Tuple[List[int], str]:
56
57
58
59
        s = self.matcher.find_jump_forward_string()
        if s:
            return [], s
        return None
Lianmin Zheng's avatar
Lianmin Zheng committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81

    def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
        _, data = helper
        return data, -1

    def jump_and_retokenize(
        self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
    ):
        k = 0
        for i, old_id in enumerate(old_output_ids):
            if old_id == new_output_ids[i]:
                k = i + 1
            else:
                break

        # rollback to the last token that is the same
        if k < len(old_output_ids):
            self.matcher.rollback(len(old_output_ids) - k)

        for i in range(k, len(new_output_ids)):
            assert self.matcher.accept_token(new_output_ids[i])

82
83
84
    def allocate_vocab_mask(
        self, vocab_size: int, batch_size: int, device
    ) -> torch.Tensor:
85
        return allocate_token_bitmask(batch_size, vocab_size)
86
87
88
89
90

    def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
        self.matcher.fill_next_token_bitmask(vocab_mask, idx)

    @staticmethod
91
92
    def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
        return vocab_mask.to(device, non_blocking=True)
93

94
95
    @staticmethod
    def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
96
        apply_token_bitmask_inplace(logits, vocab_mask)
Lianmin Zheng's avatar
Lianmin Zheng committed
97

98
    def copy(self):
99
        matcher = GrammarMatcher(self.ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
100
101
        return XGrammarGrammar(matcher, self.vocab_size, self.ctx)

Lianmin Zheng's avatar
Lianmin Zheng committed
102

103
class XGrammarGrammarBackend(BaseGrammarBackend):
Lianmin Zheng's avatar
Lianmin Zheng committed
104
105
106
107
108
    def __init__(
        self,
        tokenizer,
        vocab_size: int,
    ):
109
        super().__init__()
110

111
112
113
114
        tokenizer_info = TokenizerInfo.from_huggingface(
            tokenizer, vocab_size=vocab_size
        )
        self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
Lianmin Zheng's avatar
Lianmin Zheng committed
115
116
        self.vocab_size = vocab_size

117
    def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
118

Lianmin Zheng's avatar
Lianmin Zheng committed
119
120
        key_type, key_string = key
        if key_type == "json":
121
            try:
122
123
124
125
                if key_string == "$$ANY$$":
                    ctx = self.grammar_compiler.compile_builtin_json_grammar()
                else:
                    ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
126
127
128
129
130
            except RuntimeError as e:
                logging.warning(
                    f"Skip invalid json_schema: json_schema={key_string}, {e=}"
                )
                return None
131
132
133
134
135
136
        elif key_type == "ebnf":
            try:
                ctx = self.grammar_compiler.compile_grammar(key_string)
            except RuntimeError as e:
                logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
                return None
Lianmin Zheng's avatar
Lianmin Zheng committed
137
        elif key_type == "regex":
138
            try:
139
                ctx = self.grammar_compiler.compile_regex(key_string)
140
141
142
            except RuntimeError as e:
                logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
                return None
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
        elif key_type == "structural_tag":
            try:
                structural_tag = json.loads(key_string)
                tags = [
                    StructuralTagItem(
                        begin=structure["begin"],
                        schema=json.dumps(structure["schema"]),
                        end=structure["end"],
                    )
                    for structure in structural_tag["structures"]
                ]
                ctx = self.grammar_compiler.compile_structural_tag(
                    tags, structural_tag["triggers"]
                )
            except RuntimeError as e:
                logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
                return None
Lianmin Zheng's avatar
Lianmin Zheng committed
160
161
162
        else:
            raise ValueError(f"Invalid key_type: {key_type}")

163
        matcher = GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
164
        return XGrammarGrammar(matcher, self.vocab_size, ctx)
Lianmin Zheng's avatar
Lianmin Zheng committed
165
166

    def reset(self):
167
168
        if self.grammar_compiler:
            self.grammar_compiler.clear_cache()