Unverified Commit a4ec0c55 authored by daniel-salib's avatar daniel-salib Committed by GitHub
Browse files

[Frontend] Add MCP tool streaming support to Responses API (#31761)


Signed-off-by: default avatarDaniel Salib <danielsalib@meta.com>
parent 0fa8dd24
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import pytest_asyncio
from openai import OpenAI
......@@ -13,199 +14,6 @@ from ...utils import RemoteOpenAIServer
MODEL_NAME = "openai/gpt-oss-20b"
@pytest.fixture(scope="module")
def monkeypatch_module():
from _pytest.monkeypatch import MonkeyPatch
mpatch = MonkeyPatch()
yield mpatch
mpatch.undo()
@pytest.fixture(scope="module")
def mcp_disabled_server(monkeypatch_module: pytest.MonkeyPatch):
args = ["--enforce-eager", "--tool-server", "demo"]
with monkeypatch_module.context() as m:
m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1")
m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv")
# Helps the model follow instructions better
m.setenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "1")
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.fixture(scope="function")
def mcp_enabled_server(monkeypatch_module: pytest.MonkeyPatch):
args = ["--enforce-eager", "--tool-server", "demo"]
with monkeypatch_module.context() as m:
m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1")
m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv")
m.setenv("VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS", "code_interpreter,container")
# Helps the model follow instructions better
m.setenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "1")
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def mcp_disabled_client(mcp_disabled_server):
async with mcp_disabled_server.get_async_client() as async_client:
yield async_client
@pytest_asyncio.fixture
async def mcp_enabled_client(mcp_enabled_server):
async with mcp_enabled_server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_mcp_tool_env_flag_enabled(mcp_enabled_client: OpenAI, model_name: str):
response = await mcp_enabled_client.responses.create(
model=model_name,
input=(
"Execute the following code: "
"import random; print(random.randint(1, 1000000))"
),
instructions=(
"You must use the Python tool to execute code. Never simulate execution."
),
tools=[
{
"type": "mcp",
"server_label": "code_interpreter",
# URL unused for DemoToolServer
"server_url": "http://localhost:8888",
}
],
extra_body={"enable_response_messages": True},
)
assert response is not None
assert response.status == "completed"
# Verify output messages: Tool calls and responses on analysis channel
tool_call_found = False
tool_response_found = False
for message in response.output_messages:
recipient = message.get("recipient")
if recipient and recipient.startswith("python"):
tool_call_found = True
assert message.get("channel") == "analysis", (
"Tool call should be on analysis channel"
)
author = message.get("author", {})
if (
author.get("role") == "tool"
and author.get("name")
and author.get("name").startswith("python")
):
tool_response_found = True
assert message.get("channel") == "analysis", (
"Tool response should be on analysis channel"
)
assert tool_call_found, "Should have found at least one Python tool call"
assert tool_response_found, "Should have found at least one Python tool response"
for message in response.input_messages:
assert message.get("author").get("role") != "developer", (
"No developer messages should be present with valid mcp tool"
)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_mcp_tool_with_allowed_tools_star(
mcp_enabled_client: OpenAI, model_name: str
):
"""Test MCP tool with allowed_tools=['*'] to select all available tools.
This E2E test verifies that the "*" wildcard works end-to-end.
See test_serving_responses.py for detailed unit tests of "*" normalization.
"""
response = await mcp_enabled_client.responses.create(
model=model_name,
input=(
"Execute the following code: "
"import random; print(random.randint(1, 1000000))"
),
instructions=(
"You must use the Python tool to execute code. Never simulate execution."
),
tools=[
{
"type": "mcp",
"server_label": "code_interpreter",
"server_url": "http://localhost:8888",
# Using "*" to allow all tools from this MCP server
"allowed_tools": ["*"],
}
],
extra_body={"enable_response_messages": True},
)
assert response is not None
assert response.status == "completed"
# Verify tool calls work with allowed_tools=["*"]
tool_call_found = False
for message in response.output_messages:
recipient = message.get("recipient")
if recipient and recipient.startswith("python"):
tool_call_found = True
break
assert tool_call_found, "Should have found at least one Python tool call with '*'"
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_mcp_tool_env_flag_disabled(mcp_disabled_client: OpenAI, model_name: str):
response = await mcp_disabled_client.responses.create(
model=model_name,
input=(
"Execute the following code if the tool is present: "
"import random; print(random.randint(1, 1000000))"
),
tools=[
{
"type": "mcp",
"server_label": "code_interpreter",
# URL unused for DemoToolServer
"server_url": "http://localhost:8888",
}
],
extra_body={"enable_response_messages": True},
)
assert response is not None
assert response.status == "completed"
# Verify output messages: No tool calls and responses
tool_call_found = False
tool_response_found = False
for message in response.output_messages:
recipient = message.get("recipient")
if recipient and recipient.startswith("python"):
tool_call_found = True
assert message.get("channel") == "analysis", (
"Tool call should be on analysis channel"
)
author = message.get("author", {})
if (
author.get("role") == "tool"
and author.get("name")
and author.get("name").startswith("python")
):
tool_response_found = True
assert message.get("channel") == "analysis", (
"Tool response should be on analysis channel"
)
assert not tool_call_found, "Should not have a python call"
assert not tool_response_found, "Should not have a tool response"
for message in response.input_messages:
assert message.get("author").get("role") != "developer", (
"No developer messages should be present without a valid tool"
)
def test_get_tool_description():
"""Test MCPToolServer.get_tool_description filtering logic.
......@@ -249,7 +57,10 @@ def test_get_tool_description():
assert result.tools[1].name == "tool3"
# Single tool
result = server.get_tool_description("test_server", allowed_tools=["tool2"])
result = server.get_tool_description(
"test_server",
allowed_tools=["tool2"],
)
assert len(result.tools) == 1
assert result.tools[0].name == "tool2"
......@@ -259,3 +70,283 @@ def test_get_tool_description():
# Empty list - returns None
assert server.get_tool_description("test_server", allowed_tools=[]) is None
class TestMCPEnabled:
"""Tests that require MCP tools to be enabled via environment variable."""
@pytest.fixture(scope="class")
def monkeypatch_class(self):
from _pytest.monkeypatch import MonkeyPatch
mpatch = MonkeyPatch()
yield mpatch
mpatch.undo()
@pytest.fixture(scope="class")
def mcp_enabled_server(self, monkeypatch_class: pytest.MonkeyPatch):
args = ["--enforce-eager", "--tool-server", "demo"]
with monkeypatch_class.context() as m:
m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1")
m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv")
m.setenv(
"VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS", "code_interpreter,container"
)
# Helps the model follow instructions better
m.setenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "1")
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def mcp_enabled_client(self, mcp_enabled_server):
async with mcp_enabled_server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_mcp_tool_env_flag_enabled(
self, mcp_enabled_client: OpenAI, model_name: str
):
response = await mcp_enabled_client.responses.create(
model=model_name,
input=(
"Execute the following code: "
"import random; print(random.randint(1, 1000000))"
),
instructions=(
"You must use the Python tool to execute code. "
"Never simulate execution."
),
tools=[
{
"type": "mcp",
"server_label": "code_interpreter",
# URL unused for DemoToolServer
"server_url": "http://localhost:8888",
}
],
extra_body={"enable_response_messages": True},
)
assert response is not None
assert response.status == "completed"
# Verify output messages: Tool calls and responses on analysis channel
tool_call_found = False
tool_response_found = False
for message in response.output_messages:
recipient = message.get("recipient")
if recipient and recipient.startswith("python"):
tool_call_found = True
assert message.get("channel") == "analysis", (
"Tool call should be on analysis channel"
)
author = message.get("author", {})
if (
author.get("role") == "tool"
and author.get("name")
and author.get("name").startswith("python")
):
tool_response_found = True
assert message.get("channel") == "analysis", (
"Tool response should be on analysis channel"
)
assert tool_call_found, "Should have found at least one Python tool call"
assert tool_response_found, (
"Should have found at least one Python tool response"
)
for message in response.input_messages:
assert message.get("author").get("role") != "developer", (
"No developer messages should be present with valid mcp tool"
)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_mcp_tool_with_allowed_tools_star(
self, mcp_enabled_client: OpenAI, model_name: str
):
"""Test MCP tool with allowed_tools=['*'] to select all available
tools.
This E2E test verifies that the "*" wildcard works end-to-end.
See test_serving_responses.py for detailed unit tests of "*"
normalization.
"""
response = await mcp_enabled_client.responses.create(
model=model_name,
input=(
"Execute the following code: "
"import random; print(random.randint(1, 1000000))"
),
instructions=(
"You must use the Python tool to execute code. "
"Never simulate execution."
),
tools=[
{
"type": "mcp",
"server_label": "code_interpreter",
"server_url": "http://localhost:8888",
# Using "*" to allow all tools from this MCP server
"allowed_tools": ["*"],
}
],
extra_body={"enable_response_messages": True},
)
assert response is not None
assert response.status == "completed"
# Verify tool calls work with allowed_tools=["*"]
tool_call_found = False
for message in response.output_messages:
recipient = message.get("recipient")
if recipient and recipient.startswith("python"):
tool_call_found = True
break
assert tool_call_found, (
"Should have found at least one Python tool call with '*'"
)
@pytest.mark.flaky(reruns=3)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_mcp_tool_calling_streaming_types(
self, mcp_enabled_client: OpenAI, model_name: str
):
pairs_of_event_types = {
"response.completed": "response.created",
"response.output_item.done": "response.output_item.added",
"response.content_part.done": "response.content_part.added",
"response.output_text.done": "response.output_text.delta",
"response.reasoning_text.done": "response.reasoning_text.delta",
"response.reasoning_part.done": "response.reasoning_part.added",
"response.mcp_call_arguments.done": ("response.mcp_call_arguments.delta"),
"response.mcp_call.completed": "response.mcp_call.in_progress",
}
tools = [
{
"type": "mcp",
"server_label": "code_interpreter",
}
]
input_text = "What is 13 * 24? Use python to calculate the result."
stream_response = await mcp_enabled_client.responses.create(
model=model_name,
input=input_text,
tools=tools,
stream=True,
instructions=(
"You must use the Python tool to execute code. "
"Never simulate execution."
),
)
stack_of_event_types = []
saw_mcp_type = False
async for event in stream_response:
if event.type == "response.created":
stack_of_event_types.append(event.type)
elif event.type == "response.completed":
assert stack_of_event_types[-1] == pairs_of_event_types[event.type]
stack_of_event_types.pop()
elif (
event.type.endswith("added")
or event.type == "response.mcp_call.in_progress"
):
stack_of_event_types.append(event.type)
elif event.type.endswith("delta"):
if stack_of_event_types[-1] == event.type:
continue
stack_of_event_types.append(event.type)
elif (
event.type.endswith("done")
or event.type == "response.mcp_call.completed"
):
assert stack_of_event_types[-1] == pairs_of_event_types[event.type]
if "mcp_call" in event.type:
saw_mcp_type = True
stack_of_event_types.pop()
assert len(stack_of_event_types) == 0
assert saw_mcp_type, "Should have seen at least one mcp call"
class TestMCPDisabled:
"""Tests that verify behavior when MCP tools are disabled."""
@pytest.fixture(scope="class")
def monkeypatch_class(self):
from _pytest.monkeypatch import MonkeyPatch
mpatch = MonkeyPatch()
yield mpatch
mpatch.undo()
@pytest.fixture(scope="class")
def mcp_disabled_server(self, monkeypatch_class: pytest.MonkeyPatch):
args = ["--enforce-eager", "--tool-server", "demo"]
with monkeypatch_class.context() as m:
m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1")
m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv")
# Helps the model follow instructions better
m.setenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "1")
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def mcp_disabled_client(self, mcp_disabled_server):
async with mcp_disabled_server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_mcp_tool_env_flag_disabled(
self, mcp_disabled_client: OpenAI, model_name: str
):
response = await mcp_disabled_client.responses.create(
model=model_name,
input=(
"Execute the following code if the tool is present: "
"import random; print(random.randint(1, 1000000))"
),
tools=[
{
"type": "mcp",
"server_label": "code_interpreter",
# URL unused for DemoToolServer
"server_url": "http://localhost:8888",
}
],
extra_body={"enable_response_messages": True},
)
assert response is not None
assert response.status == "completed"
# Verify output messages: No tool calls and responses
tool_call_found = False
tool_response_found = False
for message in response.output_messages:
recipient = message.get("recipient")
if recipient and recipient.startswith("python"):
tool_call_found = True
assert message.get("channel") == "analysis", (
"Tool call should be on analysis channel"
)
author = message.get("author", {})
if (
author.get("role") == "tool"
and author.get("name")
and author.get("name").startswith("python")
):
tool_response_found = True
assert message.get("channel") == "analysis", (
"Tool response should be on analysis channel"
)
assert not tool_call_found, "Should not have a python call"
assert not tool_response_found, "Should not have a tool response"
for message in response.input_messages:
assert message.get("author").get("role") != "developer", (
"No developer messages should be present without a valid tool"
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
import importlib.util
import json
import time
......@@ -44,6 +43,8 @@ def server():
env_dict = dict(
VLLM_ENABLE_RESPONSES_API_STORE="1",
PYTHON_EXECUTION_BACKEND="dangerously_use_uv",
VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS="code_interpreter,container,web_search_preview",
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS="1",
)
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as remote_server:
......@@ -855,6 +856,237 @@ async def test_function_calling_with_stream(client: OpenAI, model_name: str):
assert event.response.output_text is not None
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_function_calling_no_code_interpreter_events(
client: OpenAI, model_name: str
):
"""Verify that function calls don't trigger code_interpreter events.
This test ensures that function calls (functions.*) use their own
function_call event types and don't incorrectly emit code_interpreter
events during streaming.
"""
tools = [GET_WEATHER_SCHEMA]
input_list = [
{
"role": "user",
"content": "What's the weather like in Paris today?",
}
]
stream_response = await client.responses.create(
model=model_name,
input=input_list,
tools=tools,
stream=True,
)
# Track which event types we see
event_types_seen = set()
function_call_found = False
async for event in stream_response:
event_types_seen.add(event.type)
if (
event.type == "response.output_item.added"
and event.item.type == "function_call"
):
function_call_found = True
# Ensure NO code_interpreter events are emitted for function calls
assert "code_interpreter" not in event.type, (
"Found code_interpreter event "
f"'{event.type}' during function call. Function calls should only "
"emit function_call events, not code_interpreter events."
)
# Verify we actually saw a function call
assert function_call_found, "Expected to see a function_call in the stream"
# Verify we saw the correct function call event types
assert (
"response.function_call_arguments.delta" in event_types_seen
or "response.function_call_arguments.done" in event_types_seen
), "Expected to see function_call_arguments events"
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_mcp_code_interpreter_streaming(client: OpenAI, model_name: str, server):
tools = [
{
"type": "mcp",
"server_label": "code_interpreter",
}
]
input_text = (
"Calculate 15 * 32 using python. "
"The python interpreter is not stateful and you must print to see the output."
)
stream_response = await client.responses.create(
model=model_name,
input=input_text,
tools=tools,
stream=True,
temperature=0.0,
instructions=(
"You must use the Python tool to execute code. Never simulate execution."
),
)
mcp_call_added = False
mcp_call_in_progress = False
mcp_arguments_delta_seen = False
mcp_arguments_done = False
mcp_call_completed = False
mcp_item_done = False
code_interpreter_events_seen = False
async for event in stream_response:
if "code_interpreter" in event.type:
code_interpreter_events_seen = True
if event.type == "response.output_item.added":
if hasattr(event.item, "type") and event.item.type == "mcp_call":
mcp_call_added = True
assert event.item.name == "python"
assert event.item.server_label == "code_interpreter"
elif event.type == "response.mcp_call.in_progress":
mcp_call_in_progress = True
elif event.type == "response.mcp_call_arguments.delta":
mcp_arguments_delta_seen = True
assert event.delta is not None
elif event.type == "response.mcp_call_arguments.done":
mcp_arguments_done = True
assert event.name == "python"
assert event.arguments is not None
elif event.type == "response.mcp_call.completed":
mcp_call_completed = True
elif (
event.type == "response.output_item.done"
and hasattr(event.item, "type")
and event.item.type == "mcp_call"
):
mcp_item_done = True
assert event.item.name == "python"
assert event.item.status == "completed"
assert mcp_call_added, "MCP call was not added"
assert mcp_call_in_progress, "MCP call in_progress event not seen"
assert mcp_arguments_delta_seen, "MCP arguments delta event not seen"
assert mcp_arguments_done, "MCP arguments done event not seen"
assert mcp_call_completed, "MCP call completed event not seen"
assert mcp_item_done, "MCP item done event not seen"
assert not code_interpreter_events_seen, (
"Should not see code_interpreter events when using MCP type"
)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_mcp_tool_multi_turn(client: OpenAI, model_name: str, server):
"""Test MCP tool calling across multiple turns.
This test verifies that MCP tools work correctly in multi-turn conversations,
maintaining state across turns via the previous_response_id mechanism.
"""
tools = [
{
"type": "mcp",
"server_label": "code_interpreter",
}
]
# First turn - make a calculation
response1 = await client.responses.create(
model=model_name,
input="Calculate 123 * 456 using python and print the result.",
tools=tools,
temperature=0.0,
instructions=(
"You must use the Python tool to execute code. Never simulate execution."
),
extra_body={"enable_response_messages": True},
)
assert response1 is not None
assert response1.status == "completed"
# Verify MCP call in first response by checking output_messages
tool_call_found = False
tool_response_found = False
for message in response1.output_messages:
recipient = message.get("recipient")
if recipient and recipient.startswith("python"):
tool_call_found = True
author = message.get("author", {})
if (
author.get("role") == "tool"
and author.get("name")
and author.get("name").startswith("python")
):
tool_response_found = True
# Verify MCP tools were actually used
assert tool_call_found, "MCP tool call not found in output_messages"
assert tool_response_found, "MCP tool response not found in output_messages"
# Verify input messages: Should have system message with tool, NO developer message
developer_messages = [
msg for msg in response1.input_messages if msg["author"]["role"] == "developer"
]
assert len(developer_messages) == 0, (
"No developer message expected for elevated tools"
)
# Second turn - reference previous calculation
response2 = await client.responses.create(
model=model_name,
input="Now divide that result by 2.",
tools=tools,
temperature=0.0,
instructions=(
"You must use the Python tool to execute code. Never simulate execution."
),
previous_response_id=response1.id,
extra_body={"enable_response_messages": True},
)
assert response2 is not None
assert response2.status == "completed"
# Verify input messages are correct: should have two messages -
# one to the python recipient on analysis channel and one from tool role
mcp_recipient_messages = []
tool_role_messages = []
for msg in response2.input_messages:
if msg["author"]["role"] == "assistant":
# Check if this is a message to MCP recipient on analysis channel
if msg.get("channel") == "analysis" and msg.get("recipient"):
recipient = msg.get("recipient")
if recipient.startswith("code_interpreter") or recipient == "python":
mcp_recipient_messages.append(msg)
elif msg["author"]["role"] == "tool":
tool_role_messages.append(msg)
assert len(mcp_recipient_messages) > 0, (
"Expected message(s) to MCP recipient on analysis channel"
)
assert len(tool_role_messages) > 0, (
"Expected message(s) from tool role after MCP call"
)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_output_messages_enabled(client: OpenAI, model_name: str, server):
......
......@@ -9,6 +9,7 @@ from collections import deque
from collections.abc import AsyncGenerator, AsyncIterator, Callable, Sequence
from contextlib import AsyncExitStack
from copy import copy
from dataclasses import dataclass
from http import HTTPStatus
from typing import Final
......@@ -27,6 +28,10 @@ from openai.types.responses import (
ResponseFunctionCallArgumentsDoneEvent,
ResponseFunctionToolCall,
ResponseFunctionWebSearch,
ResponseMcpCallArgumentsDeltaEvent,
ResponseMcpCallArgumentsDoneEvent,
ResponseMcpCallCompletedEvent,
ResponseMcpCallInProgressEvent,
ResponseOutputItem,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
......@@ -44,6 +49,7 @@ from openai.types.responses import (
response_function_web_search,
response_text_delta_event,
)
from openai.types.responses.response_output_item import McpCall
from openai.types.responses.response_output_text import Logprob, LogprobTopLogprob
from openai.types.responses.response_reasoning_item import (
Content as ResponseReasoningTextContent,
......@@ -119,6 +125,23 @@ from vllm.utils import random_uuid
logger = init_logger(__name__)
@dataclass
class HarmonyStreamingState:
"""Mutable state for harmony streaming event processing."""
current_content_index: int = -1
current_output_index: int = 0
current_item_id: str = ""
sent_output_item_added: bool = False
is_first_function_call_delta: bool = False
def reset_for_new_item(self) -> None:
"""Reset state when expecting a new output item."""
self.current_output_index += 1
self.sent_output_item_added = False
self.is_first_function_call_delta = False
def _extract_allowed_tools_from_mcp_requests(
tools: list[Tool],
) -> dict[str, list[str] | None]:
......@@ -740,6 +763,26 @@ class OpenAIServingResponses(OpenAIServing):
self.response_store[response.id] = response
return response
def _is_mcp_tool_by_namespace(self, recipient: str | None) -> bool:
"""
Determine if a tool call is an MCP tool based on recipient prefix.
- Tools starting with "functions." are function calls
- Everything else is an MCP tool
"""
if recipient is None:
return False
# Function calls have "functions." prefix
# Everything else is an MCP tool
return not recipient.startswith("functions.")
_TOOL_NAME_TO_MCP_SERVER_LABEL: Final[dict[str, str]] = {
"python": "code_interpreter",
"container": "container",
"browser": "web_search_preview",
}
def _topk_logprobs(
self,
logprobs: dict[int, SampleLogprob],
......@@ -1036,8 +1079,7 @@ class OpenAIServingResponses(OpenAIServing):
del prev_msgs[prev_final_msg_idx + 1 :]
for msg in recent_turn_msgs:
assert isinstance(msg, OpenAIHarmonyMessage)
if msg.channel != "analysis":
prev_msgs.append(msg)
prev_msgs.append(msg)
messages.extend(prev_msgs)
# Append the new input.
# Responses API supports simple text inputs without chat format.
......@@ -1520,6 +1562,816 @@ class OpenAIServingResponses(OpenAIServing):
)
)
def _emit_function_call_done_events(
self,
previous_item,
state: HarmonyStreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events when a function call completes."""
function_name = previous_item.recipient[len("functions.") :]
events = []
events.append(
ResponseFunctionCallArgumentsDoneEvent(
type="response.function_call_arguments.done",
arguments=previous_item.content[0].text,
name=function_name,
item_id=state.current_item_id,
output_index=state.current_output_index,
sequence_number=-1,
)
)
function_call_item = ResponseFunctionToolCall(
type="function_call",
arguments=previous_item.content[0].text,
name=function_name,
item_id=state.current_item_id,
output_index=state.current_output_index,
sequence_number=-1,
call_id=f"fc_{random_uuid()}",
status="completed",
)
events.append(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=state.current_output_index,
item=function_call_item,
)
)
return events
def _emit_mcp_call_done_events(
self,
previous_item,
state: HarmonyStreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events when an MCP tool call completes."""
server_label = self._TOOL_NAME_TO_MCP_SERVER_LABEL.get(
previous_item.recipient, previous_item.recipient
)
events = []
events.append(
ResponseMcpCallArgumentsDoneEvent(
type="response.mcp_call_arguments.done",
arguments=previous_item.content[0].text,
name=previous_item.recipient,
item_id=state.current_item_id,
output_index=state.current_output_index,
sequence_number=-1,
)
)
events.append(
ResponseMcpCallCompletedEvent(
type="response.mcp_call.completed",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
)
)
events.append(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=state.current_output_index,
item=McpCall(
type="mcp_call",
arguments=previous_item.content[0].text,
name=previous_item.recipient,
id=state.current_item_id,
server_label=server_label,
status="completed",
),
)
)
return events
def _emit_reasoning_done_events(
self,
previous_item,
state: HarmonyStreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events when a reasoning (analysis) item completes."""
content = ResponseReasoningTextContent(
text=previous_item.content[0].text,
type="reasoning_text",
)
reasoning_item = ResponseReasoningItem(
type="reasoning",
content=[content],
status="completed",
id=state.current_item_id,
summary=[],
)
events = []
events.append(
ResponseReasoningTextDoneEvent(
type="response.reasoning_text.done",
item_id=state.current_item_id,
sequence_number=-1,
output_index=state.current_output_index,
content_index=state.current_content_index,
text=previous_item.content[0].text,
)
)
events.append(
ResponseReasoningPartDoneEvent(
type="response.reasoning_part.done",
sequence_number=-1,
item_id=state.current_item_id,
output_index=state.current_output_index,
content_index=state.current_content_index,
part=content,
)
)
events.append(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=state.current_output_index,
item=reasoning_item,
)
)
return events
def _emit_text_output_done_events(
self,
previous_item,
state: HarmonyStreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events when a final text output item completes."""
text_content = ResponseOutputText(
type="output_text",
text=previous_item.content[0].text,
annotations=[],
)
events = []
events.append(
ResponseTextDoneEvent(
type="response.output_text.done",
sequence_number=-1,
output_index=state.current_output_index,
content_index=state.current_content_index,
text=previous_item.content[0].text,
logprobs=[],
item_id=state.current_item_id,
)
)
events.append(
ResponseContentPartDoneEvent(
type="response.content_part.done",
sequence_number=-1,
item_id=state.current_item_id,
output_index=state.current_output_index,
content_index=state.current_content_index,
part=text_content,
)
)
events.append(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=state.current_output_index,
item=ResponseOutputMessage(
id=state.current_item_id,
type="message",
role="assistant",
content=[text_content],
status="completed",
),
)
)
return events
def _emit_previous_item_done_events(
self,
previous_item,
state: HarmonyStreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit done events for the previous item when expecting a new start."""
if previous_item.recipient is not None:
# Deal with tool call
if previous_item.recipient.startswith("functions."):
return self._emit_function_call_done_events(previous_item, state)
elif (
self._is_mcp_tool_by_namespace(previous_item.recipient)
and state.current_item_id is not None
and state.current_item_id.startswith("mcp_")
):
return self._emit_mcp_call_done_events(previous_item, state)
elif previous_item.channel == "analysis":
return self._emit_reasoning_done_events(previous_item, state)
elif previous_item.channel == "final":
return self._emit_text_output_done_events(previous_item, state)
return []
def _emit_final_channel_delta_events(
self,
ctx: StreamingHarmonyContext,
state: HarmonyStreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events for final channel text delta streaming."""
events = []
if not state.sent_output_item_added:
state.sent_output_item_added = True
state.current_item_id = f"msg_{random_uuid()}"
events.append(
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=state.current_output_index,
item=ResponseOutputMessage(
id=state.current_item_id,
type="message",
role="assistant",
content=[],
status="in_progress",
),
)
)
state.current_content_index += 1
events.append(
ResponseContentPartAddedEvent(
type="response.content_part.added",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
content_index=state.current_content_index,
part=ResponseOutputText(
type="output_text",
text="",
annotations=[],
logprobs=[],
),
)
)
events.append(
ResponseTextDeltaEvent(
type="response.output_text.delta",
sequence_number=-1,
content_index=state.current_content_index,
output_index=state.current_output_index,
item_id=state.current_item_id,
delta=ctx.parser.last_content_delta,
# TODO, use logprobs from ctx.last_request_output
logprobs=[],
)
)
return events
def _emit_analysis_channel_delta_events(
self,
ctx: StreamingHarmonyContext,
state: HarmonyStreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events for analysis channel reasoning delta streaming."""
events = []
if not state.sent_output_item_added:
state.sent_output_item_added = True
state.current_item_id = f"msg_{random_uuid()}"
events.append(
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=state.current_output_index,
item=ResponseReasoningItem(
type="reasoning",
id=state.current_item_id,
summary=[],
status="in_progress",
),
)
)
state.current_content_index += 1
events.append(
ResponseReasoningPartAddedEvent(
type="response.reasoning_part.added",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
content_index=state.current_content_index,
part=ResponseReasoningTextContent(
text="",
type="reasoning_text",
),
)
)
events.append(
ResponseReasoningTextDeltaEvent(
type="response.reasoning_text.delta",
item_id=state.current_item_id,
output_index=state.current_output_index,
content_index=state.current_content_index,
delta=ctx.parser.last_content_delta,
sequence_number=-1,
)
)
return events
def _emit_mcp_tool_delta_events(
self,
ctx: StreamingHarmonyContext,
state: HarmonyStreamingState,
recipient: str,
) -> list[StreamingResponsesResponse]:
"""Emit events for MCP tool delta streaming."""
server_label = self._TOOL_NAME_TO_MCP_SERVER_LABEL.get(recipient, recipient)
events = []
if not state.sent_output_item_added:
state.sent_output_item_added = True
state.current_item_id = f"mcp_{random_uuid()}"
events.append(
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=state.current_output_index,
item=McpCall(
type="mcp_call",
id=state.current_item_id,
name=recipient,
arguments="",
server_label=server_label,
status="in_progress",
),
)
)
events.append(
ResponseMcpCallInProgressEvent(
type="response.mcp_call.in_progress",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
)
)
events.append(
ResponseMcpCallArgumentsDeltaEvent(
type="response.mcp_call_arguments.delta",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
delta=ctx.parser.last_content_delta,
)
)
return events
def _emit_code_interpreter_delta_events(
self,
ctx: StreamingHarmonyContext,
state: HarmonyStreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events for code interpreter delta streaming."""
events = []
if not state.sent_output_item_added:
state.sent_output_item_added = True
state.current_item_id = f"tool_{random_uuid()}"
events.append(
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=state.current_output_index,
item=ResponseCodeInterpreterToolCallParam(
type="code_interpreter_call",
id=state.current_item_id,
code=None,
container_id="auto",
outputs=None,
status="in_progress",
),
)
)
events.append(
ResponseCodeInterpreterCallInProgressEvent(
type="response.code_interpreter_call.in_progress",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
)
)
events.append(
ResponseCodeInterpreterCallCodeDeltaEvent(
type="response.code_interpreter_call_code.delta",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
delta=ctx.parser.last_content_delta,
)
)
return events
def _emit_mcp_prefix_delta_events(
self,
ctx: StreamingHarmonyContext,
state: HarmonyStreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events for MCP prefix (mcp.*) delta streaming."""
events = []
if not state.sent_output_item_added:
state.sent_output_item_added = True
state.current_item_id = f"mcp_{random_uuid()}"
mcp_name = ctx.parser.current_recipient[len("mcp.") :]
events.append(
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=state.current_output_index,
item=McpCall(
type="mcp_call",
id=state.current_item_id,
name=mcp_name,
arguments="",
server_label=mcp_name,
status="in_progress",
),
)
)
events.append(
ResponseMcpCallInProgressEvent(
type="response.mcp_call.in_progress",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
)
)
events.append(
ResponseMcpCallArgumentsDeltaEvent(
type="response.mcp_call_arguments.delta",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
delta=ctx.parser.last_content_delta,
)
)
return events
def _emit_content_delta_events(
self,
ctx: StreamingHarmonyContext,
state: HarmonyStreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events for content delta streaming based on channel type."""
if not ctx.parser.last_content_delta:
return []
if (
ctx.parser.current_channel == "final"
and ctx.parser.current_recipient is None
):
return self._emit_final_channel_delta_events(ctx, state)
elif (
ctx.parser.current_channel == "analysis"
and ctx.parser.current_recipient is None
):
return self._emit_analysis_channel_delta_events(ctx, state)
# built-in tools will be triggered on the analysis channel
# However, occasionally built-in tools will
# still be output to commentary.
elif (
ctx.parser.current_channel == "commentary"
or ctx.parser.current_channel == "analysis"
) and ctx.parser.current_recipient is not None:
recipient = ctx.parser.current_recipient
# Check for function calls first - they have their own event handling
if recipient.startswith("functions."):
return self._emit_function_call_delta_events(ctx, state)
is_mcp_tool = self._is_mcp_tool_by_namespace(recipient)
if is_mcp_tool:
return self._emit_mcp_tool_delta_events(ctx, state, recipient)
else:
return self._emit_code_interpreter_delta_events(ctx, state)
elif (
(
ctx.parser.current_channel == "commentary"
or ctx.parser.current_channel == "analysis"
)
and ctx.parser.current_recipient is not None
and ctx.parser.current_recipient.startswith("mcp.")
):
return self._emit_mcp_prefix_delta_events(ctx, state)
return []
def _emit_browser_tool_events(
self,
previous_item,
state: HarmonyStreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events for browser tool calls (web search)."""
function_name = previous_item.recipient[len("browser.") :]
parsed_args = json.loads(previous_item.content[0].text)
action = None
if function_name == "search":
action = response_function_web_search.ActionSearch(
type="search",
query=parsed_args["query"],
)
elif function_name == "open":
action = response_function_web_search.ActionOpenPage(
type="open_page",
# TODO: translate to url
url=f"cursor:{parsed_args.get('cursor', '')}",
)
elif function_name == "find":
action = response_function_web_search.ActionFind(
type="find",
pattern=parsed_args["pattern"],
# TODO: translate to url
url=f"cursor:{parsed_args.get('cursor', '')}",
)
else:
raise ValueError(f"Unknown function name: {function_name}")
state.current_item_id = f"tool_{random_uuid()}"
events = []
events.append(
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=state.current_output_index,
item=response_function_web_search.ResponseFunctionWebSearch(
# TODO: generate a unique id for web search call
type="web_search_call",
id=state.current_item_id,
action=action,
status="in_progress",
),
)
)
events.append(
ResponseWebSearchCallInProgressEvent(
type="response.web_search_call.in_progress",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
)
)
events.append(
ResponseWebSearchCallSearchingEvent(
type="response.web_search_call.searching",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
)
)
# enqueue
events.append(
ResponseWebSearchCallCompletedEvent(
type="response.web_search_call.completed",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
)
)
events.append(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=state.current_output_index,
item=ResponseFunctionWebSearch(
type="web_search_call",
id=state.current_item_id,
action=action,
status="completed",
),
)
)
return events
def _emit_mcp_tool_completion_events(
self,
previous_item,
state: HarmonyStreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events when an MCP tool completes during assistant action turn."""
recipient = previous_item.recipient
server_label = self._TOOL_NAME_TO_MCP_SERVER_LABEL.get(recipient, recipient)
events = []
events.append(
ResponseMcpCallArgumentsDoneEvent(
type="response.mcp_call_arguments.done",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
arguments=previous_item.content[0].text,
name=recipient,
)
)
events.append(
ResponseMcpCallCompletedEvent(
type="response.mcp_call.completed",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
)
)
events.append(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=state.current_output_index,
item=McpCall(
type="mcp_call",
id=state.current_item_id,
name=recipient,
arguments=previous_item.content[0].text,
server_label=server_label,
status="completed",
),
)
)
return events
def _emit_code_interpreter_completion_events(
self,
previous_item,
state: HarmonyStreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events when code interpreter completes."""
events = []
events.append(
ResponseCodeInterpreterCallCodeDoneEvent(
type="response.code_interpreter_call_code.done",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
code=previous_item.content[0].text,
)
)
events.append(
ResponseCodeInterpreterCallInterpretingEvent(
type="response.code_interpreter_call.interpreting",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
)
)
events.append(
ResponseCodeInterpreterCallCompletedEvent(
type="response.code_interpreter_call.completed",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
)
)
events.append(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=state.current_output_index,
item=ResponseCodeInterpreterToolCallParam(
type="code_interpreter_call",
id=state.current_item_id,
code=previous_item.content[0].text,
container_id="auto",
outputs=[],
status="completed",
),
)
)
return events
def _emit_mcp_prefix_completion_events(
self,
previous_item,
state: HarmonyStreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events when an MCP prefix tool (mcp.*) completes."""
mcp_name = previous_item.recipient[len("mcp.") :]
events = []
events.append(
ResponseMcpCallArgumentsDoneEvent(
type="response.mcp_call_arguments.done",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
arguments=previous_item.content[0].text,
name=mcp_name,
)
)
events.append(
ResponseMcpCallCompletedEvent(
type="response.mcp_call.completed",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
)
)
events.append(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=state.current_output_index,
item=McpCall(
type="mcp_call",
id=state.current_item_id,
name=mcp_name,
arguments=previous_item.content[0].text,
server_label=mcp_name,
status="completed",
),
)
)
return events
def _emit_tool_action_events(
self,
ctx: StreamingHarmonyContext,
state: HarmonyStreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events for tool action turn."""
if not ctx.is_assistant_action_turn() or len(ctx.parser.messages) == 0:
return []
events = []
previous_item = ctx.parser.messages[-1]
# Handle browser tool
if (
self.tool_server is not None
and self.tool_server.has_tool("browser")
and previous_item.recipient is not None
and previous_item.recipient.startswith("browser.")
):
events.extend(self._emit_browser_tool_events(previous_item, state))
# Handle tool completion
if (
self.tool_server is not None
and previous_item.recipient is not None
and state.current_item_id is not None
and state.sent_output_item_added
):
recipient = previous_item.recipient
# Handle MCP prefix tool completion first
if recipient.startswith("mcp."):
events.extend(
self._emit_mcp_prefix_completion_events(previous_item, state)
)
else:
# Handle other MCP tool and code interpreter completion
is_mcp_tool = self._is_mcp_tool_by_namespace(
recipient
) and state.current_item_id.startswith("mcp_")
if is_mcp_tool:
events.extend(
self._emit_mcp_tool_completion_events(previous_item, state)
)
else:
events.extend(
self._emit_code_interpreter_completion_events(
previous_item, state
)
)
return events
def _emit_function_call_delta_events(
self,
ctx: StreamingHarmonyContext,
state: HarmonyStreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events for developer function calls on commentary channel."""
if not (
ctx.parser.current_channel == "commentary"
and ctx.parser.current_recipient
and ctx.parser.current_recipient.startswith("functions.")
):
return []
events = []
if state.is_first_function_call_delta is False:
state.is_first_function_call_delta = True
fc_name = ctx.parser.current_recipient[len("functions.") :]
state.current_item_id = f"fc_{random_uuid()}"
tool_call_item = ResponseFunctionToolCall(
name=fc_name,
type="function_call",
id=state.current_item_id,
call_id=f"call_{random_uuid()}",
arguments="",
status="in_progress",
)
events.append(
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=state.current_output_index,
item=tool_call_item,
)
)
# Always emit the delta (including on first call)
events.append(
ResponseFunctionCallArgumentsDeltaEvent(
item_id=state.current_item_id,
delta=ctx.parser.last_content_delta,
output_index=state.current_output_index,
sequence_number=-1,
type="response.function_call_arguments.delta",
)
)
return events
async def _process_harmony_streaming_events(
self,
request: ResponsesRequest,
......@@ -1534,11 +2386,8 @@ class OpenAIServingResponses(OpenAIServing):
[StreamingResponsesResponse], StreamingResponsesResponse
],
) -> AsyncGenerator[StreamingResponsesResponse, None]:
current_content_index = -1
current_output_index = 0
current_item_id: str = ""
sent_output_item_added = False
is_first_function_call_delta = False
state = HarmonyStreamingState()
async for ctx in result_generator:
assert isinstance(ctx, StreamingHarmonyContext)
......@@ -1546,435 +2395,21 @@ class OpenAIServingResponses(OpenAIServing):
self._raise_if_error(ctx.finish_reason, request.request_id)
if ctx.is_expecting_start():
current_output_index += 1
sent_output_item_added = False
is_first_function_call_delta = False
if len(ctx.parser.messages) > 0:
previous_item = ctx.parser.messages[-1]
if previous_item.recipient is not None:
# Deal with tool call
if previous_item.recipient.startswith("functions."):
function_name = previous_item.recipient[len("functions.") :]
yield _increment_sequence_number_and_return(
ResponseFunctionCallArgumentsDoneEvent(
type="response.function_call_arguments.done",
arguments=previous_item.content[0].text,
name=function_name,
item_id=current_item_id,
output_index=current_output_index,
sequence_number=-1,
)
)
function_call_item = ResponseFunctionToolCall(
type="function_call",
arguments=previous_item.content[0].text,
name=function_name,
item_id=current_item_id,
output_index=current_output_index,
sequence_number=-1,
call_id=f"fc_{random_uuid()}",
status="completed",
)
yield _increment_sequence_number_and_return(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=current_output_index,
item=function_call_item,
)
)
elif previous_item.channel == "analysis":
content = ResponseReasoningTextContent(
text=previous_item.content[0].text,
type="reasoning_text",
)
reasoning_item = ResponseReasoningItem(
type="reasoning",
content=[content],
status="completed",
id=current_item_id,
summary=[],
)
yield _increment_sequence_number_and_return(
ResponseReasoningTextDoneEvent(
type="response.reasoning_text.done",
item_id=current_item_id,
sequence_number=-1,
output_index=current_output_index,
content_index=current_content_index,
text=previous_item.content[0].text,
)
)
yield _increment_sequence_number_and_return(
ResponseReasoningPartDoneEvent(
type="response.reasoning_part.done",
sequence_number=-1,
item_id=current_item_id,
output_index=current_output_index,
content_index=current_content_index,
part=content,
)
)
yield _increment_sequence_number_and_return(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=current_output_index,
item=reasoning_item,
)
)
elif previous_item.channel == "final":
text_content = ResponseOutputText(
type="output_text",
text=previous_item.content[0].text,
annotations=[],
)
yield _increment_sequence_number_and_return(
ResponseTextDoneEvent(
type="response.output_text.done",
sequence_number=-1,
output_index=current_output_index,
content_index=current_content_index,
text=previous_item.content[0].text,
logprobs=[],
item_id=current_item_id,
)
)
yield _increment_sequence_number_and_return(
ResponseContentPartDoneEvent(
type="response.content_part.done",
sequence_number=-1,
item_id=current_item_id,
output_index=current_output_index,
content_index=current_content_index,
part=text_content,
)
)
yield _increment_sequence_number_and_return(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=current_output_index,
item=ResponseOutputMessage(
id=current_item_id,
type="message",
role="assistant",
content=[text_content],
status="completed",
),
)
)
# stream the output of a harmony message
if ctx.parser.last_content_delta:
if (
ctx.parser.current_channel == "final"
and ctx.parser.current_recipient is None
):
if not sent_output_item_added:
sent_output_item_added = True
current_item_id = f"msg_{random_uuid()}"
yield _increment_sequence_number_and_return(
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=current_output_index,
item=ResponseOutputMessage(
id=current_item_id,
type="message",
role="assistant",
content=[],
status="in_progress",
),
)
)
current_content_index += 1
yield _increment_sequence_number_and_return(
ResponseContentPartAddedEvent(
type="response.content_part.added",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
content_index=current_content_index,
part=ResponseOutputText(
type="output_text",
text="",
annotations=[],
logprobs=[],
),
)
)
yield _increment_sequence_number_and_return(
ResponseTextDeltaEvent(
type="response.output_text.delta",
sequence_number=-1,
content_index=current_content_index,
output_index=current_output_index,
item_id=current_item_id,
delta=ctx.parser.last_content_delta,
# TODO, use logprobs from ctx.last_request_output
logprobs=[],
)
)
elif (
ctx.parser.current_channel == "analysis"
and ctx.parser.current_recipient is None
):
if not sent_output_item_added:
sent_output_item_added = True
current_item_id = f"msg_{random_uuid()}"
yield _increment_sequence_number_and_return(
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=current_output_index,
item=ResponseReasoningItem(
type="reasoning",
id=current_item_id,
summary=[],
status="in_progress",
),
)
)
current_content_index += 1
yield _increment_sequence_number_and_return(
ResponseReasoningPartAddedEvent(
type="response.reasoning_part.added",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
content_index=current_content_index,
part=ResponseReasoningTextContent(
text="",
type="reasoning_text",
),
)
)
yield _increment_sequence_number_and_return(
ResponseReasoningTextDeltaEvent(
type="response.reasoning_text.delta",
item_id=current_item_id,
output_index=current_output_index,
content_index=current_content_index,
delta=ctx.parser.last_content_delta,
sequence_number=-1,
)
)
# built-in tools will be triggered on the analysis channel
# However, occasionally built-in tools will
# still be output to commentary.
elif (
ctx.parser.current_channel == "commentary"
or ctx.parser.current_channel == "analysis"
) and ctx.parser.current_recipient == "python":
if not sent_output_item_added:
sent_output_item_added = True
current_item_id = f"tool_{random_uuid()}"
yield _increment_sequence_number_and_return(
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=current_output_index,
item=ResponseCodeInterpreterToolCallParam(
type="code_interpreter_call",
id=current_item_id,
code=None,
container_id="auto",
outputs=None,
status="in_progress",
),
)
)
yield _increment_sequence_number_and_return(
ResponseCodeInterpreterCallInProgressEvent(
type="response.code_interpreter_call.in_progress",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
)
)
yield _increment_sequence_number_and_return(
ResponseCodeInterpreterCallCodeDeltaEvent(
type="response.code_interpreter_call_code.delta",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
delta=ctx.parser.last_content_delta,
)
)
# stream tool call outputs
if ctx.is_assistant_action_turn() and len(ctx.parser.messages) > 0:
previous_item = ctx.parser.messages[-1]
if (
self.tool_server is not None
and self.tool_server.has_tool("browser")
and previous_item.recipient is not None
and previous_item.recipient.startswith("browser.")
):
function_name = previous_item.recipient[len("browser.") :]
action = None
parsed_args = json.loads(previous_item.content[0].text)
if function_name == "search":
action = response_function_web_search.ActionSearch(
type="search",
query=parsed_args["query"],
)
elif function_name == "open":
action = response_function_web_search.ActionOpenPage(
type="open_page",
# TODO: translate to url
url=f"cursor:{parsed_args.get('cursor', '')}",
)
elif function_name == "find":
action = response_function_web_search.ActionFind(
type="find",
pattern=parsed_args["pattern"],
# TODO: translate to url
url=f"cursor:{parsed_args.get('cursor', '')}",
)
else:
raise ValueError(f"Unknown function name: {function_name}")
current_item_id = f"tool_{random_uuid()}"
yield _increment_sequence_number_and_return(
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=current_output_index,
item=response_function_web_search.ResponseFunctionWebSearch(
# TODO: generate a unique id for web search call
type="web_search_call",
id=current_item_id,
action=action,
status="in_progress",
),
)
)
yield _increment_sequence_number_and_return(
ResponseWebSearchCallInProgressEvent(
type="response.web_search_call.in_progress",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
)
)
yield _increment_sequence_number_and_return(
ResponseWebSearchCallSearchingEvent(
type="response.web_search_call.searching",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
)
)
# enqueue
yield _increment_sequence_number_and_return(
ResponseWebSearchCallCompletedEvent(
type="response.web_search_call.completed",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
)
)
yield _increment_sequence_number_and_return(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=current_output_index,
item=ResponseFunctionWebSearch(
type="web_search_call",
id=current_item_id,
action=action,
status="completed",
),
)
)
if (
self.tool_server is not None
and self.tool_server.has_tool("python")
and previous_item.recipient is not None
and previous_item.recipient.startswith("python")
):
yield _increment_sequence_number_and_return(
ResponseCodeInterpreterCallCodeDoneEvent(
type="response.code_interpreter_call_code.done",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
code=previous_item.content[0].text,
)
)
yield _increment_sequence_number_and_return(
ResponseCodeInterpreterCallInterpretingEvent(
type="response.code_interpreter_call.interpreting",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
)
)
yield _increment_sequence_number_and_return(
ResponseCodeInterpreterCallCompletedEvent(
type="response.code_interpreter_call.completed",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
)
)
yield _increment_sequence_number_and_return(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=current_output_index,
item=ResponseCodeInterpreterToolCallParam(
type="code_interpreter_call",
id=current_item_id,
code=previous_item.content[0].text,
container_id="auto",
# TODO: add outputs here
outputs=[],
status="completed",
),
)
)
# developer tools will be triggered on the commentary channel
# and recipient starts with "functions.TOOL_NAME"
if (
ctx.parser.current_channel == "commentary"
and ctx.parser.current_recipient
and ctx.parser.current_recipient.startswith("functions.")
):
if is_first_function_call_delta is False:
is_first_function_call_delta = True
fc_name = ctx.parser.current_recipient[len("functions.") :]
tool_call_item = ResponseFunctionToolCall(
name=fc_name,
type="function_call",
id=current_item_id,
call_id=f"call_{random_uuid()}",
arguments="",
status="in_progress",
)
current_item_id = f"fc_{random_uuid()}"
yield _increment_sequence_number_and_return(
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=current_output_index,
item=tool_call_item,
)
)
else:
yield _increment_sequence_number_and_return(
ResponseFunctionCallArgumentsDeltaEvent(
item_id=current_item_id,
delta=ctx.parser.last_content_delta,
output_index=current_output_index,
sequence_number=-1,
type="response.function_call_arguments.delta",
)
)
for event in self._emit_previous_item_done_events(
previous_item, state
):
yield _increment_sequence_number_and_return(event)
state.reset_for_new_item()
# Stream the output of a harmony message
for event in self._emit_content_delta_events(ctx, state):
yield _increment_sequence_number_and_return(event)
# Stream tool call outputs
for event in self._emit_tool_action_events(ctx, state):
yield _increment_sequence_number_and_return(event)
async def responses_stream_generator(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment