Unverified Commit 9c884faa authored by amittell's avatar amittell Committed by GitHub
Browse files

[Bugfix] Preserve tool call id/type/name in streaming finish chunk (#31438)


Signed-off-by: default avataramittell <mittell@me.com>
Signed-off-by: default avatarAlex Mittell <mittell@me.com>
parent 48d5ca4e
......@@ -1506,3 +1506,142 @@ async def test_tool_choice_validation_without_parser():
assert isinstance(response_named, ErrorResponse)
assert "tool_choice" in response_named.error.message
assert "--tool-call-parser" in response_named.error.message
class TestCreateRemainingArgsDelta:
"""Tests for _create_remaining_args_delta helper function.
This helper is used when streaming tool calls to preserve id/type/name
fields in the finish chunk, which would otherwise be lost.
"""
def test_preserves_id_type_name(self):
"""Test that id, type, and name are preserved from original delta."""
from vllm.entrypoints.openai.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
original_delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=0,
id="call_abc123",
type="function",
function=DeltaFunctionCall(
name="get_weather",
arguments='{"location": "Paris"}',
),
)
]
)
result = OpenAIServingChat._create_remaining_args_delta(
original_delta, '", "unit": "celsius"}', 0
)
assert len(result.tool_calls) == 1
tc = result.tool_calls[0]
assert tc.index == 0
assert tc.id == "call_abc123"
assert tc.type == "function"
assert tc.function.name == "get_weather"
assert tc.function.arguments == '", "unit": "celsius"}'
def test_matches_by_index(self):
"""Test that the correct tool call is matched by index."""
from vllm.entrypoints.openai.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
original_delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=0,
id="call_first",
type="function",
function=DeltaFunctionCall(name="func_a", arguments="{}"),
),
DeltaToolCall(
index=1,
id="call_second",
type="function",
function=DeltaFunctionCall(name="func_b", arguments="{}"),
),
]
)
result = OpenAIServingChat._create_remaining_args_delta(
original_delta, '{"extra": true}', 1
)
assert len(result.tool_calls) == 1
tc = result.tool_calls[0]
assert tc.index == 1
assert tc.id == "call_second"
assert tc.function.name == "func_b"
def test_no_matching_tool_call(self):
"""Test graceful handling when no matching tool call is found."""
from vllm.entrypoints.openai.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
original_delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=0,
id="call_zero",
type="function",
function=DeltaFunctionCall(name="func", arguments="{}"),
)
]
)
result = OpenAIServingChat._create_remaining_args_delta(
original_delta, '{"arg": 1}', 5
)
assert len(result.tool_calls) == 1
tc = result.tool_calls[0]
assert tc.index == 5
assert tc.id is None
assert tc.type is None
assert tc.function.name is None
assert tc.function.arguments == '{"arg": 1}'
def test_function_is_none(self):
"""Test handling when original tool call has no function."""
from vllm.entrypoints.openai.protocol import DeltaMessage, DeltaToolCall
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
original_delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=0,
id="call_nofunc",
type="function",
function=None,
)
]
)
result = OpenAIServingChat._create_remaining_args_delta(
original_delta, '{"data": "value"}', 0
)
assert len(result.tool_calls) == 1
tc = result.tool_calls[0]
assert tc.index == 0
assert tc.id == "call_nofunc"
assert tc.type == "function"
assert tc.function.name is None
assert tc.function.arguments == '{"data": "value"}'
......@@ -1208,15 +1208,8 @@ class OpenAIServingChat(OpenAIServing):
# check to see if there's anything left to stream
remaining_call = expected_call.replace(actual_call, "", 1)
# set that as a delta message
delta_message = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=index,
function=DeltaFunctionCall(
arguments=remaining_call
).model_dump(exclude_none=True),
)
]
delta_message = self._create_remaining_args_delta(
delta_message, remaining_call, index
)
# Send the finish response for each request.n only once
......@@ -1803,6 +1796,35 @@ class OpenAIServingChat(OpenAIServing):
and delta_message.tool_calls[0].function.arguments is not None
)
@staticmethod
def _create_remaining_args_delta(
delta_message: DeltaMessage,
remaining_call: str,
index: int,
) -> DeltaMessage:
"""
Create a delta message for remaining tool arguments, preserving
id/type/name from the original delta.
"""
original_tc = next(
(tc for tc in delta_message.tool_calls if tc.index == index),
None,
)
original_fn = original_tc.function if original_tc else None
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=index,
id=original_tc.id if original_tc else None,
type=original_tc.type if original_tc else None,
function=DeltaFunctionCall(
name=original_fn.name if original_fn else None,
arguments=remaining_call,
),
)
]
)
def _make_request_with_harmony(
self,
request: ChatCompletionRequest,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment