Unverified Commit 132bfd45 authored by Chauncey's avatar Chauncey Committed by GitHub
Browse files

[Bugfix][ResponsesAPI] Fix crash when tool_choice=required exceeds max_output_tokens (#37258)


Signed-off-by: default avatarchaunceyjiang <chaunceyjiang@gmail.com>
parent 24b4272a
...@@ -134,6 +134,34 @@ async def test_function_tool_use( ...@@ -134,6 +134,34 @@ async def test_function_tool_use(
assert reasoning.type == "reasoning" assert reasoning.type == "reasoning"
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_max_tokens_with_tool_choice_required(
client: openai.AsyncOpenAI, model_name: str
):
prompt = [
{
"role": "user",
"content": "Can you tell me what the current weather is in Berlin and the "
"forecast for the next 5 days, in fahrenheit?",
},
]
response = await client.responses.create(
model=model_name,
input=prompt,
tools=tools,
tool_choice="required",
max_output_tokens=10,
)
assert len(response.output) >= 1
for out in response.output:
# When `tool_choice="required"` and the tokens of `tools`
# exceed `max_output_tokens`,`function_call` should be empty.
# This behavior should be consistent with OpenAI
assert out.type != "function_call"
assert response.incomplete_details.reason == "max_output_tokens"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_named_tool_use(client: openai.AsyncOpenAI): async def test_named_tool_use(client: openai.AsyncOpenAI):
def get_weather(latitude: float, longitude: float) -> str: def get_weather(latitude: float, longitude: float) -> str:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import json import json
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Sequence from collections.abc import Sequence
...@@ -18,7 +19,7 @@ from openai.types.responses.response_output_text import Logprob ...@@ -18,7 +19,7 @@ from openai.types.responses.response_output_text import Logprob
from openai.types.responses.response_reasoning_item import ( from openai.types.responses.response_reasoning_item import (
Content as ResponseReasoningTextContent, Content as ResponseReasoningTextContent,
) )
from pydantic import TypeAdapter from pydantic import TypeAdapter, ValidationError
from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.chat_completion.protocol import ( from vllm.entrypoints.openai.chat_completion.protocol import (
...@@ -422,14 +423,18 @@ class DelegatingParser(Parser): ...@@ -422,14 +423,18 @@ class DelegatingParser(Parser):
if request.tool_choice == "required": if request.tool_choice == "required":
# Required tool calls - parse JSON # Required tool calls - parse JSON
assert content is not None tool_calls = []
tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(content) with contextlib.suppress(ValidationError):
function_calls.extend( content = content or ""
tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(
content
)
for tool_call in tool_calls:
function_calls.append(
FunctionCall( FunctionCall(
name=tool_call.name, name=tool_call.name,
arguments=json.dumps(tool_call.parameters, ensure_ascii=False), arguments=json.dumps(tool_call.parameters, ensure_ascii=False),
) )
for tool_call in tool_calls
) )
return function_calls, None # Clear content since tool is called. return function_calls, None # Clear content since tool is called.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment