Unverified Commit 29283e89 authored by Aaron Pham's avatar Aaron Pham Committed by GitHub
Browse files

[Chore] Cleanup guided namespace, move to structured outputs config (#22772)


Signed-off-by: default avatarAaron Pham <contact@aarnphm.xyz>
Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 05b044e6
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# imports for guided decoding tests # imports for structured outputs tests
import io import io
import json import json
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import io import io
# imports for guided decoding tests # imports for structured outputs tests
import json import json
import httpx import httpx
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for the SamplingParams class.
"""
import pytest
from vllm import SamplingParams
from vllm.config import ModelConfig
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
MODEL_NAME = "Qwen/Qwen1.5-7B"
def test_max_tokens_none():
"""max_tokens=None should be allowed"""
SamplingParams(temperature=0.01, top_p=0.1, max_tokens=None)
@pytest.fixture(scope="module")
def model_config():
return ModelConfig(
MODEL_NAME,
seed=0,
dtype="float16",
)
@pytest.fixture(scope="module")
def default_max_tokens():
return 4096
def test_sampling_params_from_request_with_no_guided_decoding_backend(
model_config, default_max_tokens):
# guided_decoding_backend is not present at request level
request = ChatCompletionRequest.model_validate({
'messages': [{
'role': 'user',
'content': 'Hello'
}],
'model':
MODEL_NAME,
'response_format': {
'type': 'json_object',
},
})
sampling_params = request.to_sampling_params(
default_max_tokens,
model_config.logits_processor_pattern,
)
# we do not expect any backend to be present and the default
# guided_decoding_backend at engine level will be used.
assert sampling_params.guided_decoding.backend is None
@pytest.mark.parametrize("request_level_guided_decoding_backend,expected",
[("xgrammar", "xgrammar"), ("guidance", "guidance"),
("outlines", "outlines")])
def test_sampling_params_from_request_with_guided_decoding_backend(
request_level_guided_decoding_backend: str, expected: str,
model_config, default_max_tokens):
request = ChatCompletionRequest.model_validate({
'messages': [{
'role': 'user',
'content': 'Hello'
}],
'model':
MODEL_NAME,
'response_format': {
'type': 'json_object',
},
'guided_decoding_backend':
request_level_guided_decoding_backend,
})
sampling_params = request.to_sampling_params(
default_max_tokens,
model_config.logits_processor_pattern,
)
# backend correctly identified in resulting sampling_params
assert sampling_params.guided_decoding.backend == expected
...@@ -68,7 +68,7 @@ EXAMPLE_TOOLS = [ ...@@ -68,7 +68,7 @@ EXAMPLE_TOOLS = [
def _compile_and_check(tools: list[ChatCompletionToolsParam], sample_output, def _compile_and_check(tools: list[ChatCompletionToolsParam], sample_output,
should_match: bool): should_match: bool):
self = MagicMock(tool_choice="required", tools=tools) self = MagicMock(tool_choice="required", tools=tools)
schema = ChatCompletionRequest._get_guided_json_from_tool(self) schema = ChatCompletionRequest._get_json_schema_from_tool(self)
assert isinstance(schema, dict) assert isinstance(schema, dict)
# use build_regex_from_schema used in JSONLogitsProcessor to create Guide # use build_regex_from_schema used in JSONLogitsProcessor to create Guide
...@@ -218,7 +218,7 @@ VALID_TOOLS = [t[0] for t in VALID_TOOL_OUTPUTS] ...@@ -218,7 +218,7 @@ VALID_TOOLS = [t[0] for t in VALID_TOOL_OUTPUTS]
} }
}, {}], False), }, {}], False),
]) ])
def test_guided_json(sample_output, should_match): def test_structured_outputs_json(sample_output, should_match):
_compile_and_check(tools=TypeAdapter( _compile_and_check(tools=TypeAdapter(
list[ChatCompletionToolsParam]).validate_python(EXAMPLE_TOOLS), list[ChatCompletionToolsParam]).validate_python(EXAMPLE_TOOLS),
sample_output=sample_output, sample_output=sample_output,
...@@ -273,8 +273,9 @@ def update_parameters_empty_dict( ...@@ -273,8 +273,9 @@ def update_parameters_empty_dict(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"update_parameters", "update_parameters",
[update_parameters_none, update_parameters_empty_dict]) [update_parameters_none, update_parameters_empty_dict])
def test_guided_json_without_parameters(sample_output, should_match, def test_structured_outputs_json_without_parameters(sample_output,
update_parameters): should_match,
update_parameters):
updated_tools = [deepcopy(EXAMPLE_TOOLS[0])] updated_tools = [deepcopy(EXAMPLE_TOOLS[0])]
tools = TypeAdapter( tools = TypeAdapter(
list[ChatCompletionToolsParam]).validate_python(updated_tools) list[ChatCompletionToolsParam]).validate_python(updated_tools)
...@@ -334,4 +335,4 @@ def test_streaming_output_valid(output, empty_params, delta_len): ...@@ -334,4 +335,4 @@ def test_streaming_output_valid(output, empty_params, delta_len):
combined_messages += message.tool_calls[0].function.arguments combined_messages += message.tool_calls[0].function.arguments
combined_messages += "}]" combined_messages += "}]"
assert json.loads(combined_messages) == output assert json.loads(combined_messages) == output
assert json.dumps(json.loads(combined_messages)) == output_json assert json.dumps(json.loads(combined_messages)) == output_json
\ No newline at end of file
...@@ -10,7 +10,7 @@ from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, ...@@ -10,7 +10,7 @@ from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig) SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import (MultiModalFeatureSpec, from vllm.multimodal.inputs import (MultiModalFeatureSpec,
MultiModalKwargsItem, PlaceholderRange) MultiModalKwargsItem, PlaceholderRange)
from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
...@@ -1796,11 +1796,11 @@ def test_schedule_skip_tokenizer_init(): ...@@ -1796,11 +1796,11 @@ def test_schedule_skip_tokenizer_init():
def test_schedule_skip_tokenizer_init_structured_output_request(): def test_schedule_skip_tokenizer_init_structured_output_request():
scheduler = create_scheduler(skip_tokenizer_init=True) scheduler = create_scheduler(skip_tokenizer_init=True)
guided_params = GuidedDecodingParams(regex="[0-9]+") structured_outputs_params = StructuredOutputsParams(regex="[0-9]+")
sampling_params = SamplingParams( sampling_params = SamplingParams(
ignore_eos=False, ignore_eos=False,
max_tokens=16, max_tokens=16,
guided_decoding=guided_params, structured_outputs=structured_outputs_params,
) )
request = Request( request = Request(
request_id="0", request_id="0",
......
...@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Optional ...@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Optional
import pytest import pytest
from vllm import LLM from vllm import LLM
from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.v1.metrics.reader import Counter, Gauge, Histogram, Metric, Vector from vllm.v1.metrics.reader import Counter, Gauge, Histogram, Metric, Vector
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -97,7 +97,7 @@ def _get_test_sampling_params( ...@@ -97,7 +97,7 @@ def _get_test_sampling_params(
top_p=0.95, top_p=0.95,
n=n, n=n,
seed=seed, seed=seed,
guided_decoding=GuidedDecodingParams( structured_outputs=StructuredOutputsParams(
regex="[0-9]+") if structured_outputs else None, regex="[0-9]+") if structured_outputs else None,
) for n in n_list ) for n in n_list
], n_list ], n_list
......
...@@ -151,7 +151,7 @@ def sample_definition_json_schema(): ...@@ -151,7 +151,7 @@ def sample_definition_json_schema():
@pytest.fixture @pytest.fixture
def sample_guided_choice(): def sample_structured_outputs_choices():
return [ return [
"Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript", "Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript",
"Ruby", "Swift", "Kotlin" "Ruby", "Swift", "Kotlin"
......
...@@ -15,12 +15,13 @@ import torch ...@@ -15,12 +15,13 @@ import torch
from pydantic import BaseModel from pydantic import BaseModel
from tests.reasoning.utils import run_reasoning_extraction from tests.reasoning.utils import run_reasoning_extraction
from vllm.config import StructuredOutputsConfig
from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed import cleanup_dist_env_and_memory
from vllm.entrypoints.llm import LLM from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager
from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.sampling_params import SamplingParams, StructuredOutputsParams
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import TokenizerMode from vllm.config import TokenizerMode
...@@ -90,7 +91,7 @@ def _load_json(s: str, backend: str) -> str: ...@@ -90,7 +91,7 @@ def _load_json(s: str, backend: str) -> str:
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name, guided_decoding_backend, tokenizer_mode, speculative_config", "model_name, backend, tokenizer_mode, speculative_config",
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE) PARAMS_MODELS_BACKENDS_TOKENIZER_MODE)
def test_structured_output( def test_structured_output(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
...@@ -99,8 +100,8 @@ def test_structured_output( ...@@ -99,8 +100,8 @@ def test_structured_output(
sample_sql_ebnf: str, sample_sql_ebnf: str,
sample_sql_lark: str, sample_sql_lark: str,
sample_regex: str, sample_regex: str,
sample_guided_choice: str, sample_structured_outputs_choices: str,
guided_decoding_backend: str, backend: str,
tokenizer_mode: str, tokenizer_mode: str,
model_name: str, model_name: str,
speculative_config: dict[str, Any], speculative_config: dict[str, Any],
...@@ -115,16 +116,15 @@ def test_structured_output( ...@@ -115,16 +116,15 @@ def test_structured_output(
enforce_eager = bool(not current_platform.is_tpu()) enforce_eager = bool(not current_platform.is_tpu())
# Use a single LLM instance for several scenarios to # Use a single LLM instance for several scenarios to
# speed up the test suite. # speed up the test suite.
llm = LLM( llm = LLM(model=model_name,
model=model_name, enforce_eager=enforce_eager,
enforce_eager=enforce_eager, max_model_len=1024,
max_model_len=1024, structured_outputs_config=dict(backend=backend,
guided_decoding_backend=guided_decoding_backend, disable_any_whitespace=backend
guided_decoding_disable_any_whitespace=(guided_decoding_backend in {"xgrammar", "guidance"}),
in {"xgrammar", "guidance"}), seed=120,
seed=120, tokenizer_mode=tokenizer_mode,
tokenizer_mode=tokenizer_mode, speculative_config=speculative_config)
speculative_config=speculative_config)
# #
# Test 1: Generate JSON output based on a provided schema # Test 1: Generate JSON output based on a provided schema
...@@ -132,7 +132,7 @@ def test_structured_output( ...@@ -132,7 +132,7 @@ def test_structured_output(
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=1.0, temperature=1.0,
max_tokens=4096, max_tokens=4096,
guided_decoding=GuidedDecodingParams(json=sample_json_schema)) structured_outputs=StructuredOutputsParams(json=sample_json_schema))
prompt = ("Give an example JSON for an employee profile that fits this " prompt = ("Give an example JSON for an employee profile that fits this "
"schema. Make the response as short as possible. Schema: " "schema. Make the response as short as possible. Schema: "
...@@ -152,7 +152,7 @@ def test_structured_output( ...@@ -152,7 +152,7 @@ def test_structured_output(
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
assert generated_text is not None assert generated_text is not None
if guided_decoding_backend != 'lm-format-enforcer': if backend != 'lm-format-enforcer':
assert "\n" not in generated_text assert "\n" not in generated_text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text) output_json = json.loads(generated_text)
...@@ -161,12 +161,12 @@ def test_structured_output( ...@@ -161,12 +161,12 @@ def test_structured_output(
# #
# Test 2: Generate JSON object without a schema # Test 2: Generate JSON object without a schema
# #
if guided_decoding_backend != "outlines": if backend != "outlines":
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=1.0, temperature=1.0,
max_tokens=4096, max_tokens=4096,
n=2, n=2,
guided_decoding=GuidedDecodingParams(json_object=True)) structured_outputs=StructuredOutputsParams(json_object=True))
outputs = llm.generate(prompts=( outputs = llm.generate(prompts=(
"Generate a JSON object with curly braces for a person with " "Generate a JSON object with curly braces for a person with "
...@@ -195,8 +195,9 @@ def test_structured_output( ...@@ -195,8 +195,9 @@ def test_structured_output(
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=1.0, temperature=1.0,
max_tokens=4096, max_tokens=4096,
guided_decoding=GuidedDecodingParams(json=unsupported_json_schema)) structured_outputs=StructuredOutputsParams(
if guided_decoding_backend.startswith("xgrammar"): json=unsupported_json_schema))
if backend.startswith("xgrammar"):
with pytest.raises(ValueError, with pytest.raises(ValueError,
match="The provided JSON schema contains features " match="The provided JSON schema contains features "
"not supported by xgrammar."): "not supported by xgrammar."):
...@@ -230,7 +231,7 @@ def test_structured_output( ...@@ -230,7 +231,7 @@ def test_structured_output(
parsed_json = json.loads(generated_text) parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict) assert isinstance(parsed_json, dict)
if guided_decoding_backend not in ["outlines", "lm-format-enforcer"]: if backend not in ["outlines", "lm-format-enforcer"]:
# #
# Test 4: Generate SQL statement using EBNF grammar # Test 4: Generate SQL statement using EBNF grammar
# #
...@@ -238,7 +239,8 @@ def test_structured_output( ...@@ -238,7 +239,8 @@ def test_structured_output(
temperature=0.8, temperature=0.8,
top_p=0.95, top_p=0.95,
max_tokens=1000, max_tokens=1000,
guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf)) structured_outputs=StructuredOutputsParams(
grammar=sample_sql_ebnf))
outputs = llm.generate( outputs = llm.generate(
("Generate a sql statement that selects col_1 from " ("Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1. Make the response as short as " "table_1 where it is equal to 1. Make the response as short as "
...@@ -271,7 +273,8 @@ def test_structured_output( ...@@ -271,7 +273,8 @@ def test_structured_output(
temperature=0.8, temperature=0.8,
top_p=0.95, top_p=0.95,
max_tokens=1000, max_tokens=1000,
guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark)) structured_outputs=StructuredOutputsParams(
grammar=sample_sql_lark))
outputs = llm.generate( outputs = llm.generate(
("Generate a sql statement that selects col_1 from " ("Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1. Make the response as short as " "table_1 where it is equal to 1. Make the response as short as "
...@@ -309,7 +312,8 @@ def test_structured_output( ...@@ -309,7 +312,8 @@ def test_structured_output(
temperature=0.8, temperature=0.8,
top_p=0.95, top_p=0.95,
max_tokens=1000, max_tokens=1000,
guided_decoding=GuidedDecodingParams(grammar="not a grammar")) structured_outputs=StructuredOutputsParams(
grammar="not a grammar"))
with pytest.raises(ValueError, match="Failed to convert the grammar "): with pytest.raises(ValueError, match="Failed to convert the grammar "):
llm.generate( llm.generate(
("Generate a sql statement that selects col_1 from " ("Generate a sql statement that selects col_1 from "
...@@ -325,7 +329,7 @@ def test_structured_output( ...@@ -325,7 +329,7 @@ def test_structured_output(
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0.8, temperature=0.8,
top_p=0.95, top_p=0.95,
guided_decoding=GuidedDecodingParams(regex=sample_regex)) structured_outputs=StructuredOutputsParams(regex=sample_regex))
prompt = (f"Give an example IPv4 address with this regex: {sample_regex}. " prompt = (f"Give an example IPv4 address with this regex: {sample_regex}. "
f"Make the response as short as possible.") f"Make the response as short as possible.")
...@@ -352,7 +356,8 @@ def test_structured_output( ...@@ -352,7 +356,8 @@ def test_structured_output(
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0.8, temperature=0.8,
top_p=0.95, top_p=0.95,
guided_decoding=GuidedDecodingParams(choice=sample_guided_choice)) structured_outputs=StructuredOutputsParams(
choice=sample_structured_outputs_choices))
outputs = llm.generate( outputs = llm.generate(
("The best language for type-safe systems programming is " ("The best language for type-safe systems programming is "
...@@ -368,7 +373,7 @@ def test_structured_output( ...@@ -368,7 +373,7 @@ def test_structured_output(
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
print(generated_text) print(generated_text)
assert generated_text is not None assert generated_text is not None
assert generated_text in sample_guided_choice assert generated_text in sample_structured_outputs_choices
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
# #
...@@ -378,7 +383,7 @@ def test_structured_output( ...@@ -378,7 +383,7 @@ def test_structured_output(
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=1.0, temperature=1.0,
max_tokens=1000, max_tokens=1000,
guided_decoding=GuidedDecodingParams(json=json_schema)) structured_outputs=StructuredOutputsParams(json=json_schema))
outputs = llm.generate( outputs = llm.generate(
("Generate a JSON with the brand, model and car_type of the most " ("Generate a JSON with the brand, model and car_type of the most "
...@@ -422,7 +427,7 @@ def test_structured_output( ...@@ -422,7 +427,7 @@ def test_structured_output(
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=1.0, temperature=1.0,
max_tokens=4096, max_tokens=4096,
guided_decoding=GuidedDecodingParams(json=json_schema)) structured_outputs=StructuredOutputsParams(json=json_schema))
outputs = llm.generate( outputs = llm.generate(
("Generate a description of a frog using 50 characters. " ("Generate a description of a frog using 50 characters. "
...@@ -444,7 +449,7 @@ def test_structured_output( ...@@ -444,7 +449,7 @@ def test_structured_output(
output_json = json.loads(generated_text) output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=json_schema) jsonschema.validate(instance=output_json, schema=json_schema)
if guided_decoding_backend not in ["outlines", "lm-format-enforcer"]: if backend not in ["outlines", "lm-format-enforcer"]:
# #
# Test 11: Generate structured output using structural_tag format # Test 11: Generate structured output using structural_tag format
# #
...@@ -470,7 +475,7 @@ def test_structured_output( ...@@ -470,7 +475,7 @@ def test_structured_output(
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0.0, temperature=0.0,
max_tokens=4096, max_tokens=4096,
guided_decoding=GuidedDecodingParams( structured_outputs=StructuredOutputsParams(
structural_tag=json.dumps(structural_tag_config))) structural_tag=json.dumps(structural_tag_config)))
prompt = """ prompt = """
...@@ -547,7 +552,7 @@ Make the response as short as possible. ...@@ -547,7 +552,7 @@ Make the response as short as possible.
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name, guided_decoding_backend, tokenizer_mode, reasoning_parser, speculative_config", # noqa: E501 "model_name, backend, tokenizer_mode, reasoning_parser, speculative_config", # noqa: E501
[ [
("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", "xgrammar", "auto", ("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", "xgrammar", "auto",
"deepseek_r1", NGRAM_SPEC_CONFIG), "deepseek_r1", NGRAM_SPEC_CONFIG),
...@@ -556,7 +561,7 @@ Make the response as short as possible. ...@@ -556,7 +561,7 @@ Make the response as short as possible.
) )
def test_structured_output_with_reasoning_matrices( def test_structured_output_with_reasoning_matrices(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
guided_decoding_backend: str, backend: str,
tokenizer_mode: TokenizerMode, tokenizer_mode: TokenizerMode,
reasoning_parser: str, reasoning_parser: str,
model_name: str, model_name: str,
...@@ -576,10 +581,11 @@ def test_structured_output_with_reasoning_matrices( ...@@ -576,10 +581,11 @@ def test_structured_output_with_reasoning_matrices(
enforce_eager=bool(not current_platform.is_tpu()), enforce_eager=bool(not current_platform.is_tpu()),
max_model_len=1024, max_model_len=1024,
max_num_seqs=16, max_num_seqs=16,
guided_decoding_backend=guided_decoding_backend, structured_outputs_config=dict(backend=backend,
guided_decoding_disable_any_whitespace=True, disable_any_whitespace=backend
in {"xgrammar", "guidance"},
reasoning_parser=reasoning_parser),
tokenizer_mode=tokenizer_mode, tokenizer_mode=tokenizer_mode,
reasoning_parser=reasoning_parser,
speculative_config=speculative_config, speculative_config=speculative_config,
) )
tokenizer = llm.get_tokenizer() tokenizer = llm.get_tokenizer()
...@@ -603,7 +609,7 @@ def test_structured_output_with_reasoning_matrices( ...@@ -603,7 +609,7 @@ def test_structured_output_with_reasoning_matrices(
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0.1, temperature=0.1,
max_tokens=8192, max_tokens=8192,
guided_decoding=GuidedDecodingParams(json=reasoning_schema), structured_outputs=StructuredOutputsParams(json=reasoning_schema),
) )
outputs = llm.generate( outputs = llm.generate(
[reasoning_prompt], [reasoning_prompt],
...@@ -640,13 +646,14 @@ def test_structured_output_auto_mode( ...@@ -640,13 +646,14 @@ def test_structured_output_auto_mode(
llm = LLM(model=model_name, llm = LLM(model=model_name,
max_model_len=1024, max_model_len=1024,
guided_decoding_backend="auto", structured_outputs_config=dict(backend="auto"),
tokenizer_mode=tokenizer_mode) tokenizer_mode=tokenizer_mode)
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=1.0, temperature=1.0,
max_tokens=1000, max_tokens=1000,
guided_decoding=GuidedDecodingParams(json=unsupported_json_schema)) structured_outputs=StructuredOutputsParams(
json=unsupported_json_schema))
prompts = ( prompts = (
"Give an example JSON object for a grade " "Give an example JSON object for a grade "
...@@ -681,9 +688,10 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch): ...@@ -681,9 +688,10 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch):
llm = LLM(model="Qwen/Qwen2.5-1.5B-Instruct", llm = LLM(model="Qwen/Qwen2.5-1.5B-Instruct",
max_model_len=1024, max_model_len=1024,
guided_decoding_backend="guidance", structured_outputs_config=dict(
guided_decoding_disable_any_whitespace=True, backend="guidance",
guided_decoding_disable_additional_properties=True) disable_any_whitespace=True,
disable_additional_properties=True))
schema = { schema = {
'type': 'object', 'type': 'object',
...@@ -709,14 +717,15 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch): ...@@ -709,14 +717,15 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch):
"<|im_end|>\n<|im_start|>assistant\n") "<|im_end|>\n<|im_start|>assistant\n")
def generate_with_backend(backend): def generate_with_backend(backend):
guided_params = GuidedDecodingParams( structured_outputs_params = StructuredOutputsParams(
json=schema, json=schema,
backend=backend, backend=backend,
disable_any_whitespace=True, disable_any_whitespace=True,
disable_additional_properties=True) disable_additional_properties=True)
sampling_params = SamplingParams(temperature=0, sampling_params = SamplingParams(
max_tokens=256, temperature=0,
guided_decoding=guided_params) max_tokens=256,
structured_outputs=structured_outputs_params)
outputs = llm.generate(prompt, sampling_params=sampling_params) outputs = llm.generate(prompt, sampling_params=sampling_params)
assert outputs is not None assert outputs is not None
...@@ -736,12 +745,11 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch): ...@@ -736,12 +745,11 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch):
assert "a6" not in generated assert "a6" not in generated
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("backend", ["guidance", "xgrammar", "outlines"])
["guidance", "xgrammar", "outlines"]) def test_structured_output_batched_with_non_structured_outputs_requests(
def test_structured_output_batched_with_non_guided_requests(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
sample_json_schema: dict[str, Any], sample_json_schema: dict[str, Any],
guided_decoding_backend: str, backend: str,
): ):
monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_USE_V1", "1")
...@@ -753,24 +761,25 @@ def test_structured_output_batched_with_non_guided_requests( ...@@ -753,24 +761,25 @@ def test_structured_output_batched_with_non_guided_requests(
model="meta-llama/Meta-Llama-3.1-8B-Instruct", model="meta-llama/Meta-Llama-3.1-8B-Instruct",
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
max_model_len=1024, max_model_len=1024,
guided_decoding_backend=guided_decoding_backend, structured_outputs_config=StructuredOutputsConfig(
guided_decoding_disable_any_whitespace=(guided_decoding_backend backend=backend,
in {"xgrammar", "guidance"}), disable_any_whitespace=backend in {"xgrammar", "guidance"},
),
) )
guided_prompt = ( structured_outputs_prompt = (
"Give an example JSON for an employee profile that fits this " "Give an example JSON for an employee profile that fits this "
"schema. Make the response as short as possible. Schema: " "schema. Make the response as short as possible. Schema: "
f"{sample_json_schema}") f"{sample_json_schema}")
non_guided_prompt = "The diameter of the Earth in kilometers is " non_structured_outputs_prompt = "The diameter of the Earth in kilometers is "
prompts = [guided_prompt, non_guided_prompt] prompts = [structured_outputs_prompt, non_structured_outputs_prompt]
sampling_params = [ sampling_params = [
SamplingParams( SamplingParams(temperature=1.0,
temperature=1.0, max_tokens=400,
max_tokens=400, structured_outputs=StructuredOutputsParams(
guided_decoding=GuidedDecodingParams(json=sample_json_schema)), json=sample_json_schema)),
# No max tokens, temp=0 to assert on contents # No max tokens, temp=0 to assert on contents
SamplingParams( SamplingParams(
seed=42, seed=42,
...@@ -801,16 +810,16 @@ def test_structured_output_batched_with_non_guided_requests( ...@@ -801,16 +810,16 @@ def test_structured_output_batched_with_non_guided_requests(
print(f"Prompt:\n{prompt!r}\nGenerated text:\n{generated_text!r}") print(f"Prompt:\n{prompt!r}\nGenerated text:\n{generated_text!r}")
if index == 0: if index == 0:
# First prompt is guided, expect valid JSON # First prompt is structured outputs, expect valid JSON
assert "\n" not in generated_text assert "\n" not in generated_text
output_json = json.loads(generated_text) output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, jsonschema.validate(instance=output_json,
schema=sample_json_schema) schema=sample_json_schema)
else: else:
# Second prompt is not guided, expect valid output # Second prompt is not structured outputs, expect valid output
# Cannot assert on exact output, but we can expect it to be factual # Cannot assert on exact output, but we can expect it to be factual
assert "12,742" in generated_text assert "12,742" in generated_text
# non-guided requests should not return a valid JSON here # non-structured outputs requests should not return a valid JSON here
with pytest.raises(ValueError): with pytest.raises(ValueError):
output_json = json.loads(generated_text) output_json = json.loads(generated_text)
...@@ -77,7 +77,9 @@ async def test_invalid_json_schema(client: openai.AsyncOpenAI, ...@@ -77,7 +77,9 @@ async def test_invalid_json_schema(client: openai.AsyncOpenAI,
"role": "user", "role": "user",
"content": prompt, "content": prompt,
}], }],
extra_body={"guided_json": invalid_json_schema}, extra_body={"structured_outputs": {
"json": invalid_json_schema
}},
) )
...@@ -99,7 +101,9 @@ async def test_invalid_regex(client: openai.AsyncOpenAI, model_name: str): ...@@ -99,7 +101,9 @@ async def test_invalid_regex(client: openai.AsyncOpenAI, model_name: str):
"content": prompt, "content": prompt,
}], }],
extra_body={ extra_body={
"guided_regex": r"[.*", "structured_outputs": {
"regex": r"[.*"
},
"stop": ["\n"] "stop": ["\n"]
}, },
) )
...@@ -134,5 +138,9 @@ async def test_invalid_grammar(client: openai.AsyncOpenAI, model_name: str): ...@@ -134,5 +138,9 @@ async def test_invalid_grammar(client: openai.AsyncOpenAI, model_name: str):
"role": "user", "role": "user",
"content": prompt, "content": prompt,
}], }],
extra_body={"guided_grammar": invalid_simplified_sql_grammar}, extra_body={
"structured_outputs": {
"grammar": invalid_simplified_sql_grammar
}
},
) )
...@@ -627,7 +627,9 @@ async def test_invalid_json_schema(client: openai.AsyncOpenAI, ...@@ -627,7 +627,9 @@ async def test_invalid_json_schema(client: openai.AsyncOpenAI,
await client.completions.create( await client.completions.create(
model=model_name, model=model_name,
prompt=prompt, prompt=prompt,
extra_body={"guided_json": invalid_json_schema}, extra_body={"structured_outputs": {
"json": invalid_json_schema
}},
) )
...@@ -646,7 +648,9 @@ async def test_invalid_regex(client: openai.AsyncOpenAI, model_name: str): ...@@ -646,7 +648,9 @@ async def test_invalid_regex(client: openai.AsyncOpenAI, model_name: str):
model=model_name, model=model_name,
prompt=prompt, prompt=prompt,
extra_body={ extra_body={
"guided_regex": r"[.*", "structured_outputs": {
"regex": r"[.*"
},
"stop": ["\n"] "stop": ["\n"]
}, },
) )
...@@ -678,7 +682,11 @@ async def test_invalid_grammar(client: openai.AsyncOpenAI, model_name: str): ...@@ -678,7 +682,11 @@ async def test_invalid_grammar(client: openai.AsyncOpenAI, model_name: str):
await client.completions.create( await client.completions.create(
model=model_name, model=model_name,
prompt=prompt, prompt=prompt,
extra_body={"guided_grammar": invalid_simplified_sql_grammar}, extra_body={
"structured_outputs": {
"grammar": invalid_simplified_sql_grammar
}
},
) )
......
...@@ -2277,34 +2277,34 @@ def get_served_model_name(model: str, ...@@ -2277,34 +2277,34 @@ def get_served_model_name(model: str,
return served_model_name return served_model_name
GuidedDecodingBackend = Literal["auto", "xgrammar", "guidance", "outlines", StructuredOutputsBackend = Literal["auto", "xgrammar", "guidance", "outlines",
"lm-format-enforcer"] "lm-format-enforcer"]
@config @config
@dataclass @dataclass
class DecodingConfig: class StructuredOutputsConfig:
"""Dataclass which contains the decoding strategy of the engine.""" """Dataclass which contains structured outputs config for the engine."""
backend: GuidedDecodingBackend = "auto" backend: StructuredOutputsBackend = "auto"
"""Which engine will be used for guided decoding (JSON schema / regex etc) """Which engine will be used for structured outputs (e.g. JSON schema,
by default. With "auto", we will make opinionated choices based on request regex, etc) by default. With "auto", we will make opinionated choices
contents and what the backend libraries currently support, so the behavior based on request contents and what the backend libraries currently support,
is subject to change in each release.""" so the behavior is subject to change in each release."""
disable_fallback: bool = False disable_fallback: bool = False
"""If `True`, vLLM will not fallback to a different backend on error.""" """If `True`, vLLM will not fallback to a different backend on error."""
disable_any_whitespace: bool = False disable_any_whitespace: bool = False
"""If `True`, the model will not generate any whitespace during guided """If `True`, the model will not generate any whitespace during structured
decoding. This is only supported for xgrammar and guidance backends.""" outputs. This is only supported for xgrammar and guidance backends."""
disable_additional_properties: bool = False disable_additional_properties: bool = False
"""If `True`, the `guidance` backend will not use `additionalProperties` """If `True`, the `guidance` backend will not use `additionalProperties`
in the JSON schema. This is only supported for the `guidance` backend and in the JSON schema. This is only supported for the `guidance` backend and
is used to better align its behaviour with `outlines` and `xgrammar`.""" is used to better align its behaviour with `outlines` and `xgrammar`."""
reasoning_backend: str = "" reasoning_parser: str = ""
"""Select the reasoning parser depending on the model that you're using. """Select the reasoning parser depending on the model that you're using.
This is used to parse the reasoning content into OpenAI API format.""" This is used to parse the reasoning content into OpenAI API format."""
...@@ -2451,8 +2451,9 @@ class VllmConfig: ...@@ -2451,8 +2451,9 @@ class VllmConfig:
"""LoRA configuration.""" """LoRA configuration."""
speculative_config: Optional[SpeculativeConfig] = None speculative_config: Optional[SpeculativeConfig] = None
"""Speculative decoding configuration.""" """Speculative decoding configuration."""
decoding_config: DecodingConfig = field(default_factory=DecodingConfig) structured_outputs_config: StructuredOutputsConfig = field(
"""Decoding configuration.""" default_factory=StructuredOutputsConfig)
"""Structured outputs configuration."""
observability_config: Optional[ObservabilityConfig] = None observability_config: Optional[ObservabilityConfig] = None
"""Observability configuration.""" """Observability configuration."""
quant_config: Optional[QuantizationConfig] = None quant_config: Optional[QuantizationConfig] = None
...@@ -2543,8 +2544,8 @@ class VllmConfig: ...@@ -2543,8 +2544,8 @@ class VllmConfig:
vllm_factors.append(self.speculative_config.compute_hash()) vllm_factors.append(self.speculative_config.compute_hash())
else: else:
vllm_factors.append("None") vllm_factors.append("None")
if self.decoding_config: if self.structured_outputs_config:
vllm_factors.append(self.decoding_config.compute_hash()) vllm_factors.append(self.structured_outputs_config.compute_hash())
else: else:
vllm_factors.append("None") vllm_factors.append("None")
if self.observability_config: if self.observability_config:
...@@ -3063,7 +3064,7 @@ class VllmConfig: ...@@ -3063,7 +3064,7 @@ class VllmConfig:
f"enforce_eager={self.model_config.enforce_eager}, " f"enforce_eager={self.model_config.enforce_eager}, "
f"kv_cache_dtype={self.cache_config.cache_dtype}, " f"kv_cache_dtype={self.cache_config.cache_dtype}, "
f"device_config={self.device_config.device}, " f"device_config={self.device_config.device}, "
f"decoding_config={self.decoding_config!r}, " f"structured_outputs_config={self.structured_outputs_config!r}, "
f"observability_config={self.observability_config!r}, " f"observability_config={self.observability_config!r}, "
f"seed={self.model_config.seed}, " f"seed={self.model_config.seed}, "
f"served_model_name={self.model_config.served_model_name}, " f"served_model_name={self.model_config.served_model_name}, "
......
...@@ -22,17 +22,16 @@ from typing_extensions import TypeIs, deprecated ...@@ -22,17 +22,16 @@ from typing_extensions import TypeIs, deprecated
import vllm.envs as envs import vllm.envs as envs
from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig, from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
ConfigType, ConvertOption, DecodingConfig, ConfigType, ConvertOption, DetailedTraceModules,
DetailedTraceModules, Device, DeviceConfig, Device, DeviceConfig, DistributedExecutorBackend,
DistributedExecutorBackend, EPLBConfig, EPLBConfig, HfOverrides, KVEventsConfig,
GuidedDecodingBackend, HfOverrides, KVEventsConfig,
KVTransferConfig, LoadConfig, LogprobsMode, KVTransferConfig, LoadConfig, LogprobsMode,
LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig, LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig,
ModelDType, ModelImpl, ObservabilityConfig, ModelDType, ModelImpl, ObservabilityConfig,
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo, ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
RunnerOption, SchedulerConfig, SchedulerPolicy, RunnerOption, SchedulerConfig, SchedulerPolicy,
SpeculativeConfig, TaskOption, TokenizerMode, SpeculativeConfig, StructuredOutputsConfig,
VllmConfig, get_attr_docs) TaskOption, TokenizerMode, VllmConfig, get_attr_docs)
from vllm.config.multimodal import MMCacheType, MultiModalConfig from vllm.config.multimodal import MMCacheType, MultiModalConfig
from vllm.config.parallel import ExpertPlacementStrategy from vllm.config.parallel import ExpertPlacementStrategy
from vllm.config.utils import get_field from vllm.config.utils import get_field
...@@ -418,12 +417,15 @@ class EngineArgs: ...@@ -418,12 +417,15 @@ class EngineArgs:
disable_hybrid_kv_cache_manager: bool = ( disable_hybrid_kv_cache_manager: bool = (
SchedulerConfig.disable_hybrid_kv_cache_manager) SchedulerConfig.disable_hybrid_kv_cache_manager)
guided_decoding_backend: GuidedDecodingBackend = DecodingConfig.backend structured_outputs_config: StructuredOutputsConfig = get_field(
guided_decoding_disable_fallback: bool = DecodingConfig.disable_fallback VllmConfig, "structured_outputs_config")
guided_decoding_disable_any_whitespace: bool = \ reasoning_parser: str = StructuredOutputsConfig.reasoning_parser
DecodingConfig.disable_any_whitespace # Deprecated guided decoding fields
guided_decoding_disable_additional_properties: bool = \ guided_decoding_backend: Optional[str] = None
DecodingConfig.disable_additional_properties guided_decoding_disable_fallback: Optional[bool] = None
guided_decoding_disable_any_whitespace: Optional[bool] = None
guided_decoding_disable_additional_properties: Optional[bool] = None
logits_processor_pattern: Optional[ logits_processor_pattern: Optional[
str] = ModelConfig.logits_processor_pattern str] = ModelConfig.logits_processor_pattern
...@@ -462,7 +464,6 @@ class EngineArgs: ...@@ -462,7 +464,6 @@ class EngineArgs:
additional_config: dict[str, Any] = \ additional_config: dict[str, Any] = \
get_field(VllmConfig, "additional_config") get_field(VllmConfig, "additional_config")
reasoning_parser: str = DecodingConfig.reasoning_backend
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
pt_load_map_location: str = LoadConfig.pt_load_map_location pt_load_map_location: str = LoadConfig.pt_load_map_location
...@@ -618,28 +619,29 @@ class EngineArgs: ...@@ -618,28 +619,29 @@ class EngineArgs:
load_group.add_argument('--pt-load-map-location', load_group.add_argument('--pt-load-map-location',
**load_kwargs["pt_load_map_location"]) **load_kwargs["pt_load_map_location"])
# Guided decoding arguments # Structured outputs arguments
guided_decoding_kwargs = get_kwargs(DecodingConfig) structured_outputs_kwargs = get_kwargs(StructuredOutputsConfig)
guided_decoding_group = parser.add_argument_group( structured_outputs_group = parser.add_argument_group(
title="DecodingConfig", title="StructuredOutputsConfig",
description=DecodingConfig.__doc__, description=StructuredOutputsConfig.__doc__,
) )
guided_decoding_group.add_argument("--guided-decoding-backend", structured_outputs_group.add_argument(
**guided_decoding_kwargs["backend"])
guided_decoding_group.add_argument(
"--guided-decoding-disable-fallback",
**guided_decoding_kwargs["disable_fallback"])
guided_decoding_group.add_argument(
"--guided-decoding-disable-any-whitespace",
**guided_decoding_kwargs["disable_any_whitespace"])
guided_decoding_group.add_argument(
"--guided-decoding-disable-additional-properties",
**guided_decoding_kwargs["disable_additional_properties"])
guided_decoding_group.add_argument(
"--reasoning-parser", "--reasoning-parser",
# This choice is a special case because it's not static # This choice is a special case because it's not static
choices=list(ReasoningParserManager.reasoning_parsers), choices=list(ReasoningParserManager.reasoning_parsers),
**guided_decoding_kwargs["reasoning_backend"]) **structured_outputs_kwargs["reasoning_parser"])
# Deprecated guided decoding arguments
for arg, type in [
("--guided-decoding-backend", str),
("--guided-decoding-disable-fallback", bool),
("--guided-decoding-disable-any-whitespace", bool),
("--guided-decoding-disable-additional-properties", bool),
]:
structured_outputs_group.add_argument(
arg,
type=type,
help=(f"[DEPRECATED] {arg} will be removed in v0.12.0."),
deprecated=True)
# Parallel arguments # Parallel arguments
parallel_kwargs = get_kwargs(ParallelConfig) parallel_kwargs = get_kwargs(ParallelConfig)
...@@ -934,6 +936,8 @@ class EngineArgs: ...@@ -934,6 +936,8 @@ class EngineArgs:
**vllm_kwargs["compilation_config"]) **vllm_kwargs["compilation_config"])
vllm_group.add_argument("--additional-config", vllm_group.add_argument("--additional-config",
**vllm_kwargs["additional_config"]) **vllm_kwargs["additional_config"])
vllm_group.add_argument('--structured-outputs-config',
**vllm_kwargs["structured_outputs_config"])
# Other arguments # Other arguments
parser.add_argument('--disable-log-stats', parser.add_argument('--disable-log-stats',
...@@ -1421,14 +1425,25 @@ class EngineArgs: ...@@ -1421,14 +1425,25 @@ class EngineArgs:
load_config = self.create_load_config() load_config = self.create_load_config()
decoding_config = DecodingConfig( # Pass reasoning_parser into StructuredOutputsConfig
backend=self.guided_decoding_backend, if self.reasoning_parser:
disable_fallback=self.guided_decoding_disable_fallback, self.structured_outputs_config.reasoning_parser = \
disable_any_whitespace=self.guided_decoding_disable_any_whitespace, self.reasoning_parser
disable_additional_properties=\
self.guided_decoding_disable_additional_properties, # Forward the deprecated CLI args to the StructuredOutputsConfig
reasoning_backend=self.reasoning_parser so_config = self.structured_outputs_config
) if self.guided_decoding_backend is not None:
so_config.guided_decoding_backend = \
self.guided_decoding_backend
if self.guided_decoding_disable_fallback is not None:
so_config.guided_decoding_disable_fallback = \
self.guided_decoding_disable_fallback
if self.guided_decoding_disable_any_whitespace is not None:
so_config.guided_decoding_disable_any_whitespace = \
self.guided_decoding_disable_any_whitespace
if self.guided_decoding_disable_additional_properties is not None:
so_config.guided_decoding_disable_additional_properties = \
self.guided_decoding_disable_additional_properties
observability_config = ObservabilityConfig( observability_config = ObservabilityConfig(
show_hidden_metrics_for_version=( show_hidden_metrics_for_version=(
...@@ -1446,7 +1461,7 @@ class EngineArgs: ...@@ -1446,7 +1461,7 @@ class EngineArgs:
lora_config=lora_config, lora_config=lora_config,
speculative_config=speculative_config, speculative_config=speculative_config,
load_config=load_config, load_config=load_config,
decoding_config=decoding_config, structured_outputs_config=self.structured_outputs_config,
observability_config=observability_config, observability_config=observability_config,
compilation_config=self.compilation_config, compilation_config=self.compilation_config,
kv_transfer_config=self.kv_transfer_config, kv_transfer_config=self.kv_transfer_config,
......
...@@ -10,9 +10,8 @@ from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List, ...@@ -10,9 +10,8 @@ from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
from weakref import ReferenceType from weakref import ReferenceType
import vllm.envs as envs import vllm.envs as envs
from vllm.config import (DecodingConfig, ModelConfig, ParallelConfig, from vllm.config import (LoRAConfig, ModelConfig, ParallelConfig,
SchedulerConfig, VllmConfig) SchedulerConfig, VllmConfig)
from vllm.config.lora import LoRAConfig
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout from vllm.engine.async_timeout import asyncio_timeout
...@@ -955,10 +954,6 @@ class AsyncLLMEngine(EngineClient): ...@@ -955,10 +954,6 @@ class AsyncLLMEngine(EngineClient):
"""Get the parallel configuration of the vLLM engine.""" """Get the parallel configuration of the vLLM engine."""
return self.engine.get_parallel_config() return self.engine.get_parallel_config()
async def get_decoding_config(self) -> DecodingConfig:
"""Get the decoding configuration of the vLLM engine."""
return self.engine.get_decoding_config()
async def get_scheduler_config(self) -> SchedulerConfig: async def get_scheduler_config(self) -> SchedulerConfig:
"""Get the scheduling configuration of the vLLM engine.""" """Get the scheduling configuration of the vLLM engine."""
return self.engine.get_scheduler_config() return self.engine.get_scheduler_config()
......
...@@ -16,9 +16,8 @@ import torch ...@@ -16,9 +16,8 @@ import torch
from typing_extensions import TypeVar from typing_extensions import TypeVar
import vllm.envs as envs import vllm.envs as envs
from vllm.config import (DecodingConfig, ModelConfig, ObservabilityConfig, from vllm.config import (LoRAConfig, ModelConfig, ObservabilityConfig,
ParallelConfig, SchedulerConfig, VllmConfig) ParallelConfig, SchedulerConfig, VllmConfig)
from vllm.config.lora import LoRAConfig
from vllm.core.scheduler import ScheduledSequenceGroup, SchedulerOutputs from vllm.core.scheduler import ScheduledSequenceGroup, SchedulerOutputs
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics_types import StatLoggerBase, Stats from vllm.engine.metrics_types import StatLoggerBase, Stats
...@@ -213,8 +212,7 @@ class LLMEngine: ...@@ -213,8 +212,7 @@ class LLMEngine:
self.device_config = vllm_config.device_config self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config # noqa self.speculative_config = vllm_config.speculative_config # noqa
self.load_config = vllm_config.load_config self.load_config = vllm_config.load_config
self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa self.structured_outputs_config = vllm_config.structured_outputs_config
)
self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa
) )
...@@ -364,10 +362,9 @@ class LLMEngine: ...@@ -364,10 +362,9 @@ class LLMEngine:
self.observability_config.otlp_traces_endpoint) self.observability_config.otlp_traces_endpoint)
# Initialize reasoning parser if reasoning backend is set. # Initialize reasoning parser if reasoning backend is set.
if self.decoding_config.reasoning_backend and \ if self.structured_outputs_config.reasoning_parser and self.tokenizer:
self.tokenizer:
reasoner_class = ReasoningParserManager.get_reasoning_parser( reasoner_class = ReasoningParserManager.get_reasoning_parser(
self.decoding_config.reasoning_backend) self.structured_outputs_config.reasoning_parser)
self.reasoner: ReasoningParser = reasoner_class( self.reasoner: ReasoningParser = reasoner_class(
self.tokenizer.get_lora_tokenizer()) self.tokenizer.get_lora_tokenizer())
...@@ -381,7 +378,8 @@ class LLMEngine: ...@@ -381,7 +378,8 @@ class LLMEngine:
self.seq_counter, self.seq_counter,
stop_checker=StopChecker( stop_checker=StopChecker(
self.scheduler_config.max_model_len, self.scheduler_config.max_model_len,
self.reasoner if self.decoding_config.reasoning_backend self.reasoner
if self.structured_outputs_config.reasoning_parser
and self.tokenizer else None, and self.tokenizer else None,
), ),
)) ))
...@@ -772,10 +770,6 @@ class LLMEngine: ...@@ -772,10 +770,6 @@ class LLMEngine:
"""Gets the parallel configuration.""" """Gets the parallel configuration."""
return self.parallel_config return self.parallel_config
def get_decoding_config(self) -> DecodingConfig:
"""Gets the decoding configuration."""
return self.decoding_config
def get_scheduler_config(self) -> SchedulerConfig: def get_scheduler_config(self) -> SchedulerConfig:
"""Gets the scheduler configuration.""" """Gets the scheduler configuration."""
return self.scheduler_config return self.scheduler_config
......
...@@ -6,7 +6,7 @@ from abc import ABC, abstractmethod ...@@ -6,7 +6,7 @@ from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Iterable, Mapping, Optional, Union from typing import Any, AsyncGenerator, Iterable, Mapping, Optional, Union
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.config import DecodingConfig, ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.inputs.data import PromptType, TokensPrompt from vllm.inputs.data import PromptType, TokensPrompt
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
...@@ -248,11 +248,6 @@ class EngineClient(ABC): ...@@ -248,11 +248,6 @@ class EngineClient(ABC):
"""Get the model configuration of the vLLM engine.""" """Get the model configuration of the vLLM engine."""
... ...
@abstractmethod
async def get_decoding_config(self) -> DecodingConfig:
"""Get the decoding configuration of the vLLM engine."""
...
@abstractmethod @abstractmethod
async def get_input_preprocessor(self) -> InputPreprocessor: async def get_input_preprocessor(self) -> InputPreprocessor:
"""Get the input processor of the vLLM engine.""" """Get the input processor of the vLLM engine."""
......
...@@ -15,8 +15,8 @@ import vllm.envs as envs ...@@ -15,8 +15,8 @@ import vllm.envs as envs
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence, BeamSearchSequence,
create_sort_beams_key_function) create_sort_beams_key_function)
from vllm.config import (CompilationConfig, ModelDType, TokenizerMode, from vllm.config import (CompilationConfig, ModelDType,
is_init_field) StructuredOutputsConfig, TokenizerMode, is_init_field)
from vllm.engine.arg_utils import (ConvertOption, EngineArgs, HfOverrides, from vllm.engine.arg_utils import (ConvertOption, EngineArgs, HfOverrides,
PoolerConfig, RunnerOption) PoolerConfig, RunnerOption)
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
...@@ -192,6 +192,8 @@ class LLM: ...@@ -192,6 +192,8 @@ class LLM:
hf_overrides: Optional[HfOverrides] = None, hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[dict[str, Any]] = None, mm_processor_kwargs: Optional[dict[str, Any]] = None,
override_pooler_config: Optional[PoolerConfig] = None, override_pooler_config: Optional[PoolerConfig] = None,
structured_outputs_config: Optional[Union[dict[
str, Any], StructuredOutputsConfig]] = None,
kv_cache_memory_bytes: Optional[int] = None, kv_cache_memory_bytes: Optional[int] = None,
compilation_config: Optional[Union[int, dict[str, Any], compilation_config: Optional[Union[int, dict[str, Any],
CompilationConfig]] = None, CompilationConfig]] = None,
...@@ -236,14 +238,30 @@ class LLM: ...@@ -236,14 +238,30 @@ class LLM:
compilation_config_instance = CompilationConfig( compilation_config_instance = CompilationConfig(
level=compilation_config) level=compilation_config)
elif isinstance(compilation_config, dict): elif isinstance(compilation_config, dict):
predicate = lambda x: is_init_field(CompilationConfig, x[0])
compilation_config_instance = CompilationConfig( compilation_config_instance = CompilationConfig(
**dict(filter(predicate, compilation_config.items()))) **{
k: v
for k, v in compilation_config.items()
if is_init_field(CompilationConfig, k)
})
else: else:
compilation_config_instance = compilation_config compilation_config_instance = compilation_config
else: else:
compilation_config_instance = CompilationConfig() compilation_config_instance = CompilationConfig()
if structured_outputs_config is not None:
if isinstance(structured_outputs_config, dict):
structured_outputs_instance = StructuredOutputsConfig(
**{
k: v
for k, v in structured_outputs_config.items()
if is_init_field(StructuredOutputsConfig, k)
})
else:
structured_outputs_instance = structured_outputs_config
else:
structured_outputs_instance = StructuredOutputsConfig()
engine_args = EngineArgs( engine_args = EngineArgs(
model=model, model=model,
runner=runner, runner=runner,
...@@ -271,6 +289,7 @@ class LLM: ...@@ -271,6 +289,7 @@ class LLM:
hf_overrides=hf_overrides, hf_overrides=hf_overrides,
mm_processor_kwargs=mm_processor_kwargs, mm_processor_kwargs=mm_processor_kwargs,
override_pooler_config=override_pooler_config, override_pooler_config=override_pooler_config,
structured_outputs_config=structured_outputs_instance,
compilation_config=compilation_config_instance, compilation_config=compilation_config_instance,
logits_processors=logits_processors, logits_processors=logits_processors,
**kwargs, **kwargs,
......
...@@ -1678,7 +1678,7 @@ async def init_app_state( ...@@ -1678,7 +1678,7 @@ async def init_app_state(
enable_auto_tools=args.enable_auto_tool_choice, enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser, tool_parser=args.tool_call_parser,
tool_server=tool_server, tool_server=tool_server,
reasoning_parser=args.reasoning_parser, reasoning_parser=args.structured_outputs_config.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage, enable_force_include_usage=args.enable_force_include_usage,
enable_log_outputs=args.enable_log_outputs, enable_log_outputs=args.enable_log_outputs,
...@@ -1697,7 +1697,7 @@ async def init_app_state( ...@@ -1697,7 +1697,7 @@ async def init_app_state(
exclude_tools_when_tool_choice_none=args. exclude_tools_when_tool_choice_none=args.
exclude_tools_when_tool_choice_none, exclude_tools_when_tool_choice_none,
tool_parser=args.tool_call_parser, tool_parser=args.tool_call_parser,
reasoning_parser=args.reasoning_parser, reasoning_parser=args.structured_outputs_config.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage, enable_force_include_usage=args.enable_force_include_usage,
enable_log_outputs=args.enable_log_outputs, enable_log_outputs=args.enable_log_outputs,
...@@ -1800,10 +1800,10 @@ def validate_api_server_args(args): ...@@ -1800,10 +1800,10 @@ def validate_api_server_args(args):
f"(chose from {{ {','.join(valid_tool_parses)} }})") f"(chose from {{ {','.join(valid_tool_parses)} }})")
valid_reasoning_parses = ReasoningParserManager.reasoning_parsers.keys() valid_reasoning_parses = ReasoningParserManager.reasoning_parsers.keys()
if args.reasoning_parser \ if ((reasoning_parser := args.structured_outputs_config.reasoning_parser)
and args.reasoning_parser not in valid_reasoning_parses: and reasoning_parser not in valid_reasoning_parses):
raise KeyError( raise KeyError(
f"invalid reasoning parser: {args.reasoning_parser} " f"invalid reasoning parser: {reasoning_parser} "
f"(chose from {{ {','.join(valid_reasoning_parses)} }})") f"(chose from {{ {','.join(valid_reasoning_parses)} }})")
......
...@@ -54,8 +54,8 @@ from vllm.entrypoints.score_utils import (ScoreContentPartParam, ...@@ -54,8 +54,8 @@ from vllm.entrypoints.score_utils import (ScoreContentPartParam,
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, from vllm.sampling_params import (BeamSearchParams, RequestOutputKind,
RequestOutputKind, SamplingParams) SamplingParams, StructuredOutputsParams)
from vllm.utils import random_uuid, resolve_obj_by_qualname from vllm.utils import random_uuid, resolve_obj_by_qualname
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -373,11 +373,12 @@ class ResponsesRequest(OpenAIBaseModel): ...@@ -373,11 +373,12 @@ class ResponsesRequest(OpenAIBaseModel):
stop_token_ids = default_sampling_params.get("stop_token_ids") stop_token_ids = default_sampling_params.get("stop_token_ids")
# Structured output # Structured output
guided_decoding = None structured_outputs = None
if self.text is not None and self.text.format is not None: if self.text is not None and self.text.format is not None:
response_format = self.text.format response_format = self.text.format
if response_format.type == "json_schema": if (response_format.type == "json_schema"
guided_decoding = GuidedDecodingParams.from_optional( and response_format.schema_ is not None):
structured_outputs = StructuredOutputsParams(
json=response_format.schema_) json=response_format.schema_)
elif response_format.type == "json_object": elif response_format.type == "json_object":
raise NotImplementedError("json_object is not supported") raise NotImplementedError("json_object is not supported")
...@@ -392,7 +393,7 @@ class ResponsesRequest(OpenAIBaseModel): ...@@ -392,7 +393,7 @@ class ResponsesRequest(OpenAIBaseModel):
stop_token_ids=stop_token_ids, stop_token_ids=stop_token_ids,
output_kind=(RequestOutputKind.DELTA output_kind=(RequestOutputKind.DELTA
if self.stream else RequestOutputKind.FINAL_ONLY), if self.stream else RequestOutputKind.FINAL_ONLY),
guided_decoding=guided_decoding, structured_outputs=structured_outputs,
) )
def is_include_output_logprobs(self) -> bool: def is_include_output_logprobs(self) -> bool:
...@@ -547,42 +548,9 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -547,42 +548,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
default=None, default=None,
description=("Additional kwargs to pass to the HF processor."), description=("Additional kwargs to pass to the HF processor."),
) )
guided_json: Optional[Union[str, dict, BaseModel]] = Field( structured_outputs: Optional[StructuredOutputsParams] = Field(
default=None, default=None,
description=("If specified, the output will follow the JSON schema."), description="Additional kwargs for structured outputs",
)
guided_regex: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the regex pattern."),
)
guided_choice: Optional[list[str]] = Field(
default=None,
description=(
"If specified, the output will be exactly one of the choices."),
)
guided_grammar: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the context free grammar."),
)
structural_tag: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the structural tag schema."),
)
guided_decoding_backend: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be either "
"'outlines' / 'lm-format-enforcer'"),
)
guided_whitespace_pattern: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."),
) )
priority: int = Field( priority: int = Field(
default=0, default=0,
...@@ -701,31 +669,33 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -701,31 +669,33 @@ class ChatCompletionRequest(OpenAIBaseModel):
if prompt_logprobs is None and self.echo: if prompt_logprobs is None and self.echo:
prompt_logprobs = self.top_logprobs prompt_logprobs = self.top_logprobs
guided_json_object = None response_format = self.response_format
if self.response_format is not None: json_schema_from_tool = self._get_json_schema_from_tool()
if self.response_format.type == "json_object": if response_format is not None or json_schema_from_tool is not None:
guided_json_object = True # If structured outputs wasn't already enabled,
elif self.response_format.type == "json_schema": # we must enable it for these features to work
json_schema = self.response_format.json_schema if self.structured_outputs is None:
assert json_schema is not None self.structured_outputs = StructuredOutputsParams()
self.guided_json = json_schema.json_schema
elif self.response_format.type == "structural_tag": # Set structured output params for response format
structural_tag = self.response_format if response_format is not None:
assert structural_tag is not None and isinstance( if response_format.type == "json_object":
structural_tag, StructuralTagResponseFormat) self.structured_outputs.json_object = True
s_tag_obj = structural_tag.model_dump(by_alias=True) elif response_format.type == "json_schema":
self.structural_tag = json.dumps(s_tag_obj) json_schema = response_format.json_schema
assert json_schema is not None
guided_decoding = GuidedDecodingParams.from_optional( self.structured_outputs.json = json_schema.json_schema
json=self._get_guided_json_from_tool() or self.guided_json, elif response_format.type == "structural_tag":
regex=self.guided_regex, structural_tag = response_format
choice=self.guided_choice, assert structural_tag is not None and isinstance(
grammar=self.guided_grammar, structural_tag, StructuralTagResponseFormat)
json_object=guided_json_object, s_tag_obj = structural_tag.model_dump(by_alias=True)
backend=self.guided_decoding_backend, self.structured_outputs.structural_tag = json.dumps(
whitespace_pattern=self.guided_whitespace_pattern, s_tag_obj)
structural_tag=self.structural_tag,
) # Set structured output params for tool calling
if json_schema_from_tool is not None:
self.structured_outputs.json = json_schema_from_tool
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {} extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
if self.kv_transfer_params: if self.kv_transfer_params:
...@@ -757,15 +727,14 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -757,15 +727,14 @@ class ChatCompletionRequest(OpenAIBaseModel):
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \ output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY, else RequestOutputKind.FINAL_ONLY,
guided_decoding=guided_decoding, structured_outputs=self.structured_outputs,
logit_bias=self.logit_bias, logit_bias=self.logit_bias,
bad_words= self.bad_words, bad_words=self.bad_words,
allowed_token_ids=self.allowed_token_ids, allowed_token_ids=self.allowed_token_ids,
extra_args=extra_args or None, extra_args=extra_args or None,
) )
def _get_guided_json_from_tool( def _get_json_schema_from_tool(self) -> Optional[Union[str, dict]]:
self) -> Optional[Union[str, dict, BaseModel]]:
# user has chosen to not use any tool # user has chosen to not use any tool
if self.tool_choice == "none" or self.tools is None: if self.tool_choice == "none" or self.tools is None:
return None return None
...@@ -875,28 +844,31 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -875,28 +844,31 @@ class ChatCompletionRequest(OpenAIBaseModel):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_guided_decoding_count(cls, data): def check_structured_outputs_count(cls, data):
if isinstance(data, ValueError): if isinstance(data, ValueError):
raise data raise data
guide_count = sum([ if "structured_outputs" not in data:
"guided_json" in data and data["guided_json"] is not None, return data
"guided_regex" in data and data["guided_regex"] is not None,
"guided_choice" in data and data["guided_choice"] is not None structured_outputs_kwargs = data['structured_outputs']
]) count = sum(
# you can only use one kind of guided decoding structured_outputs_kwargs.get(k) is not None
if guide_count > 1: for k in ("json", "regex", "choice"))
# you can only use one kind of constraints for structured outputs
if count > 1:
raise ValueError( raise ValueError(
"You can only use one kind of guided decoding " "You can only use one kind of constraints for structured "
"('guided_json', 'guided_regex' or 'guided_choice').") "outputs ('json', 'regex' or 'choice').")
# you can only either use guided decoding or tools, not both # you can only either use structured outputs or tools, not both
if guide_count > 1 and data.get("tool_choice", "none") not in ( if count > 1 and data.get("tool_choice", "none") not in (
"none", "none",
"auto", "auto",
"required", "required",
): ):
raise ValueError( raise ValueError(
"You can only either use guided decoding or tools, not both.") "You can only either use constraints for structured outputs "
"or tools, not both.")
return data return data
@model_validator(mode="before") @model_validator(mode="before")
...@@ -1049,37 +1021,9 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -1049,37 +1021,9 @@ class CompletionRequest(OpenAIBaseModel):
", {'type': 'structural_tag'}, or {'type': 'text' } is supported." ", {'type': 'structural_tag'}, or {'type': 'text' } is supported."
), ),
) )
guided_json: Optional[Union[str, dict, BaseModel]] = Field( structured_outputs: Optional[StructuredOutputsParams] = Field(
default=None,
description="If specified, the output will follow the JSON schema.",
)
guided_regex: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the regex pattern."),
)
guided_choice: Optional[list[str]] = Field(
default=None, default=None,
description=( description="Additional kwargs for structured outputs",
"If specified, the output will be exactly one of the choices."),
)
guided_grammar: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the context free grammar."),
)
guided_decoding_backend: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be one of "
"'outlines' / 'lm-format-enforcer'"),
)
guided_whitespace_pattern: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."),
) )
priority: int = Field( priority: int = Field(
default=0, default=0,
...@@ -1210,20 +1154,10 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -1210,20 +1154,10 @@ class CompletionRequest(OpenAIBaseModel):
echo_without_generation = self.echo and self.max_tokens == 0 echo_without_generation = self.echo and self.max_tokens == 0
guided_json_object = None if (self.structured_outputs is not None
if (self.response_format is not None and self.response_format is not None
and self.response_format.type == "json_object"): and self.response_format.type == "json_object"):
guided_json_object = True self.structured_outputs.json_object = True
guided_decoding = GuidedDecodingParams.from_optional(
json=self.guided_json,
regex=self.guided_regex,
choice=self.guided_choice,
grammar=self.guided_grammar,
json_object=guided_json_object,
backend=self.guided_decoding_backend,
whitespace_pattern=self.guided_whitespace_pattern,
)
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {} extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
if self.kv_transfer_params: if self.kv_transfer_params:
...@@ -1255,7 +1189,7 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -1255,7 +1189,7 @@ class CompletionRequest(OpenAIBaseModel):
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \ output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY, else RequestOutputKind.FINAL_ONLY,
guided_decoding=guided_decoding, structured_outputs=self.structured_outputs,
logit_bias=self.logit_bias, logit_bias=self.logit_bias,
allowed_token_ids=self.allowed_token_ids, allowed_token_ids=self.allowed_token_ids,
extra_args=extra_args or None, extra_args=extra_args or None,
...@@ -1263,16 +1197,18 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -1263,16 +1197,18 @@ class CompletionRequest(OpenAIBaseModel):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_guided_decoding_count(cls, data): def check_structured_outputs_count(cls, data):
guide_count = sum([ if "structured_outputs" not in data:
"guided_json" in data and data["guided_json"] is not None, return data
"guided_regex" in data and data["guided_regex"] is not None,
"guided_choice" in data and data["guided_choice"] is not None structured_outputs_kwargs = data['structured_outputs']
]) count = sum(
if guide_count > 1: structured_outputs_kwargs.get(k) is not None
for k in ("json", "regex", "choice"))
if count > 1:
raise ValueError( raise ValueError(
"You can only use one kind of guided decoding " "You can only use one kind of constraints for structured "
"('guided_json', 'guided_regex' or 'guided_choice').") "outputs ('json', 'regex' or 'choice').")
return data return data
@model_validator(mode="before") @model_validator(mode="before")
......
...@@ -993,7 +993,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -993,7 +993,7 @@ class OpenAIServingChat(OpenAIServing):
# check to make sure we haven't "forgotten" to stream # check to make sure we haven't "forgotten" to stream
# any tokens that were generated but previously # any tokens that were generated but previously
# matched by partial json parsing # matched by partial json parsing
# only happens if we are NOT using guided decoding # only happens if we are NOT using structured outputs
auto_tools_called = False auto_tools_called = False
if tool_parser: if tool_parser:
auto_tools_called = len( auto_tools_called = len(
......
...@@ -262,9 +262,9 @@ class GptOssForCausalLMConfig(VerifyAndUpdateConfig): ...@@ -262,9 +262,9 @@ class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
@staticmethod @staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None: def verify_and_update_config(vllm_config: "VllmConfig") -> None:
decoding_config = vllm_config.decoding_config structured_outputs_config = vllm_config.structured_outputs_config
if decoding_config.reasoning_backend == "": if structured_outputs_config.reasoning_parser == "":
decoding_config.reasoning_backend = "openai_gptoss" structured_outputs_config.reasoning_parser = "openai_gptoss"
# Increase the max capture size from 512 to 1024 for performance. # Increase the max capture size from 512 to 1024 for performance.
# NOTE(woosuk): This will increase the number of CUDA graphs # NOTE(woosuk): This will increase the number of CUDA graphs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment