Unverified Commit 29283e89 authored by Aaron Pham's avatar Aaron Pham Committed by GitHub
Browse files

[Chore] Cleanup guided namespace, move to structured outputs config (#22772)


Signed-off-by: default avatarAaron Pham <contact@aarnphm.xyz>
Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 05b044e6
...@@ -2,13 +2,13 @@ ...@@ -2,13 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Sampling parameters for text generation.""" """Sampling parameters for text generation."""
import copy import copy
from dataclasses import dataclass from dataclasses import field
from enum import Enum, IntEnum from enum import Enum, IntEnum
from functools import cached_property from functools import cached_property
from typing import Annotated, Any, Optional, Union from typing import Annotated, Any, Optional, Union
import msgspec import msgspec
from pydantic import BaseModel from pydantic.dataclasses import dataclass
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logits_process import LogitsProcessor from vllm.logits_process import LogitsProcessor
...@@ -28,60 +28,35 @@ class SamplingType(IntEnum): ...@@ -28,60 +28,35 @@ class SamplingType(IntEnum):
# maybe make msgspec? # maybe make msgspec?
@dataclass @dataclass
class GuidedDecodingParams: class StructuredOutputsParams:
"""One of these fields will be used to build a logit processor.""" # One of these fields will be used to build a logit processor.
json: Optional[Union[str, dict]] = None json: Optional[Union[str, dict]] = None
regex: Optional[str] = None regex: Optional[str] = None
choice: Optional[list[str]] = None choice: Optional[list[str]] = None
grammar: Optional[str] = None grammar: Optional[str] = None
json_object: Optional[bool] = None json_object: Optional[bool] = None
"""These are other options that can be set""" # These are other options that can be set.
backend: Optional[str] = None
backend_was_auto: bool = False
disable_fallback: bool = False disable_fallback: bool = False
disable_any_whitespace: bool = False disable_any_whitespace: bool = False
disable_additional_properties: bool = False disable_additional_properties: bool = False
whitespace_pattern: Optional[str] = None whitespace_pattern: Optional[str] = None
structural_tag: Optional[str] = None structural_tag: Optional[str] = None
@staticmethod _backend: Optional[str] = field(default=None, init=False)
def from_optional( """CAUTION: Should only be set by Processor._validate_structured_output"""
json: Optional[Union[dict, BaseModel, str]] = None, _backend_was_auto: bool = field(default=False, init=False)
regex: Optional[str] = None, """CAUTION: Should only be set by Processor._validate_structured_output"""
choice: Optional[list[str]] = None,
grammar: Optional[str] = None,
json_object: Optional[bool] = None,
backend: Optional[str] = None,
whitespace_pattern: Optional[str] = None,
structural_tag: Optional[str] = None,
) -> Optional["GuidedDecodingParams"]:
if all(arg is None for arg in (json, regex, choice, grammar,
json_object, structural_tag)):
return None
# Extract json schemas from pydantic models
if isinstance(json, (BaseModel, type(BaseModel))):
json = json.model_json_schema()
return GuidedDecodingParams(
json=json,
regex=regex,
choice=choice,
grammar=grammar,
json_object=json_object,
backend=backend,
whitespace_pattern=whitespace_pattern,
structural_tag=structural_tag,
)
def __post_init__(self): def __post_init__(self):
"""Validate that some fields are mutually exclusive.""" """Validate that some fields are mutually exclusive."""
guide_count = sum([ count = sum([
self.json is not None, self.regex is not None, self.choice self.json is not None, self.regex is not None, self.choice
is not None, self.grammar is not None, self.json_object is not None is not None, self.grammar is not None, self.json_object is not None
]) ])
if guide_count > 1: if count > 1:
raise ValueError( raise ValueError(
"You can only use one kind of guided decoding but multiple are " "You can only use one kind of structured outputs constraint "
f"specified: {self.__dict__}") f"but multiple are specified: {self.__dict__}")
class RequestOutputKind(Enum): class RequestOutputKind(Enum):
...@@ -196,9 +171,8 @@ class SamplingParams( ...@@ -196,9 +171,8 @@ class SamplingParams(
_all_stop_token_ids: set[int] = msgspec.field(default_factory=set) _all_stop_token_ids: set[int] = msgspec.field(default_factory=set)
# Fields used to construct logits processors # Fields used to construct logits processors
guided_decoding: Optional[GuidedDecodingParams] = None structured_outputs: Optional[StructuredOutputsParams] = None
"""If provided, the engine will construct a guided decoding logits """Parameters for configuring structured outputs."""
processor from these parameters."""
logit_bias: Optional[dict[int, float]] = None logit_bias: Optional[dict[int, float]] = None
"""If provided, the engine will construct a logits processor that applies """If provided, the engine will construct a logits processor that applies
these logit biases.""" these logit biases."""
...@@ -246,7 +220,7 @@ class SamplingParams( ...@@ -246,7 +220,7 @@ class SamplingParams(
msgspec.Meta( msgspec.Meta(
ge=-1)]] = None, ge=-1)]] = None,
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE, output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
guided_decoding: Optional[GuidedDecodingParams] = None, structured_outputs: Optional[StructuredOutputsParams] = None,
logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None, logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None,
allowed_token_ids: Optional[list[int]] = None, allowed_token_ids: Optional[list[int]] = None,
extra_args: Optional[dict[str, Any]] = None, extra_args: Optional[dict[str, Any]] = None,
...@@ -288,7 +262,7 @@ class SamplingParams( ...@@ -288,7 +262,7 @@ class SamplingParams(
logits_processors=logits_processors, logits_processors=logits_processors,
truncate_prompt_tokens=truncate_prompt_tokens, truncate_prompt_tokens=truncate_prompt_tokens,
output_kind=output_kind, output_kind=output_kind,
guided_decoding=guided_decoding, structured_outputs=structured_outputs,
logit_bias=logit_bias, logit_bias=logit_bias,
allowed_token_ids=allowed_token_ids, allowed_token_ids=allowed_token_ids,
extra_args=extra_args, extra_args=extra_args,
...@@ -559,7 +533,7 @@ class SamplingParams( ...@@ -559,7 +533,7 @@ class SamplingParams(
"spaces_between_special_tokens=" "spaces_between_special_tokens="
f"{self.spaces_between_special_tokens}, " f"{self.spaces_between_special_tokens}, "
f"truncate_prompt_tokens={self.truncate_prompt_tokens}, " f"truncate_prompt_tokens={self.truncate_prompt_tokens}, "
f"guided_decoding={self.guided_decoding}, " f"structured_outputs={self.structured_outputs}, "
f"extra_args={self.extra_args})") f"extra_args={self.extra_args})")
......
...@@ -274,7 +274,7 @@ class MistralTokenizer(TokenizerBase): ...@@ -274,7 +274,7 @@ class MistralTokenizer(TokenizerBase):
return tokenizer_file return tokenizer_file
# the following attributes are set to fit vLLM's design and are used # the following attributes are set to fit vLLM's design and are used
# by the guided structured output backends. # by the structured output backends.
@property @property
def all_special_tokens_extended(self) -> list[str]: def all_special_tokens_extended(self) -> list[str]:
from mistral_common.tokens.tokenizers.base import SpecialTokens from mistral_common.tokens.tokenizers.base import SpecialTokens
...@@ -463,9 +463,6 @@ class MistralTokenizer(TokenizerBase): ...@@ -463,9 +463,6 @@ class MistralTokenizer(TokenizerBase):
return decoded return decoded
# WARN: Outlines logits processors can overwrite this method.
# See: guided_decoding/outlines_logits_processors.py::_adapt_tokenizer
# for more.
def decode(self, def decode(self,
ids: Union[list[int], int], ids: Union[list[int], int],
skip_special_tokens: bool = True) -> str: skip_special_tokens: bool = True) -> str:
......
...@@ -588,9 +588,6 @@ class AsyncLLM(EngineClient): ...@@ -588,9 +588,6 @@ class AsyncLLM(EngineClient):
async def get_model_config(self) -> ModelConfig: async def get_model_config(self) -> ModelConfig:
return self.model_config return self.model_config
async def get_decoding_config(self):
raise ValueError("Not Supported on V1 yet.")
async def get_input_preprocessor(self) -> InputPreprocessor: async def get_input_preprocessor(self) -> InputPreprocessor:
return self.processor.input_preprocessor return self.processor.input_preprocessor
......
...@@ -45,7 +45,7 @@ class Processor: ...@@ -45,7 +45,7 @@ class Processor:
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config self.lora_config = vllm_config.lora_config
self.decoding_config = vllm_config.decoding_config self.structured_outputs_config = vllm_config.structured_outputs_config
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.generation_config_fields = ( self.generation_config_fields = (
...@@ -219,58 +219,57 @@ class Processor: ...@@ -219,58 +219,57 @@ class Processor:
"[lora_path]` to use the LoRA tokenizer.") "[lora_path]` to use the LoRA tokenizer.")
def _validate_structured_output(self, params: SamplingParams) -> None: def _validate_structured_output(self, params: SamplingParams) -> None:
if not params.guided_decoding or not self.decoding_config: if not params.structured_outputs or not self.structured_outputs_config:
return return
if self.model_config.skip_tokenizer_init and params.guided_decoding: if self.model_config.skip_tokenizer_init and params.structured_outputs:
raise ValueError( raise ValueError(
"Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'" # noqa: E501 "Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'" # noqa: E501
) )
engine_level_backend = self.decoding_config.backend backend = self.structured_outputs_config.backend
if params.guided_decoding.backend: if _backend := params.structured_outputs._backend:
# Request-level backend selection is not supported in V1. # Request-level backend selection is not supported.
# The values may differ if `params` is reused and was set # The values may differ if `params` is reused and was set
# to a specific backend based on `auto` behavior in a previous # to a specific backend based on `auto` behavior in a previous
# request. We remember that it was set as a result of `auto` # request. We remember that it was set as a result of `auto`
# using the `_auto` option set on the backend in the params. # using the `_backend_was_auto` field set in the params.
if (params.guided_decoding.backend != engine_level_backend if (backend != _backend
and not (engine_level_backend == "auto" and not (backend == "auto"
and params.guided_decoding.backend_was_auto)): and params.structured_outputs._backend_was_auto)):
raise ValueError( raise ValueError(
"Request-level structured output backend selection is no " "Request-level structured output backend selection is not "
"longer supported. The request specified " f"supported. The request specified '{_backend}', but vLLM "
f"'{params.guided_decoding.backend}', but vLLM was " f"was initialised with '{backend}'. This error can be "
f"initialised with '{engine_level_backend}'. This error " "resolved by removing '_backend' from the request.")
"can be resolved by removing backend selection from the "
"request.")
else: else:
params.guided_decoding.backend = engine_level_backend params.structured_outputs._backend = backend
# Request content validation # Request content validation
if (isinstance(params.guided_decoding.choice, list) if (isinstance(params.structured_outputs.choice, list)
and not params.guided_decoding.choice): and not params.structured_outputs.choice):
# It is invalid for choice to be an empty list # It is invalid for choice to be an empty list
raise ValueError(f"Choice '{params.guided_decoding.choice}' " raise ValueError(
"cannot be an empty list") f"Choice '{params.structured_outputs.choice}' cannot be an empty list" # noqa: E501
)
if engine_level_backend.startswith("xgrammar"): if backend.startswith("xgrammar"):
# xgrammar with no fallback # xgrammar with no fallback
validate_xgrammar_grammar(params) validate_xgrammar_grammar(params)
elif engine_level_backend.startswith("guidance"): elif backend.startswith("guidance"):
# TODO: ideally we would have the LLTokenizer here as Lark syntax # TODO: ideally we would have the LLTokenizer here as Lark syntax
# allows <|special_token|> and similar, see # allows <|special_token|> and similar, see
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# Without tokenizer these are disallowed in grammars. # Without tokenizer these are disallowed in grammars.
validate_guidance_grammar(params, tokenizer=None) validate_guidance_grammar(params, tokenizer=None)
elif engine_level_backend == "outlines": elif backend == "outlines":
# outlines backend # outlines backend
validate_structured_output_request_outlines(params) validate_structured_output_request_outlines(params)
elif engine_level_backend == "lm-format-enforcer": elif backend == "lm-format-enforcer":
# lm format enforcer backend # lm format enforcer backend
validate_structured_output_request_lm_format_enforcer(params) validate_structured_output_request_lm_format_enforcer(params)
else: else:
# NOTE: engine_level_backend must be "auto" here, because we have # NOTE: backend must be "auto" here, because we have
# checked supported_backends above. # checked supported_backends above.
# In this mode, we set opinionated defaults based on what we think # In this mode, we set opinionated defaults based on what we think
# will satisfy the most use cases without having to worry about # will satisfy the most use cases without having to worry about
...@@ -278,15 +277,15 @@ class Processor: ...@@ -278,15 +277,15 @@ class Processor:
# other setting where a specific backend was specified. # other setting where a specific backend was specified.
try: try:
validate_xgrammar_grammar(params) validate_xgrammar_grammar(params)
params.guided_decoding.backend = "xgrammar" params.structured_outputs._backend = "xgrammar"
except ValueError: except ValueError:
# The request either failed validation # The request either failed validation
# or includes some jsonschema feature(s) that # or includes some jsonschema feature(s) that
# are not supported in xgrammar. Fall back to guidance. # are not supported in xgrammar. Fall back to guidance.
validate_guidance_grammar(params, tokenizer=None) validate_guidance_grammar(params, tokenizer=None)
params.guided_decoding.backend = "guidance" params.structured_outputs._backend = "guidance"
# Remember that this backend was set automatically # Remember that this backend was set automatically
params.guided_decoding.backend_was_auto = True params.structured_outputs._backend_was_auto = True
def _maybe_build_mm_uuids( def _maybe_build_mm_uuids(
self, self,
......
...@@ -67,7 +67,7 @@ class Request: ...@@ -67,7 +67,7 @@ class Request:
# Generative models. # Generative models.
assert sampling_params.max_tokens is not None assert sampling_params.max_tokens is not None
self.max_tokens = sampling_params.max_tokens self.max_tokens = sampling_params.max_tokens
if sampling_params.guided_decoding is not None: if sampling_params.structured_outputs is not None:
self.status = RequestStatus.WAITING_FOR_FSM self.status = RequestStatus.WAITING_FOR_FSM
self.use_structured_output = True self.use_structured_output = True
......
...@@ -61,11 +61,11 @@ class StructuredOutputManager: ...@@ -61,11 +61,11 @@ class StructuredOutputManager:
self.executor = ThreadPoolExecutor(max_workers=max_workers) self.executor = ThreadPoolExecutor(max_workers=max_workers)
self.tokenizer = init_tokenizer_from_configs( self.tokenizer = init_tokenizer_from_configs(
model_config=self.vllm_config.model_config) model_config=self.vllm_config.model_config)
reasoning_backend = \ reasoning_parser = \
self.vllm_config.decoding_config.reasoning_backend self.vllm_config.structured_outputs_config.reasoning_parser
if reasoning_backend: if reasoning_parser:
reasoner_cls = ReasoningParserManager.get_reasoning_parser( reasoner_cls = ReasoningParserManager.get_reasoning_parser(
reasoning_backend) reasoning_parser)
self.reasoner = reasoner_cls(tokenizer=self.tokenizer) self.reasoner = reasoner_cls(tokenizer=self.tokenizer)
def grammar_init(self, request: Request) -> None: def grammar_init(self, request: Request) -> None:
...@@ -74,15 +74,16 @@ class StructuredOutputManager: ...@@ -74,15 +74,16 @@ class StructuredOutputManager:
if TYPE_CHECKING: if TYPE_CHECKING:
assert request.sampling_params is not None and \ assert request.sampling_params is not None and \
request.sampling_params.guided_decoding is not None request.sampling_params.structured_outputs is not None
# Initialize the backend the first time it is needed. # Initialize the backend the first time it is needed.
# #
# NOTE: We only support a single backend. We do NOT support different # NOTE: We only support a single backend. We do NOT support different
# backends on a per-request basis in V1 (for now, anyway...). # backends on a per-request basis in V1 (for now, anyway...).
# _backend is set in Processor._validate_structured_output
if self.backend is None: if self.backend is None:
assert request.sampling_params is not None assert request.sampling_params is not None
backend = request.sampling_params.guided_decoding.backend backend = request.sampling_params.structured_outputs._backend
vocab_size = self.vllm_config.model_config.get_vocab_size() vocab_size = self.vllm_config.model_config.get_vocab_size()
if backend == "xgrammar": if backend == "xgrammar":
self.backend = XgrammarBackend( self.backend = XgrammarBackend(
......
...@@ -60,9 +60,9 @@ class GuidanceBackend(StructuredOutputBackend): ...@@ -60,9 +60,9 @@ class GuidanceBackend(StructuredOutputBackend):
def __post_init__(self): def __post_init__(self):
self.disable_any_whitespace = \ self.disable_any_whitespace = \
self.vllm_config.decoding_config.disable_any_whitespace self.vllm_config.structured_outputs_config.disable_any_whitespace
self.disable_additional_properties = \ self.disable_additional_properties = \
self.vllm_config.decoding_config.disable_additional_properties self.vllm_config.structured_outputs_config.disable_additional_properties
self.ll_tokenizer = llguidance_hf.from_tokenizer( self.ll_tokenizer = llguidance_hf.from_tokenizer(
self.tokenizer, self.vocab_size) self.tokenizer, self.vocab_size)
......
...@@ -138,30 +138,30 @@ class LMFormatEnforcerBackend(StructuredOutputBackend): ...@@ -138,30 +138,30 @@ class LMFormatEnforcerBackend(StructuredOutputBackend):
def validate_structured_output_request_lm_format_enforcer( def validate_structured_output_request_lm_format_enforcer(
params: SamplingParams): params: SamplingParams):
if params.guided_decoding is None: if params.structured_outputs is None:
return return
gd_params = params.guided_decoding so_params = params.structured_outputs
if gd_params.regex: if so_params.regex:
return return
elif gd_params.json: elif so_params.json:
if isinstance(gd_params.json, str): if isinstance(so_params.json, str):
try: try:
# make sure schema is valid json # make sure schema is valid json
json.loads(gd_params.json) json.loads(so_params.json)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
raise ValueError("Invalid JSON grammar specification.") from e raise ValueError("Invalid JSON grammar specification.") from e
else: else:
try: try:
json.dumps(gd_params.json) json.dumps(so_params.json)
except Exception as e: except Exception as e:
raise ValueError( raise ValueError(
f"Error serializing guided decoding jsonschema: {e}" f"Error serializing structured outputs jsonschema: {e}"
) from e ) from e
return return
elif gd_params.choice: elif so_params.choice:
return return
elif gd_params.grammar: elif so_params.grammar:
raise ValueError("LM Format Enforcer guided decoding backend " raise ValueError("LM Format Enforcer structured outputs backend "
"does not support grammar specifications") "does not support grammar specifications")
...@@ -158,36 +158,36 @@ class OutlinesGrammar(StructuredOutputGrammar): ...@@ -158,36 +158,36 @@ class OutlinesGrammar(StructuredOutputGrammar):
def validate_structured_output_request_outlines(params: SamplingParams): def validate_structured_output_request_outlines(params: SamplingParams):
if params.guided_decoding is None: if params.structured_outputs is None:
return return
gd_params = params.guided_decoding so_params = params.structured_outputs
if gd_params.regex: if so_params.regex:
validate_regex_is_buildable(gd_params.regex) validate_regex_is_buildable(so_params.regex)
elif gd_params.json: elif so_params.json:
if isinstance(gd_params.json, str): if isinstance(so_params.json, str):
try: try:
# make sure schema is valid json # make sure schema is valid json
json.loads(gd_params.json) json.loads(so_params.json)
schema = gd_params.json schema = so_params.json
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
raise ValueError("Invalid JSON grammar specification.") from e raise ValueError("Invalid JSON grammar specification.") from e
else: else:
try: try:
schema = json.dumps(gd_params.json) schema = json.dumps(so_params.json)
except Exception as e: except Exception as e:
raise ValueError( raise ValueError(
f"Error serializing guided decoding jsonschema: {e}" f"Error serializing structured outputs jsonschema: {e}"
) from e ) from e
pattern = json_schema.build_regex_from_schema(schema) pattern = json_schema.build_regex_from_schema(schema)
validate_regex_is_buildable(pattern) validate_regex_is_buildable(pattern)
elif gd_params.choice: elif so_params.choice:
choices = [regex_escape(str(choice)) for choice in gd_params.choice] choices = [regex_escape(str(choice)) for choice in so_params.choice]
regex = "(" + "|".join(choices) + ")" regex = "(" + "|".join(choices) + ")"
validate_regex_is_buildable(regex) validate_regex_is_buildable(regex)
elif gd_params.grammar: elif so_params.grammar:
raise ValueError("Outlines guided decoding backend " raise ValueError("Outlines structured outputs backend "
"does not support grammar specifications") "does not support grammar specifications")
...@@ -306,7 +306,7 @@ def validate_regex_is_buildable(pattern: str) -> None: ...@@ -306,7 +306,7 @@ def validate_regex_is_buildable(pattern: str) -> None:
_check_unsupported(parsed) _check_unsupported(parsed)
except ValueError as e: except ValueError as e:
raise ValueError( raise ValueError(
f"Regex uses unsupported feature for guided decoding: {e}. " f"Regex uses unsupported feature for structured outputs: {e}. "
"Only basic matching constructs are supported—lookarounds, " "Only basic matching constructs are supported—lookarounds, "
"backreferences, and unicode boundaries are not.") from e "backreferences, and unicode boundaries are not.") from e
...@@ -315,6 +315,6 @@ def validate_regex_is_buildable(pattern: str) -> None: ...@@ -315,6 +315,6 @@ def validate_regex_is_buildable(pattern: str) -> None:
"Regex does not have a anchored universal start state" "Regex does not have a anchored universal start state"
"This means that the Regex uses anchors (^) or look-arounds " "This means that the Regex uses anchors (^) or look-arounds "
"in a way which requires context before any token is matched." "in a way which requires context before any token is matched."
"Guided decoding needs regexes that can match without needing " "structured outputs needs regexes that can match without needing "
"that context. Try rewriting the pattern without using these " "that context. Try rewriting the pattern without using these "
f"constructs. Pattern:\n{pattern}") f"constructs. Pattern:\n{pattern}")
...@@ -34,7 +34,7 @@ class XgrammarBackend(StructuredOutputBackend): ...@@ -34,7 +34,7 @@ class XgrammarBackend(StructuredOutputBackend):
def __post_init__(self): def __post_init__(self):
self.disable_any_whitespace = \ self.disable_any_whitespace = \
self.vllm_config.decoding_config.disable_any_whitespace self.vllm_config.structured_outputs_config.disable_any_whitespace
if isinstance(self.tokenizer, MistralTokenizer): if isinstance(self.tokenizer, MistralTokenizer):
# NOTE: ideally, xgrammar should handle this accordingly. # NOTE: ideally, xgrammar should handle this accordingly.
...@@ -248,37 +248,37 @@ def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None: ...@@ -248,37 +248,37 @@ def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None:
Raises ValueError if the request is not supported. Raises ValueError if the request is not supported.
""" """
if sampling_params.guided_decoding is None: if sampling_params.structured_outputs is None:
return return
gd_params = sampling_params.guided_decoding so_params = sampling_params.structured_outputs
if gd_params.regex: if so_params.regex:
try: try:
xgr.Grammar.from_regex(gd_params.regex) xgr.Grammar.from_regex(so_params.regex)
except Exception as err: except Exception as err:
raise ValueError("Failed to transform regex into a grammar: " raise ValueError("Failed to transform regex into a grammar: "
f"{err}") from err f"{err}") from err
if gd_params.choice: if so_params.choice:
choice_grammar = choice_as_grammar(gd_params.choice) choice_grammar = choice_as_grammar(so_params.choice)
try: try:
xgr.Grammar.from_ebnf(choice_grammar) xgr.Grammar.from_ebnf(choice_grammar)
except Exception as err: except Exception as err:
raise ValueError("Failed to transform choices into a grammar: " raise ValueError("Failed to transform choices into a grammar: "
"{err}") from err "{err}") from err
gd_params.choice = None so_params.choice = None
gd_params.grammar = choice_grammar so_params.grammar = choice_grammar
return return
if gd_params.json: if so_params.json:
if isinstance(gd_params.json, str): if isinstance(so_params.json, str):
try: try:
schema = json.loads(gd_params.json) schema = json.loads(so_params.json)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
raise ValueError("Invalid JSON grammar specification.") from e raise ValueError("Invalid JSON grammar specification.") from e
else: else:
schema = gd_params.json schema = so_params.json
try: try:
xgr.Grammar.from_json_schema(schema) xgr.Grammar.from_json_schema(schema)
...@@ -291,11 +291,11 @@ def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None: ...@@ -291,11 +291,11 @@ def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None:
"supported by xgrammar.") "supported by xgrammar.")
return return
if gd_params.grammar: if so_params.grammar:
if grammar_is_likely_lark(gd_params.grammar): if grammar_is_likely_lark(so_params.grammar):
# xgrammar supports EBNF grammars only # xgrammar supports EBNF grammars only
try: try:
gd_params.grammar = convert_lark_to_ebnf(gd_params.grammar) so_params.grammar = convert_lark_to_ebnf(so_params.grammar)
except ValueError as e: except ValueError as e:
raise ValueError( raise ValueError(
"Failed to convert the grammar from Lark to EBNF. ") from e "Failed to convert the grammar from Lark to EBNF. ") from e
...@@ -303,14 +303,14 @@ def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None: ...@@ -303,14 +303,14 @@ def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None:
# Test parsing EBNF grammar, possibly already converted from Lark # Test parsing EBNF grammar, possibly already converted from Lark
try: try:
# parse the grammar, but we aren't compiling it. # parse the grammar, but we aren't compiling it.
xgr.Grammar.from_ebnf(gd_params.grammar) xgr.Grammar.from_ebnf(so_params.grammar)
except Exception as e: except Exception as e:
raise ValueError("Invalid grammar specification.") from e raise ValueError("Invalid grammar specification.") from e
return return
if gd_params.structural_tag: if so_params.structural_tag:
try: try:
s_tag = json.loads(gd_params.structural_tag) s_tag = json.loads(so_params.structural_tag)
tags = [ tags = [
xgr.StructuralTagItem( xgr.StructuralTagItem(
begin=s["begin"], begin=s["begin"],
......
...@@ -60,7 +60,7 @@ class StructuredOutputRequest: ...@@ -60,7 +60,7 @@ class StructuredOutputRequest:
def get_structured_output_key( def get_structured_output_key(
sampling_params: SamplingParams) -> StructuredOutputKey: sampling_params: SamplingParams) -> StructuredOutputKey:
params = sampling_params.guided_decoding params = sampling_params.structured_outputs
assert params is not None, "params can't be None." assert params is not None, "params can't be None."
if params.json is not None: if params.json is not None:
if not isinstance(params.json, str): if not isinstance(params.json, str):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment