Unverified Commit ec1d30c0 authored by Flora Feng's avatar Flora Feng Committed by GitHub
Browse files

[Responses] Decouple SSE event helpers from Harmony context (#35148)


Signed-off-by: default avatarsfeng33 <4florafeng@gmail.com>
parent e3b2324e
...@@ -39,6 +39,7 @@ def pairs_of_event_types() -> dict[str, str]: ...@@ -39,6 +39,7 @@ def pairs_of_event_types() -> dict[str, str]:
"response.mcp_call.completed": "response.mcp_call.in_progress", "response.mcp_call.completed": "response.mcp_call.in_progress",
"response.function_call_arguments.done": "response.function_call_arguments.delta", # noqa: E501 "response.function_call_arguments.done": "response.function_call_arguments.delta", # noqa: E501
"response.code_interpreter_call_code.done": "response.code_interpreter_call_code.delta", # noqa: E501 "response.code_interpreter_call_code.done": "response.code_interpreter_call_code.delta", # noqa: E501
"response.code_interpreter_call.completed": "response.code_interpreter_call.in_progress", # noqa: E501
"response.web_search_call.completed": "response.web_search_call.in_progress", "response.web_search_call.completed": "response.web_search_call.in_progress",
} }
# fmt: on # fmt: on
...@@ -108,29 +109,19 @@ def events_contain_type(events: list, type_substring: str) -> bool: ...@@ -108,29 +109,19 @@ def events_contain_type(events: list, type_substring: str) -> bool:
return any(type_substring in getattr(e, "type", "") for e in events) return any(type_substring in getattr(e, "type", "") for e in events)
def validate_streaming_event_stack( def _validate_event_pairing(events: list, pairs_of_event_types: dict[str, str]) -> None:
events: list, pairs_of_event_types: dict[str, str] """Validate that streaming events are properly nested/paired.
) -> None:
"""Validate that streaming events are properly nested/paired.""" Derives push/pop sets from *pairs_of_event_types* so that every
start/end pair in the dict is handled automatically.
"""
start_events = set(pairs_of_event_types.values())
end_events = set(pairs_of_event_types.keys())
stack: list[str] = [] stack: list[str] = []
for event in events: for event in events:
etype = event.type etype = event.type
if etype == "response.created": if etype in end_events:
stack.append(etype)
elif etype == "response.completed":
assert stack and stack[-1] == pairs_of_event_types[etype], (
f"Unexpected stack top for {etype}: "
f"got {stack[-1] if stack else '<empty>'}"
)
stack.pop()
elif etype.endswith("added") or etype == "response.mcp_call.in_progress":
stack.append(etype)
elif etype.endswith("delta"):
if stack and stack[-1] == etype:
continue
stack.append(etype)
elif etype.endswith("done") or etype == "response.mcp_call.completed":
assert etype in pairs_of_event_types, f"Unknown done event: {etype}"
expected_start = pairs_of_event_types[etype] expected_start = pairs_of_event_types[etype]
assert stack and stack[-1] == expected_start, ( assert stack and stack[-1] == expected_start, (
f"Stack mismatch for {etype}: " f"Stack mismatch for {etype}: "
...@@ -138,9 +129,180 @@ def validate_streaming_event_stack( ...@@ -138,9 +129,180 @@ def validate_streaming_event_stack(
f"got {stack[-1] if stack else '<empty>'}" f"got {stack[-1] if stack else '<empty>'}"
) )
stack.pop() stack.pop()
elif etype in start_events:
# Consecutive deltas of the same type share a single stack slot.
if etype.endswith("delta") and stack and stack[-1] == etype:
continue
stack.append(etype)
# else: passthrough event (e.g. response.in_progress,
# web_search_call.searching, code_interpreter_call.interpreting)
assert len(stack) == 0, f"Unclosed events on stack: {stack}" assert len(stack) == 0, f"Unclosed events on stack: {stack}"
def _validate_event_ordering(events: list) -> None:
"""Validate that envelope events appear in the correct positions."""
assert len(events) >= 2, f"Expected at least 2 events, got {len(events)}"
# First event must be response.created
assert events[0].type == "response.created", (
f"First event must be response.created, got {events[0].type}"
)
# Last event must be response.completed
assert events[-1].type == "response.completed", (
f"Last event must be response.completed, got {events[-1].type}"
)
# response.in_progress, if present, must be the second event
in_progress_indices = [
i for i, e in enumerate(events) if e.type == "response.in_progress"
]
if in_progress_indices:
assert in_progress_indices == [1], (
f"response.in_progress must be the second event, "
f"found at indices {in_progress_indices}"
)
# Exactly one created and one completed
created_count = sum(1 for e in events if e.type == "response.created")
completed_count = sum(1 for e in events if e.type == "response.completed")
assert created_count == 1, (
f"Expected exactly 1 response.created, got {created_count}"
)
assert completed_count == 1, (
f"Expected exactly 1 response.completed, got {completed_count}"
)
def _validate_field_consistency(events: list) -> None:
"""Validate item_id, output_index, and content_index consistency.
Tracks the active output item established by ``output_item.added``
and verifies that all subsequent events for that item carry matching
identifiers until ``output_item.done`` closes it.
"""
_SESSION_EVENTS = {
"response.created",
"response.in_progress",
"response.completed",
}
active_item_id: str | None = None
active_output_index: int | None = None
last_output_index: int = -1
active_content_index: int | None = None
for event in events:
etype = event.type
if etype in _SESSION_EVENTS:
continue
# --- output_item.added: opens a new item ------------------
if etype == "response.output_item.added":
item = getattr(event, "item", None)
output_index = getattr(event, "output_index", None)
assert item is not None, "output_item.added must have an item"
item_id = getattr(item, "id", None)
assert item_id, "output_item.added item must have an id"
# output_index must be non-decreasing across items
if output_index is not None:
assert output_index >= last_output_index, (
f"output_index went backwards: {output_index} < {last_output_index}"
)
last_output_index = output_index
active_item_id = item_id
active_output_index = output_index
active_content_index = None
continue
# --- output_item.done: closes the active item -------------
if etype == "response.output_item.done":
item = getattr(event, "item", None)
output_index = getattr(event, "output_index", None)
assert item is not None, "output_item.done must have an item"
done_item_id = getattr(item, "id", None)
if active_item_id is not None and done_item_id:
assert done_item_id == active_item_id, (
f"output_item.done item.id mismatch: "
f"expected {active_item_id}, got {done_item_id}"
)
if active_output_index is not None and output_index is not None:
assert output_index == active_output_index, (
f"output_item.done output_index mismatch: "
f"expected {active_output_index}, got {output_index}"
)
active_item_id = None
active_output_index = None
active_content_index = None
continue
# --- content_part / reasoning_part added: sets content_index
if etype in (
"response.content_part.added",
"response.reasoning_part.added",
):
_assert_item_fields(event, etype, active_item_id, active_output_index)
active_content_index = getattr(event, "content_index", None)
continue
# --- all other item-level events --------------------------
_assert_item_fields(event, etype, active_item_id, active_output_index)
# content_index (only meaningful on events that carry it)
content_index = getattr(event, "content_index", None)
if content_index is not None and active_content_index is not None:
assert content_index == active_content_index, (
f"{etype} content_index mismatch: "
f"expected {active_content_index}, got {content_index}"
)
def _assert_item_fields(
event,
etype: str,
active_item_id: str | None,
active_output_index: int | None,
) -> None:
"""Check that *event*'s item_id and output_index match the active item."""
event_item_id = getattr(event, "item_id", None)
output_index = getattr(event, "output_index", None)
if active_item_id is not None and event_item_id is not None:
assert event_item_id == active_item_id, (
f"{etype} item_id mismatch: expected {active_item_id}, got {event_item_id}"
)
if active_output_index is not None and output_index is not None:
assert output_index == active_output_index, (
f"{etype} output_index mismatch: "
f"expected {active_output_index}, got {output_index}"
)
def validate_streaming_event_stack(
events: list, pairs_of_event_types: dict[str, str]
) -> None:
"""Validate streaming events: pairing, ordering, and field consistency.
Checks three aspects:
1. **Event pairing** — start/end events are properly nested
(stack-based matching derived from *pairs_of_event_types*).
2. **Event ordering** — envelope events (``created``,
``in_progress``, ``completed``) appear at the correct positions.
3. **Field consistency** — ``item_id``, ``output_index``, and
``content_index`` are consistent across related events within
each output item's lifecycle.
"""
_validate_event_pairing(events, pairs_of_event_types)
_validate_event_ordering(events)
_validate_field_consistency(events)
def log_response_diagnostics( def log_response_diagnostics(
response, response,
*, *,
......
...@@ -910,21 +910,25 @@ async def test_function_calling_no_code_interpreter_events( ...@@ -910,21 +910,25 @@ async def test_function_calling_no_code_interpreter_events(
reason="This test is flaky in CI, needs investigation and " reason="This test is flaky in CI, needs investigation and "
"potential fixes in the code interpreter MCP implementation." "potential fixes in the code interpreter MCP implementation."
) )
async def test_mcp_code_interpreter_streaming(client: OpenAI, model_name: str, server): async def test_code_interpreter_streaming(
tools = [{"type": "mcp", "server_label": "code_interpreter"}] client: OpenAI,
model_name: str,
pairs_of_event_types: dict[str, str],
):
tools = [{"type": "code_interpreter", "container": {"type": "auto"}}]
input_text = ( input_text = (
"Calculate 123 * 456 using python. " "Calculate 123 * 456 using python. "
"The python interpreter is not stateful and you must " "The python interpreter is not stateful and you must "
"print to see the output." "print to see the output."
) )
def _has_mcp_call(evts: list) -> bool: def _has_code_interpreter(evts: list) -> bool:
return events_contain_type(evts, "mcp_call") return events_contain_type(evts, "code_interpreter")
events = await retry_streaming_for( events = await retry_streaming_for(
client, client,
model=model_name, model=model_name,
validate_events=_has_mcp_call, validate_events=_has_code_interpreter,
input=input_text, input=input_text,
tools=tools, tools=tools,
temperature=0.0, temperature=0.0,
...@@ -936,59 +940,36 @@ async def test_mcp_code_interpreter_streaming(client: OpenAI, model_name: str, s ...@@ -936,59 +940,36 @@ async def test_mcp_code_interpreter_streaming(client: OpenAI, model_name: str, s
event_types = [e.type for e in events] event_types = [e.type for e in events]
event_types_set = set(event_types) event_types_set = set(event_types)
logger.info( logger.info(
"\n====== MCP Streaming Diagnostics ======\n" "\n====== Code Interpreter Streaming Diagnostics ======\n"
"Event count: %d\n" "Event count: %d\n"
"Event types (in order): %s\n" "Event types (in order): %s\n"
"Unique event types: %s\n" "Unique event types: %s\n"
"=======================================", "====================================================",
len(events), len(events),
event_types, event_types,
sorted(event_types_set), sorted(event_types_set),
) )
# Verify the full MCP streaming lifecycle # Structural validation (pairing, ordering, field consistency)
assert "response.output_item.added" in event_types_set, ( validate_streaming_event_stack(events, pairs_of_event_types)
f"MCP call was not added. Events: {sorted(event_types_set)}"
)
assert "response.mcp_call.in_progress" in event_types_set, (
f"MCP call in_progress not seen. Events: {sorted(event_types_set)}"
)
assert "response.mcp_call_arguments.delta" in event_types_set, (
f"MCP arguments delta not seen. Events: {sorted(event_types_set)}"
)
assert "response.mcp_call_arguments.done" in event_types_set, (
f"MCP arguments done not seen. Events: {sorted(event_types_set)}"
)
assert "response.mcp_call.completed" in event_types_set, (
f"MCP call completed not seen. Events: {sorted(event_types_set)}"
)
assert "response.output_item.done" in event_types_set, (
f"MCP item done not seen. Events: {sorted(event_types_set)}"
)
# Validate specific MCP event details # Validate code interpreter item fields
for event in events: for event in events:
if event.type == "response.output_item.added": if (
if hasattr(event.item, "type") and event.item.type == "mcp_call": event.type == "response.output_item.added"
assert event.item.name == "python" and hasattr(event.item, "type")
assert event.item.server_label == "code_interpreter" and event.item.type == "code_interpreter_call"
elif event.type == "response.mcp_call_arguments.done": ):
assert event.name == "python" assert event.item.status == "in_progress"
assert event.arguments is not None elif event.type == "response.code_interpreter_call_code.done":
assert event.code is not None
elif ( elif (
event.type == "response.output_item.done" event.type == "response.output_item.done"
and hasattr(event.item, "type") and hasattr(event.item, "type")
and event.item.type == "mcp_call" and event.item.type == "code_interpreter_call"
): ):
assert event.item.name == "python"
assert event.item.status == "completed" assert event.item.status == "completed"
assert event.item.code is not None
# code_interpreter events should NOT appear when using MCP type
code_interp_events = [e.type for e in events if "code_interpreter" in e.type]
assert not code_interp_events, (
"Should not see code_interpreter events when using MCP type, "
f"but got: {code_interp_events}"
)
@pytest.mark.asyncio @pytest.mark.asyncio
......
...@@ -241,81 +241,3 @@ class TestMCPEnabled: ...@@ -241,81 +241,3 @@ class TestMCPEnabled:
) )
validate_streaming_event_stack(events, pairs_of_event_types) validate_streaming_event_stack(events, pairs_of_event_types)
assert events_contain_type(events, "mcp_call"), (
f"No mcp_call events after retries. "
f"Event types: {sorted({e.type for e in events})}"
)
class TestMCPDisabled:
"""Tests that MCP tools are not executed when the env flag is unset."""
@pytest.fixture(scope="class")
def mcp_disabled_server(self):
env_dict = {
**BASE_TEST_ENV,
"VLLM_ENABLE_RESPONSES_API_STORE": "1",
"PYTHON_EXECUTION_BACKEND": "dangerously_use_uv",
"VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS": "1",
}
with RemoteOpenAIServer(
MODEL_NAME, list(_BASE_SERVER_ARGS), env_dict=env_dict
) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(self, mcp_disabled_server):
async with mcp_disabled_server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_mcp_disabled_server_does_not_execute(
self, client: OpenAI, model_name: str
):
"""When MCP is disabled the model may still attempt tool calls
(tool descriptions can remain in the prompt), but the server
must NOT execute them."""
response = await client.responses.create(
model=model_name,
input=(
"Execute the following code if the tool is present: "
"import random; print(random.randint(1, 1000000))"
),
tools=[
{
"type": "mcp",
"server_label": "code_interpreter",
"server_url": "http://localhost:8888",
}
],
temperature=0.0,
extra_body={"enable_response_messages": True},
)
assert response is not None
assert response.status == "completed"
log_response_diagnostics(response, label="MCP Disabled")
# Server must not have executed any tool calls
for message in response.output_messages:
author = message.get("author", {})
assert not (
author.get("role") == "tool"
and (author.get("name") or "").startswith("python")
), (
"Server executed a python tool call even though MCP is "
f"disabled. Message: {message}"
)
# No completed mcp_call output items
for item in response.output:
if getattr(item, "type", None) == "mcp_call":
assert getattr(item, "status", None) != "completed", (
"MCP call should not be completed when MCP is disabled"
)
# No developer messages injected
for message in response.input_messages:
assert message.get("author", {}).get("role") != "developer"
...@@ -89,7 +89,7 @@ from vllm.entrypoints.openai.responses.protocol import ( ...@@ -89,7 +89,7 @@ from vllm.entrypoints.openai.responses.protocol import (
StreamingResponsesResponse, StreamingResponsesResponse,
) )
from vllm.entrypoints.openai.responses.streaming_events import ( from vllm.entrypoints.openai.responses.streaming_events import (
HarmonyStreamingState, StreamingState,
emit_content_delta_events, emit_content_delta_events,
emit_previous_item_done_events, emit_previous_item_done_events,
emit_tool_action_events, emit_tool_action_events,
...@@ -1591,7 +1591,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -1591,7 +1591,7 @@ class OpenAIServingResponses(OpenAIServing):
[StreamingResponsesResponse], StreamingResponsesResponse [StreamingResponsesResponse], StreamingResponsesResponse
], ],
) -> AsyncGenerator[StreamingResponsesResponse, None]: ) -> AsyncGenerator[StreamingResponsesResponse, None]:
state = HarmonyStreamingState() state = StreamingState()
async for ctx in result_generator: async for ctx in result_generator:
assert isinstance(ctx, StreamingHarmonyContext) assert isinstance(ctx, StreamingHarmonyContext)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment