Unverified Commit d084e9fc authored by Aleksandr Samarin's avatar Aleksandr Samarin Committed by GitHub
Browse files

[MODEL] Fix handling of multiple channels for gpt-oss with speculative decoding (#26291)


Signed-off-by: default avatarAleksandr Samarin <astrlrd@nebius.com>
Signed-off-by: default avatarsouthfreebird <yvorott@gmail.com>
Co-authored-by: default avatarsouthfreebird <yvorott@gmail.com>
parent 3a612322
...@@ -35,6 +35,7 @@ from .utils import ( ...@@ -35,6 +35,7 @@ from .utils import (
) )
GPT_OSS_MODEL_NAME = "openai/gpt-oss-20b" GPT_OSS_MODEL_NAME = "openai/gpt-oss-20b"
GPT_OSS_SPECULATOR_NAME = "RedHatAI/gpt-oss-20b-speculator.eagle3"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
...@@ -66,7 +67,8 @@ def exclude_tools_when_tool_choice_none(request) -> bool: ...@@ -66,7 +67,8 @@ def exclude_tools_when_tool_choice_none(request) -> bool:
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def default_server_args( def default_server_args(
with_tool_parser: bool, exclude_tools_when_tool_choice_none: bool with_tool_parser: bool,
exclude_tools_when_tool_choice_none: bool,
): ):
args = [ args = [
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
...@@ -76,7 +78,7 @@ def default_server_args( ...@@ -76,7 +78,7 @@ def default_server_args(
"--reasoning-parser", "--reasoning-parser",
"openai_gptoss", "openai_gptoss",
"--gpu-memory-utilization", "--gpu-memory-utilization",
"0.8", "0.85",
] ]
if with_tool_parser: if with_tool_parser:
args.extend( args.extend(
...@@ -91,23 +93,42 @@ def default_server_args( ...@@ -91,23 +93,42 @@ def default_server_args(
return args return args
@pytest.fixture(scope="module") @pytest.fixture(scope="class")
def gptoss_server(default_server_args: list[str]): def gptoss_server(default_server_args: list[str]):
server_args = default_server_args + ["--attention-backend=TRITON_ATTN"] server_args = default_server_args + ["--attention-backend=TRITON_ATTN"]
with RemoteOpenAIServer(GPT_OSS_MODEL_NAME, server_args) as remote_server: with RemoteOpenAIServer(GPT_OSS_MODEL_NAME, server_args) as remote_server:
yield remote_server yield remote_server
@pytest.fixture(scope="class")
def gptoss_speculative_server(default_server_args: list[str]):
server_args = default_server_args + [
"--speculative-config",
f'{{"model": "{GPT_OSS_SPECULATOR_NAME}", '
f'"method": "eagle3", "num_speculative_tokens": 3}}',
"--attention-backend=TRITON_ATTN",
]
with RemoteOpenAIServer(GPT_OSS_MODEL_NAME, server_args) as remote_server:
yield remote_server
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def gptoss_client(gptoss_server): async def gptoss_client(gptoss_server):
async with gptoss_server.get_async_client() as async_client: async with gptoss_server.get_async_client() as async_client:
yield async_client yield async_client
@pytest.mark.asyncio @pytest_asyncio.fixture
async def test_gpt_oss_chat_tool_call_streaming( async def gptoss_speculative_client(gptoss_speculative_server):
gptoss_client: OpenAI, with_tool_parser: bool async with gptoss_speculative_server.get_async_client() as async_client:
): yield async_client
class TestGPTOSSChat:
@pytest.mark.asyncio
async def test_gpt_oss_chat_tool_call_streaming(
self, gptoss_client: OpenAI, with_tool_parser: bool
):
tools = [ tools = [
{ {
"type": "function", "type": "function",
...@@ -162,9 +183,10 @@ async def test_gpt_oss_chat_tool_call_streaming( ...@@ -162,9 +183,10 @@ async def test_gpt_oss_chat_tool_call_streaming(
assert len(args_buf) == 0 assert len(args_buf) == 0
assert len(content_buf) > 0 assert len(content_buf) > 0
@pytest.mark.asyncio
@pytest.mark.asyncio async def test_gpt_oss_multi_turn_chat(
async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI, with_tool_parser: bool): self, gptoss_client: OpenAI, with_tool_parser: bool
):
if not with_tool_parser: if not with_tool_parser:
pytest.skip("skip non-tool for multi-turn tests") pytest.skip("skip non-tool for multi-turn tests")
tools = [ tools = [
...@@ -191,7 +213,10 @@ async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI, with_tool_parser: ...@@ -191,7 +213,10 @@ async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI, with_tool_parser:
messages = [ messages = [
{"role": "system", "content": "you are a helpful assistant"}, {"role": "system", "content": "you are a helpful assistant"},
{"role": "user", "content": "What is the weather in Dallas, TX with celsius?"}, {
"role": "user",
"content": "What is the weather in Dallas, TX with celsius?",
},
] ]
first = await gptoss_client.chat.completions.create( first = await gptoss_client.chat.completions.create(
...@@ -224,11 +249,10 @@ async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI, with_tool_parser: ...@@ -224,11 +249,10 @@ async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI, with_tool_parser:
second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0 second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0
) )
@pytest.mark.asyncio
@pytest.mark.asyncio async def test_gpt_oss_tool_message_array_content(
async def test_gpt_oss_tool_message_array_content( self, gptoss_client: OpenAI, with_tool_parser: bool
gptoss_client: OpenAI, with_tool_parser: bool ):
):
"""Test that tool messages support both string and array content formats.""" """Test that tool messages support both string and array content formats."""
if not with_tool_parser: if not with_tool_parser:
pytest.skip("skip non-tool for array content tests") pytest.skip("skip non-tool for array content tests")
...@@ -350,13 +374,13 @@ async def test_gpt_oss_tool_message_array_content( ...@@ -350,13 +374,13 @@ async def test_gpt_oss_tool_message_array_content(
assert response_multi_array is not None assert response_multi_array is not None
assert response_multi_array.choices[0].message is not None assert response_multi_array.choices[0].message is not None
@pytest.mark.asyncio
@pytest.mark.asyncio async def test_gpt_oss_tool_choice_none(
async def test_gpt_oss_tool_choice_none( self,
gptoss_client: OpenAI, gptoss_client: OpenAI,
with_tool_parser: bool, with_tool_parser: bool,
exclude_tools_when_tool_choice_none: bool, exclude_tools_when_tool_choice_none: bool,
): ):
if not (with_tool_parser and exclude_tools_when_tool_choice_none): if not (with_tool_parser and exclude_tools_when_tool_choice_none):
pytest.skip( pytest.skip(
"skip tool_choice tests when non-tool or " "skip tool_choice tests when non-tool or "
...@@ -414,6 +438,42 @@ async def test_gpt_oss_tool_choice_none( ...@@ -414,6 +438,42 @@ async def test_gpt_oss_tool_choice_none(
assert len(msg.tool_calls) == 0 assert len(msg.tool_calls) == 0
class TestGPTOSSSpeculativeChat:
@pytest.mark.asyncio
async def test_gpt_oss_speculative_reasoning_leakage(
self,
gptoss_speculative_client: OpenAI,
with_tool_parser: bool,
):
if not with_tool_parser:
pytest.skip("skip non-tool for array content tests")
messages = [
{"role": "user", "content": "Calculate 2+2. Return the answer 4 only."},
]
stream = await gptoss_speculative_client.chat.completions.create(
model=GPT_OSS_MODEL_NAME,
messages=messages,
stream=True,
temperature=0.0,
)
content = ""
reasoning_content = ""
async for chunk in stream:
delta = chunk.choices[0].delta
if delta.content:
content += delta.content
chunk_reasoning = getattr(delta, "reasoning", None)
if chunk_reasoning:
reasoning_content += delta.reasoning
assert len(reasoning_content) > 0, "No reasoning was generated."
assert content.strip() == "4"
MODEL_NAME = "openai-community/gpt2" MODEL_NAME = "openai-community/gpt2"
MODEL_NAME_SHORT = "gpt2" MODEL_NAME_SHORT = "gpt2"
CHAT_TEMPLATE = "Dummy chat template for testing {}" CHAT_TEMPLATE = "Dummy chat template for testing {}"
......
...@@ -10,6 +10,7 @@ from unittest.mock import patch ...@@ -10,6 +10,7 @@ from unittest.mock import patch
import pytest import pytest
from vllm.entrypoints.openai.chat_completion.stream_harmony import ( from vllm.entrypoints.openai.chat_completion.stream_harmony import (
TokenState,
extract_harmony_streaming_delta, extract_harmony_streaming_delta,
) )
...@@ -42,12 +43,14 @@ class TestExtractHarmonyStreamingDelta: ...@@ -42,12 +43,14 @@ class TestExtractHarmonyStreamingDelta:
def test_final_channel_returns_content_delta(self, delta_text, expected_content): def test_final_channel_returns_content_delta(self, delta_text, expected_content):
"""Test that final channel returns a DeltaMessage with content.""" """Test that final channel returns a DeltaMessage with content."""
parser = MockStreamableParser() parser = MockStreamableParser()
# Updated to use TokenState list
token_states = [TokenState(channel="final", recipient=None, text=delta_text)]
delta_message, tools_streamed = extract_harmony_streaming_delta( delta_message, tools_streamed = extract_harmony_streaming_delta(
harmony_parser=parser, harmony_parser=parser,
cur_channel="final", token_states=token_states,
cur_recipient=None,
prev_recipient=None, prev_recipient=None,
delta_text=delta_text,
include_reasoning=False, include_reasoning=False,
) )
...@@ -65,18 +68,19 @@ class TestExtractHarmonyStreamingDelta: ...@@ -65,18 +68,19 @@ class TestExtractHarmonyStreamingDelta:
def test_analysis_channel_reasoning(self, include_reasoning, expected_has_message): def test_analysis_channel_reasoning(self, include_reasoning, expected_has_message):
"""Test analysis channel respects include_reasoning flag.""" """Test analysis channel respects include_reasoning flag."""
parser = MockStreamableParser() parser = MockStreamableParser()
text = "Let me think..."
token_states = [TokenState(channel="analysis", recipient=None, text=text)]
delta_message, tools_streamed = extract_harmony_streaming_delta( delta_message, tools_streamed = extract_harmony_streaming_delta(
harmony_parser=parser, harmony_parser=parser,
cur_channel="analysis", token_states=token_states,
cur_recipient=None,
prev_recipient=None, prev_recipient=None,
delta_text="Let me think...",
include_reasoning=include_reasoning, include_reasoning=include_reasoning,
) )
if expected_has_message: if expected_has_message:
assert delta_message is not None assert delta_message is not None
assert delta_message.reasoning == "Let me think..." assert delta_message.reasoning == text
else: else:
assert delta_message is None assert delta_message is None
assert tools_streamed is False assert tools_streamed is False
...@@ -88,12 +92,14 @@ class TestExtractHarmonyStreamingDelta: ...@@ -88,12 +92,14 @@ class TestExtractHarmonyStreamingDelta:
mock_make_tool_call_id.return_value = "call_test123" mock_make_tool_call_id.return_value = "call_test123"
parser = MockStreamableParser() parser = MockStreamableParser()
token_states = [
TokenState(channel=channel, recipient="functions.get_weather", text="")
]
delta_message, tools_streamed = extract_harmony_streaming_delta( delta_message, tools_streamed = extract_harmony_streaming_delta(
harmony_parser=parser, harmony_parser=parser,
cur_channel=channel, token_states=token_states,
cur_recipient="functions.get_weather",
prev_recipient=None, prev_recipient=None,
delta_text="",
include_reasoning=False, include_reasoning=False,
) )
...@@ -111,20 +117,25 @@ class TestExtractHarmonyStreamingDelta: ...@@ -111,20 +117,25 @@ class TestExtractHarmonyStreamingDelta:
def test_tool_call_argument_streaming(self, channel): def test_tool_call_argument_streaming(self, channel):
"""Test streaming tool call arguments (same recipient).""" """Test streaming tool call arguments (same recipient)."""
parser = MockStreamableParser() parser = MockStreamableParser()
args_text = '{"location": "Paris"}'
token_states = [
TokenState(
channel=channel, recipient="functions.get_weather", text=args_text
)
]
delta_message, tools_streamed = extract_harmony_streaming_delta( delta_message, tools_streamed = extract_harmony_streaming_delta(
harmony_parser=parser, harmony_parser=parser,
cur_channel=channel, token_states=token_states,
cur_recipient="functions.get_weather",
prev_recipient="functions.get_weather", prev_recipient="functions.get_weather",
delta_text='{"location": "Paris"}',
include_reasoning=False, include_reasoning=False,
) )
assert delta_message is not None assert delta_message is not None
tool_call = delta_message.tool_calls[0] tool_call = delta_message.tool_calls[0]
assert tool_call.id is None assert tool_call.id is None
assert tool_call.function.arguments == '{"location": "Paris"}' assert tool_call.function.arguments == args_text
assert tool_call.index == 0 assert tool_call.index == 0
assert tools_streamed is True assert tools_streamed is True
...@@ -133,12 +144,14 @@ class TestExtractHarmonyStreamingDelta: ...@@ -133,12 +144,14 @@ class TestExtractHarmonyStreamingDelta:
"""Test empty delta_text with same recipient returns None.""" """Test empty delta_text with same recipient returns None."""
parser = MockStreamableParser() parser = MockStreamableParser()
token_states = [
TokenState(channel=channel, recipient="functions.get_weather", text="")
]
delta_message, tools_streamed = extract_harmony_streaming_delta( delta_message, tools_streamed = extract_harmony_streaming_delta(
harmony_parser=parser, harmony_parser=parser,
cur_channel=channel, token_states=token_states,
cur_recipient="functions.get_weather",
prev_recipient="functions.get_weather", prev_recipient="functions.get_weather",
delta_text="",
include_reasoning=False, include_reasoning=False,
) )
...@@ -154,12 +167,14 @@ class TestExtractHarmonyStreamingDelta: ...@@ -154,12 +167,14 @@ class TestExtractHarmonyStreamingDelta:
] ]
parser = MockStreamableParser(messages=messages) parser = MockStreamableParser(messages=messages)
token_states = [
TokenState(channel="commentary", recipient="functions.tool2", text="args")
]
delta_message, _ = extract_harmony_streaming_delta( delta_message, _ = extract_harmony_streaming_delta(
harmony_parser=parser, harmony_parser=parser,
cur_channel="commentary", token_states=token_states,
cur_recipient="functions.tool2",
prev_recipient="functions.tool2", prev_recipient="functions.tool2",
delta_text="args",
include_reasoning=False, include_reasoning=False,
) )
...@@ -173,15 +188,18 @@ class TestExtractHarmonyStreamingDelta: ...@@ -173,15 +188,18 @@ class TestExtractHarmonyStreamingDelta:
], ],
) )
def test_returns_tool_call_preambles(self, channel, recipient): def test_returns_tool_call_preambles(self, channel, recipient):
"""Test that invalid channel/recipient combinations return None.""" """Test that invalid tool recipient on commentary is treated as content."""
parser = MockStreamableParser() parser = MockStreamableParser()
delta_text = "some text" delta_text = "some text"
token_states = [
TokenState(channel=channel, recipient=recipient, text=delta_text)
]
delta_message, tools_streamed = extract_harmony_streaming_delta( delta_message, tools_streamed = extract_harmony_streaming_delta(
harmony_parser=parser, harmony_parser=parser,
cur_channel=channel, token_states=token_states,
cur_recipient=recipient,
prev_recipient=None, prev_recipient=None,
delta_text=delta_text,
include_reasoning=True, include_reasoning=True,
) )
...@@ -199,14 +217,140 @@ class TestExtractHarmonyStreamingDelta: ...@@ -199,14 +217,140 @@ class TestExtractHarmonyStreamingDelta:
"""Test that invalid channel/recipient combinations return None.""" """Test that invalid channel/recipient combinations return None."""
parser = MockStreamableParser() parser = MockStreamableParser()
token_states = [
TokenState(channel=channel, recipient=recipient, text="some text")
]
delta_message, tools_streamed = extract_harmony_streaming_delta( delta_message, tools_streamed = extract_harmony_streaming_delta(
harmony_parser=parser, harmony_parser=parser,
cur_channel=channel, token_states=token_states,
cur_recipient=recipient,
prev_recipient=None, prev_recipient=None,
delta_text="some text",
include_reasoning=True, include_reasoning=True,
) )
assert delta_message is None assert delta_message is None
assert tools_streamed is False assert tools_streamed is False
def test_consecutive_token_grouping(self):
"""
Test that consecutive tokens with the same channel/recipient
are merged into a single processing group.
"""
parser = MockStreamableParser()
token_states = [
TokenState("final", None, "H"),
TokenState("final", None, "el"),
TokenState("final", None, "lo"),
TokenState("final", None, ","),
TokenState("final", None, " World"),
]
delta_message, _ = extract_harmony_streaming_delta(
harmony_parser=parser,
token_states=token_states,
prev_recipient=None,
include_reasoning=False,
)
assert delta_message is not None
assert delta_message.content == "Hello, World"
@patch("vllm.entrypoints.openai.chat_completion.stream_harmony.make_tool_call_id")
def test_complex_batch_permutation(self, mock_make_id):
"""
Test a complex permutation: Reasoning -> Tool Call -> Content.
This verifies that multiple distinct actions in one batch
are all captured in the single DeltaMessage.
"""
mock_make_id.return_value = "call_batch_test"
parser = MockStreamableParser()
token_states = [
# 1. Reasoning
TokenState("analysis", None, "Reasoning about query..."),
# 2. Tool Calling
TokenState("commentary", "functions.search", '{"query":'),
TokenState("commentary", "functions.search", ' "vllm"}'),
# 3. Final Content
TokenState("final", None, "."),
]
delta_message, tools_streamed = extract_harmony_streaming_delta(
harmony_parser=parser,
token_states=token_states,
prev_recipient=None,
include_reasoning=True,
)
assert delta_message is not None
assert delta_message.reasoning == "Reasoning about query..."
# We expect 2 objects for 1 logical tool call:
# 1. The definition (id, name, type)
# 2. The arguments payload
assert len(delta_message.tool_calls) == 2
header = delta_message.tool_calls[0]
payload = delta_message.tool_calls[1]
assert header.function.name == "search"
assert header.id == "call_batch_test"
assert header.index == 0
assert payload.index == 0
assert payload.function.arguments == '{"query": "vllm"}'
assert delta_message.content == "."
assert tools_streamed is True
@patch("vllm.entrypoints.openai.chat_completion.stream_harmony.make_tool_call_id")
def test_tool_call_index_consistency_with_ongoing_call(self, mock_make_id):
"""
Test that an ongoing tool call continuation and subsequent new calls
maintain correct indexing when interleaved with content.
"""
mock_make_id.side_effect = ["id_b", "id_c"]
messages = [
MockMessage(channel="commentary", recipient="functions.previous_tool")
]
parser = MockStreamableParser(messages=messages)
token_states = [
TokenState("commentary", "functions.tool_a", '{"key_a": "val_a"}'),
TokenState("final", None, "Thinking..."),
TokenState("commentary", "functions.tool_b", '{"key_b": "val_b"}'),
TokenState("final", None, " Thinking again..."),
TokenState("commentary", "functions.tool_c", '{"key_c": "val_c"}'),
]
delta_message, _ = extract_harmony_streaming_delta(
harmony_parser=parser,
token_states=token_states,
prev_recipient="functions.tool_a",
include_reasoning=False,
)
assert delta_message is not None
tool_a_deltas = [t for t in delta_message.tool_calls if t.index == 1]
assert len(tool_a_deltas) > 0
assert tool_a_deltas[0].id is None
assert tool_a_deltas[0].function.arguments == '{"key_a": "val_a"}'
tool_b_header = next(t for t in delta_message.tool_calls if t.id == "id_b")
assert tool_b_header.index == 2
tool_b_args = next(
t for t in delta_message.tool_calls if t.index == 2 and t.id is None
)
assert tool_b_args.function.arguments == '{"key_b": "val_b"}'
tool_c_start = next(t for t in delta_message.tool_calls if t.id == "id_c")
assert tool_c_start.index == 3
tool_c_args = next(
t for t in delta_message.tool_calls if t.index == 3 and t.id is None
)
assert tool_c_args.function.arguments == '{"key_c": "val_c"}'
assert delta_message.content == "Thinking... Thinking again..."
...@@ -36,6 +36,7 @@ from vllm.entrypoints.openai.chat_completion.protocol import ( ...@@ -36,6 +36,7 @@ from vllm.entrypoints.openai.chat_completion.protocol import (
ChatMessage, ChatMessage,
) )
from vllm.entrypoints.openai.chat_completion.stream_harmony import ( from vllm.entrypoints.openai.chat_completion.stream_harmony import (
TokenState,
extract_harmony_streaming_delta, extract_harmony_streaming_delta,
) )
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.openai.engine.protocol import (
...@@ -826,12 +827,22 @@ class OpenAIServingChat(OpenAIServing): ...@@ -826,12 +827,22 @@ class OpenAIServingChat(OpenAIServing):
if self.use_harmony: if self.use_harmony:
harmony_parser = harmony_parsers[i] harmony_parser = harmony_parsers[i]
prev_recipient = harmony_parser.current_recipient prev_recipient = harmony_parser.current_recipient
delta_text = ""
# Track accumulated content per token with their state
token_states: list[TokenState] = []
for token_id in output.token_ids: for token_id in output.token_ids:
harmony_parser.process(token_id) harmony_parser.process(token_id)
delta_text += harmony_parser.last_content_delta or "" token_delta = harmony_parser.last_content_delta or ""
token_states.append(
TokenState(
harmony_parser.current_channel,
harmony_parser.current_recipient,
token_delta,
)
)
delta_text = "".join(delta for _, _, delta in token_states)
cur_channel = harmony_parser.current_channel cur_channel = harmony_parser.current_channel
cur_recipient = harmony_parser.current_recipient
# handle the case where several tokens where generated at once # handle the case where several tokens where generated at once
# including the final token, leading to a delta in the text # including the final token, leading to a delta in the text
# but the current channel to be empty (start state) # but the current channel to be empty (start state)
...@@ -869,10 +880,8 @@ class OpenAIServingChat(OpenAIServing): ...@@ -869,10 +880,8 @@ class OpenAIServingChat(OpenAIServing):
delta_message, tools_streamed_flag = ( delta_message, tools_streamed_flag = (
extract_harmony_streaming_delta( extract_harmony_streaming_delta(
harmony_parser=harmony_parser, harmony_parser=harmony_parser,
cur_channel=cur_channel, token_states=token_states,
cur_recipient=cur_recipient,
prev_recipient=prev_recipient, prev_recipient=prev_recipient,
delta_text=delta_text,
include_reasoning=request.include_reasoning, include_reasoning=request.include_reasoning,
) )
) )
...@@ -1139,17 +1148,23 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1139,17 +1148,23 @@ class OpenAIServingChat(OpenAIServing):
# Log streaming delta if output logging is enabled # Log streaming delta if output logging is enabled
if self.enable_log_outputs and self.request_logger: if self.enable_log_outputs and self.request_logger:
delta_content = "" delta_content_parts = []
if delta_message.content: if delta_message.content:
delta_content = delta_message.content delta_content_parts.append(delta_message.content)
elif delta_message.tool_calls: if delta_message.reasoning_content:
delta_content = "".join( reasoning = delta_message.reasoning_content
delta_content_parts.append(f"[reasoning: {reasoning}]")
if delta_message.tool_calls:
tool_args = "".join(
tc.function.arguments tc.function.arguments
for tc in delta_message.tool_calls for tc in delta_message.tool_calls
if tc.function and tc.function.arguments if tc.function and tc.function.arguments
) )
if tool_args:
delta_content_parts.append(f"[tool_calls: {tool_args}]")
if delta_content and self.enable_log_deltas: if delta_content_parts and self.enable_log_deltas:
delta_content = " ".join(delta_content_parts)
self.request_logger.log_outputs( self.request_logger.log_outputs(
request_id=request_id, request_id=request_id,
outputs=delta_content, outputs=delta_content,
......
...@@ -7,6 +7,8 @@ This module handles the extraction of DeltaMessage objects from ...@@ -7,6 +7,8 @@ This module handles the extraction of DeltaMessage objects from
harmony parser state during streaming chat completions. harmony parser state during streaming chat completions.
""" """
from typing import NamedTuple
from openai_harmony import StreamableParser from openai_harmony import StreamableParser
from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.chat_utils import make_tool_call_id
...@@ -17,12 +19,16 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -17,12 +19,16 @@ from vllm.entrypoints.openai.engine.protocol import (
) )
class TokenState(NamedTuple):
channel: str | None
recipient: str | None
text: str
def extract_harmony_streaming_delta( def extract_harmony_streaming_delta(
harmony_parser: StreamableParser, harmony_parser: StreamableParser,
cur_channel: str | None, token_states: list[TokenState],
cur_recipient: str | None,
prev_recipient: str | None, prev_recipient: str | None,
delta_text: str,
include_reasoning: bool, include_reasoning: bool,
) -> tuple[DeltaMessage | None, bool]: ) -> tuple[DeltaMessage | None, bool]:
""" """
...@@ -30,25 +36,47 @@ def extract_harmony_streaming_delta( ...@@ -30,25 +36,47 @@ def extract_harmony_streaming_delta(
Args: Args:
harmony_parser: The StreamableParser instance tracking parse state harmony_parser: The StreamableParser instance tracking parse state
cur_channel: Current channel ("final", "analysis", "commentary", etc.) token_states: List of TokenState tuples for each token
cur_recipient: Current recipient (e.g., "functions.my_func")
prev_recipient: Previous recipient for detecting tool call transitions prev_recipient: Previous recipient for detecting tool call transitions
delta_text: The text delta to include in the message
include_reasoning: Whether to include reasoning content include_reasoning: Whether to include reasoning content
Returns: Returns:
A tuple of (DeltaMessage or None, tools_streamed_flag) A tuple of (DeltaMessage or None, tools_streamed_flag)
""" """
if not token_states:
return None, False
tools_streamed = False tools_streamed = False
if cur_channel == "final": # Group consecutive tokens with same channel/recipient
delta_message = DeltaMessage(content=delta_text) groups: list[TokenState] = []
elif (
(cur_channel == "commentary" or cur_channel == "analysis") current_channel = token_states[0].channel
and cur_recipient current_recipient = token_states[0].recipient
and cur_recipient.startswith("functions.") current_text = token_states[0].text
):
# Count completed tool calls to determine index for i in range(1, len(token_states)):
state = token_states[i]
if state.channel == current_channel and state.recipient == current_recipient:
current_text += state.text
else:
groups.append(TokenState(current_channel, current_recipient, current_text))
current_channel = state.channel
current_recipient = state.recipient
current_text = state.text
groups.append(TokenState(current_channel, current_recipient, current_text))
# Process each group and create delta messages
delta_message = None
combined_content = ""
combined_reasoning = ""
tool_messages = []
content_encountered = False
# Calculate base_index once before the loop
# This counts completed tool calls in messages
base_index = 0 base_index = 0
for msg in harmony_parser.messages: for msg in harmony_parser.messages:
if ( if (
...@@ -58,10 +86,31 @@ def extract_harmony_streaming_delta( ...@@ -58,10 +86,31 @@ def extract_harmony_streaming_delta(
): ):
base_index += 1 base_index += 1
if prev_recipient != cur_recipient: # If there's an ongoing tool call from previous chunk,
tool_name = cur_recipient.split("functions.", 1)[1] # the next new tool call starts at base_index + 1
delta_message = DeltaMessage( if prev_recipient and prev_recipient.startswith("functions."):
tool_calls=[ next_tool_index = base_index + 1
# Ongoing call is at base_index
ongoing_tool_index = base_index
else:
# No ongoing call, next new call is at base_index
next_tool_index = base_index
ongoing_tool_index = None
for group in groups:
if group.channel == "final":
combined_content += group.text
content_encountered = True
elif (
(group.channel == "commentary" or group.channel == "analysis")
and group.recipient
and group.recipient.startswith("functions.")
):
opened_new_call = False
if prev_recipient != group.recipient:
# New tool call - emit the opening message
tool_name = group.recipient.split("functions.", 1)[1]
tool_messages.append(
DeltaToolCall( DeltaToolCall(
id=make_tool_call_id(), id=make_tool_call_id(),
type="function", type="function",
...@@ -69,32 +118,53 @@ def extract_harmony_streaming_delta( ...@@ -69,32 +118,53 @@ def extract_harmony_streaming_delta(
name=tool_name, name=tool_name,
arguments="", arguments="",
), ),
index=base_index, index=next_tool_index,
) )
]
) )
elif delta_text: opened_new_call = True
delta_message = DeltaMessage( prev_recipient = group.recipient
tool_calls=[ # Increment for subsequent new tool calls
next_tool_index += 1
if group.text:
# Stream arguments for the ongoing tool call
if opened_new_call:
# Just opened in this group
tool_call_index = next_tool_index - 1
else:
# Continuing from previous chunk
# If ongoing_tool_index is None here, it means
# we're continuing a call but prev_recipient
# wasn't a function. Use base_index.
tool_call_index = (
ongoing_tool_index
if ongoing_tool_index is not None
else base_index
)
tool_messages.append(
DeltaToolCall( DeltaToolCall(
index=base_index, index=tool_call_index,
function=DeltaFunctionCall(arguments=delta_text), function=DeltaFunctionCall(arguments=group.text),
) )
]
) )
else: elif group.channel == "commentary":
delta_message = None # Tool call preambles meant to be shown to the user
combined_content += group.text
content_encountered = True
elif group.channel == "analysis" and include_reasoning:
combined_reasoning += group.text
if delta_message is not None: # Combine all non-empty fields into a single message
if content_encountered or combined_reasoning or tool_messages:
delta_kwargs: dict[str, str | list[DeltaToolCall]] = {}
if content_encountered:
delta_kwargs["content"] = combined_content
if combined_reasoning:
delta_kwargs["reasoning"] = combined_reasoning
if tool_messages:
delta_kwargs["tool_calls"] = tool_messages
tools_streamed = True tools_streamed = True
elif cur_channel == "commentary": delta_message = DeltaMessage(**delta_kwargs)
# Tool call preambles meant to be shown to the user
delta_message = DeltaMessage(content=delta_text)
elif cur_channel == "analysis":
if include_reasoning:
delta_message = DeltaMessage(reasoning=delta_text)
else:
delta_message = None
else: else:
delta_message = None delta_message = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment