Unverified Commit 19120f71 authored by DarkSharpness's avatar DarkSharpness Committed by GitHub
Browse files

[Fix & Style] Refactor the grammar backend to reduce human errors and improve readability (#4030)

parent 2415ec38
...@@ -13,31 +13,130 @@ ...@@ -13,31 +13,130 @@
# ============================================================================== # ==============================================================================
"""The baseclass of a backend for grammar-guided constrained decoding.""" """The baseclass of a backend for grammar-guided constrained decoding."""
import logging
from abc import ABC, abstractmethod
from concurrent.futures import Future, ThreadPoolExecutor from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass from dataclasses import dataclass
from threading import Event, Lock from threading import Event, Lock
from typing import Any, Optional, Tuple from typing import Dict, List, Optional, Tuple
import torch
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
logger = logging.getLogger(__name__)
class BaseGrammarObject(ABC):
@abstractmethod
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
"""
Try to jump forward in the grammar.
Returns:
A jump forward helper which may be used in `jump_forward_str_state`.
None if the jump forward is not possible.
"""
raise NotImplementedError
@abstractmethod
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
"""
Jump forward for the grammar.
Returns:
A tuple of the jump forward string and the next state of the grammar
(which can be used in `jump_and_retokenize` if needed).
"""
raise NotImplementedError
@abstractmethod
def jump_and_retokenize(
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
) -> None:
"""
Jump forward occurs, and update the grammar state if needed.
"""
raise NotImplementedError
@abstractmethod
def allocate_vocab_mask(
self, vocab_size: int, batch_size: int, device
) -> torch.Tensor:
raise NotImplementedError
@abstractmethod
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
raise NotImplementedError
@staticmethod
@abstractmethod
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
raise NotImplementedError
@staticmethod
@abstractmethod
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
raise NotImplementedError
@abstractmethod
def copy(self) -> "BaseGrammarObject":
raise NotImplementedError
@dataclass @dataclass
class CacheEntry: class CacheEntry:
value: Any value: Optional[BaseGrammarObject]
event: Event event: Event
class BaseGrammarObject: class BaseGrammarBackend(ABC):
pass
class BaseGrammarBackend:
def __init__(self): def __init__(self):
self.executor = ThreadPoolExecutor() self.executor = ThreadPoolExecutor()
self.cache = {} self.cache: Dict[Tuple[str, str], CacheEntry] = {}
self.cache_lock = Lock() self.cache_lock = Lock()
def init_value(self, key: Tuple[str, str]) -> BaseGrammarObject: def _not_supported(self, key_type: str, key_string: str) -> None:
logger.warning(f"Skip unsupported {key_type}: {key_type}={key_string}")
def dispatch_fallback(
self, key_type: str, key_string: str
) -> Optional[BaseGrammarObject]:
"""
This function should not be reached in any case.
"""
raise ValueError(f"Invalid key_type: {key_type}={key_string}")
@abstractmethod
def dispatch_json(self, key_string: str) -> Optional[BaseGrammarObject]:
return self._not_supported("json", key_string)
@abstractmethod
def dispatch_regex(self, key_string: str) -> Optional[BaseGrammarObject]:
return self._not_supported("regex", key_string)
@abstractmethod
def dispatch_ebnf(self, key_string: str) -> Optional[BaseGrammarObject]:
return self._not_supported("ebnf", key_string)
@abstractmethod
def dispatch_structural_tag(self, key_string: str) -> Optional[BaseGrammarObject]:
return self._not_supported("structural_tag", key_string)
def _init_value_dispatch(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
key_type, key_string = key
if key_type == "json":
return self.dispatch_json(key_string)
elif key_type == "regex":
return self.dispatch_regex(key_string)
elif key_type == "ebnf":
return self.dispatch_ebnf(key_string)
elif key_type == "structural_tag":
return self.dispatch_structural_tag(key_string)
else:
return self.dispatch_fallback(key_type, key_string)
def _init_value(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
with self.cache_lock: with self.cache_lock:
if key in self.cache: if key in self.cache:
cache_hit = True cache_hit = True
...@@ -50,13 +149,10 @@ class BaseGrammarBackend: ...@@ -50,13 +149,10 @@ class BaseGrammarBackend:
if cache_hit: if cache_hit:
entry.event.wait() entry.event.wait()
else: else:
entry.value = self.init_value_impl(key) entry.value = self._init_value_dispatch(key)
entry.event.set() entry.event.set()
return entry.value.copy() if entry.value else None return entry.value.copy() if entry.value else None
def init_value_impl(self, key: Tuple[str, str]) -> BaseGrammarObject:
raise NotImplementedError()
def get_cached_value(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]: def get_cached_value(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
with self.cache_lock: with self.cache_lock:
entry = self.cache.get(key) entry = self.cache.get(key)
...@@ -66,7 +162,7 @@ class BaseGrammarBackend: ...@@ -66,7 +162,7 @@ class BaseGrammarBackend:
return val.copy() if val else None return val.copy() if val else None
def get_future_value(self, key: Tuple[str, str]) -> Future: def get_future_value(self, key: Tuple[str, str]) -> Future:
return self.executor.submit(self.init_value, key) return self.executor.submit(self._init_value, key)
def reset(self): def reset(self):
with self.cache_lock: with self.cache_lock:
......
...@@ -48,7 +48,7 @@ class GuidanceGrammar(BaseGrammarObject): ...@@ -48,7 +48,7 @@ class GuidanceGrammar(BaseGrammarObject):
self.finished = False self.finished = False
self.bitmask = None self.bitmask = None
def try_jump_forward(self, tokenizer) -> Tuple[List[int], str]: def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
if len(self.pending_ff_tokens) > 0: if len(self.pending_ff_tokens) > 0:
s = self.llguidance_tokenizer.decode_str(self.pending_ff_tokens) s = self.llguidance_tokenizer.decode_str(self.pending_ff_tokens)
ff_tokens = self.pending_ff_tokens ff_tokens = self.pending_ff_tokens
...@@ -125,22 +125,27 @@ class GuidanceBackend(BaseGrammarBackend): ...@@ -125,22 +125,27 @@ class GuidanceBackend(BaseGrammarBackend):
) )
self.llguidance_tokenizer = llguidance.hf.from_tokenizer(self.tokenizer, None) self.llguidance_tokenizer = llguidance.hf.from_tokenizer(self.tokenizer, None)
def init_value_impl(self, key: Tuple[str, str]) -> GuidanceGrammar: def _from_serialized(self, serialized_grammar) -> GuidanceGrammar:
mode, value = key
if mode == "json":
json_schema = value
compiler = llguidance.JsonCompiler(
whitespace_flexible=self.whitespace_flexible
)
serialized_grammar = compiler.compile(json_schema)
elif mode == "regex":
compiler = llguidance.RegexCompiler()
serialized_grammar = compiler.compile(regex=value)
elif mode == "ebnf":
compiler = llguidance.LarkCompiler()
serialized_grammar = compiler.compile(any_to_lark(value))
return GuidanceGrammar( return GuidanceGrammar(
llguidance_tokenizer=self.llguidance_tokenizer, llguidance_tokenizer=self.llguidance_tokenizer,
serialized_grammar=serialized_grammar, serialized_grammar=serialized_grammar,
) )
def dispatch_json(self, key_string: str) -> GuidanceGrammar:
json_schema = key_string
compiler = llguidance.JsonCompiler(whitespace_flexible=self.whitespace_flexible)
serialized_grammar = compiler.compile(json_schema)
return self._from_serialized(serialized_grammar)
def dispatch_regex(self, key_string: str) -> GuidanceGrammar:
compiler = llguidance.RegexCompiler()
serialized_grammar = compiler.compile(regex=key_string)
return self._from_serialized(serialized_grammar)
def dispatch_ebnf(self, key_string: str) -> GuidanceGrammar:
compiler = llguidance.LarkCompiler()
serialized_grammar = compiler.compile(any_to_lark(key_string))
return self._from_serialized(serialized_grammar)
def dispatch_structural_tag(self, key_string: str):
return super().dispatch_structural_tag(key_string)
...@@ -141,24 +141,7 @@ class OutlinesGrammarBackend(BaseGrammarBackend): ...@@ -141,24 +141,7 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
) )
self.whitespace_pattern = whitespace_pattern self.whitespace_pattern = whitespace_pattern
def init_value_impl(self, key: Tuple[str, str]) -> OutlinesGrammar: def _compile_regex(self, regex: str) -> Optional[OutlinesGrammar]:
key_type, key_string = key
if key_type == "json":
try:
regex = build_regex_from_object(
key_string,
whitespace_pattern=self.whitespace_pattern,
)
except (NotImplementedError, json.decoder.JSONDecodeError) as e:
logger.warning(
f"Skip invalid json_schema: json_schema={key_string}, {e=}"
)
return None
elif key_type == "regex":
regex = key_string
else:
raise ValueError(f"Invalid key_type: {key_type}")
try: try:
if hasattr(RegexGuide, "from_regex"): if hasattr(RegexGuide, "from_regex"):
# outlines >= 0.1.1 # outlines >= 0.1.1
...@@ -173,6 +156,25 @@ class OutlinesGrammarBackend(BaseGrammarBackend): ...@@ -173,6 +156,25 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
jump_forward_map = None jump_forward_map = None
return OutlinesGrammar(guide, jump_forward_map) return OutlinesGrammar(guide, jump_forward_map)
def dispatch_ebnf(self, key_string: str):
return super().dispatch_ebnf(key_string)
def dispatch_structural_tag(self, key_string: str):
return super().dispatch_structural_tag(key_string)
def dispatch_json(self, key_string: str):
try:
regex = build_regex_from_object(
key_string,
whitespace_pattern=self.whitespace_pattern,
)
except (NotImplementedError, json.decoder.JSONDecodeError) as e:
logger.warning(f"Skip invalid json_schema: json_schema={key_string}, {e=}")
return self._compile_regex(regex)
def dispatch_regex(self, key_string: str):
return self._compile_regex(key_string)
def build_regex_from_object( def build_regex_from_object(
object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
......
...@@ -57,7 +57,7 @@ class XGrammarGrammar(BaseGrammarObject): ...@@ -57,7 +57,7 @@ class XGrammarGrammar(BaseGrammarObject):
def accept_token(self, token: int): def accept_token(self, token: int):
assert self.matcher.accept_token(token) assert self.matcher.accept_token(token)
def try_jump_forward(self, tokenizer) -> Tuple[List[int], str]: def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
s = self.matcher.find_jump_forward_string() s = self.matcher.find_jump_forward_string()
if s: if s:
return [], s return [], s
...@@ -128,55 +128,56 @@ class XGrammarGrammarBackend(BaseGrammarBackend): ...@@ -128,55 +128,56 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.override_stop_tokens = override_stop_tokens self.override_stop_tokens = override_stop_tokens
def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar: def _from_context(self, ctx: CompiledGrammar) -> XGrammarGrammar:
key_type, key_string = key
if key_type == "json":
try:
if key_string == "$$ANY$$":
ctx = self.grammar_compiler.compile_builtin_json_grammar()
else:
ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
except RuntimeError as e:
logging.warning(
f"Skip invalid json_schema: json_schema={key_string}, {e=}"
)
return None
elif key_type == "ebnf":
try:
ctx = self.grammar_compiler.compile_grammar(key_string)
except RuntimeError as e:
logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
return None
elif key_type == "regex":
try:
ctx = self.grammar_compiler.compile_regex(key_string)
except RuntimeError as e:
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
return None
elif key_type == "structural_tag":
try:
structural_tag = json.loads(key_string)
tags = [
StructuralTagItem(
begin=structure["begin"],
schema=json.dumps(structure["schema"]),
end=structure["end"],
)
for structure in structural_tag["structures"]
]
ctx = self.grammar_compiler.compile_structural_tag(
tags, structural_tag["triggers"]
)
except RuntimeError as e:
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
return None
else:
raise ValueError(f"Invalid key_type: {key_type}")
matcher = GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS) matcher = GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
return XGrammarGrammar(matcher, self.vocab_size, ctx, self.override_stop_tokens) return XGrammarGrammar(matcher, self.vocab_size, ctx, self.override_stop_tokens)
def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]:
try:
if key_string == "$$ANY$$":
ctx = self.grammar_compiler.compile_builtin_json_grammar()
else:
ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
except RuntimeError as e:
logging.warning(f"Skip invalid json_schema: json_schema={key_string}, {e=}")
return None
return self._from_context(ctx)
def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]:
try:
ctx = self.grammar_compiler.compile_grammar(key_string)
except RuntimeError as e:
logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
return None
return self._from_context(ctx)
def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]:
try:
ctx = self.grammar_compiler.compile_regex(key_string)
except RuntimeError as e:
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
return None
return self._from_context(ctx)
def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
try:
structural_tag = json.loads(key_string)
tags = [
StructuralTagItem(
begin=structure["begin"],
schema=json.dumps(structure["schema"]),
end=structure["end"],
)
for structure in structural_tag["structures"]
]
ctx = self.grammar_compiler.compile_structural_tag(
tags, structural_tag["triggers"]
)
except RuntimeError as e:
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
return None
return self._from_context(ctx)
def reset(self): def reset(self):
if self.grammar_compiler: if self.grammar_compiler:
self.grammar_compiler.clear_cache() self.grammar_compiler.clear_cache()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment