Unverified Commit 7551498a authored by JC1DA's avatar JC1DA Committed by GitHub
Browse files

[Feature] Support llguidance for constrained decoding (#3298)

parent 44a2c4bd
...@@ -17,10 +17,13 @@ ...@@ -17,10 +17,13 @@
"\n", "\n",
"- [Outlines](https://github.com/dottxt-ai/outlines) (default): Supports JSON schema and regular expression constraints.\n", "- [Outlines](https://github.com/dottxt-ai/outlines) (default): Supports JSON schema and regular expression constraints.\n",
"- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema, regular expression, and EBNF constraints.\n", "- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema, regular expression, and EBNF constraints.\n",
"- [Llguidance](https://github.com/guidance-ai/llguidance): Supports JSON schema, regular expression, and EBNF constraints.\n",
"\n", "\n",
"We suggest using XGrammar for its better performance and utility. XGrammar currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md). For more details, see [XGrammar technical overview](https://blog.mlc.ai/2024/11/22/achieving-efficient-flexible-portable-structured-generation-with-xgrammar).\n", "We suggest using XGrammar for its better performance and utility. XGrammar currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md). For more details, see [XGrammar technical overview](https://blog.mlc.ai/2024/11/22/achieving-efficient-flexible-portable-structured-generation-with-xgrammar).\n",
"\n", "\n",
"To use Xgrammar, simply add `--grammar-backend` xgrammar when launching the server. If no backend is specified, Outlines will be used as the default.\n", "To use Xgrammar, simply add `--grammar-backend xgrammar` when launching the server.\n",
"To use llguidance, add `--grammar-backend llguidance` when launching the server.\n",
"If no backend is specified, Outlines will be used as the default.\n",
"\n", "\n",
"For better output quality, **It's advisable to explicitly include instructions in the prompt to guide the model to generate the desired format.** For example, you can specify, 'Please generate the output in the following JSON format: ...'.\n" "For better output quality, **It's advisable to explicitly include instructions in the prompt to guide the model to generate the desired format.** For example, you can specify, 'Please generate the output in the following JSON format: ...'.\n"
] ]
......
...@@ -38,6 +38,7 @@ runtime_common = [ ...@@ -38,6 +38,7 @@ runtime_common = [
"xgrammar==0.1.10", "xgrammar==0.1.10",
"ninja", "ninja",
"transformers==4.48.3", "transformers==4.48.3",
"llguidance>=0.6.15"
] ]
srt = [ srt = [
"sglang[runtime_common]", "sglang[runtime_common]",
......
...@@ -86,6 +86,13 @@ def create_grammar_backend(server_args: ServerArgs, tokenizer, vocab_size): ...@@ -86,6 +86,13 @@ def create_grammar_backend(server_args: ServerArgs, tokenizer, vocab_size):
from sglang.srt.constrained.xgrammar_backend import XGrammarGrammarBackend from sglang.srt.constrained.xgrammar_backend import XGrammarGrammarBackend
grammar_backend = XGrammarGrammarBackend(tokenizer, vocab_size=vocab_size) grammar_backend = XGrammarGrammarBackend(tokenizer, vocab_size=vocab_size)
elif server_args.grammar_backend == "llguidance":
from sglang.srt.constrained.llguidance_backend import GuidanceBackend
grammar_backend = GuidanceBackend(
tokenizer=tokenizer,
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
)
else: else:
raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}") raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}")
......
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Constrained decoding with llguidance backend."""
import json
import os
from typing import List, Optional, Tuple
import llguidance
import llguidance.hf
import llguidance.torch
import torch
from llguidance.gbnf_to_lark import any_to_lark
from sglang.srt.constrained.base_grammar_backend import (
BaseGrammarBackend,
BaseGrammarObject,
)
class GuidanceGrammar(BaseGrammarObject):
def __init__(
self, llguidance_tokenizer: llguidance.LLTokenizer, serialized_grammar: str
):
self.llguidance_tokenizer = llguidance_tokenizer
self.serialized_grammar = serialized_grammar
# TODO: add support for fast-forward tokens in the future
self.ll_interpreter = llguidance.LLInterpreter(
self.llguidance_tokenizer,
self.serialized_grammar,
enable_backtrack=False,
enable_ff_tokens=False,
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
)
self.pending_ff_tokens: list[int] = []
self.finished = False
self.bitmask = None
def try_jump_forward(self, tokenizer) -> Tuple[List[int], str]:
if len(self.pending_ff_tokens) > 0:
s = self.llguidance_tokenizer.decode_str(self.pending_ff_tokens)
ff_tokens = self.pending_ff_tokens
self.pending_ff_tokens = []
return (ff_tokens, s)
return None
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
return "", -1
def jump_and_retokenize(
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
):
pass
def accept_token(self, token: int):
backtrack, ff_tokens = self.ll_interpreter.commit_token(token)
if len(ff_tokens) > 0 and backtrack == 0:
# first token is last generated token
ff_tokens = ff_tokens[1:]
self.pending_ff_tokens.extend(ff_tokens)
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
if len(self.pending_ff_tokens) > 0:
# if we have pending fast-forward tokens,
# just return them immediately
ff_token = self.pending_ff_tokens.pop(0)
vocab_mask[idx, :] = 0
vocab_mask[idx, ff_token // 32] = 1 << (ff_token % 32)
return
if self.ll_interpreter.has_pending_stop():
self.finished = True
llguidance.torch.fill_next_token_bitmask(self.ll_interpreter, vocab_mask, idx)
def allocate_vocab_mask(
self, vocab_size: int, batch_size: int, device
) -> torch.Tensor:
if self.bitmask is None or self.bitmask.shape[0] < batch_size:
# only create bitmask when batch gets larger
self.bitmask = llguidance.torch.allocate_token_bitmask(
batch_size, self.llguidance_tokenizer.vocab_size
)
bitmask = self.bitmask
else:
bitmask = self.bitmask[:batch_size]
return bitmask
@staticmethod
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
return vocab_mask.to(device, non_blocking=True)
@staticmethod
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
llguidance.torch.apply_token_bitmask_inplace(logits, vocab_mask)
def copy(self):
return GuidanceGrammar(
llguidance_tokenizer=self.llguidance_tokenizer,
serialized_grammar=self.serialized_grammar,
)
class GuidanceBackend(BaseGrammarBackend):
def __init__(self, tokenizer, whitespace_pattern: Optional[str] = None):
super().__init__()
self.tokenizer = tokenizer
self.whitespace_flexible = (
True if whitespace_pattern == "whitespace_flexible" else False
)
self.llguidance_tokenizer = llguidance.hf.from_tokenizer(self.tokenizer, None)
def init_value_impl(self, key: Tuple[str, str]) -> GuidanceGrammar:
mode, value = key
if mode == "json":
json_schema = value
compiler = llguidance.JsonCompiler(
whitespace_flexible=self.whitespace_flexible
)
serialized_grammar = compiler.compile(json_schema)
elif mode == "regex":
compiler = llguidance.RegexCompiler()
serialized_grammar = compiler.compile(regex=value)
elif mode == "ebnf":
compiler = llguidance.LarkCompiler()
serialized_grammar = compiler.compile(any_to_lark(value))
return GuidanceGrammar(
llguidance_tokenizer=self.llguidance_tokenizer,
serialized_grammar=serialized_grammar,
)
...@@ -698,7 +698,7 @@ class ServerArgs: ...@@ -698,7 +698,7 @@ class ServerArgs:
parser.add_argument( parser.add_argument(
"--grammar-backend", "--grammar-backend",
type=str, type=str,
choices=["xgrammar", "outlines"], choices=["xgrammar", "outlines", "llguidance"],
default=ServerArgs.grammar_backend, default=ServerArgs.grammar_backend,
help="Choose the backend for grammar-guided decoding.", help="Choose the backend for grammar-guided decoding.",
) )
......
""" """
python3 -m unittest test_ebnf_constrained.TestEBNFConstrained.test_ebnf_generate_email python3 -m unittest test_ebnf_constrained.TestEBNFConstrained.test_ebnf_generate_email
python3 -m unittest test_ebnf_constrained.TestEBNFConstrained.test_ebnf_generate_greeting python3 -m unittest test_ebnf_constrained.TestEBNFConstrained.test_ebnf_generate_greeting
python3 -m unittest test_ebnf_constrained.TestEBNFConstrainedLLGuidance.test_ebnf_generate_email
python3 -m unittest test_ebnf_constrained.TestEBNFConstrainedLLGuidance.test_ebnf_generate_greeting
""" """
import json import json
...@@ -17,7 +19,7 @@ from sglang.test.test_utils import ( ...@@ -17,7 +19,7 @@ from sglang.test.test_utils import (
) )
def setup_class(cls, disable_overlap: bool): def setup_class(cls, backend: str, disable_overlap: bool):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST
cls.ebnf_grammar = 'root ::= "test"' # Default grammar cls.ebnf_grammar = 'root ::= "test"' # Default grammar
...@@ -26,7 +28,7 @@ def setup_class(cls, disable_overlap: bool): ...@@ -26,7 +28,7 @@ def setup_class(cls, disable_overlap: bool):
"--max-running-requests", "--max-running-requests",
"10", "10",
"--grammar-backend", "--grammar-backend",
"xgrammar", backend,
] ]
if disable_overlap: if disable_overlap:
...@@ -43,7 +45,7 @@ def setup_class(cls, disable_overlap: bool): ...@@ -43,7 +45,7 @@ def setup_class(cls, disable_overlap: bool):
class TestEBNFConstrained(unittest.TestCase): class TestEBNFConstrained(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
setup_class(cls, disable_overlap=False) setup_class(cls, "xgrammar", disable_overlap=False)
cls.check_jump_forward = False cls.check_jump_forward = False
@classmethod @classmethod
...@@ -236,5 +238,12 @@ class TestEBNFConstrained(unittest.TestCase): ...@@ -236,5 +238,12 @@ class TestEBNFConstrained(unittest.TestCase):
) )
class TestEBNFConstrainedLLGuidance(TestEBNFConstrained):
@classmethod
def setUpClass(cls):
setup_class(cls, "llguidance", disable_overlap=False)
cls.check_jump_forward = False
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
""" """
python3 -m unittest test_json_constrained.TestJSONConstrainedOutlinesBackend.test_json_generate python3 -m unittest test_json_constrained.TestJSONConstrainedOutlinesBackend.test_json_generate
python3 -m unittest test_json_constrained.TestJSONConstrainedXGrammarBackend.test_json_generate python3 -m unittest test_json_constrained.TestJSONConstrainedXGrammarBackend.test_json_generate
python3 -m unittest test_json_constrained.TestJSONConstrainedLLGuidanceBackend.test_json_generate
""" """
import json import json
...@@ -30,6 +31,7 @@ def setup_class(cls, backend: str, disable_overlap: bool): ...@@ -30,6 +31,7 @@ def setup_class(cls, backend: str, disable_overlap: bool):
"population": {"type": "integer"}, "population": {"type": "integer"},
}, },
"required": ["name", "population"], "required": ["name", "population"],
"additionalProperties": False,
} }
) )
...@@ -146,5 +148,12 @@ class TestJSONConstrainedXGrammarBackend(TestJSONConstrainedOutlinesBackend): ...@@ -146,5 +148,12 @@ class TestJSONConstrainedXGrammarBackend(TestJSONConstrainedOutlinesBackend):
cls.check_jump_forward = False cls.check_jump_forward = False
class TestJSONConstrainedLLGuidanceBackend(TestJSONConstrainedOutlinesBackend):
@classmethod
def setUpClass(cls):
setup_class(cls, backend="llguidance", disable_overlap=False)
cls.check_jump_forward = False
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
""" """
python3 -m unittest test_regex_constrained.TestRegexConstrained.test_regex_generate_email python3 -m unittest test_regex_constrained.TestRegexConstrained.test_regex_generate_email
python3 -m unittest test_regex_constrained.TestRegexConstrained.test_regex_generate_greeting python3 -m unittest test_regex_constrained.TestRegexConstrained.test_regex_generate_greeting
python3 -m unittest test_regex_constrained.TestRegexConstrainedLLGuidance.test_regex_generate_email
python3 -m unittest test_regex_constrained.TestRegexConstrainedLLGuidance.test_regex_generate_greeting
python3 -m unittest test_regex_constrained.TestJumpForwardLLGuidance.test_regex_generate_email
python3 -m unittest test_regex_constrained.TestJumpForwardLLGuidance.test_regex_generate_greeting
""" """
import json import json
...@@ -17,7 +21,7 @@ from sglang.test.test_utils import ( ...@@ -17,7 +21,7 @@ from sglang.test.test_utils import (
) )
def setup_class(cls, disable_overlap: bool): def setup_class(cls, backend: str, disable_overlap: bool):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST
...@@ -25,7 +29,7 @@ def setup_class(cls, disable_overlap: bool): ...@@ -25,7 +29,7 @@ def setup_class(cls, disable_overlap: bool):
"--max-running-requests", "--max-running-requests",
"10", "10",
"--grammar-backend", "--grammar-backend",
"xgrammar", backend,
] ]
if disable_overlap: if disable_overlap:
...@@ -42,7 +46,7 @@ def setup_class(cls, disable_overlap: bool): ...@@ -42,7 +46,7 @@ def setup_class(cls, disable_overlap: bool):
class TestRegexConstrained(unittest.TestCase): class TestRegexConstrained(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
setup_class(cls, disable_overlap=False) setup_class(cls, "xgrammar", disable_overlap=False)
cls.check_jump_forward = False cls.check_jump_forward = False
@classmethod @classmethod
...@@ -178,9 +182,22 @@ class TestRegexConstrained(unittest.TestCase): ...@@ -178,9 +182,22 @@ class TestRegexConstrained(unittest.TestCase):
class TestJumpForward(TestRegexConstrained): class TestJumpForward(TestRegexConstrained):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
setup_class(cls, disable_overlap=True) setup_class(cls, "xgrammar", disable_overlap=True)
cls.check_jump_forward = True cls.check_jump_forward = True
class TestJumpForwardLLGuidance(TestRegexConstrained):
@classmethod
def setUpClass(cls):
setup_class(cls, "llguidance", disable_overlap=True)
cls.check_jump_forward = True
class TestRegexConstrainedLLGuidance(TestRegexConstrained):
@classmethod
def setUpClass(cls):
setup_class(cls, "llguidance", disable_overlap=True)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment