"vscode:/vscode.git/clone" did not exist on "ea1292ad3ee724e44b3dfec2a26778cd614729f9"
Unverified Commit 29283e89 authored by Aaron Pham's avatar Aaron Pham Committed by GitHub
Browse files

[Chore] Cleanup guided namespace, move to structured outputs config (#22772)


Signed-off-by: default avatarAaron Pham <contact@aarnphm.xyz>
Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 05b044e6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# imports for guided decoding tests
# imports for structured outputs tests
import io
import json
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import io
# imports for guided decoding tests
# imports for structured outputs tests
import json
import httpx
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for the SamplingParams class.
"""
import pytest
from vllm import SamplingParams
from vllm.config import ModelConfig
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
MODEL_NAME = "Qwen/Qwen1.5-7B"
def test_max_tokens_none():
"""max_tokens=None should be allowed"""
SamplingParams(temperature=0.01, top_p=0.1, max_tokens=None)
@pytest.fixture(scope="module")
def model_config():
return ModelConfig(
MODEL_NAME,
seed=0,
dtype="float16",
)
@pytest.fixture(scope="module")
def default_max_tokens():
return 4096
def test_sampling_params_from_request_with_no_guided_decoding_backend(
model_config, default_max_tokens):
# guided_decoding_backend is not present at request level
request = ChatCompletionRequest.model_validate({
'messages': [{
'role': 'user',
'content': 'Hello'
}],
'model':
MODEL_NAME,
'response_format': {
'type': 'json_object',
},
})
sampling_params = request.to_sampling_params(
default_max_tokens,
model_config.logits_processor_pattern,
)
# we do not expect any backend to be present and the default
# guided_decoding_backend at engine level will be used.
assert sampling_params.guided_decoding.backend is None
@pytest.mark.parametrize("request_level_guided_decoding_backend,expected",
[("xgrammar", "xgrammar"), ("guidance", "guidance"),
("outlines", "outlines")])
def test_sampling_params_from_request_with_guided_decoding_backend(
request_level_guided_decoding_backend: str, expected: str,
model_config, default_max_tokens):
request = ChatCompletionRequest.model_validate({
'messages': [{
'role': 'user',
'content': 'Hello'
}],
'model':
MODEL_NAME,
'response_format': {
'type': 'json_object',
},
'guided_decoding_backend':
request_level_guided_decoding_backend,
})
sampling_params = request.to_sampling_params(
default_max_tokens,
model_config.logits_processor_pattern,
)
# backend correctly identified in resulting sampling_params
assert sampling_params.guided_decoding.backend == expected
......@@ -68,7 +68,7 @@ EXAMPLE_TOOLS = [
def _compile_and_check(tools: list[ChatCompletionToolsParam], sample_output,
should_match: bool):
self = MagicMock(tool_choice="required", tools=tools)
schema = ChatCompletionRequest._get_guided_json_from_tool(self)
schema = ChatCompletionRequest._get_json_schema_from_tool(self)
assert isinstance(schema, dict)
# use build_regex_from_schema used in JSONLogitsProcessor to create Guide
......@@ -218,7 +218,7 @@ VALID_TOOLS = [t[0] for t in VALID_TOOL_OUTPUTS]
}
}, {}], False),
])
def test_guided_json(sample_output, should_match):
def test_structured_outputs_json(sample_output, should_match):
_compile_and_check(tools=TypeAdapter(
list[ChatCompletionToolsParam]).validate_python(EXAMPLE_TOOLS),
sample_output=sample_output,
......@@ -273,8 +273,9 @@ def update_parameters_empty_dict(
@pytest.mark.parametrize(
"update_parameters",
[update_parameters_none, update_parameters_empty_dict])
def test_guided_json_without_parameters(sample_output, should_match,
update_parameters):
def test_structured_outputs_json_without_parameters(sample_output,
should_match,
update_parameters):
updated_tools = [deepcopy(EXAMPLE_TOOLS[0])]
tools = TypeAdapter(
list[ChatCompletionToolsParam]).validate_python(updated_tools)
......@@ -334,4 +335,4 @@ def test_streaming_output_valid(output, empty_params, delta_len):
combined_messages += message.tool_calls[0].function.arguments
combined_messages += "}]"
assert json.loads(combined_messages) == output
assert json.dumps(json.loads(combined_messages)) == output_json
\ No newline at end of file
assert json.dumps(json.loads(combined_messages)) == output_json
......@@ -10,7 +10,7 @@ from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
MultiModalKwargsItem, PlaceholderRange)
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
......@@ -1796,11 +1796,11 @@ def test_schedule_skip_tokenizer_init():
def test_schedule_skip_tokenizer_init_structured_output_request():
scheduler = create_scheduler(skip_tokenizer_init=True)
guided_params = GuidedDecodingParams(regex="[0-9]+")
structured_outputs_params = StructuredOutputsParams(regex="[0-9]+")
sampling_params = SamplingParams(
ignore_eos=False,
max_tokens=16,
guided_decoding=guided_params,
structured_outputs=structured_outputs_params,
)
request = Request(
request_id="0",
......
......@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Optional
import pytest
from vllm import LLM
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.v1.metrics.reader import Counter, Gauge, Histogram, Metric, Vector
if TYPE_CHECKING:
......@@ -97,7 +97,7 @@ def _get_test_sampling_params(
top_p=0.95,
n=n,
seed=seed,
guided_decoding=GuidedDecodingParams(
structured_outputs=StructuredOutputsParams(
regex="[0-9]+") if structured_outputs else None,
) for n in n_list
], n_list
......
......@@ -151,7 +151,7 @@ def sample_definition_json_schema():
@pytest.fixture
def sample_guided_choice():
def sample_structured_outputs_choices():
return [
"Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript",
"Ruby", "Swift", "Kotlin"
......
......@@ -15,12 +15,13 @@ import torch
from pydantic import BaseModel
from tests.reasoning.utils import run_reasoning_extraction
from vllm.config import StructuredOutputsConfig
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput
from vllm.platforms import current_platform
from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
if TYPE_CHECKING:
from vllm.config import TokenizerMode
......@@ -90,7 +91,7 @@ def _load_json(s: str, backend: str) -> str:
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize(
"model_name, guided_decoding_backend, tokenizer_mode, speculative_config",
"model_name, backend, tokenizer_mode, speculative_config",
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE)
def test_structured_output(
monkeypatch: pytest.MonkeyPatch,
......@@ -99,8 +100,8 @@ def test_structured_output(
sample_sql_ebnf: str,
sample_sql_lark: str,
sample_regex: str,
sample_guided_choice: str,
guided_decoding_backend: str,
sample_structured_outputs_choices: str,
backend: str,
tokenizer_mode: str,
model_name: str,
speculative_config: dict[str, Any],
......@@ -115,16 +116,15 @@ def test_structured_output(
enforce_eager = bool(not current_platform.is_tpu())
# Use a single LLM instance for several scenarios to
# speed up the test suite.
llm = LLM(
model=model_name,
enforce_eager=enforce_eager,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend,
guided_decoding_disable_any_whitespace=(guided_decoding_backend
in {"xgrammar", "guidance"}),
seed=120,
tokenizer_mode=tokenizer_mode,
speculative_config=speculative_config)
llm = LLM(model=model_name,
enforce_eager=enforce_eager,
max_model_len=1024,
structured_outputs_config=dict(backend=backend,
disable_any_whitespace=backend
in {"xgrammar", "guidance"}),
seed=120,
tokenizer_mode=tokenizer_mode,
speculative_config=speculative_config)
#
# Test 1: Generate JSON output based on a provided schema
......@@ -132,7 +132,7 @@ def test_structured_output(
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=4096,
guided_decoding=GuidedDecodingParams(json=sample_json_schema))
structured_outputs=StructuredOutputsParams(json=sample_json_schema))
prompt = ("Give an example JSON for an employee profile that fits this "
"schema. Make the response as short as possible. Schema: "
......@@ -152,7 +152,7 @@ def test_structured_output(
generated_text = output.outputs[0].text
assert generated_text is not None
if guided_decoding_backend != 'lm-format-enforcer':
if backend != 'lm-format-enforcer':
assert "\n" not in generated_text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
......@@ -161,12 +161,12 @@ def test_structured_output(
#
# Test 2: Generate JSON object without a schema
#
if guided_decoding_backend != "outlines":
if backend != "outlines":
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=4096,
n=2,
guided_decoding=GuidedDecodingParams(json_object=True))
structured_outputs=StructuredOutputsParams(json_object=True))
outputs = llm.generate(prompts=(
"Generate a JSON object with curly braces for a person with "
......@@ -195,8 +195,9 @@ def test_structured_output(
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=4096,
guided_decoding=GuidedDecodingParams(json=unsupported_json_schema))
if guided_decoding_backend.startswith("xgrammar"):
structured_outputs=StructuredOutputsParams(
json=unsupported_json_schema))
if backend.startswith("xgrammar"):
with pytest.raises(ValueError,
match="The provided JSON schema contains features "
"not supported by xgrammar."):
......@@ -230,7 +231,7 @@ def test_structured_output(
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)
if guided_decoding_backend not in ["outlines", "lm-format-enforcer"]:
if backend not in ["outlines", "lm-format-enforcer"]:
#
# Test 4: Generate SQL statement using EBNF grammar
#
......@@ -238,7 +239,8 @@ def test_structured_output(
temperature=0.8,
top_p=0.95,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf))
structured_outputs=StructuredOutputsParams(
grammar=sample_sql_ebnf))
outputs = llm.generate(
("Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1. Make the response as short as "
......@@ -271,7 +273,8 @@ def test_structured_output(
temperature=0.8,
top_p=0.95,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark))
structured_outputs=StructuredOutputsParams(
grammar=sample_sql_lark))
outputs = llm.generate(
("Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1. Make the response as short as "
......@@ -309,7 +312,8 @@ def test_structured_output(
temperature=0.8,
top_p=0.95,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(grammar="not a grammar"))
structured_outputs=StructuredOutputsParams(
grammar="not a grammar"))
with pytest.raises(ValueError, match="Failed to convert the grammar "):
llm.generate(
("Generate a sql statement that selects col_1 from "
......@@ -325,7 +329,7 @@ def test_structured_output(
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(regex=sample_regex))
structured_outputs=StructuredOutputsParams(regex=sample_regex))
prompt = (f"Give an example IPv4 address with this regex: {sample_regex}. "
f"Make the response as short as possible.")
......@@ -352,7 +356,8 @@ def test_structured_output(
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(choice=sample_guided_choice))
structured_outputs=StructuredOutputsParams(
choice=sample_structured_outputs_choices))
outputs = llm.generate(
("The best language for type-safe systems programming is "
......@@ -368,7 +373,7 @@ def test_structured_output(
generated_text = output.outputs[0].text
print(generated_text)
assert generated_text is not None
assert generated_text in sample_guided_choice
assert generated_text in sample_structured_outputs_choices
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
#
......@@ -378,7 +383,7 @@ def test_structured_output(
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(json=json_schema))
structured_outputs=StructuredOutputsParams(json=json_schema))
outputs = llm.generate(
("Generate a JSON with the brand, model and car_type of the most "
......@@ -422,7 +427,7 @@ def test_structured_output(
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=4096,
guided_decoding=GuidedDecodingParams(json=json_schema))
structured_outputs=StructuredOutputsParams(json=json_schema))
outputs = llm.generate(
("Generate a description of a frog using 50 characters. "
......@@ -444,7 +449,7 @@ def test_structured_output(
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=json_schema)
if guided_decoding_backend not in ["outlines", "lm-format-enforcer"]:
if backend not in ["outlines", "lm-format-enforcer"]:
#
# Test 11: Generate structured output using structural_tag format
#
......@@ -470,7 +475,7 @@ def test_structured_output(
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=4096,
guided_decoding=GuidedDecodingParams(
structured_outputs=StructuredOutputsParams(
structural_tag=json.dumps(structural_tag_config)))
prompt = """
......@@ -547,7 +552,7 @@ Make the response as short as possible.
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize(
"model_name, guided_decoding_backend, tokenizer_mode, reasoning_parser, speculative_config", # noqa: E501
"model_name, backend, tokenizer_mode, reasoning_parser, speculative_config", # noqa: E501
[
("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", "xgrammar", "auto",
"deepseek_r1", NGRAM_SPEC_CONFIG),
......@@ -556,7 +561,7 @@ Make the response as short as possible.
)
def test_structured_output_with_reasoning_matrices(
monkeypatch: pytest.MonkeyPatch,
guided_decoding_backend: str,
backend: str,
tokenizer_mode: TokenizerMode,
reasoning_parser: str,
model_name: str,
......@@ -576,10 +581,11 @@ def test_structured_output_with_reasoning_matrices(
enforce_eager=bool(not current_platform.is_tpu()),
max_model_len=1024,
max_num_seqs=16,
guided_decoding_backend=guided_decoding_backend,
guided_decoding_disable_any_whitespace=True,
structured_outputs_config=dict(backend=backend,
disable_any_whitespace=backend
in {"xgrammar", "guidance"},
reasoning_parser=reasoning_parser),
tokenizer_mode=tokenizer_mode,
reasoning_parser=reasoning_parser,
speculative_config=speculative_config,
)
tokenizer = llm.get_tokenizer()
......@@ -603,7 +609,7 @@ def test_structured_output_with_reasoning_matrices(
sampling_params = SamplingParams(
temperature=0.1,
max_tokens=8192,
guided_decoding=GuidedDecodingParams(json=reasoning_schema),
structured_outputs=StructuredOutputsParams(json=reasoning_schema),
)
outputs = llm.generate(
[reasoning_prompt],
......@@ -640,13 +646,14 @@ def test_structured_output_auto_mode(
llm = LLM(model=model_name,
max_model_len=1024,
guided_decoding_backend="auto",
structured_outputs_config=dict(backend="auto"),
tokenizer_mode=tokenizer_mode)
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(json=unsupported_json_schema))
structured_outputs=StructuredOutputsParams(
json=unsupported_json_schema))
prompts = (
"Give an example JSON object for a grade "
......@@ -681,9 +688,10 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch):
llm = LLM(model="Qwen/Qwen2.5-1.5B-Instruct",
max_model_len=1024,
guided_decoding_backend="guidance",
guided_decoding_disable_any_whitespace=True,
guided_decoding_disable_additional_properties=True)
structured_outputs_config=dict(
backend="guidance",
disable_any_whitespace=True,
disable_additional_properties=True))
schema = {
'type': 'object',
......@@ -709,14 +717,15 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch):
"<|im_end|>\n<|im_start|>assistant\n")
def generate_with_backend(backend):
guided_params = GuidedDecodingParams(
structured_outputs_params = StructuredOutputsParams(
json=schema,
backend=backend,
disable_any_whitespace=True,
disable_additional_properties=True)
sampling_params = SamplingParams(temperature=0,
max_tokens=256,
guided_decoding=guided_params)
sampling_params = SamplingParams(
temperature=0,
max_tokens=256,
structured_outputs=structured_outputs_params)
outputs = llm.generate(prompt, sampling_params=sampling_params)
assert outputs is not None
......@@ -736,12 +745,11 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch):
assert "a6" not in generated
@pytest.mark.parametrize("guided_decoding_backend",
["guidance", "xgrammar", "outlines"])
def test_structured_output_batched_with_non_guided_requests(
@pytest.mark.parametrize("backend", ["guidance", "xgrammar", "outlines"])
def test_structured_output_batched_with_non_structured_outputs_requests(
monkeypatch: pytest.MonkeyPatch,
sample_json_schema: dict[str, Any],
guided_decoding_backend: str,
backend: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1")
......@@ -753,24 +761,25 @@ def test_structured_output_batched_with_non_guided_requests(
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
enforce_eager=enforce_eager,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend,
guided_decoding_disable_any_whitespace=(guided_decoding_backend
in {"xgrammar", "guidance"}),
structured_outputs_config=StructuredOutputsConfig(
backend=backend,
disable_any_whitespace=backend in {"xgrammar", "guidance"},
),
)
guided_prompt = (
structured_outputs_prompt = (
"Give an example JSON for an employee profile that fits this "
"schema. Make the response as short as possible. Schema: "
f"{sample_json_schema}")
non_guided_prompt = "The diameter of the Earth in kilometers is "
non_structured_outputs_prompt = "The diameter of the Earth in kilometers is "
prompts = [guided_prompt, non_guided_prompt]
prompts = [structured_outputs_prompt, non_structured_outputs_prompt]
sampling_params = [
SamplingParams(
temperature=1.0,
max_tokens=400,
guided_decoding=GuidedDecodingParams(json=sample_json_schema)),
SamplingParams(temperature=1.0,
max_tokens=400,
structured_outputs=StructuredOutputsParams(
json=sample_json_schema)),
# No max tokens, temp=0 to assert on contents
SamplingParams(
seed=42,
......@@ -801,16 +810,16 @@ def test_structured_output_batched_with_non_guided_requests(
print(f"Prompt:\n{prompt!r}\nGenerated text:\n{generated_text!r}")
if index == 0:
# First prompt is guided, expect valid JSON
# First prompt is structured outputs, expect valid JSON
assert "\n" not in generated_text
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json,
schema=sample_json_schema)
else:
# Second prompt is not guided, expect valid output
# Second prompt is not structured outputs, expect valid output
# Cannot assert on exact output, but we can expect it to be factual
assert "12,742" in generated_text
# non-guided requests should not return a valid JSON here
# non-structured outputs requests should not return a valid JSON here
with pytest.raises(ValueError):
output_json = json.loads(generated_text)
......@@ -77,7 +77,9 @@ async def test_invalid_json_schema(client: openai.AsyncOpenAI,
"role": "user",
"content": prompt,
}],
extra_body={"guided_json": invalid_json_schema},
extra_body={"structured_outputs": {
"json": invalid_json_schema
}},
)
......@@ -99,7 +101,9 @@ async def test_invalid_regex(client: openai.AsyncOpenAI, model_name: str):
"content": prompt,
}],
extra_body={
"guided_regex": r"[.*",
"structured_outputs": {
"regex": r"[.*"
},
"stop": ["\n"]
},
)
......@@ -134,5 +138,9 @@ async def test_invalid_grammar(client: openai.AsyncOpenAI, model_name: str):
"role": "user",
"content": prompt,
}],
extra_body={"guided_grammar": invalid_simplified_sql_grammar},
extra_body={
"structured_outputs": {
"grammar": invalid_simplified_sql_grammar
}
},
)
......@@ -627,7 +627,9 @@ async def test_invalid_json_schema(client: openai.AsyncOpenAI,
await client.completions.create(
model=model_name,
prompt=prompt,
extra_body={"guided_json": invalid_json_schema},
extra_body={"structured_outputs": {
"json": invalid_json_schema
}},
)
......@@ -646,7 +648,9 @@ async def test_invalid_regex(client: openai.AsyncOpenAI, model_name: str):
model=model_name,
prompt=prompt,
extra_body={
"guided_regex": r"[.*",
"structured_outputs": {
"regex": r"[.*"
},
"stop": ["\n"]
},
)
......@@ -678,7 +682,11 @@ async def test_invalid_grammar(client: openai.AsyncOpenAI, model_name: str):
await client.completions.create(
model=model_name,
prompt=prompt,
extra_body={"guided_grammar": invalid_simplified_sql_grammar},
extra_body={
"structured_outputs": {
"grammar": invalid_simplified_sql_grammar
}
},
)
......
......@@ -2277,34 +2277,34 @@ def get_served_model_name(model: str,
return served_model_name
GuidedDecodingBackend = Literal["auto", "xgrammar", "guidance", "outlines",
"lm-format-enforcer"]
StructuredOutputsBackend = Literal["auto", "xgrammar", "guidance", "outlines",
"lm-format-enforcer"]
@config
@dataclass
class DecodingConfig:
"""Dataclass which contains the decoding strategy of the engine."""
class StructuredOutputsConfig:
"""Dataclass which contains structured outputs config for the engine."""
backend: GuidedDecodingBackend = "auto"
"""Which engine will be used for guided decoding (JSON schema / regex etc)
by default. With "auto", we will make opinionated choices based on request
contents and what the backend libraries currently support, so the behavior
is subject to change in each release."""
backend: StructuredOutputsBackend = "auto"
"""Which engine will be used for structured outputs (e.g. JSON schema,
regex, etc) by default. With "auto", we will make opinionated choices
based on request contents and what the backend libraries currently support,
so the behavior is subject to change in each release."""
disable_fallback: bool = False
"""If `True`, vLLM will not fallback to a different backend on error."""
disable_any_whitespace: bool = False
"""If `True`, the model will not generate any whitespace during guided
decoding. This is only supported for xgrammar and guidance backends."""
"""If `True`, the model will not generate any whitespace during structured
outputs. This is only supported for xgrammar and guidance backends."""
disable_additional_properties: bool = False
"""If `True`, the `guidance` backend will not use `additionalProperties`
in the JSON schema. This is only supported for the `guidance` backend and
is used to better align its behaviour with `outlines` and `xgrammar`."""
reasoning_backend: str = ""
reasoning_parser: str = ""
"""Select the reasoning parser depending on the model that you're using.
This is used to parse the reasoning content into OpenAI API format."""
......@@ -2451,8 +2451,9 @@ class VllmConfig:
"""LoRA configuration."""
speculative_config: Optional[SpeculativeConfig] = None
"""Speculative decoding configuration."""
decoding_config: DecodingConfig = field(default_factory=DecodingConfig)
"""Decoding configuration."""
structured_outputs_config: StructuredOutputsConfig = field(
default_factory=StructuredOutputsConfig)
"""Structured outputs configuration."""
observability_config: Optional[ObservabilityConfig] = None
"""Observability configuration."""
quant_config: Optional[QuantizationConfig] = None
......@@ -2543,8 +2544,8 @@ class VllmConfig:
vllm_factors.append(self.speculative_config.compute_hash())
else:
vllm_factors.append("None")
if self.decoding_config:
vllm_factors.append(self.decoding_config.compute_hash())
if self.structured_outputs_config:
vllm_factors.append(self.structured_outputs_config.compute_hash())
else:
vllm_factors.append("None")
if self.observability_config:
......@@ -3063,7 +3064,7 @@ class VllmConfig:
f"enforce_eager={self.model_config.enforce_eager}, "
f"kv_cache_dtype={self.cache_config.cache_dtype}, "
f"device_config={self.device_config.device}, "
f"decoding_config={self.decoding_config!r}, "
f"structured_outputs_config={self.structured_outputs_config!r}, "
f"observability_config={self.observability_config!r}, "
f"seed={self.model_config.seed}, "
f"served_model_name={self.model_config.served_model_name}, "
......
......@@ -22,17 +22,16 @@ from typing_extensions import TypeIs, deprecated
import vllm.envs as envs
from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
ConfigType, ConvertOption, DecodingConfig,
DetailedTraceModules, Device, DeviceConfig,
DistributedExecutorBackend, EPLBConfig,
GuidedDecodingBackend, HfOverrides, KVEventsConfig,
ConfigType, ConvertOption, DetailedTraceModules,
Device, DeviceConfig, DistributedExecutorBackend,
EPLBConfig, HfOverrides, KVEventsConfig,
KVTransferConfig, LoadConfig, LogprobsMode,
LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig,
ModelDType, ModelImpl, ObservabilityConfig,
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
RunnerOption, SchedulerConfig, SchedulerPolicy,
SpeculativeConfig, TaskOption, TokenizerMode,
VllmConfig, get_attr_docs)
SpeculativeConfig, StructuredOutputsConfig,
TaskOption, TokenizerMode, VllmConfig, get_attr_docs)
from vllm.config.multimodal import MMCacheType, MultiModalConfig
from vllm.config.parallel import ExpertPlacementStrategy
from vllm.config.utils import get_field
......@@ -418,12 +417,15 @@ class EngineArgs:
disable_hybrid_kv_cache_manager: bool = (
SchedulerConfig.disable_hybrid_kv_cache_manager)
guided_decoding_backend: GuidedDecodingBackend = DecodingConfig.backend
guided_decoding_disable_fallback: bool = DecodingConfig.disable_fallback
guided_decoding_disable_any_whitespace: bool = \
DecodingConfig.disable_any_whitespace
guided_decoding_disable_additional_properties: bool = \
DecodingConfig.disable_additional_properties
structured_outputs_config: StructuredOutputsConfig = get_field(
VllmConfig, "structured_outputs_config")
reasoning_parser: str = StructuredOutputsConfig.reasoning_parser
# Deprecated guided decoding fields
guided_decoding_backend: Optional[str] = None
guided_decoding_disable_fallback: Optional[bool] = None
guided_decoding_disable_any_whitespace: Optional[bool] = None
guided_decoding_disable_additional_properties: Optional[bool] = None
logits_processor_pattern: Optional[
str] = ModelConfig.logits_processor_pattern
......@@ -462,7 +464,6 @@ class EngineArgs:
additional_config: dict[str, Any] = \
get_field(VllmConfig, "additional_config")
reasoning_parser: str = DecodingConfig.reasoning_backend
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
pt_load_map_location: str = LoadConfig.pt_load_map_location
......@@ -618,28 +619,29 @@ class EngineArgs:
load_group.add_argument('--pt-load-map-location',
**load_kwargs["pt_load_map_location"])
# Guided decoding arguments
guided_decoding_kwargs = get_kwargs(DecodingConfig)
guided_decoding_group = parser.add_argument_group(
title="DecodingConfig",
description=DecodingConfig.__doc__,
# Structured outputs arguments
structured_outputs_kwargs = get_kwargs(StructuredOutputsConfig)
structured_outputs_group = parser.add_argument_group(
title="StructuredOutputsConfig",
description=StructuredOutputsConfig.__doc__,
)
guided_decoding_group.add_argument("--guided-decoding-backend",
**guided_decoding_kwargs["backend"])
guided_decoding_group.add_argument(
"--guided-decoding-disable-fallback",
**guided_decoding_kwargs["disable_fallback"])
guided_decoding_group.add_argument(
"--guided-decoding-disable-any-whitespace",
**guided_decoding_kwargs["disable_any_whitespace"])
guided_decoding_group.add_argument(
"--guided-decoding-disable-additional-properties",
**guided_decoding_kwargs["disable_additional_properties"])
guided_decoding_group.add_argument(
structured_outputs_group.add_argument(
"--reasoning-parser",
# This choice is a special case because it's not static
choices=list(ReasoningParserManager.reasoning_parsers),
**guided_decoding_kwargs["reasoning_backend"])
**structured_outputs_kwargs["reasoning_parser"])
# Deprecated guided decoding arguments
for arg, type in [
("--guided-decoding-backend", str),
("--guided-decoding-disable-fallback", bool),
("--guided-decoding-disable-any-whitespace", bool),
("--guided-decoding-disable-additional-properties", bool),
]:
structured_outputs_group.add_argument(
arg,
type=type,
help=(f"[DEPRECATED] {arg} will be removed in v0.12.0."),
deprecated=True)
# Parallel arguments
parallel_kwargs = get_kwargs(ParallelConfig)
......@@ -934,6 +936,8 @@ class EngineArgs:
**vllm_kwargs["compilation_config"])
vllm_group.add_argument("--additional-config",
**vllm_kwargs["additional_config"])
vllm_group.add_argument('--structured-outputs-config',
**vllm_kwargs["structured_outputs_config"])
# Other arguments
parser.add_argument('--disable-log-stats',
......@@ -1421,14 +1425,25 @@ class EngineArgs:
load_config = self.create_load_config()
decoding_config = DecodingConfig(
backend=self.guided_decoding_backend,
disable_fallback=self.guided_decoding_disable_fallback,
disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
disable_additional_properties=\
self.guided_decoding_disable_additional_properties,
reasoning_backend=self.reasoning_parser
)
# Pass reasoning_parser into StructuredOutputsConfig
if self.reasoning_parser:
self.structured_outputs_config.reasoning_parser = \
self.reasoning_parser
# Forward the deprecated CLI args to the StructuredOutputsConfig
so_config = self.structured_outputs_config
if self.guided_decoding_backend is not None:
so_config.guided_decoding_backend = \
self.guided_decoding_backend
if self.guided_decoding_disable_fallback is not None:
so_config.guided_decoding_disable_fallback = \
self.guided_decoding_disable_fallback
if self.guided_decoding_disable_any_whitespace is not None:
so_config.guided_decoding_disable_any_whitespace = \
self.guided_decoding_disable_any_whitespace
if self.guided_decoding_disable_additional_properties is not None:
so_config.guided_decoding_disable_additional_properties = \
self.guided_decoding_disable_additional_properties
observability_config = ObservabilityConfig(
show_hidden_metrics_for_version=(
......@@ -1446,7 +1461,7 @@ class EngineArgs:
lora_config=lora_config,
speculative_config=speculative_config,
load_config=load_config,
decoding_config=decoding_config,
structured_outputs_config=self.structured_outputs_config,
observability_config=observability_config,
compilation_config=self.compilation_config,
kv_transfer_config=self.kv_transfer_config,
......
......@@ -10,9 +10,8 @@ from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
from weakref import ReferenceType
import vllm.envs as envs
from vllm.config import (DecodingConfig, ModelConfig, ParallelConfig,
from vllm.config import (LoRAConfig, ModelConfig, ParallelConfig,
SchedulerConfig, VllmConfig)
from vllm.config.lora import LoRAConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout
......@@ -955,10 +954,6 @@ class AsyncLLMEngine(EngineClient):
"""Get the parallel configuration of the vLLM engine."""
return self.engine.get_parallel_config()
async def get_decoding_config(self) -> DecodingConfig:
"""Get the decoding configuration of the vLLM engine."""
return self.engine.get_decoding_config()
async def get_scheduler_config(self) -> SchedulerConfig:
"""Get the scheduling configuration of the vLLM engine."""
return self.engine.get_scheduler_config()
......
......@@ -16,9 +16,8 @@ import torch
from typing_extensions import TypeVar
import vllm.envs as envs
from vllm.config import (DecodingConfig, ModelConfig, ObservabilityConfig,
from vllm.config import (LoRAConfig, ModelConfig, ObservabilityConfig,
ParallelConfig, SchedulerConfig, VllmConfig)
from vllm.config.lora import LoRAConfig
from vllm.core.scheduler import ScheduledSequenceGroup, SchedulerOutputs
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics_types import StatLoggerBase, Stats
......@@ -213,8 +212,7 @@ class LLMEngine:
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config # noqa
self.load_config = vllm_config.load_config
self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa
)
self.structured_outputs_config = vllm_config.structured_outputs_config
self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa
)
......@@ -364,10 +362,9 @@ class LLMEngine:
self.observability_config.otlp_traces_endpoint)
# Initialize reasoning parser if reasoning backend is set.
if self.decoding_config.reasoning_backend and \
self.tokenizer:
if self.structured_outputs_config.reasoning_parser and self.tokenizer:
reasoner_class = ReasoningParserManager.get_reasoning_parser(
self.decoding_config.reasoning_backend)
self.structured_outputs_config.reasoning_parser)
self.reasoner: ReasoningParser = reasoner_class(
self.tokenizer.get_lora_tokenizer())
......@@ -381,7 +378,8 @@ class LLMEngine:
self.seq_counter,
stop_checker=StopChecker(
self.scheduler_config.max_model_len,
self.reasoner if self.decoding_config.reasoning_backend
self.reasoner
if self.structured_outputs_config.reasoning_parser
and self.tokenizer else None,
),
))
......@@ -772,10 +770,6 @@ class LLMEngine:
"""Gets the parallel configuration."""
return self.parallel_config
def get_decoding_config(self) -> DecodingConfig:
"""Gets the decoding configuration."""
return self.decoding_config
def get_scheduler_config(self) -> SchedulerConfig:
"""Gets the scheduler configuration."""
return self.scheduler_config
......
......@@ -6,7 +6,7 @@ from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Iterable, Mapping, Optional, Union
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.config import DecodingConfig, ModelConfig, VllmConfig
from vllm.config import ModelConfig, VllmConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
......@@ -248,11 +248,6 @@ class EngineClient(ABC):
"""Get the model configuration of the vLLM engine."""
...
@abstractmethod
async def get_decoding_config(self) -> DecodingConfig:
"""Get the decoding configuration of the vLLM engine."""
...
@abstractmethod
async def get_input_preprocessor(self) -> InputPreprocessor:
"""Get the input processor of the vLLM engine."""
......
......@@ -15,8 +15,8 @@ import vllm.envs as envs
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence,
create_sort_beams_key_function)
from vllm.config import (CompilationConfig, ModelDType, TokenizerMode,
is_init_field)
from vllm.config import (CompilationConfig, ModelDType,
StructuredOutputsConfig, TokenizerMode, is_init_field)
from vllm.engine.arg_utils import (ConvertOption, EngineArgs, HfOverrides,
PoolerConfig, RunnerOption)
from vllm.engine.llm_engine import LLMEngine
......@@ -192,6 +192,8 @@ class LLM:
hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[dict[str, Any]] = None,
override_pooler_config: Optional[PoolerConfig] = None,
structured_outputs_config: Optional[Union[dict[
str, Any], StructuredOutputsConfig]] = None,
kv_cache_memory_bytes: Optional[int] = None,
compilation_config: Optional[Union[int, dict[str, Any],
CompilationConfig]] = None,
......@@ -236,14 +238,30 @@ class LLM:
compilation_config_instance = CompilationConfig(
level=compilation_config)
elif isinstance(compilation_config, dict):
predicate = lambda x: is_init_field(CompilationConfig, x[0])
compilation_config_instance = CompilationConfig(
**dict(filter(predicate, compilation_config.items())))
**{
k: v
for k, v in compilation_config.items()
if is_init_field(CompilationConfig, k)
})
else:
compilation_config_instance = compilation_config
else:
compilation_config_instance = CompilationConfig()
if structured_outputs_config is not None:
if isinstance(structured_outputs_config, dict):
structured_outputs_instance = StructuredOutputsConfig(
**{
k: v
for k, v in structured_outputs_config.items()
if is_init_field(StructuredOutputsConfig, k)
})
else:
structured_outputs_instance = structured_outputs_config
else:
structured_outputs_instance = StructuredOutputsConfig()
engine_args = EngineArgs(
model=model,
runner=runner,
......@@ -271,6 +289,7 @@ class LLM:
hf_overrides=hf_overrides,
mm_processor_kwargs=mm_processor_kwargs,
override_pooler_config=override_pooler_config,
structured_outputs_config=structured_outputs_instance,
compilation_config=compilation_config_instance,
logits_processors=logits_processors,
**kwargs,
......
......@@ -1678,7 +1678,7 @@ async def init_app_state(
enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser,
tool_server=tool_server,
reasoning_parser=args.reasoning_parser,
reasoning_parser=args.structured_outputs_config.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
enable_log_outputs=args.enable_log_outputs,
......@@ -1697,7 +1697,7 @@ async def init_app_state(
exclude_tools_when_tool_choice_none=args.
exclude_tools_when_tool_choice_none,
tool_parser=args.tool_call_parser,
reasoning_parser=args.reasoning_parser,
reasoning_parser=args.structured_outputs_config.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
enable_log_outputs=args.enable_log_outputs,
......@@ -1800,10 +1800,10 @@ def validate_api_server_args(args):
f"(chose from {{ {','.join(valid_tool_parses)} }})")
valid_reasoning_parses = ReasoningParserManager.reasoning_parsers.keys()
if args.reasoning_parser \
and args.reasoning_parser not in valid_reasoning_parses:
if ((reasoning_parser := args.structured_outputs_config.reasoning_parser)
and reasoning_parser not in valid_reasoning_parses):
raise KeyError(
f"invalid reasoning parser: {args.reasoning_parser} "
f"invalid reasoning parser: {reasoning_parser} "
f"(chose from {{ {','.join(valid_reasoning_parses)} }})")
......
......@@ -54,8 +54,8 @@ from vllm.entrypoints.score_utils import (ScoreContentPartParam,
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
RequestOutputKind, SamplingParams)
from vllm.sampling_params import (BeamSearchParams, RequestOutputKind,
SamplingParams, StructuredOutputsParams)
from vllm.utils import random_uuid, resolve_obj_by_qualname
logger = init_logger(__name__)
......@@ -373,11 +373,12 @@ class ResponsesRequest(OpenAIBaseModel):
stop_token_ids = default_sampling_params.get("stop_token_ids")
# Structured output
guided_decoding = None
structured_outputs = None
if self.text is not None and self.text.format is not None:
response_format = self.text.format
if response_format.type == "json_schema":
guided_decoding = GuidedDecodingParams.from_optional(
if (response_format.type == "json_schema"
and response_format.schema_ is not None):
structured_outputs = StructuredOutputsParams(
json=response_format.schema_)
elif response_format.type == "json_object":
raise NotImplementedError("json_object is not supported")
......@@ -392,7 +393,7 @@ class ResponsesRequest(OpenAIBaseModel):
stop_token_ids=stop_token_ids,
output_kind=(RequestOutputKind.DELTA
if self.stream else RequestOutputKind.FINAL_ONLY),
guided_decoding=guided_decoding,
structured_outputs=structured_outputs,
)
def is_include_output_logprobs(self) -> bool:
......@@ -547,42 +548,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
default=None,
description=("Additional kwargs to pass to the HF processor."),
)
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
structured_outputs: Optional[StructuredOutputsParams] = Field(
default=None,
description=("If specified, the output will follow the JSON schema."),
)
guided_regex: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the regex pattern."),
)
guided_choice: Optional[list[str]] = Field(
default=None,
description=(
"If specified, the output will be exactly one of the choices."),
)
guided_grammar: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the context free grammar."),
)
structural_tag: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the structural tag schema."),
)
guided_decoding_backend: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be either "
"'outlines' / 'lm-format-enforcer'"),
)
guided_whitespace_pattern: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."),
description="Additional kwargs for structured outputs",
)
priority: int = Field(
default=0,
......@@ -701,31 +669,33 @@ class ChatCompletionRequest(OpenAIBaseModel):
if prompt_logprobs is None and self.echo:
prompt_logprobs = self.top_logprobs
guided_json_object = None
if self.response_format is not None:
if self.response_format.type == "json_object":
guided_json_object = True
elif self.response_format.type == "json_schema":
json_schema = self.response_format.json_schema
assert json_schema is not None
self.guided_json = json_schema.json_schema
elif self.response_format.type == "structural_tag":
structural_tag = self.response_format
assert structural_tag is not None and isinstance(
structural_tag, StructuralTagResponseFormat)
s_tag_obj = structural_tag.model_dump(by_alias=True)
self.structural_tag = json.dumps(s_tag_obj)
guided_decoding = GuidedDecodingParams.from_optional(
json=self._get_guided_json_from_tool() or self.guided_json,
regex=self.guided_regex,
choice=self.guided_choice,
grammar=self.guided_grammar,
json_object=guided_json_object,
backend=self.guided_decoding_backend,
whitespace_pattern=self.guided_whitespace_pattern,
structural_tag=self.structural_tag,
)
response_format = self.response_format
json_schema_from_tool = self._get_json_schema_from_tool()
if response_format is not None or json_schema_from_tool is not None:
# If structured outputs wasn't already enabled,
# we must enable it for these features to work
if self.structured_outputs is None:
self.structured_outputs = StructuredOutputsParams()
# Set structured output params for response format
if response_format is not None:
if response_format.type == "json_object":
self.structured_outputs.json_object = True
elif response_format.type == "json_schema":
json_schema = response_format.json_schema
assert json_schema is not None
self.structured_outputs.json = json_schema.json_schema
elif response_format.type == "structural_tag":
structural_tag = response_format
assert structural_tag is not None and isinstance(
structural_tag, StructuralTagResponseFormat)
s_tag_obj = structural_tag.model_dump(by_alias=True)
self.structured_outputs.structural_tag = json.dumps(
s_tag_obj)
# Set structured output params for tool calling
if json_schema_from_tool is not None:
self.structured_outputs.json = json_schema_from_tool
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
if self.kv_transfer_params:
......@@ -757,15 +727,14 @@ class ChatCompletionRequest(OpenAIBaseModel):
truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY,
guided_decoding=guided_decoding,
structured_outputs=self.structured_outputs,
logit_bias=self.logit_bias,
bad_words= self.bad_words,
bad_words=self.bad_words,
allowed_token_ids=self.allowed_token_ids,
extra_args=extra_args or None,
)
def _get_guided_json_from_tool(
self) -> Optional[Union[str, dict, BaseModel]]:
def _get_json_schema_from_tool(self) -> Optional[Union[str, dict]]:
# user has chosen to not use any tool
if self.tool_choice == "none" or self.tools is None:
return None
......@@ -875,28 +844,31 @@ class ChatCompletionRequest(OpenAIBaseModel):
@model_validator(mode="before")
@classmethod
def check_guided_decoding_count(cls, data):
def check_structured_outputs_count(cls, data):
if isinstance(data, ValueError):
raise data
guide_count = sum([
"guided_json" in data and data["guided_json"] is not None,
"guided_regex" in data and data["guided_regex"] is not None,
"guided_choice" in data and data["guided_choice"] is not None
])
# you can only use one kind of guided decoding
if guide_count > 1:
if "structured_outputs" not in data:
return data
structured_outputs_kwargs = data['structured_outputs']
count = sum(
structured_outputs_kwargs.get(k) is not None
for k in ("json", "regex", "choice"))
# you can only use one kind of constraints for structured outputs
if count > 1:
raise ValueError(
"You can only use one kind of guided decoding "
"('guided_json', 'guided_regex' or 'guided_choice').")
# you can only either use guided decoding or tools, not both
if guide_count > 1 and data.get("tool_choice", "none") not in (
"You can only use one kind of constraints for structured "
"outputs ('json', 'regex' or 'choice').")
# you can only either use structured outputs or tools, not both
if count > 1 and data.get("tool_choice", "none") not in (
"none",
"auto",
"required",
):
raise ValueError(
"You can only either use guided decoding or tools, not both.")
"You can only either use constraints for structured outputs "
"or tools, not both.")
return data
@model_validator(mode="before")
......@@ -1049,37 +1021,9 @@ class CompletionRequest(OpenAIBaseModel):
", {'type': 'structural_tag'}, or {'type': 'text' } is supported."
),
)
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
default=None,
description="If specified, the output will follow the JSON schema.",
)
guided_regex: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the regex pattern."),
)
guided_choice: Optional[list[str]] = Field(
structured_outputs: Optional[StructuredOutputsParams] = Field(
default=None,
description=(
"If specified, the output will be exactly one of the choices."),
)
guided_grammar: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the context free grammar."),
)
guided_decoding_backend: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be one of "
"'outlines' / 'lm-format-enforcer'"),
)
guided_whitespace_pattern: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."),
description="Additional kwargs for structured outputs",
)
priority: int = Field(
default=0,
......@@ -1210,20 +1154,10 @@ class CompletionRequest(OpenAIBaseModel):
echo_without_generation = self.echo and self.max_tokens == 0
guided_json_object = None
if (self.response_format is not None
if (self.structured_outputs is not None
and self.response_format is not None
and self.response_format.type == "json_object"):
guided_json_object = True
guided_decoding = GuidedDecodingParams.from_optional(
json=self.guided_json,
regex=self.guided_regex,
choice=self.guided_choice,
grammar=self.guided_grammar,
json_object=guided_json_object,
backend=self.guided_decoding_backend,
whitespace_pattern=self.guided_whitespace_pattern,
)
self.structured_outputs.json_object = True
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
if self.kv_transfer_params:
......@@ -1255,7 +1189,7 @@ class CompletionRequest(OpenAIBaseModel):
truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY,
guided_decoding=guided_decoding,
structured_outputs=self.structured_outputs,
logit_bias=self.logit_bias,
allowed_token_ids=self.allowed_token_ids,
extra_args=extra_args or None,
......@@ -1263,16 +1197,18 @@ class CompletionRequest(OpenAIBaseModel):
@model_validator(mode="before")
@classmethod
def check_guided_decoding_count(cls, data):
guide_count = sum([
"guided_json" in data and data["guided_json"] is not None,
"guided_regex" in data and data["guided_regex"] is not None,
"guided_choice" in data and data["guided_choice"] is not None
])
if guide_count > 1:
def check_structured_outputs_count(cls, data):
if "structured_outputs" not in data:
return data
structured_outputs_kwargs = data['structured_outputs']
count = sum(
structured_outputs_kwargs.get(k) is not None
for k in ("json", "regex", "choice"))
if count > 1:
raise ValueError(
"You can only use one kind of guided decoding "
"('guided_json', 'guided_regex' or 'guided_choice').")
"You can only use one kind of constraints for structured "
"outputs ('json', 'regex' or 'choice').")
return data
@model_validator(mode="before")
......
......@@ -993,7 +993,7 @@ class OpenAIServingChat(OpenAIServing):
# check to make sure we haven't "forgotten" to stream
# any tokens that were generated but previously
# matched by partial json parsing
# only happens if we are NOT using guided decoding
# only happens if we are NOT using structured outputs
auto_tools_called = False
if tool_parser:
auto_tools_called = len(
......
......@@ -262,9 +262,9 @@ class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
decoding_config = vllm_config.decoding_config
if decoding_config.reasoning_backend == "":
decoding_config.reasoning_backend = "openai_gptoss"
structured_outputs_config = vllm_config.structured_outputs_config
if structured_outputs_config.reasoning_parser == "":
structured_outputs_config.reasoning_parser = "openai_gptoss"
# Increase the max capture size from 512 to 1024 for performance.
# NOTE(woosuk): This will increase the number of CUDA graphs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment