Unverified Commit d6902ce7 authored by Nathan Hoos's avatar Nathan Hoos Committed by GitHub
Browse files

[V0][V1][Core] Add outlines integration for V1, and update V0 integration. (#15975)


Signed-off-by: default avatarNathan Hoos <thwackyy.y@gmail.com>
parent 5e53c89a
...@@ -21,7 +21,9 @@ prometheus-fastapi-instrumentator >= 7.0.0 ...@@ -21,7 +21,9 @@ prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer >= 0.10.11, < 0.11 lm-format-enforcer >= 0.10.11, < 0.11
llguidance >= 0.7.11, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64" llguidance >= 0.7.11, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64"
outlines == 0.1.11 outlines_core == 0.2.10
# required for outlines backend disk cache
diskcache == 5.6.3
lark == 1.2.2 lark == 1.2.2
xgrammar == 0.1.19; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64" xgrammar == 0.1.19; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64"
typing_extensions >= 4.10 typing_extensions >= 4.10
......
...@@ -16,14 +16,18 @@ from vllm.outputs import RequestOutput ...@@ -16,14 +16,18 @@ from vllm.outputs import RequestOutput
from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.sampling_params import GuidedDecodingParams, SamplingParams
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
GUIDED_DECODING_BACKENDS = [
# Separate backends which support grammars vs ones
# which only support regex based constraints in tests.
GRAMMAR_DECODING_BACKENDS = [
# (backend, disable_any_whitespace), # (backend, disable_any_whitespace),
("outlines", False),
("lm-format-enforcer", False), ("lm-format-enforcer", False),
("xgrammar", True), ("xgrammar", True),
("guidance", True), ("guidance", True),
] ]
ALL_DECODING_BACKENDS = ([("outlines", False)] + GRAMMAR_DECODING_BACKENDS)
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def llm(): def llm():
...@@ -39,7 +43,7 @@ def llm(): ...@@ -39,7 +43,7 @@ def llm():
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS) ALL_DECODING_BACKENDS)
def test_guided_regex(sample_regex, llm, guided_decoding_backend: str, def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
disable_any_whitespace: bool): disable_any_whitespace: bool):
sampling_params = SamplingParams( sampling_params = SamplingParams(
...@@ -49,6 +53,7 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str, ...@@ -49,6 +53,7 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
regex=sample_regex, regex=sample_regex,
backend=guided_decoding_backend, backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace)) disable_any_whitespace=disable_any_whitespace))
outputs = llm.generate(prompts=[ outputs = llm.generate(prompts=[
f"Give an example IPv4 address with this regex: {sample_regex}" f"Give an example IPv4 address with this regex: {sample_regex}"
] * 2, ] * 2,
...@@ -69,7 +74,7 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str, ...@@ -69,7 +74,7 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS) ALL_DECODING_BACKENDS)
def test_guided_json_completion(sample_json_schema, llm, def test_guided_json_completion(sample_json_schema, llm,
guided_decoding_backend: str, guided_decoding_backend: str,
disable_any_whitespace: bool): disable_any_whitespace: bool):
...@@ -103,7 +108,7 @@ def test_guided_json_completion(sample_json_schema, llm, ...@@ -103,7 +108,7 @@ def test_guided_json_completion(sample_json_schema, llm,
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS) ALL_DECODING_BACKENDS)
def test_guided_complex_json_completion(sample_complex_json_schema, llm, def test_guided_complex_json_completion(sample_complex_json_schema, llm,
guided_decoding_backend: str, guided_decoding_backend: str,
disable_any_whitespace: bool): disable_any_whitespace: bool):
...@@ -138,7 +143,7 @@ def test_guided_complex_json_completion(sample_complex_json_schema, llm, ...@@ -138,7 +143,7 @@ def test_guided_complex_json_completion(sample_complex_json_schema, llm,
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS) ALL_DECODING_BACKENDS)
def test_guided_definition_json_completion(sample_definition_json_schema, llm, def test_guided_definition_json_completion(sample_definition_json_schema, llm,
guided_decoding_backend: str, guided_decoding_backend: str,
disable_any_whitespace: bool): disable_any_whitespace: bool):
...@@ -173,7 +178,7 @@ def test_guided_definition_json_completion(sample_definition_json_schema, llm, ...@@ -173,7 +178,7 @@ def test_guided_definition_json_completion(sample_definition_json_schema, llm,
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS) ALL_DECODING_BACKENDS)
def test_guided_enum_json_completion(sample_enum_json_schema, llm, def test_guided_enum_json_completion(sample_enum_json_schema, llm,
guided_decoding_backend: str, guided_decoding_backend: str,
disable_any_whitespace: bool): disable_any_whitespace: bool):
...@@ -218,7 +223,7 @@ def test_guided_enum_json_completion(sample_enum_json_schema, llm, ...@@ -218,7 +223,7 @@ def test_guided_enum_json_completion(sample_enum_json_schema, llm,
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS) ALL_DECODING_BACKENDS)
def test_guided_choice_completion(sample_guided_choice, llm, def test_guided_choice_completion(sample_guided_choice, llm,
guided_decoding_backend: str, guided_decoding_backend: str,
disable_any_whitespace: bool): disable_any_whitespace: bool):
...@@ -248,7 +253,7 @@ def test_guided_choice_completion(sample_guided_choice, llm, ...@@ -248,7 +253,7 @@ def test_guided_choice_completion(sample_guided_choice, llm,
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS) GRAMMAR_DECODING_BACKENDS)
def test_guided_grammar(sample_sql_statements, llm, def test_guided_grammar(sample_sql_statements, llm,
guided_decoding_backend: str, guided_decoding_backend: str,
disable_any_whitespace: bool): disable_any_whitespace: bool):
...@@ -344,7 +349,7 @@ def test_disable_guided_decoding_fallback(sample_regex, llm): ...@@ -344,7 +349,7 @@ def test_disable_guided_decoding_fallback(sample_regex, llm):
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS) GRAMMAR_DECODING_BACKENDS)
def test_guided_json_object(llm, guided_decoding_backend: str, def test_guided_json_object(llm, guided_decoding_backend: str,
disable_any_whitespace: bool): disable_any_whitespace: bool):
sampling_params = SamplingParams( sampling_params = SamplingParams(
...@@ -377,7 +382,9 @@ def test_guided_json_object(llm, guided_decoding_backend: str, ...@@ -377,7 +382,9 @@ def test_guided_json_object(llm, guided_decoding_backend: str,
# Parse to verify it is valid JSON # Parse to verify it is valid JSON
parsed_json = json.loads(generated_text) parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict) # A list is not what was intended, but is still valid
# json.
assert isinstance(parsed_json, (dict, list))
class CarType(str, Enum): class CarType(str, Enum):
...@@ -395,7 +402,7 @@ class CarDescription(BaseModel): ...@@ -395,7 +402,7 @@ class CarDescription(BaseModel):
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS) ALL_DECODING_BACKENDS)
def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str, def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str,
disable_any_whitespace: bool): disable_any_whitespace: bool):
json_schema = CarDescription.model_json_schema() json_schema = CarDescription.model_json_schema()
...@@ -427,7 +434,7 @@ def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str, ...@@ -427,7 +434,7 @@ def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str,
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS) ALL_DECODING_BACKENDS)
def test_guided_number_range_json_completion(llm, guided_decoding_backend: str, def test_guided_number_range_json_completion(llm, guided_decoding_backend: str,
disable_any_whitespace: bool): disable_any_whitespace: bool):
sample_output_schema = { sample_output_schema = {
......
...@@ -46,20 +46,15 @@ def test_guided_logits_processors(zephyr_7B_tokenzer, sample_regex, ...@@ -46,20 +46,15 @@ def test_guided_logits_processors(zephyr_7B_tokenzer, sample_regex,
whitespace_pattern=None, whitespace_pattern=None,
reasoner=None) reasoner=None)
token_ids = zephyr_7B_tokenzer.encode(
f"Give an example IPv4 address with this regex: {sample_regex}")
tensor = torch.rand(32000) tensor = torch.rand(32000)
original_tensor = torch.clone(tensor) original_tensor = torch.clone(tensor)
regex_LP(token_ids, tensor) tensor = regex_LP([], tensor)
assert tensor.shape == original_tensor.shape assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor) assert not torch.allclose(tensor, original_tensor)
token_ids = zephyr_7B_tokenzer.encode(
f"Give an employee profile that fits this schema: {sample_json_schema}"
)
tensor = torch.rand(32000) tensor = torch.rand(32000)
original_tensor = torch.clone(tensor) original_tensor = torch.clone(tensor)
json_LP(token_ids, tensor) tensor = json_LP([], tensor)
assert tensor.shape == original_tensor.shape assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor) assert not torch.allclose(tensor, original_tensor)
...@@ -81,8 +76,6 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool, ...@@ -81,8 +76,6 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
seed=0, seed=0,
dtype="bfloat16", dtype="bfloat16",
) )
token_ids = zephyr_7B_tokenzer.encode(
f"Give an example IPv4 address with this regex: {sample_regex}")
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend) regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
regex_lp = get_local_guided_decoding_logits_processor( regex_lp = get_local_guided_decoding_logits_processor(
...@@ -92,13 +85,11 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool, ...@@ -92,13 +85,11 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
assert regex_lp is not None assert regex_lp is not None
tensor = torch.rand(32000) tensor = torch.rand(32000)
original_tensor = torch.clone(tensor) original_tensor = torch.clone(tensor)
tensor = regex_lp(token_ids, tensor) # allowed tokens at state 0
tensor = regex_lp([], tensor)
assert tensor.shape == original_tensor.shape assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor) assert not torch.allclose(tensor, original_tensor)
token_ids = zephyr_7B_tokenzer.encode(
f"Give an employee profile that fits this schema: {sample_json_schema}"
)
json_request = GuidedDecodingParams(json=sample_json_schema, json_request = GuidedDecodingParams(json=sample_json_schema,
backend=backend) backend=backend)
json_lp = await get_guided_decoding_logits_processor( json_lp = await get_guided_decoding_logits_processor(
...@@ -106,7 +97,7 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool, ...@@ -106,7 +97,7 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
assert json_lp is not None assert json_lp is not None
tensor = torch.rand(32000) tensor = torch.rand(32000)
original_tensor = torch.clone(tensor) original_tensor = torch.clone(tensor)
tensor = json_lp(token_ids, tensor) tensor = json_lp([], tensor)
assert tensor.shape == original_tensor.shape assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor) assert not torch.allclose(tensor, original_tensor)
...@@ -130,7 +121,6 @@ async def test_guided_logits_processor_with_reasoning( ...@@ -130,7 +121,6 @@ async def test_guided_logits_processor_with_reasoning(
dtype="bfloat16", dtype="bfloat16",
) )
token_ids = deepseek_r1_qwen_tokenizer.encode( token_ids = deepseek_r1_qwen_tokenizer.encode(
f"Give an example IPv4 address with this regex: {sample_regex}."
"<think>here is the thinking process") "<think>here is the thinking process")
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend) regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
...@@ -141,14 +131,13 @@ async def test_guided_logits_processor_with_reasoning( ...@@ -141,14 +131,13 @@ async def test_guided_logits_processor_with_reasoning(
regex_request, deepseek_r1_qwen_tokenizer, config, regex_request, deepseek_r1_qwen_tokenizer, config,
reasoning_backend) reasoning_backend)
assert regex_lp is not None assert regex_lp is not None
tensor = torch.rand(32000) tensor = torch.rand(151664)
original_tensor = torch.clone(tensor) original_tensor = torch.clone(tensor)
tensor = regex_lp(token_ids, tensor) tensor = regex_lp(token_ids, tensor)
assert tensor.shape == original_tensor.shape assert tensor.shape == original_tensor.shape
assert torch.allclose(tensor, original_tensor) assert torch.allclose(tensor, original_tensor)
token_ids = deepseek_r1_qwen_tokenizer.encode( token_ids = deepseek_r1_qwen_tokenizer.encode(
f"Give an employee profile that fits this schema: {sample_json_schema}."
"<think>here is the thinking process") "<think>here is the thinking process")
json_request = GuidedDecodingParams(json=sample_json_schema, json_request = GuidedDecodingParams(json=sample_json_schema,
backend=backend) backend=backend)
...@@ -158,7 +147,7 @@ async def test_guided_logits_processor_with_reasoning( ...@@ -158,7 +147,7 @@ async def test_guided_logits_processor_with_reasoning(
await get_guided_decoding_logits_processor( await get_guided_decoding_logits_processor(
json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend) json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
assert json_lp is not None assert json_lp is not None
tensor = torch.rand(32000) tensor = torch.rand(151664)
original_tensor = torch.clone(tensor) original_tensor = torch.clone(tensor)
tensor = json_lp(token_ids, tensor) tensor = json_lp(token_ids, tensor)
assert tensor.shape == original_tensor.shape assert tensor.shape == original_tensor.shape
...@@ -166,8 +155,7 @@ async def test_guided_logits_processor_with_reasoning( ...@@ -166,8 +155,7 @@ async def test_guided_logits_processor_with_reasoning(
# Thinking is over, so the tensor should change. # Thinking is over, so the tensor should change.
token_ids = deepseek_r1_qwen_tokenizer.encode( token_ids = deepseek_r1_qwen_tokenizer.encode(
f"Give an employee profile that fits this schema: {sample_json_schema}." "<think>here is the thinking process</think>")
"<think>here is the thinking process</think> Then")
json_request = GuidedDecodingParams(json=sample_json_schema, json_request = GuidedDecodingParams(json=sample_json_schema,
backend=backend) backend=backend)
json_lp = get_local_guided_decoding_logits_processor( json_lp = get_local_guided_decoding_logits_processor(
...@@ -176,7 +164,7 @@ async def test_guided_logits_processor_with_reasoning( ...@@ -176,7 +164,7 @@ async def test_guided_logits_processor_with_reasoning(
await get_guided_decoding_logits_processor( await get_guided_decoding_logits_processor(
json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend) json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
assert json_lp is not None assert json_lp is not None
tensor = torch.rand(32000) tensor = torch.rand(151664)
original_tensor = torch.clone(tensor) original_tensor = torch.clone(tensor)
tensor = json_lp(token_ids, tensor) tensor = json_lp(token_ids, tensor)
assert tensor.shape == original_tensor.shape assert tensor.shape == original_tensor.shape
......
...@@ -72,7 +72,7 @@ def _compile_and_check(tools: list[ChatCompletionToolsParam], sample_output, ...@@ -72,7 +72,7 @@ def _compile_and_check(tools: list[ChatCompletionToolsParam], sample_output,
assert isinstance(schema, dict) assert isinstance(schema, dict)
# use build_regex_from_schema used in JSONLogitsProcessor to create Guide # use build_regex_from_schema used in JSONLogitsProcessor to create Guide
from outlines_core.fsm.json_schema import build_regex_from_schema from outlines_core.json_schema import build_regex_from_schema
regex = build_regex_from_schema(json.dumps(schema)) regex = build_regex_from_schema(json.dumps(schema))
compiled = re.compile(regex) compiled = re.compile(regex)
matches = compiled.fullmatch(json.dumps(sample_output)) is not None matches = compiled.fullmatch(json.dumps(sample_output)) is not None
......
...@@ -41,6 +41,10 @@ PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [ ...@@ -41,6 +41,10 @@ PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None), ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None),
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None), ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None),
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None), ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None),
("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", None),
("mistralai/Ministral-8B-Instruct-2410", "outlines", "mistral", None),
("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto",
NGRAM_SPEC_CONFIG),
#FIXME: This test is flaky on CI thus disabled #FIXME: This test is flaky on CI thus disabled
#("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"), #("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"),
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto",
...@@ -106,11 +110,13 @@ def test_structured_output( ...@@ -106,11 +110,13 @@ def test_structured_output(
enforce_eager = bool(not current_platform.is_tpu()) enforce_eager = bool(not current_platform.is_tpu())
# Use a single LLM instance for several scenarios to # Use a single LLM instance for several scenarios to
# speed up the test suite. # speed up the test suite.
llm = LLM(model=model_name, llm = LLM(
model=model_name,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
max_model_len=1024, max_model_len=1024,
guided_decoding_backend=guided_decoding_backend, guided_decoding_backend=guided_decoding_backend,
guided_decoding_disable_any_whitespace=True, guided_decoding_disable_any_whitespace=(guided_decoding_backend
in {"xgrammar", "guidance"}),
tokenizer_mode=tokenizer_mode, tokenizer_mode=tokenizer_mode,
speculative_config=speculative_config) speculative_config=speculative_config)
...@@ -146,14 +152,15 @@ def test_structured_output( ...@@ -146,14 +152,15 @@ def test_structured_output(
# #
# Test 2: Generate JSON object without a schema # Test 2: Generate JSON object without a schema
# #
if guided_decoding_backend != "outlines":
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=1.0, temperature=1.0,
max_tokens=4096, max_tokens=4096,
n=2, n=2,
guided_decoding=GuidedDecodingParams(json_object=True)) guided_decoding=GuidedDecodingParams(json_object=True))
outputs = llm.generate( outputs = llm.generate(prompts=(
prompts=("Generate a JSON object with curly braces for a person with " "Generate a JSON object with curly braces for a person with "
"name and age fields for John Smith who is 31 years old. " "name and age fields for John Smith who is 31 years old. "
"Make the response as short as possible."), "Make the response as short as possible."),
sampling_params=sampling_params, sampling_params=sampling_params,
...@@ -210,6 +217,7 @@ def test_structured_output( ...@@ -210,6 +217,7 @@ def test_structured_output(
parsed_json = json.loads(generated_text) parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict) assert isinstance(parsed_json, dict)
if guided_decoding_backend != "outlines":
# #
# Test 4: Generate SQL statement using EBNF grammar # Test 4: Generate SQL statement using EBNF grammar
# #
...@@ -293,8 +301,8 @@ def test_structured_output( ...@@ -293,8 +301,8 @@ def test_structured_output(
guided_decoding=GuidedDecodingParams(grammar="not a grammar")) guided_decoding=GuidedDecodingParams(grammar="not a grammar"))
with pytest.raises(ValueError, match="Failed to convert the grammar "): with pytest.raises(ValueError, match="Failed to convert the grammar "):
llm.generate( llm.generate(
prompts=( prompts=
"Generate a sql statement that selects col_1 from " ("Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1. Make the response as short " "table_1 where it is equal to 1. Make the response as short "
"as possible."), "as possible."),
sampling_params=sampling_params, sampling_params=sampling_params,
...@@ -421,6 +429,7 @@ def test_structured_output( ...@@ -421,6 +429,7 @@ def test_structured_output(
output_json = json.loads(generated_text) output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=json_schema) jsonschema.validate(instance=output_json, schema=json_schema)
if guided_decoding_backend != "outlines":
# #
# Test 11: Generate structured output using structural_tag format # Test 11: Generate structured output using structural_tag format
# #
......
...@@ -3580,7 +3580,8 @@ def get_served_model_name(model: str, ...@@ -3580,7 +3580,8 @@ def get_served_model_name(model: str,
GuidedDecodingBackendV0 = Literal["auto", "outlines", "lm-format-enforcer", GuidedDecodingBackendV0 = Literal["auto", "outlines", "lm-format-enforcer",
"xgrammar", "guidance"] "xgrammar", "guidance"]
GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance"]
GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance", "outlines"]
GuidedDecodingBackend = Literal[GuidedDecodingBackendV0, GuidedDecodingBackend = Literal[GuidedDecodingBackendV0,
GuidedDecodingBackendV1] GuidedDecodingBackendV1]
......
...@@ -117,6 +117,7 @@ if TYPE_CHECKING: ...@@ -117,6 +117,7 @@ if TYPE_CHECKING:
VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
VLLM_V0_USE_OUTLINES_CACHE: bool = False VLLM_V0_USE_OUTLINES_CACHE: bool = False
VLLM_V1_USE_OUTLINES_CACHE: bool = False
VLLM_TPU_BUCKET_PADDING_GAP: int = 0 VLLM_TPU_BUCKET_PADDING_GAP: int = 0
VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None
VLLM_USE_DEEP_GEMM: bool = False VLLM_USE_DEEP_GEMM: bool = False
...@@ -847,6 +848,12 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -847,6 +848,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_V0_USE_OUTLINES_CACHE": "VLLM_V0_USE_OUTLINES_CACHE":
lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1", lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1",
# Whether to turn on the outlines cache for V1
# This cache is unbounded and on disk, so it's not safe to use in
# an environment with potentially malicious users.
"VLLM_V1_USE_OUTLINES_CACHE":
lambda: os.environ.get("VLLM_V1_USE_OUTLINES_CACHE", "0") == "1",
# Gap between padding buckets for the forward pass. So we have # Gap between padding buckets for the forward pass. So we have
# 8, we will run forward pass with [16, 24, 32, ...]. # 8, we will run forward pass with [16, 24, 32, ...].
"VLLM_TPU_BUCKET_PADDING_GAP": "VLLM_TPU_BUCKET_PADDING_GAP":
......
...@@ -79,20 +79,33 @@ def maybe_backend_fallback( ...@@ -79,20 +79,33 @@ def maybe_backend_fallback(
fallback_or_error( fallback_or_error(
guided_params, guided_params,
"xgrammar does not support Lark grammars and the " "xgrammar does not support Lark grammars and the "
"grammar failed to convert to GBNF.", "outlines") "grammar failed to convert to GBNF.", "guidance")
# If the xgrammar module cannot be imported successfully, # If the xgrammar module cannot be imported successfully,
# we should still allow users to use guided decoding with a fallback. # we should still allow users to use guided decoding with a fallback.
elif not xgr_installed: elif not xgr_installed:
fallback_or_error( fallback_or_error(
guided_params, guided_params,
"xgrammar module cannot be imported successfully.", "outlines") "xgrammar module cannot be imported successfully.", "guidance")
if (guided_params.backend == "outlines" if guided_params.backend == "outlines":
and guided_params.json_object is not None): if guided_params.json_object is not None:
# outlines doesn't support json_object, fallback to guidance # outlines doesn't support json_object, fallback to guidance
fallback_or_error(guided_params, fallback_or_error(guided_params,
"outlines does not support json_object.", "guidance") "outlines does not support json_object.",
"guidance")
elif guided_params.grammar is not None:
# outlines grammar support has been removed, fallback to guidance
# if it is a lark-based grammar and xgrammar otherwise
if grammar_is_likely_lark(guided_params.grammar):
fallback_or_error(guided_params,
"outlines no longer supports grammars.",
"guidance")
else:
# The grammar is likely already GBNF format.
fallback_or_error(guided_params,
"outlines no longer supports grammars.",
"xgrammar")
return guided_params return guided_params
...@@ -111,7 +124,6 @@ async def get_guided_decoding_logits_processor( ...@@ -111,7 +124,6 @@ async def get_guided_decoding_logits_processor(
guided_params = maybe_backend_fallback(guided_params) guided_params = maybe_backend_fallback(guided_params)
# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend == 'outlines': if guided_params.backend == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
...@@ -152,7 +164,6 @@ def get_local_guided_decoding_logits_processor( ...@@ -152,7 +164,6 @@ def get_local_guided_decoding_logits_processor(
reasoning_backend) reasoning_backend)
reasoner = reasoner_class(tokenizer) reasoner = reasoner_class(tokenizer)
# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend == 'outlines': if guided_params.backend == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
......
...@@ -12,7 +12,7 @@ from regex import escape as regex_escape ...@@ -12,7 +12,7 @@ from regex import escape as regex_escape
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from vllm.model_executor.guided_decoding.outlines_logits_processors import ( from vllm.model_executor.guided_decoding.outlines_logits_processors import (
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor) JSONLogitsProcessor, RegexLogitsProcessor)
from vllm.reasoning import ReasoningParser from vllm.reasoning import ReasoningParser
from vllm.sampling_params import GuidedDecodingParams from vllm.sampling_params import GuidedDecodingParams
...@@ -21,36 +21,8 @@ class GuidedDecodingMode(Enum): ...@@ -21,36 +21,8 @@ class GuidedDecodingMode(Enum):
JSON = "json" JSON = "json"
REGEX = "regex" REGEX = "regex"
CHOICE = "choice" CHOICE = "choice"
GRAMMAR = "grammar"
# https://github.com/outlines-dev/outlines/blob/main/outlines/grammars/json.lark
# the main difference is that we changed the start: value to
# start: object | array, so we are denying scalar values as the root of the
# JSON. Starting with scalars as the root seems to cause llama to generate
# without stop.
JSON_GRAMMAR = r"""
?start: object | array
?value: object
| array
| UNESCAPED_STRING
| SIGNED_NUMBER -> number
| "true" -> true
| "false" -> false
| "null" -> null
array : "[" [value ("," value)*] "]"
object : "{" [pair ("," pair)*] "}"
pair : UNESCAPED_STRING ":" value
%import common.UNESCAPED_STRING
%import common.SIGNED_NUMBER
%import common.WS
%ignore WS
"""
global_thread_pool = None # used for generating logits processor fsm global_thread_pool = None # used for generating logits processor fsm
# It's not yet clear that using more provides a benefit, and it could # It's not yet clear that using more provides a benefit, and it could
...@@ -60,16 +32,12 @@ _MAX_THREADPOOL_WORKERS = 16 ...@@ -60,16 +32,12 @@ _MAX_THREADPOOL_WORKERS = 16
async def get_outlines_guided_decoding_logits_processor( async def get_outlines_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams, guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase,
tokenizer: PreTrainedTokenizerBase, reasoner: Optional[ReasoningParser]
reasoner: Optional[ReasoningParser], ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]:
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]:
""" """
Given an OpenAI-compatible request, check for guided decoding parameters Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide. and get the necessary logits processor for the given guide.
We cache logit processors by (guide, tokenizer), and on cache hit
we make a shallow copy to reuse the same underlying FSM.
""" """
global global_thread_pool global global_thread_pool
guide, mode = _get_guide_and_mode(guided_params) guide, mode = _get_guide_and_mode(guided_params)
...@@ -83,7 +51,6 @@ async def get_outlines_guided_decoding_logits_processor( ...@@ -83,7 +51,6 @@ async def get_outlines_guided_decoding_logits_processor(
global_thread_pool = concurrent.futures.ThreadPoolExecutor( global_thread_pool = concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers) max_workers=max_workers)
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
return await loop.run_in_executor(global_thread_pool, return await loop.run_in_executor(global_thread_pool,
_get_logits_processor, guide, tokenizer, _get_logits_processor, guide, tokenizer,
mode, guided_params.whitespace_pattern, mode, guided_params.whitespace_pattern,
...@@ -91,16 +58,12 @@ async def get_outlines_guided_decoding_logits_processor( ...@@ -91,16 +58,12 @@ async def get_outlines_guided_decoding_logits_processor(
def get_local_outlines_guided_decoding_logits_processor( def get_local_outlines_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams, guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase,
tokenizer: PreTrainedTokenizerBase, reasoner: Optional[ReasoningParser]
reasoner: Optional[ReasoningParser], ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]:
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]:
""" """
Given an OpenAI-compatible request, check for guided decoding parameters Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide. and get the necessary logits processor for the given guide.
We cache logit processors by (guide, tokenizer), and on cache hit
we make a shallow copy to reuse the same underlying FSM.
""" """
guide, mode = _get_guide_and_mode(guided_params) guide, mode = _get_guide_and_mode(guided_params)
if not guide or not mode: if not guide or not mode:
...@@ -130,9 +93,10 @@ def _get_guide_and_mode( ...@@ -130,9 +93,10 @@ def _get_guide_and_mode(
choices_regex = "(" + "|".join(choices) + ")" choices_regex = "(" + "|".join(choices) + ")"
return choices_regex, GuidedDecodingMode.CHOICE return choices_regex, GuidedDecodingMode.CHOICE
elif guided_params.grammar: elif guided_params.grammar:
return guided_params.grammar, GuidedDecodingMode.GRAMMAR raise ValueError(
elif guided_params.json_object: "The `outlines` guided decoding backend no longer supports grammar "
return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR "guided generation. Please use either the `xgrammar` or `guidance` "
"backend")
else: else:
return None, None return None, None
...@@ -143,13 +107,11 @@ def _get_logits_processor( ...@@ -143,13 +107,11 @@ def _get_logits_processor(
mode: GuidedDecodingMode, mode: GuidedDecodingMode,
whitespace_pattern: Union[str, None], whitespace_pattern: Union[str, None],
reasoner: Optional[ReasoningParser], reasoner: Optional[ReasoningParser],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]: ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]:
if mode == GuidedDecodingMode.JSON: if mode == GuidedDecodingMode.JSON:
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern, return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern,
reasoner) reasoner)
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE: elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
return RegexLogitsProcessor(guide, tokenizer, reasoner) return RegexLogitsProcessor(guide, tokenizer, reasoner)
elif mode == GuidedDecodingMode.GRAMMAR:
return CFGLogitsProcessor(guide, tokenizer, reasoner)
else: else:
raise ValueError(f"Unknown guided decoding mode {mode}") raise ValueError(f"Unknown guided decoding mode {mode}")
...@@ -23,6 +23,8 @@ from vllm.v1.engine import EngineCoreRequest ...@@ -23,6 +23,8 @@ from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
from vllm.v1.structured_output.backend_guidance import ( from vllm.v1.structured_output.backend_guidance import (
validate_guidance_grammar) validate_guidance_grammar)
from vllm.v1.structured_output.backend_outlines import (
validate_structured_output_request_outlines)
from vllm.v1.structured_output.backend_xgrammar import ( from vllm.v1.structured_output.backend_xgrammar import (
validate_xgrammar_grammar) validate_xgrammar_grammar)
...@@ -193,6 +195,9 @@ class Processor: ...@@ -193,6 +195,9 @@ class Processor:
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# Without tokenizer these are disallowed in grammars. # Without tokenizer these are disallowed in grammars.
validate_guidance_grammar(params, tokenizer=None) validate_guidance_grammar(params, tokenizer=None)
elif engine_level_backend == "outlines":
# outlines backend
validate_structured_output_request_outlines(params)
else: else:
# NOTE: engine_level_backend must be "auto" here, because we have # NOTE: engine_level_backend must be "auto" here, because we have
# checked supported_backends above. # checked supported_backends above.
......
...@@ -88,6 +88,15 @@ class StructuredOutputManager: ...@@ -88,6 +88,15 @@ class StructuredOutputManager:
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
vocab_size=vocab_size, vocab_size=vocab_size,
) )
elif backend == "outlines":
from vllm.v1.structured_output.backend_outlines import (
OutlinesBackend)
self.backend = OutlinesBackend(
self.vllm_config,
tokenizer=self.tokenizer,
vocab_size=vocab_size,
)
else: else:
raise ValueError( raise ValueError(
f"Unsupported structured output backend: {backend}") f"Unsupported structured output backend: {backend}")
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright 2025-present the Outlines developers
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import ast
import importlib
import json
import sys
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
import torch
from regex import escape as regex_escape
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
OutlinesVocabulary, get_cache, get_vocabulary)
from vllm.sampling_params import SamplingParams
from vllm.utils import LazyLoader
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
StructuredOutputGrammar,
StructuredOutputOptions)
if TYPE_CHECKING:
import outlines_core as oc
import outlines_core.json_schema as json_schema
else:
oc = LazyLoader("oc", globals(), "outlines_core")
json_schema = LazyLoader("json_schema", globals(),
"outlines_core.json_schema")
# Python 3.11+ sre_parse and sre_constants
# are deprecated, so we must import them from re
if sys.version_info >= (3, 11):
# Hack to get around pre-commit regex module rule
# because going through re is the only way to get sre_parse
# and sre_constants in Python 3.11+
_re = importlib.import_module("re")
sre_parse = _re._parser
sre_constants = _re._constants
else:
import sre_constants
import sre_parse
@dataclass
class OutlinesBackend(StructuredOutputBackend):
def __post_init__(self):
self.vocabulary = get_vocabulary(self.tokenizer)
self.cache = get_cache()
def _compile_index(self, regex_string: str,
vocabulary: OutlinesVocabulary) -> oc.Index:
cache_key = f"{vocabulary._hash}_{regex_string}"
if cache_key in self.cache:
return self.cache[cache_key]
index = oc.Index(regex_string, vocabulary.inner)
self.cache[cache_key] = index
return index
def compile_grammar(self, request_type: StructuredOutputOptions,
grammar_spec: str) -> StructuredOutputGrammar:
if request_type == StructuredOutputOptions.JSON:
regex = json_schema.build_regex_from_schema(grammar_spec)
elif request_type == StructuredOutputOptions.REGEX:
regex = grammar_spec
elif request_type == StructuredOutputOptions.CHOICE:
choices = ast.literal_eval(grammar_spec)
choices = [regex_escape(c) for c in choices]
regex = "(" + "|".join(choices) + ")"
else:
raise ValueError(
f"Invalid request type for Outlines backend ({request_type!s})"
)
index = self._compile_index(regex, self.vocabulary)
max_rollback_tokens = (
self.vllm_config.speculative_config.num_speculative_tokens
if self.vllm_config.speculative_config is not None else 0)
return OutlinesGrammar(vocab_size=self.vocab_size,
guide=oc.Guide(
index, max_rollback=max_rollback_tokens))
def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor:
return torch.full(
(max_num_seqs, (self.vocab_size + 31) // 32),
-1,
dtype=torch.int32,
pin_memory=torch.cuda.is_available(),
)
def destroy(self):
pass
@dataclass
class OutlinesGrammar(StructuredOutputGrammar):
vocab_size: int
guide: oc.Guide = field(hash=False)
num_processed_tokens: int = field(default_factory=lambda: 0,
repr=False,
hash=False,
init=False)
# outlines_core signals done on DFA accept; vLLM expects done after EOS.
# We delay the finished flag by one step so EOS can still be emitted.
_prev_finished: bool = field(default=False,
init=False,
repr=False,
hash=False)
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
"""Accepts a list of tokens and advances the FSM.
Returns True if the FSM was advanced successfully.
Returns False if the FSM failed to advance.
"""
if self.guide.accepts_tokens(tokens):
# Advance cannot fail because we checked Guide.accepts_tokens()
for t in tokens:
self.guide.advance(t)
self.num_processed_tokens += 1
return True
return False
def rollback(self, num_tokens: int) -> None:
self.guide.rollback_state(num_tokens)
self.num_processed_tokens -= num_tokens
def validate_tokens(self, tokens: list[int]) -> list[int]:
accepted: list[int] = []
for tok in tokens:
accepted.append(tok)
if not self.guide.accepts_tokens(accepted):
accepted.pop()
break
return accepted
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
mask = bitmask[idx]
self.guide.write_mask_into(mask.data_ptr(), mask.numel(),
mask.element_size())
def is_terminated(self) -> bool:
curr = self.guide.is_finished()
prev = self._prev_finished
self._prev_finished = curr
return prev
def reset(self):
self.num_processed_tokens = 0
self._prev_finished = False
self.guide.reset()
def validate_structured_output_request_outlines(params: SamplingParams):
if params.guided_decoding is None:
return
gd_params = params.guided_decoding
if gd_params.regex:
validate_regex_is_buildable(gd_params.regex)
elif gd_params.json:
if isinstance(gd_params.json, str):
try:
# make sure schema is valid json
json.loads(gd_params.json)
schema = gd_params.json
except json.JSONDecodeError as e:
raise ValueError("Invalid JSON grammar specification.") from e
else:
try:
schema = json.dumps(gd_params.json)
except Exception as e:
raise ValueError(
f"Error serializing guided decoding jsonschema: {e}"
) from e
pattern = json_schema.build_regex_from_schema(schema)
validate_regex_is_buildable(pattern)
elif gd_params.choice:
choices = [regex_escape(str(choice)) for choice in gd_params.choice]
regex = "(" + "|".join(choices) + ")"
validate_regex_is_buildable(regex)
elif gd_params.grammar:
raise ValueError("Outlines guided decoding backend "
"does not support grammar specifications")
def _prefix_needs_context(parsed) -> bool:
"""Return True if there's a look-around/anchor before any consumer."""
def subpattern_consumes(parsed) -> bool:
"""Return True if subpattern can consume at least one character."""
tokens = parsed.data if hasattr(parsed, 'data') else parsed
for ttype, tval in tokens:
# literal, character class, or dot always consumes
if ttype in (sre_parse.LITERAL, sre_parse.IN, sre_parse.ANY):
return True
# quantified subpattern: check inner pattern
elif ttype == sre_parse.MAX_REPEAT:
_, mx, sub = tval
if mx != 0 and subpattern_consumes(sub):
return True
# alternation: if any branch consumes, the whole does
elif ttype == sre_parse.BRANCH:
_, branches = tval
if any(subpattern_consumes(br) for br in branches):
return True
# grouped subpattern: recurse into its contents
elif ttype == sre_parse.SUBPATTERN and subpattern_consumes(
tval[3]):
return True
# No consumers, return False
return False
tokens = parsed.data if hasattr(parsed, 'data') else parsed
for ttype, tval in tokens:
# Direct anchors or look-around
if ttype == sre_parse.AT or ttype in (sre_constants.ASSERT,
sre_constants.ASSERT_NOT):
return True
# Nested subpattern: check
if ttype == sre_parse.SUBPATTERN:
# tval: (group, add_flags, del_flags, subpattern)
if _prefix_needs_context(tval[3]):
return True
if subpattern_consumes(tval[3]):
return False
# if any branch has a prefix anchor => True,
# else if at least one branch consumes => prefix ends => False
elif ttype == sre_parse.BRANCH:
saw_consumer = False
for br in tval[1]:
if _prefix_needs_context(br):
return True
if subpattern_consumes(br):
saw_consumer = True
if saw_consumer:
return False
# Immediate consumer tokens
elif ttype in (sre_parse.LITERAL, sre_parse.IN, sre_parse.ANY):
return False
# if subpattern has anchor => True, if it can consume => stop
elif ttype == sre_parse.MAX_REPEAT:
if _prefix_needs_context(tval[2]):
return True
if subpattern_consumes(tval[2]):
return False
return False
def _check_unsupported(parsed) -> None:
"""Check for regex features unsupported by regex-automata"""
tokens = parsed.data if hasattr(parsed, 'data') else parsed
for ttype, tval in tokens:
# backreference
if ttype in (sre_parse.GROUPREF, sre_parse.GROUPREF_EXISTS):
raise ValueError("Backreferences are unsupported.")
# look-around assertion
elif ttype in (sre_constants.ASSERT, sre_constants.ASSERT_NOT):
raise ValueError("Look-Around assertion are unsupported.")
# unicode word boundaries
elif ttype == sre_parse.AT:
if tval in (sre_constants.AT_BOUNDARY,
sre_constants.AT_NON_BOUNDARY):
raise ValueError("Unicode word boundaries are unsupported.")
elif ttype == sre_parse.BRANCH:
# tval is (None, branches)
for branch in tval[1]:
_check_unsupported(branch)
# tval is (min, max, subpattern)
elif ttype == sre_parse.MAX_REPEAT:
_check_unsupported(tval[2])
def validate_regex_is_buildable(pattern: str) -> None:
"""
Validates that the input regex is not using unsupported features
of the `regex-automata` crate (outlines_core regex engine) and has a
universal start state.
definition of universal start state used can be found at:
https://docs.rs/regex-automata/latest/regex_automata/dfa/trait.Automaton.html#method.universal_start_state
"""
try:
parsed = sre_parse.parse(pattern)
except sre_constants.error as e:
raise ValueError(f"Error parsing regex: {e}") from e
try:
_check_unsupported(parsed)
except ValueError as e:
raise ValueError(
f"Regex uses unsupported feature for guided decoding: {e}. "
"Only basic matching constructs are supported—lookarounds, "
"backreferences, and unicode boundaries are not.") from e
if _prefix_needs_context(parsed):
raise ValueError(
"Regex does not have a anchored universal start state"
"This means that the Regex uses anchors (^) or look-arounds "
"in a way which requires context before any token is matched."
"Guided decoding needs regexes that can match without needing "
"that context. Try rewriting the pattern without using these "
f"constructs. Pattern:\n{pattern}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment