Unverified Commit b77a02cd authored by DarkSharpness's avatar DarkSharpness Committed by GitHub
Browse files

[Performance] Support both xgrammar and outlines for constrained decoding (#1752)

parent 30643fed
......@@ -51,6 +51,21 @@ except ImportError:
return build_regex_from_schema(schema, whitespace_pattern)
try:
from xgrammar import (
GrammarMatcher,
GrammarMatcherInitContext,
GrammarMatcherInitContextCache,
)
except ImportError as e:
class Dummy:
pass
GrammarMatcher = Dummy
GrammarMatcherInitContext = Dummy
GrammarMatcherInitContextCache = Dummy
__all__ = [
"RegexGuide",
"FSMInfo",
......@@ -60,4 +75,7 @@ __all__ = [
"disk_cache",
"disable_cache",
"make_byte_level_fsm",
"GrammarMatcher",
"GrammarMatcherInitContext",
"GrammarMatcherInitContextCache",
]
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""Cache for the compressed finite state machine."""
from typing import Tuple
from transformers import AutoTokenizer
from sglang.srt.constrained import (
GrammarMatcher,
GrammarMatcherInitContext,
GrammarMatcherInitContextCache,
)
MAX_ROLLBACK_TOKENS = 10
class BNFCache:
grammar_cache: GrammarMatcherInitContextCache
def __init__(
self,
tokenizer_path,
tokenizer_args_dict,
skip_tokenizer_init=False,
whitespace_patterns=None,
):
# TODO(dark): how to deal with whitespace_patterns and skip_tokenizer_init
if skip_tokenizer_init:
return
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_args_dict)
self.grammar_cache = GrammarMatcherInitContextCache(
tokenizer_or_vocab=tokenizer
)
def get_context(self, key: Tuple[str, str]) -> GrammarMatcherInitContext:
key_type, key_string = key
if key_type == "json":
return self.grammar_cache.get_init_context_for_json_schema(key_string)
elif key_type == "regex":
raise ValueError(f"regex hasn't been supported by xgrammar yet")
else:
raise ValueError(f"Invalid key_type: {key_type}")
def query(self, key: Tuple[str, str], vocab_size: int) -> GrammarMatcher:
ctx = self.get_context(key)
return GrammarMatcher(
ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS, mask_vocab_size=vocab_size
)
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""Cache for the compressed finite state machine."""
import logging
from typing import List, Optional, Tuple, Union
import torch
from sglang.srt.constrained import GrammarMatcher, RegexGuide
from sglang.srt.constrained.bnf_cache import BNFCache
from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.constrained.jump_forward import JumpForwardCache, JumpForwardMap
# from sglang.srt.managers.schedule_batch import Req
logger = logging.getLogger(__name__)
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
class XGrammarJump:
pass
class JumpHelper:
data: Union[List, str]
state: int
suffix_ids: List[int]
def __init__(
self, data: Union[List, str] = "", state: int = -1, suffix_ids=[]
) -> None:
self.data = data
self.state = state
self.suffix_ids = suffix_ids
def can_jump(self):
return len(self.data) > 0
class Grammar:
grammar: Union[GrammarMatcher, Tuple[RegexGuide, int]]
jump_map: Union[XGrammarJump, JumpForwardMap, None]
def __init__(
self,
grammar: Union[GrammarMatcher, Tuple[RegexGuide, int]],
jump_map: Union[XGrammarJump, JumpForwardMap, None],
) -> None:
self.grammar = grammar
self.jump_map = jump_map
def accept_token(self, token: int):
if isinstance(self.grammar, GrammarMatcher):
assert self.grammar.accept_token(token)
else:
guide, state = self.grammar
self.grammar = guide, guide.get_next_state(state, token)
def try_jump(self, tokenizer) -> JumpHelper:
if isinstance(self.jump_map, XGrammarJump):
assert isinstance(self.grammar, GrammarMatcher)
return JumpHelper(self.grammar.find_jump_forward_string())
elif isinstance(self.jump_map, JumpForwardMap):
assert isinstance(self.grammar, Tuple)
_, state = self.grammar
jump_forward_bytes = self.jump_map.jump_forward_byte(state)
if jump_forward_bytes is None or len(jump_forward_bytes) == 0:
return JumpHelper() # can't jump
# preprocess the jump forward string
suffix_bytes = []
continuation_range = range(0x80, 0xC0)
cur_state = state
while (
len(jump_forward_bytes)
and jump_forward_bytes[0][0] in continuation_range
):
# continuation bytes
byte_edge = jump_forward_bytes.pop(0)
suffix_bytes.append(byte_edge[0])
cur_state = byte_edge[1]
suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
suffix_ids = tokenizer.convert_tokens_to_ids(suffix_tokens)
return JumpHelper(suffix_ids, cur_state, suffix_bytes)
else:
return JumpHelper() # can't jump
def jump_forward_str_state(self, helper: JumpHelper) -> Tuple[str, int]:
if isinstance(helper.data, str):
return helper.data, -1
else:
assert isinstance(self.jump_map, JumpForwardMap)
return self.jump_map.jump_forward_symbol(helper.state)
def jump_and_retokenize(
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
):
if isinstance(self.grammar, GrammarMatcher):
k = 0
for i, old_id in enumerate(old_output_ids):
if old_id == new_output_ids[i]:
k = i + 1
else:
break
# rollback to the last token that is the same
if k < len(old_output_ids):
self.grammar.rollback(len(old_output_ids) - k)
for i in range(k, len(new_output_ids)):
assert self.grammar.accept_token(new_output_ids[i])
else:
self.grammar = self.grammar[0], next_state
def fill_vocab_mask(self, vocab_mask: torch.Tensor, vocab_size: int):
if isinstance(self.grammar, GrammarMatcher):
# Note that this bitmask is a bitset, not bool
bitmask = self.grammar.find_next_token_bitmask()
# Mask the tokens that are not allowed
vocab_mask[
self.grammar.get_rejected_tokens_from_bitmask(bitmask, vocab_size)
] = 1
else:
guide, state = self.grammar
vocab_mask.fill_(1)
vocab_mask[guide.get_next_instruction(state).tokens] = 0
class GrammarCache:
grammar_cache: Union[BNFCache, FSMCache]
jump_cache: Union[XGrammarJump, JumpForwardCache, None]
def __init__(
self,
tokenizer_path,
tokenizer_args_dict,
skip_tokenizer_init=False,
whitespace_patterns=None,
backend=None,
allow_jump=False,
):
if backend == "xgrammar":
self.grammar_cache = BNFCache(
tokenizer_path=tokenizer_path,
tokenizer_args_dict=tokenizer_args_dict,
skip_tokenizer_init=skip_tokenizer_init,
whitespace_patterns=whitespace_patterns,
)
self.jump_cache = XGrammarJump() if allow_jump else None
else:
assert backend == "outlines"
self.grammar_cache = FSMCache(
tokenizer_path=tokenizer_path,
tokenizer_args_dict=tokenizer_args_dict,
skip_tokenizer_init=skip_tokenizer_init,
constrained_json_whitespace_pattern=whitespace_patterns,
enable=True,
)
self.jump_cache = JumpForwardCache() if allow_jump else None
def query(self, key: Tuple[str, str], vocab_size: int) -> Grammar:
if isinstance(self.grammar_cache, BNFCache):
assert not isinstance(self.jump_cache, JumpForwardCache)
return Grammar(self.grammar_cache.query(key, vocab_size), self.jump_cache)
else:
jump_map = None
guide, regex = self.grammar_cache.query(key)
if isinstance(self.jump_cache, JumpForwardCache):
jump_map = self.jump_cache.query(regex)
return Grammar((guide, 0), jump_map)
def reset(self):
if isinstance(self.grammar_cache, FSMCache):
self.grammar_cache.reset()
if isinstance(self.jump_cache, JumpForwardCache):
self.jump_cache.reset()
......@@ -37,8 +37,7 @@ import torch
from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.constrained.grammar import Grammar
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
......@@ -247,9 +246,7 @@ class Req:
self.embedding = None
# Constrained decoding
self.regex_fsm: RegexGuide = None
self.regex_fsm_state: int = 0
self.jump_forward_map: JumpForwardMap = None
self.grammar: Optional[Grammar] = None
# For Qwen2-VL
self.mrope_position_delta = [] # use mutable object
......@@ -359,6 +356,8 @@ class Req:
return
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
assert self.grammar is not None and self.tokenizer is not None
if self.origin_input_text is None:
# Recovering text can only use unpadded ids
self.origin_input_text = self.tokenizer.decode(
......@@ -398,7 +397,8 @@ class Req:
self.surr_offset = self.read_offset - i
break
self.regex_fsm_state = next_state
# update the inner state of the grammar
self.grammar.jump_and_retokenize(old_output_ids, self.output_ids, next_state)
if self.return_logprob:
# For fast-forward part's logprobs
......@@ -468,8 +468,8 @@ class ScheduleBatch:
# Stream
has_stream: bool = False
# Has regex
has_regex: bool = False
# Has grammar
has_grammar: bool = False
# device
device: str = "cuda"
......@@ -477,7 +477,7 @@ class ScheduleBatch:
@classmethod
def init_new(
cls,
reqs,
reqs: List[Req],
req_to_token_pool,
token_to_kv_pool,
tree_cache,
......@@ -491,7 +491,7 @@ class ScheduleBatch:
model_config=model_config,
return_logprob=any(req.return_logprob for req in reqs),
has_stream=any(req.stream for req in reqs),
has_regex=any(req.regex_fsm for req in reqs),
has_grammar=any(req.grammar for req in reqs),
device=req_to_token_pool.device,
)
......@@ -803,26 +803,10 @@ class ScheduleBatch:
keep_indices = set(i for i in range(len(self.reqs)))
for i, req in enumerate(self.reqs):
if req.jump_forward_map is not None:
jump_forward_bytes = req.jump_forward_map.jump_forward_byte(
req.regex_fsm_state
)
if jump_forward_bytes is not None and len(jump_forward_bytes) > 1:
suffix_bytes = []
continuation_range = range(0x80, 0xC0)
cur_state = req.regex_fsm_state
while (
len(jump_forward_bytes)
and jump_forward_bytes[0][0] in continuation_range
):
# continuation bytes
byte_edge = jump_forward_bytes.pop(0)
suffix_bytes.append(byte_edge[0])
cur_state = byte_edge[1]
suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
suffix_ids = req.tokenizer.convert_tokens_to_ids(suffix_tokens)
if req.grammar is not None:
jump_helper = req.grammar.try_jump(req.tokenizer)
if jump_helper.can_jump():
suffix_ids = jump_helper.suffix_ids
# Current ids, for cache and revert
cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
cur_output_ids = req.output_ids
......@@ -836,10 +820,8 @@ class ScheduleBatch:
(
jump_forward_str,
next_state,
) = req.jump_forward_map.jump_forward_symbol(cur_state)
) = req.grammar.jump_forward_str_state(jump_helper)
# Make the incrementally decoded text part of jump_forward_str
# so that the UTF-8 will not corrupt
jump_forward_str = new_text + jump_forward_str
if not req.jump_forward_and_retokenize(
jump_forward_str, next_state
......@@ -946,7 +928,7 @@ class ScheduleBatch:
self.top_logprobs_nums = None
self.has_stream = any(req.stream for req in self.reqs)
self.has_regex = any(req.regex_fsm for req in self.reqs)
self.has_grammar = any(req.grammar for req in self.reqs)
self.sampling_info.filter_batch(keep_indices, new_indices)
......@@ -979,7 +961,7 @@ class ScheduleBatch:
self.return_logprob = self.return_logprob or other.return_logprob
self.has_stream = self.has_stream or other.has_stream
self.has_regex = self.has_regex or other.has_regex
self.has_grammar = self.has_grammar or other.has_grammar
def get_model_worker_batch(self):
if self.forward_mode.is_decode():
......@@ -989,13 +971,10 @@ class ScheduleBatch:
extend_prefix_lens = self.prefix_lens
extend_logprob_start_lens = self.extend_logprob_start_lens
if self.has_regex:
self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs]
self.sampling_info.regex_fsm_states = [
req.regex_fsm_state for req in self.reqs
]
if self.has_grammar:
self.sampling_info.grammars = [req.grammar for req in self.reqs]
else:
self.sampling_info.regex_fsms = None
self.sampling_info.grammars = None
global bid
bid += 1
......
......@@ -29,8 +29,7 @@ import zmq
from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.constrained.jump_forward import JumpForwardCache
from sglang.srt.constrained.grammar import GrammarCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
......@@ -225,17 +224,20 @@ class Scheduler:
)
# Init the FSM cache for constrained generation
self.grammar_cache = None
if not server_args.skip_tokenizer_init:
self.regex_fsm_cache = FSMCache(
self.grammar_cache = GrammarCache(
server_args.tokenizer_path,
{
"tokenizer_mode": server_args.tokenizer_mode,
"trust_remote_code": server_args.trust_remote_code,
},
skip_tokenizer_init=server_args.skip_tokenizer_init,
constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern,
whitespace_patterns=server_args.constrained_json_whitespace_pattern,
backend=server_args.grammar_backend,
allow_jump=not server_args.disable_regex_jump_forward,
)
self.jump_forward_cache = JumpForwardCache()
# Init new token estimation
assert (
......@@ -402,22 +404,20 @@ class Scheduler:
# By default, only return the logprobs for output tokens
req.logprob_start_len = len(recv_req.input_ids) - 1
# Init regex FSM
# Init regex FSM or BNF
if (
req.sampling_params.json_schema is not None
or req.sampling_params.regex is not None
):
assert self.grammar_cache is not None
if req.sampling_params.json_schema is not None:
req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query(
("json", req.sampling_params.json_schema)
req.grammar = self.grammar_cache.query(
("json", req.sampling_params.json_schema),
self.model_config.vocab_size,
)
elif req.sampling_params.regex is not None:
req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query(
("regex", req.sampling_params.regex)
)
if not self.disable_regex_jump_forward:
req.jump_forward_map = self.jump_forward_cache.query(
computed_regex_string
req.grammar = self.grammar_cache.query(
("regex", req.sampling_params.regex), self.model_config.vocab_size
)
# Truncate prompts that are too long
......@@ -796,10 +796,8 @@ class Scheduler:
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
self.tree_cache.cache_unfinished_req(req)
if req.regex_fsm is not None:
req.regex_fsm_state = req.regex_fsm.get_next_state(
req.regex_fsm_state, next_token_ids[i]
)
if req.grammar is not None:
req.grammar.accept_token(next_token_ids[i])
if req.return_logprob:
logprob_pt += self.add_logprob_return_values(
......@@ -855,10 +853,8 @@ class Scheduler:
req.output_ids.append(next_token_id)
req.check_finished()
if req.regex_fsm is not None:
req.regex_fsm_state = req.regex_fsm.get_next_state(
req.regex_fsm_state, next_token_id
)
if req.grammar is not None:
req.grammar.accept_token(next_token_id)
if req.finished():
self.tree_cache.cache_finished_req(req)
......@@ -1056,7 +1052,9 @@ class Scheduler:
):
self.tree_cache.reset()
self.tree_cache_metrics = {"total": 0, "hit": 0}
self.regex_fsm_cache.reset()
if self.grammar_cache is not None:
self.grammar_cache.reset()
# TODO(dark): reset the bnf cache
self.req_to_token_pool.clear()
self.token_to_kv_pool.clear()
torch.cuda.empty_cache()
......
......@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, List, Optional
import torch
import sglang.srt.sampling.penaltylib as penaltylib
from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.grammar import Grammar
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch
......@@ -29,11 +29,9 @@ class SamplingBatchInfo:
# Bias Tensors
vocab_size: int
logit_bias: torch.Tensor = None
vocab_mask: torch.Tensor = None
vocab_mask: Optional[torch.Tensor] = None
# FSM states
regex_fsms: List[RegexGuide] = None
regex_fsm_states: List[int] = None
grammars: Optional[List[Optional[Grammar]]] = None
# Penalizer
penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
......@@ -136,8 +134,7 @@ class SamplingBatchInfo:
self.linear_penalties = penalizer.apply(self.linear_penalties)
def update_regex_vocab_mask(self):
has_regex = self.regex_fsms and any(regex_fsm for regex_fsm in self.regex_fsms)
if not has_regex:
if not self.grammars or not any(grammar for grammar in self.grammars):
self.vocab_mask = None
return
......@@ -147,12 +144,9 @@ class SamplingBatchInfo:
dtype=torch.bool,
device=self.device,
)
for i, regex_fsm in enumerate(self.regex_fsms):
if regex_fsm is not None:
self.vocab_mask[i].fill_(1)
self.vocab_mask[i][
regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens
] = 0
for i, grammar in enumerate(self.grammars):
if grammar is not None:
grammar.fill_vocab_mask(self.vocab_mask[i], self.vocab_size)
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
if self.penalizer_orchestrator:
......
......@@ -102,6 +102,7 @@ class ServerArgs:
# Kernel backend
attention_backend: Optional[str] = None
sampling_backend: Optional[str] = None
grammar_backend: Optional[str] = "outlines"
# Optimization/debug options
disable_flashinfer: bool = False
......@@ -537,6 +538,13 @@ class ServerArgs:
default=ServerArgs.sampling_backend,
help="Choose the kernels for sampling layers.",
)
parser.add_argument(
"--grammar-backend",
type=str,
choices=["xgrammar", "outlines"],
default=ServerArgs.grammar_backend,
help="Choose the backend for constrained decoding.",
)
# Optimization/debug options
parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment