Unverified Commit 7645bc52 authored by Mary's avatar Mary Committed by GitHub
Browse files

[OpenAI] Fix tool_choice=required streaming when output has trailing extra data (#31610)


Signed-off-by: default avatarmaylikenoother <ogedengbemary19@gmail.com>
Co-authored-by: default avatarChauncey <chaunceyjiang@gmail.com>
parent 1123a878
...@@ -311,6 +311,7 @@ def test_streaming_output_valid(output, empty_params, delta_len): ...@@ -311,6 +311,7 @@ def test_streaming_output_valid(output, empty_params, delta_len):
previous_text = current_text previous_text = current_text
assert len(messages) > 0 assert len(messages) > 0
combined_messages = "[" combined_messages = "["
for message in messages: for message in messages:
if message.tool_calls[0].function.name: if message.tool_calls[0].function.name:
...@@ -328,3 +329,35 @@ def test_streaming_output_valid(output, empty_params, delta_len): ...@@ -328,3 +329,35 @@ def test_streaming_output_valid(output, empty_params, delta_len):
combined_messages += "}]" combined_messages += "}]"
assert json.loads(combined_messages) == output assert json.loads(combined_messages) == output
assert json.dumps(json.loads(combined_messages)) == output_json assert json.dumps(json.loads(combined_messages)) == output_json
def test_streaming_output_valid_with_trailing_extra_data():
self = MagicMock()
output = [{"name": "get_current_weather", "parameters": {"city": "Vienna"}}]
output_json = json.dumps(output) + "\nDONE"
previous_text = ""
function_name_returned = False
messages = []
delta_len = 3
for i in range(0, len(output_json), delta_len):
delta_text = output_json[i : i + delta_len]
current_text = previous_text + delta_text
delta_message, function_name_returned = (
OpenAIServingChat.extract_tool_call_required_streaming(
self,
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
function_name_returned=function_name_returned,
)
)
if delta_message:
messages.append(delta_message)
previous_text = current_text
assert len(messages) > 0
...@@ -13,6 +13,7 @@ import partial_json_parser ...@@ -13,6 +13,7 @@ import partial_json_parser
import regex as re import regex as re
from fastapi import Request from fastapi import Request
from openai_harmony import Message as OpenAIMessage from openai_harmony import Message as OpenAIMessage
from partial_json_parser.core.options import Allow
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ( from vllm.entrypoints.chat_utils import (
...@@ -76,6 +77,7 @@ from vllm.tokenizers.mistral import ( ...@@ -76,6 +77,7 @@ from vllm.tokenizers.mistral import (
) )
from vllm.tool_parsers import ToolParser from vllm.tool_parsers import ToolParser
from vllm.tool_parsers.mistral_tool_parser import MistralToolCall from vllm.tool_parsers.mistral_tool_parser import MistralToolCall
from vllm.tool_parsers.utils import partial_json_loads
from vllm.utils.collection_utils import as_list from vllm.utils.collection_utils import as_list
from vllm.v1.sample.logits_processor import validate_logits_processors_parameters from vllm.v1.sample.logits_processor import validate_logits_processors_parameters
...@@ -511,8 +513,12 @@ class OpenAIServingChat(OpenAIServing): ...@@ -511,8 +513,12 @@ class OpenAIServingChat(OpenAIServing):
# if the current text is empty, we cannot parse it # if the current text is empty, we cannot parse it
return None, function_name_returned return None, function_name_returned
try: try:
obj = partial_json_parser.loads(current_text) flags = Allow.ALL
except partial_json_parser.core.exceptions.MalformedJSON: obj, _ = partial_json_loads(current_text, flags)
except (
partial_json_parser.core.exceptions.MalformedJSON,
json.JSONDecodeError,
):
logger.debug("not enough tokens to parse into JSON yet") logger.debug("not enough tokens to parse into JSON yet")
obj = None obj = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment