Unverified Commit 35bf5d08 authored by cjackal's avatar cjackal Committed by GitHub
Browse files

[bugfix] Fix online serving crash when text type response_format is received (#26822)


Signed-off-by: default avatarcjackal <44624812+cjackal@users.noreply.github.com>
Signed-off-by: default avatarj0shuajun <59368606+j0shuajun@users.noreply.github.com>
Co-authored-by: default avatarj0shuajun <59368606+j0shuajun@users.noreply.github.com>
parent 5de6dd06
...@@ -671,6 +671,25 @@ async def test_response_format_json_schema(client: openai.AsyncOpenAI): ...@@ -671,6 +671,25 @@ async def test_response_format_json_schema(client: openai.AsyncOpenAI):
assert loaded == {"result": 2}, loaded assert loaded == {"result": 2}, loaded
@pytest.mark.asyncio
async def test_response_format_text(client: openai.AsyncOpenAI):
for _ in range(2):
resp = await client.chat.completions.create(
model=MODEL_NAME,
messages=[
{
"role": "user",
"content": "what is 1+1?",
}
],
max_completion_tokens=10,
response_format={"type": "text"},
)
content = resp.choices[0].message.content
assert content is not None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_extra_fields_allowed(client: openai.AsyncOpenAI): async def test_extra_fields_allowed(client: openai.AsyncOpenAI):
resp = await client.chat.completions.create( resp = await client.chat.completions.create(
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import json import json
import time import time
from dataclasses import replace
from typing import Annotated, Any, ClassVar, Literal from typing import Annotated, Any, ClassVar, Literal
import torch import torch
...@@ -417,18 +418,15 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -417,18 +418,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
response_format = self.response_format response_format = self.response_format
if response_format is not None: if response_format is not None:
# If structured outputs wasn't already enabled, structured_outputs_kwargs = dict[str, Any]()
# we must enable it for these features to work
if self.structured_outputs is None:
self.structured_outputs = StructuredOutputsParams()
# Set structured output params for response format # Set structured output params for response format
if response_format.type == "json_object": if response_format.type == "json_object":
self.structured_outputs.json_object = True structured_outputs_kwargs["json_object"] = True
elif response_format.type == "json_schema": elif response_format.type == "json_schema":
json_schema = response_format.json_schema json_schema = response_format.json_schema
assert json_schema is not None assert json_schema is not None
self.structured_outputs.json = json_schema.json_schema structured_outputs_kwargs["json"] = json_schema.json_schema
elif response_format.type == "structural_tag": elif response_format.type == "structural_tag":
structural_tag = response_format structural_tag = response_format
assert structural_tag is not None and isinstance( assert structural_tag is not None and isinstance(
...@@ -439,7 +437,16 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -439,7 +437,16 @@ class ChatCompletionRequest(OpenAIBaseModel):
), ),
) )
s_tag_obj = structural_tag.model_dump(by_alias=True) s_tag_obj = structural_tag.model_dump(by_alias=True)
self.structured_outputs.structural_tag = json.dumps(s_tag_obj) structured_outputs_kwargs["structural_tag"] = json.dumps(s_tag_obj)
# If structured outputs wasn't already enabled,
# we must enable it for these features to work
if len(structured_outputs_kwargs) > 0:
self.structured_outputs = (
StructuredOutputsParams(**structured_outputs_kwargs)
if self.structured_outputs is None
else replace(self.structured_outputs, **structured_outputs_kwargs)
)
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {} extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
if self.kv_transfer_params: if self.kv_transfer_params:
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import json import json
import time import time
from dataclasses import replace
from typing import Annotated, Any, Literal from typing import Annotated, Any, Literal
import torch import torch
...@@ -247,18 +248,15 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -247,18 +248,15 @@ class CompletionRequest(OpenAIBaseModel):
response_format = self.response_format response_format = self.response_format
if response_format is not None: if response_format is not None:
# If structured outputs wasn't already enabled, structured_outputs_kwargs = dict[str, Any]()
# we must enable it for these features to work
if self.structured_outputs is None:
self.structured_outputs = StructuredOutputsParams()
# Set structured output params for response format # Set structured output params for response format
if response_format.type == "json_object": if response_format.type == "json_object":
self.structured_outputs.json_object = True structured_outputs_kwargs["json_object"] = True
elif response_format.type == "json_schema": elif response_format.type == "json_schema":
json_schema = response_format.json_schema json_schema = response_format.json_schema
assert json_schema is not None assert json_schema is not None
self.structured_outputs.json = json_schema.json_schema structured_outputs_kwargs["json"] = json_schema.json_schema
elif response_format.type == "structural_tag": elif response_format.type == "structural_tag":
structural_tag = response_format structural_tag = response_format
assert structural_tag is not None and isinstance( assert structural_tag is not None and isinstance(
...@@ -269,7 +267,16 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -269,7 +267,16 @@ class CompletionRequest(OpenAIBaseModel):
), ),
) )
s_tag_obj = structural_tag.model_dump(by_alias=True) s_tag_obj = structural_tag.model_dump(by_alias=True)
self.structured_outputs.structural_tag = json.dumps(s_tag_obj) structured_outputs_kwargs["structural_tag"] = json.dumps(s_tag_obj)
# If structured outputs wasn't already enabled,
# we must enable it for these features to work
if len(structured_outputs_kwargs) > 0:
self.structured_outputs = (
StructuredOutputsParams(**structured_outputs_kwargs)
if self.structured_outputs is None
else replace(self.structured_outputs, **structured_outputs_kwargs)
)
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {} extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
if self.kv_transfer_params: if self.kv_transfer_params:
......
...@@ -9,7 +9,7 @@ from collections import deque ...@@ -9,7 +9,7 @@ from collections import deque
from collections.abc import AsyncGenerator, AsyncIterator, Callable, Sequence from collections.abc import AsyncGenerator, AsyncIterator, Callable, Sequence
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from copy import copy from copy import copy
from dataclasses import dataclass from dataclasses import dataclass, replace
from http import HTTPStatus from http import HTTPStatus
from typing import Final from typing import Final
...@@ -467,15 +467,18 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -467,15 +467,18 @@ class OpenAIServingResponses(OpenAIServing):
if self.reasoning_parser is not None: if self.reasoning_parser is not None:
reasoning_parser = self.reasoning_parser(tokenizer) reasoning_parser = self.reasoning_parser(tokenizer)
if sampling_params.structured_outputs is None: if (
sampling_params.structured_outputs = StructuredOutputsParams() isinstance(
struct_out = sampling_params.structured_outputs struct_out := sampling_params.structured_outputs,
if struct_out.all_non_structural_tag_constraints_none(): StructuredOutputsParams,
sampling_params.structured_outputs.structural_tag = ( )
reasoning_parser.prepare_structured_tag( and struct_out.all_non_structural_tag_constraints_none()
sampling_params.structured_outputs.structural_tag, ):
self.tool_server, sampling_params.structured_outputs = replace(
) struct_out,
structural_tag=reasoning_parser.prepare_structured_tag(
struct_out.structural_tag, self.tool_server
),
) )
generator = self._generate_with_builtin_tools( generator = self._generate_with_builtin_tools(
request_id=request.request_id, request_id=request.request_id,
......
...@@ -67,6 +67,11 @@ class StructuredOutputsParams: ...@@ -67,6 +67,11 @@ class StructuredOutputsParams:
"You can only use one kind of structured outputs constraint " "You can only use one kind of structured outputs constraint "
f"but multiple are specified: {self.__dict__}" f"but multiple are specified: {self.__dict__}"
) )
if count < 1:
raise ValueError(
"You must use one kind of structured outputs constraint "
f"but none are specified: {self.__dict__}"
)
def all_constraints_none(self) -> bool: def all_constraints_none(self) -> bool:
""" """
......
...@@ -65,10 +65,11 @@ class ToolParser: ...@@ -65,10 +65,11 @@ class ToolParser:
# Set structured output params for tool calling # Set structured output params for tool calling
if json_schema_from_tool is not None: if json_schema_from_tool is not None:
if isinstance(request, ChatCompletionRequest): if isinstance(request, ChatCompletionRequest):
request.structured_outputs = StructuredOutputsParams()
# tool_choice: "Forced Function" or "required" will override # tool_choice: "Forced Function" or "required" will override
# structured output json settings to make tool calling work correctly # structured output json settings to make tool calling work correctly
request.structured_outputs.json = json_schema_from_tool request.structured_outputs = StructuredOutputsParams(
json=json_schema_from_tool
)
request.response_format = None request.response_format = None
if isinstance(request, ResponsesRequest): if isinstance(request, ResponsesRequest):
request.text = ResponseTextConfig() request.text = ResponseTextConfig()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment