"lib/mocker/src/vscode:/vscode.git/clone" did not exist on "af32579e89c2da91e0305ef452917489b3ba4154"
Unverified Commit 5a71cdd7 authored by Chauncey's avatar Chauncey Committed by GitHub
Browse files

[Bugfix] Fix crash when tool_choice=required exceeds max_tokens (#36841)


Signed-off-by: default avatarchaunceyjiang <chaunceyjiang@gmail.com>
parent f0d3658c
...@@ -514,3 +514,27 @@ async def test_inconsistent_tool_choice_and_tools( ...@@ -514,3 +514,27 @@ async def test_inconsistent_tool_choice_and_tools(
], ],
tool_choice={}, tool_choice={},
) )
@pytest.mark.asyncio
async def test_max_tokens_with_tool_choice_required(client: openai.AsyncOpenAI):
""" """
models = await client.models.list()
model_name: str = models.data[0].id
# This combination previously crashed the engine
chat_completion = await client.chat.completions.create(
messages=messages,
temperature=0,
max_completion_tokens=1,
model=model_name,
tools=tools,
tool_choice="required",
)
# When `tool_choice="required"` and the tokens of `tools` exceed `max_tokens`,
# both `tool_calls` and `content` should be empty.
# This behavior should be consistent with OpenAI.
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
assert len(choice.message.tool_calls) == 0
assert choice.message.content == ""
...@@ -1507,7 +1507,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1507,7 +1507,7 @@ class OpenAIServingChat(OpenAIServing):
elif request.tool_choice and request.tool_choice == "required": elif request.tool_choice and request.tool_choice == "required":
tool_call_class_items = [] tool_call_class_items = []
assert tool_calls is not None and len(tool_calls) > 0 tool_calls = tool_calls or []
for idx, tool_call in enumerate(tool_calls): for idx, tool_call in enumerate(tool_calls):
# Use native ID if available, # Use native ID if available,
# otherwise generate ID with correct id_type # otherwise generate ID with correct id_type
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio import asyncio
import contextlib
import json import json
import time import time
from collections.abc import AsyncGenerator, Callable, Mapping, Sequence from collections.abc import AsyncGenerator, Callable, Mapping, Sequence
...@@ -13,7 +14,7 @@ from fastapi import Request ...@@ -13,7 +14,7 @@ from fastapi import Request
from openai.types.responses import ( from openai.types.responses import (
ToolChoiceFunction, ToolChoiceFunction,
) )
from pydantic import ConfigDict, TypeAdapter from pydantic import ConfigDict, TypeAdapter, ValidationError
from starlette.datastructures import Headers from starlette.datastructures import Headers
import vllm.envs as envs import vllm.envs as envs
...@@ -1125,16 +1126,18 @@ class OpenAIServing: ...@@ -1125,16 +1126,18 @@ class OpenAIServing:
) )
content = None # Clear content since tool is called. content = None # Clear content since tool is called.
elif request.tool_choice == "required": elif request.tool_choice == "required":
assert content is not None tool_calls = []
tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(content) with contextlib.suppress(ValidationError):
function_calls.extend( content = content or ""
[ tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(
content
)
for tool_call in tool_calls:
function_calls.append(
FunctionCall( FunctionCall(
name=tool_call.name, name=tool_call.name,
arguments=json.dumps(tool_call.parameters, ensure_ascii=False), arguments=json.dumps(tool_call.parameters, ensure_ascii=False),
) )
for tool_call in tool_calls
]
) )
content = None # Clear content since tool is called. content = None # Clear content since tool is called.
elif ( elif (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment