Unverified Commit d6902ce7 authored by Nathan Hoos's avatar Nathan Hoos Committed by GitHub
Browse files

[V0][V1][Core] Add outlines integration for V1, and update V0 integration. (#15975)


Signed-off-by: default avatarNathan Hoos <thwackyy.y@gmail.com>
parent 5e53c89a
......@@ -21,7 +21,9 @@ prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer >= 0.10.11, < 0.11
llguidance >= 0.7.11, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64"
outlines == 0.1.11
outlines_core == 0.2.10
# required for outlines backend disk cache
diskcache == 5.6.3
lark == 1.2.2
xgrammar == 0.1.19; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64"
typing_extensions >= 4.10
......
......@@ -16,14 +16,18 @@ from vllm.outputs import RequestOutput
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
GUIDED_DECODING_BACKENDS = [
# Separate backends which support grammars vs ones
# which only support regex based constraints in tests.
GRAMMAR_DECODING_BACKENDS = [
# (backend, disable_any_whitespace),
("outlines", False),
("lm-format-enforcer", False),
("xgrammar", True),
("guidance", True),
]
ALL_DECODING_BACKENDS = ([("outlines", False)] + GRAMMAR_DECODING_BACKENDS)
@pytest.fixture(scope="module")
def llm():
......@@ -39,7 +43,7 @@ def llm():
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
ALL_DECODING_BACKENDS)
def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
disable_any_whitespace: bool):
sampling_params = SamplingParams(
......@@ -49,6 +53,7 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
regex=sample_regex,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))
outputs = llm.generate(prompts=[
f"Give an example IPv4 address with this regex: {sample_regex}"
] * 2,
......@@ -69,7 +74,7 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
ALL_DECODING_BACKENDS)
def test_guided_json_completion(sample_json_schema, llm,
guided_decoding_backend: str,
disable_any_whitespace: bool):
......@@ -103,7 +108,7 @@ def test_guided_json_completion(sample_json_schema, llm,
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
ALL_DECODING_BACKENDS)
def test_guided_complex_json_completion(sample_complex_json_schema, llm,
guided_decoding_backend: str,
disable_any_whitespace: bool):
......@@ -138,7 +143,7 @@ def test_guided_complex_json_completion(sample_complex_json_schema, llm,
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
ALL_DECODING_BACKENDS)
def test_guided_definition_json_completion(sample_definition_json_schema, llm,
guided_decoding_backend: str,
disable_any_whitespace: bool):
......@@ -173,7 +178,7 @@ def test_guided_definition_json_completion(sample_definition_json_schema, llm,
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
ALL_DECODING_BACKENDS)
def test_guided_enum_json_completion(sample_enum_json_schema, llm,
guided_decoding_backend: str,
disable_any_whitespace: bool):
......@@ -218,7 +223,7 @@ def test_guided_enum_json_completion(sample_enum_json_schema, llm,
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
ALL_DECODING_BACKENDS)
def test_guided_choice_completion(sample_guided_choice, llm,
guided_decoding_backend: str,
disable_any_whitespace: bool):
......@@ -248,7 +253,7 @@ def test_guided_choice_completion(sample_guided_choice, llm,
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
GRAMMAR_DECODING_BACKENDS)
def test_guided_grammar(sample_sql_statements, llm,
guided_decoding_backend: str,
disable_any_whitespace: bool):
......@@ -344,7 +349,7 @@ def test_disable_guided_decoding_fallback(sample_regex, llm):
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
GRAMMAR_DECODING_BACKENDS)
def test_guided_json_object(llm, guided_decoding_backend: str,
disable_any_whitespace: bool):
sampling_params = SamplingParams(
......@@ -377,7 +382,9 @@ def test_guided_json_object(llm, guided_decoding_backend: str,
# Parse to verify it is valid JSON
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)
# A list is not what was intended, but is still valid
# json.
assert isinstance(parsed_json, (dict, list))
class CarType(str, Enum):
......@@ -395,7 +402,7 @@ class CarDescription(BaseModel):
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
ALL_DECODING_BACKENDS)
def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str,
disable_any_whitespace: bool):
json_schema = CarDescription.model_json_schema()
......@@ -427,7 +434,7 @@ def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str,
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
ALL_DECODING_BACKENDS)
def test_guided_number_range_json_completion(llm, guided_decoding_backend: str,
disable_any_whitespace: bool):
sample_output_schema = {
......
......@@ -46,20 +46,15 @@ def test_guided_logits_processors(zephyr_7B_tokenzer, sample_regex,
whitespace_pattern=None,
reasoner=None)
token_ids = zephyr_7B_tokenzer.encode(
f"Give an example IPv4 address with this regex: {sample_regex}")
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
regex_LP(token_ids, tensor)
tensor = regex_LP([], tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)
token_ids = zephyr_7B_tokenzer.encode(
f"Give an employee profile that fits this schema: {sample_json_schema}"
)
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
json_LP(token_ids, tensor)
tensor = json_LP([], tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)
......@@ -81,8 +76,6 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
seed=0,
dtype="bfloat16",
)
token_ids = zephyr_7B_tokenzer.encode(
f"Give an example IPv4 address with this regex: {sample_regex}")
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
regex_lp = get_local_guided_decoding_logits_processor(
......@@ -92,13 +85,11 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
assert regex_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
tensor = regex_lp(token_ids, tensor)
# allowed tokens at state 0
tensor = regex_lp([], tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)
token_ids = zephyr_7B_tokenzer.encode(
f"Give an employee profile that fits this schema: {sample_json_schema}"
)
json_request = GuidedDecodingParams(json=sample_json_schema,
backend=backend)
json_lp = await get_guided_decoding_logits_processor(
......@@ -106,7 +97,7 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
assert json_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
tensor = json_lp(token_ids, tensor)
tensor = json_lp([], tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)
......@@ -130,7 +121,6 @@ async def test_guided_logits_processor_with_reasoning(
dtype="bfloat16",
)
token_ids = deepseek_r1_qwen_tokenizer.encode(
f"Give an example IPv4 address with this regex: {sample_regex}."
"<think>here is the thinking process")
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
......@@ -141,14 +131,13 @@ async def test_guided_logits_processor_with_reasoning(
regex_request, deepseek_r1_qwen_tokenizer, config,
reasoning_backend)
assert regex_lp is not None
tensor = torch.rand(32000)
tensor = torch.rand(151664)
original_tensor = torch.clone(tensor)
tensor = regex_lp(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert torch.allclose(tensor, original_tensor)
token_ids = deepseek_r1_qwen_tokenizer.encode(
f"Give an employee profile that fits this schema: {sample_json_schema}."
"<think>here is the thinking process")
json_request = GuidedDecodingParams(json=sample_json_schema,
backend=backend)
......@@ -158,7 +147,7 @@ async def test_guided_logits_processor_with_reasoning(
await get_guided_decoding_logits_processor(
json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
assert json_lp is not None
tensor = torch.rand(32000)
tensor = torch.rand(151664)
original_tensor = torch.clone(tensor)
tensor = json_lp(token_ids, tensor)
assert tensor.shape == original_tensor.shape
......@@ -166,8 +155,7 @@ async def test_guided_logits_processor_with_reasoning(
# Thinking is over, so the tensor should change.
token_ids = deepseek_r1_qwen_tokenizer.encode(
f"Give an employee profile that fits this schema: {sample_json_schema}."
"<think>here is the thinking process</think> Then")
"<think>here is the thinking process</think>")
json_request = GuidedDecodingParams(json=sample_json_schema,
backend=backend)
json_lp = get_local_guided_decoding_logits_processor(
......@@ -176,7 +164,7 @@ async def test_guided_logits_processor_with_reasoning(
await get_guided_decoding_logits_processor(
json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
assert json_lp is not None
tensor = torch.rand(32000)
tensor = torch.rand(151664)
original_tensor = torch.clone(tensor)
tensor = json_lp(token_ids, tensor)
assert tensor.shape == original_tensor.shape
......
......@@ -72,7 +72,7 @@ def _compile_and_check(tools: list[ChatCompletionToolsParam], sample_output,
assert isinstance(schema, dict)
# use build_regex_from_schema used in JSONLogitsProcessor to create Guide
from outlines_core.fsm.json_schema import build_regex_from_schema
from outlines_core.json_schema import build_regex_from_schema
regex = build_regex_from_schema(json.dumps(schema))
compiled = re.compile(regex)
matches = compiled.fullmatch(json.dumps(sample_output)) is not None
......
......@@ -41,6 +41,10 @@ PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None),
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None),
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None),
("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", None),
("mistralai/Ministral-8B-Instruct-2410", "outlines", "mistral", None),
("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto",
NGRAM_SPEC_CONFIG),
#FIXME: This test is flaky on CI thus disabled
#("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"),
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto",
......@@ -106,11 +110,13 @@ def test_structured_output(
enforce_eager = bool(not current_platform.is_tpu())
# Use a single LLM instance for several scenarios to
# speed up the test suite.
llm = LLM(model=model_name,
llm = LLM(
model=model_name,
enforce_eager=enforce_eager,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend,
guided_decoding_disable_any_whitespace=True,
guided_decoding_disable_any_whitespace=(guided_decoding_backend
in {"xgrammar", "guidance"}),
tokenizer_mode=tokenizer_mode,
speculative_config=speculative_config)
......@@ -146,14 +152,15 @@ def test_structured_output(
#
# Test 2: Generate JSON object without a schema
#
if guided_decoding_backend != "outlines":
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=4096,
n=2,
guided_decoding=GuidedDecodingParams(json_object=True))
outputs = llm.generate(
prompts=("Generate a JSON object with curly braces for a person with "
outputs = llm.generate(prompts=(
"Generate a JSON object with curly braces for a person with "
"name and age fields for John Smith who is 31 years old. "
"Make the response as short as possible."),
sampling_params=sampling_params,
......@@ -210,6 +217,7 @@ def test_structured_output(
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)
if guided_decoding_backend != "outlines":
#
# Test 4: Generate SQL statement using EBNF grammar
#
......@@ -293,8 +301,8 @@ def test_structured_output(
guided_decoding=GuidedDecodingParams(grammar="not a grammar"))
with pytest.raises(ValueError, match="Failed to convert the grammar "):
llm.generate(
prompts=(
"Generate a sql statement that selects col_1 from "
prompts=
("Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1. Make the response as short "
"as possible."),
sampling_params=sampling_params,
......@@ -421,6 +429,7 @@ def test_structured_output(
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=json_schema)
if guided_decoding_backend != "outlines":
#
# Test 11: Generate structured output using structural_tag format
#
......
......@@ -3580,7 +3580,8 @@ def get_served_model_name(model: str,
GuidedDecodingBackendV0 = Literal["auto", "outlines", "lm-format-enforcer",
"xgrammar", "guidance"]
GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance"]
GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance", "outlines"]
GuidedDecodingBackend = Literal[GuidedDecodingBackendV0,
GuidedDecodingBackendV1]
......
......@@ -117,6 +117,7 @@ if TYPE_CHECKING:
VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
VLLM_V0_USE_OUTLINES_CACHE: bool = False
VLLM_V1_USE_OUTLINES_CACHE: bool = False
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None
VLLM_USE_DEEP_GEMM: bool = False
......@@ -847,6 +848,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_V0_USE_OUTLINES_CACHE":
lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1",
# Whether to turn on the outlines cache for V1
# This cache is unbounded and on disk, so it's not safe to use in
# an environment with potentially malicious users.
"VLLM_V1_USE_OUTLINES_CACHE":
lambda: os.environ.get("VLLM_V1_USE_OUTLINES_CACHE", "0") == "1",
# Gap between padding buckets for the forward pass. So we have
# 8, we will run forward pass with [16, 24, 32, ...].
"VLLM_TPU_BUCKET_PADDING_GAP":
......
......@@ -79,20 +79,33 @@ def maybe_backend_fallback(
fallback_or_error(
guided_params,
"xgrammar does not support Lark grammars and the "
"grammar failed to convert to GBNF.", "outlines")
"grammar failed to convert to GBNF.", "guidance")
# If the xgrammar module cannot be imported successfully,
# we should still allow users to use guided decoding with a fallback.
elif not xgr_installed:
fallback_or_error(
guided_params,
"xgrammar module cannot be imported successfully.", "outlines")
"xgrammar module cannot be imported successfully.", "guidance")
if (guided_params.backend == "outlines"
and guided_params.json_object is not None):
if guided_params.backend == "outlines":
if guided_params.json_object is not None:
# outlines doesn't support json_object, fallback to guidance
fallback_or_error(guided_params,
"outlines does not support json_object.", "guidance")
"outlines does not support json_object.",
"guidance")
elif guided_params.grammar is not None:
# outlines grammar support has been removed, fallback to guidance
# if it is a lark-based grammar and xgrammar otherwise
if grammar_is_likely_lark(guided_params.grammar):
fallback_or_error(guided_params,
"outlines no longer supports grammars.",
"guidance")
else:
# The grammar is likely already GBNF format.
fallback_or_error(guided_params,
"outlines no longer supports grammars.",
"xgrammar")
return guided_params
......@@ -111,7 +124,6 @@ async def get_guided_decoding_logits_processor(
guided_params = maybe_backend_fallback(guided_params)
# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
......@@ -152,7 +164,6 @@ def get_local_guided_decoding_logits_processor(
reasoning_backend)
reasoner = reasoner_class(tokenizer)
# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
......
......@@ -12,7 +12,7 @@ from regex import escape as regex_escape
from transformers import PreTrainedTokenizerBase
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
JSONLogitsProcessor, RegexLogitsProcessor)
from vllm.reasoning import ReasoningParser
from vllm.sampling_params import GuidedDecodingParams
......@@ -21,36 +21,8 @@ class GuidedDecodingMode(Enum):
JSON = "json"
REGEX = "regex"
CHOICE = "choice"
GRAMMAR = "grammar"
# https://github.com/outlines-dev/outlines/blob/main/outlines/grammars/json.lark
# the main difference is that we changed the start: value to
# start: object | array, so we are denying scalar values as the root of the
# JSON. Starting with scalars as the root seems to cause llama to generate
# without stop.
JSON_GRAMMAR = r"""
?start: object | array
?value: object
| array
| UNESCAPED_STRING
| SIGNED_NUMBER -> number
| "true" -> true
| "false" -> false
| "null" -> null
array : "[" [value ("," value)*] "]"
object : "{" [pair ("," pair)*] "}"
pair : UNESCAPED_STRING ":" value
%import common.UNESCAPED_STRING
%import common.SIGNED_NUMBER
%import common.WS
%ignore WS
"""
global_thread_pool = None # used for generating logits processor fsm
# It's not yet clear that using more provides a benefit, and it could
......@@ -60,16 +32,12 @@ _MAX_THREADPOOL_WORKERS = 16
async def get_outlines_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[ReasoningParser],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]:
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[ReasoningParser]
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
We cache logit processors by (guide, tokenizer), and on cache hit
we make a shallow copy to reuse the same underlying FSM.
"""
global global_thread_pool
guide, mode = _get_guide_and_mode(guided_params)
......@@ -83,7 +51,6 @@ async def get_outlines_guided_decoding_logits_processor(
global_thread_pool = concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers)
loop = asyncio.get_running_loop()
return await loop.run_in_executor(global_thread_pool,
_get_logits_processor, guide, tokenizer,
mode, guided_params.whitespace_pattern,
......@@ -91,16 +58,12 @@ async def get_outlines_guided_decoding_logits_processor(
def get_local_outlines_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[ReasoningParser],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]:
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[ReasoningParser]
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
We cache logit processors by (guide, tokenizer), and on cache hit
we make a shallow copy to reuse the same underlying FSM.
"""
guide, mode = _get_guide_and_mode(guided_params)
if not guide or not mode:
......@@ -130,9 +93,10 @@ def _get_guide_and_mode(
choices_regex = "(" + "|".join(choices) + ")"
return choices_regex, GuidedDecodingMode.CHOICE
elif guided_params.grammar:
return guided_params.grammar, GuidedDecodingMode.GRAMMAR
elif guided_params.json_object:
return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR
raise ValueError(
"The `outlines` guided decoding backend no longer supports grammar "
"guided generation. Please use either the `xgrammar` or `guidance` "
"backend")
else:
return None, None
......@@ -143,13 +107,11 @@ def _get_logits_processor(
mode: GuidedDecodingMode,
whitespace_pattern: Union[str, None],
reasoner: Optional[ReasoningParser],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]:
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]:
if mode == GuidedDecodingMode.JSON:
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern,
reasoner)
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
return RegexLogitsProcessor(guide, tokenizer, reasoner)
elif mode == GuidedDecodingMode.GRAMMAR:
return CFGLogitsProcessor(guide, tokenizer, reasoner)
else:
raise ValueError(f"Unknown guided decoding mode {mode}")
......@@ -23,6 +23,8 @@ from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
from vllm.v1.structured_output.backend_guidance import (
validate_guidance_grammar)
from vllm.v1.structured_output.backend_outlines import (
validate_structured_output_request_outlines)
from vllm.v1.structured_output.backend_xgrammar import (
validate_xgrammar_grammar)
......@@ -193,6 +195,9 @@ class Processor:
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# Without tokenizer these are disallowed in grammars.
validate_guidance_grammar(params, tokenizer=None)
elif engine_level_backend == "outlines":
# outlines backend
validate_structured_output_request_outlines(params)
else:
# NOTE: engine_level_backend must be "auto" here, because we have
# checked supported_backends above.
......
......@@ -88,6 +88,15 @@ class StructuredOutputManager:
tokenizer=self.tokenizer,
vocab_size=vocab_size,
)
elif backend == "outlines":
from vllm.v1.structured_output.backend_outlines import (
OutlinesBackend)
self.backend = OutlinesBackend(
self.vllm_config,
tokenizer=self.tokenizer,
vocab_size=vocab_size,
)
else:
raise ValueError(
f"Unsupported structured output backend: {backend}")
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright 2025-present the Outlines developers
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import ast
import importlib
import json
import sys
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
import torch
from regex import escape as regex_escape
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
OutlinesVocabulary, get_cache, get_vocabulary)
from vllm.sampling_params import SamplingParams
from vllm.utils import LazyLoader
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
StructuredOutputGrammar,
StructuredOutputOptions)
if TYPE_CHECKING:
import outlines_core as oc
import outlines_core.json_schema as json_schema
else:
oc = LazyLoader("oc", globals(), "outlines_core")
json_schema = LazyLoader("json_schema", globals(),
"outlines_core.json_schema")
# Python 3.11+ sre_parse and sre_constants
# are deprecated, so we must import them from re
if sys.version_info >= (3, 11):
# Hack to get around pre-commit regex module rule
# because going through re is the only way to get sre_parse
# and sre_constants in Python 3.11+
_re = importlib.import_module("re")
sre_parse = _re._parser
sre_constants = _re._constants
else:
import sre_constants
import sre_parse
@dataclass
class OutlinesBackend(StructuredOutputBackend):
def __post_init__(self):
self.vocabulary = get_vocabulary(self.tokenizer)
self.cache = get_cache()
def _compile_index(self, regex_string: str,
vocabulary: OutlinesVocabulary) -> oc.Index:
cache_key = f"{vocabulary._hash}_{regex_string}"
if cache_key in self.cache:
return self.cache[cache_key]
index = oc.Index(regex_string, vocabulary.inner)
self.cache[cache_key] = index
return index
def compile_grammar(self, request_type: StructuredOutputOptions,
grammar_spec: str) -> StructuredOutputGrammar:
if request_type == StructuredOutputOptions.JSON:
regex = json_schema.build_regex_from_schema(grammar_spec)
elif request_type == StructuredOutputOptions.REGEX:
regex = grammar_spec
elif request_type == StructuredOutputOptions.CHOICE:
choices = ast.literal_eval(grammar_spec)
choices = [regex_escape(c) for c in choices]
regex = "(" + "|".join(choices) + ")"
else:
raise ValueError(
f"Invalid request type for Outlines backend ({request_type!s})"
)
index = self._compile_index(regex, self.vocabulary)
max_rollback_tokens = (
self.vllm_config.speculative_config.num_speculative_tokens
if self.vllm_config.speculative_config is not None else 0)
return OutlinesGrammar(vocab_size=self.vocab_size,
guide=oc.Guide(
index, max_rollback=max_rollback_tokens))
def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor:
return torch.full(
(max_num_seqs, (self.vocab_size + 31) // 32),
-1,
dtype=torch.int32,
pin_memory=torch.cuda.is_available(),
)
def destroy(self):
pass
@dataclass
class OutlinesGrammar(StructuredOutputGrammar):
vocab_size: int
guide: oc.Guide = field(hash=False)
num_processed_tokens: int = field(default_factory=lambda: 0,
repr=False,
hash=False,
init=False)
# outlines_core signals done on DFA accept; vLLM expects done after EOS.
# We delay the finished flag by one step so EOS can still be emitted.
_prev_finished: bool = field(default=False,
init=False,
repr=False,
hash=False)
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
"""Accepts a list of tokens and advances the FSM.
Returns True if the FSM was advanced successfully.
Returns False if the FSM failed to advance.
"""
if self.guide.accepts_tokens(tokens):
# Advance cannot fail because we checked Guide.accepts_tokens()
for t in tokens:
self.guide.advance(t)
self.num_processed_tokens += 1
return True
return False
def rollback(self, num_tokens: int) -> None:
self.guide.rollback_state(num_tokens)
self.num_processed_tokens -= num_tokens
def validate_tokens(self, tokens: list[int]) -> list[int]:
accepted: list[int] = []
for tok in tokens:
accepted.append(tok)
if not self.guide.accepts_tokens(accepted):
accepted.pop()
break
return accepted
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
mask = bitmask[idx]
self.guide.write_mask_into(mask.data_ptr(), mask.numel(),
mask.element_size())
def is_terminated(self) -> bool:
curr = self.guide.is_finished()
prev = self._prev_finished
self._prev_finished = curr
return prev
def reset(self):
self.num_processed_tokens = 0
self._prev_finished = False
self.guide.reset()
def validate_structured_output_request_outlines(params: SamplingParams):
if params.guided_decoding is None:
return
gd_params = params.guided_decoding
if gd_params.regex:
validate_regex_is_buildable(gd_params.regex)
elif gd_params.json:
if isinstance(gd_params.json, str):
try:
# make sure schema is valid json
json.loads(gd_params.json)
schema = gd_params.json
except json.JSONDecodeError as e:
raise ValueError("Invalid JSON grammar specification.") from e
else:
try:
schema = json.dumps(gd_params.json)
except Exception as e:
raise ValueError(
f"Error serializing guided decoding jsonschema: {e}"
) from e
pattern = json_schema.build_regex_from_schema(schema)
validate_regex_is_buildable(pattern)
elif gd_params.choice:
choices = [regex_escape(str(choice)) for choice in gd_params.choice]
regex = "(" + "|".join(choices) + ")"
validate_regex_is_buildable(regex)
elif gd_params.grammar:
raise ValueError("Outlines guided decoding backend "
"does not support grammar specifications")
def _prefix_needs_context(parsed) -> bool:
"""Return True if there's a look-around/anchor before any consumer."""
def subpattern_consumes(parsed) -> bool:
"""Return True if subpattern can consume at least one character."""
tokens = parsed.data if hasattr(parsed, 'data') else parsed
for ttype, tval in tokens:
# literal, character class, or dot always consumes
if ttype in (sre_parse.LITERAL, sre_parse.IN, sre_parse.ANY):
return True
# quantified subpattern: check inner pattern
elif ttype == sre_parse.MAX_REPEAT:
_, mx, sub = tval
if mx != 0 and subpattern_consumes(sub):
return True
# alternation: if any branch consumes, the whole does
elif ttype == sre_parse.BRANCH:
_, branches = tval
if any(subpattern_consumes(br) for br in branches):
return True
# grouped subpattern: recurse into its contents
elif ttype == sre_parse.SUBPATTERN and subpattern_consumes(
tval[3]):
return True
# No consumers, return False
return False
tokens = parsed.data if hasattr(parsed, 'data') else parsed
for ttype, tval in tokens:
# Direct anchors or look-around
if ttype == sre_parse.AT or ttype in (sre_constants.ASSERT,
sre_constants.ASSERT_NOT):
return True
# Nested subpattern: check
if ttype == sre_parse.SUBPATTERN:
# tval: (group, add_flags, del_flags, subpattern)
if _prefix_needs_context(tval[3]):
return True
if subpattern_consumes(tval[3]):
return False
# if any branch has a prefix anchor => True,
# else if at least one branch consumes => prefix ends => False
elif ttype == sre_parse.BRANCH:
saw_consumer = False
for br in tval[1]:
if _prefix_needs_context(br):
return True
if subpattern_consumes(br):
saw_consumer = True
if saw_consumer:
return False
# Immediate consumer tokens
elif ttype in (sre_parse.LITERAL, sre_parse.IN, sre_parse.ANY):
return False
# if subpattern has anchor => True, if it can consume => stop
elif ttype == sre_parse.MAX_REPEAT:
if _prefix_needs_context(tval[2]):
return True
if subpattern_consumes(tval[2]):
return False
return False
def _check_unsupported(parsed) -> None:
"""Check for regex features unsupported by regex-automata"""
tokens = parsed.data if hasattr(parsed, 'data') else parsed
for ttype, tval in tokens:
# backreference
if ttype in (sre_parse.GROUPREF, sre_parse.GROUPREF_EXISTS):
raise ValueError("Backreferences are unsupported.")
# look-around assertion
elif ttype in (sre_constants.ASSERT, sre_constants.ASSERT_NOT):
raise ValueError("Look-Around assertion are unsupported.")
# unicode word boundaries
elif ttype == sre_parse.AT:
if tval in (sre_constants.AT_BOUNDARY,
sre_constants.AT_NON_BOUNDARY):
raise ValueError("Unicode word boundaries are unsupported.")
elif ttype == sre_parse.BRANCH:
# tval is (None, branches)
for branch in tval[1]:
_check_unsupported(branch)
# tval is (min, max, subpattern)
elif ttype == sre_parse.MAX_REPEAT:
_check_unsupported(tval[2])
def validate_regex_is_buildable(pattern: str) -> None:
"""
Validates that the input regex is not using unsupported features
of the `regex-automata` crate (outlines_core regex engine) and has a
universal start state.
definition of universal start state used can be found at:
https://docs.rs/regex-automata/latest/regex_automata/dfa/trait.Automaton.html#method.universal_start_state
"""
try:
parsed = sre_parse.parse(pattern)
except sre_constants.error as e:
raise ValueError(f"Error parsing regex: {e}") from e
try:
_check_unsupported(parsed)
except ValueError as e:
raise ValueError(
f"Regex uses unsupported feature for guided decoding: {e}. "
"Only basic matching constructs are supported—lookarounds, "
"backreferences, and unicode boundaries are not.") from e
if _prefix_needs_context(parsed):
raise ValueError(
"Regex does not have a anchored universal start state"
"This means that the Regex uses anchors (^) or look-arounds "
"in a way which requires context before any token is matched."
"Guided decoding needs regexes that can match without needing "
"that context. Try rewriting the pattern without using these "
f"constructs. Pattern:\n{pattern}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment