Unverified Commit 2269cf1e authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Auto Sync] Update base_grammar_backend.py, llguidance_back... (20250911) (#10333)


Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent 151e287d
......@@ -14,8 +14,9 @@
"""The baseclass of a backend for grammar-guided constrained decoding."""
import logging
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from dataclasses import dataclass, field
from threading import Event
from typing import Dict, List, Optional, Tuple
......@@ -26,10 +27,22 @@ from sglang.srt.server_args import ServerArgs
logger = logging.getLogger(__name__)
@dataclass
class GrammarStats:
compilation_time: Optional[float] = None
schema_count: Optional[int] = None
ebnf_size: Optional[int] = None
is_cache_hit: bool = False
is_grammar_aborted: bool = False
tree_traversal_time: List[float] = field(default_factory=list)
class BaseGrammarObject:
def __init__(self):
self._finished = False
self.grammar_stats = None
self.current_token = None
def accept_token(self, token: int) -> None:
"""
......@@ -137,19 +150,26 @@ class BaseGrammarBackend:
return self._not_supported("structural_tag", key_string)
def _init_value_dispatch(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
s = time.perf_counter()
key_type, key_string = key
if key_type == "json":
return self.dispatch_json(key_string)
grammar = self.dispatch_json(key_string)
elif key_type == "regex":
return self.dispatch_regex(key_string)
grammar = self.dispatch_regex(key_string)
elif key_type == "ebnf":
return self.dispatch_ebnf(key_string)
grammar = self.dispatch_ebnf(key_string)
elif key_type == "structural_tag":
return self.dispatch_structural_tag(key_string)
grammar = self.dispatch_structural_tag(key_string)
elif key_type == "structural_pattern":
return self.dispatch_structural_pattern(key_string)
grammar = self.dispatch_structural_pattern(key_string)
elif key_type == "structural_pattern_v2":
grammar = self.dispatch_structural_pattern_v2(key_string)
else:
return self.dispatch_fallback(key_type, key_string)
grammar = self.dispatch_fallback(key_type, key_string)
if grammar is not None and grammar.grammar_stats is not None:
grammar.grammar_stats.compilation_time = time.perf_counter() - s
return grammar
def get_cached_or_future_value(
self, key: Tuple[str, str]
......@@ -167,20 +187,36 @@ class BaseGrammarBackend:
self.cache.clear()
GRAMMAR_BACKEND_REGISTRY = {}
def register_grammar_backend(name, init_func):
GRAMMAR_BACKEND_REGISTRY[name] = init_func
def create_grammar_backend(
server_args: ServerArgs,
tokenizer,
vocab_size: int,
eos_token_ids: Optional[set] = None,
) -> Optional[BaseGrammarBackend]:
if server_args.grammar_backend == "outlines":
name = server_args.grammar_backend
# Custom grammar backend has the highest priority
if name in GRAMMAR_BACKEND_REGISTRY:
return GRAMMAR_BACKEND_REGISTRY[name](
server_args, tokenizer, vocab_size, eos_token_ids
)
# Default grammar backends
if name == "outlines":
from sglang.srt.constrained.outlines_backend import OutlinesGrammarBackend
grammar_backend = OutlinesGrammarBackend(
tokenizer,
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
)
elif server_args.grammar_backend == "xgrammar":
elif name == "xgrammar":
from sglang.srt.constrained.xgrammar_backend import XGrammarGrammarBackend
# Convert Set[int] to List[int] if needed
......@@ -189,17 +225,17 @@ def create_grammar_backend(
grammar_backend = XGrammarGrammarBackend(
tokenizer, vocab_size=vocab_size, model_eos_token_ids=eos_list
)
elif server_args.grammar_backend == "llguidance":
elif name == "llguidance":
from sglang.srt.constrained.llguidance_backend import GuidanceBackend
grammar_backend = GuidanceBackend(
tokenizer=tokenizer,
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
)
elif server_args.grammar_backend == "none":
elif name == "none":
return None
else:
raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}")
raise ValueError(f"Invalid grammar backend: {name}")
if server_args.reasoning_parser and hasattr(tokenizer, "think_end_id"):
from sglang.srt.constrained.reasoner_grammar_backend import (
......
......@@ -48,7 +48,6 @@ class GuidanceGrammar(BaseGrammarObject):
self.serialized_grammar,
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
)
self.finished = False
self.bitmask = None
def accept_token(self, token: int):
......
......@@ -49,7 +49,6 @@ class OutlinesGrammar(BaseGrammarObject):
self.guide = guide
self.jump_forward_map = jump_forward_map
self.state = 0
self.finished = False
def accept_token(self, token: int):
self.state = self.guide.get_next_state(self.state, token)
......
......@@ -13,6 +13,7 @@
# ==============================================================================
"""Constrained decoding with xgrammar backend."""
import dataclasses
import json
import logging
from typing import List, Optional, Tuple, Union
......@@ -31,6 +32,7 @@ from sglang.srt.constrained.base_grammar_backend import (
INVALID_GRAMMAR_OBJ,
BaseGrammarBackend,
BaseGrammarObject,
GrammarStats,
)
from sglang.srt.utils import is_hip
......@@ -41,9 +43,9 @@ else:
from sglang.srt.constrained.triton_ops.bitmask_ops import (
apply_token_bitmask_inplace_triton,
)
logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
MAX_ROLLBACK_TOKENS = 200
......@@ -56,17 +58,20 @@ class XGrammarGrammar(BaseGrammarObject):
ctx: CompiledGrammar,
override_stop_tokens: Optional[Union[List[int], int]],
key_string: Optional[str] = None, # TODO (sk): for debugging, remove later
grammar_stats: Optional[GrammarStats] = GrammarStats(),
) -> None:
super().__init__()
self.matcher = matcher
self.vocab_size = vocab_size
self.ctx = ctx
self.override_stop_tokens = override_stop_tokens
self.finished = False
self.accepted_tokens = []
self.key_string = key_string
self.grammar_stats = grammar_stats
def accept_token(self, token: int):
if not self.is_terminated():
self.current_token = token
accepted = self.matcher.accept_token(token)
if not accepted:
# log for debugging
......@@ -120,6 +125,9 @@ class XGrammarGrammar(BaseGrammarObject):
self.ctx,
self.override_stop_tokens,
self.key_string,
dataclasses.replace(
self.grammar_stats, is_cache_hit=True, tree_traversal_time=[]
),
)
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
......@@ -150,7 +158,7 @@ class XGrammarGrammar(BaseGrammarObject):
assert self.matcher.accept_token(new_output_ids[i])
def __repr__(self):
return f"XGrammarGrammar({self.key_string=}, {self.accepted_tokens=})"
return f"XGrammarGrammar({self.key_string=}, {self.accepted_tokens=}, {self.current_token=})"
class XGrammarGrammarBackend(BaseGrammarBackend):
......@@ -177,14 +185,21 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
self.vocab_size = vocab_size
self.override_stop_tokens = override_stop_tokens
def _from_context(self, ctx: CompiledGrammar, key_string: str) -> XGrammarGrammar:
def _from_context(
self, ctx: CompiledGrammar, key_string: str, grammar_stats: GrammarStats
) -> XGrammarGrammar:
matcher = GrammarMatcher(
ctx,
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
override_stop_tokens=self.override_stop_tokens,
)
return XGrammarGrammar(
matcher, self.vocab_size, ctx, self.override_stop_tokens, key_string
matcher,
self.vocab_size,
ctx,
self.override_stop_tokens,
key_string,
grammar_stats,
)
def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]:
......@@ -198,7 +213,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
except (RuntimeError, json.decoder.JSONDecodeError) as e:
logging.error(f"Hit invalid json_schema: {key_string=}, {e=}")
return INVALID_GRAMMAR_OBJ
return self._from_context(ctx, key_string)
return self._from_context(ctx, key_string, GrammarStats())
def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]:
try:
......@@ -206,7 +221,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
except RuntimeError as e:
logging.error(f"Hit invalid ebnf: {key_string=}, {e=}")
return INVALID_GRAMMAR_OBJ
return self._from_context(ctx, key_string)
return self._from_context(ctx, key_string, GrammarStats())
def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]:
try:
......@@ -214,7 +229,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
except RuntimeError as e:
logging.error(f"Hit invalid regex: {key_string=}, {e=}")
return INVALID_GRAMMAR_OBJ
return self._from_context(ctx, key_string)
return self._from_context(ctx, key_string, GrammarStats())
def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
try:
......@@ -233,7 +248,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
except (RuntimeError, json.decoder.JSONDecodeError) as e:
logging.error(f"Hit invalid structural_tag: {key_string=}, {e=}")
return INVALID_GRAMMAR_OBJ
return self._from_context(ctx, key_string)
return self._from_context(ctx, key_string, GrammarStats())
def reset(self):
self.grammar_compiler.clear_cache()
......@@ -110,7 +110,10 @@ class ScheduleBatchDisaggregationDecodeMixin:
if req.grammar is not None:
# FIXME: this try-except block is for handling unexpected xgrammar issue.
try:
req.grammar.accept_token(req.output_ids[-1])
# if it is not None, then the grammar is from a retracted request, and we should not
# accept the token as it's already accepted
if req.grammar.current_token is None:
req.grammar.accept_token(req.output_ids[-1])
except ValueError as e:
# Grammar accept_token can raise ValueError if the token is not in the grammar.
# This can happen if the grammar is not set correctly or the token is invalid.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment