Unverified Commit bdbe5f81 authored by Michał Moskal's avatar Michał Moskal Committed by GitHub
Browse files

update llguidance to 0.7.11; adds StructTag (#4870)

parent 9ad28f63
...@@ -24,7 +24,7 @@ runtime_common = [ ...@@ -24,7 +24,7 @@ runtime_common = [
"hf_transfer", "hf_transfer",
"huggingface_hub", "huggingface_hub",
"interegular", "interegular",
"llguidance>=0.6.15", "llguidance>=0.7.11,<0.8.0",
"modelscope", "modelscope",
"ninja", "ninja",
"orjson", "orjson",
......
...@@ -14,49 +14,48 @@ ...@@ -14,49 +14,48 @@
"""Constrained decoding with llguidance backend.""" """Constrained decoding with llguidance backend."""
import json import json
import logging
import os import os
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import llguidance
import llguidance.hf
import llguidance.torch
import torch import torch
from llguidance.gbnf_to_lark import any_to_lark from llguidance import LLMatcher, LLTokenizer, StructTag, grammar_from
from llguidance.hf import from_tokenizer
from llguidance.torch import (
allocate_token_bitmask,
apply_token_bitmask_inplace,
fill_next_token_bitmask,
)
from sglang.srt.constrained.base_grammar_backend import ( from sglang.srt.constrained.base_grammar_backend import (
BaseGrammarBackend, BaseGrammarBackend,
BaseGrammarObject, BaseGrammarObject,
) )
logger = logging.getLogger(__name__)
class GuidanceGrammar(BaseGrammarObject): class GuidanceGrammar(BaseGrammarObject):
def __init__(
self, llguidance_tokenizer: llguidance.LLTokenizer, serialized_grammar: str def __init__(self, llguidance_tokenizer: LLTokenizer, serialized_grammar: str):
):
super().__init__() super().__init__()
self.llguidance_tokenizer = llguidance_tokenizer self.llguidance_tokenizer = llguidance_tokenizer
self.serialized_grammar = serialized_grammar self.serialized_grammar = serialized_grammar
# TODO: add support for fast-forward tokens in the future self.ll_matcher = LLMatcher(
self.ll_interpreter = llguidance.LLInterpreter(
self.llguidance_tokenizer, self.llguidance_tokenizer,
self.serialized_grammar, self.serialized_grammar,
enable_backtrack=False,
enable_ff_tokens=False,
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")), log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
) )
self.pending_ff_tokens: list[int] = []
self.finished = False self.finished = False
self.bitmask = None self.bitmask = None
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]: def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
if len(self.pending_ff_tokens) > 0: ff_tokens = self.ll_matcher.compute_ff_tokens()
s = self.llguidance_tokenizer.decode_str(self.pending_ff_tokens) if ff_tokens:
ff_tokens = self.pending_ff_tokens return ff_tokens, ""
self.pending_ff_tokens = [] else:
return (ff_tokens, s) return None
return None
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]: def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
return "", -1 return "", -1
...@@ -67,32 +66,22 @@ class GuidanceGrammar(BaseGrammarObject): ...@@ -67,32 +66,22 @@ class GuidanceGrammar(BaseGrammarObject):
pass pass
def accept_token(self, token: int): def accept_token(self, token: int):
backtrack, ff_tokens = self.ll_interpreter.commit_token(token) if not self.ll_matcher.consume_token(token):
if len(ff_tokens) > 0 and backtrack == 0: logger.warning(f"matcher error: {self.ll_matcher.get_error()}")
# first token is last generated token self.finished = True
ff_tokens = ff_tokens[1:]
self.pending_ff_tokens.extend(ff_tokens)
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None: def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
if len(self.pending_ff_tokens) > 0: if self.ll_matcher.is_stopped():
# if we have pending fast-forward tokens,
# just return them immediately
ff_token = self.pending_ff_tokens.pop(0)
vocab_mask[idx, :] = 0
vocab_mask[idx, ff_token // 32] = 1 << (ff_token % 32)
return
if self.ll_interpreter.has_pending_stop():
self.finished = True self.finished = True
llguidance.torch.fill_next_token_bitmask(self.ll_interpreter, vocab_mask, idx) fill_next_token_bitmask(self.ll_matcher, vocab_mask, idx)
def allocate_vocab_mask( def allocate_vocab_mask(
self, vocab_size: int, batch_size: int, device self, vocab_size: int, batch_size: int, device
) -> torch.Tensor: ) -> torch.Tensor:
if self.bitmask is None or self.bitmask.shape[0] < batch_size: if self.bitmask is None or self.bitmask.shape[0] < batch_size:
# only create bitmask when batch gets larger # only create bitmask when batch gets larger
self.bitmask = llguidance.torch.allocate_token_bitmask( self.bitmask = allocate_token_bitmask(
batch_size, self.llguidance_tokenizer.vocab_size batch_size, self.llguidance_tokenizer.vocab_size
) )
bitmask = self.bitmask bitmask = self.bitmask
...@@ -107,7 +96,7 @@ class GuidanceGrammar(BaseGrammarObject): ...@@ -107,7 +96,7 @@ class GuidanceGrammar(BaseGrammarObject):
@staticmethod @staticmethod
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
llguidance.torch.apply_token_bitmask_inplace(logits, vocab_mask) apply_token_bitmask_inplace(logits, vocab_mask)
def copy(self): def copy(self):
return GuidanceGrammar( return GuidanceGrammar(
...@@ -117,36 +106,64 @@ class GuidanceGrammar(BaseGrammarObject): ...@@ -117,36 +106,64 @@ class GuidanceGrammar(BaseGrammarObject):
class GuidanceBackend(BaseGrammarBackend): class GuidanceBackend(BaseGrammarBackend):
def __init__(self, tokenizer, whitespace_pattern: Optional[str] = None):
def __init__(
self,
tokenizer,
whitespace_pattern: Optional[str] = None,
n_vocab: Optional[int] = None,
):
super().__init__() super().__init__()
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.whitespace_flexible = ( self.whitespace_pattern = whitespace_pattern
True if whitespace_pattern == "whitespace_flexible" else False self.llguidance_tokenizer = from_tokenizer(self.tokenizer, n_vocab)
)
self.llguidance_tokenizer = llguidance.hf.from_tokenizer(self.tokenizer, None) def _from_serialized(self, serialized_grammar) -> Optional[GuidanceGrammar]:
try:
def _from_serialized(self, serialized_grammar) -> GuidanceGrammar: return GuidanceGrammar(
return GuidanceGrammar( llguidance_tokenizer=self.llguidance_tokenizer,
llguidance_tokenizer=self.llguidance_tokenizer, serialized_grammar=serialized_grammar,
serialized_grammar=serialized_grammar, )
except Exception as e:
logger.warning(f"Skip invalid grammar: {serialized_grammar}, {e=}")
return None
def dispatch_json(self, key_string: str) -> Optional[GuidanceGrammar]:
serialized_grammar = LLMatcher.grammar_from_json_schema(
key_string,
defaults={
"whitespace_pattern": self.whitespace_pattern,
},
) )
def dispatch_json(self, key_string: str) -> GuidanceGrammar:
json_schema = key_string
compiler = llguidance.JsonCompiler(whitespace_flexible=self.whitespace_flexible)
serialized_grammar = compiler.compile(json_schema)
return self._from_serialized(serialized_grammar)
def dispatch_regex(self, key_string: str) -> GuidanceGrammar:
compiler = llguidance.RegexCompiler()
serialized_grammar = compiler.compile(regex=key_string)
return self._from_serialized(serialized_grammar) return self._from_serialized(serialized_grammar)
def dispatch_ebnf(self, key_string: str) -> GuidanceGrammar: def dispatch_regex(self, key_string: str) -> Optional[GuidanceGrammar]:
compiler = llguidance.LarkCompiler() serialized_grammar = grammar_from("regex", key_string)
serialized_grammar = compiler.compile(any_to_lark(key_string))
return self._from_serialized(serialized_grammar) return self._from_serialized(serialized_grammar)
def dispatch_structural_tag(self, key_string: str): def dispatch_ebnf(self, key_string: str) -> Optional[GuidanceGrammar]:
return super().dispatch_structural_tag(key_string) try:
serialized_grammar = grammar_from("ebnf", key_string)
return self._from_serialized(serialized_grammar)
except ValueError as e:
logger.warning(f"Skip invalid ebnf: regex={key_string}, {e=}")
return None
def dispatch_structural_tag(self, key_string: str) -> Optional[GuidanceGrammar]:
try:
structural_tag = json.loads(key_string)
tags = [
StructTag(
begin=structure["begin"],
grammar=structure["schema"],
end=structure["end"],
trigger=structural_tag["triggers"][0], # TODO?
)
for structure in structural_tag["structures"]
]
g = StructTag.to_grammar(tags)
return self._from_serialized(g)
except Exception as e:
logging.warning(f"Skip invalid structural_tag: {key_string}, {e=}")
return None
...@@ -238,5 +238,11 @@ class TestEBNFConstrained(CustomTestCase): ...@@ -238,5 +238,11 @@ class TestEBNFConstrained(CustomTestCase):
) )
class TestEBNFConstrainedLLGuidance(TestEBNFConstrained):
@classmethod
def setUpClass(cls):
setup_class(cls, "llguidance", disable_overlap=False)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment