Unverified Commit 29283e89 authored by Aaron Pham's avatar Aaron Pham Committed by GitHub
Browse files

[Chore] Cleanup guided namespace, move to structured outputs config (#22772)


Signed-off-by: default avatarAaron Pham <contact@aarnphm.xyz>
Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 05b044e6
......@@ -2,13 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Sampling parameters for text generation."""
import copy
from dataclasses import dataclass
from dataclasses import field
from enum import Enum, IntEnum
from functools import cached_property
from typing import Annotated, Any, Optional, Union
import msgspec
from pydantic import BaseModel
from pydantic.dataclasses import dataclass
from vllm.logger import init_logger
from vllm.logits_process import LogitsProcessor
......@@ -28,60 +28,35 @@ class SamplingType(IntEnum):
# maybe make msgspec?
@dataclass
class GuidedDecodingParams:
"""One of these fields will be used to build a logit processor."""
class StructuredOutputsParams:
# One of these fields will be used to build a logit processor.
json: Optional[Union[str, dict]] = None
regex: Optional[str] = None
choice: Optional[list[str]] = None
grammar: Optional[str] = None
json_object: Optional[bool] = None
"""These are other options that can be set"""
backend: Optional[str] = None
backend_was_auto: bool = False
# These are other options that can be set.
disable_fallback: bool = False
disable_any_whitespace: bool = False
disable_additional_properties: bool = False
whitespace_pattern: Optional[str] = None
structural_tag: Optional[str] = None
@staticmethod
def from_optional(
json: Optional[Union[dict, BaseModel, str]] = None,
regex: Optional[str] = None,
choice: Optional[list[str]] = None,
grammar: Optional[str] = None,
json_object: Optional[bool] = None,
backend: Optional[str] = None,
whitespace_pattern: Optional[str] = None,
structural_tag: Optional[str] = None,
) -> Optional["GuidedDecodingParams"]:
if all(arg is None for arg in (json, regex, choice, grammar,
json_object, structural_tag)):
return None
# Extract json schemas from pydantic models
if isinstance(json, (BaseModel, type(BaseModel))):
json = json.model_json_schema()
return GuidedDecodingParams(
json=json,
regex=regex,
choice=choice,
grammar=grammar,
json_object=json_object,
backend=backend,
whitespace_pattern=whitespace_pattern,
structural_tag=structural_tag,
)
_backend: Optional[str] = field(default=None, init=False)
"""CAUTION: Should only be set by Processor._validate_structured_output"""
_backend_was_auto: bool = field(default=False, init=False)
"""CAUTION: Should only be set by Processor._validate_structured_output"""
def __post_init__(self):
"""Validate that some fields are mutually exclusive."""
guide_count = sum([
count = sum([
self.json is not None, self.regex is not None, self.choice
is not None, self.grammar is not None, self.json_object is not None
])
if guide_count > 1:
if count > 1:
raise ValueError(
"You can only use one kind of guided decoding but multiple are "
f"specified: {self.__dict__}")
"You can only use one kind of structured outputs constraint "
f"but multiple are specified: {self.__dict__}")
class RequestOutputKind(Enum):
......@@ -196,9 +171,8 @@ class SamplingParams(
_all_stop_token_ids: set[int] = msgspec.field(default_factory=set)
# Fields used to construct logits processors
guided_decoding: Optional[GuidedDecodingParams] = None
"""If provided, the engine will construct a guided decoding logits
processor from these parameters."""
structured_outputs: Optional[StructuredOutputsParams] = None
"""Parameters for configuring structured outputs."""
logit_bias: Optional[dict[int, float]] = None
"""If provided, the engine will construct a logits processor that applies
these logit biases."""
......@@ -246,7 +220,7 @@ class SamplingParams(
msgspec.Meta(
ge=-1)]] = None,
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
guided_decoding: Optional[GuidedDecodingParams] = None,
structured_outputs: Optional[StructuredOutputsParams] = None,
logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None,
allowed_token_ids: Optional[list[int]] = None,
extra_args: Optional[dict[str, Any]] = None,
......@@ -288,7 +262,7 @@ class SamplingParams(
logits_processors=logits_processors,
truncate_prompt_tokens=truncate_prompt_tokens,
output_kind=output_kind,
guided_decoding=guided_decoding,
structured_outputs=structured_outputs,
logit_bias=logit_bias,
allowed_token_ids=allowed_token_ids,
extra_args=extra_args,
......@@ -559,7 +533,7 @@ class SamplingParams(
"spaces_between_special_tokens="
f"{self.spaces_between_special_tokens}, "
f"truncate_prompt_tokens={self.truncate_prompt_tokens}, "
f"guided_decoding={self.guided_decoding}, "
f"structured_outputs={self.structured_outputs}, "
f"extra_args={self.extra_args})")
......
......@@ -274,7 +274,7 @@ class MistralTokenizer(TokenizerBase):
return tokenizer_file
# the following attributes are set to fit vLLM's design and are used
# by the guided structured output backends.
# by the structured output backends.
@property
def all_special_tokens_extended(self) -> list[str]:
from mistral_common.tokens.tokenizers.base import SpecialTokens
......@@ -463,9 +463,6 @@ class MistralTokenizer(TokenizerBase):
return decoded
# WARN: Outlines logits processors can overwrite this method.
# See: guided_decoding/outlines_logits_processors.py::_adapt_tokenizer
# for more.
def decode(self,
ids: Union[list[int], int],
skip_special_tokens: bool = True) -> str:
......
......@@ -588,9 +588,6 @@ class AsyncLLM(EngineClient):
async def get_model_config(self) -> ModelConfig:
return self.model_config
async def get_decoding_config(self):
raise ValueError("Not Supported on V1 yet.")
async def get_input_preprocessor(self) -> InputPreprocessor:
return self.processor.input_preprocessor
......
......@@ -45,7 +45,7 @@ class Processor:
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.decoding_config = vllm_config.decoding_config
self.structured_outputs_config = vllm_config.structured_outputs_config
self.tokenizer = tokenizer
self.generation_config_fields = (
......@@ -219,58 +219,57 @@ class Processor:
"[lora_path]` to use the LoRA tokenizer.")
def _validate_structured_output(self, params: SamplingParams) -> None:
if not params.guided_decoding or not self.decoding_config:
if not params.structured_outputs or not self.structured_outputs_config:
return
if self.model_config.skip_tokenizer_init and params.guided_decoding:
if self.model_config.skip_tokenizer_init and params.structured_outputs:
raise ValueError(
"Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'" # noqa: E501
)
engine_level_backend = self.decoding_config.backend
if params.guided_decoding.backend:
# Request-level backend selection is not supported in V1.
backend = self.structured_outputs_config.backend
if _backend := params.structured_outputs._backend:
# Request-level backend selection is not supported.
# The values may differ if `params` is reused and was set
# to a specific backend based on `auto` behavior in a previous
# request. We remember that it was set as a result of `auto`
# using the `_auto` option set on the backend in the params.
if (params.guided_decoding.backend != engine_level_backend
and not (engine_level_backend == "auto"
and params.guided_decoding.backend_was_auto)):
# using the `_backend_was_auto` field set in the params.
if (backend != _backend
and not (backend == "auto"
and params.structured_outputs._backend_was_auto)):
raise ValueError(
"Request-level structured output backend selection is no "
"longer supported. The request specified "
f"'{params.guided_decoding.backend}', but vLLM was "
f"initialised with '{engine_level_backend}'. This error "
"can be resolved by removing backend selection from the "
"request.")
"Request-level structured output backend selection is not "
f"supported. The request specified '{_backend}', but vLLM "
f"was initialised with '{backend}'. This error can be "
"resolved by removing '_backend' from the request.")
else:
params.guided_decoding.backend = engine_level_backend
params.structured_outputs._backend = backend
# Request content validation
if (isinstance(params.guided_decoding.choice, list)
and not params.guided_decoding.choice):
if (isinstance(params.structured_outputs.choice, list)
and not params.structured_outputs.choice):
# It is invalid for choice to be an empty list
raise ValueError(f"Choice '{params.guided_decoding.choice}' "
"cannot be an empty list")
raise ValueError(
f"Choice '{params.structured_outputs.choice}' cannot be an empty list" # noqa: E501
)
if engine_level_backend.startswith("xgrammar"):
if backend.startswith("xgrammar"):
# xgrammar with no fallback
validate_xgrammar_grammar(params)
elif engine_level_backend.startswith("guidance"):
elif backend.startswith("guidance"):
# TODO: ideally we would have the LLTokenizer here as Lark syntax
# allows <|special_token|> and similar, see
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# Without tokenizer these are disallowed in grammars.
validate_guidance_grammar(params, tokenizer=None)
elif engine_level_backend == "outlines":
elif backend == "outlines":
# outlines backend
validate_structured_output_request_outlines(params)
elif engine_level_backend == "lm-format-enforcer":
elif backend == "lm-format-enforcer":
# lm format enforcer backend
validate_structured_output_request_lm_format_enforcer(params)
else:
# NOTE: engine_level_backend must be "auto" here, because we have
# NOTE: backend must be "auto" here, because we have
# checked supported_backends above.
# In this mode, we set opinionated defaults based on what we think
# will satisfy the most use cases without having to worry about
......@@ -278,15 +277,15 @@ class Processor:
# other setting where a specific backend was specified.
try:
validate_xgrammar_grammar(params)
params.guided_decoding.backend = "xgrammar"
params.structured_outputs._backend = "xgrammar"
except ValueError:
# The request either failed validation
# or includes some jsonschema feature(s) that
# are not supported in xgrammar. Fall back to guidance.
validate_guidance_grammar(params, tokenizer=None)
params.guided_decoding.backend = "guidance"
params.structured_outputs._backend = "guidance"
# Remember that this backend was set automatically
params.guided_decoding.backend_was_auto = True
params.structured_outputs._backend_was_auto = True
def _maybe_build_mm_uuids(
self,
......
......@@ -67,7 +67,7 @@ class Request:
# Generative models.
assert sampling_params.max_tokens is not None
self.max_tokens = sampling_params.max_tokens
if sampling_params.guided_decoding is not None:
if sampling_params.structured_outputs is not None:
self.status = RequestStatus.WAITING_FOR_FSM
self.use_structured_output = True
......
......@@ -61,11 +61,11 @@ class StructuredOutputManager:
self.executor = ThreadPoolExecutor(max_workers=max_workers)
self.tokenizer = init_tokenizer_from_configs(
model_config=self.vllm_config.model_config)
reasoning_backend = \
self.vllm_config.decoding_config.reasoning_backend
if reasoning_backend:
reasoning_parser = \
self.vllm_config.structured_outputs_config.reasoning_parser
if reasoning_parser:
reasoner_cls = ReasoningParserManager.get_reasoning_parser(
reasoning_backend)
reasoning_parser)
self.reasoner = reasoner_cls(tokenizer=self.tokenizer)
def grammar_init(self, request: Request) -> None:
......@@ -74,15 +74,16 @@ class StructuredOutputManager:
if TYPE_CHECKING:
assert request.sampling_params is not None and \
request.sampling_params.guided_decoding is not None
request.sampling_params.structured_outputs is not None
# Initialize the backend the first time it is needed.
#
# NOTE: We only support a single backend. We do NOT support different
# backends on a per-request basis in V1 (for now, anyway...).
# _backend is set in Processor._validate_structured_output
if self.backend is None:
assert request.sampling_params is not None
backend = request.sampling_params.guided_decoding.backend
backend = request.sampling_params.structured_outputs._backend
vocab_size = self.vllm_config.model_config.get_vocab_size()
if backend == "xgrammar":
self.backend = XgrammarBackend(
......
......@@ -60,9 +60,9 @@ class GuidanceBackend(StructuredOutputBackend):
def __post_init__(self):
self.disable_any_whitespace = \
self.vllm_config.decoding_config.disable_any_whitespace
self.vllm_config.structured_outputs_config.disable_any_whitespace
self.disable_additional_properties = \
self.vllm_config.decoding_config.disable_additional_properties
self.vllm_config.structured_outputs_config.disable_additional_properties
self.ll_tokenizer = llguidance_hf.from_tokenizer(
self.tokenizer, self.vocab_size)
......
......@@ -138,30 +138,30 @@ class LMFormatEnforcerBackend(StructuredOutputBackend):
def validate_structured_output_request_lm_format_enforcer(
params: SamplingParams):
if params.guided_decoding is None:
if params.structured_outputs is None:
return
gd_params = params.guided_decoding
so_params = params.structured_outputs
if gd_params.regex:
if so_params.regex:
return
elif gd_params.json:
if isinstance(gd_params.json, str):
elif so_params.json:
if isinstance(so_params.json, str):
try:
# make sure schema is valid json
json.loads(gd_params.json)
json.loads(so_params.json)
except json.JSONDecodeError as e:
raise ValueError("Invalid JSON grammar specification.") from e
else:
try:
json.dumps(gd_params.json)
json.dumps(so_params.json)
except Exception as e:
raise ValueError(
f"Error serializing guided decoding jsonschema: {e}"
f"Error serializing structured outputs jsonschema: {e}"
) from e
return
elif gd_params.choice:
elif so_params.choice:
return
elif gd_params.grammar:
raise ValueError("LM Format Enforcer guided decoding backend "
elif so_params.grammar:
raise ValueError("LM Format Enforcer structured outputs backend "
"does not support grammar specifications")
......@@ -158,36 +158,36 @@ class OutlinesGrammar(StructuredOutputGrammar):
def validate_structured_output_request_outlines(params: SamplingParams):
if params.guided_decoding is None:
if params.structured_outputs is None:
return
gd_params = params.guided_decoding
so_params = params.structured_outputs
if gd_params.regex:
validate_regex_is_buildable(gd_params.regex)
elif gd_params.json:
if isinstance(gd_params.json, str):
if so_params.regex:
validate_regex_is_buildable(so_params.regex)
elif so_params.json:
if isinstance(so_params.json, str):
try:
# make sure schema is valid json
json.loads(gd_params.json)
schema = gd_params.json
json.loads(so_params.json)
schema = so_params.json
except json.JSONDecodeError as e:
raise ValueError("Invalid JSON grammar specification.") from e
else:
try:
schema = json.dumps(gd_params.json)
schema = json.dumps(so_params.json)
except Exception as e:
raise ValueError(
f"Error serializing guided decoding jsonschema: {e}"
f"Error serializing structured outputs jsonschema: {e}"
) from e
pattern = json_schema.build_regex_from_schema(schema)
validate_regex_is_buildable(pattern)
elif gd_params.choice:
choices = [regex_escape(str(choice)) for choice in gd_params.choice]
elif so_params.choice:
choices = [regex_escape(str(choice)) for choice in so_params.choice]
regex = "(" + "|".join(choices) + ")"
validate_regex_is_buildable(regex)
elif gd_params.grammar:
raise ValueError("Outlines guided decoding backend "
elif so_params.grammar:
raise ValueError("Outlines structured outputs backend "
"does not support grammar specifications")
......@@ -306,7 +306,7 @@ def validate_regex_is_buildable(pattern: str) -> None:
_check_unsupported(parsed)
except ValueError as e:
raise ValueError(
f"Regex uses unsupported feature for guided decoding: {e}. "
f"Regex uses unsupported feature for structured outputs: {e}. "
"Only basic matching constructs are supported—lookarounds, "
"backreferences, and unicode boundaries are not.") from e
......@@ -315,6 +315,6 @@ def validate_regex_is_buildable(pattern: str) -> None:
"Regex does not have a anchored universal start state"
"This means that the Regex uses anchors (^) or look-arounds "
"in a way which requires context before any token is matched."
"Guided decoding needs regexes that can match without needing "
"structured outputs needs regexes that can match without needing "
"that context. Try rewriting the pattern without using these "
f"constructs. Pattern:\n{pattern}")
......@@ -34,7 +34,7 @@ class XgrammarBackend(StructuredOutputBackend):
def __post_init__(self):
self.disable_any_whitespace = \
self.vllm_config.decoding_config.disable_any_whitespace
self.vllm_config.structured_outputs_config.disable_any_whitespace
if isinstance(self.tokenizer, MistralTokenizer):
# NOTE: ideally, xgrammar should handle this accordingly.
......@@ -248,37 +248,37 @@ def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None:
Raises ValueError if the request is not supported.
"""
if sampling_params.guided_decoding is None:
if sampling_params.structured_outputs is None:
return
gd_params = sampling_params.guided_decoding
so_params = sampling_params.structured_outputs
if gd_params.regex:
if so_params.regex:
try:
xgr.Grammar.from_regex(gd_params.regex)
xgr.Grammar.from_regex(so_params.regex)
except Exception as err:
raise ValueError("Failed to transform regex into a grammar: "
f"{err}") from err
if gd_params.choice:
choice_grammar = choice_as_grammar(gd_params.choice)
if so_params.choice:
choice_grammar = choice_as_grammar(so_params.choice)
try:
xgr.Grammar.from_ebnf(choice_grammar)
except Exception as err:
raise ValueError("Failed to transform choices into a grammar: "
"{err}") from err
gd_params.choice = None
gd_params.grammar = choice_grammar
so_params.choice = None
so_params.grammar = choice_grammar
return
if gd_params.json:
if isinstance(gd_params.json, str):
if so_params.json:
if isinstance(so_params.json, str):
try:
schema = json.loads(gd_params.json)
schema = json.loads(so_params.json)
except json.JSONDecodeError as e:
raise ValueError("Invalid JSON grammar specification.") from e
else:
schema = gd_params.json
schema = so_params.json
try:
xgr.Grammar.from_json_schema(schema)
......@@ -291,11 +291,11 @@ def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None:
"supported by xgrammar.")
return
if gd_params.grammar:
if grammar_is_likely_lark(gd_params.grammar):
if so_params.grammar:
if grammar_is_likely_lark(so_params.grammar):
# xgrammar supports EBNF grammars only
try:
gd_params.grammar = convert_lark_to_ebnf(gd_params.grammar)
so_params.grammar = convert_lark_to_ebnf(so_params.grammar)
except ValueError as e:
raise ValueError(
"Failed to convert the grammar from Lark to EBNF. ") from e
......@@ -303,14 +303,14 @@ def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None:
# Test parsing EBNF grammar, possibly already converted from Lark
try:
# parse the grammar, but we aren't compiling it.
xgr.Grammar.from_ebnf(gd_params.grammar)
xgr.Grammar.from_ebnf(so_params.grammar)
except Exception as e:
raise ValueError("Invalid grammar specification.") from e
return
if gd_params.structural_tag:
if so_params.structural_tag:
try:
s_tag = json.loads(gd_params.structural_tag)
s_tag = json.loads(so_params.structural_tag)
tags = [
xgr.StructuralTagItem(
begin=s["begin"],
......
......@@ -60,7 +60,7 @@ class StructuredOutputRequest:
def get_structured_output_key(
sampling_params: SamplingParams) -> StructuredOutputKey:
params = sampling_params.guided_decoding
params = sampling_params.structured_outputs
assert params is not None, "params can't be None."
if params.json is not None:
if not isinstance(params.json, str):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment