Unverified Commit d084e9fc authored by Aleksandr Samarin's avatar Aleksandr Samarin Committed by GitHub
Browse files

[MODEL] Fix handling of multiple channels for gpt-oss with speculative decoding (#26291)


Signed-off-by: default avatarAleksandr Samarin <astrlrd@nebius.com>
Signed-off-by: default avatarsouthfreebird <yvorott@gmail.com>
Co-authored-by: default avatarsouthfreebird <yvorott@gmail.com>
parent 3a612322
......@@ -10,6 +10,7 @@ from unittest.mock import patch
import pytest
from vllm.entrypoints.openai.chat_completion.stream_harmony import (
TokenState,
extract_harmony_streaming_delta,
)
......@@ -42,12 +43,14 @@ class TestExtractHarmonyStreamingDelta:
def test_final_channel_returns_content_delta(self, delta_text, expected_content):
"""Test that final channel returns a DeltaMessage with content."""
parser = MockStreamableParser()
# Updated to use TokenState list
token_states = [TokenState(channel="final", recipient=None, text=delta_text)]
delta_message, tools_streamed = extract_harmony_streaming_delta(
harmony_parser=parser,
cur_channel="final",
cur_recipient=None,
token_states=token_states,
prev_recipient=None,
delta_text=delta_text,
include_reasoning=False,
)
......@@ -65,18 +68,19 @@ class TestExtractHarmonyStreamingDelta:
def test_analysis_channel_reasoning(self, include_reasoning, expected_has_message):
"""Test analysis channel respects include_reasoning flag."""
parser = MockStreamableParser()
text = "Let me think..."
token_states = [TokenState(channel="analysis", recipient=None, text=text)]
delta_message, tools_streamed = extract_harmony_streaming_delta(
harmony_parser=parser,
cur_channel="analysis",
cur_recipient=None,
token_states=token_states,
prev_recipient=None,
delta_text="Let me think...",
include_reasoning=include_reasoning,
)
if expected_has_message:
assert delta_message is not None
assert delta_message.reasoning == "Let me think..."
assert delta_message.reasoning == text
else:
assert delta_message is None
assert tools_streamed is False
......@@ -88,12 +92,14 @@ class TestExtractHarmonyStreamingDelta:
mock_make_tool_call_id.return_value = "call_test123"
parser = MockStreamableParser()
token_states = [
TokenState(channel=channel, recipient="functions.get_weather", text="")
]
delta_message, tools_streamed = extract_harmony_streaming_delta(
harmony_parser=parser,
cur_channel=channel,
cur_recipient="functions.get_weather",
token_states=token_states,
prev_recipient=None,
delta_text="",
include_reasoning=False,
)
......@@ -111,20 +117,25 @@ class TestExtractHarmonyStreamingDelta:
def test_tool_call_argument_streaming(self, channel):
"""Test streaming tool call arguments (same recipient)."""
parser = MockStreamableParser()
args_text = '{"location": "Paris"}'
token_states = [
TokenState(
channel=channel, recipient="functions.get_weather", text=args_text
)
]
delta_message, tools_streamed = extract_harmony_streaming_delta(
harmony_parser=parser,
cur_channel=channel,
cur_recipient="functions.get_weather",
token_states=token_states,
prev_recipient="functions.get_weather",
delta_text='{"location": "Paris"}',
include_reasoning=False,
)
assert delta_message is not None
tool_call = delta_message.tool_calls[0]
assert tool_call.id is None
assert tool_call.function.arguments == '{"location": "Paris"}'
assert tool_call.function.arguments == args_text
assert tool_call.index == 0
assert tools_streamed is True
......@@ -133,12 +144,14 @@ class TestExtractHarmonyStreamingDelta:
"""Test empty delta_text with same recipient returns None."""
parser = MockStreamableParser()
token_states = [
TokenState(channel=channel, recipient="functions.get_weather", text="")
]
delta_message, tools_streamed = extract_harmony_streaming_delta(
harmony_parser=parser,
cur_channel=channel,
cur_recipient="functions.get_weather",
token_states=token_states,
prev_recipient="functions.get_weather",
delta_text="",
include_reasoning=False,
)
......@@ -154,12 +167,14 @@ class TestExtractHarmonyStreamingDelta:
]
parser = MockStreamableParser(messages=messages)
token_states = [
TokenState(channel="commentary", recipient="functions.tool2", text="args")
]
delta_message, _ = extract_harmony_streaming_delta(
harmony_parser=parser,
cur_channel="commentary",
cur_recipient="functions.tool2",
token_states=token_states,
prev_recipient="functions.tool2",
delta_text="args",
include_reasoning=False,
)
......@@ -173,15 +188,18 @@ class TestExtractHarmonyStreamingDelta:
],
)
def test_returns_tool_call_preambles(self, channel, recipient):
"""Test that invalid channel/recipient combinations return None."""
"""Test that invalid tool recipient on commentary is treated as content."""
parser = MockStreamableParser()
delta_text = "some text"
token_states = [
TokenState(channel=channel, recipient=recipient, text=delta_text)
]
delta_message, tools_streamed = extract_harmony_streaming_delta(
harmony_parser=parser,
cur_channel=channel,
cur_recipient=recipient,
token_states=token_states,
prev_recipient=None,
delta_text=delta_text,
include_reasoning=True,
)
......@@ -199,14 +217,140 @@ class TestExtractHarmonyStreamingDelta:
"""Test that invalid channel/recipient combinations return None."""
parser = MockStreamableParser()
token_states = [
TokenState(channel=channel, recipient=recipient, text="some text")
]
delta_message, tools_streamed = extract_harmony_streaming_delta(
harmony_parser=parser,
cur_channel=channel,
cur_recipient=recipient,
token_states=token_states,
prev_recipient=None,
delta_text="some text",
include_reasoning=True,
)
assert delta_message is None
assert tools_streamed is False
def test_consecutive_token_grouping(self):
"""
Test that consecutive tokens with the same channel/recipient
are merged into a single processing group.
"""
parser = MockStreamableParser()
token_states = [
TokenState("final", None, "H"),
TokenState("final", None, "el"),
TokenState("final", None, "lo"),
TokenState("final", None, ","),
TokenState("final", None, " World"),
]
delta_message, _ = extract_harmony_streaming_delta(
harmony_parser=parser,
token_states=token_states,
prev_recipient=None,
include_reasoning=False,
)
assert delta_message is not None
assert delta_message.content == "Hello, World"
@patch("vllm.entrypoints.openai.chat_completion.stream_harmony.make_tool_call_id")
def test_complex_batch_permutation(self, mock_make_id):
"""
Test a complex permutation: Reasoning -> Tool Call -> Content.
This verifies that multiple distinct actions in one batch
are all captured in the single DeltaMessage.
"""
mock_make_id.return_value = "call_batch_test"
parser = MockStreamableParser()
token_states = [
# 1. Reasoning
TokenState("analysis", None, "Reasoning about query..."),
# 2. Tool Calling
TokenState("commentary", "functions.search", '{"query":'),
TokenState("commentary", "functions.search", ' "vllm"}'),
# 3. Final Content
TokenState("final", None, "."),
]
delta_message, tools_streamed = extract_harmony_streaming_delta(
harmony_parser=parser,
token_states=token_states,
prev_recipient=None,
include_reasoning=True,
)
assert delta_message is not None
assert delta_message.reasoning == "Reasoning about query..."
# We expect 2 objects for 1 logical tool call:
# 1. The definition (id, name, type)
# 2. The arguments payload
assert len(delta_message.tool_calls) == 2
header = delta_message.tool_calls[0]
payload = delta_message.tool_calls[1]
assert header.function.name == "search"
assert header.id == "call_batch_test"
assert header.index == 0
assert payload.index == 0
assert payload.function.arguments == '{"query": "vllm"}'
assert delta_message.content == "."
assert tools_streamed is True
@patch("vllm.entrypoints.openai.chat_completion.stream_harmony.make_tool_call_id")
def test_tool_call_index_consistency_with_ongoing_call(self, mock_make_id):
"""
Test that an ongoing tool call continuation and subsequent new calls
maintain correct indexing when interleaved with content.
"""
mock_make_id.side_effect = ["id_b", "id_c"]
messages = [
MockMessage(channel="commentary", recipient="functions.previous_tool")
]
parser = MockStreamableParser(messages=messages)
token_states = [
TokenState("commentary", "functions.tool_a", '{"key_a": "val_a"}'),
TokenState("final", None, "Thinking..."),
TokenState("commentary", "functions.tool_b", '{"key_b": "val_b"}'),
TokenState("final", None, " Thinking again..."),
TokenState("commentary", "functions.tool_c", '{"key_c": "val_c"}'),
]
delta_message, _ = extract_harmony_streaming_delta(
harmony_parser=parser,
token_states=token_states,
prev_recipient="functions.tool_a",
include_reasoning=False,
)
assert delta_message is not None
tool_a_deltas = [t for t in delta_message.tool_calls if t.index == 1]
assert len(tool_a_deltas) > 0
assert tool_a_deltas[0].id is None
assert tool_a_deltas[0].function.arguments == '{"key_a": "val_a"}'
tool_b_header = next(t for t in delta_message.tool_calls if t.id == "id_b")
assert tool_b_header.index == 2
tool_b_args = next(
t for t in delta_message.tool_calls if t.index == 2 and t.id is None
)
assert tool_b_args.function.arguments == '{"key_b": "val_b"}'
tool_c_start = next(t for t in delta_message.tool_calls if t.id == "id_c")
assert tool_c_start.index == 3
tool_c_args = next(
t for t in delta_message.tool_calls if t.index == 3 and t.id is None
)
assert tool_c_args.function.arguments == '{"key_c": "val_c"}'
assert delta_message.content == "Thinking... Thinking again..."
......@@ -36,6 +36,7 @@ from vllm.entrypoints.openai.chat_completion.protocol import (
ChatMessage,
)
from vllm.entrypoints.openai.chat_completion.stream_harmony import (
TokenState,
extract_harmony_streaming_delta,
)
from vllm.entrypoints.openai.engine.protocol import (
......@@ -826,12 +827,22 @@ class OpenAIServingChat(OpenAIServing):
if self.use_harmony:
harmony_parser = harmony_parsers[i]
prev_recipient = harmony_parser.current_recipient
delta_text = ""
# Track accumulated content per token with their state
token_states: list[TokenState] = []
for token_id in output.token_ids:
harmony_parser.process(token_id)
delta_text += harmony_parser.last_content_delta or ""
token_delta = harmony_parser.last_content_delta or ""
token_states.append(
TokenState(
harmony_parser.current_channel,
harmony_parser.current_recipient,
token_delta,
)
)
delta_text = "".join(delta for _, _, delta in token_states)
cur_channel = harmony_parser.current_channel
cur_recipient = harmony_parser.current_recipient
# handle the case where several tokens where generated at once
# including the final token, leading to a delta in the text
# but the current channel to be empty (start state)
......@@ -869,10 +880,8 @@ class OpenAIServingChat(OpenAIServing):
delta_message, tools_streamed_flag = (
extract_harmony_streaming_delta(
harmony_parser=harmony_parser,
cur_channel=cur_channel,
cur_recipient=cur_recipient,
token_states=token_states,
prev_recipient=prev_recipient,
delta_text=delta_text,
include_reasoning=request.include_reasoning,
)
)
......@@ -1139,17 +1148,23 @@ class OpenAIServingChat(OpenAIServing):
# Log streaming delta if output logging is enabled
if self.enable_log_outputs and self.request_logger:
delta_content = ""
delta_content_parts = []
if delta_message.content:
delta_content = delta_message.content
elif delta_message.tool_calls:
delta_content = "".join(
delta_content_parts.append(delta_message.content)
if delta_message.reasoning_content:
reasoning = delta_message.reasoning_content
delta_content_parts.append(f"[reasoning: {reasoning}]")
if delta_message.tool_calls:
tool_args = "".join(
tc.function.arguments
for tc in delta_message.tool_calls
if tc.function and tc.function.arguments
)
if tool_args:
delta_content_parts.append(f"[tool_calls: {tool_args}]")
if delta_content and self.enable_log_deltas:
if delta_content_parts and self.enable_log_deltas:
delta_content = " ".join(delta_content_parts)
self.request_logger.log_outputs(
request_id=request_id,
outputs=delta_content,
......
......@@ -7,6 +7,8 @@ This module handles the extraction of DeltaMessage objects from
harmony parser state during streaming chat completions.
"""
from typing import NamedTuple
from openai_harmony import StreamableParser
from vllm.entrypoints.chat_utils import make_tool_call_id
......@@ -17,12 +19,16 @@ from vllm.entrypoints.openai.engine.protocol import (
)
class TokenState(NamedTuple):
channel: str | None
recipient: str | None
text: str
def extract_harmony_streaming_delta(
harmony_parser: StreamableParser,
cur_channel: str | None,
cur_recipient: str | None,
token_states: list[TokenState],
prev_recipient: str | None,
delta_text: str,
include_reasoning: bool,
) -> tuple[DeltaMessage | None, bool]:
"""
......@@ -30,38 +36,81 @@ def extract_harmony_streaming_delta(
Args:
harmony_parser: The StreamableParser instance tracking parse state
cur_channel: Current channel ("final", "analysis", "commentary", etc.)
cur_recipient: Current recipient (e.g., "functions.my_func")
token_states: List of TokenState tuples for each token
prev_recipient: Previous recipient for detecting tool call transitions
delta_text: The text delta to include in the message
include_reasoning: Whether to include reasoning content
Returns:
A tuple of (DeltaMessage or None, tools_streamed_flag)
"""
if not token_states:
return None, False
tools_streamed = False
if cur_channel == "final":
delta_message = DeltaMessage(content=delta_text)
elif (
(cur_channel == "commentary" or cur_channel == "analysis")
and cur_recipient
and cur_recipient.startswith("functions.")
):
# Count completed tool calls to determine index
base_index = 0
for msg in harmony_parser.messages:
if (
(msg.channel == "commentary" or msg.channel == "analysis")
and msg.recipient
and msg.recipient.startswith("functions.")
):
base_index += 1
if prev_recipient != cur_recipient:
tool_name = cur_recipient.split("functions.", 1)[1]
delta_message = DeltaMessage(
tool_calls=[
# Group consecutive tokens with same channel/recipient
groups: list[TokenState] = []
current_channel = token_states[0].channel
current_recipient = token_states[0].recipient
current_text = token_states[0].text
for i in range(1, len(token_states)):
state = token_states[i]
if state.channel == current_channel and state.recipient == current_recipient:
current_text += state.text
else:
groups.append(TokenState(current_channel, current_recipient, current_text))
current_channel = state.channel
current_recipient = state.recipient
current_text = state.text
groups.append(TokenState(current_channel, current_recipient, current_text))
# Process each group and create delta messages
delta_message = None
combined_content = ""
combined_reasoning = ""
tool_messages = []
content_encountered = False
# Calculate base_index once before the loop
# This counts completed tool calls in messages
base_index = 0
for msg in harmony_parser.messages:
if (
(msg.channel == "commentary" or msg.channel == "analysis")
and msg.recipient
and msg.recipient.startswith("functions.")
):
base_index += 1
# If there's an ongoing tool call from previous chunk,
# the next new tool call starts at base_index + 1
if prev_recipient and prev_recipient.startswith("functions."):
next_tool_index = base_index + 1
# Ongoing call is at base_index
ongoing_tool_index = base_index
else:
# No ongoing call, next new call is at base_index
next_tool_index = base_index
ongoing_tool_index = None
for group in groups:
if group.channel == "final":
combined_content += group.text
content_encountered = True
elif (
(group.channel == "commentary" or group.channel == "analysis")
and group.recipient
and group.recipient.startswith("functions.")
):
opened_new_call = False
if prev_recipient != group.recipient:
# New tool call - emit the opening message
tool_name = group.recipient.split("functions.", 1)[1]
tool_messages.append(
DeltaToolCall(
id=make_tool_call_id(),
type="function",
......@@ -69,32 +118,53 @@ def extract_harmony_streaming_delta(
name=tool_name,
arguments="",
),
index=base_index,
index=next_tool_index,
)
]
)
elif delta_text:
delta_message = DeltaMessage(
tool_calls=[
)
opened_new_call = True
prev_recipient = group.recipient
# Increment for subsequent new tool calls
next_tool_index += 1
if group.text:
# Stream arguments for the ongoing tool call
if opened_new_call:
# Just opened in this group
tool_call_index = next_tool_index - 1
else:
# Continuing from previous chunk
# If ongoing_tool_index is None here, it means
# we're continuing a call but prev_recipient
# wasn't a function. Use base_index.
tool_call_index = (
ongoing_tool_index
if ongoing_tool_index is not None
else base_index
)
tool_messages.append(
DeltaToolCall(
index=base_index,
function=DeltaFunctionCall(arguments=delta_text),
index=tool_call_index,
function=DeltaFunctionCall(arguments=group.text),
)
]
)
else:
delta_message = None
)
elif group.channel == "commentary":
# Tool call preambles meant to be shown to the user
combined_content += group.text
content_encountered = True
elif group.channel == "analysis" and include_reasoning:
combined_reasoning += group.text
if delta_message is not None:
# Combine all non-empty fields into a single message
if content_encountered or combined_reasoning or tool_messages:
delta_kwargs: dict[str, str | list[DeltaToolCall]] = {}
if content_encountered:
delta_kwargs["content"] = combined_content
if combined_reasoning:
delta_kwargs["reasoning"] = combined_reasoning
if tool_messages:
delta_kwargs["tool_calls"] = tool_messages
tools_streamed = True
elif cur_channel == "commentary":
# Tool call preambles meant to be shown to the user
delta_message = DeltaMessage(content=delta_text)
elif cur_channel == "analysis":
if include_reasoning:
delta_message = DeltaMessage(reasoning=delta_text)
else:
delta_message = None
delta_message = DeltaMessage(**delta_kwargs)
else:
delta_message = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment