"tests/python/vscode:/vscode.git/clone" did not exist on "dadce86a782f4527978f2efb349ae65087a77f08"
Unverified Commit b77a02cd authored by DarkSharpness's avatar DarkSharpness Committed by GitHub
Browse files

[Performance] Support both xgrammar and outlines for constrained decoding (#1752)

parent 30643fed
...@@ -51,6 +51,21 @@ except ImportError: ...@@ -51,6 +51,21 @@ except ImportError:
return build_regex_from_schema(schema, whitespace_pattern) return build_regex_from_schema(schema, whitespace_pattern)
try:
from xgrammar import (
GrammarMatcher,
GrammarMatcherInitContext,
GrammarMatcherInitContextCache,
)
except ImportError as e:
class Dummy:
pass
GrammarMatcher = Dummy
GrammarMatcherInitContext = Dummy
GrammarMatcherInitContextCache = Dummy
__all__ = [ __all__ = [
"RegexGuide", "RegexGuide",
"FSMInfo", "FSMInfo",
...@@ -60,4 +75,7 @@ __all__ = [ ...@@ -60,4 +75,7 @@ __all__ = [
"disk_cache", "disk_cache",
"disable_cache", "disable_cache",
"make_byte_level_fsm", "make_byte_level_fsm",
"GrammarMatcher",
"GrammarMatcherInitContext",
"GrammarMatcherInitContextCache",
] ]
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""Cache for the compressed finite state machine."""
from typing import Tuple
from transformers import AutoTokenizer
from sglang.srt.constrained import (
GrammarMatcher,
GrammarMatcherInitContext,
GrammarMatcherInitContextCache,
)
MAX_ROLLBACK_TOKENS = 10
class BNFCache:
grammar_cache: GrammarMatcherInitContextCache
def __init__(
self,
tokenizer_path,
tokenizer_args_dict,
skip_tokenizer_init=False,
whitespace_patterns=None,
):
# TODO(dark): how to deal with whitespace_patterns and skip_tokenizer_init
if skip_tokenizer_init:
return
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_args_dict)
self.grammar_cache = GrammarMatcherInitContextCache(
tokenizer_or_vocab=tokenizer
)
def get_context(self, key: Tuple[str, str]) -> GrammarMatcherInitContext:
key_type, key_string = key
if key_type == "json":
return self.grammar_cache.get_init_context_for_json_schema(key_string)
elif key_type == "regex":
raise ValueError(f"regex hasn't been supported by xgrammar yet")
else:
raise ValueError(f"Invalid key_type: {key_type}")
def query(self, key: Tuple[str, str], vocab_size: int) -> GrammarMatcher:
ctx = self.get_context(key)
return GrammarMatcher(
ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS, mask_vocab_size=vocab_size
)
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""Cache for the compressed finite state machine."""
import logging
from typing import List, Optional, Tuple, Union
import torch
from sglang.srt.constrained import GrammarMatcher, RegexGuide
from sglang.srt.constrained.bnf_cache import BNFCache
from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.constrained.jump_forward import JumpForwardCache, JumpForwardMap
# from sglang.srt.managers.schedule_batch import Req
logger = logging.getLogger(__name__)
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
class XGrammarJump:
pass
class JumpHelper:
data: Union[List, str]
state: int
suffix_ids: List[int]
def __init__(
self, data: Union[List, str] = "", state: int = -1, suffix_ids=[]
) -> None:
self.data = data
self.state = state
self.suffix_ids = suffix_ids
def can_jump(self):
return len(self.data) > 0
class Grammar:
grammar: Union[GrammarMatcher, Tuple[RegexGuide, int]]
jump_map: Union[XGrammarJump, JumpForwardMap, None]
def __init__(
self,
grammar: Union[GrammarMatcher, Tuple[RegexGuide, int]],
jump_map: Union[XGrammarJump, JumpForwardMap, None],
) -> None:
self.grammar = grammar
self.jump_map = jump_map
def accept_token(self, token: int):
if isinstance(self.grammar, GrammarMatcher):
assert self.grammar.accept_token(token)
else:
guide, state = self.grammar
self.grammar = guide, guide.get_next_state(state, token)
def try_jump(self, tokenizer) -> JumpHelper:
if isinstance(self.jump_map, XGrammarJump):
assert isinstance(self.grammar, GrammarMatcher)
return JumpHelper(self.grammar.find_jump_forward_string())
elif isinstance(self.jump_map, JumpForwardMap):
assert isinstance(self.grammar, Tuple)
_, state = self.grammar
jump_forward_bytes = self.jump_map.jump_forward_byte(state)
if jump_forward_bytes is None or len(jump_forward_bytes) == 0:
return JumpHelper() # can't jump
# preprocess the jump forward string
suffix_bytes = []
continuation_range = range(0x80, 0xC0)
cur_state = state
while (
len(jump_forward_bytes)
and jump_forward_bytes[0][0] in continuation_range
):
# continuation bytes
byte_edge = jump_forward_bytes.pop(0)
suffix_bytes.append(byte_edge[0])
cur_state = byte_edge[1]
suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
suffix_ids = tokenizer.convert_tokens_to_ids(suffix_tokens)
return JumpHelper(suffix_ids, cur_state, suffix_bytes)
else:
return JumpHelper() # can't jump
def jump_forward_str_state(self, helper: JumpHelper) -> Tuple[str, int]:
if isinstance(helper.data, str):
return helper.data, -1
else:
assert isinstance(self.jump_map, JumpForwardMap)
return self.jump_map.jump_forward_symbol(helper.state)
def jump_and_retokenize(
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
):
if isinstance(self.grammar, GrammarMatcher):
k = 0
for i, old_id in enumerate(old_output_ids):
if old_id == new_output_ids[i]:
k = i + 1
else:
break
# rollback to the last token that is the same
if k < len(old_output_ids):
self.grammar.rollback(len(old_output_ids) - k)
for i in range(k, len(new_output_ids)):
assert self.grammar.accept_token(new_output_ids[i])
else:
self.grammar = self.grammar[0], next_state
def fill_vocab_mask(self, vocab_mask: torch.Tensor, vocab_size: int):
if isinstance(self.grammar, GrammarMatcher):
# Note that this bitmask is a bitset, not bool
bitmask = self.grammar.find_next_token_bitmask()
# Mask the tokens that are not allowed
vocab_mask[
self.grammar.get_rejected_tokens_from_bitmask(bitmask, vocab_size)
] = 1
else:
guide, state = self.grammar
vocab_mask.fill_(1)
vocab_mask[guide.get_next_instruction(state).tokens] = 0
class GrammarCache:
grammar_cache: Union[BNFCache, FSMCache]
jump_cache: Union[XGrammarJump, JumpForwardCache, None]
def __init__(
self,
tokenizer_path,
tokenizer_args_dict,
skip_tokenizer_init=False,
whitespace_patterns=None,
backend=None,
allow_jump=False,
):
if backend == "xgrammar":
self.grammar_cache = BNFCache(
tokenizer_path=tokenizer_path,
tokenizer_args_dict=tokenizer_args_dict,
skip_tokenizer_init=skip_tokenizer_init,
whitespace_patterns=whitespace_patterns,
)
self.jump_cache = XGrammarJump() if allow_jump else None
else:
assert backend == "outlines"
self.grammar_cache = FSMCache(
tokenizer_path=tokenizer_path,
tokenizer_args_dict=tokenizer_args_dict,
skip_tokenizer_init=skip_tokenizer_init,
constrained_json_whitespace_pattern=whitespace_patterns,
enable=True,
)
self.jump_cache = JumpForwardCache() if allow_jump else None
def query(self, key: Tuple[str, str], vocab_size: int) -> Grammar:
if isinstance(self.grammar_cache, BNFCache):
assert not isinstance(self.jump_cache, JumpForwardCache)
return Grammar(self.grammar_cache.query(key, vocab_size), self.jump_cache)
else:
jump_map = None
guide, regex = self.grammar_cache.query(key)
if isinstance(self.jump_cache, JumpForwardCache):
jump_map = self.jump_cache.query(regex)
return Grammar((guide, 0), jump_map)
def reset(self):
if isinstance(self.grammar_cache, FSMCache):
self.grammar_cache.reset()
if isinstance(self.jump_cache, JumpForwardCache):
self.jump_cache.reset()
...@@ -37,8 +37,7 @@ import torch ...@@ -37,8 +37,7 @@ import torch
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained import RegexGuide from sglang.srt.constrained.grammar import Grammar
from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
...@@ -247,9 +246,7 @@ class Req: ...@@ -247,9 +246,7 @@ class Req:
self.embedding = None self.embedding = None
# Constrained decoding # Constrained decoding
self.regex_fsm: RegexGuide = None self.grammar: Optional[Grammar] = None
self.regex_fsm_state: int = 0
self.jump_forward_map: JumpForwardMap = None
# For Qwen2-VL # For Qwen2-VL
self.mrope_position_delta = [] # use mutable object self.mrope_position_delta = [] # use mutable object
...@@ -359,6 +356,8 @@ class Req: ...@@ -359,6 +356,8 @@ class Req:
return return
def jump_forward_and_retokenize(self, jump_forward_str, next_state): def jump_forward_and_retokenize(self, jump_forward_str, next_state):
assert self.grammar is not None and self.tokenizer is not None
if self.origin_input_text is None: if self.origin_input_text is None:
# Recovering text can only use unpadded ids # Recovering text can only use unpadded ids
self.origin_input_text = self.tokenizer.decode( self.origin_input_text = self.tokenizer.decode(
...@@ -398,7 +397,8 @@ class Req: ...@@ -398,7 +397,8 @@ class Req:
self.surr_offset = self.read_offset - i self.surr_offset = self.read_offset - i
break break
self.regex_fsm_state = next_state # update the inner state of the grammar
self.grammar.jump_and_retokenize(old_output_ids, self.output_ids, next_state)
if self.return_logprob: if self.return_logprob:
# For fast-forward part's logprobs # For fast-forward part's logprobs
...@@ -468,8 +468,8 @@ class ScheduleBatch: ...@@ -468,8 +468,8 @@ class ScheduleBatch:
# Stream # Stream
has_stream: bool = False has_stream: bool = False
# Has regex # Has grammar
has_regex: bool = False has_grammar: bool = False
# device # device
device: str = "cuda" device: str = "cuda"
...@@ -477,7 +477,7 @@ class ScheduleBatch: ...@@ -477,7 +477,7 @@ class ScheduleBatch:
@classmethod @classmethod
def init_new( def init_new(
cls, cls,
reqs, reqs: List[Req],
req_to_token_pool, req_to_token_pool,
token_to_kv_pool, token_to_kv_pool,
tree_cache, tree_cache,
...@@ -491,7 +491,7 @@ class ScheduleBatch: ...@@ -491,7 +491,7 @@ class ScheduleBatch:
model_config=model_config, model_config=model_config,
return_logprob=any(req.return_logprob for req in reqs), return_logprob=any(req.return_logprob for req in reqs),
has_stream=any(req.stream for req in reqs), has_stream=any(req.stream for req in reqs),
has_regex=any(req.regex_fsm for req in reqs), has_grammar=any(req.grammar for req in reqs),
device=req_to_token_pool.device, device=req_to_token_pool.device,
) )
...@@ -803,26 +803,10 @@ class ScheduleBatch: ...@@ -803,26 +803,10 @@ class ScheduleBatch:
keep_indices = set(i for i in range(len(self.reqs))) keep_indices = set(i for i in range(len(self.reqs)))
for i, req in enumerate(self.reqs): for i, req in enumerate(self.reqs):
if req.jump_forward_map is not None: if req.grammar is not None:
jump_forward_bytes = req.jump_forward_map.jump_forward_byte( jump_helper = req.grammar.try_jump(req.tokenizer)
req.regex_fsm_state if jump_helper.can_jump():
) suffix_ids = jump_helper.suffix_ids
if jump_forward_bytes is not None and len(jump_forward_bytes) > 1:
suffix_bytes = []
continuation_range = range(0x80, 0xC0)
cur_state = req.regex_fsm_state
while (
len(jump_forward_bytes)
and jump_forward_bytes[0][0] in continuation_range
):
# continuation bytes
byte_edge = jump_forward_bytes.pop(0)
suffix_bytes.append(byte_edge[0])
cur_state = byte_edge[1]
suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
suffix_ids = req.tokenizer.convert_tokens_to_ids(suffix_tokens)
# Current ids, for cache and revert # Current ids, for cache and revert
cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1] cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
cur_output_ids = req.output_ids cur_output_ids = req.output_ids
...@@ -836,10 +820,8 @@ class ScheduleBatch: ...@@ -836,10 +820,8 @@ class ScheduleBatch:
( (
jump_forward_str, jump_forward_str,
next_state, next_state,
) = req.jump_forward_map.jump_forward_symbol(cur_state) ) = req.grammar.jump_forward_str_state(jump_helper)
# Make the incrementally decoded text part of jump_forward_str
# so that the UTF-8 will not corrupt
jump_forward_str = new_text + jump_forward_str jump_forward_str = new_text + jump_forward_str
if not req.jump_forward_and_retokenize( if not req.jump_forward_and_retokenize(
jump_forward_str, next_state jump_forward_str, next_state
...@@ -946,7 +928,7 @@ class ScheduleBatch: ...@@ -946,7 +928,7 @@ class ScheduleBatch:
self.top_logprobs_nums = None self.top_logprobs_nums = None
self.has_stream = any(req.stream for req in self.reqs) self.has_stream = any(req.stream for req in self.reqs)
self.has_regex = any(req.regex_fsm for req in self.reqs) self.has_grammar = any(req.grammar for req in self.reqs)
self.sampling_info.filter_batch(keep_indices, new_indices) self.sampling_info.filter_batch(keep_indices, new_indices)
...@@ -979,7 +961,7 @@ class ScheduleBatch: ...@@ -979,7 +961,7 @@ class ScheduleBatch:
self.return_logprob = self.return_logprob or other.return_logprob self.return_logprob = self.return_logprob or other.return_logprob
self.has_stream = self.has_stream or other.has_stream self.has_stream = self.has_stream or other.has_stream
self.has_regex = self.has_regex or other.has_regex self.has_grammar = self.has_grammar or other.has_grammar
def get_model_worker_batch(self): def get_model_worker_batch(self):
if self.forward_mode.is_decode(): if self.forward_mode.is_decode():
...@@ -989,13 +971,10 @@ class ScheduleBatch: ...@@ -989,13 +971,10 @@ class ScheduleBatch:
extend_prefix_lens = self.prefix_lens extend_prefix_lens = self.prefix_lens
extend_logprob_start_lens = self.extend_logprob_start_lens extend_logprob_start_lens = self.extend_logprob_start_lens
if self.has_regex: if self.has_grammar:
self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs] self.sampling_info.grammars = [req.grammar for req in self.reqs]
self.sampling_info.regex_fsm_states = [
req.regex_fsm_state for req in self.reqs
]
else: else:
self.sampling_info.regex_fsms = None self.sampling_info.grammars = None
global bid global bid
bid += 1 bid += 1
......
...@@ -29,8 +29,7 @@ import zmq ...@@ -29,8 +29,7 @@ import zmq
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.fsm_cache import FSMCache from sglang.srt.constrained.grammar import GrammarCache
from sglang.srt.constrained.jump_forward import JumpForwardCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
...@@ -225,17 +224,20 @@ class Scheduler: ...@@ -225,17 +224,20 @@ class Scheduler:
) )
# Init the FSM cache for constrained generation # Init the FSM cache for constrained generation
self.grammar_cache = None
if not server_args.skip_tokenizer_init: if not server_args.skip_tokenizer_init:
self.regex_fsm_cache = FSMCache( self.grammar_cache = GrammarCache(
server_args.tokenizer_path, server_args.tokenizer_path,
{ {
"tokenizer_mode": server_args.tokenizer_mode, "tokenizer_mode": server_args.tokenizer_mode,
"trust_remote_code": server_args.trust_remote_code, "trust_remote_code": server_args.trust_remote_code,
}, },
skip_tokenizer_init=server_args.skip_tokenizer_init, skip_tokenizer_init=server_args.skip_tokenizer_init,
constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern, whitespace_patterns=server_args.constrained_json_whitespace_pattern,
backend=server_args.grammar_backend,
allow_jump=not server_args.disable_regex_jump_forward,
) )
self.jump_forward_cache = JumpForwardCache()
# Init new token estimation # Init new token estimation
assert ( assert (
...@@ -402,22 +404,20 @@ class Scheduler: ...@@ -402,22 +404,20 @@ class Scheduler:
# By default, only return the logprobs for output tokens # By default, only return the logprobs for output tokens
req.logprob_start_len = len(recv_req.input_ids) - 1 req.logprob_start_len = len(recv_req.input_ids) - 1
# Init regex FSM # Init regex FSM or BNF
if ( if (
req.sampling_params.json_schema is not None req.sampling_params.json_schema is not None
or req.sampling_params.regex is not None or req.sampling_params.regex is not None
): ):
assert self.grammar_cache is not None
if req.sampling_params.json_schema is not None: if req.sampling_params.json_schema is not None:
req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query( req.grammar = self.grammar_cache.query(
("json", req.sampling_params.json_schema) ("json", req.sampling_params.json_schema),
self.model_config.vocab_size,
) )
elif req.sampling_params.regex is not None: elif req.sampling_params.regex is not None:
req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query( req.grammar = self.grammar_cache.query(
("regex", req.sampling_params.regex) ("regex", req.sampling_params.regex), self.model_config.vocab_size
)
if not self.disable_regex_jump_forward:
req.jump_forward_map = self.jump_forward_cache.query(
computed_regex_string
) )
# Truncate prompts that are too long # Truncate prompts that are too long
...@@ -796,10 +796,8 @@ class Scheduler: ...@@ -796,10 +796,8 @@ class Scheduler:
elif not batch.decoding_reqs or req not in batch.decoding_reqs: elif not batch.decoding_reqs or req not in batch.decoding_reqs:
self.tree_cache.cache_unfinished_req(req) self.tree_cache.cache_unfinished_req(req)
if req.regex_fsm is not None: if req.grammar is not None:
req.regex_fsm_state = req.regex_fsm.get_next_state( req.grammar.accept_token(next_token_ids[i])
req.regex_fsm_state, next_token_ids[i]
)
if req.return_logprob: if req.return_logprob:
logprob_pt += self.add_logprob_return_values( logprob_pt += self.add_logprob_return_values(
...@@ -855,10 +853,8 @@ class Scheduler: ...@@ -855,10 +853,8 @@ class Scheduler:
req.output_ids.append(next_token_id) req.output_ids.append(next_token_id)
req.check_finished() req.check_finished()
if req.regex_fsm is not None: if req.grammar is not None:
req.regex_fsm_state = req.regex_fsm.get_next_state( req.grammar.accept_token(next_token_id)
req.regex_fsm_state, next_token_id
)
if req.finished(): if req.finished():
self.tree_cache.cache_finished_req(req) self.tree_cache.cache_finished_req(req)
...@@ -1056,7 +1052,9 @@ class Scheduler: ...@@ -1056,7 +1052,9 @@ class Scheduler:
): ):
self.tree_cache.reset() self.tree_cache.reset()
self.tree_cache_metrics = {"total": 0, "hit": 0} self.tree_cache_metrics = {"total": 0, "hit": 0}
self.regex_fsm_cache.reset() if self.grammar_cache is not None:
self.grammar_cache.reset()
# TODO(dark): reset the bnf cache
self.req_to_token_pool.clear() self.req_to_token_pool.clear()
self.token_to_kv_pool.clear() self.token_to_kv_pool.clear()
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
...@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, List, Optional ...@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, List, Optional
import torch import torch
import sglang.srt.sampling.penaltylib as penaltylib import sglang.srt.sampling.penaltylib as penaltylib
from sglang.srt.constrained import RegexGuide from sglang.srt.constrained.grammar import Grammar
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.managers.schedule_batch import ScheduleBatch
...@@ -29,11 +29,9 @@ class SamplingBatchInfo: ...@@ -29,11 +29,9 @@ class SamplingBatchInfo:
# Bias Tensors # Bias Tensors
vocab_size: int vocab_size: int
logit_bias: torch.Tensor = None logit_bias: torch.Tensor = None
vocab_mask: torch.Tensor = None vocab_mask: Optional[torch.Tensor] = None
# FSM states grammars: Optional[List[Optional[Grammar]]] = None
regex_fsms: List[RegexGuide] = None
regex_fsm_states: List[int] = None
# Penalizer # Penalizer
penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
...@@ -136,8 +134,7 @@ class SamplingBatchInfo: ...@@ -136,8 +134,7 @@ class SamplingBatchInfo:
self.linear_penalties = penalizer.apply(self.linear_penalties) self.linear_penalties = penalizer.apply(self.linear_penalties)
def update_regex_vocab_mask(self): def update_regex_vocab_mask(self):
has_regex = self.regex_fsms and any(regex_fsm for regex_fsm in self.regex_fsms) if not self.grammars or not any(grammar for grammar in self.grammars):
if not has_regex:
self.vocab_mask = None self.vocab_mask = None
return return
...@@ -147,12 +144,9 @@ class SamplingBatchInfo: ...@@ -147,12 +144,9 @@ class SamplingBatchInfo:
dtype=torch.bool, dtype=torch.bool,
device=self.device, device=self.device,
) )
for i, regex_fsm in enumerate(self.regex_fsms): for i, grammar in enumerate(self.grammars):
if regex_fsm is not None: if grammar is not None:
self.vocab_mask[i].fill_(1) grammar.fill_vocab_mask(self.vocab_mask[i], self.vocab_size)
self.vocab_mask[i][
regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens
] = 0
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor): def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
if self.penalizer_orchestrator: if self.penalizer_orchestrator:
......
...@@ -102,6 +102,7 @@ class ServerArgs: ...@@ -102,6 +102,7 @@ class ServerArgs:
# Kernel backend # Kernel backend
attention_backend: Optional[str] = None attention_backend: Optional[str] = None
sampling_backend: Optional[str] = None sampling_backend: Optional[str] = None
grammar_backend: Optional[str] = "outlines"
# Optimization/debug options # Optimization/debug options
disable_flashinfer: bool = False disable_flashinfer: bool = False
...@@ -537,6 +538,13 @@ class ServerArgs: ...@@ -537,6 +538,13 @@ class ServerArgs:
default=ServerArgs.sampling_backend, default=ServerArgs.sampling_backend,
help="Choose the kernels for sampling layers.", help="Choose the kernels for sampling layers.",
) )
parser.add_argument(
"--grammar-backend",
type=str,
choices=["xgrammar", "outlines"],
default=ServerArgs.grammar_backend,
help="Choose the backend for constrained decoding.",
)
# Optimization/debug options # Optimization/debug options
parser.add_argument( parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment