Unverified Commit 8f8fda26 authored by Ben Browning's avatar Ben Browning Committed by GitHub
Browse files

[Bugfix] Multiple fixes for gpt-oss Chat Completion prompting (#28729)


Signed-off-by: default avatarBen Browning <bbrownin@redhat.com>
Co-authored-by: default avatarChauncey <chaunceyjiang@gmail.com>
parent fe178710
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import AsyncGenerator
from typing import Any
from vllm.entrypoints.openai.protocol import (
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionStreamResponse,
ChatMessage,
UsageInfo,
)
async def accumulate_streaming_response(
stream_generator: AsyncGenerator[str, None],
) -> ChatCompletionResponse:
"""
Accumulate streaming SSE chunks into a complete ChatCompletionResponse.
This helper parses the SSE format and builds up the complete response
by combining all the delta chunks.
"""
accumulated_content = ""
accumulated_reasoning = None
accumulated_tool_calls: list[dict[str, Any]] = []
role = None
finish_reason = None
response_id = None
created = None
model = None
index = 0
async for chunk_str in stream_generator:
# Skip empty lines and [DONE] marker
if not chunk_str.strip() or chunk_str.strip() == "data: [DONE]":
continue
# Parse SSE format: "data: {json}\n\n"
if chunk_str.startswith("data: "):
json_str = chunk_str[6:].strip()
try:
chunk_data = json.loads(json_str)
# print(f"DEBUG: Parsed chunk_data: {chunk_data}")
chunk = ChatCompletionStreamResponse(**chunk_data)
# Store metadata from first chunk
if response_id is None:
response_id = chunk.id
created = chunk.created
model = chunk.model
# Process each choice in the chunk
for choice in chunk.choices:
if choice.delta.role:
role = choice.delta.role
if choice.delta.content:
accumulated_content += choice.delta.content
if choice.delta.reasoning:
if accumulated_reasoning is None:
accumulated_reasoning = ""
accumulated_reasoning += choice.delta.reasoning
if choice.delta.tool_calls:
# Accumulate tool calls
for tool_call_delta in choice.delta.tool_calls:
# Find or create the tool call at this index
while len(accumulated_tool_calls) <= tool_call_delta.index:
accumulated_tool_calls.append(
{
"id": None,
"type": "function",
"function": {"name": "", "arguments": ""},
}
)
if tool_call_delta.id:
accumulated_tool_calls[tool_call_delta.index]["id"] = (
tool_call_delta.id
)
if tool_call_delta.function:
if tool_call_delta.function.name:
accumulated_tool_calls[tool_call_delta.index][
"function"
]["name"] += tool_call_delta.function.name
if tool_call_delta.function.arguments:
accumulated_tool_calls[tool_call_delta.index][
"function"
]["arguments"] += tool_call_delta.function.arguments
if choice.finish_reason:
finish_reason = choice.finish_reason
if choice.index is not None:
index = choice.index
except json.JSONDecodeError:
continue
# Build the final message
message_kwargs = {
"role": role or "assistant",
"content": accumulated_content if accumulated_content else None,
"reasoning": accumulated_reasoning,
}
# Only include tool_calls if there are any
if accumulated_tool_calls:
message_kwargs["tool_calls"] = [
{"id": tc["id"], "type": tc["type"], "function": tc["function"]}
for tc in accumulated_tool_calls
]
message = ChatMessage(**message_kwargs)
# Build the final response
choice = ChatCompletionResponseChoice(
index=index,
message=message,
finish_reason=finish_reason or "stop",
)
# Create usage info (with dummy values for tests)
usage = UsageInfo(
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
)
response = ChatCompletionResponse(
id=response_id or "chatcmpl-test",
object="chat.completion",
created=created or 0,
model=model or "test-model",
choices=[choice],
usage=usage,
)
return response
def verify_harmony_messages(
messages: list[Any], expected_messages: list[dict[str, Any]]
):
assert len(messages) == len(expected_messages)
for msg, expected in zip(messages, expected_messages):
if "role" in expected:
assert msg.author.role == expected["role"]
if "author_name" in expected:
assert msg.author.name == expected["author_name"]
if "channel" in expected:
assert msg.channel == expected["channel"]
if "recipient" in expected:
assert msg.recipient == expected["recipient"]
if "content" in expected:
assert msg.content[0].text == expected["content"]
if "content_type" in expected:
assert msg.content_type == expected["content_type"]
if "tool_definitions" in expected:
# Check that the tool definitions match the expected list of tool names
actual_tools = [t.name for t in msg.content[0].tools["functions"].tools]
assert actual_tools == expected["tool_definitions"]
def verify_chat_response(
response: ChatCompletionResponse,
content: str | None = None,
reasoning: str | None = None,
tool_calls: list[tuple[str, str]] | None = None,
):
assert len(response.choices) == 1
message = response.choices[0].message
if content is not None:
assert message.content == content
else:
assert not message.content
if reasoning is not None:
assert message.reasoning == reasoning
else:
assert not message.reasoning
if tool_calls:
assert message.tool_calls is not None
assert len(message.tool_calls) == len(tool_calls)
for tc, (expected_name, expected_args) in zip(message.tool_calls, tool_calls):
assert tc.function.name == expected_name
assert tc.function.arguments == expected_args
else:
assert not message.tool_calls
...@@ -232,7 +232,177 @@ def parse_response_input( ...@@ -232,7 +232,177 @@ def parse_response_input(
return msg return msg
def parse_chat_inputs_to_harmony_messages(chat_msgs: list) -> list[Message]:
"""
Parse a list of messages from request.messages in the Chat Completion API to
Harmony messages.
"""
msgs: list[Message] = []
tool_id_names: dict[str, str] = {}
# Collect tool id to name mappings for tool response recipient values
for chat_msg in chat_msgs:
for tool_call in chat_msg.get("tool_calls", []):
tool_id_names[tool_call.get("id")] = tool_call.get("function", {}).get(
"name"
)
for chat_msg in chat_msgs:
msgs.extend(parse_chat_input_to_harmony_message(chat_msg, tool_id_names))
msgs = auto_drop_analysis_messages(msgs)
return msgs
def auto_drop_analysis_messages(msgs: list[Message]) -> list[Message]:
"""
Harmony models expect the analysis messages (representing raw chain of thought) to
be dropped after an assistant message to the final channel is produced from the
reasoning of those messages.
The openai-harmony library does this if the very last assistant message is to the
final channel, but it does not handle the case where we're in longer multi-turn
conversations and the client gave us reasoning content from previous turns of
the conversation with multiple assistant messages to the final channel in the
conversation.
So, we find the index of the last assistant message to the final channel and drop
all analysis messages that precede it, leaving only the analysis messages that
are relevant to the current part of the conversation.
"""
last_assistant_final_index = -1
for i in range(len(msgs) - 1, -1, -1):
msg = msgs[i]
if msg.author.role == "assistant" and msg.channel == "final":
last_assistant_final_index = i
break
cleaned_msgs: list[Message] = []
for i, msg in enumerate(msgs):
if i < last_assistant_final_index and msg.channel == "analysis":
continue
cleaned_msgs.append(msg)
return cleaned_msgs
def flatten_chat_text_content(content: str | list | None) -> str | None:
"""
Extract the text parts from a chat message content field and flatten them
into a single string.
"""
if isinstance(content, list):
return "".join(
item.get("text", "")
for item in content
if isinstance(item, dict) and item.get("type") == "text"
)
return content
def parse_chat_input_to_harmony_message(
chat_msg, tool_id_names: dict[str, str] | None = None
) -> list[Message]:
"""
Parse a message from request.messages in the Chat Completion API to
Harmony messages.
"""
tool_id_names = tool_id_names or {}
if not isinstance(chat_msg, dict):
# Handle Pydantic models
chat_msg = chat_msg.model_dump(exclude_none=True)
role = chat_msg.get("role")
msgs: list[Message] = []
# Assistant message with tool calls
tool_calls = chat_msg.get("tool_calls", [])
if role == "assistant" and tool_calls:
content = flatten_chat_text_content(chat_msg.get("content"))
if content:
commentary_msg = Message.from_role_and_content(Role.ASSISTANT, content)
commentary_msg = commentary_msg.with_channel("commentary")
msgs.append(commentary_msg)
reasoning_content = chat_msg.get("reasoning") or chat_msg.get(
"reasoning_content"
)
if reasoning_content:
analysis_msg = Message.from_role_and_content(
Role.ASSISTANT, reasoning_content
)
analysis_msg = analysis_msg.with_channel("analysis")
msgs.append(analysis_msg)
for call in tool_calls:
func = call.get("function", {})
name = func.get("name", "")
arguments = func.get("arguments", "") or ""
msg = Message.from_role_and_content(Role.ASSISTANT, arguments)
msg = msg.with_channel("commentary")
msg = msg.with_recipient(f"functions.{name}")
# Officially, this should be `<|constrain|>json` but there is not clear
# evidence that improves accuracy over `json` and some anecdotes to the
# contrary. Further testing of the different content_types is needed.
msg = msg.with_content_type("json")
msgs.append(msg)
return msgs
# Tool role message (tool output)
if role == "tool":
tool_call_id = chat_msg.get("tool_call_id", "")
name = tool_id_names.get(tool_call_id, "")
content = chat_msg.get("content", "") or ""
content = flatten_chat_text_content(content)
msg = (
Message.from_author_and_content(
Author.new(Role.TOOL, f"functions.{name}"), content
)
.with_channel("commentary")
.with_recipient("assistant")
)
return [msg]
# Non-tool reasoning content
reasoning_content = chat_msg.get("reasoning") or chat_msg.get("reasoning_content")
if role == "assistant" and reasoning_content:
analysis_msg = Message.from_role_and_content(Role.ASSISTANT, reasoning_content)
analysis_msg = analysis_msg.with_channel("analysis")
msgs.append(analysis_msg)
# Default: user/assistant/system messages with content
content = chat_msg.get("content") or ""
if content is None:
content = ""
if isinstance(content, str):
contents = [TextContent(text=content)]
else:
# TODO: Support refusal.
contents = [TextContent(text=c.get("text", "")) for c in content]
# Only add assistant messages if they have content, as reasoning or tool calling
# assistant messages were already added above.
if role == "assistant" and contents and contents[0].text:
msg = Message.from_role_and_contents(role, contents)
# Send non-tool assistant messages to the final channel
msg = msg.with_channel("final")
msgs.append(msg)
# For user/system/developer messages, add them directly even if no content.
elif role != "assistant":
msg = Message.from_role_and_contents(role, contents)
msgs.append(msg)
return msgs
def parse_input_to_harmony_message(chat_msg) -> list[Message]: def parse_input_to_harmony_message(chat_msg) -> list[Message]:
"""
Parse a message from request.previous_input_messages in the Responsees API to
Harmony messages.
"""
if not isinstance(chat_msg, dict): if not isinstance(chat_msg, dict):
# Handle Pydantic models # Handle Pydantic models
chat_msg = chat_msg.model_dump(exclude_none=True) chat_msg = chat_msg.model_dump(exclude_none=True)
...@@ -258,14 +428,7 @@ def parse_input_to_harmony_message(chat_msg) -> list[Message]: ...@@ -258,14 +428,7 @@ def parse_input_to_harmony_message(chat_msg) -> list[Message]:
if role == "tool": if role == "tool":
name = chat_msg.get("name", "") name = chat_msg.get("name", "")
content = chat_msg.get("content", "") or "" content = chat_msg.get("content", "") or ""
if isinstance(content, list): content = flatten_chat_text_content(content)
# Handle array format for tool message content
# by concatenating all text parts.
content = "".join(
item.get("text", "")
for item in content
if isinstance(item, dict) and item.get("type") == "text"
)
msg = Message.from_author_and_content( msg = Message.from_author_and_content(
Author.new(Role.TOOL, f"functions.{name}"), content Author.new(Role.TOOL, f"functions.{name}"), content
...@@ -623,20 +786,40 @@ def parse_output_into_messages(token_ids: Iterable[int]) -> StreamableParser: ...@@ -623,20 +786,40 @@ def parse_output_into_messages(token_ids: Iterable[int]) -> StreamableParser:
def parse_chat_output( def parse_chat_output(
token_ids: Sequence[int], token_ids: Sequence[int],
) -> tuple[str | None, str | None, bool]: ) -> tuple[str | None, str | None, bool]:
"""
Parse the output of a Harmony chat completion into reasoning and final content.
Note that when the `openai` tool parser is used, serving_chat only uses this
for the reasoning content and gets the final content from the tool call parser.
When the `openai` tool parser is not enabled, or when `GptOssReasoningParser` is
in use,this needs to return the final content without any tool calls parsed.
Empty reasoning or final content is returned as None instead of an empty string.
"""
parser = parse_output_into_messages(token_ids) parser = parse_output_into_messages(token_ids)
output_msgs = parser.messages output_msgs = parser.messages
is_tool_call = False # TODO: update this when tool call is supported is_tool_call = False # TODO: update this when tool call is supported
if len(output_msgs) == 0:
# The generation has stopped during reasoning. # Get completed messages from the parser
reasoning = parser.current_content reasoning_texts = [
final_content = None msg.content[0].text for msg in output_msgs if msg.channel == "analysis"
elif len(output_msgs) == 1: ]
# The generation has stopped during final message. final_texts = [
reasoning = output_msgs[0].content[0].text msg.content[0].text for msg in output_msgs if msg.channel != "analysis"
final_content = parser.current_content ]
else:
reasoning_msg = output_msgs[:-1] # Extract partial messages from the parser
final_msg = output_msgs[-1] if parser.current_channel == "analysis" and parser.current_content:
reasoning = "\n".join([msg.content[0].text for msg in reasoning_msg]) reasoning_texts.append(parser.current_content)
final_content = final_msg.content[0].text elif parser.current_channel != "analysis" and parser.current_content:
final_texts.append(parser.current_content)
# Flatten multiple messages into a single string
reasoning: str | None = "\n".join(reasoning_texts)
final_content: str | None = "\n".join(final_texts)
# Return None instead of empty string since existing callers check for None
reasoning = reasoning or None
final_content = final_content or None
return reasoning, final_content, is_tool_call return reasoning, final_content, is_tool_call
...@@ -27,8 +27,8 @@ from vllm.entrypoints.openai.parser.harmony_utils import ( ...@@ -27,8 +27,8 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
get_stop_tokens_for_assistant_actions, get_stop_tokens_for_assistant_actions,
get_streamable_parser_for_assistant, get_streamable_parser_for_assistant,
get_system_message, get_system_message,
parse_chat_inputs_to_harmony_messages,
parse_chat_output, parse_chat_output,
parse_input_to_harmony_message,
render_for_completion, render_for_completion,
) )
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
...@@ -822,6 +822,9 @@ class OpenAIServingChat(OpenAIServing): ...@@ -822,6 +822,9 @@ class OpenAIServingChat(OpenAIServing):
if delta_message is not None: if delta_message is not None:
harmony_tools_streamed[i] = True harmony_tools_streamed[i] = True
elif cur_channel == "commentary":
# Tool call preambles meant to be shown to the user
delta_message = DeltaMessage(content=delta_text)
else: else:
delta_message = None delta_message = None
# handle streaming deltas for tools with named tool_choice # handle streaming deltas for tools with named tool_choice
...@@ -1770,6 +1773,11 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1770,6 +1773,11 @@ class OpenAIServingChat(OpenAIServing):
): ):
messages: list[OpenAIMessage] = [] messages: list[OpenAIMessage] = []
# because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls(request)
# Add system message. # Add system message.
# NOTE: In Chat Completion API, browsing is enabled by default # NOTE: In Chat Completion API, browsing is enabled by default
# if the model supports it. TODO: Support browsing. # if the model supports it. TODO: Support browsing.
...@@ -1788,8 +1796,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1788,8 +1796,7 @@ class OpenAIServingChat(OpenAIServing):
messages.append(dev_msg) messages.append(dev_msg)
# Add user message. # Add user message.
for chat_msg in request.messages: messages.extend(parse_chat_inputs_to_harmony_messages(request.messages))
messages.extend(parse_input_to_harmony_message(chat_msg))
# Render prompt token ids. # Render prompt token ids.
prompt_token_ids = render_for_completion(messages) prompt_token_ids = render_for_completion(messages)
......
...@@ -43,6 +43,7 @@ class OpenAIToolParser(ToolParser): ...@@ -43,6 +43,7 @@ class OpenAIToolParser(ToolParser):
parser = parse_output_into_messages(token_ids) parser = parse_output_into_messages(token_ids)
tool_calls = [] tool_calls = []
final_content = None final_content = None
commentary_content = None
if len(parser.messages) > 0: if len(parser.messages) > 0:
for msg in parser.messages: for msg in parser.messages:
...@@ -75,11 +76,15 @@ class OpenAIToolParser(ToolParser): ...@@ -75,11 +76,15 @@ class OpenAIToolParser(ToolParser):
) )
elif msg.channel == "final": elif msg.channel == "final":
final_content = msg_text final_content = msg_text
elif msg.channel == "commentary" and not msg.recipient:
commentary_content = msg_text
return ExtractedToolCallInformation( return ExtractedToolCallInformation(
tools_called=len(tool_calls) > 0, tools_called=len(tool_calls) > 0,
tool_calls=tool_calls, tool_calls=tool_calls,
content=final_content, # prefer final content over commentary content if both are present
# commentary content is tool call preambles meant to be shown to the user
content=final_content or commentary_content,
) )
def extract_tool_calls_streaming( def extract_tool_calls_streaming(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment