xgrammar_backend.py 4.86 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
15
"""Constrained decoding with xgrammar backend."""

16
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
17
18
19
from typing import List, Tuple

import torch
20
21
22
23
24
25
26
27
from xgrammar import (
    CompiledGrammar,
    GrammarCompiler,
    GrammarMatcher,
    TokenizerInfo,
    allocate_token_bitmask,
    apply_token_bitmask_inplace,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
28

29
30
31
32
from sglang.srt.constrained.base_grammar_backend import (
    BaseGrammarBackend,
    BaseGrammarObject,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
33

34
35
36
logger = logging.getLogger(__name__)


37
MAX_ROLLBACK_TOKENS = 200
Lianmin Zheng's avatar
Lianmin Zheng committed
38
39


40
class XGrammarGrammar(BaseGrammarObject):
Lianmin Zheng's avatar
Lianmin Zheng committed
41

42
43
44
    def __init__(
        self, matcher: GrammarMatcher, vocab_size: int, ctx: CompiledGrammar
    ) -> None:
Lianmin Zheng's avatar
Lianmin Zheng committed
45
46
        self.matcher = matcher
        self.vocab_size = vocab_size
47
        self.ctx = ctx
48
        self.finished = False
Lianmin Zheng's avatar
Lianmin Zheng committed
49
50
51
52
53

    def accept_token(self, token: int):
        assert self.matcher.accept_token(token)

    def try_jump_forward(self, tokenizer) -> Tuple[List[int], str]:
54
55
56
57
        s = self.matcher.find_jump_forward_string()
        if s:
            return [], s
        return None
Lianmin Zheng's avatar
Lianmin Zheng committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79

    def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
        _, data = helper
        return data, -1

    def jump_and_retokenize(
        self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
    ):
        k = 0
        for i, old_id in enumerate(old_output_ids):
            if old_id == new_output_ids[i]:
                k = i + 1
            else:
                break

        # rollback to the last token that is the same
        if k < len(old_output_ids):
            self.matcher.rollback(len(old_output_ids) - k)

        for i in range(k, len(new_output_ids)):
            assert self.matcher.accept_token(new_output_ids[i])

80
81
82
    def allocate_vocab_mask(
        self, vocab_size: int, batch_size: int, device
    ) -> torch.Tensor:
83
        return allocate_token_bitmask(batch_size, vocab_size)
84
85
86
87
88

    def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
        self.matcher.fill_next_token_bitmask(vocab_mask, idx)

    @staticmethod
89
90
    def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
        return vocab_mask.to(device, non_blocking=True)
91

92
93
    @staticmethod
    def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
94
        apply_token_bitmask_inplace(logits, vocab_mask)
Lianmin Zheng's avatar
Lianmin Zheng committed
95

96
    def copy(self):
97
        matcher = GrammarMatcher(self.ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
98
99
        return XGrammarGrammar(matcher, self.vocab_size, self.ctx)

Lianmin Zheng's avatar
Lianmin Zheng committed
100

101
class XGrammarGrammarBackend(BaseGrammarBackend):
Lianmin Zheng's avatar
Lianmin Zheng committed
102
103
104
105
106
    def __init__(
        self,
        tokenizer,
        vocab_size: int,
    ):
107
        super().__init__()
108

109
110
111
112
        tokenizer_info = TokenizerInfo.from_huggingface(
            tokenizer, vocab_size=vocab_size
        )
        self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
Lianmin Zheng's avatar
Lianmin Zheng committed
113
114
        self.vocab_size = vocab_size

115
    def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
116

Lianmin Zheng's avatar
Lianmin Zheng committed
117
118
        key_type, key_string = key
        if key_type == "json":
119
            try:
120
121
122
123
                if key_string == "$$ANY$$":
                    ctx = self.grammar_compiler.compile_builtin_json_grammar()
                else:
                    ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
124
125
126
127
128
            except RuntimeError as e:
                logging.warning(
                    f"Skip invalid json_schema: json_schema={key_string}, {e=}"
                )
                return None
129
130
131
132
133
134
        elif key_type == "ebnf":
            try:
                ctx = self.grammar_compiler.compile_grammar(key_string)
            except RuntimeError as e:
                logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
                return None
Lianmin Zheng's avatar
Lianmin Zheng committed
135
        elif key_type == "regex":
136
            try:
137
                ctx = self.grammar_compiler.compile_regex(key_string)
138
139
140
            except RuntimeError as e:
                logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
                return None
Lianmin Zheng's avatar
Lianmin Zheng committed
141
142
143
        else:
            raise ValueError(f"Invalid key_type: {key_type}")

144
        matcher = GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
145
        return XGrammarGrammar(matcher, self.vocab_size, ctx)
Lianmin Zheng's avatar
Lianmin Zheng committed
146
147

    def reset(self):
148
149
        if self.grammar_compiler:
            self.grammar_compiler.clear_cache()