"docs/source/vscode:/vscode.git/clone" did not exist on "a844d16ec0aa3c78fc3e3f675c7ad5f4e2f51866"
Unverified Commit 2269cf1e authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Auto Sync] Update base_grammar_backend.py, llguidance_back... (20250911) (#10333)


Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent 151e287d
...@@ -14,8 +14,9 @@ ...@@ -14,8 +14,9 @@
"""The baseclass of a backend for grammar-guided constrained decoding.""" """The baseclass of a backend for grammar-guided constrained decoding."""
import logging import logging
import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass from dataclasses import dataclass, field
from threading import Event from threading import Event
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
...@@ -26,10 +27,22 @@ from sglang.srt.server_args import ServerArgs ...@@ -26,10 +27,22 @@ from sglang.srt.server_args import ServerArgs
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@dataclass
class GrammarStats:
compilation_time: Optional[float] = None
schema_count: Optional[int] = None
ebnf_size: Optional[int] = None
is_cache_hit: bool = False
is_grammar_aborted: bool = False
tree_traversal_time: List[float] = field(default_factory=list)
class BaseGrammarObject: class BaseGrammarObject:
def __init__(self): def __init__(self):
self._finished = False self._finished = False
self.grammar_stats = None
self.current_token = None
def accept_token(self, token: int) -> None: def accept_token(self, token: int) -> None:
""" """
...@@ -137,19 +150,26 @@ class BaseGrammarBackend: ...@@ -137,19 +150,26 @@ class BaseGrammarBackend:
return self._not_supported("structural_tag", key_string) return self._not_supported("structural_tag", key_string)
def _init_value_dispatch(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]: def _init_value_dispatch(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
s = time.perf_counter()
key_type, key_string = key key_type, key_string = key
if key_type == "json": if key_type == "json":
return self.dispatch_json(key_string) grammar = self.dispatch_json(key_string)
elif key_type == "regex": elif key_type == "regex":
return self.dispatch_regex(key_string) grammar = self.dispatch_regex(key_string)
elif key_type == "ebnf": elif key_type == "ebnf":
return self.dispatch_ebnf(key_string) grammar = self.dispatch_ebnf(key_string)
elif key_type == "structural_tag": elif key_type == "structural_tag":
return self.dispatch_structural_tag(key_string) grammar = self.dispatch_structural_tag(key_string)
elif key_type == "structural_pattern": elif key_type == "structural_pattern":
return self.dispatch_structural_pattern(key_string) grammar = self.dispatch_structural_pattern(key_string)
elif key_type == "structural_pattern_v2":
grammar = self.dispatch_structural_pattern_v2(key_string)
else: else:
return self.dispatch_fallback(key_type, key_string) grammar = self.dispatch_fallback(key_type, key_string)
if grammar is not None and grammar.grammar_stats is not None:
grammar.grammar_stats.compilation_time = time.perf_counter() - s
return grammar
def get_cached_or_future_value( def get_cached_or_future_value(
self, key: Tuple[str, str] self, key: Tuple[str, str]
...@@ -167,20 +187,36 @@ class BaseGrammarBackend: ...@@ -167,20 +187,36 @@ class BaseGrammarBackend:
self.cache.clear() self.cache.clear()
GRAMMAR_BACKEND_REGISTRY = {}
def register_grammar_backend(name, init_func):
GRAMMAR_BACKEND_REGISTRY[name] = init_func
def create_grammar_backend( def create_grammar_backend(
server_args: ServerArgs, server_args: ServerArgs,
tokenizer, tokenizer,
vocab_size: int, vocab_size: int,
eos_token_ids: Optional[set] = None, eos_token_ids: Optional[set] = None,
) -> Optional[BaseGrammarBackend]: ) -> Optional[BaseGrammarBackend]:
if server_args.grammar_backend == "outlines": name = server_args.grammar_backend
# Custom grammar backend has the highest priority
if name in GRAMMAR_BACKEND_REGISTRY:
return GRAMMAR_BACKEND_REGISTRY[name](
server_args, tokenizer, vocab_size, eos_token_ids
)
# Default grammar backends
if name == "outlines":
from sglang.srt.constrained.outlines_backend import OutlinesGrammarBackend from sglang.srt.constrained.outlines_backend import OutlinesGrammarBackend
grammar_backend = OutlinesGrammarBackend( grammar_backend = OutlinesGrammarBackend(
tokenizer, tokenizer,
whitespace_pattern=server_args.constrained_json_whitespace_pattern, whitespace_pattern=server_args.constrained_json_whitespace_pattern,
) )
elif server_args.grammar_backend == "xgrammar": elif name == "xgrammar":
from sglang.srt.constrained.xgrammar_backend import XGrammarGrammarBackend from sglang.srt.constrained.xgrammar_backend import XGrammarGrammarBackend
# Convert Set[int] to List[int] if needed # Convert Set[int] to List[int] if needed
...@@ -189,17 +225,17 @@ def create_grammar_backend( ...@@ -189,17 +225,17 @@ def create_grammar_backend(
grammar_backend = XGrammarGrammarBackend( grammar_backend = XGrammarGrammarBackend(
tokenizer, vocab_size=vocab_size, model_eos_token_ids=eos_list tokenizer, vocab_size=vocab_size, model_eos_token_ids=eos_list
) )
elif server_args.grammar_backend == "llguidance": elif name == "llguidance":
from sglang.srt.constrained.llguidance_backend import GuidanceBackend from sglang.srt.constrained.llguidance_backend import GuidanceBackend
grammar_backend = GuidanceBackend( grammar_backend = GuidanceBackend(
tokenizer=tokenizer, tokenizer=tokenizer,
whitespace_pattern=server_args.constrained_json_whitespace_pattern, whitespace_pattern=server_args.constrained_json_whitespace_pattern,
) )
elif server_args.grammar_backend == "none": elif name == "none":
return None return None
else: else:
raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}") raise ValueError(f"Invalid grammar backend: {name}")
if server_args.reasoning_parser and hasattr(tokenizer, "think_end_id"): if server_args.reasoning_parser and hasattr(tokenizer, "think_end_id"):
from sglang.srt.constrained.reasoner_grammar_backend import ( from sglang.srt.constrained.reasoner_grammar_backend import (
......
...@@ -48,7 +48,6 @@ class GuidanceGrammar(BaseGrammarObject): ...@@ -48,7 +48,6 @@ class GuidanceGrammar(BaseGrammarObject):
self.serialized_grammar, self.serialized_grammar,
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")), log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
) )
self.finished = False
self.bitmask = None self.bitmask = None
def accept_token(self, token: int): def accept_token(self, token: int):
......
...@@ -49,7 +49,6 @@ class OutlinesGrammar(BaseGrammarObject): ...@@ -49,7 +49,6 @@ class OutlinesGrammar(BaseGrammarObject):
self.guide = guide self.guide = guide
self.jump_forward_map = jump_forward_map self.jump_forward_map = jump_forward_map
self.state = 0 self.state = 0
self.finished = False
def accept_token(self, token: int): def accept_token(self, token: int):
self.state = self.guide.get_next_state(self.state, token) self.state = self.guide.get_next_state(self.state, token)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# ============================================================================== # ==============================================================================
"""Constrained decoding with xgrammar backend.""" """Constrained decoding with xgrammar backend."""
import dataclasses
import json import json
import logging import logging
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
...@@ -31,6 +32,7 @@ from sglang.srt.constrained.base_grammar_backend import ( ...@@ -31,6 +32,7 @@ from sglang.srt.constrained.base_grammar_backend import (
INVALID_GRAMMAR_OBJ, INVALID_GRAMMAR_OBJ,
BaseGrammarBackend, BaseGrammarBackend,
BaseGrammarObject, BaseGrammarObject,
GrammarStats,
) )
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip
...@@ -41,9 +43,9 @@ else: ...@@ -41,9 +43,9 @@ else:
from sglang.srt.constrained.triton_ops.bitmask_ops import ( from sglang.srt.constrained.triton_ops.bitmask_ops import (
apply_token_bitmask_inplace_triton, apply_token_bitmask_inplace_triton,
) )
logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
MAX_ROLLBACK_TOKENS = 200 MAX_ROLLBACK_TOKENS = 200
...@@ -56,17 +58,20 @@ class XGrammarGrammar(BaseGrammarObject): ...@@ -56,17 +58,20 @@ class XGrammarGrammar(BaseGrammarObject):
ctx: CompiledGrammar, ctx: CompiledGrammar,
override_stop_tokens: Optional[Union[List[int], int]], override_stop_tokens: Optional[Union[List[int], int]],
key_string: Optional[str] = None, # TODO (sk): for debugging, remove later key_string: Optional[str] = None, # TODO (sk): for debugging, remove later
grammar_stats: Optional[GrammarStats] = GrammarStats(),
) -> None: ) -> None:
super().__init__()
self.matcher = matcher self.matcher = matcher
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.ctx = ctx self.ctx = ctx
self.override_stop_tokens = override_stop_tokens self.override_stop_tokens = override_stop_tokens
self.finished = False
self.accepted_tokens = [] self.accepted_tokens = []
self.key_string = key_string self.key_string = key_string
self.grammar_stats = grammar_stats
def accept_token(self, token: int): def accept_token(self, token: int):
if not self.is_terminated(): if not self.is_terminated():
self.current_token = token
accepted = self.matcher.accept_token(token) accepted = self.matcher.accept_token(token)
if not accepted: if not accepted:
# log for debugging # log for debugging
...@@ -120,6 +125,9 @@ class XGrammarGrammar(BaseGrammarObject): ...@@ -120,6 +125,9 @@ class XGrammarGrammar(BaseGrammarObject):
self.ctx, self.ctx,
self.override_stop_tokens, self.override_stop_tokens,
self.key_string, self.key_string,
dataclasses.replace(
self.grammar_stats, is_cache_hit=True, tree_traversal_time=[]
),
) )
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]: def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
...@@ -150,7 +158,7 @@ class XGrammarGrammar(BaseGrammarObject): ...@@ -150,7 +158,7 @@ class XGrammarGrammar(BaseGrammarObject):
assert self.matcher.accept_token(new_output_ids[i]) assert self.matcher.accept_token(new_output_ids[i])
def __repr__(self): def __repr__(self):
return f"XGrammarGrammar({self.key_string=}, {self.accepted_tokens=})" return f"XGrammarGrammar({self.key_string=}, {self.accepted_tokens=}, {self.current_token=})"
class XGrammarGrammarBackend(BaseGrammarBackend): class XGrammarGrammarBackend(BaseGrammarBackend):
...@@ -177,14 +185,21 @@ class XGrammarGrammarBackend(BaseGrammarBackend): ...@@ -177,14 +185,21 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.override_stop_tokens = override_stop_tokens self.override_stop_tokens = override_stop_tokens
def _from_context(self, ctx: CompiledGrammar, key_string: str) -> XGrammarGrammar: def _from_context(
self, ctx: CompiledGrammar, key_string: str, grammar_stats: GrammarStats
) -> XGrammarGrammar:
matcher = GrammarMatcher( matcher = GrammarMatcher(
ctx, ctx,
max_rollback_tokens=MAX_ROLLBACK_TOKENS, max_rollback_tokens=MAX_ROLLBACK_TOKENS,
override_stop_tokens=self.override_stop_tokens, override_stop_tokens=self.override_stop_tokens,
) )
return XGrammarGrammar( return XGrammarGrammar(
matcher, self.vocab_size, ctx, self.override_stop_tokens, key_string matcher,
self.vocab_size,
ctx,
self.override_stop_tokens,
key_string,
grammar_stats,
) )
def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]: def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]:
...@@ -198,7 +213,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend): ...@@ -198,7 +213,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
except (RuntimeError, json.decoder.JSONDecodeError) as e: except (RuntimeError, json.decoder.JSONDecodeError) as e:
logging.error(f"Hit invalid json_schema: {key_string=}, {e=}") logging.error(f"Hit invalid json_schema: {key_string=}, {e=}")
return INVALID_GRAMMAR_OBJ return INVALID_GRAMMAR_OBJ
return self._from_context(ctx, key_string) return self._from_context(ctx, key_string, GrammarStats())
def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]: def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]:
try: try:
...@@ -206,7 +221,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend): ...@@ -206,7 +221,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
except RuntimeError as e: except RuntimeError as e:
logging.error(f"Hit invalid ebnf: {key_string=}, {e=}") logging.error(f"Hit invalid ebnf: {key_string=}, {e=}")
return INVALID_GRAMMAR_OBJ return INVALID_GRAMMAR_OBJ
return self._from_context(ctx, key_string) return self._from_context(ctx, key_string, GrammarStats())
def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]: def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]:
try: try:
...@@ -214,7 +229,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend): ...@@ -214,7 +229,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
except RuntimeError as e: except RuntimeError as e:
logging.error(f"Hit invalid regex: {key_string=}, {e=}") logging.error(f"Hit invalid regex: {key_string=}, {e=}")
return INVALID_GRAMMAR_OBJ return INVALID_GRAMMAR_OBJ
return self._from_context(ctx, key_string) return self._from_context(ctx, key_string, GrammarStats())
def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]: def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
try: try:
...@@ -233,7 +248,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend): ...@@ -233,7 +248,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
except (RuntimeError, json.decoder.JSONDecodeError) as e: except (RuntimeError, json.decoder.JSONDecodeError) as e:
logging.error(f"Hit invalid structural_tag: {key_string=}, {e=}") logging.error(f"Hit invalid structural_tag: {key_string=}, {e=}")
return INVALID_GRAMMAR_OBJ return INVALID_GRAMMAR_OBJ
return self._from_context(ctx, key_string) return self._from_context(ctx, key_string, GrammarStats())
def reset(self): def reset(self):
self.grammar_compiler.clear_cache() self.grammar_compiler.clear_cache()
...@@ -110,7 +110,10 @@ class ScheduleBatchDisaggregationDecodeMixin: ...@@ -110,7 +110,10 @@ class ScheduleBatchDisaggregationDecodeMixin:
if req.grammar is not None: if req.grammar is not None:
# FIXME: this try-except block is for handling unexpected xgrammar issue. # FIXME: this try-except block is for handling unexpected xgrammar issue.
try: try:
req.grammar.accept_token(req.output_ids[-1]) # if it is not None, then the grammar is from a retracted request, and we should not
# accept the token as it's already accepted
if req.grammar.current_token is None:
req.grammar.accept_token(req.output_ids[-1])
except ValueError as e: except ValueError as e:
# Grammar accept_token can raise ValueError if the token is not in the grammar. # Grammar accept_token can raise ValueError if the token is not in the grammar.
# This can happen if the grammar is not set correctly or the token is invalid. # This can happen if the grammar is not set correctly or the token is invalid.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment