Unverified Commit ec1d30c0 authored by Flora Feng's avatar Flora Feng Committed by GitHub
Browse files

[Responses] Decouple SSE event helpers from Harmony context (#35148)


Signed-off-by: default avatarsfeng33 <4florafeng@gmail.com>
parent e3b2324e
...@@ -39,6 +39,7 @@ def pairs_of_event_types() -> dict[str, str]: ...@@ -39,6 +39,7 @@ def pairs_of_event_types() -> dict[str, str]:
"response.mcp_call.completed": "response.mcp_call.in_progress", "response.mcp_call.completed": "response.mcp_call.in_progress",
"response.function_call_arguments.done": "response.function_call_arguments.delta", # noqa: E501 "response.function_call_arguments.done": "response.function_call_arguments.delta", # noqa: E501
"response.code_interpreter_call_code.done": "response.code_interpreter_call_code.delta", # noqa: E501 "response.code_interpreter_call_code.done": "response.code_interpreter_call_code.delta", # noqa: E501
"response.code_interpreter_call.completed": "response.code_interpreter_call.in_progress", # noqa: E501
"response.web_search_call.completed": "response.web_search_call.in_progress", "response.web_search_call.completed": "response.web_search_call.in_progress",
} }
# fmt: on # fmt: on
...@@ -108,29 +109,19 @@ def events_contain_type(events: list, type_substring: str) -> bool: ...@@ -108,29 +109,19 @@ def events_contain_type(events: list, type_substring: str) -> bool:
return any(type_substring in getattr(e, "type", "") for e in events) return any(type_substring in getattr(e, "type", "") for e in events)
def validate_streaming_event_stack( def _validate_event_pairing(events: list, pairs_of_event_types: dict[str, str]) -> None:
events: list, pairs_of_event_types: dict[str, str] """Validate that streaming events are properly nested/paired.
) -> None:
"""Validate that streaming events are properly nested/paired.""" Derives push/pop sets from *pairs_of_event_types* so that every
start/end pair in the dict is handled automatically.
"""
start_events = set(pairs_of_event_types.values())
end_events = set(pairs_of_event_types.keys())
stack: list[str] = [] stack: list[str] = []
for event in events: for event in events:
etype = event.type etype = event.type
if etype == "response.created": if etype in end_events:
stack.append(etype)
elif etype == "response.completed":
assert stack and stack[-1] == pairs_of_event_types[etype], (
f"Unexpected stack top for {etype}: "
f"got {stack[-1] if stack else '<empty>'}"
)
stack.pop()
elif etype.endswith("added") or etype == "response.mcp_call.in_progress":
stack.append(etype)
elif etype.endswith("delta"):
if stack and stack[-1] == etype:
continue
stack.append(etype)
elif etype.endswith("done") or etype == "response.mcp_call.completed":
assert etype in pairs_of_event_types, f"Unknown done event: {etype}"
expected_start = pairs_of_event_types[etype] expected_start = pairs_of_event_types[etype]
assert stack and stack[-1] == expected_start, ( assert stack and stack[-1] == expected_start, (
f"Stack mismatch for {etype}: " f"Stack mismatch for {etype}: "
...@@ -138,9 +129,180 @@ def validate_streaming_event_stack( ...@@ -138,9 +129,180 @@ def validate_streaming_event_stack(
f"got {stack[-1] if stack else '<empty>'}" f"got {stack[-1] if stack else '<empty>'}"
) )
stack.pop() stack.pop()
elif etype in start_events:
# Consecutive deltas of the same type share a single stack slot.
if etype.endswith("delta") and stack and stack[-1] == etype:
continue
stack.append(etype)
# else: passthrough event (e.g. response.in_progress,
# web_search_call.searching, code_interpreter_call.interpreting)
assert len(stack) == 0, f"Unclosed events on stack: {stack}" assert len(stack) == 0, f"Unclosed events on stack: {stack}"
def _validate_event_ordering(events: list) -> None:
"""Validate that envelope events appear in the correct positions."""
assert len(events) >= 2, f"Expected at least 2 events, got {len(events)}"
# First event must be response.created
assert events[0].type == "response.created", (
f"First event must be response.created, got {events[0].type}"
)
# Last event must be response.completed
assert events[-1].type == "response.completed", (
f"Last event must be response.completed, got {events[-1].type}"
)
# response.in_progress, if present, must be the second event
in_progress_indices = [
i for i, e in enumerate(events) if e.type == "response.in_progress"
]
if in_progress_indices:
assert in_progress_indices == [1], (
f"response.in_progress must be the second event, "
f"found at indices {in_progress_indices}"
)
# Exactly one created and one completed
created_count = sum(1 for e in events if e.type == "response.created")
completed_count = sum(1 for e in events if e.type == "response.completed")
assert created_count == 1, (
f"Expected exactly 1 response.created, got {created_count}"
)
assert completed_count == 1, (
f"Expected exactly 1 response.completed, got {completed_count}"
)
def _validate_field_consistency(events: list) -> None:
"""Validate item_id, output_index, and content_index consistency.
Tracks the active output item established by ``output_item.added``
and verifies that all subsequent events for that item carry matching
identifiers until ``output_item.done`` closes it.
"""
_SESSION_EVENTS = {
"response.created",
"response.in_progress",
"response.completed",
}
active_item_id: str | None = None
active_output_index: int | None = None
last_output_index: int = -1
active_content_index: int | None = None
for event in events:
etype = event.type
if etype in _SESSION_EVENTS:
continue
# --- output_item.added: opens a new item ------------------
if etype == "response.output_item.added":
item = getattr(event, "item", None)
output_index = getattr(event, "output_index", None)
assert item is not None, "output_item.added must have an item"
item_id = getattr(item, "id", None)
assert item_id, "output_item.added item must have an id"
# output_index must be non-decreasing across items
if output_index is not None:
assert output_index >= last_output_index, (
f"output_index went backwards: {output_index} < {last_output_index}"
)
last_output_index = output_index
active_item_id = item_id
active_output_index = output_index
active_content_index = None
continue
# --- output_item.done: closes the active item -------------
if etype == "response.output_item.done":
item = getattr(event, "item", None)
output_index = getattr(event, "output_index", None)
assert item is not None, "output_item.done must have an item"
done_item_id = getattr(item, "id", None)
if active_item_id is not None and done_item_id:
assert done_item_id == active_item_id, (
f"output_item.done item.id mismatch: "
f"expected {active_item_id}, got {done_item_id}"
)
if active_output_index is not None and output_index is not None:
assert output_index == active_output_index, (
f"output_item.done output_index mismatch: "
f"expected {active_output_index}, got {output_index}"
)
active_item_id = None
active_output_index = None
active_content_index = None
continue
# --- content_part / reasoning_part added: sets content_index
if etype in (
"response.content_part.added",
"response.reasoning_part.added",
):
_assert_item_fields(event, etype, active_item_id, active_output_index)
active_content_index = getattr(event, "content_index", None)
continue
# --- all other item-level events --------------------------
_assert_item_fields(event, etype, active_item_id, active_output_index)
# content_index (only meaningful on events that carry it)
content_index = getattr(event, "content_index", None)
if content_index is not None and active_content_index is not None:
assert content_index == active_content_index, (
f"{etype} content_index mismatch: "
f"expected {active_content_index}, got {content_index}"
)
def _assert_item_fields(
event,
etype: str,
active_item_id: str | None,
active_output_index: int | None,
) -> None:
"""Check that *event*'s item_id and output_index match the active item."""
event_item_id = getattr(event, "item_id", None)
output_index = getattr(event, "output_index", None)
if active_item_id is not None and event_item_id is not None:
assert event_item_id == active_item_id, (
f"{etype} item_id mismatch: expected {active_item_id}, got {event_item_id}"
)
if active_output_index is not None and output_index is not None:
assert output_index == active_output_index, (
f"{etype} output_index mismatch: "
f"expected {active_output_index}, got {output_index}"
)
def validate_streaming_event_stack(
events: list, pairs_of_event_types: dict[str, str]
) -> None:
"""Validate streaming events: pairing, ordering, and field consistency.
Checks three aspects:
1. **Event pairing** — start/end events are properly nested
(stack-based matching derived from *pairs_of_event_types*).
2. **Event ordering** — envelope events (``created``,
``in_progress``, ``completed``) appear at the correct positions.
3. **Field consistency** — ``item_id``, ``output_index``, and
``content_index`` are consistent across related events within
each output item's lifecycle.
"""
_validate_event_pairing(events, pairs_of_event_types)
_validate_event_ordering(events)
_validate_field_consistency(events)
def log_response_diagnostics( def log_response_diagnostics(
response, response,
*, *,
......
...@@ -910,21 +910,25 @@ async def test_function_calling_no_code_interpreter_events( ...@@ -910,21 +910,25 @@ async def test_function_calling_no_code_interpreter_events(
reason="This test is flaky in CI, needs investigation and " reason="This test is flaky in CI, needs investigation and "
"potential fixes in the code interpreter MCP implementation." "potential fixes in the code interpreter MCP implementation."
) )
async def test_mcp_code_interpreter_streaming(client: OpenAI, model_name: str, server): async def test_code_interpreter_streaming(
tools = [{"type": "mcp", "server_label": "code_interpreter"}] client: OpenAI,
model_name: str,
pairs_of_event_types: dict[str, str],
):
tools = [{"type": "code_interpreter", "container": {"type": "auto"}}]
input_text = ( input_text = (
"Calculate 123 * 456 using python. " "Calculate 123 * 456 using python. "
"The python interpreter is not stateful and you must " "The python interpreter is not stateful and you must "
"print to see the output." "print to see the output."
) )
def _has_mcp_call(evts: list) -> bool: def _has_code_interpreter(evts: list) -> bool:
return events_contain_type(evts, "mcp_call") return events_contain_type(evts, "code_interpreter")
events = await retry_streaming_for( events = await retry_streaming_for(
client, client,
model=model_name, model=model_name,
validate_events=_has_mcp_call, validate_events=_has_code_interpreter,
input=input_text, input=input_text,
tools=tools, tools=tools,
temperature=0.0, temperature=0.0,
...@@ -936,59 +940,36 @@ async def test_mcp_code_interpreter_streaming(client: OpenAI, model_name: str, s ...@@ -936,59 +940,36 @@ async def test_mcp_code_interpreter_streaming(client: OpenAI, model_name: str, s
event_types = [e.type for e in events] event_types = [e.type for e in events]
event_types_set = set(event_types) event_types_set = set(event_types)
logger.info( logger.info(
"\n====== MCP Streaming Diagnostics ======\n" "\n====== Code Interpreter Streaming Diagnostics ======\n"
"Event count: %d\n" "Event count: %d\n"
"Event types (in order): %s\n" "Event types (in order): %s\n"
"Unique event types: %s\n" "Unique event types: %s\n"
"=======================================", "====================================================",
len(events), len(events),
event_types, event_types,
sorted(event_types_set), sorted(event_types_set),
) )
# Verify the full MCP streaming lifecycle # Structural validation (pairing, ordering, field consistency)
assert "response.output_item.added" in event_types_set, ( validate_streaming_event_stack(events, pairs_of_event_types)
f"MCP call was not added. Events: {sorted(event_types_set)}"
)
assert "response.mcp_call.in_progress" in event_types_set, (
f"MCP call in_progress not seen. Events: {sorted(event_types_set)}"
)
assert "response.mcp_call_arguments.delta" in event_types_set, (
f"MCP arguments delta not seen. Events: {sorted(event_types_set)}"
)
assert "response.mcp_call_arguments.done" in event_types_set, (
f"MCP arguments done not seen. Events: {sorted(event_types_set)}"
)
assert "response.mcp_call.completed" in event_types_set, (
f"MCP call completed not seen. Events: {sorted(event_types_set)}"
)
assert "response.output_item.done" in event_types_set, (
f"MCP item done not seen. Events: {sorted(event_types_set)}"
)
# Validate specific MCP event details # Validate code interpreter item fields
for event in events: for event in events:
if event.type == "response.output_item.added": if (
if hasattr(event.item, "type") and event.item.type == "mcp_call": event.type == "response.output_item.added"
assert event.item.name == "python" and hasattr(event.item, "type")
assert event.item.server_label == "code_interpreter" and event.item.type == "code_interpreter_call"
elif event.type == "response.mcp_call_arguments.done": ):
assert event.name == "python" assert event.item.status == "in_progress"
assert event.arguments is not None elif event.type == "response.code_interpreter_call_code.done":
assert event.code is not None
elif ( elif (
event.type == "response.output_item.done" event.type == "response.output_item.done"
and hasattr(event.item, "type") and hasattr(event.item, "type")
and event.item.type == "mcp_call" and event.item.type == "code_interpreter_call"
): ):
assert event.item.name == "python"
assert event.item.status == "completed" assert event.item.status == "completed"
assert event.item.code is not None
# code_interpreter events should NOT appear when using MCP type
code_interp_events = [e.type for e in events if "code_interpreter" in e.type]
assert not code_interp_events, (
"Should not see code_interpreter events when using MCP type, "
f"but got: {code_interp_events}"
)
@pytest.mark.asyncio @pytest.mark.asyncio
......
...@@ -241,81 +241,3 @@ class TestMCPEnabled: ...@@ -241,81 +241,3 @@ class TestMCPEnabled:
) )
validate_streaming_event_stack(events, pairs_of_event_types) validate_streaming_event_stack(events, pairs_of_event_types)
assert events_contain_type(events, "mcp_call"), (
f"No mcp_call events after retries. "
f"Event types: {sorted({e.type for e in events})}"
)
class TestMCPDisabled:
"""Tests that MCP tools are not executed when the env flag is unset."""
@pytest.fixture(scope="class")
def mcp_disabled_server(self):
env_dict = {
**BASE_TEST_ENV,
"VLLM_ENABLE_RESPONSES_API_STORE": "1",
"PYTHON_EXECUTION_BACKEND": "dangerously_use_uv",
"VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS": "1",
}
with RemoteOpenAIServer(
MODEL_NAME, list(_BASE_SERVER_ARGS), env_dict=env_dict
) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(self, mcp_disabled_server):
async with mcp_disabled_server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_mcp_disabled_server_does_not_execute(
self, client: OpenAI, model_name: str
):
"""When MCP is disabled the model may still attempt tool calls
(tool descriptions can remain in the prompt), but the server
must NOT execute them."""
response = await client.responses.create(
model=model_name,
input=(
"Execute the following code if the tool is present: "
"import random; print(random.randint(1, 1000000))"
),
tools=[
{
"type": "mcp",
"server_label": "code_interpreter",
"server_url": "http://localhost:8888",
}
],
temperature=0.0,
extra_body={"enable_response_messages": True},
)
assert response is not None
assert response.status == "completed"
log_response_diagnostics(response, label="MCP Disabled")
# Server must not have executed any tool calls
for message in response.output_messages:
author = message.get("author", {})
assert not (
author.get("role") == "tool"
and (author.get("name") or "").startswith("python")
), (
"Server executed a python tool call even though MCP is "
f"disabled. Message: {message}"
)
# No completed mcp_call output items
for item in response.output:
if getattr(item, "type", None) == "mcp_call":
assert getattr(item, "status", None) != "completed", (
"MCP call should not be completed when MCP is disabled"
)
# No developer messages injected
for message in response.input_messages:
assert message.get("author", {}).get("role") != "developer"
...@@ -89,7 +89,7 @@ from vllm.entrypoints.openai.responses.protocol import ( ...@@ -89,7 +89,7 @@ from vllm.entrypoints.openai.responses.protocol import (
StreamingResponsesResponse, StreamingResponsesResponse,
) )
from vllm.entrypoints.openai.responses.streaming_events import ( from vllm.entrypoints.openai.responses.streaming_events import (
HarmonyStreamingState, StreamingState,
emit_content_delta_events, emit_content_delta_events,
emit_previous_item_done_events, emit_previous_item_done_events,
emit_tool_action_events, emit_tool_action_events,
...@@ -1591,7 +1591,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -1591,7 +1591,7 @@ class OpenAIServingResponses(OpenAIServing):
[StreamingResponsesResponse], StreamingResponsesResponse [StreamingResponsesResponse], StreamingResponsesResponse
], ],
) -> AsyncGenerator[StreamingResponsesResponse, None]: ) -> AsyncGenerator[StreamingResponsesResponse, None]:
state = HarmonyStreamingState() state = StreamingState()
async for ctx in result_generator: async for ctx in result_generator:
assert isinstance(ctx, StreamingHarmonyContext) assert isinstance(ctx, StreamingHarmonyContext)
......
...@@ -6,6 +6,13 @@ Streaming SSE event builders for the Responses API. ...@@ -6,6 +6,13 @@ Streaming SSE event builders for the Responses API.
Pure functions that translate streaming state + delta data into Pure functions that translate streaming state + delta data into
OpenAI Response API SSE events. Used by the streaming event OpenAI Response API SSE events. Used by the streaming event
processors in serving.py. processors in serving.py.
The file is organized as:
1. StreamingState dataclass + utility helpers
2. Shared leaf helpers — delta events (take plain strings, no context)
3. Shared leaf helpers — done events (take plain strings, no context)
4. Harmony-specific dispatchers (route ctx/previous_item → leaf helpers)
5. Harmony-specific tool lifecycle helpers
""" """
import json import json
...@@ -47,6 +54,7 @@ from openai.types.responses.response_output_item import McpCall ...@@ -47,6 +54,7 @@ from openai.types.responses.response_output_item import McpCall
from openai.types.responses.response_reasoning_item import ( from openai.types.responses.response_reasoning_item import (
Content as ResponseReasoningTextContent, Content as ResponseReasoningTextContent,
) )
from openai_harmony import Message as HarmonyMessage
from vllm.entrypoints.mcp.tool_server import ToolServer from vllm.entrypoints.mcp.tool_server import ToolServer
from vllm.entrypoints.openai.responses.context import StreamingHarmonyContext from vllm.entrypoints.openai.responses.context import StreamingHarmonyContext
...@@ -64,13 +72,28 @@ TOOL_NAME_TO_MCP_SERVER_LABEL: Final[dict[str, str]] = { ...@@ -64,13 +72,28 @@ TOOL_NAME_TO_MCP_SERVER_LABEL: Final[dict[str, str]] = {
} }
def _resolve_mcp_name_label(recipient: str) -> tuple[str, str]:
"""Resolve MCP tool name and server label from a recipient string.
- ``mcp.*`` recipients: strip prefix, use the bare name as both
name and server_label.
- Everything else: use the recipient as the name and look up the
server_label in TOOL_NAME_TO_MCP_SERVER_LABEL.
"""
if recipient.startswith("mcp."):
name = recipient[len("mcp.") :]
return name, name
return recipient, TOOL_NAME_TO_MCP_SERVER_LABEL.get(recipient, recipient)
@dataclass @dataclass
class HarmonyStreamingState: class StreamingState:
"""Mutable state for harmony streaming event processing.""" """Mutable state for streaming event processing."""
current_content_index: int = -1 current_content_index: int = -1
current_output_index: int = 0 current_output_index: int = 0
current_item_id: str = "" current_item_id: str = ""
current_call_id: str = ""
sent_output_item_added: bool = False sent_output_item_added: bool = False
is_first_function_call_delta: bool = False is_first_function_call_delta: bool = False
...@@ -79,6 +102,7 @@ class HarmonyStreamingState: ...@@ -79,6 +102,7 @@ class HarmonyStreamingState:
self.current_output_index += 1 self.current_output_index += 1
self.sent_output_item_added = False self.sent_output_item_added = False
self.is_first_function_call_delta = False self.is_first_function_call_delta = False
self.current_call_id = ""
def is_mcp_tool_by_namespace(recipient: str | None) -> bool: def is_mcp_tool_by_namespace(recipient: str | None) -> bool:
...@@ -96,213 +120,16 @@ def is_mcp_tool_by_namespace(recipient: str | None) -> bool: ...@@ -96,213 +120,16 @@ def is_mcp_tool_by_namespace(recipient: str | None) -> bool:
return not recipient.startswith("functions.") return not recipient.startswith("functions.")
def emit_function_call_done_events( # =====================================================================
previous_item, # Shared leaf helpers — delta events
state: HarmonyStreamingState, # =====================================================================
) -> list[StreamingResponsesResponse]:
"""Emit events when a function call completes."""
function_name = previous_item.recipient[len("functions.") :]
events: list[StreamingResponsesResponse] = []
events.append(
ResponseFunctionCallArgumentsDoneEvent(
type="response.function_call_arguments.done",
arguments=previous_item.content[0].text,
name=function_name,
item_id=state.current_item_id,
output_index=state.current_output_index,
sequence_number=-1,
)
)
function_call_item = ResponseFunctionToolCall(
type="function_call",
arguments=previous_item.content[0].text,
name=function_name,
item_id=state.current_item_id,
output_index=state.current_output_index,
sequence_number=-1,
call_id=f"fc_{random_uuid()}",
status="completed",
)
events.append(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=state.current_output_index,
item=function_call_item,
)
)
return events
def emit_mcp_call_done_events(
previous_item,
state: HarmonyStreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events when an MCP tool call completes."""
server_label = TOOL_NAME_TO_MCP_SERVER_LABEL.get(
previous_item.recipient, previous_item.recipient
)
events: list[StreamingResponsesResponse] = []
events.append(
ResponseMcpCallArgumentsDoneEvent(
type="response.mcp_call_arguments.done",
arguments=previous_item.content[0].text,
name=previous_item.recipient,
item_id=state.current_item_id,
output_index=state.current_output_index,
sequence_number=-1,
)
)
events.append(
ResponseMcpCallCompletedEvent(
type="response.mcp_call.completed",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
)
)
events.append(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=state.current_output_index,
item=McpCall(
type="mcp_call",
arguments=previous_item.content[0].text,
name=previous_item.recipient,
id=state.current_item_id,
server_label=server_label,
status="completed",
),
)
)
return events
def emit_reasoning_done_events(
previous_item,
state: HarmonyStreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events when a reasoning (analysis) item completes."""
content = ResponseReasoningTextContent(
text=previous_item.content[0].text,
type="reasoning_text",
)
reasoning_item = ResponseReasoningItem(
type="reasoning",
content=[content],
status="completed",
id=state.current_item_id,
summary=[],
)
events: list[StreamingResponsesResponse] = []
events.append(
ResponseReasoningTextDoneEvent(
type="response.reasoning_text.done",
item_id=state.current_item_id,
sequence_number=-1,
output_index=state.current_output_index,
content_index=state.current_content_index,
text=previous_item.content[0].text,
)
)
events.append(
ResponseReasoningPartDoneEvent(
type="response.reasoning_part.done",
sequence_number=-1,
item_id=state.current_item_id,
output_index=state.current_output_index,
content_index=state.current_content_index,
part=content,
)
)
events.append(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=state.current_output_index,
item=reasoning_item,
)
)
return events
def emit_text_output_done_events(
previous_item,
state: HarmonyStreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events when a final text output item completes."""
text_content = ResponseOutputText(
type="output_text",
text=previous_item.content[0].text,
annotations=[],
)
events: list[StreamingResponsesResponse] = []
events.append(
ResponseTextDoneEvent(
type="response.output_text.done",
sequence_number=-1,
output_index=state.current_output_index,
content_index=state.current_content_index,
text=previous_item.content[0].text,
logprobs=[],
item_id=state.current_item_id,
)
)
events.append(
ResponseContentPartDoneEvent(
type="response.content_part.done",
sequence_number=-1,
item_id=state.current_item_id,
output_index=state.current_output_index,
content_index=state.current_content_index,
part=text_content,
)
)
events.append(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=state.current_output_index,
item=ResponseOutputMessage(
id=state.current_item_id,
type="message",
role="assistant",
content=[text_content],
status="completed",
),
)
)
return events
def emit_previous_item_done_events(
previous_item,
state: HarmonyStreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit done events for the previous item when expecting a new start."""
if previous_item.recipient is not None:
# Deal with tool call
if previous_item.recipient.startswith("functions."):
return emit_function_call_done_events(previous_item, state)
elif (
is_mcp_tool_by_namespace(previous_item.recipient)
and state.current_item_id is not None
and state.current_item_id.startswith("mcp_")
):
return emit_mcp_call_done_events(previous_item, state)
elif previous_item.channel == "analysis":
return emit_reasoning_done_events(previous_item, state)
elif previous_item.channel == "final":
return emit_text_output_done_events(previous_item, state)
return []
def emit_final_channel_delta_events( def emit_text_delta_events(
ctx: StreamingHarmonyContext, delta: str,
state: HarmonyStreamingState, state: StreamingState,
) -> list[StreamingResponsesResponse]: ) -> list[StreamingResponsesResponse]:
"""Emit events for final channel text delta streaming.""" """Emit events for text content delta streaming."""
events: list[StreamingResponsesResponse] = [] events: list[StreamingResponsesResponse] = []
if not state.sent_output_item_added: if not state.sent_output_item_added:
state.sent_output_item_added = True state.sent_output_item_added = True
...@@ -344,7 +171,7 @@ def emit_final_channel_delta_events( ...@@ -344,7 +171,7 @@ def emit_final_channel_delta_events(
content_index=state.current_content_index, content_index=state.current_content_index,
output_index=state.current_output_index, output_index=state.current_output_index,
item_id=state.current_item_id, item_id=state.current_item_id,
delta=ctx.last_content_delta, delta=delta,
# TODO, use logprobs from ctx.last_request_output # TODO, use logprobs from ctx.last_request_output
logprobs=[], logprobs=[],
) )
...@@ -352,11 +179,11 @@ def emit_final_channel_delta_events( ...@@ -352,11 +179,11 @@ def emit_final_channel_delta_events(
return events return events
def emit_analysis_channel_delta_events( def emit_reasoning_delta_events(
ctx: StreamingHarmonyContext, delta: str,
state: HarmonyStreamingState, state: StreamingState,
) -> list[StreamingResponsesResponse]: ) -> list[StreamingResponsesResponse]:
"""Emit events for analysis channel reasoning delta streaming.""" """Emit events for reasoning text delta streaming."""
events: list[StreamingResponsesResponse] = [] events: list[StreamingResponsesResponse] = []
if not state.sent_output_item_added: if not state.sent_output_item_added:
state.sent_output_item_added = True state.sent_output_item_added = True
...@@ -394,20 +221,60 @@ def emit_analysis_channel_delta_events( ...@@ -394,20 +221,60 @@ def emit_analysis_channel_delta_events(
item_id=state.current_item_id, item_id=state.current_item_id,
output_index=state.current_output_index, output_index=state.current_output_index,
content_index=state.current_content_index, content_index=state.current_content_index,
delta=ctx.last_content_delta, delta=delta,
sequence_number=-1, sequence_number=-1,
) )
) )
return events return events
def emit_mcp_tool_delta_events( def emit_function_call_delta_events(
ctx: StreamingHarmonyContext, delta: str,
state: HarmonyStreamingState, function_name: str,
state: StreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events for function call argument deltas."""
events: list[StreamingResponsesResponse] = []
if state.is_first_function_call_delta is False:
state.is_first_function_call_delta = True
state.current_item_id = f"fc_{random_uuid()}"
state.current_call_id = f"call_{random_uuid()}"
tool_call_item = ResponseFunctionToolCall(
name=function_name,
type="function_call",
id=state.current_item_id,
call_id=state.current_call_id,
arguments="",
status="in_progress",
)
events.append(
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=state.current_output_index,
item=tool_call_item,
)
)
# Always emit the delta (including on first call)
events.append(
ResponseFunctionCallArgumentsDeltaEvent(
item_id=state.current_item_id,
delta=delta,
output_index=state.current_output_index,
sequence_number=-1,
type="response.function_call_arguments.delta",
)
)
return events
def emit_mcp_delta_events(
delta: str,
state: StreamingState,
recipient: str, recipient: str,
) -> list[StreamingResponsesResponse]: ) -> list[StreamingResponsesResponse]:
"""Emit events for MCP tool delta streaming.""" """Emit events for MCP tool delta streaming."""
server_label = TOOL_NAME_TO_MCP_SERVER_LABEL.get(recipient, recipient) name, server_label = _resolve_mcp_name_label(recipient)
events: list[StreamingResponsesResponse] = [] events: list[StreamingResponsesResponse] = []
if not state.sent_output_item_added: if not state.sent_output_item_added:
state.sent_output_item_added = True state.sent_output_item_added = True
...@@ -420,7 +287,7 @@ def emit_mcp_tool_delta_events( ...@@ -420,7 +287,7 @@ def emit_mcp_tool_delta_events(
item=McpCall( item=McpCall(
type="mcp_call", type="mcp_call",
id=state.current_item_id, id=state.current_item_id,
name=recipient, name=name,
arguments="", arguments="",
server_label=server_label, server_label=server_label,
status="in_progress", status="in_progress",
...@@ -441,15 +308,15 @@ def emit_mcp_tool_delta_events( ...@@ -441,15 +308,15 @@ def emit_mcp_tool_delta_events(
sequence_number=-1, sequence_number=-1,
output_index=state.current_output_index, output_index=state.current_output_index,
item_id=state.current_item_id, item_id=state.current_item_id,
delta=ctx.last_content_delta, delta=delta,
) )
) )
return events return events
def emit_code_interpreter_delta_events( def emit_code_interpreter_delta_events(
ctx: StreamingHarmonyContext, delta: str,
state: HarmonyStreamingState, state: StreamingState,
) -> list[StreamingResponsesResponse]: ) -> list[StreamingResponsesResponse]:
"""Emit events for code interpreter delta streaming.""" """Emit events for code interpreter delta streaming."""
events: list[StreamingResponsesResponse] = [] events: list[StreamingResponsesResponse] = []
...@@ -485,151 +352,274 @@ def emit_code_interpreter_delta_events( ...@@ -485,151 +352,274 @@ def emit_code_interpreter_delta_events(
sequence_number=-1, sequence_number=-1,
output_index=state.current_output_index, output_index=state.current_output_index,
item_id=state.current_item_id, item_id=state.current_item_id,
delta=ctx.last_content_delta, delta=delta,
) )
) )
return events return events
def emit_mcp_prefix_delta_events( # =====================================================================
ctx: StreamingHarmonyContext, # Shared leaf helpers — done events
state: HarmonyStreamingState, # =====================================================================
def emit_text_output_done_events(
text: str,
state: StreamingState,
) -> list[StreamingResponsesResponse]: ) -> list[StreamingResponsesResponse]:
"""Emit events for MCP prefix (mcp.*) delta streaming.""" """Emit events when a final text output item completes."""
text_content = ResponseOutputText(
type="output_text",
text=text,
annotations=[],
)
events: list[StreamingResponsesResponse] = [] events: list[StreamingResponsesResponse] = []
if not state.sent_output_item_added:
state.sent_output_item_added = True
state.current_item_id = f"mcp_{random_uuid()}"
mcp_name = ctx.parser.current_recipient[len("mcp.") :]
events.append( events.append(
ResponseOutputItemAddedEvent( ResponseTextDoneEvent(
type="response.output_item.added", type="response.output_text.done",
sequence_number=-1, sequence_number=-1,
output_index=state.current_output_index, output_index=state.current_output_index,
item=McpCall( content_index=state.current_content_index,
type="mcp_call", text=text,
logprobs=[],
item_id=state.current_item_id,
)
)
events.append(
ResponseContentPartDoneEvent(
type="response.content_part.done",
sequence_number=-1,
item_id=state.current_item_id,
output_index=state.current_output_index,
content_index=state.current_content_index,
part=text_content,
)
)
events.append(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=state.current_output_index,
item=ResponseOutputMessage(
id=state.current_item_id, id=state.current_item_id,
name=mcp_name, type="message",
arguments="", role="assistant",
server_label=mcp_name, content=[text_content],
status="in_progress", status="completed",
), ),
) )
) )
return events
def emit_reasoning_done_events(
text: str,
state: StreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events when a reasoning (analysis) item completes."""
content = ResponseReasoningTextContent(
text=text,
type="reasoning_text",
)
reasoning_item = ResponseReasoningItem(
type="reasoning",
content=[content],
status="completed",
id=state.current_item_id,
summary=[],
)
events: list[StreamingResponsesResponse] = []
events.append( events.append(
ResponseMcpCallInProgressEvent( ResponseReasoningTextDoneEvent(
type="response.mcp_call.in_progress", type="response.reasoning_text.done",
item_id=state.current_item_id,
sequence_number=-1, sequence_number=-1,
output_index=state.current_output_index, output_index=state.current_output_index,
content_index=state.current_content_index,
text=text,
)
)
events.append(
ResponseReasoningPartDoneEvent(
type="response.reasoning_part.done",
sequence_number=-1,
item_id=state.current_item_id, item_id=state.current_item_id,
output_index=state.current_output_index,
content_index=state.current_content_index,
part=content,
) )
) )
events.append( events.append(
ResponseMcpCallArgumentsDeltaEvent( ResponseOutputItemDoneEvent(
type="response.mcp_call_arguments.delta", type="response.output_item.done",
sequence_number=-1, sequence_number=-1,
output_index=state.current_output_index, output_index=state.current_output_index,
item_id=state.current_item_id, item=reasoning_item,
delta=ctx.last_content_delta,
) )
) )
return events return events
def emit_function_call_delta_events( def emit_function_call_done_events(
ctx: StreamingHarmonyContext, function_name: str,
state: HarmonyStreamingState, arguments: str,
state: StreamingState,
) -> list[StreamingResponsesResponse]: ) -> list[StreamingResponsesResponse]:
"""Emit events for developer function calls on commentary channel.""" """Emit events when a function call completes."""
if not (
ctx.parser.current_channel == "commentary"
and ctx.parser.current_recipient
and ctx.parser.current_recipient.startswith("functions.")
):
return []
events: list[StreamingResponsesResponse] = [] events: list[StreamingResponsesResponse] = []
if state.is_first_function_call_delta is False: events.append(
state.is_first_function_call_delta = True ResponseFunctionCallArgumentsDoneEvent(
fc_name = ctx.parser.current_recipient[len("functions.") :] type="response.function_call_arguments.done",
state.current_item_id = f"fc_{random_uuid()}" arguments=arguments,
tool_call_item = ResponseFunctionToolCall( name=function_name,
name=fc_name, item_id=state.current_item_id,
output_index=state.current_output_index,
sequence_number=-1,
)
)
function_call_item = ResponseFunctionToolCall(
type="function_call", type="function_call",
id=state.current_item_id, arguments=arguments,
call_id=f"call_{random_uuid()}", name=function_name,
arguments="", item_id=state.current_item_id,
status="in_progress", output_index=state.current_output_index,
sequence_number=-1,
call_id=state.current_call_id,
status="completed",
) )
events.append( events.append(
ResponseOutputItemAddedEvent( ResponseOutputItemDoneEvent(
type="response.output_item.added", type="response.output_item.done",
sequence_number=-1, sequence_number=-1,
output_index=state.current_output_index, output_index=state.current_output_index,
item=tool_call_item, item=function_call_item,
) )
) )
# Always emit the delta (including on first call) return events
def emit_mcp_completion_events(
recipient: str,
arguments: str,
state: StreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events when an MCP tool call completes."""
name, server_label = _resolve_mcp_name_label(recipient)
events: list[StreamingResponsesResponse] = []
events.append( events.append(
ResponseFunctionCallArgumentsDeltaEvent( ResponseMcpCallArgumentsDoneEvent(
type="response.mcp_call_arguments.done",
arguments=arguments,
name=name,
item_id=state.current_item_id, item_id=state.current_item_id,
delta=ctx.last_content_delta,
output_index=state.current_output_index, output_index=state.current_output_index,
sequence_number=-1, sequence_number=-1,
type="response.function_call_arguments.delta", )
)
events.append(
ResponseMcpCallCompletedEvent(
type="response.mcp_call.completed",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
)
)
events.append(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=state.current_output_index,
item=McpCall(
type="mcp_call",
arguments=arguments,
name=name,
id=state.current_item_id,
server_label=server_label,
status="completed",
),
) )
) )
return events return events
# =====================================================================
# Harmony-specific dispatchers
# =====================================================================
def emit_content_delta_events( def emit_content_delta_events(
ctx: StreamingHarmonyContext, ctx: StreamingHarmonyContext,
state: HarmonyStreamingState, state: StreamingState,
) -> list[StreamingResponsesResponse]: ) -> list[StreamingResponsesResponse]:
"""Emit events for content delta streaming based on channel type.""" """Emit events for content delta streaming based on channel type.
if not ctx.last_content_delta:
This is a Harmony-specific dispatcher that extracts values from the
Harmony context and delegates to shared leaf helpers.
"""
delta = ctx.last_content_delta
if not delta:
return [] return []
if ctx.parser.current_channel == "final" and ctx.parser.current_recipient is None: channel = ctx.parser.current_channel
return emit_final_channel_delta_events(ctx, state) recipient = ctx.parser.current_recipient
elif (
ctx.parser.current_channel == "analysis" if channel == "final" and recipient is None:
and ctx.parser.current_recipient is None return emit_text_delta_events(delta, state)
): elif channel == "analysis" and recipient is None:
return emit_analysis_channel_delta_events(ctx, state) return emit_reasoning_delta_events(delta, state)
# built-in tools will be triggered on the analysis channel # built-in tools will be triggered on the analysis channel
# However, occasionally built-in tools will # However, occasionally built-in tools will
# still be output to commentary. # still be output to commentary.
elif ( elif channel in ("commentary", "analysis") and recipient is not None:
ctx.parser.current_channel == "commentary"
or ctx.parser.current_channel == "analysis"
) and ctx.parser.current_recipient is not None:
recipient = ctx.parser.current_recipient
# Check for function calls first - they have their own event handling
if recipient.startswith("functions."): if recipient.startswith("functions."):
return emit_function_call_delta_events(ctx, state) function_name = recipient[len("functions.") :]
if is_mcp_tool_by_namespace(recipient): return emit_function_call_delta_events(delta, function_name, state)
return emit_mcp_tool_delta_events(ctx, state, recipient) elif recipient == "python":
else: return emit_code_interpreter_delta_events(delta, state)
return emit_code_interpreter_delta_events(ctx, state) elif recipient.startswith("mcp.") or is_mcp_tool_by_namespace(recipient):
return emit_mcp_delta_events(delta, state, recipient)
return []
def emit_previous_item_done_events(
previous_item: HarmonyMessage,
state: StreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit done events for the previous item when expecting a new start.
This is a Harmony-specific dispatcher that extracts values from the
Harmony parser's message object and delegates to shared leaf helpers.
"""
text = previous_item.content[0].text
if previous_item.recipient is not None:
# Deal with tool call
if previous_item.recipient.startswith("functions."):
function_name = previous_item.recipient[len("functions.") :]
return emit_function_call_done_events(function_name, text, state)
elif previous_item.recipient == "python":
return emit_code_interpreter_completion_events(previous_item, state)
elif ( elif (
( is_mcp_tool_by_namespace(previous_item.recipient)
ctx.parser.current_channel == "commentary" and state.current_item_id is not None
or ctx.parser.current_channel == "analysis" and state.current_item_id.startswith("mcp_")
)
and ctx.parser.current_recipient is not None
and ctx.parser.current_recipient.startswith("mcp.")
): ):
return emit_mcp_prefix_delta_events(ctx, state) return emit_mcp_completion_events(previous_item.recipient, text, state)
elif previous_item.channel == "analysis":
return emit_reasoning_done_events(text, state)
elif previous_item.channel == "final":
return emit_text_output_done_events(text, state)
return [] return []
# =====================================================================
# Harmony-specific tool lifecycle helpers
# =====================================================================
def emit_browser_tool_events( def emit_browser_tool_events(
previous_item, previous_item: HarmonyMessage,
state: HarmonyStreamingState, state: StreamingState,
) -> list[StreamingResponsesResponse]: ) -> list[StreamingResponsesResponse]:
"""Emit events for browser tool calls (web search).""" """Emit events for browser tool calls (web search)."""
function_name = previous_item.recipient[len("browser.") :] function_name = previous_item.recipient[len("browser.") :]
...@@ -714,53 +704,9 @@ def emit_browser_tool_events( ...@@ -714,53 +704,9 @@ def emit_browser_tool_events(
return events return events
def emit_mcp_tool_completion_events(
previous_item,
state: HarmonyStreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events when an MCP tool completes during assistant action turn."""
recipient = previous_item.recipient
server_label = TOOL_NAME_TO_MCP_SERVER_LABEL.get(recipient, recipient)
events: list[StreamingResponsesResponse] = []
events.append(
ResponseMcpCallArgumentsDoneEvent(
type="response.mcp_call_arguments.done",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
arguments=previous_item.content[0].text,
name=recipient,
)
)
events.append(
ResponseMcpCallCompletedEvent(
type="response.mcp_call.completed",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
)
)
events.append(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=state.current_output_index,
item=McpCall(
type="mcp_call",
id=state.current_item_id,
name=recipient,
arguments=previous_item.content[0].text,
server_label=server_label,
status="completed",
),
)
)
return events
def emit_code_interpreter_completion_events( def emit_code_interpreter_completion_events(
previous_item, previous_item: HarmonyMessage,
state: HarmonyStreamingState, state: StreamingState,
) -> list[StreamingResponsesResponse]: ) -> list[StreamingResponsesResponse]:
"""Emit events when code interpreter completes.""" """Emit events when code interpreter completes."""
events: list[StreamingResponsesResponse] = [] events: list[StreamingResponsesResponse] = []
...@@ -807,52 +753,9 @@ def emit_code_interpreter_completion_events( ...@@ -807,52 +753,9 @@ def emit_code_interpreter_completion_events(
return events return events
def emit_mcp_prefix_completion_events(
previous_item,
state: HarmonyStreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events when an MCP prefix tool (mcp.*) completes."""
mcp_name = previous_item.recipient[len("mcp.") :]
events: list[StreamingResponsesResponse] = []
events.append(
ResponseMcpCallArgumentsDoneEvent(
type="response.mcp_call_arguments.done",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
arguments=previous_item.content[0].text,
name=mcp_name,
)
)
events.append(
ResponseMcpCallCompletedEvent(
type="response.mcp_call.completed",
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
)
)
events.append(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=state.current_output_index,
item=McpCall(
type="mcp_call",
id=state.current_item_id,
name=mcp_name,
arguments=previous_item.content[0].text,
server_label=mcp_name,
status="completed",
),
)
)
return events
def emit_tool_action_events( def emit_tool_action_events(
ctx: StreamingHarmonyContext, ctx: StreamingHarmonyContext,
state: HarmonyStreamingState, state: StreamingState,
tool_server: ToolServer | None, tool_server: ToolServer | None,
) -> list[StreamingResponsesResponse]: ) -> list[StreamingResponsesResponse]:
"""Emit events for tool action turn.""" """Emit events for tool action turn."""
...@@ -879,19 +782,13 @@ def emit_tool_action_events( ...@@ -879,19 +782,13 @@ def emit_tool_action_events(
and state.sent_output_item_added and state.sent_output_item_added
): ):
recipient = previous_item.recipient recipient = previous_item.recipient
# Handle MCP prefix tool completion first if recipient == "python":
if recipient.startswith("mcp."): events.extend(emit_code_interpreter_completion_events(previous_item, state))
events.extend(emit_mcp_prefix_completion_events(previous_item, state)) elif recipient.startswith("mcp.") or is_mcp_tool_by_namespace(recipient):
else:
# Handle other MCP tool and code interpreter completion
is_mcp_tool = is_mcp_tool_by_namespace(
recipient
) and state.current_item_id.startswith("mcp_")
if is_mcp_tool:
events.extend(emit_mcp_tool_completion_events(previous_item, state))
else:
events.extend( events.extend(
emit_code_interpreter_completion_events(previous_item, state) emit_mcp_completion_events(
recipient, previous_item.content[0].text, state
)
) )
return events return events
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment