Unverified Commit 4bf82d4b authored by Russell Bryant's avatar Russell Bryant Committed by GitHub
Browse files

[V1] Add regex structured output support with xgrammar (#14590)


Signed-off-by: default avatarRussell Bryant <rbryant@redhat.com>
parent 9ab32671
......@@ -19,7 +19,7 @@ tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer >= 0.10.11, < 0.11
outlines == 0.1.11
lark == 1.2.2
xgrammar == 0.1.11; platform_machine == "x86_64"
xgrammar == 0.1.15; platform_machine == "x86_64" or platform_machine == "aarch64"
typing_extensions >= 4.10
filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317
partial-json-parser # used for parsing partial JSON outputs
......
# SPDX-License-Identifier: Apache-2.0
import json
import re
import jsonschema
import pytest
......@@ -219,25 +220,24 @@ def test_guided_regex(monkeypatch, sample_regex, guided_decoding_backend: str):
guided_decoding=GuidedDecodingParams(
regex=sample_regex,
backend=guided_decoding_backend))
with pytest.raises(ValueError,
match="Regex guided decoding is not supported."):
llm.generate(prompts=[
outputs = llm.generate(
prompts=[
f"Give an example IPv4 address with this regex: {sample_regex}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True)
use_tqdm=True,
)
# Once regex is supported --
#assert outputs is not None
#for output in outputs:
# assert output is not None
# assert isinstance(output, RequestOutput)
# prompt = output.prompt
# generated_text = output.outputs[0].text
# print(generated_text)
# assert generated_text is not None
# assert re.fullmatch(sample_regex, generated_text) is not None
# print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
print(generated_text)
assert generated_text is not None
assert re.fullmatch(sample_regex, generated_text) is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
@pytest.mark.skip_global_cleanup
......
......@@ -112,6 +112,8 @@ class StructuredOutputManager:
ctx = self.compiler.compile_builtin_json_grammar()
elif request_type == StructuredOutputOptions.GRAMMAR:
ctx = self.compiler.compile_grammar(grammar_spec)
elif request_type == StructuredOutputOptions.REGEX:
ctx = self.compiler.compile_regex(grammar_spec)
else:
logger.error("Validation should have already occurred. "
"Please file an issue.")
......
......@@ -251,7 +251,11 @@ def validate_structured_output_request(
gd_params = sampling_params.guided_decoding
if gd_params.regex:
raise ValueError("Regex structured output is not supported.")
try:
xgr.Grammar.from_regex(gd_params.regex)
except Exception as err:
raise ValueError("Failed to transform regex into a grammar: "
f"{err}") from err
if gd_params.choice:
choice_grammar = choice_as_grammar(gd_params.choice)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment