Unverified Commit d084e9fc authored by Aleksandr Samarin's avatar Aleksandr Samarin Committed by GitHub
Browse files

[MODEL] Fix handling of multiple channels for gpt-oss with speculative decoding (#26291)


Signed-off-by: default avatarAleksandr Samarin <astrlrd@nebius.com>
Signed-off-by: default avatarsouthfreebird <yvorott@gmail.com>
Co-authored-by: default avatarsouthfreebird <yvorott@gmail.com>
parent 3a612322
...@@ -10,6 +10,7 @@ from unittest.mock import patch ...@@ -10,6 +10,7 @@ from unittest.mock import patch
import pytest import pytest
from vllm.entrypoints.openai.chat_completion.stream_harmony import ( from vllm.entrypoints.openai.chat_completion.stream_harmony import (
TokenState,
extract_harmony_streaming_delta, extract_harmony_streaming_delta,
) )
...@@ -42,12 +43,14 @@ class TestExtractHarmonyStreamingDelta: ...@@ -42,12 +43,14 @@ class TestExtractHarmonyStreamingDelta:
def test_final_channel_returns_content_delta(self, delta_text, expected_content): def test_final_channel_returns_content_delta(self, delta_text, expected_content):
"""Test that final channel returns a DeltaMessage with content.""" """Test that final channel returns a DeltaMessage with content."""
parser = MockStreamableParser() parser = MockStreamableParser()
# Updated to use TokenState list
token_states = [TokenState(channel="final", recipient=None, text=delta_text)]
delta_message, tools_streamed = extract_harmony_streaming_delta( delta_message, tools_streamed = extract_harmony_streaming_delta(
harmony_parser=parser, harmony_parser=parser,
cur_channel="final", token_states=token_states,
cur_recipient=None,
prev_recipient=None, prev_recipient=None,
delta_text=delta_text,
include_reasoning=False, include_reasoning=False,
) )
...@@ -65,18 +68,19 @@ class TestExtractHarmonyStreamingDelta: ...@@ -65,18 +68,19 @@ class TestExtractHarmonyStreamingDelta:
def test_analysis_channel_reasoning(self, include_reasoning, expected_has_message): def test_analysis_channel_reasoning(self, include_reasoning, expected_has_message):
"""Test analysis channel respects include_reasoning flag.""" """Test analysis channel respects include_reasoning flag."""
parser = MockStreamableParser() parser = MockStreamableParser()
text = "Let me think..."
token_states = [TokenState(channel="analysis", recipient=None, text=text)]
delta_message, tools_streamed = extract_harmony_streaming_delta( delta_message, tools_streamed = extract_harmony_streaming_delta(
harmony_parser=parser, harmony_parser=parser,
cur_channel="analysis", token_states=token_states,
cur_recipient=None,
prev_recipient=None, prev_recipient=None,
delta_text="Let me think...",
include_reasoning=include_reasoning, include_reasoning=include_reasoning,
) )
if expected_has_message: if expected_has_message:
assert delta_message is not None assert delta_message is not None
assert delta_message.reasoning == "Let me think..." assert delta_message.reasoning == text
else: else:
assert delta_message is None assert delta_message is None
assert tools_streamed is False assert tools_streamed is False
...@@ -88,12 +92,14 @@ class TestExtractHarmonyStreamingDelta: ...@@ -88,12 +92,14 @@ class TestExtractHarmonyStreamingDelta:
mock_make_tool_call_id.return_value = "call_test123" mock_make_tool_call_id.return_value = "call_test123"
parser = MockStreamableParser() parser = MockStreamableParser()
token_states = [
TokenState(channel=channel, recipient="functions.get_weather", text="")
]
delta_message, tools_streamed = extract_harmony_streaming_delta( delta_message, tools_streamed = extract_harmony_streaming_delta(
harmony_parser=parser, harmony_parser=parser,
cur_channel=channel, token_states=token_states,
cur_recipient="functions.get_weather",
prev_recipient=None, prev_recipient=None,
delta_text="",
include_reasoning=False, include_reasoning=False,
) )
...@@ -111,20 +117,25 @@ class TestExtractHarmonyStreamingDelta: ...@@ -111,20 +117,25 @@ class TestExtractHarmonyStreamingDelta:
def test_tool_call_argument_streaming(self, channel): def test_tool_call_argument_streaming(self, channel):
"""Test streaming tool call arguments (same recipient).""" """Test streaming tool call arguments (same recipient)."""
parser = MockStreamableParser() parser = MockStreamableParser()
args_text = '{"location": "Paris"}'
token_states = [
TokenState(
channel=channel, recipient="functions.get_weather", text=args_text
)
]
delta_message, tools_streamed = extract_harmony_streaming_delta( delta_message, tools_streamed = extract_harmony_streaming_delta(
harmony_parser=parser, harmony_parser=parser,
cur_channel=channel, token_states=token_states,
cur_recipient="functions.get_weather",
prev_recipient="functions.get_weather", prev_recipient="functions.get_weather",
delta_text='{"location": "Paris"}',
include_reasoning=False, include_reasoning=False,
) )
assert delta_message is not None assert delta_message is not None
tool_call = delta_message.tool_calls[0] tool_call = delta_message.tool_calls[0]
assert tool_call.id is None assert tool_call.id is None
assert tool_call.function.arguments == '{"location": "Paris"}' assert tool_call.function.arguments == args_text
assert tool_call.index == 0 assert tool_call.index == 0
assert tools_streamed is True assert tools_streamed is True
...@@ -133,12 +144,14 @@ class TestExtractHarmonyStreamingDelta: ...@@ -133,12 +144,14 @@ class TestExtractHarmonyStreamingDelta:
"""Test empty delta_text with same recipient returns None.""" """Test empty delta_text with same recipient returns None."""
parser = MockStreamableParser() parser = MockStreamableParser()
token_states = [
TokenState(channel=channel, recipient="functions.get_weather", text="")
]
delta_message, tools_streamed = extract_harmony_streaming_delta( delta_message, tools_streamed = extract_harmony_streaming_delta(
harmony_parser=parser, harmony_parser=parser,
cur_channel=channel, token_states=token_states,
cur_recipient="functions.get_weather",
prev_recipient="functions.get_weather", prev_recipient="functions.get_weather",
delta_text="",
include_reasoning=False, include_reasoning=False,
) )
...@@ -154,12 +167,14 @@ class TestExtractHarmonyStreamingDelta: ...@@ -154,12 +167,14 @@ class TestExtractHarmonyStreamingDelta:
] ]
parser = MockStreamableParser(messages=messages) parser = MockStreamableParser(messages=messages)
token_states = [
TokenState(channel="commentary", recipient="functions.tool2", text="args")
]
delta_message, _ = extract_harmony_streaming_delta( delta_message, _ = extract_harmony_streaming_delta(
harmony_parser=parser, harmony_parser=parser,
cur_channel="commentary", token_states=token_states,
cur_recipient="functions.tool2",
prev_recipient="functions.tool2", prev_recipient="functions.tool2",
delta_text="args",
include_reasoning=False, include_reasoning=False,
) )
...@@ -173,15 +188,18 @@ class TestExtractHarmonyStreamingDelta: ...@@ -173,15 +188,18 @@ class TestExtractHarmonyStreamingDelta:
], ],
) )
def test_returns_tool_call_preambles(self, channel, recipient): def test_returns_tool_call_preambles(self, channel, recipient):
"""Test that invalid channel/recipient combinations return None.""" """Test that invalid tool recipient on commentary is treated as content."""
parser = MockStreamableParser() parser = MockStreamableParser()
delta_text = "some text" delta_text = "some text"
token_states = [
TokenState(channel=channel, recipient=recipient, text=delta_text)
]
delta_message, tools_streamed = extract_harmony_streaming_delta( delta_message, tools_streamed = extract_harmony_streaming_delta(
harmony_parser=parser, harmony_parser=parser,
cur_channel=channel, token_states=token_states,
cur_recipient=recipient,
prev_recipient=None, prev_recipient=None,
delta_text=delta_text,
include_reasoning=True, include_reasoning=True,
) )
...@@ -199,14 +217,140 @@ class TestExtractHarmonyStreamingDelta: ...@@ -199,14 +217,140 @@ class TestExtractHarmonyStreamingDelta:
"""Test that invalid channel/recipient combinations return None.""" """Test that invalid channel/recipient combinations return None."""
parser = MockStreamableParser() parser = MockStreamableParser()
token_states = [
TokenState(channel=channel, recipient=recipient, text="some text")
]
delta_message, tools_streamed = extract_harmony_streaming_delta( delta_message, tools_streamed = extract_harmony_streaming_delta(
harmony_parser=parser, harmony_parser=parser,
cur_channel=channel, token_states=token_states,
cur_recipient=recipient,
prev_recipient=None, prev_recipient=None,
delta_text="some text",
include_reasoning=True, include_reasoning=True,
) )
assert delta_message is None assert delta_message is None
assert tools_streamed is False assert tools_streamed is False
def test_consecutive_token_grouping(self):
"""
Test that consecutive tokens with the same channel/recipient
are merged into a single processing group.
"""
parser = MockStreamableParser()
token_states = [
TokenState("final", None, "H"),
TokenState("final", None, "el"),
TokenState("final", None, "lo"),
TokenState("final", None, ","),
TokenState("final", None, " World"),
]
delta_message, _ = extract_harmony_streaming_delta(
harmony_parser=parser,
token_states=token_states,
prev_recipient=None,
include_reasoning=False,
)
assert delta_message is not None
assert delta_message.content == "Hello, World"
@patch("vllm.entrypoints.openai.chat_completion.stream_harmony.make_tool_call_id")
def test_complex_batch_permutation(self, mock_make_id):
"""
Test a complex permutation: Reasoning -> Tool Call -> Content.
This verifies that multiple distinct actions in one batch
are all captured in the single DeltaMessage.
"""
mock_make_id.return_value = "call_batch_test"
parser = MockStreamableParser()
token_states = [
# 1. Reasoning
TokenState("analysis", None, "Reasoning about query..."),
# 2. Tool Calling
TokenState("commentary", "functions.search", '{"query":'),
TokenState("commentary", "functions.search", ' "vllm"}'),
# 3. Final Content
TokenState("final", None, "."),
]
delta_message, tools_streamed = extract_harmony_streaming_delta(
harmony_parser=parser,
token_states=token_states,
prev_recipient=None,
include_reasoning=True,
)
assert delta_message is not None
assert delta_message.reasoning == "Reasoning about query..."
# We expect 2 objects for 1 logical tool call:
# 1. The definition (id, name, type)
# 2. The arguments payload
assert len(delta_message.tool_calls) == 2
header = delta_message.tool_calls[0]
payload = delta_message.tool_calls[1]
assert header.function.name == "search"
assert header.id == "call_batch_test"
assert header.index == 0
assert payload.index == 0
assert payload.function.arguments == '{"query": "vllm"}'
assert delta_message.content == "."
assert tools_streamed is True
@patch("vllm.entrypoints.openai.chat_completion.stream_harmony.make_tool_call_id")
def test_tool_call_index_consistency_with_ongoing_call(self, mock_make_id):
"""
Test that an ongoing tool call continuation and subsequent new calls
maintain correct indexing when interleaved with content.
"""
mock_make_id.side_effect = ["id_b", "id_c"]
messages = [
MockMessage(channel="commentary", recipient="functions.previous_tool")
]
parser = MockStreamableParser(messages=messages)
token_states = [
TokenState("commentary", "functions.tool_a", '{"key_a": "val_a"}'),
TokenState("final", None, "Thinking..."),
TokenState("commentary", "functions.tool_b", '{"key_b": "val_b"}'),
TokenState("final", None, " Thinking again..."),
TokenState("commentary", "functions.tool_c", '{"key_c": "val_c"}'),
]
delta_message, _ = extract_harmony_streaming_delta(
harmony_parser=parser,
token_states=token_states,
prev_recipient="functions.tool_a",
include_reasoning=False,
)
assert delta_message is not None
tool_a_deltas = [t for t in delta_message.tool_calls if t.index == 1]
assert len(tool_a_deltas) > 0
assert tool_a_deltas[0].id is None
assert tool_a_deltas[0].function.arguments == '{"key_a": "val_a"}'
tool_b_header = next(t for t in delta_message.tool_calls if t.id == "id_b")
assert tool_b_header.index == 2
tool_b_args = next(
t for t in delta_message.tool_calls if t.index == 2 and t.id is None
)
assert tool_b_args.function.arguments == '{"key_b": "val_b"}'
tool_c_start = next(t for t in delta_message.tool_calls if t.id == "id_c")
assert tool_c_start.index == 3
tool_c_args = next(
t for t in delta_message.tool_calls if t.index == 3 and t.id is None
)
assert tool_c_args.function.arguments == '{"key_c": "val_c"}'
assert delta_message.content == "Thinking... Thinking again..."
...@@ -36,6 +36,7 @@ from vllm.entrypoints.openai.chat_completion.protocol import ( ...@@ -36,6 +36,7 @@ from vllm.entrypoints.openai.chat_completion.protocol import (
ChatMessage, ChatMessage,
) )
from vllm.entrypoints.openai.chat_completion.stream_harmony import ( from vllm.entrypoints.openai.chat_completion.stream_harmony import (
TokenState,
extract_harmony_streaming_delta, extract_harmony_streaming_delta,
) )
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.openai.engine.protocol import (
...@@ -826,12 +827,22 @@ class OpenAIServingChat(OpenAIServing): ...@@ -826,12 +827,22 @@ class OpenAIServingChat(OpenAIServing):
if self.use_harmony: if self.use_harmony:
harmony_parser = harmony_parsers[i] harmony_parser = harmony_parsers[i]
prev_recipient = harmony_parser.current_recipient prev_recipient = harmony_parser.current_recipient
delta_text = ""
# Track accumulated content per token with their state
token_states: list[TokenState] = []
for token_id in output.token_ids: for token_id in output.token_ids:
harmony_parser.process(token_id) harmony_parser.process(token_id)
delta_text += harmony_parser.last_content_delta or "" token_delta = harmony_parser.last_content_delta or ""
token_states.append(
TokenState(
harmony_parser.current_channel,
harmony_parser.current_recipient,
token_delta,
)
)
delta_text = "".join(delta for _, _, delta in token_states)
cur_channel = harmony_parser.current_channel cur_channel = harmony_parser.current_channel
cur_recipient = harmony_parser.current_recipient
# handle the case where several tokens where generated at once # handle the case where several tokens where generated at once
# including the final token, leading to a delta in the text # including the final token, leading to a delta in the text
# but the current channel to be empty (start state) # but the current channel to be empty (start state)
...@@ -869,10 +880,8 @@ class OpenAIServingChat(OpenAIServing): ...@@ -869,10 +880,8 @@ class OpenAIServingChat(OpenAIServing):
delta_message, tools_streamed_flag = ( delta_message, tools_streamed_flag = (
extract_harmony_streaming_delta( extract_harmony_streaming_delta(
harmony_parser=harmony_parser, harmony_parser=harmony_parser,
cur_channel=cur_channel, token_states=token_states,
cur_recipient=cur_recipient,
prev_recipient=prev_recipient, prev_recipient=prev_recipient,
delta_text=delta_text,
include_reasoning=request.include_reasoning, include_reasoning=request.include_reasoning,
) )
) )
...@@ -1139,17 +1148,23 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1139,17 +1148,23 @@ class OpenAIServingChat(OpenAIServing):
# Log streaming delta if output logging is enabled # Log streaming delta if output logging is enabled
if self.enable_log_outputs and self.request_logger: if self.enable_log_outputs and self.request_logger:
delta_content = "" delta_content_parts = []
if delta_message.content: if delta_message.content:
delta_content = delta_message.content delta_content_parts.append(delta_message.content)
elif delta_message.tool_calls: if delta_message.reasoning_content:
delta_content = "".join( reasoning = delta_message.reasoning_content
delta_content_parts.append(f"[reasoning: {reasoning}]")
if delta_message.tool_calls:
tool_args = "".join(
tc.function.arguments tc.function.arguments
for tc in delta_message.tool_calls for tc in delta_message.tool_calls
if tc.function and tc.function.arguments if tc.function and tc.function.arguments
) )
if tool_args:
delta_content_parts.append(f"[tool_calls: {tool_args}]")
if delta_content and self.enable_log_deltas: if delta_content_parts and self.enable_log_deltas:
delta_content = " ".join(delta_content_parts)
self.request_logger.log_outputs( self.request_logger.log_outputs(
request_id=request_id, request_id=request_id,
outputs=delta_content, outputs=delta_content,
......
...@@ -7,6 +7,8 @@ This module handles the extraction of DeltaMessage objects from ...@@ -7,6 +7,8 @@ This module handles the extraction of DeltaMessage objects from
harmony parser state during streaming chat completions. harmony parser state during streaming chat completions.
""" """
from typing import NamedTuple
from openai_harmony import StreamableParser from openai_harmony import StreamableParser
from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.chat_utils import make_tool_call_id
...@@ -17,12 +19,16 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -17,12 +19,16 @@ from vllm.entrypoints.openai.engine.protocol import (
) )
class TokenState(NamedTuple):
channel: str | None
recipient: str | None
text: str
def extract_harmony_streaming_delta( def extract_harmony_streaming_delta(
harmony_parser: StreamableParser, harmony_parser: StreamableParser,
cur_channel: str | None, token_states: list[TokenState],
cur_recipient: str | None,
prev_recipient: str | None, prev_recipient: str | None,
delta_text: str,
include_reasoning: bool, include_reasoning: bool,
) -> tuple[DeltaMessage | None, bool]: ) -> tuple[DeltaMessage | None, bool]:
""" """
...@@ -30,38 +36,81 @@ def extract_harmony_streaming_delta( ...@@ -30,38 +36,81 @@ def extract_harmony_streaming_delta(
Args: Args:
harmony_parser: The StreamableParser instance tracking parse state harmony_parser: The StreamableParser instance tracking parse state
cur_channel: Current channel ("final", "analysis", "commentary", etc.) token_states: List of TokenState tuples for each token
cur_recipient: Current recipient (e.g., "functions.my_func")
prev_recipient: Previous recipient for detecting tool call transitions prev_recipient: Previous recipient for detecting tool call transitions
delta_text: The text delta to include in the message
include_reasoning: Whether to include reasoning content include_reasoning: Whether to include reasoning content
Returns: Returns:
A tuple of (DeltaMessage or None, tools_streamed_flag) A tuple of (DeltaMessage or None, tools_streamed_flag)
""" """
if not token_states:
return None, False
tools_streamed = False tools_streamed = False
if cur_channel == "final": # Group consecutive tokens with same channel/recipient
delta_message = DeltaMessage(content=delta_text) groups: list[TokenState] = []
elif (
(cur_channel == "commentary" or cur_channel == "analysis") current_channel = token_states[0].channel
and cur_recipient current_recipient = token_states[0].recipient
and cur_recipient.startswith("functions.") current_text = token_states[0].text
):
# Count completed tool calls to determine index for i in range(1, len(token_states)):
base_index = 0 state = token_states[i]
for msg in harmony_parser.messages: if state.channel == current_channel and state.recipient == current_recipient:
if ( current_text += state.text
(msg.channel == "commentary" or msg.channel == "analysis") else:
and msg.recipient groups.append(TokenState(current_channel, current_recipient, current_text))
and msg.recipient.startswith("functions.") current_channel = state.channel
): current_recipient = state.recipient
base_index += 1 current_text = state.text
if prev_recipient != cur_recipient: groups.append(TokenState(current_channel, current_recipient, current_text))
tool_name = cur_recipient.split("functions.", 1)[1]
delta_message = DeltaMessage( # Process each group and create delta messages
tool_calls=[ delta_message = None
combined_content = ""
combined_reasoning = ""
tool_messages = []
content_encountered = False
# Calculate base_index once before the loop
# This counts completed tool calls in messages
base_index = 0
for msg in harmony_parser.messages:
if (
(msg.channel == "commentary" or msg.channel == "analysis")
and msg.recipient
and msg.recipient.startswith("functions.")
):
base_index += 1
# If there's an ongoing tool call from previous chunk,
# the next new tool call starts at base_index + 1
if prev_recipient and prev_recipient.startswith("functions."):
next_tool_index = base_index + 1
# Ongoing call is at base_index
ongoing_tool_index = base_index
else:
# No ongoing call, next new call is at base_index
next_tool_index = base_index
ongoing_tool_index = None
for group in groups:
if group.channel == "final":
combined_content += group.text
content_encountered = True
elif (
(group.channel == "commentary" or group.channel == "analysis")
and group.recipient
and group.recipient.startswith("functions.")
):
opened_new_call = False
if prev_recipient != group.recipient:
# New tool call - emit the opening message
tool_name = group.recipient.split("functions.", 1)[1]
tool_messages.append(
DeltaToolCall( DeltaToolCall(
id=make_tool_call_id(), id=make_tool_call_id(),
type="function", type="function",
...@@ -69,32 +118,53 @@ def extract_harmony_streaming_delta( ...@@ -69,32 +118,53 @@ def extract_harmony_streaming_delta(
name=tool_name, name=tool_name,
arguments="", arguments="",
), ),
index=base_index, index=next_tool_index,
) )
] )
) opened_new_call = True
elif delta_text: prev_recipient = group.recipient
delta_message = DeltaMessage( # Increment for subsequent new tool calls
tool_calls=[ next_tool_index += 1
if group.text:
# Stream arguments for the ongoing tool call
if opened_new_call:
# Just opened in this group
tool_call_index = next_tool_index - 1
else:
# Continuing from previous chunk
# If ongoing_tool_index is None here, it means
# we're continuing a call but prev_recipient
# wasn't a function. Use base_index.
tool_call_index = (
ongoing_tool_index
if ongoing_tool_index is not None
else base_index
)
tool_messages.append(
DeltaToolCall( DeltaToolCall(
index=base_index, index=tool_call_index,
function=DeltaFunctionCall(arguments=delta_text), function=DeltaFunctionCall(arguments=group.text),
) )
] )
) elif group.channel == "commentary":
else: # Tool call preambles meant to be shown to the user
delta_message = None combined_content += group.text
content_encountered = True
elif group.channel == "analysis" and include_reasoning:
combined_reasoning += group.text
if delta_message is not None: # Combine all non-empty fields into a single message
if content_encountered or combined_reasoning or tool_messages:
delta_kwargs: dict[str, str | list[DeltaToolCall]] = {}
if content_encountered:
delta_kwargs["content"] = combined_content
if combined_reasoning:
delta_kwargs["reasoning"] = combined_reasoning
if tool_messages:
delta_kwargs["tool_calls"] = tool_messages
tools_streamed = True tools_streamed = True
elif cur_channel == "commentary": delta_message = DeltaMessage(**delta_kwargs)
# Tool call preambles meant to be shown to the user
delta_message = DeltaMessage(content=delta_text)
elif cur_channel == "analysis":
if include_reasoning:
delta_message = DeltaMessage(reasoning=delta_text)
else:
delta_message = None
else: else:
delta_message = None delta_message = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment