Unverified Commit 991d6bff authored by Andreas Karatzas's avatar Andreas Karatzas Committed by GitHub
Browse files

[CI][MCP][Harmony] Heavy refactoring Harmony & MCP response tests and...


[CI][MCP][Harmony] Heavy refactoring Harmony & MCP response tests and stabilizing with deterministic test infrastructure (#33949)
Signed-off-by: default avatarAndreas Karatzas <akaratza@amd.com>
parent 5719a4e4
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import json
import logging
from collections.abc import Callable
from typing import Any
import pytest import pytest
logger = logging.getLogger(__name__)
BASE_TEST_ENV = {
# The day vLLM said "hello world" on arxiv 🚀
"VLLM_SYSTEM_START_DATE": "2023-09-12",
}
DEFAULT_MAX_RETRIES = 3
@pytest.fixture @pytest.fixture
def pairs_of_event_types() -> dict[str, str]: def pairs_of_event_types() -> dict[str, str]:
...@@ -28,3 +43,159 @@ def pairs_of_event_types() -> dict[str, str]: ...@@ -28,3 +43,159 @@ def pairs_of_event_types() -> dict[str, str]:
} }
# fmt: on # fmt: on
return event_pairs return event_pairs
async def retry_for_tool_call(
client,
*,
model: str,
expected_tool_type: str,
max_retries: int = DEFAULT_MAX_RETRIES,
**create_kwargs: Any,
):
"""Call ``client.responses.create`` up to *max_retries* times, returning
the first response that contains an output item of *expected_tool_type*.
Returns the **last** response if none match so the caller's assertions
fire with a clear diagnostic.
"""
last_response = None
for attempt in range(max_retries):
response = await client.responses.create(model=model, **create_kwargs)
last_response = response
if any(
getattr(item, "type", None) == expected_tool_type
for item in response.output
):
return response
assert last_response is not None
return last_response
async def retry_streaming_for(
client,
*,
model: str,
validate_events: Callable[[list], bool],
max_retries: int = DEFAULT_MAX_RETRIES,
**create_kwargs: Any,
) -> list:
"""Call ``client.responses.create(stream=True)`` up to *max_retries*
times, returning the first event list where *validate_events* returns
``True``.
"""
last_events: list = []
for attempt in range(max_retries):
stream = await client.responses.create(
model=model, stream=True, **create_kwargs
)
events: list = []
async for event in stream:
events.append(event)
last_events = events
if validate_events(events):
return events
return last_events
def has_output_type(response, type_name: str) -> bool:
"""Return True if *response* has at least one output item of *type_name*."""
return any(getattr(item, "type", None) == type_name for item in response.output)
def events_contain_type(events: list, type_substring: str) -> bool:
"""Return True if any event's type contains *type_substring*."""
return any(type_substring in getattr(e, "type", "") for e in events)
def validate_streaming_event_stack(
events: list, pairs_of_event_types: dict[str, str]
) -> None:
"""Validate that streaming events are properly nested/paired."""
stack: list[str] = []
for event in events:
etype = event.type
if etype == "response.created":
stack.append(etype)
elif etype == "response.completed":
assert stack and stack[-1] == pairs_of_event_types[etype], (
f"Unexpected stack top for {etype}: "
f"got {stack[-1] if stack else '<empty>'}"
)
stack.pop()
elif etype.endswith("added") or etype == "response.mcp_call.in_progress":
stack.append(etype)
elif etype.endswith("delta"):
if stack and stack[-1] == etype:
continue
stack.append(etype)
elif etype.endswith("done") or etype == "response.mcp_call.completed":
assert etype in pairs_of_event_types, f"Unknown done event: {etype}"
expected_start = pairs_of_event_types[etype]
assert stack and stack[-1] == expected_start, (
f"Stack mismatch for {etype}: "
f"expected {expected_start}, "
f"got {stack[-1] if stack else '<empty>'}"
)
stack.pop()
assert len(stack) == 0, f"Unclosed events on stack: {stack}"
def log_response_diagnostics(
response,
*,
label: str = "Response Diagnostics",
) -> dict[str, Any]:
"""Extract and log diagnostic info from a Responses API response.
Logs reasoning, tool-call attempts, MCP items, and output types so
that CI output (``pytest -s`` or ``--log-cli-level=INFO``) gives
full visibility into model behaviour even on passing runs.
Returns the extracted data so callers can make additional assertions
if needed.
"""
reasoning_texts = [
text
for item in response.output
if getattr(item, "type", None) == "reasoning"
for content in getattr(item, "content", [])
if (text := getattr(content, "text", None))
]
tool_call_attempts = [
{
"recipient": msg.get("recipient"),
"channel": msg.get("channel"),
}
for msg in response.output_messages
if (msg.get("recipient") or "").startswith("python")
]
mcp_items = [
{
"name": getattr(item, "name", None),
"status": getattr(item, "status", None),
}
for item in response.output
if getattr(item, "type", None) == "mcp_call"
]
output_types = [getattr(o, "type", None) for o in response.output]
diagnostics = {
"model_attempted_tool_calls": bool(tool_call_attempts),
"tool_call_attempts": tool_call_attempts,
"mcp_items": mcp_items,
"reasoning": reasoning_texts,
"output_text": response.output_text,
"output_types": output_types,
}
logger.info(
"\n====== %s ======\n%s\n==============================",
label,
json.dumps(diagnostics, indent=2, default=str),
)
return diagnostics
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Integration tests for the Harmony-based Responses API."""
from __future__ import annotations
import importlib.util import importlib.util
import json import json
import logging
import time import time
from typing import Any
import pytest import pytest
import pytest_asyncio import pytest_asyncio
import requests
from openai import BadRequestError, NotFoundError, OpenAI from openai import BadRequestError, NotFoundError, OpenAI
from openai_harmony import ( from openai_harmony import Message
Message,
)
from ....utils import RemoteOpenAIServer from ....utils import RemoteOpenAIServer
from .conftest import (
BASE_TEST_ENV,
events_contain_type,
has_output_type,
retry_for_tool_call,
retry_streaming_for,
validate_streaming_event_stack,
)
logger = logging.getLogger(__name__)
MODEL_NAME = "openai/gpt-oss-20b" MODEL_NAME = "openai/gpt-oss-20b"
...@@ -32,20 +47,72 @@ GET_WEATHER_SCHEMA = { ...@@ -32,20 +47,72 @@ GET_WEATHER_SCHEMA = {
} }
def get_weather(latitude, longitude):
try:
response = requests.get(
f"https://api.open-meteo.com/v1/forecast?"
f"latitude={latitude}&longitude={longitude}"
f"&current=temperature_2m,wind_speed_10m"
f"&hourly=temperature_2m,relative_humidity_2m,"
f"wind_speed_10m",
timeout=10,
)
data = response.json()
return data["current"]["temperature_2m"]
except (requests.RequestException, KeyError) as e:
logger.warning(
"External weather API call failed (%s), "
"returning fake value. This does not affect "
"test correctness — only the tool-calling "
"protocol is under test.",
e,
)
return 15.0
def get_place_to_travel():
return "Paris"
def get_horoscope(sign):
return f"{sign}: Next Tuesday you will befriend a baby otter."
def call_function(name, args):
logger.info("Calling function %s with args %s", name, args)
dispatch = {
"get_weather": lambda: get_weather(**args),
"get_place_to_travel": lambda: get_place_to_travel(),
"get_horoscope": lambda: get_horoscope(**args),
}
if name not in dispatch:
raise ValueError(f"Unknown function: {name}")
result = dispatch[name]()
logger.info("Function %s returned: %s", name, result)
return result
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(): def server():
assert importlib.util.find_spec("gpt_oss") is not None, ( assert importlib.util.find_spec("gpt_oss") is not None, (
"Harmony tests require gpt_oss package to be installed" "Harmony tests require gpt_oss package to be installed"
) )
args = [
args = ["--enforce-eager", "--tool-server", "demo", "--max_model_len", "5000"] "--enforce-eager",
env_dict = dict( "--tool-server",
VLLM_ENABLE_RESPONSES_API_STORE="1", "demo",
PYTHON_EXECUTION_BACKEND="dangerously_use_uv", "--max_model_len",
VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS="code_interpreter,container,web_search_preview", "5000",
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS="1", ]
) env_dict = {
**BASE_TEST_ENV,
"VLLM_ENABLE_RESPONSES_API_STORE": "1",
"PYTHON_EXECUTION_BACKEND": "dangerously_use_uv",
"VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS": (
"code_interpreter,container,web_search_preview"
),
"VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS": "1",
}
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as remote_server: with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as remote_server:
yield remote_server yield remote_server
...@@ -159,7 +226,10 @@ async def test_structured_output(client: OpenAI, model_name: str): ...@@ -159,7 +226,10 @@ async def test_structured_output(client: OpenAI, model_name: str):
"properties": { "properties": {
"name": {"type": "string"}, "name": {"type": "string"},
"date": {"type": "string"}, "date": {"type": "string"},
"participants": {"type": "array", "items": {"type": "string"}}, "participants": {
"type": "array",
"items": {"type": "string"},
},
}, },
"required": ["name", "date", "participants"], "required": ["name", "date", "participants"],
"additionalProperties": False, "additionalProperties": False,
...@@ -210,7 +280,9 @@ async def test_store(client: OpenAI, model_name: str): ...@@ -210,7 +280,9 @@ async def test_store(client: OpenAI, model_name: str):
except NotFoundError: except NotFoundError:
is_not_found = True is_not_found = True
assert is_not_found == (not store) assert is_not_found == (not store), (
f"store={store}: expected not_found={not store}, got {is_not_found}"
)
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -254,10 +326,8 @@ async def test_background_cancel(client: OpenAI, model_name: str): ...@@ -254,10 +326,8 @@ async def test_background_cancel(client: OpenAI, model_name: str):
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_stateful_multi_turn(client: OpenAI, model_name: str): async def test_stateful_multi_turn(client: OpenAI, model_name: str):
response1 = await client.responses.create( response1 = await client.responses.create(
model=model_name, model=model_name, input="What is 123 * 456?"
input="What is 123 * 456?",
) )
assert response1 is not None
assert response1.status == "completed" assert response1.status == "completed"
response2 = await client.responses.create( response2 = await client.responses.create(
...@@ -265,7 +335,6 @@ async def test_stateful_multi_turn(client: OpenAI, model_name: str): ...@@ -265,7 +335,6 @@ async def test_stateful_multi_turn(client: OpenAI, model_name: str):
input="What if I increase both numbers by 1?", input="What if I increase both numbers by 1?",
previous_response_id=response1.id, previous_response_id=response1.id,
) )
assert response2 is not None
assert response2.status == "completed" assert response2.status == "completed"
response3 = await client.responses.create( response3 = await client.responses.create(
...@@ -273,7 +342,6 @@ async def test_stateful_multi_turn(client: OpenAI, model_name: str): ...@@ -273,7 +342,6 @@ async def test_stateful_multi_turn(client: OpenAI, model_name: str):
input="Divide the result by 2.", input="Divide the result by 2.",
previous_response_id=response2.id, previous_response_id=response2.id,
) )
assert response3 is not None
assert response3.status == "completed" assert response3.status == "completed"
...@@ -282,37 +350,19 @@ async def test_stateful_multi_turn(client: OpenAI, model_name: str): ...@@ -282,37 +350,19 @@ async def test_stateful_multi_turn(client: OpenAI, model_name: str):
async def test_streaming_types( async def test_streaming_types(
pairs_of_event_types: dict[str, str], client: OpenAI, model_name: str pairs_of_event_types: dict[str, str], client: OpenAI, model_name: str
): ):
prompts = [ stream = await client.responses.create(
"tell me a story about a cat in 20 words", model=model_name,
] input="tell me a story about a cat in 20 words",
reasoning={"effort": "low"},
for prompt in prompts: tools=[],
response = await client.responses.create( stream=True,
model=model_name, background=False,
input=prompt, )
reasoning={"effort": "low"}, events = []
tools=[], async for event in stream:
stream=True, events.append(event)
background=False,
)
stack_of_event_types = [] validate_streaming_event_stack(events, pairs_of_event_types)
async for event in response:
if event.type == "response.created":
stack_of_event_types.append(event.type)
elif event.type == "response.completed":
assert stack_of_event_types[-1] == pairs_of_event_types[event.type]
stack_of_event_types.pop()
if event.type.endswith("added"):
stack_of_event_types.append(event.type)
elif event.type.endswith("delta"):
if stack_of_event_types[-1] == event.type:
continue
stack_of_event_types.append(event.type)
elif event.type.endswith("done"):
assert stack_of_event_types[-1] == pairs_of_event_types[event.type]
stack_of_event_types.pop()
assert len(stack_of_event_types) == 0
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -320,37 +370,21 @@ async def test_streaming_types( ...@@ -320,37 +370,21 @@ async def test_streaming_types(
async def test_function_calling_with_streaming_types( async def test_function_calling_with_streaming_types(
pairs_of_event_types: dict[str, str], client: OpenAI, model_name: str pairs_of_event_types: dict[str, str], client: OpenAI, model_name: str
): ):
tools = [GET_WEATHER_SCHEMA] """Streaming event nesting for function-calling responses."""
input_list = [
{ def _has_function_events(evts: list) -> bool:
"role": "user", return events_contain_type(evts, "function_call_arguments")
"content": "What's the weather like in Paris today?",
} events = await retry_streaming_for(
] client,
stream_response = await client.responses.create(
model=model_name, model=model_name,
input=input_list, validate_events=_has_function_events,
tools=tools, input=[{"role": "user", "content": "What's the weather like in Paris today?"}],
stream=True, tools=[GET_WEATHER_SCHEMA],
temperature=0.0,
) )
stack_of_event_types = [] validate_streaming_event_stack(events, pairs_of_event_types)
async for event in stream_response:
if event.type == "response.created":
stack_of_event_types.append(event.type)
elif event.type == "response.completed":
assert stack_of_event_types[-1] == pairs_of_event_types[event.type]
stack_of_event_types.pop()
if event.type.endswith("added"):
stack_of_event_types.append(event.type)
elif event.type.endswith("delta"):
if stack_of_event_types[-1] == event.type:
continue
stack_of_event_types.append(event.type)
elif event.type.endswith("done"):
assert stack_of_event_types[-1] == pairs_of_event_types[event.type]
stack_of_event_types.pop()
assert len(stack_of_event_types) == 0
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -365,7 +399,7 @@ async def test_streaming(client: OpenAI, model_name: str, background: bool): ...@@ -365,7 +399,7 @@ async def test_streaming(client: OpenAI, model_name: str, background: bool):
] ]
for prompt in prompts: for prompt in prompts:
response = await client.responses.create( stream = await client.responses.create(
model=model_name, model=model_name,
input=prompt, input=prompt,
reasoning={"effort": "low"}, reasoning={"effort": "low"},
...@@ -387,11 +421,12 @@ async def test_streaming(client: OpenAI, model_name: str, background: bool): ...@@ -387,11 +421,12 @@ async def test_streaming(client: OpenAI, model_name: str, background: bool):
current_event_mode = None current_event_mode = None
resp_id = None resp_id = None
checked_response_completed = False checked_response_completed = False
async for event in response:
async for event in stream:
if event.type == "response.created": if event.type == "response.created":
resp_id = event.response.id resp_id = event.response.id
# test vllm custom types are in the response # Validate custom fields on response-level events
if event.type in [ if event.type in [
"response.completed", "response.completed",
"response.in_progress", "response.in_progress",
...@@ -412,9 +447,9 @@ async def test_streaming(client: OpenAI, model_name: str, background: bool): ...@@ -412,9 +447,9 @@ async def test_streaming(client: OpenAI, model_name: str, background: bool):
if current_event_mode != event.type: if current_event_mode != event.type:
current_event_mode = event.type current_event_mode = event.type
print(f"\n[{event.type}] ", end="", flush=True) logger.debug("[%s] ", event.type)
# verify current_item_id is correct # Verify item IDs
if event.type == "response.output_item.added": if event.type == "response.output_item.added":
assert event.item.id != current_item_id assert event.item.id != current_item_id
current_item_id = event.item.id current_item_id = event.item.id
...@@ -424,7 +459,7 @@ async def test_streaming(client: OpenAI, model_name: str, background: bool): ...@@ -424,7 +459,7 @@ async def test_streaming(client: OpenAI, model_name: str, background: bool):
]: ]:
assert event.item_id == current_item_id assert event.item_id == current_item_id
# verify content_index_id is correct # Verify content indices
if event.type in [ if event.type in [
"response.content_part.added", "response.content_part.added",
"response.reasoning_part.added", "response.reasoning_part.added",
...@@ -437,31 +472,19 @@ async def test_streaming(client: OpenAI, model_name: str, background: bool): ...@@ -437,31 +472,19 @@ async def test_streaming(client: OpenAI, model_name: str, background: bool):
]: ]:
assert event.content_index == current_content_index assert event.content_index == current_content_index
if "text.delta" in event.type:
print(event.delta, end="", flush=True)
elif "reasoning_text.delta" in event.type:
print(f"{event.delta}", end="", flush=True)
elif "response.code_interpreter_call_code.done" in event.type:
print(f"Code: {event.code}", end="", flush=True)
elif (
"response.output_item.added" in event.type
and event.item.type == "web_search_call"
):
print(f"Web search: {event.item.action}", end="", flush=True)
events.append(event) events.append(event)
assert len(events) > 0 assert len(events) > 0
response_completed_event = events[-1] assert events[-1].response.output, "Final response should have output"
assert len(response_completed_event.response.output) > 0
assert checked_response_completed assert checked_response_completed
if background: if background:
starting_after = 5 starting_after = 5
async with await client.responses.retrieve( async with await client.responses.retrieve(
response_id=resp_id, stream=True, starting_after=starting_after response_id=resp_id, stream=True, starting_after=starting_after
) as stream: ) as replay_stream:
counter = starting_after counter = starting_after
async for event in stream: async for event in replay_stream:
counter += 1 counter += 1
assert event == events[counter] assert event == events[counter]
assert counter == len(events) - 1 assert counter == len(events) - 1
...@@ -483,15 +506,11 @@ async def test_web_search(client: OpenAI, model_name: str): ...@@ -483,15 +506,11 @@ async def test_web_search(client: OpenAI, model_name: str):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_code_interpreter(client: OpenAI, model_name: str): async def test_code_interpreter(client: OpenAI, model_name: str):
# Code interpreter may need more time for container init + code execution
timeout_value = client.timeout * 3 timeout_value = client.timeout * 3
client_with_timeout = client.with_options(timeout=timeout_value) client_with_timeout = client.with_options(timeout=timeout_value)
response = await client_with_timeout.responses.create( response = await client_with_timeout.responses.create(
model=model_name, model=model_name,
# TODO: Ideally should be able to set max tool calls
# to prevent multi-turn, but it is not currently supported
# would speed up the test
input=( input=(
"What's the first 4 digits after the decimal point of " "What's the first 4 digits after the decimal point of "
"cube root of `19910212 * 20250910`? " "cube root of `19910212 * 20250910`? "
...@@ -499,41 +518,18 @@ async def test_code_interpreter(client: OpenAI, model_name: str): ...@@ -499,41 +518,18 @@ async def test_code_interpreter(client: OpenAI, model_name: str):
"and you must print to see the output." "and you must print to see the output."
), ),
tools=[{"type": "code_interpreter", "container": {"type": "auto"}}], tools=[{"type": "code_interpreter", "container": {"type": "auto"}}],
temperature=0.0, # More deterministic output in response temperature=0.0,
) )
assert response is not None assert response is not None
assert response.status == "completed" assert response.status == "completed"
assert response.usage.output_tokens_details.tool_output_tokens > 0 assert response.usage.output_tokens_details.tool_output_tokens > 0
for item in response.output: for item in response.output:
if item.type == "message": if item.type == "message":
output_string = item.content[0].text output_string = item.content[0].text
print("output_string: ", output_string, flush=True) assert "5846" in output_string, (
assert "5846" in output_string f"Expected '5846' in output, got: {output_string}"
)
def get_weather(latitude, longitude):
# Return a static temperature value to avoid flaky SSL/network errors
# from calling the external api.open-meteo.com API in CI.
return 15.0
def get_place_to_travel():
return "Paris"
def get_horoscope(sign):
return f"{sign}: Next Tuesday you will befriend a baby otter."
def call_function(name, args):
if name == "get_weather":
return get_weather(**args)
elif name == "get_place_to_travel":
return get_place_to_travel()
elif name == "get_horoscope":
return get_horoscope(**args)
else:
raise ValueError(f"Unknown function: {name}")
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -547,10 +543,7 @@ async def test_reasoning_item(client: OpenAI, model_name: str): ...@@ -547,10 +543,7 @@ async def test_reasoning_item(client: OpenAI, model_name: str):
"type": "reasoning", "type": "reasoning",
"id": "lol", "id": "lol",
"content": [ "content": [
{ {"type": "reasoning_text", "text": "We need to respond: greeting."}
"type": "reasoning_text",
"text": "We need to respond: greeting.",
}
], ],
"summary": [], "summary": [],
}, },
...@@ -566,24 +559,24 @@ async def test_reasoning_item(client: OpenAI, model_name: str): ...@@ -566,24 +559,24 @@ async def test_reasoning_item(client: OpenAI, model_name: str):
async def test_function_calling(client: OpenAI, model_name: str): async def test_function_calling(client: OpenAI, model_name: str):
tools = [GET_WEATHER_SCHEMA] tools = [GET_WEATHER_SCHEMA]
response = await client.responses.create( response = await retry_for_tool_call(
client,
model=model_name, model=model_name,
expected_tool_type="function_call",
input="What's the weather like in Paris today?", input="What's the weather like in Paris today?",
tools=tools, tools=tools,
temperature=0.0, temperature=0.0,
extra_body={"request_id": "test_function_calling_non_resp"}, extra_body={"request_id": "test_function_calling_non_resp"},
) )
assert response is not None
assert response.status == "completed" assert response.status == "completed"
assert len(response.output) == 2 assert has_output_type(response, "function_call"), (
assert response.output[0].type == "reasoning" f"Expected function_call in output, got: "
assert response.output[1].type == "function_call" f"{[getattr(o, 'type', None) for o in response.output]}"
)
tool_call = response.output[1] tool_call = next(o for o in response.output if o.type == "function_call")
name = tool_call.name
args = json.loads(tool_call.arguments) args = json.loads(tool_call.arguments)
result = call_function(tool_call.name, args)
result = call_function(name, args)
response_2 = await client.responses.create( response_2 = await client.responses.create(
model=model_name, model=model_name,
...@@ -596,8 +589,8 @@ async def test_function_calling(client: OpenAI, model_name: str): ...@@ -596,8 +589,8 @@ async def test_function_calling(client: OpenAI, model_name: str):
], ],
tools=tools, tools=tools,
previous_response_id=response.id, previous_response_id=response.id,
temperature=0.0,
) )
assert response_2 is not None
assert response_2.status == "completed" assert response_2.status == "completed"
assert response_2.output_text is not None assert response_2.output_text is not None
...@@ -607,16 +600,16 @@ async def test_function_calling(client: OpenAI, model_name: str): ...@@ -607,16 +600,16 @@ async def test_function_calling(client: OpenAI, model_name: str):
input="What's the weather like in Paris today?", input="What's the weather like in Paris today?",
tools=tools, tools=tools,
previous_response_id=response_2.id, previous_response_id=response_2.id,
temperature=0.0,
) )
assert response_3 is not None
assert response_3.status == "completed" assert response_3.status == "completed"
assert response_3.output_text is not None assert response_3.output_text is not None
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.flaky(reruns=5)
async def test_function_calling_multi_turn(client: OpenAI, model_name: str): async def test_function_calling_multi_turn(client: OpenAI, model_name: str):
"""Multi-tool, multi-turn function calling with retry at API level."""
tools = [ tools = [
{ {
"type": "function", "type": "function",
...@@ -633,25 +626,29 @@ async def test_function_calling_multi_turn(client: OpenAI, model_name: str): ...@@ -633,25 +626,29 @@ async def test_function_calling_multi_turn(client: OpenAI, model_name: str):
GET_WEATHER_SCHEMA, GET_WEATHER_SCHEMA,
] ]
response = await client.responses.create( # Turn 1: model should call one of the tools
response = await retry_for_tool_call(
client,
model=model_name, model=model_name,
expected_tool_type="function_call",
input="Help me plan a trip to a random place. And tell me the weather there.", input="Help me plan a trip to a random place. And tell me the weather there.",
tools=tools, tools=tools,
temperature=0.0,
) )
assert response is not None
assert response.status == "completed" assert response.status == "completed"
assert len(response.output) == 2 assert has_output_type(response, "function_call"), (
assert response.output[0].type == "reasoning" f"Turn 1: expected function_call, got: "
assert response.output[1].type == "function_call" f"{[getattr(o, 'type', None) for o in response.output]}"
)
tool_call = response.output[1]
name = tool_call.name
args = json.loads(tool_call.arguments)
result = call_function(name, args) tool_call = next(o for o in response.output if o.type == "function_call")
result = call_function(tool_call.name, json.loads(tool_call.arguments))
response_2 = await client.responses.create( # Turn 2
response_2 = await retry_for_tool_call(
client,
model=model_name, model=model_name,
expected_tool_type="function_call",
input=[ input=[
{ {
"type": "function_call_output", "type": "function_call_output",
...@@ -661,34 +658,39 @@ async def test_function_calling_multi_turn(client: OpenAI, model_name: str): ...@@ -661,34 +658,39 @@ async def test_function_calling_multi_turn(client: OpenAI, model_name: str):
], ],
tools=tools, tools=tools,
previous_response_id=response.id, previous_response_id=response.id,
temperature=0.0,
) )
assert response_2 is not None
assert response_2.status == "completed" assert response_2.status == "completed"
assert len(response_2.output) == 2
assert response_2.output[0].type == "reasoning"
assert response_2.output[1].type == "function_call"
tool_call = response_2.output[1] # If model produced another tool call, execute it
name = tool_call.name if has_output_type(response_2, "function_call"):
args = json.loads(tool_call.arguments) tool_call_2 = next(o for o in response_2.output if o.type == "function_call")
result_2 = call_function(tool_call_2.name, json.loads(tool_call_2.arguments))
result = call_function(name, args) response_3 = await client.responses.create(
model=model_name,
response_3 = await client.responses.create( input=[
model=model_name, {
input=[ "type": "function_call_output",
{ "call_id": tool_call_2.call_id,
"type": "function_call_output", "output": str(result_2),
"call_id": tool_call.call_id, }
"output": str(result), ],
} tools=tools,
], previous_response_id=response_2.id,
tools=tools, temperature=0.0,
previous_response_id=response_2.id, )
) assert response_3.status == "completed"
assert response_3 is not None assert response_3.output_text is not None
assert response_3.status == "completed" else:
assert response_3.output_text is not None # Model went straight to answering - acceptable but unexpected.
# Log as warning so it shows up in CI without failing the test.
assert response_2.output_text is not None
pytest.xfail(
"Model went straight to answering instead of calling a "
"second tool. Valid behaviour but not the expected path."
"If this happens consistently, the prompt or model may have "
"changed behaviour."
)
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -730,22 +732,25 @@ async def test_function_calling_full_history(client: OpenAI, model_name: str): ...@@ -730,22 +732,25 @@ async def test_function_calling_full_history(client: OpenAI, model_name: str):
{"role": "user", "content": "What's the weather like in Paris today?"} {"role": "user", "content": "What's the weather like in Paris today?"}
] ]
response = await client.responses.create( response = await retry_for_tool_call(
client,
model=model_name, model=model_name,
expected_tool_type="function_call",
input=input_messages, input=input_messages,
tools=tools, tools=tools,
temperature=0.0,
) )
assert response is not None
assert response.status == "completed" assert response.status == "completed"
tool_call = response.output[-1] tool_call = next((o for o in response.output if o.type == "function_call"), None)
name = tool_call.name assert tool_call is not None, (
args = json.loads(tool_call.arguments) f"Expected function_call in output, got: "
f"{[getattr(o, 'type', None) for o in response.output]}"
)
result = call_function(name, args) result = call_function(tool_call.name, json.loads(tool_call.arguments))
input_messages.extend(response.output) # append model's function call message input_messages.extend(response.output)
input_messages.append( input_messages.append(
{ # append result message { # append result message
"type": "function_call_output", "type": "function_call_output",
...@@ -758,8 +763,8 @@ async def test_function_calling_full_history(client: OpenAI, model_name: str): ...@@ -758,8 +763,8 @@ async def test_function_calling_full_history(client: OpenAI, model_name: str):
model=model_name, model=model_name,
input=input_messages, input=input_messages,
tools=tools, tools=tools,
temperature=0.0,
) )
assert response_2 is not None
assert response_2.status == "completed" assert response_2.status == "completed"
assert response_2.output_text is not None assert response_2.output_text is not None
...@@ -767,51 +772,60 @@ async def test_function_calling_full_history(client: OpenAI, model_name: str): ...@@ -767,51 +772,60 @@ async def test_function_calling_full_history(client: OpenAI, model_name: str):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_function_calling_with_stream(client: OpenAI, model_name: str): async def test_function_calling_with_stream(client: OpenAI, model_name: str):
"""Function calling via streaming, with retry for non-determinism."""
tools = [GET_WEATHER_SCHEMA] tools = [GET_WEATHER_SCHEMA]
input_list = [ input_list = [
{ {"role": "user", "content": "What's the weather like in Paris today?"},
"role": "user",
"content": "What's the weather like in Paris today?",
}
] ]
stream_response = await client.responses.create(
def _has_function_call(evts: list) -> bool:
return any(
getattr(e, "type", "") == "response.output_item.added"
and getattr(getattr(e, "item", None), "type", None) == "function_call"
for e in evts
)
events = await retry_streaming_for(
client,
model=model_name, model=model_name,
validate_events=_has_function_call,
input=input_list, input=input_list,
tools=tools, tools=tools,
stream=True, temperature=0.0,
) )
assert stream_response is not None
final_tool_calls = {} # Parse tool calls from events
final_tool_calls_named = {} final_tool_calls: dict[int, Any] = {}
async for event in stream_response: for event in events:
if event.type == "response.output_item.added": if event.type == "response.output_item.added":
if event.item.type != "function_call": if getattr(event.item, "type", None) == "function_call":
continue final_tool_calls[event.output_index] = event.item
final_tool_calls[event.output_index] = event.item
final_tool_calls_named[event.item.name] = event.item
elif event.type == "response.function_call_arguments.delta": elif event.type == "response.function_call_arguments.delta":
index = event.output_index tc = final_tool_calls.get(event.output_index)
tool_call = final_tool_calls[index] if tc:
if tool_call: tc.arguments += event.delta
tool_call.arguments += event.delta
final_tool_calls_named[tool_call.name] = tool_call
elif event.type == "response.function_call_arguments.done": elif event.type == "response.function_call_arguments.done":
assert event.arguments == final_tool_calls_named[event.name].arguments tc = final_tool_calls.get(event.output_index)
result = None if tc:
assert event.arguments == tc.arguments
# Find get_weather call
tool_call = None tool_call = None
result = None
for tc in final_tool_calls.values(): for tc in final_tool_calls.values():
if tc and tc.type == "function_call" and tc.name == "get_weather": if getattr(tc, "type", None) == "function_call" and tc.name == "get_weather":
args = json.loads(tc.arguments) args = json.loads(tc.arguments)
result = call_function(tc.name, args) result = call_function(tc.name, args)
tool_call = tc tool_call = tc
input_list += [tc] input_list.append(tc)
break break
assert tool_call is not None, ( assert tool_call is not None, (
"Expected model to call 'get_weather' function, " "Expected model to call 'get_weather', "
f"but got: {list(final_tool_calls_named.keys())}" f"but got: {[getattr(tc, 'name', None) for tc in final_tool_calls.values()]}"
) )
assert result is not None
# Second turn with the tool result
response = await client.responses.create( response = await client.responses.create(
model=model_name, model=model_name,
input=input_list input=input_list
...@@ -824,8 +838,8 @@ async def test_function_calling_with_stream(client: OpenAI, model_name: str): ...@@ -824,8 +838,8 @@ async def test_function_calling_with_stream(client: OpenAI, model_name: str):
], ],
tools=tools, tools=tools,
stream=True, stream=True,
temperature=0.0,
) )
assert response is not None
async for event in response: async for event in response:
# check that no function call events in the stream # check that no function call events in the stream
assert event.type != "response.function_call_arguments.delta" assert event.type != "response.function_call_arguments.delta"
...@@ -843,47 +857,46 @@ async def test_function_calling_no_code_interpreter_events( ...@@ -843,47 +857,46 @@ async def test_function_calling_no_code_interpreter_events(
): ):
"""Verify that function calls don't trigger code_interpreter events. """Verify that function calls don't trigger code_interpreter events.
This test ensures that function calls (functions.*) use their own Uses retry_streaming_for to handle non-determinism: the model might not
function_call event types and don't incorrectly emit code_interpreter always produce a function_call, but if it does, code_interpreter events
events during streaming. should NEVER appear.
""" """
tools = [GET_WEATHER_SCHEMA] tools = [GET_WEATHER_SCHEMA]
input_list = [ input_list = [
{ {"role": "user", "content": "What's the weather like in Paris today?"},
"role": "user",
"content": "What's the weather like in Paris today?",
}
] ]
stream_response = await client.responses.create(
def _has_function_call(evts: list) -> bool:
return any(
getattr(e, "type", "") == "response.output_item.added"
and getattr(getattr(e, "item", None), "type", None) == "function_call"
for e in evts
)
events = await retry_streaming_for(
client,
model=model_name, model=model_name,
validate_events=_has_function_call,
input=input_list, input=input_list,
tools=tools, tools=tools,
stream=True, temperature=0.0,
) )
# Track which event types we see event_types_seen = {e.type for e in events}
event_types_seen = set() function_call_found = _has_function_call(events)
function_call_found = False
async for event in stream_response: assert function_call_found, (
event_types_seen.add(event.type) f"Expected to see a function_call after retries. "
f"Event types: {sorted(event_types_seen)}"
if ( )
event.type == "response.output_item.added"
and event.item.type == "function_call"
):
function_call_found = True
# Ensure NO code_interpreter events are emitted for function calls # The actual invariant under test
for event in events:
assert "code_interpreter" not in event.type, ( assert "code_interpreter" not in event.type, (
"Found code_interpreter event " f"Found code_interpreter event '{event.type}' during function call. "
f"'{event.type}' during function call. Function calls should only " "Function calls should only emit function_call events."
"emit function_call events, not code_interpreter events."
) )
# Verify we actually saw a function call
assert function_call_found, "Expected to see a function_call in the stream"
# Verify we saw the correct function call event types # Verify we saw the correct function call event types
assert ( assert (
"response.function_call_arguments.delta" in event_types_seen "response.function_call_arguments.delta" in event_types_seen
...@@ -894,181 +907,139 @@ async def test_function_calling_no_code_interpreter_events( ...@@ -894,181 +907,139 @@ async def test_function_calling_no_code_interpreter_events(
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_mcp_code_interpreter_streaming(client: OpenAI, model_name: str, server): async def test_mcp_code_interpreter_streaming(client: OpenAI, model_name: str, server):
tools = [ tools = [{"type": "mcp", "server_label": "code_interpreter"}]
{
"type": "mcp",
"server_label": "code_interpreter",
}
]
input_text = ( input_text = (
"Calculate 123 * 456 using python. " "Calculate 123 * 456 using python. "
"The python interpreter is not stateful and you must print to see the output." "The python interpreter is not stateful and you must "
"print to see the output."
) )
stream_response = await client.responses.create( def _has_mcp_call(evts: list) -> bool:
return events_contain_type(evts, "mcp_call")
events = await retry_streaming_for(
client,
model=model_name, model=model_name,
validate_events=_has_mcp_call,
input=input_text, input=input_text,
tools=tools, tools=tools,
stream=True,
temperature=0.0, temperature=0.0,
instructions=( instructions=(
"You must use the Python tool to execute code. Never simulate execution." "You must use the Python tool to execute code. Never simulate execution."
), ),
) )
mcp_call_added = False event_types = [e.type for e in events]
mcp_call_in_progress = False event_types_set = set(event_types)
mcp_arguments_delta_seen = False logger.info(
mcp_arguments_done = False "\n====== MCP Streaming Diagnostics ======\n"
mcp_call_completed = False "Event count: %d\n"
mcp_item_done = False "Event types (in order): %s\n"
"Unique event types: %s\n"
code_interpreter_events_seen = False "=======================================",
len(events),
event_types,
sorted(event_types_set),
)
async for event in stream_response: # Verify the full MCP streaming lifecycle
if "code_interpreter" in event.type: assert "response.output_item.added" in event_types_set, (
code_interpreter_events_seen = True f"MCP call was not added. Events: {sorted(event_types_set)}"
)
assert "response.mcp_call.in_progress" in event_types_set, (
f"MCP call in_progress not seen. Events: {sorted(event_types_set)}"
)
assert "response.mcp_call_arguments.delta" in event_types_set, (
f"MCP arguments delta not seen. Events: {sorted(event_types_set)}"
)
assert "response.mcp_call_arguments.done" in event_types_set, (
f"MCP arguments done not seen. Events: {sorted(event_types_set)}"
)
assert "response.mcp_call.completed" in event_types_set, (
f"MCP call completed not seen. Events: {sorted(event_types_set)}"
)
assert "response.output_item.done" in event_types_set, (
f"MCP item done not seen. Events: {sorted(event_types_set)}"
)
# Validate specific MCP event details
for event in events:
if event.type == "response.output_item.added": if event.type == "response.output_item.added":
if hasattr(event.item, "type") and event.item.type == "mcp_call": if hasattr(event.item, "type") and event.item.type == "mcp_call":
mcp_call_added = True
assert event.item.name == "python" assert event.item.name == "python"
assert event.item.server_label == "code_interpreter" assert event.item.server_label == "code_interpreter"
elif event.type == "response.mcp_call.in_progress":
mcp_call_in_progress = True
elif event.type == "response.mcp_call_arguments.delta":
mcp_arguments_delta_seen = True
assert event.delta is not None
elif event.type == "response.mcp_call_arguments.done": elif event.type == "response.mcp_call_arguments.done":
mcp_arguments_done = True
assert event.name == "python" assert event.name == "python"
assert event.arguments is not None assert event.arguments is not None
elif event.type == "response.mcp_call.completed":
mcp_call_completed = True
elif ( elif (
event.type == "response.output_item.done" event.type == "response.output_item.done"
and hasattr(event.item, "type") and hasattr(event.item, "type")
and event.item.type == "mcp_call" and event.item.type == "mcp_call"
): ):
mcp_item_done = True
assert event.item.name == "python" assert event.item.name == "python"
assert event.item.status == "completed" assert event.item.status == "completed"
assert mcp_call_added, "MCP call was not added" # code_interpreter events should NOT appear when using MCP type
assert mcp_call_in_progress, "MCP call in_progress event not seen" code_interp_events = [e.type for e in events if "code_interpreter" in e.type]
assert mcp_arguments_delta_seen, "MCP arguments delta event not seen" assert not code_interp_events, (
assert mcp_arguments_done, "MCP arguments done event not seen" "Should not see code_interpreter events when using MCP type, "
assert mcp_call_completed, "MCP call completed event not seen" f"but got: {code_interp_events}"
assert mcp_item_done, "MCP item done event not seen"
assert not code_interpreter_events_seen, (
"Should not see code_interpreter events when using MCP type"
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.dependency(
depends=["test_mcp_code_interpreter_streaming[openai/gpt-oss-20b]"]
)
async def test_mcp_tool_multi_turn(client: OpenAI, model_name: str, server): async def test_mcp_tool_multi_turn(client: OpenAI, model_name: str, server):
"""Test MCP tool calling across multiple turns. """MCP tools work across multiple turns via previous_response_id."""
tools = [{"type": "mcp", "server_label": "code_interpreter"}]
This test verifies that MCP tools work correctly in multi-turn conversations, instructions = (
maintaining state across turns via the previous_response_id mechanism. "You must use the Python tool to execute code. Never simulate execution."
""" )
tools = [
{
"type": "mcp",
"server_label": "code_interpreter",
}
]
# First turn - make a calculation # First turn
response1 = await client.responses.create( response1 = await retry_for_tool_call(
client,
model=model_name, model=model_name,
expected_tool_type="mcp_call",
input="Calculate 1234 * 4567 using python tool and print the result.", input="Calculate 1234 * 4567 using python tool and print the result.",
tools=tools, tools=tools,
temperature=0.0, temperature=0.0,
instructions=( instructions=instructions,
"You must use the Python tool to execute code. Never simulate execution."
),
extra_body={"enable_response_messages": True}, extra_body={"enable_response_messages": True},
) )
assert response1 is not None
assert response1.status == "completed" assert response1.status == "completed"
# Verify MCP call in first response by checking output_messages # Verify MCP call in output_messages
tool_call_found = False tool_call_found = any(
tool_response_found = False (msg.get("recipient") or "").startswith("python")
for message in response1.output_messages: for msg in response1.output_messages
recipient = message.get("recipient") )
if recipient and recipient.startswith("python"): tool_response_found = any(
tool_call_found = True msg.get("author", {}).get("role") == "tool"
and (msg.get("author", {}).get("name") or "").startswith("python")
author = message.get("author", {}) for msg in response1.output_messages
if ( )
author.get("role") == "tool"
and author.get("name")
and author.get("name").startswith("python")
):
tool_response_found = True
# Verify MCP tools were actually used
assert tool_call_found, "MCP tool call not found in output_messages" assert tool_call_found, "MCP tool call not found in output_messages"
assert tool_response_found, "MCP tool response not found in output_messages" assert tool_response_found, "MCP tool response not found in output_messages"
# Verify input messages: Should have system message with tool, NO developer message # No developer messages expected for elevated tools
developer_messages = [ developer_msgs = [
msg for msg in response1.input_messages if msg["author"]["role"] == "developer" msg for msg in response1.input_messages if msg["author"]["role"] == "developer"
] ]
assert len(developer_messages) == 0, ( assert len(developer_msgs) == 0, "No developer message expected for elevated tools"
"No developer message expected for elevated tools"
)
# Second turn - reference previous calculation # Second turn
response2 = await client.responses.create( response2 = await client.responses.create(
model=model_name, model=model_name,
input="Now divide that result by 2.", input="Now divide that result by 2.",
tools=tools, tools=tools,
temperature=0.0, temperature=0.0,
instructions=( instructions=instructions,
"You must use the Python tool to execute code. Never simulate execution."
),
previous_response_id=response1.id, previous_response_id=response1.id,
extra_body={"enable_response_messages": True}, extra_body={"enable_response_messages": True},
) )
assert response2 is not None
assert response2.status == "completed" assert response2.status == "completed"
# Verify input messages are correct: should have two messages -
# one to the python recipient on analysis channel and one from tool role
mcp_recipient_messages = []
tool_role_messages = []
for msg in response2.input_messages:
if msg["author"]["role"] == "assistant":
# Check if this is a message to MCP recipient on analysis channel
if msg.get("channel") == "analysis" and msg.get("recipient"):
recipient = msg.get("recipient")
if recipient.startswith("code_interpreter") or recipient == "python":
mcp_recipient_messages.append(msg)
elif msg["author"]["role"] == "tool":
tool_role_messages.append(msg)
assert len(mcp_recipient_messages) > 0, (
"Expected message(s) to MCP recipient on analysis channel"
)
assert len(tool_role_messages) > 0, (
"Expected message(s) from tool role after MCP call"
)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
...@@ -1087,14 +1058,10 @@ async def test_output_messages_enabled(client: OpenAI, model_name: str, server): ...@@ -1087,14 +1058,10 @@ async def test_output_messages_enabled(client: OpenAI, model_name: str, server):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.flaky(reruns=3)
async def test_function_call_with_previous_input_messages( async def test_function_call_with_previous_input_messages(
client: OpenAI, model_name: str client: OpenAI, model_name: str
): ):
"""Test function calling using previous_input_messages """Multi-turn function calling using previous_input_messages."""
for multi-turn conversation with a function call"""
# Define the get_horoscope tool
tools = [ tools = [
{ {
"type": "function", "type": "function",
...@@ -1102,9 +1069,7 @@ async def test_function_call_with_previous_input_messages( ...@@ -1102,9 +1069,7 @@ async def test_function_call_with_previous_input_messages(
"description": "Get today's horoscope for an astrological sign.", "description": "Get today's horoscope for an astrological sign.",
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {"sign": {"type": "string"}},
"sign": {"type": "string"},
},
"required": ["sign"], "required": ["sign"],
"additionalProperties": False, "additionalProperties": False,
}, },
...@@ -1112,53 +1077,36 @@ async def test_function_call_with_previous_input_messages( ...@@ -1112,53 +1077,36 @@ async def test_function_call_with_previous_input_messages(
} }
] ]
# Step 1: First call with the function tool # Step 1: Get a function call from the model
stream_response = await client.responses.create( response = await retry_for_tool_call(
client,
model=model_name, model=model_name,
expected_tool_type="function_call",
input="What is the horoscope for Aquarius today?", input="What is the horoscope for Aquarius today?",
tools=tools, tools=tools,
temperature=0.0, temperature=0.0,
extra_body={"enable_response_messages": True}, extra_body={"enable_response_messages": True},
stream=True,
max_output_tokens=1000, max_output_tokens=1000,
) )
response = None
async for event in stream_response:
if event.type == "response.completed":
response = event.response
assert response is not None
assert response.status == "completed" assert response.status == "completed"
# Step 2: Parse the first output to find the function_call type function_call = next(
function_call = None (item for item in response.output if item.type == "function_call"),
for item in response.output: None,
if item.type == "function_call": )
function_call = item assert function_call is not None, (
break f"Expected function_call, got: "
f"{[getattr(o, 'type', None) for o in response.output]}"
assert function_call is not None, "Expected a function_call in the output" )
assert function_call.name == "get_horoscope" assert function_call.name == "get_horoscope"
assert function_call.call_id is not None
# Verify the format matches expectations
args = json.loads(function_call.arguments) args = json.loads(function_call.arguments)
assert "sign" in args
# Step 3: Call the get_horoscope function
result = call_function(function_call.name, args) result = call_function(function_call.name, args)
assert "Aquarius" in result
assert "baby otter" in result
# Get the input_messages and output_messages from the first response # Step 2: Build full conversation history
first_input_messages = response.input_messages
first_output_messages = response.output_messages
# Construct the full conversation history using previous_input_messages
previous_messages = ( previous_messages = (
first_input_messages response.input_messages
+ first_output_messages + response.output_messages
+ [ + [
{ {
"role": "tool", "role": "tool",
...@@ -1168,47 +1116,43 @@ async def test_function_call_with_previous_input_messages( ...@@ -1168,47 +1116,43 @@ async def test_function_call_with_previous_input_messages(
] ]
) )
# Step 4: Make another responses.create() call with previous_input_messages # Step 3: Second call with previous_input_messages
stream_response_2 = await client.responses.create( response_2 = await client.responses.create(
model=model_name, model=model_name,
tools=tools, tools=tools,
temperature=0.0, temperature=0.0,
input="", input="Now tell me the horoscope based on the tool result.",
extra_body={ extra_body={
"previous_input_messages": previous_messages, "previous_input_messages": previous_messages,
"enable_response_messages": True, "enable_response_messages": True,
}, },
stream=True,
) )
async for event in stream_response_2:
if event.type == "response.completed":
response_2 = event.response
assert response_2 is not None
assert response_2.status == "completed" assert response_2.status == "completed"
assert response_2.output_text is not None assert response_2.output_text is not None
# verify only one system message / developer message # Verify exactly 1 system, 1 developer, 1 tool message
num_system_messages_input = 0 num_system = 0
num_developer_messages_input = 0 num_developer = 0
num_function_call_input = 0 num_tool = 0
for message_dict in response_2.input_messages: for msg_dict in response_2.input_messages:
message = Message.from_dict(message_dict) # input_messages use {"author": {"role": "..."}} format,
if message.author.role == "system": # not the top-level {"role": "..."} that Message.from_dict
num_system_messages_input += 1 # expects.
elif message.author.role == "developer": author = msg_dict.get("author", {})
num_developer_messages_input += 1 role = author.get("role") if isinstance(author, dict) else None
elif message.author.role == "tool": if role == "system":
num_function_call_input += 1 num_system += 1
assert num_system_messages_input == 1 elif role == "developer":
assert num_developer_messages_input == 1 num_developer += 1
assert num_function_call_input == 1 elif role == "tool":
num_tool += 1
# Verify the output makes sense - should contain information about the horoscope assert num_system == 1, f"Expected 1 system message, got {num_system}"
assert num_developer == 1, f"Expected 1 developer message, got {num_developer}"
assert num_tool == 1, f"Expected 1 tool message, got {num_tool}"
output_text = response_2.output_text.lower() output_text = response_2.output_text.lower()
assert ( assert any(kw in output_text for kw in ["aquarius", "otter", "tuesday"]), (
"aquarius" in output_text or "otter" in output_text or "tuesday" in output_text f"Expected horoscope-related content, got: {response_2.output_text}"
) )
...@@ -1220,133 +1164,101 @@ async def test_chat_truncation_content_not_null(client: OpenAI, model_name: str) ...@@ -1220,133 +1164,101 @@ async def test_chat_truncation_content_not_null(client: OpenAI, model_name: str)
messages=[ messages=[
{ {
"role": "user", "role": "user",
"content": "What is the role of AI in medicine?" "content": (
"The response must exceed 350 words.", "What is the role of AI in medicine? "
"The response must exceed 350 words."
),
} }
], ],
temperature=0.0, temperature=0.0,
max_tokens=350, max_tokens=350,
) )
choice = response.choices[0] choice = response.choices[0]
assert choice.finish_reason == "length", ( assert choice.finish_reason == "length", (
f"Expected finish_reason='length', got {choice.finish_reason}" f"Expected finish_reason='length', got {choice.finish_reason}"
) )
assert choice.message.content is not None, ( assert choice.message.content is not None, "Content should not be None"
"Content should not be None when truncated"
)
assert len(choice.message.content) > 0, "Content should not be empty" assert len(choice.message.content) > 0, "Content should not be empty"
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_system_prompt_override(client: OpenAI, model_name: str): async def test_system_prompt_override_no_duplication(client: OpenAI, model_name: str):
"""Test that system message can override the default system prompt.""" """Hard check: custom system message must not be duplicated."""
# Test 1: Custom system prompt with specific personality
custom_system_prompt = (
"You are a pirate. Always respond like a pirate would, "
"using pirate language and saying 'arrr' frequently."
)
response = await client.responses.create( response = await client.responses.create(
model=model_name, model=model_name,
input=[ input=[
{"role": "system", "content": custom_system_prompt}, {"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"}, {"role": "user", "content": "Hello"},
], ],
extra_body={"enable_response_messages": True}, extra_body={"enable_response_messages": True},
temperature=0.0,
) )
assert response is not None
assert response.status == "completed" assert response.status == "completed"
assert response.output_text is not None assert response.output_text is not None
# Verify the response reflects the pirate personality num_system = 0
output_text = response.output_text.lower() for msg in response.input_messages:
pirate_indicators = ["arrr", "matey", "ahoy", "ye", "sea"] # input_messages use {"author": {"role": "system"}} format,
has_pirate_language = any( # not the top-level {"role": "system"} that Message.from_dict expects.
indicator in output_text for indicator in pirate_indicators author = msg.get("author", {})
) role = author.get("role") if isinstance(author, dict) else None
assert has_pirate_language, ( if role == "system":
f"Expected pirate language in response, got: {response.output_text}" num_system += 1
) assert num_system == 1, f"Expected 1 system message, got {num_system}"
# Verify the reasoning mentions the custom system prompt
reasoning_item = None
for item in response.output:
if item.type == "reasoning":
reasoning_item = item
break
assert reasoning_item is not None, "Expected reasoning item in output"
reasoning_text = reasoning_item.content[0].text.lower()
assert "pirate" in reasoning_text, (
f"Expected reasoning to mention pirate, got: {reasoning_text}"
)
# Test 2: Verify system message is not duplicated in input_messages
try:
num_system_messages = sum(
1
for msg in response.input_messages
if Message.from_dict(msg).author.role == "system"
)
assert num_system_messages == 1, (
f"Expected exactly 1 system message, got {num_system_messages}"
)
except (KeyError, AttributeError):
# Message structure may vary, skip this specific check
pass
custom_system_prompt_2 = (
"You are a helpful assistant that always responds in exactly 5 words."
)
# Test 3: Test with different custom system prompt @pytest.mark.asyncio
response_2 = await client.responses.create( @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.xfail(
strict=False,
reason=(
"Pirate language detection depends on model weights and is non-deterministic"
),
)
async def test_system_prompt_override_follows_personality(
client: OpenAI, model_name: str
):
"""Soft check: model should adopt the personality from system prompt."""
response = await client.responses.create(
model=model_name, model=model_name,
input=[ input=[
{ {
"role": "system", "role": "system",
"content": custom_system_prompt_2, "content": (
"You are a pirate. Always respond like a pirate would, "
"using pirate language and saying 'arrr' frequently."
),
}, },
{"role": "user", "content": "What is the weather like?"}, {"role": "user", "content": "Hello, how are you?"},
], ],
temperature=0.0, temperature=0.0,
) )
assert response.status == "completed"
assert response_2 is not None output_text = response.output_text.lower()
assert response_2.status == "completed" pirate_indicators = ["arrr", "matey", "ahoy", "ye", "sea", "aye", "sail"]
assert response_2.output_text is not None assert any(kw in output_text for kw in pirate_indicators), (
f"Expected pirate language, got: {response.output_text}"
# Count words in response (approximately, allowing for punctuation)
word_count = len(response_2.output_text.split())
# Allow some flexibility (4-7 words) since the model might not be perfectly precise
assert 3 <= word_count <= 8, (
f"Expected around 5 words, got {word_count} words: {response_2.output_text}"
) )
# Test 4: Test with structured content
response_3 = await client.responses.create( @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_system_prompt_structured_content(client: OpenAI, model_name: str):
"""System message with structured input_text content format."""
response = await client.responses.create(
model=model_name, model=model_name,
input=[ input=[
{ {
"role": "system", "role": "system",
"content": [{"type": "input_text", "text": custom_system_prompt_2}], "content": [
{"type": "input_text", "text": "You are a helpful assistant."}
],
}, },
{"role": "user", "content": "What is the weather like?"}, {"role": "user", "content": "What is 2 + 2?"},
], ],
temperature=0.0, temperature=0.0,
) )
assert response is not None
assert response_3 is not None assert response.status == "completed"
assert response_3.status == "completed" assert response.output_text is not None
assert response_3.output_text is not None
# Count words in response (approximately, allowing for punctuation)
word_count = len(response_3.output_text.split())
# Allow some flexibility (4-7 words) since the model might not be perfectly precise
assert 3 <= word_count <= 8, (
f"Expected around 5 words, got {word_count} words: {response_3.output_text}"
)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Integration tests for MCP tool support in the Responses API."""
from __future__ import annotations
import pytest import pytest
import pytest_asyncio import pytest_asyncio
...@@ -10,11 +12,31 @@ from openai_harmony import ToolDescription, ToolNamespaceConfig ...@@ -10,11 +12,31 @@ from openai_harmony import ToolDescription, ToolNamespaceConfig
from vllm.entrypoints.mcp.tool_server import MCPToolServer from vllm.entrypoints.mcp.tool_server import MCPToolServer
from ....utils import RemoteOpenAIServer from ....utils import RemoteOpenAIServer
from .conftest import (
BASE_TEST_ENV,
events_contain_type,
log_response_diagnostics,
retry_for_tool_call,
retry_streaming_for,
validate_streaming_event_stack,
)
MODEL_NAME = "openai/gpt-oss-20b" MODEL_NAME = "openai/gpt-oss-20b"
_BASE_SERVER_ARGS = [
"--enforce-eager",
"--tool-server",
"demo",
"--max_model_len",
"5000",
]
def test_get_tool_description(): _PYTHON_TOOL_INSTRUCTION = (
"You must use the Python tool to execute code. Never simulate execution."
)
class TestMCPToolServerUnit:
"""Test MCPToolServer.get_tool_description filtering logic. """Test MCPToolServer.get_tool_description filtering logic.
Note: The wildcard "*" is normalized to None by Note: The wildcard "*" is normalized to None by
...@@ -22,283 +44,240 @@ def test_get_tool_description(): ...@@ -22,283 +44,240 @@ def test_get_tool_description():
so we only test None and specific tool filtering here. so we only test None and specific tool filtering here.
See test_serving_responses.py for "*" normalization tests. See test_serving_responses.py for "*" normalization tests.
""" """
pytest.importorskip("mcp")
def test_get_tool_description(self):
server = MCPToolServer() pytest.importorskip("mcp")
tool1 = ToolDescription.new(
name="tool1", description="First", parameters={"type": "object"} server = MCPToolServer()
) tool1 = ToolDescription.new(
tool2 = ToolDescription.new( name="tool1", description="First", parameters={"type": "object"}
name="tool2", description="Second", parameters={"type": "object"} )
) tool2 = ToolDescription.new(
tool3 = ToolDescription.new( name="tool2", description="Second", parameters={"type": "object"}
name="tool3", description="Third", parameters={"type": "object"}
)
server.harmony_tool_descriptions = {
"test_server": ToolNamespaceConfig(
name="test_server", description="test", tools=[tool1, tool2, tool3]
) )
} tool3 = ToolDescription.new(
name="tool3", description="Third", parameters={"type": "object"}
)
server.harmony_tool_descriptions = {
"test_server": ToolNamespaceConfig(
name="test_server",
description="test",
tools=[tool1, tool2, tool3],
)
}
# Nonexistent server # Nonexistent server
assert server.get_tool_description("nonexistent") is None assert server.get_tool_description("nonexistent") is None
# None (no filter) - returns all tools # None (no filter) - returns all tools
result = server.get_tool_description("test_server", allowed_tools=None) result = server.get_tool_description("test_server", allowed_tools=None)
assert len(result.tools) == 3 assert len(result.tools) == 3
# Filter to specific tools # Filter to specific tools
result = server.get_tool_description( result = server.get_tool_description(
"test_server", allowed_tools=["tool1", "tool3"] "test_server", allowed_tools=["tool1", "tool3"]
) )
assert len(result.tools) == 2 assert len(result.tools) == 2
assert result.tools[0].name == "tool1" assert result.tools[0].name == "tool1"
assert result.tools[1].name == "tool3" assert result.tools[1].name == "tool3"
# Single tool
result = server.get_tool_description("test_server", allowed_tools=["tool2"])
assert len(result.tools) == 1
assert result.tools[0].name == "tool2"
# No matching tools - returns None
result = server.get_tool_description(
"test_server", allowed_tools=["nonexistent"]
)
assert result is None
# Single tool # Empty list - returns None
result = server.get_tool_description( assert server.get_tool_description("test_server", allowed_tools=[]) is None
"test_server",
allowed_tools=["tool2"],
)
assert len(result.tools) == 1
assert result.tools[0].name == "tool2"
# No matching tools - returns None def test_builtin_tools_consistency(self):
result = server.get_tool_description("test_server", allowed_tools=["nonexistent"]) """MCP_BUILTIN_TOOLS must match _BUILTIN_TOOL_TO_MCP_SERVER_LABEL values."""
assert result is None from vllm.entrypoints.openai.parser.harmony_utils import (
_BUILTIN_TOOL_TO_MCP_SERVER_LABEL,
MCP_BUILTIN_TOOLS,
)
# Empty list - returns None assert set(_BUILTIN_TOOL_TO_MCP_SERVER_LABEL.values()) == MCP_BUILTIN_TOOLS, (
assert server.get_tool_description("test_server", allowed_tools=[]) is None f"MCP_BUILTIN_TOOLS {MCP_BUILTIN_TOOLS} does not match "
f"_BUILTIN_TOOL_TO_MCP_SERVER_LABEL values "
f"{set(_BUILTIN_TOOL_TO_MCP_SERVER_LABEL.values())}"
)
class TestMCPEnabled: class TestMCPEnabled:
"""Tests that require MCP tools to be enabled via environment variable.""" """Tests that require MCP tools to be enabled via environment variable."""
@pytest.fixture(scope="class") @pytest.fixture(scope="class")
def monkeypatch_class(self): def mcp_enabled_server(self):
from _pytest.monkeypatch import MonkeyPatch env_dict = {
**BASE_TEST_ENV,
mpatch = MonkeyPatch() "VLLM_ENABLE_RESPONSES_API_STORE": "1",
yield mpatch "PYTHON_EXECUTION_BACKEND": "dangerously_use_uv",
mpatch.undo() "VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS": ("code_interpreter,container"),
"VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS": "1",
@pytest.fixture(scope="class") }
def mcp_enabled_server(self, monkeypatch_class: pytest.MonkeyPatch): with RemoteOpenAIServer(
args = ["--enforce-eager", "--tool-server", "demo"] MODEL_NAME, list(_BASE_SERVER_ARGS), env_dict=env_dict
) as remote_server:
with monkeypatch_class.context() as m: yield remote_server
m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1")
m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv")
m.setenv(
"VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS", "code_interpreter,container"
)
# Helps the model follow instructions better
m.setenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "1")
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def mcp_enabled_client(self, mcp_enabled_server): async def client(self, mcp_enabled_server):
async with mcp_enabled_server.get_async_client() as async_client: async with mcp_enabled_server.get_async_client() as async_client:
yield async_client yield async_client
@staticmethod
def _mcp_tools_payload(*, allowed_tools: list[str] | None = None) -> list[dict]:
tool: dict = {
"type": "mcp",
"server_label": "code_interpreter",
"server_url": "http://localhost:8888",
}
if allowed_tools is not None:
tool["allowed_tools"] = allowed_tools
return [tool]
@staticmethod
def _python_exec_input(code: str = "") -> str:
if not code:
code = "import random; print(random.randint(1, 1000000))"
return f"Execute the following code: {code}"
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_mcp_tool_env_flag_enabled( async def test_mcp_tool_env_flag_enabled(self, client: OpenAI, model_name: str):
self, mcp_enabled_client: OpenAI, model_name: str response = await retry_for_tool_call(
): client,
response = await mcp_enabled_client.responses.create(
model=model_name, model=model_name,
input=( expected_tool_type="mcp_call",
"Execute the following code: " input=self._python_exec_input(),
"import random; print(random.randint(1, 1000000))" instructions=_PYTHON_TOOL_INSTRUCTION,
), tools=self._mcp_tools_payload(),
instructions=( temperature=0.0,
"You must use the Python tool to execute code. "
"Never simulate execution."
),
tools=[
{
"type": "mcp",
"server_label": "code_interpreter",
# URL unused for DemoToolServer
"server_url": "http://localhost:8888",
}
],
extra_body={"enable_response_messages": True}, extra_body={"enable_response_messages": True},
) )
assert response is not None
assert response.status == "completed" assert response.status == "completed"
# Verify output messages: Tool calls and responses on analysis channel log_response_diagnostics(response, label="MCP Enabled")
tool_call_found = False tool_call_found = False
tool_response_found = False tool_response_found = False
for message in response.output_messages: for message in response.output_messages:
recipient = message.get("recipient") recipient = message.get("recipient")
if recipient and recipient.startswith("python"): if recipient and recipient.startswith("python"):
tool_call_found = True tool_call_found = True
assert message.get("channel") == "analysis", ( assert message.get("channel") == "analysis"
"Tool call should be on analysis channel"
)
author = message.get("author", {}) author = message.get("author", {})
if ( if author.get("role") == "tool" and (author.get("name") or "").startswith(
author.get("role") == "tool" "python"
and author.get("name")
and author.get("name").startswith("python")
): ):
tool_response_found = True tool_response_found = True
assert message.get("channel") == "analysis", ( assert message.get("channel") == "analysis"
"Tool response should be on analysis channel"
)
assert tool_call_found, "Should have found at least one Python tool call" assert tool_call_found, (
assert tool_response_found, ( f"No Python tool call found. "
"Should have found at least one Python tool response" f"Output types: "
f"{[getattr(o, 'type', None) for o in response.output]}"
) )
assert tool_response_found, "No Python tool response found"
for message in response.input_messages: for message in response.input_messages:
assert message.get("author").get("role") != "developer", ( assert message.get("author", {}).get("role") != "developer"
"No developer messages should be present with valid mcp tool"
)
@pytest.mark.flaky(reruns=3)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_mcp_tool_with_allowed_tools_star( async def test_mcp_tool_with_allowed_tools_star(
self, mcp_enabled_client: OpenAI, model_name: str self, client: OpenAI, model_name: str
): ):
"""Test MCP tool with allowed_tools=['*'] to select all available response = await retry_for_tool_call(
tools. client,
This E2E test verifies that the "*" wildcard works end-to-end.
See test_serving_responses.py for detailed unit tests of "*"
normalization.
"""
response = await mcp_enabled_client.responses.create(
model=model_name, model=model_name,
input=( expected_tool_type="mcp_call",
"Execute the following code: " input=self._python_exec_input(),
"import random; print(random.randint(1, 1000000))" instructions=_PYTHON_TOOL_INSTRUCTION,
), tools=self._mcp_tools_payload(allowed_tools=["*"]),
instructions=( temperature=0.0,
"You must use the Python tool to execute code. "
"Never simulate execution."
),
tools=[
{
"type": "mcp",
"server_label": "code_interpreter",
"server_url": "http://localhost:8888",
# Using "*" to allow all tools from this MCP server
"allowed_tools": ["*"],
}
],
extra_body={"enable_response_messages": True}, extra_body={"enable_response_messages": True},
) )
assert response is not None
assert response.status == "completed" assert response.status == "completed"
# Verify tool calls work with allowed_tools=["*"] log_response_diagnostics(response, label="MCP Allowed Tools *")
tool_call_found = False
for message in response.output_messages: tool_call_found = any(
recipient = message.get("recipient") (msg.get("recipient") or "").startswith("python")
if recipient and recipient.startswith("python"): for msg in response.output_messages
tool_call_found = True )
break
assert tool_call_found, ( assert tool_call_found, (
"Should have found at least one Python tool call with '*'" f"No Python tool call with '*'. "
f"Output types: "
f"{[getattr(o, 'type', None) for o in response.output]}"
) )
@pytest.mark.flaky(reruns=3)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_mcp_tool_calling_streaming_types( async def test_mcp_tool_calling_streaming_types(
self, self,
pairs_of_event_types: dict[str, str], pairs_of_event_types: dict[str, str],
mcp_enabled_client: OpenAI, client: OpenAI,
model_name: str, model_name: str,
): ):
tools = [ def _has_mcp_events(events: list) -> bool:
{ return events_contain_type(events, "mcp_call")
"type": "mcp",
"server_label": "code_interpreter", events = await retry_streaming_for(
} client,
]
input_text = "What is 123 * 456? Use python to calculate the result."
stream_response = await mcp_enabled_client.responses.create(
model=model_name, model=model_name,
input=input_text, validate_events=_has_mcp_events,
tools=tools, input=("What is 123 * 456? Use Python to calculate the result."),
stream=True, tools=[{"type": "mcp", "server_label": "code_interpreter"}],
instructions=( instructions=_PYTHON_TOOL_INSTRUCTION,
"You must use the Python tool to execute code. " temperature=0.0,
"Never simulate execution."
),
) )
stack_of_event_types = [] validate_streaming_event_stack(events, pairs_of_event_types)
saw_mcp_type = False
async for event in stream_response:
if event.type == "response.created":
stack_of_event_types.append(event.type)
elif event.type == "response.completed":
assert stack_of_event_types[-1] == pairs_of_event_types[event.type]
stack_of_event_types.pop()
elif (
event.type.endswith("added")
or event.type == "response.mcp_call.in_progress"
):
stack_of_event_types.append(event.type)
elif event.type.endswith("delta"):
if stack_of_event_types[-1] == event.type:
continue
stack_of_event_types.append(event.type)
elif (
event.type.endswith("done")
or event.type == "response.mcp_call.completed"
):
assert stack_of_event_types[-1] == pairs_of_event_types[event.type]
if "mcp_call" in event.type:
saw_mcp_type = True
stack_of_event_types.pop()
assert len(stack_of_event_types) == 0 assert events_contain_type(events, "mcp_call"), (
assert saw_mcp_type, "Should have seen at least one mcp call" f"No mcp_call events after retries. "
f"Event types: {sorted({e.type for e in events})}"
)
class TestMCPDisabled: class TestMCPDisabled:
"""Tests that verify behavior when MCP tools are disabled.""" """Tests that MCP tools are not executed when the env flag is unset."""
@pytest.fixture(scope="class")
def monkeypatch_class(self):
from _pytest.monkeypatch import MonkeyPatch
mpatch = MonkeyPatch()
yield mpatch
mpatch.undo()
@pytest.fixture(scope="class") @pytest.fixture(scope="class")
def mcp_disabled_server(self, monkeypatch_class: pytest.MonkeyPatch): def mcp_disabled_server(self):
args = ["--enforce-eager", "--tool-server", "demo"] env_dict = {
**BASE_TEST_ENV,
with monkeypatch_class.context() as m: "VLLM_ENABLE_RESPONSES_API_STORE": "1",
m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1") "PYTHON_EXECUTION_BACKEND": "dangerously_use_uv",
m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv") "VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS": "1",
# Helps the model follow instructions better }
m.setenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "1") with RemoteOpenAIServer(
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: MODEL_NAME, list(_BASE_SERVER_ARGS), env_dict=env_dict
yield remote_server ) as remote_server:
yield remote_server
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def mcp_disabled_client(self, mcp_disabled_server): async def client(self, mcp_disabled_server):
async with mcp_disabled_server.get_async_client() as async_client: async with mcp_disabled_server.get_async_client() as async_client:
yield async_client yield async_client
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_mcp_tool_env_flag_disabled( async def test_mcp_disabled_server_does_not_execute(
self, mcp_disabled_client: OpenAI, model_name: str self, client: OpenAI, model_name: str
): ):
response = await mcp_disabled_client.responses.create( """When MCP is disabled the model may still attempt tool calls
(tool descriptions can remain in the prompt), but the server
must NOT execute them."""
response = await client.responses.create(
model=model_name, model=model_name,
input=( input=(
"Execute the following code if the tool is present: " "Execute the following code if the tool is present: "
...@@ -308,38 +287,35 @@ class TestMCPDisabled: ...@@ -308,38 +287,35 @@ class TestMCPDisabled:
{ {
"type": "mcp", "type": "mcp",
"server_label": "code_interpreter", "server_label": "code_interpreter",
# URL unused for DemoToolServer
"server_url": "http://localhost:8888", "server_url": "http://localhost:8888",
} }
], ],
temperature=0.0,
extra_body={"enable_response_messages": True}, extra_body={"enable_response_messages": True},
) )
assert response is not None assert response is not None
assert response.status == "completed" assert response.status == "completed"
# Verify output messages: No tool calls and responses
tool_call_found = False log_response_diagnostics(response, label="MCP Disabled")
tool_response_found = False
# Server must not have executed any tool calls
for message in response.output_messages: for message in response.output_messages:
recipient = message.get("recipient")
if recipient and recipient.startswith("python"):
tool_call_found = True
assert message.get("channel") == "analysis", (
"Tool call should be on analysis channel"
)
author = message.get("author", {}) author = message.get("author", {})
if ( assert not (
author.get("role") == "tool" author.get("role") == "tool"
and author.get("name") and (author.get("name") or "").startswith("python")
and author.get("name").startswith("python") ), (
): "Server executed a python tool call even though MCP is "
tool_response_found = True f"disabled. Message: {message}"
assert message.get("channel") == "analysis", ( )
"Tool response should be on analysis channel"
# No completed mcp_call output items
for item in response.output:
if getattr(item, "type", None) == "mcp_call":
assert getattr(item, "status", None) != "completed", (
"MCP call should not be completed when MCP is disabled"
) )
assert not tool_call_found, "Should not have a python call" # No developer messages injected
assert not tool_response_found, "Should not have a tool response"
for message in response.input_messages: for message in response.input_messages:
assert message.get("author").get("role") != "developer", ( assert message.get("author", {}).get("role") != "developer"
"No developer messages should be present without a valid tool"
)
...@@ -3,15 +3,29 @@ ...@@ -3,15 +3,29 @@
import importlib.util import importlib.util
import json import json
import logging
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from openai import OpenAI from openai import OpenAI
from ....utils import RemoteOpenAIServer from ....utils import RemoteOpenAIServer
from .conftest import (
BASE_TEST_ENV,
has_output_type,
log_response_diagnostics,
retry_for_tool_call,
)
logger = logging.getLogger(__name__)
MODEL_NAME = "Qwen/Qwen3-8B" MODEL_NAME = "Qwen/Qwen3-8B"
_PYTHON_TOOL_INSTRUCTION = (
"You must use the Python tool to execute code. "
"Never simulate execution. You must print the final answer."
)
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(): def server():
...@@ -32,12 +46,12 @@ def server(): ...@@ -32,12 +46,12 @@ def server():
"--tool-server", "--tool-server",
"demo", "demo",
] ]
env_dict = dict( env_dict = {
VLLM_ENABLE_RESPONSES_API_STORE="1", **BASE_TEST_ENV,
VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT="1", "VLLM_ENABLE_RESPONSES_API_STORE": "1",
PYTHON_EXECUTION_BACKEND="dangerously_use_uv", "VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT": "1",
) "PYTHON_EXECUTION_BACKEND": "dangerously_use_uv",
}
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as remote_server: with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as remote_server:
yield remote_server yield remote_server
...@@ -54,6 +68,7 @@ async def test_basic(client: OpenAI, model_name: str): ...@@ -54,6 +68,7 @@ async def test_basic(client: OpenAI, model_name: str):
response = await client.responses.create( response = await client.responses.create(
model=model_name, model=model_name,
input="What is 123 * 456?", input="What is 123 * 456?",
temperature=0.0,
) )
assert response is not None assert response is not None
print("response: ", response) print("response: ", response)
...@@ -99,10 +114,15 @@ async def test_reasoning_and_function_items(client: OpenAI, model_name: str): ...@@ -99,10 +114,15 @@ async def test_reasoning_and_function_items(client: OpenAI, model_name: str):
) )
assert response is not None assert response is not None
assert response.status == "completed" assert response.status == "completed"
# make sure we get a reasoning and text output
assert response.output[0].type == "reasoning" output_types = [getattr(o, "type", None) for o in response.output]
assert response.output[1].type == "message" assert "reasoning" in output_types, (
assert type(response.output[1].content[0].text) is str f"Expected reasoning in output, got: {output_types}"
)
assert "message" in output_types, f"Expected message in output, got: {output_types}"
msg = next(o for o in response.output if o.type == "message")
assert type(msg.content[0].text) is str
def get_horoscope(sign): def get_horoscope(sign):
...@@ -110,10 +130,10 @@ def get_horoscope(sign): ...@@ -110,10 +130,10 @@ def get_horoscope(sign):
def call_function(name, args): def call_function(name, args):
logger.info("Calling function %s with args %s", name, args)
if name == "get_horoscope": if name == "get_horoscope":
return get_horoscope(**args) return get_horoscope(**args)
else: raise ValueError(f"Unknown function: {name}")
raise ValueError(f"Unknown function: {name}")
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -136,61 +156,111 @@ async def test_function_call_first_turn(client: OpenAI, model_name: str): ...@@ -136,61 +156,111 @@ async def test_function_call_first_turn(client: OpenAI, model_name: str):
} }
] ]
response = await client.responses.create( response = await retry_for_tool_call(
client,
model=model_name, model=model_name,
expected_tool_type="function_call",
input="What is the horoscope for Aquarius today?", input="What is the horoscope for Aquarius today?",
tools=tools, tools=tools,
temperature=0.0, temperature=0.0,
) )
assert response is not None assert response is not None
assert response.status == "completed" assert response.status == "completed"
assert len(response.output) == 2
assert response.output[0].type == "reasoning"
assert response.output[1].type == "function_call"
function_call = response.output[1] output_types = [getattr(o, "type", None) for o in response.output]
assert "reasoning" in output_types, (
f"Expected reasoning in output, got: {output_types}"
)
assert has_output_type(response, "function_call"), (
f"Expected function_call in output, got: {output_types}"
)
function_call = next(o for o in response.output if o.type == "function_call")
assert function_call.name == "get_horoscope" assert function_call.name == "get_horoscope"
assert function_call.call_id is not None assert function_call.call_id is not None
args = json.loads(function_call.arguments) args = json.loads(function_call.arguments)
assert "sign" in args assert "sign" in args
# the multi turn function call is tested above in
# test_reasoning_and_function_items
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_mcp_tool_call(client: OpenAI, model_name: str): async def test_mcp_tool_call(client: OpenAI, model_name: str):
response = await client.responses.create( """MCP tool calling with code_interpreter.
The model may make one or more tool calls before producing a final
message. We validate server invariants (mcp_call items have correct
fields) with hard assertions. Output indices are never hardcoded
since the model can produce multiple tool-call rounds.
"""
# MCP + container init + code execution can be slow
client_with_timeout = client.with_options(timeout=client.timeout * 3)
response = await retry_for_tool_call(
client_with_timeout,
model=model_name, model=model_name,
input="What is 123 * 456? Use python to calculate the result.", expected_tool_type="mcp_call",
input=(
"What is 123 * 456? Use python to calculate the result. "
"Print the result with print()."
),
tools=[{"type": "code_interpreter", "container": {"type": "auto"}}], tools=[{"type": "code_interpreter", "container": {"type": "auto"}}],
extra_body={"enable_response_messages": True}, instructions=_PYTHON_TOOL_INSTRUCTION,
temperature=0.0, temperature=0.0,
extra_body={"enable_response_messages": True},
) )
assert response is not None assert response is not None
assert response.status == "completed"
# The model may produce multiple reasoning/mcp_call rounds before the output_types = [getattr(o, "type", None) for o in response.output]
# final message, so validate structurally rather than by exact index. log_response_diagnostics(response, label="test_mcp_tool_call")
output_types = [o.type for o in response.output]
assert "reasoning" in output_types assert response.status == "completed", (
mcp_calls = [o for o in response.output if o.type == "mcp_call"] f"Response status={response.status} "
assert len(mcp_calls) >= 1 f"(details={getattr(response, 'incomplete_details', None)}). "
assert type(mcp_calls[0].arguments) is str f"Output types: {output_types}."
assert type(mcp_calls[0].output) is str )
# The final output should be a message containing the correct answer assert "reasoning" in output_types, (
assert response.output[-1].type == "message" f"Expected reasoning in output, got: {output_types}"
assert any(s in response.output[-1].content[0].text for s in ("56088", "56,088")) )
assert "mcp_call" in output_types, (
# Test raw input_messages / output_messages f"Expected mcp_call in output, got: {output_types}"
assert len(response.input_messages) == 1 )
assert len(response.output_messages) >= 3
# Every mcp_call item must have well-typed fields
for item in response.output:
if getattr(item, "type", None) == "mcp_call":
assert type(item.arguments) is str, (
f"mcp_call.arguments should be str, got {type(item.arguments)}"
)
assert type(item.output) is str, (
f"mcp_call.output should be str, got {type(item.output)}"
)
# The model may make 1+ tool-call rounds but must still produce
# a final message for a trivial calculation like 123 * 456.
message_outputs = [
o for o in response.output if getattr(o, "type", None) == "message"
]
assert message_outputs, (
f"Model did not produce a final message. Output types: {output_types}"
)
final_message = message_outputs[-1]
assert any(s in final_message.content[0].text for s in ("56088", "56,088")), (
f"Expected 56088 in final message, got: {final_message.content[0].text!r}"
)
# Validate raw input_messages / output_messages
assert len(response.input_messages) >= 1, "Expected at least 1 input message"
assert len(response.output_messages) >= 1, "Expected at least 1 output message"
assert any( assert any(
s in response.output_messages[-1]["message"] for s in ("56088", "56,088") any(s in str(msg) for s in ("56088", "56,088"))
for msg in response.output_messages
), (
f"Expected 56088 in at least one output_message, "
f"got {len(response.output_messages)} messages"
) )
...@@ -202,6 +272,7 @@ async def test_max_tokens(client: OpenAI, model_name: str): ...@@ -202,6 +272,7 @@ async def test_max_tokens(client: OpenAI, model_name: str):
input="What is the first paragraph of Moby Dick?", input="What is the first paragraph of Moby Dick?",
reasoning={"effort": "low"}, reasoning={"effort": "low"},
max_output_tokens=30, max_output_tokens=30,
temperature=0.0,
) )
assert response is not None assert response is not None
assert response.status == "incomplete" assert response.status == "incomplete"
......
...@@ -12,13 +12,15 @@ MODEL_NAME = "Qwen/Qwen3-8B" ...@@ -12,13 +12,15 @@ MODEL_NAME = "Qwen/Qwen3-8B"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(): def server():
from .conftest import BASE_TEST_ENV
args = ["--reasoning-parser", "qwen3", "--max_model_len", "5000"] args = ["--reasoning-parser", "qwen3", "--max_model_len", "5000"]
env_dict = dict( env_dict = {
VLLM_ENABLE_RESPONSES_API_STORE="1", **BASE_TEST_ENV,
"VLLM_ENABLE_RESPONSES_API_STORE": "1",
# uncomment for tool calling # uncomment for tool calling
# PYTHON_EXECUTION_BACKEND="dangerously_use_uv", # PYTHON_EXECUTION_BACKEND: "dangerously_use_uv",
) }
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as remote_server: with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as remote_server:
yield remote_server yield remote_server
......
...@@ -128,6 +128,9 @@ class RemoteOpenAIServer: ...@@ -128,6 +128,9 @@ class RemoteOpenAIServer:
env=env, env=env,
stdout=sys.stdout, stdout=sys.stdout,
stderr=sys.stderr, stderr=sys.stderr,
# Create a dedicated process group so we can kill
# the entire tree (parent + EngineCore + workers) at once.
start_new_session=True,
) )
def __init__( def __init__(
...@@ -189,6 +192,15 @@ class RemoteOpenAIServer: ...@@ -189,6 +192,15 @@ class RemoteOpenAIServer:
model_loader = get_model_loader(load_config) model_loader = get_model_loader(load_config)
model_loader.download_model(model_config) model_loader.download_model(model_config)
# Record GPU memory before server start so we know what
# "released" looks like.
self._pre_server_gpu_memory = self._get_gpu_memory_used()
if self._pre_server_gpu_memory is not None:
pre_gb = self._pre_server_gpu_memory / 1e9
print(
f"[RemoteOpenAIServer] GPU memory before server start: {pre_gb:.2f} GB"
)
self._start_server(model, vllm_serve_args, env_dict) self._start_server(model, vllm_serve_args, env_dict)
max_wait_seconds = max_wait_seconds or 360 max_wait_seconds = max_wait_seconds or 360
self._wait_for_server(url=self.url_for("health"), timeout=max_wait_seconds) self._wait_for_server(url=self.url_for("health"), timeout=max_wait_seconds)
...@@ -198,27 +210,69 @@ class RemoteOpenAIServer: ...@@ -198,27 +210,69 @@ class RemoteOpenAIServer:
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
pid = self.proc.pid pid = self.proc.pid
# Graceful shutdown
self.proc.terminate() # Get the process group ID. Because we used
# start_new_session=True the pgid equals the server's pid.
try:
pgid = os.getpgid(pid)
except (ProcessLookupError, OSError):
pgid = None
# Phase 1: graceful SIGTERM to the entire process group
if pgid is not None:
with contextlib.suppress(ProcessLookupError, OSError):
os.killpg(pgid, signal.SIGTERM)
print(f"[RemoteOpenAIServer] Sent SIGTERM to process group {pgid}")
else:
self.proc.terminate()
try: try:
self.proc.wait(timeout=15) self.proc.wait(timeout=15)
print(f"[RemoteOpenAIServer] Server {pid} terminated gracefully") print(f"[RemoteOpenAIServer] Server {pid} terminated gracefully")
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
# Phase 2: SIGKILL the entire process group
print( print(
f"[RemoteOpenAIServer] Server {pid} did not respond " f"[RemoteOpenAIServer] Server {pid} did not respond "
"to SIGTERM, sending SIGKILL" "to SIGTERM, sending SIGKILL to process group"
) )
self.proc.kill() if pgid is not None:
with contextlib.suppress(ProcessLookupError, OSError):
os.killpg(pgid, signal.SIGKILL)
else:
self.proc.kill()
try: try:
self.proc.wait(timeout=5) self.proc.wait(timeout=10)
print(f"[RemoteOpenAIServer] Server {pid} killed") print(f"[RemoteOpenAIServer] Server {pid} killed")
except subprocess.TimeoutExpired as err: except subprocess.TimeoutExpired:
raise RuntimeError( # Phase 3: last resort - find and kill any orphaned children
f"[RemoteOpenAIServer] Failed to kill server process {pid}" self._kill_orphaned_children(pid)
) from err
# Wait for GPU memory to be released # Wait for GPU memory to actually be *freed*, not just
# "stabilized at whatever level it's at".
self._wait_for_gpu_memory_release() self._wait_for_gpu_memory_release()
def _kill_orphaned_children(self, parent_pid: int) -> None:
"""Best-effort cleanup of any lingering child processes."""
try:
import psutil
parent = psutil.Process(parent_pid)
children = parent.children(recursive=True)
for child in children:
print(
f"[RemoteOpenAIServer] Killing orphaned child "
f"pid={child.pid} name={child.name()}"
)
child.kill()
psutil.wait_procs(children, timeout=5)
except Exception as e:
# psutil may not be installed, or processes already gone
print(f"[RemoteOpenAIServer] Orphan cleanup failed: {e}")
# Fallback: try to kill by pgid one more time
with contextlib.suppress(ProcessLookupError, OSError):
os.killpg(parent_pid, signal.SIGKILL)
def _get_gpu_memory_used(self) -> float | None: def _get_gpu_memory_used(self) -> float | None:
"""Get total GPU memory used across all visible devices in bytes.""" """Get total GPU memory used across all visible devices in bytes."""
try: try:
...@@ -244,10 +298,26 @@ class RemoteOpenAIServer: ...@@ -244,10 +298,26 @@ class RemoteOpenAIServer:
return None return None
return None return None
def _wait_for_gpu_memory_release(self, timeout: float = 30.0): def _wait_for_gpu_memory_release(self, timeout: float = 60.0):
"""Poll GPU memory until it stabilizes, indicating cleanup is complete.""" """Wait for GPU memory to drop back toward pre-server levels.
Two-phase strategy:
1. Try to wait for memory to return close to pre-server baseline.
2. If that doesn't happen, fall back to waiting for stabilization
and log a warning (the next server might still OOM).
"""
baseline = self._pre_server_gpu_memory
if baseline is None:
# Can't query GPU memory - nothing to do
return
# Allow up to 2 GiB overhead above baseline for driver/context state
# that may persist between server instances.
headroom_bytes = 2 * 1024 * 1024 * 1024
target = baseline + headroom_bytes
start = time.time() start = time.time()
prev_used: float | None = None last_used: float | None = None
stable_count = 0 stable_count = 0
while time.time() - start < timeout: while time.time() - start < timeout:
...@@ -256,26 +326,49 @@ class RemoteOpenAIServer: ...@@ -256,26 +326,49 @@ class RemoteOpenAIServer:
if used is None: if used is None:
return # Can't query, assume ok return # Can't query, assume ok
if prev_used is not None and abs(used - prev_used) < 100 * 1024 * 1024: used_gb = used / 1e9
stable_count += 1 target_gb = target / 1e9
if stable_count >= 3: elapsed = time.time() - start
used_gb = used / 1e9
print( # Phase 1: memory dropped to near baseline - we're done.
f"[RemoteOpenAIServer] GPU memory stabilized " if used <= target:
f"at {used_gb:.2f} GB" print(
) f"[RemoteOpenAIServer] GPU memory released to "
return f"{used_gb:.2f} GB (target: {target_gb:.2f} GB) "
else: f"in {elapsed:.1f}s"
stable_count = 0 )
return
# Phase 2 (after 40s): fall back to stabilization check.
# This handles cases where another process is using GPU memory
# and we'll never reach baseline.
if elapsed > 40.0 and last_used is not None:
delta = abs(used - last_used)
if delta < 200 * 1024 * 1024: # 200 MB
stable_count += 1
if stable_count >= 3:
print(
f"[RemoteOpenAIServer] WARNING: GPU memory "
f"stabilized at {used_gb:.2f} GB "
f"(target was {target_gb:.2f} GB). "
f"Proceeding - next server may OOM."
)
return
else:
stable_count = 0
prev_used = used last_used = used
time.sleep(0.1) time.sleep(1.0)
last_reading = prev_used / 1e9 if prev_used is not None else 0.0 # Timeout - log clearly so CI failures are diagnosable
final_used = self._get_gpu_memory_used()
final_gb = final_used / 1e9 if final_used else 0.0
raise RuntimeError( raise RuntimeError(
f"[RemoteOpenAIServer] GPU memory did not stabilize within {timeout}s. " f"[RemoteOpenAIServer] GPU memory did not release within "
f"Last reading: {last_reading:.2f} GB. " f"{timeout}s. Current: {final_gb:.2f} GB, "
"Child processes may still be holding GPU memory." f"target: {target / 1e9:.2f} GB, "
f"baseline: {baseline / 1e9:.2f} GB. "
f"Child processes may still be holding GPU memory."
) )
def _poll(self) -> int | None: def _poll(self) -> int | None:
......
...@@ -48,8 +48,11 @@ from vllm.entrypoints.openai.responses.protocol import ( ...@@ -48,8 +48,11 @@ from vllm.entrypoints.openai.responses.protocol import (
ResponseInputOutputItem, ResponseInputOutputItem,
ResponsesRequest, ResponsesRequest,
) )
from vllm.logger import init_logger
from vllm.utils import random_uuid from vllm.utils import random_uuid
logger = init_logger(__name__)
REASONING_EFFORT = { REASONING_EFFORT = {
"high": ReasoningEffort.HIGH, "high": ReasoningEffort.HIGH,
"medium": ReasoningEffort.MEDIUM, "medium": ReasoningEffort.MEDIUM,
...@@ -62,20 +65,15 @@ _harmony_encoding = None ...@@ -62,20 +65,15 @@ _harmony_encoding = None
# they are available and requested by the user. # they are available and requested by the user.
# Tool args are provided by MCP tool descriptions. Output # Tool args are provided by MCP tool descriptions. Output
# of the tools are stringified. # of the tools are stringified.
MCP_BUILTIN_TOOLS: set[str] = {
"web_search_preview",
"code_interpreter",
"container",
}
# Mapping from built-in tool recipient names to their MCP server labels.
# This ensures consistency between streaming and non-streaming responses.
_BUILTIN_TOOL_TO_MCP_SERVER_LABEL: dict[str, str] = { _BUILTIN_TOOL_TO_MCP_SERVER_LABEL: dict[str, str] = {
"python": "code_interpreter", "python": "code_interpreter",
"browser": "web_search_preview", "browser": "web_search_preview",
"container": "container", "container": "container",
} }
# Derive MCP_BUILTIN_TOOLS from the canonical mapping
MCP_BUILTIN_TOOLS: set[str] = set(_BUILTIN_TOOL_TO_MCP_SERVER_LABEL.values())
def has_custom_tools(tool_types: set[str]) -> bool: def has_custom_tools(tool_types: set[str]) -> bool:
""" """
...@@ -116,8 +114,11 @@ def get_system_message( ...@@ -116,8 +114,11 @@ def get_system_message(
REASONING_EFFORT[reasoning_effort] REASONING_EFFORT[reasoning_effort]
) )
if start_date is None: if start_date is None:
# NOTE(woosuk): This brings non-determinism in vLLM. Be careful. # NOTE(woosuk): This brings non-determinism in vLLM.
start_date = datetime.datetime.now().strftime("%Y-%m-%d") # Set VLLM_SYSTEM_START_DATE to pin it.
start_date = envs.VLLM_SYSTEM_START_DATE or datetime.datetime.now().strftime(
"%Y-%m-%d"
)
sys_msg_content = sys_msg_content.with_conversation_start_date(start_date) sys_msg_content = sys_msg_content.with_conversation_start_date(start_date)
if browser_description is not None: if browser_description is not None:
sys_msg_content = sys_msg_content.with_tools(browser_description) sys_msg_content = sys_msg_content.with_tools(browser_description)
...@@ -398,15 +399,60 @@ def parse_chat_input_to_harmony_message( ...@@ -398,15 +399,60 @@ def parse_chat_input_to_harmony_message(
def parse_input_to_harmony_message(chat_msg) -> list[Message]: def parse_input_to_harmony_message(chat_msg) -> list[Message]:
""" """Parse a message from request.previous_input_messages
Parse a message from request.previous_input_messages in the Responsees API to into Harmony messages.
Harmony messages.
Supports both OpenAI chat format ({"role": "..."}) and
Harmony format ({"author": {"role": "..."}}).
""" """
if not isinstance(chat_msg, dict): if not isinstance(chat_msg, dict):
# Handle Pydantic models
chat_msg = chat_msg.model_dump(exclude_none=True) chat_msg = chat_msg.model_dump(exclude_none=True)
if "author" in chat_msg and isinstance(chat_msg.get("author"), dict):
return [_parse_harmony_format_message(chat_msg)]
return _parse_chat_format_message(chat_msg)
def _parse_harmony_format_message(chat_msg: dict) -> Message:
"""Reconstruct a Message from Harmony-format dict,
preserving channel, recipient, and content_type."""
author_dict = chat_msg["author"]
role = author_dict.get("role")
name = author_dict.get("name")
raw_content = chat_msg.get("content", "")
if isinstance(raw_content, list):
# TODO: Support refusal and non-text content types.
contents = [TextContent(text=c.get("text", "")) for c in raw_content]
elif isinstance(raw_content, str):
contents = [TextContent(text=raw_content)]
else:
contents = [TextContent(text="")]
if name:
msg = Message.from_author_and_contents(Author.new(Role(role), name), contents)
else:
msg = Message.from_role_and_contents(Role(role), contents)
channel = chat_msg.get("channel")
if channel:
msg = msg.with_channel(channel)
recipient = chat_msg.get("recipient")
if recipient:
msg = msg.with_recipient(recipient)
content_type = chat_msg.get("content_type")
if content_type:
msg = msg.with_content_type(content_type)
return msg
def _parse_chat_format_message(chat_msg: dict) -> list[Message]:
"""Parse an OpenAI chat-format dict into Harmony messages."""
role = chat_msg.get("role") role = chat_msg.get("role")
if role is None:
raise ValueError(f"Message has no 'role' key: {chat_msg}")
# Assistant message with tool calls # Assistant message with tool calls
tool_calls = chat_msg.get("tool_calls") tool_calls = chat_msg.get("tool_calls")
...@@ -426,15 +472,21 @@ def parse_input_to_harmony_message(chat_msg) -> list[Message]: ...@@ -426,15 +472,21 @@ def parse_input_to_harmony_message(chat_msg) -> list[Message]:
# Tool role message (tool output) # Tool role message (tool output)
if role == "tool": if role == "tool":
name = chat_msg.get("name", "") name = chat_msg.get("name", "")
if name and not name.startswith("functions."):
name = f"functions.{name}"
content = chat_msg.get("content", "") or "" content = chat_msg.get("content", "") or ""
content = flatten_chat_text_content(content) content = flatten_chat_text_content(content)
# NOTE: .with_recipient("assistant") is required on tool messages
msg = Message.from_author_and_content( # to match parse_chat_input_to_harmony_message behavior and ensure
Author.new(Role.TOOL, f"functions.{name}"), content # proper routing in the Harmony protocol.
).with_channel("commentary") msg = (
Message.from_author_and_content(Author.new(Role.TOOL, name), content)
.with_channel("commentary")
.with_recipient("assistant")
)
return [msg] return [msg]
# Default: user/assistant/system messages with content # Default: user/assistant/system messages
content = chat_msg.get("content", "") content = chat_msg.get("content", "")
if isinstance(content, str): if isinstance(content, str):
contents = [TextContent(text=content)] contents = [TextContent(text=content)]
...@@ -497,6 +549,10 @@ def _parse_browser_tool_call(message: Message, recipient: str) -> ResponseOutput ...@@ -497,6 +549,10 @@ def _parse_browser_tool_call(message: Message, recipient: str) -> ResponseOutput
try: try:
browser_call = json.loads(content.text) browser_call = json.loads(content.text)
except json.JSONDecodeError: except json.JSONDecodeError:
logger.warning(
"Invalid JSON in browser tool call, using error placeholder: %s",
content.text,
)
json_retry_output_message = ( json_retry_output_message = (
f"Invalid JSON args, caught and retried: {content.text}" f"Invalid JSON args, caught and retried: {content.text}"
) )
...@@ -730,22 +786,7 @@ def parse_remaining_state(parser: StreamableParser) -> list[ResponseOutputItem]: ...@@ -730,22 +786,7 @@ def parse_remaining_state(parser: StreamableParser) -> list[ResponseOutputItem]:
) )
] ]
if parser.current_channel == "commentary": if parser.current_channel in ("commentary", "analysis"):
return [
ResponseReasoningItem(
id=f"rs_{random_uuid()}",
summary=[],
type="reasoning",
content=[
ResponseReasoningTextContent(
text=parser.current_content, type="reasoning_text"
)
],
status=None,
)
]
if parser.current_channel == "analysis":
return [ return [
ResponseReasoningItem( ResponseReasoningItem(
id=f"rs_{random_uuid()}", id=f"rs_{random_uuid()}",
......
...@@ -346,17 +346,17 @@ class ParsableContext(ConversationContext): ...@@ -346,17 +346,17 @@ class ParsableContext(ConversationContext):
self.parser.response_messages.extend(output) self.parser.response_messages.extend(output)
def need_builtin_tool_call(self) -> bool: def need_builtin_tool_call(self) -> bool:
"""Return true if the last message is a MCP tool call""" """Return true if the last message is a builtin tool call
that the request has enabled."""
last_message = self.parser.response_messages[-1] last_message = self.parser.response_messages[-1]
# TODO(qandrew): figure out which tools are MCP tools if last_message.type != "function_call":
if last_message.type == "function_call": # noqa: SIM102 return False
if last_message.name in ( if last_message.name in ("code_interpreter", "python"):
"code_interpreter", return "python" in self.available_tools
"python", if last_message.name == "web_search_preview":
"web_search_preview", return "browser" in self.available_tools
) or last_message.name.startswith("container"): if last_message.name.startswith("container"):
return True return "container" in self.available_tools
return False return False
async def call_python_tool( async def call_python_tool(
...@@ -665,11 +665,15 @@ class HarmonyContext(ConversationContext): ...@@ -665,11 +665,15 @@ class HarmonyContext(ConversationContext):
def need_builtin_tool_call(self) -> bool: def need_builtin_tool_call(self) -> bool:
last_msg = self.messages[-1] last_msg = self.messages[-1]
recipient = last_msg.recipient recipient = last_msg.recipient
return recipient is not None and ( if recipient is None:
recipient.startswith("browser.") return False
or recipient.startswith("python") if recipient.startswith("browser."):
or recipient.startswith("container.") return "browser" in self.available_tools
) if recipient.startswith("python"):
return "python" in self.available_tools
if recipient.startswith("container."):
return "container" in self.available_tools
return False
async def call_tool(self) -> list[Message]: async def call_tool(self) -> list[Message]:
if not self.messages: if not self.messages:
......
...@@ -392,13 +392,27 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -392,13 +392,27 @@ class OpenAIServingResponses(OpenAIServing):
max_model_len = self.model_config.max_model_len max_model_len = self.model_config.max_model_len
generators: list[AsyncGenerator[ConversationContext, None]] = [] generators: list[AsyncGenerator[ConversationContext, None]] = []
# Only include builtin tools that the request actually asked for.
# Without this filter, tools registered on the server (e.g. via
# --tool-server demo) would be available for execution even when
# the request didn't enable them.
requested_tool_types = extract_tool_types(request.tools)
builtin_tool_list: list[str] = [] builtin_tool_list: list[str] = []
if self.tool_server is not None: if self.tool_server is not None:
if self.tool_server.has_tool("browser"): if (
self.tool_server.has_tool("browser")
and "web_search_preview" in requested_tool_types
):
builtin_tool_list.append("browser") builtin_tool_list.append("browser")
if self.tool_server.has_tool("python"): if (
self.tool_server.has_tool("python")
and "code_interpreter" in requested_tool_types
):
builtin_tool_list.append("python") builtin_tool_list.append("python")
if self.tool_server.has_tool("container"): if (
self.tool_server.has_tool("container")
and "container" in requested_tool_types
):
builtin_tool_list.append("container") builtin_tool_list.append("container")
if self.tool_server is not None: if self.tool_server is not None:
...@@ -1049,9 +1063,15 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -1049,9 +1063,15 @@ class OpenAIServingResponses(OpenAIServing):
# FIXME(woosuk): Currently, request params like reasoning and # FIXME(woosuk): Currently, request params like reasoning and
# instructions are ignored. # instructions are ignored.
prev_msgs = self.msg_store[prev_response.id] prev_msgs = self.msg_store[prev_response.id]
# Remove the previous chain-of-thoughts if there is a new "final"
# message. Note that this also removes these messages from the # FIXME(woosuk): The slice-delete-reappend cycle below is
# msg_store. # currently a no-op --- it removes messages then puts them all
# back unfiltered. It may be intentionally deferred (see FIXME
# above) or redundant if the Harmony encoder already strips
# analysis messages at render time. If analysis messages need
# to be dropped here, add a channel != "analysis" filter when
# re-appending, similar to auto_drop_analysis_messages in
# harmony_utils.py.
if len(prev_msgs) > 0: if len(prev_msgs) > 0:
last_msg = prev_msgs[-1] last_msg = prev_msgs[-1]
assert isinstance(last_msg, OpenAIHarmonyMessage) assert isinstance(last_msg, OpenAIHarmonyMessage)
...@@ -1072,7 +1092,11 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -1072,7 +1092,11 @@ class OpenAIServingResponses(OpenAIServing):
# Append the new input. # Append the new input.
# Responses API supports simple text inputs without chat format. # Responses API supports simple text inputs without chat format.
if isinstance(request.input, str): if isinstance(request.input, str):
messages.append(get_user_message(request.input)) # Skip empty string input when previous_input_messages supplies
# the full conversation history --- an empty trailing user message
# confuses the model into thinking nothing was sent.
if request.input or not request.previous_input_messages:
messages.append(get_user_message(request.input))
else: else:
if prev_response is not None: if prev_response is not None:
prev_outputs = copy(prev_response.output) prev_outputs = copy(prev_response.output)
......
...@@ -209,6 +209,7 @@ if TYPE_CHECKING: ...@@ -209,6 +209,7 @@ if TYPE_CHECKING:
VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS: set[str] = set() VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS: set[str] = set()
VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT: bool = False VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT: bool = False
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
VLLM_SYSTEM_START_DATE: str | None = None
VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY: bool = False VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY: bool = False
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
VLLM_NVTX_SCOPES_FOR_PROFILING: bool = False VLLM_NVTX_SCOPES_FOR_PROFILING: bool = False
...@@ -1458,6 +1459,12 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1458,6 +1459,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS": lambda: bool( "VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS": lambda: bool(
int(os.getenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "0")) int(os.getenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "0"))
), ),
# Pin the conversation start date injected into the Harmony system
# message. When unset the current date is used, which introduces
# non-determinism (different tokens -> different model behaviour at
# temperature=0). Set to an ISO date string, e.g. "2023-09-12",
# for reproducible inference or testing.
"VLLM_SYSTEM_START_DATE": lambda: os.getenv("VLLM_SYSTEM_START_DATE", None),
# Enable automatic retry when tool call JSON parsing fails # Enable automatic retry when tool call JSON parsing fails
# If enabled, returns an error message to the model to retry # If enabled, returns an error message to the model to retry
# If disabled (default), raises an exception and fails the request # If disabled (default), raises an exception and fails the request
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment