Unverified Commit 19120f71 authored by DarkSharpness's avatar DarkSharpness Committed by GitHub
Browse files

[Fix & Style] Refactor the grammar backend to reduce human errors and improve readability (#4030)

parent 2415ec38
......@@ -13,31 +13,130 @@
# ==============================================================================
"""The baseclass of a backend for grammar-guided constrained decoding."""
import logging
from abc import ABC, abstractmethod
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
from threading import Event, Lock
from typing import Any, Optional, Tuple
from typing import Dict, List, Optional, Tuple
import torch
from sglang.srt.server_args import ServerArgs
logger = logging.getLogger(__name__)
class BaseGrammarObject(ABC):
@abstractmethod
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
"""
Try to jump forward in the grammar.
Returns:
A jump forward helper which may be used in `jump_forward_str_state`.
None if the jump forward is not possible.
"""
raise NotImplementedError
@abstractmethod
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
"""
Jump forward for the grammar.
Returns:
A tuple of the jump forward string and the next state of the grammar
(which can be used in `jump_and_retokenize` if needed).
"""
raise NotImplementedError
@abstractmethod
def jump_and_retokenize(
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
) -> None:
"""
Jump forward occurs, and update the grammar state if needed.
"""
raise NotImplementedError
@abstractmethod
def allocate_vocab_mask(
self, vocab_size: int, batch_size: int, device
) -> torch.Tensor:
raise NotImplementedError
@abstractmethod
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
raise NotImplementedError
@staticmethod
@abstractmethod
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
raise NotImplementedError
@staticmethod
@abstractmethod
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
raise NotImplementedError
@abstractmethod
def copy(self) -> "BaseGrammarObject":
raise NotImplementedError
@dataclass
class CacheEntry:
value: Any
value: Optional[BaseGrammarObject]
event: Event
class BaseGrammarObject:
pass
class BaseGrammarBackend:
class BaseGrammarBackend(ABC):
def __init__(self):
self.executor = ThreadPoolExecutor()
self.cache = {}
self.cache: Dict[Tuple[str, str], CacheEntry] = {}
self.cache_lock = Lock()
def init_value(self, key: Tuple[str, str]) -> BaseGrammarObject:
def _not_supported(self, key_type: str, key_string: str) -> None:
logger.warning(f"Skip unsupported {key_type}: {key_type}={key_string}")
def dispatch_fallback(
self, key_type: str, key_string: str
) -> Optional[BaseGrammarObject]:
"""
This function should not be reached in any case.
"""
raise ValueError(f"Invalid key_type: {key_type}={key_string}")
@abstractmethod
def dispatch_json(self, key_string: str) -> Optional[BaseGrammarObject]:
return self._not_supported("json", key_string)
@abstractmethod
def dispatch_regex(self, key_string: str) -> Optional[BaseGrammarObject]:
return self._not_supported("regex", key_string)
@abstractmethod
def dispatch_ebnf(self, key_string: str) -> Optional[BaseGrammarObject]:
return self._not_supported("ebnf", key_string)
@abstractmethod
def dispatch_structural_tag(self, key_string: str) -> Optional[BaseGrammarObject]:
return self._not_supported("structural_tag", key_string)
def _init_value_dispatch(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
key_type, key_string = key
if key_type == "json":
return self.dispatch_json(key_string)
elif key_type == "regex":
return self.dispatch_regex(key_string)
elif key_type == "ebnf":
return self.dispatch_ebnf(key_string)
elif key_type == "structural_tag":
return self.dispatch_structural_tag(key_string)
else:
return self.dispatch_fallback(key_type, key_string)
def _init_value(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
with self.cache_lock:
if key in self.cache:
cache_hit = True
......@@ -50,13 +149,10 @@ class BaseGrammarBackend:
if cache_hit:
entry.event.wait()
else:
entry.value = self.init_value_impl(key)
entry.value = self._init_value_dispatch(key)
entry.event.set()
return entry.value.copy() if entry.value else None
def init_value_impl(self, key: Tuple[str, str]) -> BaseGrammarObject:
raise NotImplementedError()
def get_cached_value(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
with self.cache_lock:
entry = self.cache.get(key)
......@@ -66,7 +162,7 @@ class BaseGrammarBackend:
return val.copy() if val else None
def get_future_value(self, key: Tuple[str, str]) -> Future:
return self.executor.submit(self.init_value, key)
return self.executor.submit(self._init_value, key)
def reset(self):
with self.cache_lock:
......
......@@ -48,7 +48,7 @@ class GuidanceGrammar(BaseGrammarObject):
self.finished = False
self.bitmask = None
def try_jump_forward(self, tokenizer) -> Tuple[List[int], str]:
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
if len(self.pending_ff_tokens) > 0:
s = self.llguidance_tokenizer.decode_str(self.pending_ff_tokens)
ff_tokens = self.pending_ff_tokens
......@@ -125,22 +125,27 @@ class GuidanceBackend(BaseGrammarBackend):
)
self.llguidance_tokenizer = llguidance.hf.from_tokenizer(self.tokenizer, None)
def init_value_impl(self, key: Tuple[str, str]) -> GuidanceGrammar:
mode, value = key
if mode == "json":
json_schema = value
compiler = llguidance.JsonCompiler(
whitespace_flexible=self.whitespace_flexible
)
serialized_grammar = compiler.compile(json_schema)
elif mode == "regex":
compiler = llguidance.RegexCompiler()
serialized_grammar = compiler.compile(regex=value)
elif mode == "ebnf":
compiler = llguidance.LarkCompiler()
serialized_grammar = compiler.compile(any_to_lark(value))
def _from_serialized(self, serialized_grammar) -> GuidanceGrammar:
return GuidanceGrammar(
llguidance_tokenizer=self.llguidance_tokenizer,
serialized_grammar=serialized_grammar,
)
def dispatch_json(self, key_string: str) -> GuidanceGrammar:
json_schema = key_string
compiler = llguidance.JsonCompiler(whitespace_flexible=self.whitespace_flexible)
serialized_grammar = compiler.compile(json_schema)
return self._from_serialized(serialized_grammar)
def dispatch_regex(self, key_string: str) -> GuidanceGrammar:
compiler = llguidance.RegexCompiler()
serialized_grammar = compiler.compile(regex=key_string)
return self._from_serialized(serialized_grammar)
def dispatch_ebnf(self, key_string: str) -> GuidanceGrammar:
compiler = llguidance.LarkCompiler()
serialized_grammar = compiler.compile(any_to_lark(key_string))
return self._from_serialized(serialized_grammar)
def dispatch_structural_tag(self, key_string: str):
return super().dispatch_structural_tag(key_string)
......@@ -141,24 +141,7 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
)
self.whitespace_pattern = whitespace_pattern
def init_value_impl(self, key: Tuple[str, str]) -> OutlinesGrammar:
key_type, key_string = key
if key_type == "json":
try:
regex = build_regex_from_object(
key_string,
whitespace_pattern=self.whitespace_pattern,
)
except (NotImplementedError, json.decoder.JSONDecodeError) as e:
logger.warning(
f"Skip invalid json_schema: json_schema={key_string}, {e=}"
)
return None
elif key_type == "regex":
regex = key_string
else:
raise ValueError(f"Invalid key_type: {key_type}")
def _compile_regex(self, regex: str) -> Optional[OutlinesGrammar]:
try:
if hasattr(RegexGuide, "from_regex"):
# outlines >= 0.1.1
......@@ -173,6 +156,25 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
jump_forward_map = None
return OutlinesGrammar(guide, jump_forward_map)
def dispatch_ebnf(self, key_string: str):
return super().dispatch_ebnf(key_string)
def dispatch_structural_tag(self, key_string: str):
return super().dispatch_structural_tag(key_string)
def dispatch_json(self, key_string: str):
try:
regex = build_regex_from_object(
key_string,
whitespace_pattern=self.whitespace_pattern,
)
except (NotImplementedError, json.decoder.JSONDecodeError) as e:
logger.warning(f"Skip invalid json_schema: json_schema={key_string}, {e=}")
return self._compile_regex(regex)
def dispatch_regex(self, key_string: str):
return self._compile_regex(key_string)
def build_regex_from_object(
object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
......
......@@ -57,7 +57,7 @@ class XGrammarGrammar(BaseGrammarObject):
def accept_token(self, token: int):
assert self.matcher.accept_token(token)
def try_jump_forward(self, tokenizer) -> Tuple[List[int], str]:
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
s = self.matcher.find_jump_forward_string()
if s:
return [], s
......@@ -128,55 +128,56 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
self.vocab_size = vocab_size
self.override_stop_tokens = override_stop_tokens
def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
key_type, key_string = key
if key_type == "json":
try:
if key_string == "$$ANY$$":
ctx = self.grammar_compiler.compile_builtin_json_grammar()
else:
ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
except RuntimeError as e:
logging.warning(
f"Skip invalid json_schema: json_schema={key_string}, {e=}"
)
return None
elif key_type == "ebnf":
try:
ctx = self.grammar_compiler.compile_grammar(key_string)
except RuntimeError as e:
logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
return None
elif key_type == "regex":
try:
ctx = self.grammar_compiler.compile_regex(key_string)
except RuntimeError as e:
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
return None
elif key_type == "structural_tag":
try:
structural_tag = json.loads(key_string)
tags = [
StructuralTagItem(
begin=structure["begin"],
schema=json.dumps(structure["schema"]),
end=structure["end"],
)
for structure in structural_tag["structures"]
]
ctx = self.grammar_compiler.compile_structural_tag(
tags, structural_tag["triggers"]
)
except RuntimeError as e:
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
return None
else:
raise ValueError(f"Invalid key_type: {key_type}")
def _from_context(self, ctx: CompiledGrammar) -> XGrammarGrammar:
matcher = GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
return XGrammarGrammar(matcher, self.vocab_size, ctx, self.override_stop_tokens)
def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]:
try:
if key_string == "$$ANY$$":
ctx = self.grammar_compiler.compile_builtin_json_grammar()
else:
ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
except RuntimeError as e:
logging.warning(f"Skip invalid json_schema: json_schema={key_string}, {e=}")
return None
return self._from_context(ctx)
def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]:
try:
ctx = self.grammar_compiler.compile_grammar(key_string)
except RuntimeError as e:
logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
return None
return self._from_context(ctx)
def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]:
try:
ctx = self.grammar_compiler.compile_regex(key_string)
except RuntimeError as e:
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
return None
return self._from_context(ctx)
def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
try:
structural_tag = json.loads(key_string)
tags = [
StructuralTagItem(
begin=structure["begin"],
schema=json.dumps(structure["schema"]),
end=structure["end"],
)
for structure in structural_tag["structures"]
]
ctx = self.grammar_compiler.compile_structural_tag(
tags, structural_tag["triggers"]
)
except RuntimeError as e:
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
return None
return self._from_context(ctx)
def reset(self):
if self.grammar_compiler:
self.grammar_compiler.clear_cache()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment