Unverified Commit 3bcf5ece authored by Enrique Shockwave's avatar Enrique Shockwave Committed by GitHub
Browse files

support regex in xgrammar backend (#2983)

parent 2c05f81f
...@@ -219,7 +219,7 @@ ...@@ -219,7 +219,7 @@
"SGLang supports two grammar backends:\n", "SGLang supports two grammar backends:\n",
"\n", "\n",
"- [Outlines](https://github.com/dottxt-ai/outlines) (default): Supports JSON schema and regular expression constraints.\n", "- [Outlines](https://github.com/dottxt-ai/outlines) (default): Supports JSON schema and regular expression constraints.\n",
"- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema and EBNF constraints.\n", "- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema, regular expression, and EBNF constraints.\n",
" - XGrammar currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md)\n", " - XGrammar currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md)\n",
"\n", "\n",
"Initialize the XGrammar backend using `--grammar-backend xgrammar` flag\n", "Initialize the XGrammar backend using `--grammar-backend xgrammar` flag\n",
......
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
"SGLang supports two grammar backends:\n", "SGLang supports two grammar backends:\n",
"\n", "\n",
"- [Outlines](https://github.com/dottxt-ai/outlines) (default): Supports JSON schema and regular expression constraints.\n", "- [Outlines](https://github.com/dottxt-ai/outlines) (default): Supports JSON schema and regular expression constraints.\n",
"- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema and EBNF constraints and currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md).\n", "- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema, regular expression, and EBNF constraints.\n",
" - XGrammar currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md)\n",
"\n", "\n",
"We suggest using XGrammar whenever possible for its better performance. For more details, see [XGrammar technical overview](https://blog.mlc.ai/2024/11/22/achieving-efficient-flexible-portable-structured-generation-with-xgrammar).\n", "We suggest using XGrammar whenever possible for its better performance. For more details, see [XGrammar technical overview](https://blog.mlc.ai/2024/11/22/achieving-efficient-flexible-portable-structured-generation-with-xgrammar).\n",
"\n", "\n",
......
...@@ -189,7 +189,7 @@ You can specify a JSON schema, regular expression or [EBNF](https://en.wikipedia ...@@ -189,7 +189,7 @@ You can specify a JSON schema, regular expression or [EBNF](https://en.wikipedia
SGLang supports two grammar backends: SGLang supports two grammar backends:
- [Outlines](https://github.com/dottxt-ai/outlines) (default): Supports JSON schema and regular expression constraints. - [Outlines](https://github.com/dottxt-ai/outlines) (default): Supports JSON schema and regular expression constraints.
- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema and EBNF constraints. - [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema, regular expression, and EBNF constraints.
- XGrammar currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md) - XGrammar currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md)
Initialize the XGrammar backend using `--grammar-backend xgrammar` flag Initialize the XGrammar backend using `--grammar-backend xgrammar` flag
......
...@@ -23,7 +23,7 @@ runtime_common = [ ...@@ -23,7 +23,7 @@ runtime_common = [
"packaging", "pillow", "prometheus-client>=0.20.0", "packaging", "pillow", "prometheus-client>=0.20.0",
"psutil", "pydantic", "python-multipart", "psutil", "pydantic", "python-multipart",
"pyzmq>=25.1.2", "torchao>=0.7.0", "uvicorn", "uvloop", "pyzmq>=25.1.2", "torchao>=0.7.0", "uvicorn", "uvloop",
"xgrammar>=0.1.6" "xgrammar>=0.1.10"
] ]
srt = [ srt = [
"sglang[runtime_common]", "cuda-python", "sglang[runtime_common]", "cuda-python",
......
...@@ -19,6 +19,7 @@ from typing import List, Tuple ...@@ -19,6 +19,7 @@ from typing import List, Tuple
import torch import torch
from xgrammar import ( from xgrammar import (
CompiledGrammar, CompiledGrammar,
Grammar,
GrammarCompiler, GrammarCompiler,
GrammarMatcher, GrammarMatcher,
TokenizerInfo, TokenizerInfo,
...@@ -133,10 +134,13 @@ class XGrammarGrammarBackend(BaseGrammarBackend): ...@@ -133,10 +134,13 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}") logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
return None return None
elif key_type == "regex": elif key_type == "regex":
logger.warning( try:
"regex hasn't been supported by xgrammar yet. This is skipped." ctx = self.grammar_compiler.compile_grammar(
) Grammar.from_regex(key_string)
return None )
except RuntimeError as e:
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
return None
else: else:
raise ValueError(f"Invalid key_type: {key_type}") raise ValueError(f"Invalid key_type: {key_type}")
......
...@@ -31,6 +31,7 @@ suites = { ...@@ -31,6 +31,7 @@ suites = {
"test_openai_server.py", "test_openai_server.py",
"test_pytorch_sampling_backend.py", "test_pytorch_sampling_backend.py",
"test_radix_attention.py", "test_radix_attention.py",
"test_regex_constrained.py",
"test_release_memory_occupation.py", "test_release_memory_occupation.py",
"test_request_length_validation.py", "test_request_length_validation.py",
"test_retract_decode.py", "test_retract_decode.py",
......
"""
python3 -m unittest test_regex_constrained.TestRegexConstrained.test_regex_generate_email
python3 -m unittest test_regex_constrained.TestRegexConstrained.test_regex_generate_greeting
"""
import json
import unittest
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
def setup_class(cls, disable_overlap: bool):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
other_args = [
"--max-running-requests",
"10",
"--grammar-backend",
"xgrammar",
]
if disable_overlap:
other_args += ["--disable-overlap-schedule"]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)
class TestRegexConstrained(unittest.TestCase):
@classmethod
def setUpClass(cls):
setup_class(cls, disable_overlap=False)
cls.check_jump_forward = False
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def run_decode(
self,
regex,
prompt,
return_logprob=False,
top_logprobs_num=0,
n=1,
):
response = requests.post(
self.base_url + "/generate",
json={
"text": prompt,
"sampling_params": {
"temperature": 0 if n == 1 else 0.5,
"max_new_tokens": 128,
"n": n,
"regex": regex,
},
"stream": False,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"logprob_start_len": 0,
},
)
ret = response.json()
print(json.dumps(ret, indent=2))
print("=" * 100)
if not isinstance(ret, list):
self.fail(f"Expected response to be a list, but got {type(ret)}")
for item in ret:
text = item.get("text", "").strip()
if not text:
self.fail("Generated text is empty.")
if not self.regex_match(text, regex):
self.fail(f"Text '{text}' does not match regex pattern.")
def regex_match(self, text, pattern):
import re
return re.match(pattern, text) is not None
def test_regex_generate_email(self):
pattern = r"^user@example\.com$"
prompt = "Generate an email address:"
self.run_decode(
regex=pattern,
prompt=prompt,
n=3,
)
def test_regex_generate_greeting(self):
pattern = r"^(Hello|Hi|Hey)$"
prompt = "Generate a greeting:"
self.run_decode(
regex=pattern,
prompt=prompt,
n=3,
)
def test_regex_generate_number(self):
pattern = r"^\d{3}$"
prompt = "Generate a three-digit number:"
self.run_decode(
regex=pattern,
prompt=prompt,
n=3,
)
def test_regex_generate_phone(self):
pattern = r"^\(\d{3}\) \d{3}-\d{4}$"
prompt = "Generate a phone number:"
self.run_decode(
regex=pattern,
prompt=prompt,
n=3,
)
def test_regex_generate_date(self):
pattern = r"^2024-(0[1-9]|1[0-2])-(0[1-9]|[12]\d|3[01])$"
prompt = "Generate a date in YYYY-MM-DD format:"
self.run_decode(
regex=pattern,
prompt=prompt,
n=3,
)
def test_regex_generate_hex_color(self):
pattern = r"^#[0-9A-F]{6}$"
prompt = "Generate a hex color code:"
self.run_decode(
regex=pattern,
prompt=prompt,
n=3,
)
def test_regex_generate_complex_json(self):
pattern = r'^\{\s*"name"\s*:\s*"[a-zA-Z0-9 ]+"\s*,\s*"age"\s*:\s*[1-9][0-9]*\s*,\s*"city"\s*:\s*"[a-zA-Z0-9 ]+"\s*\}$'
prompt = "Generate a simple JSON with name, age, and city:"
self.run_decode(
regex=pattern,
prompt=prompt,
n=3,
)
def test_regex_generate_custom_log_format(self):
pattern = r"^\[2024-01-01T12:00:00Z\] INFO: System\.process - Operation [a-z]+ successfully$"
prompt = "Generate a log entry:"
self.run_decode(
regex=pattern,
prompt=prompt,
n=3,
)
class TestJumpForward(TestRegexConstrained):
@classmethod
def setUpClass(cls):
setup_class(cls, disable_overlap=True)
cls.check_jump_forward = True
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment