Unverified Commit 25b3242d authored by noobHappylife's avatar noobHappylife Committed by GitHub
Browse files

Fix Responses API streaming for multiple auto tool calls (#39626)


Signed-off-by: default avatarnoobhappylife <aratar1991@hotmail.com>
parent b075604d
......@@ -249,38 +249,72 @@ async def test_function_calling_with_streaming_expected_arguments(
"additionalProperties": False,
},
"strict": True,
}
},
{
"type": "function",
"name": "get_time",
"description": "Get current local time for provided location.",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string"},
},
"required": ["location"],
"additionalProperties": False,
},
"strict": True,
},
]
stream_response = await client.responses.create(
model=model_name,
input="Can you tell me what the current weather is in Berlin?",
input=(
"Use tools only. Call get_weather for Berlin and get_time for Tokyo. "
"Do not answer directly."
),
tools=tools,
stream=True,
)
tool_call_item = None
completed_event = None
tool_call_items = {}
arguments_done_events = {}
completed_events = {}
async for event in stream_response:
if (
event.type == "response.output_item.added"
and event.item.type == "function_call"
):
tool_call_item = event.item
elif event.type == "response.function_call_arguments.delta" and tool_call_item:
tool_call_items[event.output_index] = event.item
elif event.type == "response.function_call_arguments.delta":
tool_call_item = tool_call_items[event.output_index]
tool_call_item.arguments += event.delta
elif event.type == "response.function_call_arguments.done":
arguments_done_events[event.output_index] = event
elif (
event.type == "response.output_item.done"
and event.item.type == "function_call"
):
completed_event = event
assert tool_call_item is not None
assert tool_call_item.type == "function_call"
assert tool_call_item.name == "get_weather"
assert completed_event is not None
assert tool_call_item.arguments == completed_event.item.arguments
assert tool_call_item.name == completed_event.item.name
args = json.loads(tool_call_item.arguments)
completed_events[event.output_index] = event
assert len(tool_call_items) >= 2
assert len(arguments_done_events) >= 2
assert len(completed_events) >= 2
tool_calls_by_name = {
event.item.name: (
tool_call_items[output_index],
arguments_done_events[output_index],
event.item,
)
for output_index, event in completed_events.items()
}
assert {"get_weather", "get_time"}.issubset(tool_calls_by_name)
for added_item, arguments_done_event, completed_item in tool_calls_by_name.values():
assert added_item.type == "function_call"
assert added_item.arguments == arguments_done_event.arguments
assert added_item.arguments == completed_item.arguments
assert added_item.name == arguments_done_event.name
assert added_item.name == completed_item.name
args = json.loads(added_item.arguments)
assert "location" in args
assert args["location"] is not None
......
......@@ -27,7 +27,9 @@ from openai.types.responses.tool import (
import vllm.envs as envs
from vllm.entrypoints.mcp.tool_server import ToolServer
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ErrorResponse,
RequestResponseMetadata,
)
......@@ -928,3 +930,197 @@ class TestStreamingReasoningToContentTransition:
]
assert len(item_done_events) == 1
assert isinstance(item_done_events[0].item, ResponseReasoningItem)
class TestAutoToolStreaming:
@staticmethod
async def _collect_events(delta_sequence: list[DeltaMessage]):
serving = _make_serving_instance_with_reasoning()
_mock_parser_with_reasoning(serving, delta_sequence)
contexts = [
_make_simple_context_with_output("chunk", [i])
for i in range(len(delta_sequence))
]
async def result_generator():
for ctx in contexts:
yield ctx
request = ResponsesRequest(
input="hi",
tools=[
{
"type": "function",
"name": "get_weather",
"description": "Get weather.",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
"required": ["location"],
"additionalProperties": False,
},
}
],
tool_choice="auto",
stream=True,
)
sampling_params = SamplingParams(max_tokens=64)
metadata = RequestResponseMetadata(request_id="req")
_identity_increment._counter = 0 # type: ignore
events = []
async for event in serving._process_simple_streaming_events(
request=request,
sampling_params=sampling_params,
result_generator=result_generator(),
context=SimpleContext(),
model_name="test-model",
tokenizer=MagicMock(),
request_metadata=metadata,
created_time=0,
_increment_sequence_number_and_return=_identity_increment,
):
events.append(event)
return events
@pytest.mark.skip_global_cleanup
@pytest.mark.asyncio
async def test_auto_multi_tool_streaming_opens_one_item_per_tool(self, monkeypatch):
monkeypatch.setattr(envs, "VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT", False)
delta_sequence = [
DeltaMessage(
tool_calls=[
DeltaToolCall(
id="call_vienna",
type="function",
index=0,
function=DeltaFunctionCall(
name="get_weather",
arguments="",
),
)
]
),
DeltaMessage(
tool_calls=[
DeltaToolCall(
index=0,
function=DeltaFunctionCall(
arguments='{"location":"Vienna"}',
),
)
]
),
DeltaMessage(
tool_calls=[
DeltaToolCall(
id="call_berlin",
type="function",
index=1,
function=DeltaFunctionCall(
name="get_weather",
arguments='{"location":"Berlin"}',
),
)
]
),
]
events = await self._collect_events(delta_sequence)
function_items = [
event
for event in events
if event.type == "response.output_item.added"
and getattr(event.item, "type", None) == "function_call"
]
assert len(function_items) == 2
assert [event.item.name for event in function_items] == [
"get_weather",
"get_weather",
]
assert [event.output_index for event in function_items] == [0, 1]
argument_deltas = [
event.delta
for event in events
if event.type == "response.function_call_arguments.delta"
]
assert argument_deltas == [
'{"location":"Vienna"}',
'{"location":"Berlin"}',
]
argument_done = [
event
for event in events
if event.type == "response.function_call_arguments.done"
]
assert [event.arguments for event in argument_done] == [
'{"location":"Vienna"}',
'{"location":"Berlin"}',
]
assert [event.output_index for event in argument_done] == [0, 1]
function_done = [
event
for event in events
if event.type == "response.output_item.done"
and getattr(event.item, "type", None) == "function_call"
]
assert [event.item.arguments for event in function_done] == [
'{"location":"Vienna"}',
'{"location":"Berlin"}',
]
assert [event.output_index for event in function_done] == [0, 1]
@pytest.mark.skip_global_cleanup
@pytest.mark.asyncio
async def test_auto_tool_choice_first_delta_tool_call_does_not_duplicate_item(
self, monkeypatch
):
monkeypatch.setattr(envs, "VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT", False)
delta_sequence = [
DeltaMessage(
tool_calls=[
DeltaToolCall(
id="call_test",
type="function",
index=0,
function=DeltaFunctionCall(
name="get_weather",
arguments="",
),
)
]
),
DeltaMessage(
tool_calls=[
DeltaToolCall(
index=0,
function=DeltaFunctionCall(
arguments='{"location":"Berlin"}',
),
)
]
),
]
events = await self._collect_events(delta_sequence)
function_items = [
event
for event in events
if event.type == "response.output_item.added"
and getattr(event.item, "type", None) == "function_call"
]
assert len(function_items) == 1
assert function_items[0].item.name == "get_weather"
argument_deltas = [
event.delta
for event in events
if event.type == "response.function_call_arguments.delta"
]
assert "".join(argument_deltas) == '{"location":"Berlin"}'
......@@ -1341,6 +1341,7 @@ class OpenAIServingResponses(OpenAIServing):
current_content_index = 0
current_output_index = 0
current_item_id = ""
current_tool_call_index: int | None = None
parser = self.parser(tokenizer, request.tools) if self.parser else None
first_delta_sent = False
previous_delta_messages: list[DeltaMessage] = []
......@@ -1368,6 +1369,7 @@ class OpenAIServingResponses(OpenAIServing):
)
if not delta_message:
continue
tool_call_item_started = False
if not first_delta_sent:
current_item_id = random_uuid()
if delta_message.tool_calls:
......@@ -1384,6 +1386,7 @@ class OpenAIServingResponses(OpenAIServing):
current_tool_call_name = delta_message.tool_calls[
0
].function.name
current_tool_call_index = delta_message.tool_calls[0].index
yield _increment_sequence_number_and_return(
ResponseOutputItemAddedEvent(
type="response.output_item.added",
......@@ -1394,13 +1397,12 @@ class OpenAIServingResponses(OpenAIServing):
id=current_item_id,
call_id=current_tool_call_id,
name=current_tool_call_name,
arguments=delta_message.tool_calls[
0
].function.arguments,
arguments="",
status="in_progress",
),
)
)
tool_call_item_started = True
elif delta_message.reasoning:
yield _increment_sequence_number_and_return(
ResponseOutputItemAddedEvent(
......@@ -1572,6 +1574,79 @@ class OpenAIServingResponses(OpenAIServing):
# reset previous delta messages
previous_delta_messages = []
if delta_message.tool_calls and delta_message.tool_calls[0].function:
tool_call = delta_message.tool_calls[0]
tool_call_function = tool_call.function
if (
current_tool_call_index is not None
and tool_call.index is not None
and tool_call.index != current_tool_call_index
and tool_call_function is not None
and tool_call_function.name is not None
):
# From one tool call to another, finalize the previous
# function-call item before opening the next one.
parts = []
for pm in previous_delta_messages:
if pm.tool_calls:
previous_tool_call = pm.tool_calls[0]
if previous_tool_call.function is not None:
parts.append(
previous_tool_call.function.arguments or ""
)
tool_call_arguments = "".join(parts)
yield _increment_sequence_number_and_return(
ResponseFunctionCallArgumentsDoneEvent(
type="response.function_call_arguments.done",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
arguments=tool_call_arguments,
name=current_tool_call_name,
)
)
function_call_item = ResponseFunctionToolCall(
type="function_call",
name=current_tool_call_name,
arguments=tool_call_arguments,
status="completed",
id=current_item_id,
call_id=current_tool_call_id,
)
yield _increment_sequence_number_and_return(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=current_output_index,
item=function_call_item,
)
)
# Reset previous delta messages so the next tool call
# does not reuse arguments from the completed item.
previous_delta_messages = []
current_output_index += 1
current_item_id = random_uuid()
current_tool_call_name = tool_call_function.name
current_tool_call_id = f"call_{random_uuid()}"
current_tool_call_index = tool_call.index
yield _increment_sequence_number_and_return(
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=current_output_index,
item=ResponseFunctionToolCallItem(
type="function_call",
id=current_item_id,
call_id=current_tool_call_id,
name=current_tool_call_name,
arguments="",
status="in_progress",
),
)
)
current_content_index = 0
tool_call_item_started = True
if delta_message.tool_calls[0].function.arguments:
yield _increment_sequence_number_and_return(
ResponseFunctionCallArgumentsDeltaEvent(
......@@ -1583,7 +1658,10 @@ class OpenAIServingResponses(OpenAIServing):
)
)
# tool call initiated with no arguments
elif delta_message.tool_calls[0].function.name:
elif (
delta_message.tool_calls[0].function.name
and not tool_call_item_started
):
# send done with current content part
# and add new function call item
yield _increment_sequence_number_and_return(
......@@ -1628,11 +1706,11 @@ class OpenAIServingResponses(OpenAIServing):
)
current_output_index += 1
current_item_id = random_uuid()
assert delta_message.tool_calls[0].function is not None
current_tool_call_name = delta_message.tool_calls[
0
].function.name
current_tool_call_id = f"call_{random_uuid()}"
current_tool_call_index = delta_message.tool_calls[0].index
yield _increment_sequence_number_and_return(
ResponseOutputItemAddedEvent(
type="response.output_item.added",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment