Unverified Commit a4ec0c55 authored by daniel-salib's avatar daniel-salib Committed by GitHub
Browse files

[Frontend] Add MCP tool streaming support to Responses API (#31761)


Signed-off-by: default avatarDaniel Salib <danielsalib@meta.com>
parent 0fa8dd24
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from openai import OpenAI from openai import OpenAI
...@@ -13,57 +14,100 @@ from ...utils import RemoteOpenAIServer ...@@ -13,57 +14,100 @@ from ...utils import RemoteOpenAIServer
MODEL_NAME = "openai/gpt-oss-20b" MODEL_NAME = "openai/gpt-oss-20b"
@pytest.fixture(scope="module") def test_get_tool_description():
def monkeypatch_module(): """Test MCPToolServer.get_tool_description filtering logic.
from _pytest.monkeypatch import MonkeyPatch
mpatch = MonkeyPatch() Note: The wildcard "*" is normalized to None by
yield mpatch _extract_allowed_tools_from_mcp_requests before reaching this layer,
mpatch.undo() so we only test None and specific tool filtering here.
See test_serving_responses.py for "*" normalization tests.
"""
pytest.importorskip("mcp")
server = MCPToolServer()
tool1 = ToolDescription.new(
name="tool1", description="First", parameters={"type": "object"}
)
tool2 = ToolDescription.new(
name="tool2", description="Second", parameters={"type": "object"}
)
tool3 = ToolDescription.new(
name="tool3", description="Third", parameters={"type": "object"}
)
@pytest.fixture(scope="module") server.harmony_tool_descriptions = {
def mcp_disabled_server(monkeypatch_module: pytest.MonkeyPatch): "test_server": ToolNamespaceConfig(
args = ["--enforce-eager", "--tool-server", "demo"] name="test_server", description="test", tools=[tool1, tool2, tool3]
)
}
with monkeypatch_module.context() as m: # Nonexistent server
m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1") assert server.get_tool_description("nonexistent") is None
m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv")
# Helps the model follow instructions better # None (no filter) - returns all tools
m.setenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "1") result = server.get_tool_description("test_server", allowed_tools=None)
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: assert len(result.tools) == 3
yield remote_server
# Filter to specific tools
result = server.get_tool_description(
"test_server", allowed_tools=["tool1", "tool3"]
)
assert len(result.tools) == 2
assert result.tools[0].name == "tool1"
assert result.tools[1].name == "tool3"
# Single tool
result = server.get_tool_description(
"test_server",
allowed_tools=["tool2"],
)
assert len(result.tools) == 1
assert result.tools[0].name == "tool2"
# No matching tools - returns None
result = server.get_tool_description("test_server", allowed_tools=["nonexistent"])
assert result is None
# Empty list - returns None
assert server.get_tool_description("test_server", allowed_tools=[]) is None
class TestMCPEnabled:
"""Tests that require MCP tools to be enabled via environment variable."""
@pytest.fixture(scope="class")
def monkeypatch_class(self):
from _pytest.monkeypatch import MonkeyPatch
@pytest.fixture(scope="function") mpatch = MonkeyPatch()
def mcp_enabled_server(monkeypatch_module: pytest.MonkeyPatch): yield mpatch
mpatch.undo()
@pytest.fixture(scope="class")
def mcp_enabled_server(self, monkeypatch_class: pytest.MonkeyPatch):
args = ["--enforce-eager", "--tool-server", "demo"] args = ["--enforce-eager", "--tool-server", "demo"]
with monkeypatch_module.context() as m: with monkeypatch_class.context() as m:
m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1") m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1")
m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv") m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv")
m.setenv("VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS", "code_interpreter,container") m.setenv(
"VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS", "code_interpreter,container"
)
# Helps the model follow instructions better # Helps the model follow instructions better
m.setenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "1") m.setenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "1")
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server yield remote_server
@pytest_asyncio.fixture
@pytest_asyncio.fixture async def mcp_enabled_client(self, mcp_enabled_server):
async def mcp_disabled_client(mcp_disabled_server):
async with mcp_disabled_server.get_async_client() as async_client:
yield async_client
@pytest_asyncio.fixture
async def mcp_enabled_client(mcp_enabled_server):
async with mcp_enabled_server.get_async_client() as async_client: async with mcp_enabled_server.get_async_client() as async_client:
yield async_client yield async_client
@pytest.mark.asyncio
@pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_mcp_tool_env_flag_enabled(
async def test_mcp_tool_env_flag_enabled(mcp_enabled_client: OpenAI, model_name: str): self, mcp_enabled_client: OpenAI, model_name: str
):
response = await mcp_enabled_client.responses.create( response = await mcp_enabled_client.responses.create(
model=model_name, model=model_name,
input=( input=(
...@@ -71,7 +115,8 @@ async def test_mcp_tool_env_flag_enabled(mcp_enabled_client: OpenAI, model_name: ...@@ -71,7 +115,8 @@ async def test_mcp_tool_env_flag_enabled(mcp_enabled_client: OpenAI, model_name:
"import random; print(random.randint(1, 1000000))" "import random; print(random.randint(1, 1000000))"
), ),
instructions=( instructions=(
"You must use the Python tool to execute code. Never simulate execution." "You must use the Python tool to execute code. "
"Never simulate execution."
), ),
tools=[ tools=[
{ {
...@@ -107,22 +152,25 @@ async def test_mcp_tool_env_flag_enabled(mcp_enabled_client: OpenAI, model_name: ...@@ -107,22 +152,25 @@ async def test_mcp_tool_env_flag_enabled(mcp_enabled_client: OpenAI, model_name:
) )
assert tool_call_found, "Should have found at least one Python tool call" assert tool_call_found, "Should have found at least one Python tool call"
assert tool_response_found, "Should have found at least one Python tool response" assert tool_response_found, (
"Should have found at least one Python tool response"
)
for message in response.input_messages: for message in response.input_messages:
assert message.get("author").get("role") != "developer", ( assert message.get("author").get("role") != "developer", (
"No developer messages should be present with valid mcp tool" "No developer messages should be present with valid mcp tool"
) )
@pytest.mark.asyncio
@pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_mcp_tool_with_allowed_tools_star(
async def test_mcp_tool_with_allowed_tools_star( self, mcp_enabled_client: OpenAI, model_name: str
mcp_enabled_client: OpenAI, model_name: str ):
): """Test MCP tool with allowed_tools=['*'] to select all available
"""Test MCP tool with allowed_tools=['*'] to select all available tools. tools.
This E2E test verifies that the "*" wildcard works end-to-end. This E2E test verifies that the "*" wildcard works end-to-end.
See test_serving_responses.py for detailed unit tests of "*" normalization. See test_serving_responses.py for detailed unit tests of "*"
normalization.
""" """
response = await mcp_enabled_client.responses.create( response = await mcp_enabled_client.responses.create(
model=model_name, model=model_name,
...@@ -131,7 +179,8 @@ async def test_mcp_tool_with_allowed_tools_star( ...@@ -131,7 +179,8 @@ async def test_mcp_tool_with_allowed_tools_star(
"import random; print(random.randint(1, 1000000))" "import random; print(random.randint(1, 1000000))"
), ),
instructions=( instructions=(
"You must use the Python tool to execute code. Never simulate execution." "You must use the Python tool to execute code. "
"Never simulate execution."
), ),
tools=[ tools=[
{ {
...@@ -153,12 +202,109 @@ async def test_mcp_tool_with_allowed_tools_star( ...@@ -153,12 +202,109 @@ async def test_mcp_tool_with_allowed_tools_star(
if recipient and recipient.startswith("python"): if recipient and recipient.startswith("python"):
tool_call_found = True tool_call_found = True
break break
assert tool_call_found, "Should have found at least one Python tool call with '*'" assert tool_call_found, (
"Should have found at least one Python tool call with '*'"
)
@pytest.mark.flaky(reruns=3)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_mcp_tool_calling_streaming_types(
self, mcp_enabled_client: OpenAI, model_name: str
):
pairs_of_event_types = {
"response.completed": "response.created",
"response.output_item.done": "response.output_item.added",
"response.content_part.done": "response.content_part.added",
"response.output_text.done": "response.output_text.delta",
"response.reasoning_text.done": "response.reasoning_text.delta",
"response.reasoning_part.done": "response.reasoning_part.added",
"response.mcp_call_arguments.done": ("response.mcp_call_arguments.delta"),
"response.mcp_call.completed": "response.mcp_call.in_progress",
}
tools = [
{
"type": "mcp",
"server_label": "code_interpreter",
}
]
input_text = "What is 13 * 24? Use python to calculate the result."
stream_response = await mcp_enabled_client.responses.create(
model=model_name,
input=input_text,
tools=tools,
stream=True,
instructions=(
"You must use the Python tool to execute code. "
"Never simulate execution."
),
)
stack_of_event_types = []
saw_mcp_type = False
async for event in stream_response:
if event.type == "response.created":
stack_of_event_types.append(event.type)
elif event.type == "response.completed":
assert stack_of_event_types[-1] == pairs_of_event_types[event.type]
stack_of_event_types.pop()
elif (
event.type.endswith("added")
or event.type == "response.mcp_call.in_progress"
):
stack_of_event_types.append(event.type)
elif event.type.endswith("delta"):
if stack_of_event_types[-1] == event.type:
continue
stack_of_event_types.append(event.type)
elif (
event.type.endswith("done")
or event.type == "response.mcp_call.completed"
):
assert stack_of_event_types[-1] == pairs_of_event_types[event.type]
if "mcp_call" in event.type:
saw_mcp_type = True
stack_of_event_types.pop()
assert len(stack_of_event_types) == 0
assert saw_mcp_type, "Should have seen at least one mcp call"
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) class TestMCPDisabled:
async def test_mcp_tool_env_flag_disabled(mcp_disabled_client: OpenAI, model_name: str): """Tests that verify behavior when MCP tools are disabled."""
@pytest.fixture(scope="class")
def monkeypatch_class(self):
from _pytest.monkeypatch import MonkeyPatch
mpatch = MonkeyPatch()
yield mpatch
mpatch.undo()
@pytest.fixture(scope="class")
def mcp_disabled_server(self, monkeypatch_class: pytest.MonkeyPatch):
args = ["--enforce-eager", "--tool-server", "demo"]
with monkeypatch_class.context() as m:
m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1")
m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv")
# Helps the model follow instructions better
m.setenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "1")
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def mcp_disabled_client(self, mcp_disabled_server):
async with mcp_disabled_server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_mcp_tool_env_flag_disabled(
self, mcp_disabled_client: OpenAI, model_name: str
):
response = await mcp_disabled_client.responses.create( response = await mcp_disabled_client.responses.create(
model=model_name, model=model_name,
input=( input=(
...@@ -204,58 +350,3 @@ async def test_mcp_tool_env_flag_disabled(mcp_disabled_client: OpenAI, model_nam ...@@ -204,58 +350,3 @@ async def test_mcp_tool_env_flag_disabled(mcp_disabled_client: OpenAI, model_nam
assert message.get("author").get("role") != "developer", ( assert message.get("author").get("role") != "developer", (
"No developer messages should be present without a valid tool" "No developer messages should be present without a valid tool"
) )
def test_get_tool_description():
"""Test MCPToolServer.get_tool_description filtering logic.
Note: The wildcard "*" is normalized to None by
_extract_allowed_tools_from_mcp_requests before reaching this layer,
so we only test None and specific tool filtering here.
See test_serving_responses.py for "*" normalization tests.
"""
pytest.importorskip("mcp")
server = MCPToolServer()
tool1 = ToolDescription.new(
name="tool1", description="First", parameters={"type": "object"}
)
tool2 = ToolDescription.new(
name="tool2", description="Second", parameters={"type": "object"}
)
tool3 = ToolDescription.new(
name="tool3", description="Third", parameters={"type": "object"}
)
server.harmony_tool_descriptions = {
"test_server": ToolNamespaceConfig(
name="test_server", description="test", tools=[tool1, tool2, tool3]
)
}
# Nonexistent server
assert server.get_tool_description("nonexistent") is None
# None (no filter) - returns all tools
result = server.get_tool_description("test_server", allowed_tools=None)
assert len(result.tools) == 3
# Filter to specific tools
result = server.get_tool_description(
"test_server", allowed_tools=["tool1", "tool3"]
)
assert len(result.tools) == 2
assert result.tools[0].name == "tool1"
assert result.tools[1].name == "tool3"
# Single tool
result = server.get_tool_description("test_server", allowed_tools=["tool2"])
assert len(result.tools) == 1
assert result.tools[0].name == "tool2"
# No matching tools - returns None
result = server.get_tool_description("test_server", allowed_tools=["nonexistent"])
assert result is None
# Empty list - returns None
assert server.get_tool_description("test_server", allowed_tools=[]) is None
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
import importlib.util import importlib.util
import json import json
import time import time
...@@ -44,6 +43,8 @@ def server(): ...@@ -44,6 +43,8 @@ def server():
env_dict = dict( env_dict = dict(
VLLM_ENABLE_RESPONSES_API_STORE="1", VLLM_ENABLE_RESPONSES_API_STORE="1",
PYTHON_EXECUTION_BACKEND="dangerously_use_uv", PYTHON_EXECUTION_BACKEND="dangerously_use_uv",
VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS="code_interpreter,container,web_search_preview",
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS="1",
) )
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as remote_server: with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as remote_server:
...@@ -855,6 +856,237 @@ async def test_function_calling_with_stream(client: OpenAI, model_name: str): ...@@ -855,6 +856,237 @@ async def test_function_calling_with_stream(client: OpenAI, model_name: str):
assert event.response.output_text is not None assert event.response.output_text is not None
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_function_calling_no_code_interpreter_events(
client: OpenAI, model_name: str
):
"""Verify that function calls don't trigger code_interpreter events.
This test ensures that function calls (functions.*) use their own
function_call event types and don't incorrectly emit code_interpreter
events during streaming.
"""
tools = [GET_WEATHER_SCHEMA]
input_list = [
{
"role": "user",
"content": "What's the weather like in Paris today?",
}
]
stream_response = await client.responses.create(
model=model_name,
input=input_list,
tools=tools,
stream=True,
)
# Track which event types we see
event_types_seen = set()
function_call_found = False
async for event in stream_response:
event_types_seen.add(event.type)
if (
event.type == "response.output_item.added"
and event.item.type == "function_call"
):
function_call_found = True
# Ensure NO code_interpreter events are emitted for function calls
assert "code_interpreter" not in event.type, (
"Found code_interpreter event "
f"'{event.type}' during function call. Function calls should only "
"emit function_call events, not code_interpreter events."
)
# Verify we actually saw a function call
assert function_call_found, "Expected to see a function_call in the stream"
# Verify we saw the correct function call event types
assert (
"response.function_call_arguments.delta" in event_types_seen
or "response.function_call_arguments.done" in event_types_seen
), "Expected to see function_call_arguments events"
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_mcp_code_interpreter_streaming(client: OpenAI, model_name: str, server):
tools = [
{
"type": "mcp",
"server_label": "code_interpreter",
}
]
input_text = (
"Calculate 15 * 32 using python. "
"The python interpreter is not stateful and you must print to see the output."
)
stream_response = await client.responses.create(
model=model_name,
input=input_text,
tools=tools,
stream=True,
temperature=0.0,
instructions=(
"You must use the Python tool to execute code. Never simulate execution."
),
)
mcp_call_added = False
mcp_call_in_progress = False
mcp_arguments_delta_seen = False
mcp_arguments_done = False
mcp_call_completed = False
mcp_item_done = False
code_interpreter_events_seen = False
async for event in stream_response:
if "code_interpreter" in event.type:
code_interpreter_events_seen = True
if event.type == "response.output_item.added":
if hasattr(event.item, "type") and event.item.type == "mcp_call":
mcp_call_added = True
assert event.item.name == "python"
assert event.item.server_label == "code_interpreter"
elif event.type == "response.mcp_call.in_progress":
mcp_call_in_progress = True
elif event.type == "response.mcp_call_arguments.delta":
mcp_arguments_delta_seen = True
assert event.delta is not None
elif event.type == "response.mcp_call_arguments.done":
mcp_arguments_done = True
assert event.name == "python"
assert event.arguments is not None
elif event.type == "response.mcp_call.completed":
mcp_call_completed = True
elif (
event.type == "response.output_item.done"
and hasattr(event.item, "type")
and event.item.type == "mcp_call"
):
mcp_item_done = True
assert event.item.name == "python"
assert event.item.status == "completed"
assert mcp_call_added, "MCP call was not added"
assert mcp_call_in_progress, "MCP call in_progress event not seen"
assert mcp_arguments_delta_seen, "MCP arguments delta event not seen"
assert mcp_arguments_done, "MCP arguments done event not seen"
assert mcp_call_completed, "MCP call completed event not seen"
assert mcp_item_done, "MCP item done event not seen"
assert not code_interpreter_events_seen, (
"Should not see code_interpreter events when using MCP type"
)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_mcp_tool_multi_turn(client: OpenAI, model_name: str, server):
"""Test MCP tool calling across multiple turns.
This test verifies that MCP tools work correctly in multi-turn conversations,
maintaining state across turns via the previous_response_id mechanism.
"""
tools = [
{
"type": "mcp",
"server_label": "code_interpreter",
}
]
# First turn - make a calculation
response1 = await client.responses.create(
model=model_name,
input="Calculate 123 * 456 using python and print the result.",
tools=tools,
temperature=0.0,
instructions=(
"You must use the Python tool to execute code. Never simulate execution."
),
extra_body={"enable_response_messages": True},
)
assert response1 is not None
assert response1.status == "completed"
# Verify MCP call in first response by checking output_messages
tool_call_found = False
tool_response_found = False
for message in response1.output_messages:
recipient = message.get("recipient")
if recipient and recipient.startswith("python"):
tool_call_found = True
author = message.get("author", {})
if (
author.get("role") == "tool"
and author.get("name")
and author.get("name").startswith("python")
):
tool_response_found = True
# Verify MCP tools were actually used
assert tool_call_found, "MCP tool call not found in output_messages"
assert tool_response_found, "MCP tool response not found in output_messages"
# Verify input messages: Should have system message with tool, NO developer message
developer_messages = [
msg for msg in response1.input_messages if msg["author"]["role"] == "developer"
]
assert len(developer_messages) == 0, (
"No developer message expected for elevated tools"
)
# Second turn - reference previous calculation
response2 = await client.responses.create(
model=model_name,
input="Now divide that result by 2.",
tools=tools,
temperature=0.0,
instructions=(
"You must use the Python tool to execute code. Never simulate execution."
),
previous_response_id=response1.id,
extra_body={"enable_response_messages": True},
)
assert response2 is not None
assert response2.status == "completed"
# Verify input messages are correct: should have two messages -
# one to the python recipient on analysis channel and one from tool role
mcp_recipient_messages = []
tool_role_messages = []
for msg in response2.input_messages:
if msg["author"]["role"] == "assistant":
# Check if this is a message to MCP recipient on analysis channel
if msg.get("channel") == "analysis" and msg.get("recipient"):
recipient = msg.get("recipient")
if recipient.startswith("code_interpreter") or recipient == "python":
mcp_recipient_messages.append(msg)
elif msg["author"]["role"] == "tool":
tool_role_messages.append(msg)
assert len(mcp_recipient_messages) > 0, (
"Expected message(s) to MCP recipient on analysis channel"
)
assert len(tool_role_messages) > 0, (
"Expected message(s) from tool role after MCP call"
)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_output_messages_enabled(client: OpenAI, model_name: str, server): async def test_output_messages_enabled(client: OpenAI, model_name: str, server):
......
...@@ -9,6 +9,7 @@ from collections import deque ...@@ -9,6 +9,7 @@ from collections import deque
from collections.abc import AsyncGenerator, AsyncIterator, Callable, Sequence from collections.abc import AsyncGenerator, AsyncIterator, Callable, Sequence
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from copy import copy from copy import copy
from dataclasses import dataclass
from http import HTTPStatus from http import HTTPStatus
from typing import Final from typing import Final
...@@ -27,6 +28,10 @@ from openai.types.responses import ( ...@@ -27,6 +28,10 @@ from openai.types.responses import (
ResponseFunctionCallArgumentsDoneEvent, ResponseFunctionCallArgumentsDoneEvent,
ResponseFunctionToolCall, ResponseFunctionToolCall,
ResponseFunctionWebSearch, ResponseFunctionWebSearch,
ResponseMcpCallArgumentsDeltaEvent,
ResponseMcpCallArgumentsDoneEvent,
ResponseMcpCallCompletedEvent,
ResponseMcpCallInProgressEvent,
ResponseOutputItem, ResponseOutputItem,
ResponseOutputItemAddedEvent, ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent, ResponseOutputItemDoneEvent,
...@@ -44,6 +49,7 @@ from openai.types.responses import ( ...@@ -44,6 +49,7 @@ from openai.types.responses import (
response_function_web_search, response_function_web_search,
response_text_delta_event, response_text_delta_event,
) )
from openai.types.responses.response_output_item import McpCall
from openai.types.responses.response_output_text import Logprob, LogprobTopLogprob from openai.types.responses.response_output_text import Logprob, LogprobTopLogprob
from openai.types.responses.response_reasoning_item import ( from openai.types.responses.response_reasoning_item import (
Content as ResponseReasoningTextContent, Content as ResponseReasoningTextContent,
...@@ -119,6 +125,23 @@ from vllm.utils import random_uuid ...@@ -119,6 +125,23 @@ from vllm.utils import random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
@dataclass
class HarmonyStreamingState:
"""Mutable state for harmony streaming event processing."""
current_content_index: int = -1
current_output_index: int = 0
current_item_id: str = ""
sent_output_item_added: bool = False
is_first_function_call_delta: bool = False
def reset_for_new_item(self) -> None:
"""Reset state when expecting a new output item."""
self.current_output_index += 1
self.sent_output_item_added = False
self.is_first_function_call_delta = False
def _extract_allowed_tools_from_mcp_requests( def _extract_allowed_tools_from_mcp_requests(
tools: list[Tool], tools: list[Tool],
) -> dict[str, list[str] | None]: ) -> dict[str, list[str] | None]:
...@@ -740,6 +763,26 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -740,6 +763,26 @@ class OpenAIServingResponses(OpenAIServing):
self.response_store[response.id] = response self.response_store[response.id] = response
return response return response
def _is_mcp_tool_by_namespace(self, recipient: str | None) -> bool:
"""
Determine if a tool call is an MCP tool based on recipient prefix.
- Tools starting with "functions." are function calls
- Everything else is an MCP tool
"""
if recipient is None:
return False
# Function calls have "functions." prefix
# Everything else is an MCP tool
return not recipient.startswith("functions.")
_TOOL_NAME_TO_MCP_SERVER_LABEL: Final[dict[str, str]] = {
"python": "code_interpreter",
"container": "container",
"browser": "web_search_preview",
}
def _topk_logprobs( def _topk_logprobs(
self, self,
logprobs: dict[int, SampleLogprob], logprobs: dict[int, SampleLogprob],
...@@ -1036,7 +1079,6 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -1036,7 +1079,6 @@ class OpenAIServingResponses(OpenAIServing):
del prev_msgs[prev_final_msg_idx + 1 :] del prev_msgs[prev_final_msg_idx + 1 :]
for msg in recent_turn_msgs: for msg in recent_turn_msgs:
assert isinstance(msg, OpenAIHarmonyMessage) assert isinstance(msg, OpenAIHarmonyMessage)
if msg.channel != "analysis":
prev_msgs.append(msg) prev_msgs.append(msg)
messages.extend(prev_msgs) messages.extend(prev_msgs)
# Append the new input. # Append the new input.
...@@ -1520,48 +1562,21 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -1520,48 +1562,21 @@ class OpenAIServingResponses(OpenAIServing):
) )
) )
async def _process_harmony_streaming_events( def _emit_function_call_done_events(
self, self,
request: ResponsesRequest, previous_item,
sampling_params: SamplingParams, state: HarmonyStreamingState,
result_generator: AsyncIterator[ConversationContext | None], ) -> list[StreamingResponsesResponse]:
context: ConversationContext, """Emit events when a function call completes."""
model_name: str,
tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata,
created_time: int,
_increment_sequence_number_and_return: Callable[
[StreamingResponsesResponse], StreamingResponsesResponse
],
) -> AsyncGenerator[StreamingResponsesResponse, None]:
current_content_index = -1
current_output_index = 0
current_item_id: str = ""
sent_output_item_added = False
is_first_function_call_delta = False
async for ctx in result_generator:
assert isinstance(ctx, StreamingHarmonyContext)
# finish_reason='error' indicates a retryable error
self._raise_if_error(ctx.finish_reason, request.request_id)
if ctx.is_expecting_start():
current_output_index += 1
sent_output_item_added = False
is_first_function_call_delta = False
if len(ctx.parser.messages) > 0:
previous_item = ctx.parser.messages[-1]
if previous_item.recipient is not None:
# Deal with tool call
if previous_item.recipient.startswith("functions."):
function_name = previous_item.recipient[len("functions.") :] function_name = previous_item.recipient[len("functions.") :]
yield _increment_sequence_number_and_return( events = []
events.append(
ResponseFunctionCallArgumentsDoneEvent( ResponseFunctionCallArgumentsDoneEvent(
type="response.function_call_arguments.done", type="response.function_call_arguments.done",
arguments=previous_item.content[0].text, arguments=previous_item.content[0].text,
name=function_name, name=function_name,
item_id=current_item_id, item_id=state.current_item_id,
output_index=current_output_index, output_index=state.current_output_index,
sequence_number=-1, sequence_number=-1,
) )
) )
...@@ -1569,21 +1584,73 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -1569,21 +1584,73 @@ class OpenAIServingResponses(OpenAIServing):
type="function_call", type="function_call",
arguments=previous_item.content[0].text, arguments=previous_item.content[0].text,
name=function_name, name=function_name,
item_id=current_item_id, item_id=state.current_item_id,
output_index=current_output_index, output_index=state.current_output_index,
sequence_number=-1, sequence_number=-1,
call_id=f"fc_{random_uuid()}", call_id=f"fc_{random_uuid()}",
status="completed", status="completed",
) )
yield _increment_sequence_number_and_return( events.append(
ResponseOutputItemDoneEvent( ResponseOutputItemDoneEvent(
type="response.output_item.done", type="response.output_item.done",
sequence_number=-1, sequence_number=-1,
output_index=current_output_index, output_index=state.current_output_index,
item=function_call_item, item=function_call_item,
) )
) )
elif previous_item.channel == "analysis": return events
def _emit_mcp_call_done_events(
self,
previous_item,
state: HarmonyStreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events when an MCP tool call completes."""
server_label = self._TOOL_NAME_TO_MCP_SERVER_LABEL.get(
previous_item.recipient, previous_item.recipient
)
events = []
events.append(
ResponseMcpCallArgumentsDoneEvent(
type="response.mcp_call_arguments.done",
arguments=previous_item.content[0].text,
name=previous_item.recipient,
item_id=state.current_item_id,
output_index=state.current_output_index,
sequence_number=-1,
)
)
events.append(
ResponseMcpCallCompletedEvent(
type="response.mcp_call.completed",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
)
)
events.append(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=state.current_output_index,
item=McpCall(
type="mcp_call",
arguments=previous_item.content[0].text,
name=previous_item.recipient,
id=state.current_item_id,
server_label=server_label,
status="completed",
),
)
)
return events
def _emit_reasoning_done_events(
self,
previous_item,
state: HarmonyStreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events when a reasoning (analysis) item completes."""
content = ResponseReasoningTextContent( content = ResponseReasoningTextContent(
text=previous_item.content[0].text, text=previous_item.content[0].text,
type="reasoning_text", type="reasoning_text",
...@@ -1592,71 +1659,80 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -1592,71 +1659,80 @@ class OpenAIServingResponses(OpenAIServing):
type="reasoning", type="reasoning",
content=[content], content=[content],
status="completed", status="completed",
id=current_item_id, id=state.current_item_id,
summary=[], summary=[],
) )
yield _increment_sequence_number_and_return( events = []
events.append(
ResponseReasoningTextDoneEvent( ResponseReasoningTextDoneEvent(
type="response.reasoning_text.done", type="response.reasoning_text.done",
item_id=current_item_id, item_id=state.current_item_id,
sequence_number=-1, sequence_number=-1,
output_index=current_output_index, output_index=state.current_output_index,
content_index=current_content_index, content_index=state.current_content_index,
text=previous_item.content[0].text, text=previous_item.content[0].text,
) )
) )
yield _increment_sequence_number_and_return( events.append(
ResponseReasoningPartDoneEvent( ResponseReasoningPartDoneEvent(
type="response.reasoning_part.done", type="response.reasoning_part.done",
sequence_number=-1, sequence_number=-1,
item_id=current_item_id, item_id=state.current_item_id,
output_index=current_output_index, output_index=state.current_output_index,
content_index=current_content_index, content_index=state.current_content_index,
part=content, part=content,
) )
) )
yield _increment_sequence_number_and_return( events.append(
ResponseOutputItemDoneEvent( ResponseOutputItemDoneEvent(
type="response.output_item.done", type="response.output_item.done",
sequence_number=-1, sequence_number=-1,
output_index=current_output_index, output_index=state.current_output_index,
item=reasoning_item, item=reasoning_item,
) )
) )
elif previous_item.channel == "final": return events
def _emit_text_output_done_events(
self,
previous_item,
state: HarmonyStreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events when a final text output item completes."""
text_content = ResponseOutputText( text_content = ResponseOutputText(
type="output_text", type="output_text",
text=previous_item.content[0].text, text=previous_item.content[0].text,
annotations=[], annotations=[],
) )
yield _increment_sequence_number_and_return( events = []
events.append(
ResponseTextDoneEvent( ResponseTextDoneEvent(
type="response.output_text.done", type="response.output_text.done",
sequence_number=-1, sequence_number=-1,
output_index=current_output_index, output_index=state.current_output_index,
content_index=current_content_index, content_index=state.current_content_index,
text=previous_item.content[0].text, text=previous_item.content[0].text,
logprobs=[], logprobs=[],
item_id=current_item_id, item_id=state.current_item_id,
) )
) )
yield _increment_sequence_number_and_return( events.append(
ResponseContentPartDoneEvent( ResponseContentPartDoneEvent(
type="response.content_part.done", type="response.content_part.done",
sequence_number=-1, sequence_number=-1,
item_id=current_item_id, item_id=state.current_item_id,
output_index=current_output_index, output_index=state.current_output_index,
content_index=current_content_index, content_index=state.current_content_index,
part=text_content, part=text_content,
) )
) )
yield _increment_sequence_number_and_return( events.append(
ResponseOutputItemDoneEvent( ResponseOutputItemDoneEvent(
type="response.output_item.done", type="response.output_item.done",
sequence_number=-1, sequence_number=-1,
output_index=current_output_index, output_index=state.current_output_index,
item=ResponseOutputMessage( item=ResponseOutputMessage(
id=current_item_id, id=state.current_item_id,
type="message", type="message",
role="assistant", role="assistant",
content=[text_content], content=[text_content],
...@@ -1664,23 +1740,47 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -1664,23 +1740,47 @@ class OpenAIServingResponses(OpenAIServing):
), ),
) )
) )
return events
# stream the output of a harmony message def _emit_previous_item_done_events(
if ctx.parser.last_content_delta: self,
if ( previous_item,
ctx.parser.current_channel == "final" state: HarmonyStreamingState,
and ctx.parser.current_recipient is None ) -> list[StreamingResponsesResponse]:
"""Emit done events for the previous item when expecting a new start."""
if previous_item.recipient is not None:
# Deal with tool call
if previous_item.recipient.startswith("functions."):
return self._emit_function_call_done_events(previous_item, state)
elif (
self._is_mcp_tool_by_namespace(previous_item.recipient)
and state.current_item_id is not None
and state.current_item_id.startswith("mcp_")
): ):
if not sent_output_item_added: return self._emit_mcp_call_done_events(previous_item, state)
sent_output_item_added = True elif previous_item.channel == "analysis":
current_item_id = f"msg_{random_uuid()}" return self._emit_reasoning_done_events(previous_item, state)
yield _increment_sequence_number_and_return( elif previous_item.channel == "final":
return self._emit_text_output_done_events(previous_item, state)
return []
def _emit_final_channel_delta_events(
self,
ctx: StreamingHarmonyContext,
state: HarmonyStreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events for final channel text delta streaming."""
events = []
if not state.sent_output_item_added:
state.sent_output_item_added = True
state.current_item_id = f"msg_{random_uuid()}"
events.append(
ResponseOutputItemAddedEvent( ResponseOutputItemAddedEvent(
type="response.output_item.added", type="response.output_item.added",
sequence_number=-1, sequence_number=-1,
output_index=current_output_index, output_index=state.current_output_index,
item=ResponseOutputMessage( item=ResponseOutputMessage(
id=current_item_id, id=state.current_item_id,
type="message", type="message",
role="assistant", role="assistant",
content=[], content=[],
...@@ -1688,14 +1788,14 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -1688,14 +1788,14 @@ class OpenAIServingResponses(OpenAIServing):
), ),
) )
) )
current_content_index += 1 state.current_content_index += 1
yield _increment_sequence_number_and_return( events.append(
ResponseContentPartAddedEvent( ResponseContentPartAddedEvent(
type="response.content_part.added", type="response.content_part.added",
sequence_number=-1, sequence_number=-1,
output_index=current_output_index, output_index=state.current_output_index,
item_id=current_item_id, item_id=state.current_item_id,
content_index=current_content_index, content_index=state.current_content_index,
part=ResponseOutputText( part=ResponseOutputText(
type="output_text", type="output_text",
text="", text="",
...@@ -1704,80 +1804,133 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -1704,80 +1804,133 @@ class OpenAIServingResponses(OpenAIServing):
), ),
) )
) )
yield _increment_sequence_number_and_return( events.append(
ResponseTextDeltaEvent( ResponseTextDeltaEvent(
type="response.output_text.delta", type="response.output_text.delta",
sequence_number=-1, sequence_number=-1,
content_index=current_content_index, content_index=state.current_content_index,
output_index=current_output_index, output_index=state.current_output_index,
item_id=current_item_id, item_id=state.current_item_id,
delta=ctx.parser.last_content_delta, delta=ctx.parser.last_content_delta,
# TODO, use logprobs from ctx.last_request_output # TODO, use logprobs from ctx.last_request_output
logprobs=[], logprobs=[],
) )
) )
elif ( return events
ctx.parser.current_channel == "analysis"
and ctx.parser.current_recipient is None def _emit_analysis_channel_delta_events(
): self,
if not sent_output_item_added: ctx: StreamingHarmonyContext,
sent_output_item_added = True state: HarmonyStreamingState,
current_item_id = f"msg_{random_uuid()}" ) -> list[StreamingResponsesResponse]:
yield _increment_sequence_number_and_return( """Emit events for analysis channel reasoning delta streaming."""
events = []
if not state.sent_output_item_added:
state.sent_output_item_added = True
state.current_item_id = f"msg_{random_uuid()}"
events.append(
ResponseOutputItemAddedEvent( ResponseOutputItemAddedEvent(
type="response.output_item.added", type="response.output_item.added",
sequence_number=-1, sequence_number=-1,
output_index=current_output_index, output_index=state.current_output_index,
item=ResponseReasoningItem( item=ResponseReasoningItem(
type="reasoning", type="reasoning",
id=current_item_id, id=state.current_item_id,
summary=[], summary=[],
status="in_progress", status="in_progress",
), ),
) )
) )
current_content_index += 1 state.current_content_index += 1
yield _increment_sequence_number_and_return( events.append(
ResponseReasoningPartAddedEvent( ResponseReasoningPartAddedEvent(
type="response.reasoning_part.added", type="response.reasoning_part.added",
sequence_number=-1, sequence_number=-1,
output_index=current_output_index, output_index=state.current_output_index,
item_id=current_item_id, item_id=state.current_item_id,
content_index=current_content_index, content_index=state.current_content_index,
part=ResponseReasoningTextContent( part=ResponseReasoningTextContent(
text="", text="",
type="reasoning_text", type="reasoning_text",
), ),
) )
) )
yield _increment_sequence_number_and_return( events.append(
ResponseReasoningTextDeltaEvent( ResponseReasoningTextDeltaEvent(
type="response.reasoning_text.delta", type="response.reasoning_text.delta",
item_id=current_item_id, item_id=state.current_item_id,
output_index=current_output_index, output_index=state.current_output_index,
content_index=current_content_index, content_index=state.current_content_index,
delta=ctx.parser.last_content_delta, delta=ctx.parser.last_content_delta,
sequence_number=-1, sequence_number=-1,
) )
) )
# built-in tools will be triggered on the analysis channel return events
# However, occasionally built-in tools will
# still be output to commentary. def _emit_mcp_tool_delta_events(
elif ( self,
ctx.parser.current_channel == "commentary" ctx: StreamingHarmonyContext,
or ctx.parser.current_channel == "analysis" state: HarmonyStreamingState,
) and ctx.parser.current_recipient == "python": recipient: str,
if not sent_output_item_added: ) -> list[StreamingResponsesResponse]:
sent_output_item_added = True """Emit events for MCP tool delta streaming."""
current_item_id = f"tool_{random_uuid()}" server_label = self._TOOL_NAME_TO_MCP_SERVER_LABEL.get(recipient, recipient)
yield _increment_sequence_number_and_return( events = []
if not state.sent_output_item_added:
state.sent_output_item_added = True
state.current_item_id = f"mcp_{random_uuid()}"
events.append(
ResponseOutputItemAddedEvent( ResponseOutputItemAddedEvent(
type="response.output_item.added", type="response.output_item.added",
sequence_number=-1, sequence_number=-1,
output_index=current_output_index, output_index=state.current_output_index,
item=McpCall(
type="mcp_call",
id=state.current_item_id,
name=recipient,
arguments="",
server_label=server_label,
status="in_progress",
),
)
)
events.append(
ResponseMcpCallInProgressEvent(
type="response.mcp_call.in_progress",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
)
)
events.append(
ResponseMcpCallArgumentsDeltaEvent(
type="response.mcp_call_arguments.delta",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
delta=ctx.parser.last_content_delta,
)
)
return events
def _emit_code_interpreter_delta_events(
self,
ctx: StreamingHarmonyContext,
state: HarmonyStreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events for code interpreter delta streaming."""
events = []
if not state.sent_output_item_added:
state.sent_output_item_added = True
state.current_item_id = f"tool_{random_uuid()}"
events.append(
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=state.current_output_index,
item=ResponseCodeInterpreterToolCallParam( item=ResponseCodeInterpreterToolCallParam(
type="code_interpreter_call", type="code_interpreter_call",
id=current_item_id, id=state.current_item_id,
code=None, code=None,
container_id="auto", container_id="auto",
outputs=None, outputs=None,
...@@ -1785,36 +1938,129 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -1785,36 +1938,129 @@ class OpenAIServingResponses(OpenAIServing):
), ),
) )
) )
yield _increment_sequence_number_and_return( events.append(
ResponseCodeInterpreterCallInProgressEvent( ResponseCodeInterpreterCallInProgressEvent(
type="response.code_interpreter_call.in_progress", type="response.code_interpreter_call.in_progress",
sequence_number=-1, sequence_number=-1,
output_index=current_output_index, output_index=state.current_output_index,
item_id=current_item_id, item_id=state.current_item_id,
) )
) )
yield _increment_sequence_number_and_return( events.append(
ResponseCodeInterpreterCallCodeDeltaEvent( ResponseCodeInterpreterCallCodeDeltaEvent(
type="response.code_interpreter_call_code.delta", type="response.code_interpreter_call_code.delta",
sequence_number=-1, sequence_number=-1,
output_index=current_output_index, output_index=state.current_output_index,
item_id=current_item_id, item_id=state.current_item_id,
delta=ctx.parser.last_content_delta, delta=ctx.parser.last_content_delta,
) )
) )
return events
def _emit_mcp_prefix_delta_events(
self,
ctx: StreamingHarmonyContext,
state: HarmonyStreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events for MCP prefix (mcp.*) delta streaming."""
events = []
if not state.sent_output_item_added:
state.sent_output_item_added = True
state.current_item_id = f"mcp_{random_uuid()}"
mcp_name = ctx.parser.current_recipient[len("mcp.") :]
events.append(
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=state.current_output_index,
item=McpCall(
type="mcp_call",
id=state.current_item_id,
name=mcp_name,
arguments="",
server_label=mcp_name,
status="in_progress",
),
)
)
events.append(
ResponseMcpCallInProgressEvent(
type="response.mcp_call.in_progress",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
)
)
events.append(
ResponseMcpCallArgumentsDeltaEvent(
type="response.mcp_call_arguments.delta",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
delta=ctx.parser.last_content_delta,
)
)
return events
def _emit_content_delta_events(
self,
ctx: StreamingHarmonyContext,
state: HarmonyStreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events for content delta streaming based on channel type."""
if not ctx.parser.last_content_delta:
return []
# stream tool call outputs
if ctx.is_assistant_action_turn() and len(ctx.parser.messages) > 0:
previous_item = ctx.parser.messages[-1]
if ( if (
self.tool_server is not None ctx.parser.current_channel == "final"
and self.tool_server.has_tool("browser") and ctx.parser.current_recipient is None
and previous_item.recipient is not None ):
and previous_item.recipient.startswith("browser.") return self._emit_final_channel_delta_events(ctx, state)
elif (
ctx.parser.current_channel == "analysis"
and ctx.parser.current_recipient is None
): ):
return self._emit_analysis_channel_delta_events(ctx, state)
# built-in tools will be triggered on the analysis channel
# However, occasionally built-in tools will
# still be output to commentary.
elif (
ctx.parser.current_channel == "commentary"
or ctx.parser.current_channel == "analysis"
) and ctx.parser.current_recipient is not None:
recipient = ctx.parser.current_recipient
# Check for function calls first - they have their own event handling
if recipient.startswith("functions."):
return self._emit_function_call_delta_events(ctx, state)
is_mcp_tool = self._is_mcp_tool_by_namespace(recipient)
if is_mcp_tool:
return self._emit_mcp_tool_delta_events(ctx, state, recipient)
else:
return self._emit_code_interpreter_delta_events(ctx, state)
elif (
(
ctx.parser.current_channel == "commentary"
or ctx.parser.current_channel == "analysis"
)
and ctx.parser.current_recipient is not None
and ctx.parser.current_recipient.startswith("mcp.")
):
return self._emit_mcp_prefix_delta_events(ctx, state)
return []
def _emit_browser_tool_events(
self,
previous_item,
state: HarmonyStreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events for browser tool calls (web search)."""
function_name = previous_item.recipient[len("browser.") :] function_name = previous_item.recipient[len("browser.") :]
action = None
parsed_args = json.loads(previous_item.content[0].text) parsed_args = json.loads(previous_item.content[0].text)
action = None
if function_name == "search": if function_name == "search":
action = response_function_web_search.ActionSearch( action = response_function_web_search.ActionSearch(
type="search", type="search",
...@@ -1836,145 +2082,334 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -1836,145 +2082,334 @@ class OpenAIServingResponses(OpenAIServing):
else: else:
raise ValueError(f"Unknown function name: {function_name}") raise ValueError(f"Unknown function name: {function_name}")
current_item_id = f"tool_{random_uuid()}" state.current_item_id = f"tool_{random_uuid()}"
yield _increment_sequence_number_and_return( events = []
events.append(
ResponseOutputItemAddedEvent( ResponseOutputItemAddedEvent(
type="response.output_item.added", type="response.output_item.added",
sequence_number=-1, sequence_number=-1,
output_index=current_output_index, output_index=state.current_output_index,
item=response_function_web_search.ResponseFunctionWebSearch( item=response_function_web_search.ResponseFunctionWebSearch(
# TODO: generate a unique id for web search call # TODO: generate a unique id for web search call
type="web_search_call", type="web_search_call",
id=current_item_id, id=state.current_item_id,
action=action, action=action,
status="in_progress", status="in_progress",
), ),
) )
) )
yield _increment_sequence_number_and_return( events.append(
ResponseWebSearchCallInProgressEvent( ResponseWebSearchCallInProgressEvent(
type="response.web_search_call.in_progress", type="response.web_search_call.in_progress",
sequence_number=-1, sequence_number=-1,
output_index=current_output_index, output_index=state.current_output_index,
item_id=current_item_id, item_id=state.current_item_id,
) )
) )
yield _increment_sequence_number_and_return( events.append(
ResponseWebSearchCallSearchingEvent( ResponseWebSearchCallSearchingEvent(
type="response.web_search_call.searching", type="response.web_search_call.searching",
sequence_number=-1, sequence_number=-1,
output_index=current_output_index, output_index=state.current_output_index,
item_id=current_item_id, item_id=state.current_item_id,
) )
) )
# enqueue # enqueue
yield _increment_sequence_number_and_return( events.append(
ResponseWebSearchCallCompletedEvent( ResponseWebSearchCallCompletedEvent(
type="response.web_search_call.completed", type="response.web_search_call.completed",
sequence_number=-1, sequence_number=-1,
output_index=current_output_index, output_index=state.current_output_index,
item_id=current_item_id, item_id=state.current_item_id,
) )
) )
yield _increment_sequence_number_and_return( events.append(
ResponseOutputItemDoneEvent( ResponseOutputItemDoneEvent(
type="response.output_item.done", type="response.output_item.done",
sequence_number=-1, sequence_number=-1,
output_index=current_output_index, output_index=state.current_output_index,
item=ResponseFunctionWebSearch( item=ResponseFunctionWebSearch(
type="web_search_call", type="web_search_call",
id=current_item_id, id=state.current_item_id,
action=action, action=action,
status="completed", status="completed",
), ),
) )
) )
return events
if ( def _emit_mcp_tool_completion_events(
self.tool_server is not None self,
and self.tool_server.has_tool("python") previous_item,
and previous_item.recipient is not None state: HarmonyStreamingState,
and previous_item.recipient.startswith("python") ) -> list[StreamingResponsesResponse]:
): """Emit events when an MCP tool completes during assistant action turn."""
yield _increment_sequence_number_and_return( recipient = previous_item.recipient
server_label = self._TOOL_NAME_TO_MCP_SERVER_LABEL.get(recipient, recipient)
events = []
events.append(
ResponseMcpCallArgumentsDoneEvent(
type="response.mcp_call_arguments.done",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
arguments=previous_item.content[0].text,
name=recipient,
)
)
events.append(
ResponseMcpCallCompletedEvent(
type="response.mcp_call.completed",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
)
)
events.append(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=state.current_output_index,
item=McpCall(
type="mcp_call",
id=state.current_item_id,
name=recipient,
arguments=previous_item.content[0].text,
server_label=server_label,
status="completed",
),
)
)
return events
def _emit_code_interpreter_completion_events(
self,
previous_item,
state: HarmonyStreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events when code interpreter completes."""
events = []
events.append(
ResponseCodeInterpreterCallCodeDoneEvent( ResponseCodeInterpreterCallCodeDoneEvent(
type="response.code_interpreter_call_code.done", type="response.code_interpreter_call_code.done",
sequence_number=-1, sequence_number=-1,
output_index=current_output_index, output_index=state.current_output_index,
item_id=current_item_id, item_id=state.current_item_id,
code=previous_item.content[0].text, code=previous_item.content[0].text,
) )
) )
yield _increment_sequence_number_and_return( events.append(
ResponseCodeInterpreterCallInterpretingEvent( ResponseCodeInterpreterCallInterpretingEvent(
type="response.code_interpreter_call.interpreting", type="response.code_interpreter_call.interpreting",
sequence_number=-1, sequence_number=-1,
output_index=current_output_index, output_index=state.current_output_index,
item_id=current_item_id, item_id=state.current_item_id,
) )
) )
yield _increment_sequence_number_and_return( events.append(
ResponseCodeInterpreterCallCompletedEvent( ResponseCodeInterpreterCallCompletedEvent(
type="response.code_interpreter_call.completed", type="response.code_interpreter_call.completed",
sequence_number=-1, sequence_number=-1,
output_index=current_output_index, output_index=state.current_output_index,
item_id=current_item_id, item_id=state.current_item_id,
) )
) )
yield _increment_sequence_number_and_return( events.append(
ResponseOutputItemDoneEvent( ResponseOutputItemDoneEvent(
type="response.output_item.done", type="response.output_item.done",
sequence_number=-1, sequence_number=-1,
output_index=current_output_index, output_index=state.current_output_index,
item=ResponseCodeInterpreterToolCallParam( item=ResponseCodeInterpreterToolCallParam(
type="code_interpreter_call", type="code_interpreter_call",
id=current_item_id, id=state.current_item_id,
code=previous_item.content[0].text, code=previous_item.content[0].text,
container_id="auto", container_id="auto",
# TODO: add outputs here
outputs=[], outputs=[],
status="completed", status="completed",
), ),
) )
) )
# developer tools will be triggered on the commentary channel return events
# and recipient starts with "functions.TOOL_NAME"
def _emit_mcp_prefix_completion_events(
self,
previous_item,
state: HarmonyStreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events when an MCP prefix tool (mcp.*) completes."""
mcp_name = previous_item.recipient[len("mcp.") :]
events = []
events.append(
ResponseMcpCallArgumentsDoneEvent(
type="response.mcp_call_arguments.done",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
arguments=previous_item.content[0].text,
name=mcp_name,
)
)
events.append(
ResponseMcpCallCompletedEvent(
type="response.mcp_call.completed",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
)
)
events.append(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=state.current_output_index,
item=McpCall(
type="mcp_call",
id=state.current_item_id,
name=mcp_name,
arguments=previous_item.content[0].text,
server_label=mcp_name,
status="completed",
),
)
)
return events
def _emit_tool_action_events(
self,
ctx: StreamingHarmonyContext,
state: HarmonyStreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events for tool action turn."""
if not ctx.is_assistant_action_turn() or len(ctx.parser.messages) == 0:
return []
events = []
previous_item = ctx.parser.messages[-1]
# Handle browser tool
if ( if (
self.tool_server is not None
and self.tool_server.has_tool("browser")
and previous_item.recipient is not None
and previous_item.recipient.startswith("browser.")
):
events.extend(self._emit_browser_tool_events(previous_item, state))
# Handle tool completion
if (
self.tool_server is not None
and previous_item.recipient is not None
and state.current_item_id is not None
and state.sent_output_item_added
):
recipient = previous_item.recipient
# Handle MCP prefix tool completion first
if recipient.startswith("mcp."):
events.extend(
self._emit_mcp_prefix_completion_events(previous_item, state)
)
else:
# Handle other MCP tool and code interpreter completion
is_mcp_tool = self._is_mcp_tool_by_namespace(
recipient
) and state.current_item_id.startswith("mcp_")
if is_mcp_tool:
events.extend(
self._emit_mcp_tool_completion_events(previous_item, state)
)
else:
events.extend(
self._emit_code_interpreter_completion_events(
previous_item, state
)
)
return events
def _emit_function_call_delta_events(
self,
ctx: StreamingHarmonyContext,
state: HarmonyStreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events for developer function calls on commentary channel."""
if not (
ctx.parser.current_channel == "commentary" ctx.parser.current_channel == "commentary"
and ctx.parser.current_recipient and ctx.parser.current_recipient
and ctx.parser.current_recipient.startswith("functions.") and ctx.parser.current_recipient.startswith("functions.")
): ):
if is_first_function_call_delta is False: return []
is_first_function_call_delta = True
events = []
if state.is_first_function_call_delta is False:
state.is_first_function_call_delta = True
fc_name = ctx.parser.current_recipient[len("functions.") :] fc_name = ctx.parser.current_recipient[len("functions.") :]
state.current_item_id = f"fc_{random_uuid()}"
tool_call_item = ResponseFunctionToolCall( tool_call_item = ResponseFunctionToolCall(
name=fc_name, name=fc_name,
type="function_call", type="function_call",
id=current_item_id, id=state.current_item_id,
call_id=f"call_{random_uuid()}", call_id=f"call_{random_uuid()}",
arguments="", arguments="",
status="in_progress", status="in_progress",
) )
current_item_id = f"fc_{random_uuid()}" events.append(
yield _increment_sequence_number_and_return(
ResponseOutputItemAddedEvent( ResponseOutputItemAddedEvent(
type="response.output_item.added", type="response.output_item.added",
sequence_number=-1, sequence_number=-1,
output_index=current_output_index, output_index=state.current_output_index,
item=tool_call_item, item=tool_call_item,
) )
) )
else: # Always emit the delta (including on first call)
yield _increment_sequence_number_and_return( events.append(
ResponseFunctionCallArgumentsDeltaEvent( ResponseFunctionCallArgumentsDeltaEvent(
item_id=current_item_id, item_id=state.current_item_id,
delta=ctx.parser.last_content_delta, delta=ctx.parser.last_content_delta,
output_index=current_output_index, output_index=state.current_output_index,
sequence_number=-1, sequence_number=-1,
type="response.function_call_arguments.delta", type="response.function_call_arguments.delta",
) )
) )
return events
async def _process_harmony_streaming_events(
self,
request: ResponsesRequest,
sampling_params: SamplingParams,
result_generator: AsyncIterator[ConversationContext | None],
context: ConversationContext,
model_name: str,
tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata,
created_time: int,
_increment_sequence_number_and_return: Callable[
[StreamingResponsesResponse], StreamingResponsesResponse
],
) -> AsyncGenerator[StreamingResponsesResponse, None]:
state = HarmonyStreamingState()
async for ctx in result_generator:
assert isinstance(ctx, StreamingHarmonyContext)
# finish_reason='error' indicates a retryable error
self._raise_if_error(ctx.finish_reason, request.request_id)
if ctx.is_expecting_start():
if len(ctx.parser.messages) > 0:
previous_item = ctx.parser.messages[-1]
for event in self._emit_previous_item_done_events(
previous_item, state
):
yield _increment_sequence_number_and_return(event)
state.reset_for_new_item()
# Stream the output of a harmony message
for event in self._emit_content_delta_events(ctx, state):
yield _increment_sequence_number_and_return(event)
# Stream tool call outputs
for event in self._emit_tool_action_events(ctx, state):
yield _increment_sequence_number_and_return(event)
async def responses_stream_generator( async def responses_stream_generator(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment