Unverified Commit 7ef1964a authored by Keyang Ru's avatar Keyang Ru Committed by GitHub
Browse files

[router] add basic ci tests for gpt-oss model support (#12651)


Co-authored-by: default avatarChang Su <chang.s.su@oracle.com>
parent b0476a06
......@@ -23,6 +23,7 @@ _TEST_DIR = Path(__file__).parent
sys.path.insert(0, str(_TEST_DIR.parent))
from fixtures import popen_launch_workers_and_router
from util import (
DEFAULT_GPT_OSS_MODEL_PATH,
DEFAULT_MODEL_PATH,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
......@@ -293,5 +294,51 @@ The SmartHome Mini is a compact smart home assistant available in black or white
client.models.retrieve("non-existent-model")
class TestOpenAIServerGptOss(TestOpenAIServer):
"""
Test OpenAI API through gRPC router with openai/gpt-oss-20b model.
Extends TestOpenAIServer and only changes the model.
"""
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_GPT_OSS_MODEL_PATH
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.cluster = popen_launch_workers_and_router(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
num_workers=1,
tp_size=2,
policy="round_robin",
api_key=cls.api_key,
)
cls.base_url += "/v1"
cls.tokenizer = get_tokenizer(cls.model)
def test_chat_completion(self):
for parallel_sample_num in [1, 2]:
self.run_chat_completion(None, parallel_sample_num)
def test_chat_completion_stream(self):
for parallel_sample_num in [1, 2]:
self.run_chat_completion_stream(None, parallel_sample_num)
@unittest.skip("Skipping for OSS models")
def test_regex(self):
super().test_regex()
@unittest.skip("Skipping for OSS models")
def test_response_prefill(self):
super().test_response_prefill()
@unittest.skip("Skipping for OSS models")
def test_penalty(self):
super().test_penalty()
if __name__ == "__main__":
unittest.main()
......@@ -202,8 +202,6 @@ def popen_launch_workers_and_router(
"--grpc-mode", # Enable gRPC for this worker
"--mem-fraction-static",
"0.8",
"--attention-backend",
"fa3",
]
# Add TP size
......
......@@ -89,6 +89,9 @@ DEFAULT_MISTRAL_FUNCTION_CALLING_MODEL_PATH = _get_model_path(
"mistralai/Mistral-7B-Instruct-v0.3"
)
# GPT-OSS models
DEFAULT_GPT_OSS_MODEL_PATH = _get_model_path("openai/gpt-oss-20b")
# ============================================================================
# Process Management
......
"""
gRPC backend tests for Response API (including Harmony).
Run with:
python3 -m pytest py_test/e2e_response_api/backends/test_grpc_backend.py -v
python3 -m unittest e2e_response_api.backends.test_grpc_backend.TestGrpcBackend
"""
import sys
import unittest
from pathlib import Path
# Add e2e_response_api directory for imports
_TEST_DIR = Path(__file__).parent.parent
sys.path.insert(0, str(_TEST_DIR))
# Import local modules
from mixins.function_call import FunctionCallingBaseTest
from mixins.mcp import MCPTests
from mixins.state_management import StateManagementTests
from router_fixtures import popen_launch_workers_and_router
from util import kill_process_tree
class TestGrpcBackend(StateManagementTests, MCPTests):
"""End to end tests for gRPC backend."""
@classmethod
def setUpClass(cls):
cls.model = "/home/ubuntu/models/meta-llama/Llama-3.1-8B-Instruct"
cls.base_url_port = "http://127.0.0.1:30030"
cls.cluster = popen_launch_workers_and_router(
cls.model,
cls.base_url_port,
timeout=90,
num_workers=1,
tp_size=2,
policy="round_robin",
router_args=["--history-backend", "memory"],
)
cls.base_url = cls.cluster["base_url"]
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.cluster["router"].pid)
for worker in cls.cluster.get("workers", []):
kill_process_tree(worker.pid)
@unittest.skip(
"TODO: transport error, details: [], metadata: MetadataMap { headers: {} }"
)
def test_previous_response_id_chaining(self):
super().test_previous_response_id_chaining()
@unittest.skip("TODO: return 501 Not Implemented")
def test_conversation_with_multiple_turns(self):
super().test_conversation_with_multiple_turns()
@unittest.skip("TODO: decode error message")
def test_mutually_exclusive_parameters(self):
super().test_mutually_exclusive_parameters()
@unittest.skip(
"TODO: Pipeline execution failed: Pipeline stage WorkerSelection failed"
)
def test_mcp_basic_tool_call(self):
super().test_mcp_basic_tool_call()
@unittest.skip("TODO: no event fields")
def test_mcp_basic_tool_call_streaming(self):
return super().test_mcp_basic_tool_call_streaming()
class TestHarmonyBackend(StateManagementTests, MCPTests, FunctionCallingBaseTest):
"""End to end tests for Harmony backend."""
@classmethod
def setUpClass(cls):
cls.model = "/home/ubuntu/models/openai/gpt-oss-20b"
cls.base_url_port = "http://127.0.0.1:30030"
cls.cluster = popen_launch_workers_and_router(
cls.model,
cls.base_url_port,
timeout=90,
num_workers=1,
tp_size=2,
policy="round_robin",
router_args=["--history-backend", "memory"],
)
cls.base_url = cls.cluster["base_url"]
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.cluster["router"].pid)
for worker in cls.cluster.get("workers", []):
kill_process_tree(worker.pid)
def test_previous_response_id_chaining(self):
super().test_previous_response_id_chaining()
@unittest.skip(
"TODO: fix requests.exceptions.JSONDecodeError: Expecting value: line 1 column 1 (char 0)"
)
def test_mutually_exclusive_parameters(self):
super().test_mutually_exclusive_parameters()
def test_mcp_basic_tool_call(self):
"""Test basic MCP tool call (non-streaming)."""
tools = [
{
"type": "mcp",
"server_label": "deepwiki",
"server_url": "https://mcp.deepwiki.com/mcp",
"require_approval": "never",
}
]
resp = self.create_response(
"What transport protocols does the 2025-03-26 version of the MCP spec (modelcontextprotocol/modelcontextprotocol) support?",
tools=tools,
stream=False,
)
# Should successfully make the request
self.assertEqual(resp.status_code, 200)
data = resp.json()
# Basic response structure
self.assertIn("id", data)
self.assertIn("status", data)
self.assertEqual(data["status"], "completed")
self.assertIn("output", data)
self.assertIn("model", data)
# Verify output array is not empty
output = data["output"]
self.assertIsInstance(output, list)
self.assertGreater(len(output), 0)
# Check for MCP-specific output types
output_types = [item.get("type") for item in output]
# Should have mcp_list_tools - tools are listed before calling
self.assertIn(
"mcp_list_tools", output_types, "Response should contain mcp_list_tools"
)
# Should have at least one mcp_call
mcp_calls = [item for item in output if item.get("type") == "mcp_call"]
self.assertGreater(
len(mcp_calls), 0, "Response should contain at least one mcp_call"
)
# Verify mcp_call structure
for mcp_call in mcp_calls:
self.assertIn("id", mcp_call)
self.assertIn("status", mcp_call)
self.assertEqual(mcp_call["status"], "completed")
self.assertIn("server_label", mcp_call)
self.assertEqual(mcp_call["server_label"], "deepwiki")
self.assertIn("name", mcp_call)
self.assertIn("arguments", mcp_call)
self.assertIn("output", mcp_call)
def test_mcp_basic_tool_call_streaming(self):
"""Test basic MCP tool call (streaming)."""
tools = [
{
"type": "mcp",
"server_label": "deepwiki",
"server_url": "https://mcp.deepwiki.com/mcp",
"require_approval": "never",
}
]
resp = self.create_response(
"What transport protocols does the 2025-03-26 version of the MCP spec (modelcontextprotocol/modelcontextprotocol) support?",
tools=tools,
stream=True,
)
# Should successfully make the request
self.assertEqual(resp.status_code, 200)
events = self.parse_sse_events(resp)
self.assertGreater(len(events), 0)
event_types = [e.get("event") for e in events]
# Check for lifecycle events
self.assertIn(
"response.created", event_types, "Should have response.created event"
)
self.assertIn(
"response.completed", event_types, "Should have response.completed event"
)
# Check for MCP list tools events
self.assertIn(
"response.output_item.added",
event_types,
"Should have output_item.added events",
)
self.assertIn(
"response.mcp_list_tools.in_progress",
event_types,
"Should have mcp_list_tools.in_progress event",
)
self.assertIn(
"response.mcp_list_tools.completed",
event_types,
"Should have mcp_list_tools.completed event",
)
# Check for MCP call events
self.assertIn(
"response.mcp_call.in_progress",
event_types,
"Should have mcp_call.in_progress event",
)
self.assertIn(
"response.mcp_call_arguments.delta",
event_types,
"Should have mcp_call_arguments.delta event",
)
self.assertIn(
"response.mcp_call_arguments.done",
event_types,
"Should have mcp_call_arguments.done event",
)
self.assertIn(
"response.mcp_call.completed",
event_types,
"Should have mcp_call.completed event",
)
# Verify final completed event has full response
completed_events = [e for e in events if e.get("event") == "response.completed"]
self.assertEqual(len(completed_events), 1)
final_response = completed_events[0].get("data", {}).get("response", {})
self.assertIn("id", final_response)
self.assertEqual(final_response.get("status"), "completed")
self.assertIn("output", final_response)
# Verify final output contains expected items
final_output = final_response.get("output", [])
final_output_types = [item.get("type") for item in final_output]
self.assertIn("mcp_list_tools", final_output_types)
self.assertIn("mcp_call", final_output_types)
# Verify mcp_call items in final output
mcp_calls = [item for item in final_output if item.get("type") == "mcp_call"]
self.assertGreater(len(mcp_calls), 0)
for mcp_call in mcp_calls:
self.assertEqual(mcp_call.get("status"), "completed")
self.assertEqual(mcp_call.get("server_label"), "deepwiki")
self.assertIn("name", mcp_call)
self.assertIn("arguments", mcp_call)
self.assertIn("output", mcp_call)
@unittest.skip("TODO: 501 Not Implemented")
def test_conversation_with_multiple_turns(self):
super().test_conversation_with_multiple_turns()
if __name__ == "__main__":
unittest.main()
"""
OpenAI backend tests for Response API.
HTTP backend tests for Response API (OpenAI and XAI).
Run with:
export OPENAI_API_KEY=your_key
python3 -m pytest py_test/e2e_response_api/test_openai_backend.py -v
python3 -m unittest e2e_response_api.test_openai_backend.TestOpenAIStateManagement
export XAI_API_KEY=your_key
python3 -m pytest py_test/e2e_response_api/backends/test_http_backend.py -v
python3 -m unittest e2e_response_api.backends.test_http_backend.TestOpenaiBackend
"""
import os
......@@ -12,23 +13,25 @@ import sys
import unittest
from pathlib import Path
# Add current directory for imports
_TEST_DIR = Path(__file__).parent
# Add e2e_response_api directory for imports
_TEST_DIR = Path(__file__).parent.parent
sys.path.insert(0, str(_TEST_DIR))
# Import local modules
from base import ConversationCRUDBaseTest, ResponseCRUDBaseTest
from mcp import MCPTests
from router_fixtures import (
popen_launch_openai_xai_router,
popen_launch_workers_and_router,
)
from state_management import StateManagementTests
from mixins.basic_crud import ConversationCRUDBaseTest, ResponseCRUDBaseTest
from mixins.function_call import FunctionCallingBaseTest
from mixins.mcp import MCPTests
from mixins.state_management import StateManagementTests
from router_fixtures import popen_launch_openai_xai_router
from util import kill_process_tree
class TestOpenaiBackend(
ResponseCRUDBaseTest, ConversationCRUDBaseTest, StateManagementTests, MCPTests
ResponseCRUDBaseTest,
ConversationCRUDBaseTest,
StateManagementTests,
MCPTests,
FunctionCallingBaseTest,
):
"""End to end tests for OpenAI backend."""
......@@ -75,79 +78,5 @@ class TestXaiBackend(StateManagementTests):
kill_process_tree(cls.cluster["router"].pid)
class TestOracleStore(ResponseCRUDBaseTest, ConversationCRUDBaseTest):
"""End to end tests for Oracle database storage backend."""
api_key = os.environ.get("OPENAI_API_KEY")
@classmethod
def setUpClass(cls):
cls.model = "gpt-5-nano"
cls.base_url_port = "http://127.0.0.1:30040"
cls.cluster = popen_launch_openai_xai_router(
backend="openai",
base_url=cls.base_url_port,
history_backend="oracle",
)
cls.base_url = cls.cluster["base_url"]
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.cluster["router"].pid)
class TestGrpcBackend(StateManagementTests, MCPTests):
"""End to end tests for gRPC backend."""
@classmethod
def setUpClass(cls):
cls.model = "/home/ubuntu/models/meta-llama/Llama-3.1-8B-Instruct"
cls.base_url_port = "http://127.0.0.1:30030"
cls.cluster = popen_launch_workers_and_router(
cls.model,
cls.base_url_port,
timeout=90,
num_workers=1,
tp_size=2,
policy="round_robin",
router_args=["--history-backend", "memory"],
)
cls.base_url = cls.cluster["base_url"]
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.cluster["router"].pid)
for worker in cls.cluster.get("workers", []):
kill_process_tree(worker.pid)
@unittest.skip(
"TODO: transport error, details: [], metadata: MetadataMap { headers: {} }"
)
def test_previous_response_id_chaining(self):
super().test_previous_response_id_chaining()
@unittest.skip("TODO: return 501 Not Implemented")
def test_conversation_with_multiple_turns(self):
super().test_conversation_with_multiple_turns()
@unittest.skip("TODO: decode error message")
def test_mutually_exclusive_parameters(self):
super().test_mutually_exclusive_parameters()
@unittest.skip(
"TODO: Pipeline execution failed: Pipeline stage WorkerSelection failed"
)
def test_mcp_basic_tool_call(self):
super().test_mcp_basic_tool_call()
@unittest.skip("TODO: no event fields")
def test_mcp_basic_tool_call_streaming(self):
return super().test_mcp_basic_tool_call_streaming()
if __name__ == "__main__":
unittest.main()
......@@ -18,6 +18,7 @@ def pytest_collection_modifyitems(config, items):
- ConversationCRUDBaseTest
- MCPTests
- StateManagementTests
- FunctionCallingBaseTest
"""
base_class_names = {
"StateManagementBaseTest",
......@@ -25,6 +26,7 @@ def pytest_collection_modifyitems(config, items):
"ConversationCRUDBaseTest",
"MCPTests",
"StateManagementTests",
"FunctionCallingBaseTest",
}
# Filter out tests from base classes
......
"""
Base test class for function calling tests.
This module provides test cases for function calling functionality
across different backends.
"""
import json
import sys
from pathlib import Path
# Add current directory for local imports
_TEST_DIR = Path(__file__).parent
sys.path.insert(0, str(_TEST_DIR))
from util import CustomTestCase
class ResponseAPIBaseTest(CustomTestCase):
"""Base class for Response API tests with common utilities."""
# To be set by subclasses
base_url: str = None
api_key: str = None
model: str = None
def make_request(
self,
endpoint: str,
method: str = "POST",
json_data: dict = None,
params: dict = None,
):
"""
Make HTTP request to router.
This is a minimal implementation - subclasses should import from basic_crud.
"""
import requests
url = f"{self.base_url}{endpoint}"
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
if method == "POST":
resp = requests.post(url, json=json_data, headers=headers, params=params)
elif method == "GET":
resp = requests.get(url, headers=headers, params=params)
elif method == "DELETE":
resp = requests.delete(url, headers=headers, params=params)
else:
raise ValueError(f"Unsupported method: {method}")
return resp
class FunctionCallingBaseTest(ResponseAPIBaseTest):
"""Base class for function calling tests."""
def test_basic_function_call(self):
"""
Test basic function calling workflow.
This test follows the pattern from function_call_test.py:
1. Define a function tool (get_horoscope)
2. Send user message asking for horoscope
3. Model should return function_call
4. Execute function locally and provide output
5. Model should generate final response using the function output
"""
# 1. Define a list of callable tools for the model
tools = [
{
"type": "function",
"name": "get_horoscope",
"description": "Get today's horoscope for an astrological sign.",
"parameters": {
"type": "object",
"properties": {
"sign": {
"type": "string",
"description": "An astrological sign like Taurus or Aquarius",
},
},
"required": ["sign"],
},
},
]
# Create a running input list we will add to over time
input_list = [
{"role": "user", "content": "What is my horoscope? I am an Aquarius."}
]
# 2. Prompt the model with tools defined
resp = self.make_request(
"/v1/responses",
"POST",
{
"model": self.model,
"tools": tools,
"input": input_list,
},
)
# Should successfully make the request
self.assertEqual(resp.status_code, 200)
data = resp.json()
# Basic response structure
self.assertIn("id", data)
self.assertIn("status", data)
self.assertEqual(data["status"], "completed")
self.assertIn("output", data)
# Verify output array is not empty
output = data["output"]
self.assertIsInstance(output, list)
self.assertGreater(len(output), 0)
# Check for function_call in output
function_calls = [
item for item in output if item.get("type") == "function_call"
]
self.assertGreater(
len(function_calls), 0, "Response should contain at least one function_call"
)
# Verify function_call structure
function_call = function_calls[0]
self.assertIn("call_id", function_call)
self.assertIn("name", function_call)
self.assertEqual(function_call["name"], "get_horoscope")
self.assertIn("arguments", function_call)
# Parse arguments
args = json.loads(function_call["arguments"])
self.assertIn("sign", args)
self.assertEqual(args["sign"].lower(), "aquarius")
# 3. Save function call outputs for subsequent requests
input_list += output
# 4. Execute the function logic for get_horoscope
horoscope = f"{args['sign']}: Next Tuesday you will befriend a baby otter."
# 5. Provide function call results to the model
input_list.append(
{
"type": "function_call_output",
"call_id": function_call["call_id"],
"output": json.dumps({"horoscope": horoscope}),
}
)
# 6. Make second request with function output
resp2 = self.make_request(
"/v1/responses",
"POST",
{
"model": self.model,
"instructions": "Respond only with a horoscope generated by a tool.",
"tools": tools,
"input": input_list,
},
)
self.assertEqual(resp2.status_code, 200)
data2 = resp2.json()
self.assertEqual(data2["status"], "completed")
# The model should be able to give a response using the function output
output2 = data2["output"]
self.assertGreater(len(output2), 0)
# Find message output
messages = [item for item in output2 if item.get("type") == "message"]
self.assertGreater(
len(messages), 0, "Response should contain at least one message"
)
# Verify message contains the horoscope
message = messages[0]
self.assertIn("content", message)
content_parts = message["content"]
self.assertGreater(len(content_parts), 0)
# Get text from content
text_parts = [
part.get("text", "")
for part in content_parts
if part.get("type") == "output_text"
]
full_text = " ".join(text_parts).lower()
# Should mention the horoscope or baby otter
self.assertTrue(
"baby otter" in full_text or "aquarius" in full_text,
"Response should reference the horoscope content",
)
......@@ -5,7 +5,7 @@ Tests MCP tool calling in both streaming and non-streaming modes.
These tests should work across all backends that support MCP (OpenAI, XAI).
"""
from base import ResponseAPIBaseTest
from basic_crud import ResponseAPIBaseTest
class MCPTests(ResponseAPIBaseTest):
......
......@@ -7,7 +7,7 @@ These tests should work across all backends (OpenAI, XAI, gRPC).
import unittest
from base import ResponseAPIBaseTest
from basic_crud import ResponseAPIBaseTest
class StateManagementTests(ResponseAPIBaseTest):
......
"""
Oracle database storage backend tests for Response API.
Run with:
export OPENAI_API_KEY=your_key
python3 -m pytest py_test/e2e_response_api/persistence/test_oracle_store.py -v
python3 -m unittest e2e_response_api.persistence.test_oracle_store.TestOracleStore
"""
import os
import sys
import unittest
from pathlib import Path
# Add e2e_response_api directory for imports
_TEST_DIR = Path(__file__).parent.parent
sys.path.insert(0, str(_TEST_DIR))
# Import local modules
from mixins.basic_crud import ConversationCRUDBaseTest, ResponseCRUDBaseTest
from router_fixtures import popen_launch_openai_xai_router
from util import kill_process_tree
class TestOracleStore(ResponseCRUDBaseTest, ConversationCRUDBaseTest):
"""End to end tests for Oracle database storage backend."""
api_key = os.environ.get("OPENAI_API_KEY")
@classmethod
def setUpClass(cls):
cls.model = "gpt-5-nano"
cls.base_url_port = "http://127.0.0.1:30040"
cls.cluster = popen_launch_openai_xai_router(
backend="openai",
base_url=cls.base_url_port,
history_backend="oracle",
)
cls.base_url = cls.cluster["base_url"]
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.cluster["router"].pid)
if __name__ == "__main__":
unittest.main()
......@@ -425,8 +425,6 @@ def popen_launch_workers_and_router(
"--grpc-mode", # Enable gRPC for this worker
"--mem-fraction-static",
"0.8",
"--attention-backend",
"fa3",
]
# Add TP size
......
......@@ -128,6 +128,7 @@ pub enum ResponseInputOutputItem {
id: String,
summary: Vec<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
#[serde(default)]
content: Vec<ResponseReasoningContent>,
#[serde(skip_serializing_if = "Option::is_none")]
status: Option<String>,
......@@ -168,6 +169,7 @@ pub enum ResponseContentPart {
#[serde(rename = "output_text")]
OutputText {
text: String,
#[serde(default)]
#[serde(skip_serializing_if = "Vec::is_empty")]
annotations: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
......
......@@ -6,4 +6,4 @@ pub mod utils;
pub use handlers::{cancel_response_impl, get_response_impl};
pub use streaming::{build_sse_response, OutputItemType, ResponseStreamEventEmitter};
pub use utils::ensure_mcp_connection;
pub use utils::{ensure_mcp_connection, persist_response_if_needed};
......@@ -11,7 +11,14 @@ use uuid::Uuid;
use crate::{
mcp,
protocols::{chat::ChatCompletionStreamResponse, responses::ResponsesRequest},
protocols::{
chat::ChatCompletionStreamResponse,
common::{Usage, UsageInfo},
responses::{
ResponseOutputItem, ResponseStatus, ResponsesRequest, ResponsesResponse, ResponsesUsage,
},
},
routers::grpc::harmony::responses::ToolResult,
};
pub enum OutputItemType {
......@@ -69,6 +76,7 @@ pub struct ResponseStreamEventEmitter {
has_emitted_content_part_added: bool,
// MCP call tracking
mcp_call_accumulated_args: HashMap<String, String>,
pub(crate) mcp_server_label: Option<String>, // Server label for MCP tools
// Output item tracking
output_items: Vec<OutputItemState>,
next_output_index: usize,
......@@ -93,6 +101,7 @@ impl ResponseStreamEventEmitter {
has_emitted_output_item_added: false,
has_emitted_content_part_added: false,
mcp_call_accumulated_args: HashMap::new(),
mcp_server_label: None,
output_items: Vec::new(),
next_output_index: 0,
current_message_output_index: None,
......@@ -106,6 +115,41 @@ impl ResponseStreamEventEmitter {
self.original_request = Some(request);
}
/// Set the MCP server label for MCP tool calls
pub fn set_mcp_server_label(&mut self, server_label: String) {
self.mcp_server_label = Some(server_label);
}
/// Update mcp_call output items with tool execution results
///
/// After MCP tools are executed, this updates the stored output items
/// to include the output field from the tool results.
pub(crate) fn update_mcp_call_outputs(&mut self, tool_results: &[ToolResult]) {
for tool_result in tool_results {
// Find the output item with matching call_id
for item_state in self.output_items.iter_mut() {
if let Some(ref mut item_data) = item_state.item_data {
// Check if this is an mcp_call item with matching call_id
if item_data.get("type").and_then(|t| t.as_str()) == Some("mcp_call")
&& item_data.get("call_id").and_then(|c| c.as_str())
== Some(&tool_result.call_id)
{
// Add output field
let output_str = serde_json::to_string(&tool_result.output)
.unwrap_or_else(|_| "{}".to_string());
item_data["output"] = json!(output_str);
// Update status based on success
if tool_result.is_error {
item_data["status"] = json!("failed");
}
break;
}
}
}
}
}
fn next_sequence(&mut self) -> u64 {
let seq = self.sequence_number;
self.sequence_number += 1;
......@@ -542,6 +586,78 @@ impl ResponseStreamEventEmitter {
}
}
/// Finalize and return the complete ResponsesResponse
///
/// This constructs the final ResponsesResponse from all accumulated output items
/// for persistence. Should be called after streaming is complete.
pub fn finalize(&self, usage: Option<Usage>) -> ResponsesResponse {
// Build output array from tracked items
let output: Vec<ResponseOutputItem> = self
.output_items
.iter()
.filter_map(|item| {
item.item_data
.as_ref()
.and_then(|data| serde_json::from_value(data.clone()).ok())
})
.collect();
// Convert Usage to ResponsesUsage
let responses_usage = usage.map(|u| {
let usage_info = UsageInfo {
prompt_tokens: u.prompt_tokens,
completion_tokens: u.completion_tokens,
total_tokens: u.total_tokens,
reasoning_tokens: u
.completion_tokens_details
.as_ref()
.and_then(|d| d.reasoning_tokens),
prompt_tokens_details: None,
};
ResponsesUsage::Classic(usage_info)
});
// Get original request fields or use defaults
let req = self.original_request.as_ref();
// Convert tool_choice to String
let tool_choice = req
.and_then(|r| r.tool_choice.as_ref())
.map(|tc| serde_json::to_string(tc).unwrap_or_else(|_| "auto".to_string()))
.unwrap_or_else(|| "auto".to_string());
ResponsesResponse {
id: self.response_id.clone(),
object: "response".to_string(),
created_at: self.created_at as i64,
status: ResponseStatus::Completed,
error: None,
incomplete_details: None,
instructions: req.and_then(|r| r.instructions.clone()),
max_output_tokens: req.and_then(|r| r.max_output_tokens),
model: self.model.clone(),
output,
parallel_tool_calls: req.and_then(|r| r.parallel_tool_calls).unwrap_or(true),
previous_response_id: req.and_then(|r| r.previous_response_id.clone()),
reasoning: None, // TODO: Extract from output items if needed
store: req.and_then(|r| r.store).unwrap_or(true),
temperature: req.and_then(|r| r.temperature),
text: None,
tool_choice,
tools: req
.map(|r| r.tools.clone().unwrap_or_default())
.unwrap_or_default(),
top_p: req.and_then(|r| r.top_p),
truncation: None, // Convert from Truncation to String if needed
usage: responses_usage,
metadata: req
.map(|r| r.metadata.clone().unwrap_or_default())
.unwrap_or_default(),
user: req.and_then(|r| r.user.clone()),
safety_identifier: None,
}
}
/// Emit reasoning item wrapper events (added + done)
///
/// Reasoning items in OpenAI format are simple placeholders emitted between tool iterations.
......
......@@ -6,16 +6,21 @@ use axum::{
http::StatusCode,
response::{IntoResponse, Response},
};
use serde_json::json;
use serde_json::{json, to_value};
use tracing::{debug, warn};
use crate::{
core::WorkerRegistry,
data_connector::{ConversationItemStorage, ConversationStorage, ResponseStorage},
mcp::McpManager,
protocols::{
common::Tool,
responses::{ResponseTool, ResponseToolType},
responses::{ResponseTool, ResponseToolType, ResponsesRequest, ResponsesResponse},
},
routers::{
grpc::error,
openai::{conversations::persist_conversation_items, mcp::ensure_request_mcp_client},
},
routers::{grpc::error, openai::mcp::ensure_request_mcp_client},
};
/// Ensure MCP connection succeeds if MCP tools are declared
......@@ -123,3 +128,35 @@ pub fn extract_tools_from_response_tools(
})
.collect()
}
/// Persist response to storage if store=true
///
/// Common helper function to avoid duplication across sync and streaming paths
/// in both harmony and regular responses implementations.
pub async fn persist_response_if_needed(
conversation_storage: Arc<dyn ConversationStorage>,
conversation_item_storage: Arc<dyn ConversationItemStorage>,
response_storage: Arc<dyn ResponseStorage>,
response: &ResponsesResponse,
original_request: &ResponsesRequest,
) {
if !original_request.store.unwrap_or(true) {
return;
}
if let Ok(response_json) = to_value(response) {
if let Err(e) = persist_conversation_items(
conversation_storage,
conversation_item_storage,
response_storage,
&response_json,
original_request,
)
.await
{
warn!("Failed to persist response: {}", e);
} else {
debug!("Persisted response: {}", response.id);
}
}
}
......@@ -49,7 +49,7 @@ use tracing::{debug, warn};
use uuid::Uuid;
use crate::{
data_connector::{ResponseId, ResponseStorage},
data_connector::{ConversationItemStorage, ConversationStorage, ResponseId, ResponseStorage},
mcp::{self, McpManager},
protocols::{
common::{Function, ToolCall, ToolChoice, ToolChoiceValue, Usage},
......@@ -62,7 +62,7 @@ use crate::{
},
routers::grpc::{
common::responses::{
build_sse_response, ensure_mcp_connection,
build_sse_response, ensure_mcp_connection, persist_response_if_needed,
streaming::{OutputItemType, ResponseStreamEventEmitter},
},
context::SharedComponents,
......@@ -157,6 +157,12 @@ pub struct HarmonyResponsesContext {
/// Response storage for loading conversation history
pub response_storage: Arc<dyn ResponseStorage>,
/// Conversation storage for persisting conversations
pub conversation_storage: Arc<dyn ConversationStorage>,
/// Conversation item storage for persisting conversation items
pub conversation_item_storage: Arc<dyn ConversationItemStorage>,
/// Optional streaming sender (for future streaming support)
pub stream_tx: Option<mpsc::UnboundedSender<Result<String, String>>>,
}
......@@ -168,12 +174,16 @@ impl HarmonyResponsesContext {
components: Arc<SharedComponents>,
mcp_manager: Arc<McpManager>,
response_storage: Arc<dyn ResponseStorage>,
conversation_storage: Arc<dyn ConversationStorage>,
conversation_item_storage: Arc<dyn ConversationItemStorage>,
) -> Self {
Self {
pipeline,
components,
mcp_manager,
response_storage,
conversation_storage,
conversation_item_storage,
stream_tx: None,
}
}
......@@ -184,6 +194,8 @@ impl HarmonyResponsesContext {
components: Arc<SharedComponents>,
mcp_manager: Arc<McpManager>,
response_storage: Arc<dyn ResponseStorage>,
conversation_storage: Arc<dyn ConversationStorage>,
conversation_item_storage: Arc<dyn ConversationItemStorage>,
stream_tx: mpsc::UnboundedSender<Result<String, String>>,
) -> Self {
Self {
......@@ -191,6 +203,8 @@ impl HarmonyResponsesContext {
components,
mcp_manager,
response_storage,
conversation_storage,
conversation_item_storage,
stream_tx: Some(stream_tx),
}
}
......@@ -237,6 +251,9 @@ pub async fn serve_harmony_responses(
ctx: &HarmonyResponsesContext,
request: ResponsesRequest,
) -> Result<ResponsesResponse, Response> {
// Clone request for persistence
let original_request = request.clone();
// Load previous conversation history if previous_response_id is set
let current_request = load_previous_messages(ctx, request).await?;
......@@ -244,12 +261,24 @@ pub async fn serve_harmony_responses(
let has_mcp_tools =
ensure_mcp_connection(&ctx.mcp_manager, current_request.tools.as_deref()).await?;
if has_mcp_tools {
execute_with_mcp_loop(ctx, current_request).await
let response = if has_mcp_tools {
execute_with_mcp_loop(ctx, current_request).await?
} else {
// No MCP tools - execute pipeline once (may have function tools or no tools)
execute_without_mcp_loop(ctx, current_request).await
}
execute_without_mcp_loop(ctx, current_request).await?
};
// Persist response to storage if store=true
persist_response_if_needed(
ctx.conversation_storage.clone(),
ctx.conversation_item_storage.clone(),
ctx.response_storage.clone(),
&response,
&original_request,
)
.await;
Ok(response)
}
/// Execute Harmony Responses with MCP tool loop
......@@ -260,7 +289,10 @@ async fn execute_with_mcp_loop(
mut current_request: ResponsesRequest,
) -> Result<ResponsesResponse, Response> {
let mut iteration_count = 0;
let mut mcp_tracking = McpCallTracking::new("sglang-mcp".to_string());
// Extract server_label from request tools
let server_label = extract_mcp_server_label(current_request.tools.as_deref());
let mut mcp_tracking = McpCallTracking::new(server_label.clone());
// Extract user's max_tool_calls limit (if set)
let max_tool_calls = current_request.max_tool_calls.map(|n| n as usize);
......@@ -472,7 +504,7 @@ pub async fn serve_harmony_responses_stream(
request: ResponsesRequest,
) -> Response {
// Load previous conversation history if previous_response_id is set
let current_request = match load_previous_messages(ctx, request).await {
let current_request = match load_previous_messages(ctx, request.clone()).await {
Ok(req) => req,
Err(err_response) => return err_response,
};
......@@ -517,9 +549,10 @@ pub async fn serve_harmony_responses_stream(
}
if has_mcp_tools {
execute_mcp_tool_loop_streaming(ctx, current_request, &mut emitter, &tx).await;
execute_mcp_tool_loop_streaming(ctx, current_request, &request, &mut emitter, &tx)
.await;
} else {
execute_without_mcp_streaming(ctx, &current_request, &mut emitter, &tx).await;
execute_without_mcp_streaming(ctx, &current_request, &request, &mut emitter, &tx).await;
}
});
......@@ -534,14 +567,22 @@ pub async fn serve_harmony_responses_stream(
/// - Emits mcp_list_tools events
/// - Loops through tool execution iterations
/// - Emits final response.completed event
/// - Persists response internally
async fn execute_mcp_tool_loop_streaming(
ctx: &HarmonyResponsesContext,
mut current_request: ResponsesRequest,
original_request: &ResponsesRequest,
emitter: &mut ResponseStreamEventEmitter,
tx: &mpsc::UnboundedSender<Result<Bytes, std::io::Error>>,
) {
// Extract server_label from request tools
let server_label = extract_mcp_server_label(current_request.tools.as_deref());
// Set server label in emitter for MCP call items
emitter.set_mcp_server_label(server_label.clone());
// Initialize MCP call tracking
let mut mcp_tracking = McpCallTracking::new("sglang-mcp".to_string());
let mut mcp_tracking = McpCallTracking::new(server_label.clone());
// Extract user's max_tool_calls limit (if set)
let max_tool_calls = current_request.max_tool_calls.map(|n| n as usize);
......@@ -576,11 +617,26 @@ async fn execute_mcp_tool_loop_streaming(
})
.collect();
// Build final item with completed status and tools
let item_done = json!({
"id": item_id,
"type": "mcp_list_tools",
"server_label": server_label,
"status": "completed",
"tools": tool_items
});
// Store the completed item data and mark as completed FIRST
// This ensures it appears in final response even if event sending fails
emitter.emit_output_item_done(output_index, &item_done);
emitter.complete_output_item(output_index);
// Now emit all the events (failures won't affect the stored data)
// Emit output_item.added
let item = json!({
"id": item_id,
"type": "mcp_list_tools",
"server_label": "sglang-mcp",
"server_label": server_label,
"status": "in_progress",
"tools": []
});
......@@ -602,20 +658,11 @@ async fn execute_mcp_tool_loop_streaming(
}
// Emit output_item.done
let item_done = json!({
"id": item_id,
"type": "mcp_list_tools",
"server_label": "sglang-mcp",
"status": "completed",
"tools": tool_items
});
let event = emitter.emit_output_item_done(output_index, &item_done);
if emitter.send_event(&event, tx).is_err() {
return;
}
emitter.complete_output_item(output_index);
debug!(
tool_count = mcp_tools.len(),
"Emitted mcp_list_tools on first iteration"
......@@ -736,6 +783,9 @@ async fn execute_mcp_tool_loop_streaming(
}
};
// Update mcp_call output items with execution results
emitter.update_mcp_call_outputs(&tool_results);
// Build next request with appended history
current_request = match build_next_request_with_tools(
current_request,
......@@ -765,6 +815,19 @@ async fn execute_mcp_tool_loop_streaming(
"Harmony Responses streaming completed - no more tool calls"
);
// Finalize response from emitter's accumulated data
let final_response = emitter.finalize(Some(usage.clone()));
// Persist response to storage if store=true
persist_response_if_needed(
ctx.conversation_storage.clone(),
ctx.conversation_item_storage.clone(),
ctx.response_storage.clone(),
&final_response,
original_request,
)
.await;
// Emit response.completed with usage
let usage_json = json!({
"input_tokens": usage.prompt_tokens,
......@@ -786,6 +849,7 @@ async fn execute_mcp_tool_loop_streaming(
async fn execute_without_mcp_streaming(
ctx: &HarmonyResponsesContext,
current_request: &ResponsesRequest,
original_request: &ResponsesRequest,
emitter: &mut ResponseStreamEventEmitter,
tx: &mpsc::UnboundedSender<Result<Bytes, std::io::Error>>,
) {
......@@ -830,6 +894,19 @@ async fn execute_without_mcp_streaming(
ResponsesIterationResult::Completed { usage, .. } => usage,
};
// Finalize response from emitter's accumulated data
let final_response = emitter.finalize(Some(usage.clone()));
// Persist response to storage if store=true
persist_response_if_needed(
ctx.conversation_storage.clone(),
ctx.conversation_item_storage.clone(),
ctx.response_storage.clone(),
&final_response,
original_request,
)
.await;
// Emit response.completed with usage
let usage_json = json!({
"input_tokens": usage.prompt_tokens,
......@@ -1191,19 +1268,19 @@ fn build_next_request_with_tools(
/// Tool execution result
///
/// Contains the result of executing a single MCP tool.
struct ToolResult {
pub(crate) struct ToolResult {
/// Tool call ID (for matching with request)
call_id: String,
pub(crate) call_id: String,
/// Tool name
#[allow(dead_code)] // Kept for documentation and future use
tool_name: String,
pub(crate) tool_name: String,
/// Tool output (JSON value)
output: Value,
pub(crate) output: Value,
/// Whether this is an error result
is_error: bool,
pub(crate) is_error: bool,
}
/// Convert MCP tools to Responses API tool format
......@@ -1297,6 +1374,24 @@ fn inject_mcp_metadata(
response.output.extend(mcp_call_items);
}
/// Extract MCP server label from request tools
///
/// Searches for the first MCP tool in the tools array and returns its server_label.
/// Falls back to "sglang-mcp" if no MCP tool with server_label is found.
fn extract_mcp_server_label(tools: Option<&[ResponseTool]>) -> String {
tools
.and_then(|tools| {
tools.iter().find_map(|tool| {
if matches!(tool.r#type, ResponseToolType::Mcp) {
tool.server_label.clone()
} else {
None
}
})
})
.unwrap_or_else(|| "sglang-mcp".to_string())
}
/// Load previous conversation messages from storage
///
/// If the request has `previous_response_id`, loads the response chain from storage
......
......@@ -733,7 +733,7 @@ impl HarmonyStreamingProcessor {
// Emit output_item.added wrapper event
let call_id = tc_delta.id.as_ref().unwrap();
let item = json!({
let mut item = json!({
"id": item_id,
"type": mode.type_str(),
"name": tool_name,
......@@ -741,6 +741,14 @@ impl HarmonyStreamingProcessor {
"arguments": "",
"status": "in_progress"
});
// Add server_label for MCP calls
if mode.emits_status_events() {
if let Some(ref server_label) = emitter.mcp_server_label {
item["server_label"] = json!(server_label);
}
}
let event = emitter.emit_output_item_added(output_index, &item);
emitter.send_event_best_effort(&event, tx);
......@@ -836,7 +844,7 @@ impl HarmonyStreamingProcessor {
}
// Emit output_item.done wrapper event
let item = json!({
let mut item = json!({
"id": item_id,
"type": mode.type_str(),
"name": tool_name,
......@@ -844,11 +852,21 @@ impl HarmonyStreamingProcessor {
"arguments": args_str,
"status": "completed"
});
// Add server_label for MCP calls
if mode.emits_status_events() {
// MCP mode - include server_label
if let Some(ref server_label) = emitter.mcp_server_label {
item["server_label"] = json!(server_label);
}
}
let event = emitter.emit_output_item_done(*output_index, &item);
emitter.send_event_best_effort(&event, tx);
// Mark output item as completed
// Mark output item as completed before sending
emitter.complete_output_item(*output_index);
emitter.send_event_best_effort(&event, tx);
}
}
}
......@@ -878,9 +896,11 @@ impl HarmonyStreamingProcessor {
}]
});
let event = emitter.emit_output_item_done(output_index, &item);
emitter.send_event_best_effort(&event, tx);
// Mark as completed before sending (so it's included in final output even if send fails)
emitter.complete_output_item(output_index);
emitter.send_event_best_effort(&event, tx);
}
}
Some(proto::generate_response::Response::Error(err)) => {
......@@ -939,7 +959,7 @@ impl HarmonyStreamingProcessor {
}
// Emit output_item.done wrapper event
let item = json!({
let mut item = json!({
"id": item_id,
"type": mode.type_str(),
"name": tool_name,
......@@ -947,11 +967,20 @@ impl HarmonyStreamingProcessor {
"arguments": args_str,
"status": "completed"
});
// Add server_label for MCP calls
if mode.emits_status_events() {
if let Some(ref server_label) = emitter.mcp_server_label {
item["server_label"] = json!(server_label);
}
}
let event = emitter.emit_output_item_done(*output_index, &item);
emitter.send_event_best_effort(&event, tx);
// Mark output item as completed
// Mark output item as completed before sending
emitter.complete_output_item(*output_index);
emitter.send_event_best_effort(&event, tx);
}
}
}
......
......@@ -65,14 +65,12 @@ use crate::{
ResponsesUsage,
},
},
routers::{
grpc::{
common::responses::{
build_sse_response, ensure_mcp_connection, streaming::ResponseStreamEventEmitter,
},
error,
routers::grpc::{
common::responses::{
build_sse_response, ensure_mcp_connection, persist_response_if_needed,
streaming::ResponseStreamEventEmitter,
},
openai::conversations::persist_conversation_items,
error,
},
};
......@@ -221,21 +219,14 @@ async fn route_responses_internal(
};
// 5. Persist response to storage if store=true
if request.store.unwrap_or(true) {
if let Ok(response_json) = serde_json::to_value(&responses_response) {
if let Err(e) = persist_conversation_items(
ctx.conversation_storage.clone(),
ctx.conversation_item_storage.clone(),
ctx.response_storage.clone(),
&response_json,
&request,
)
.await
{
warn!("Failed to persist response: {}", e);
}
}
}
persist_response_if_needed(
ctx.conversation_storage.clone(),
ctx.conversation_item_storage.clone(),
ctx.response_storage.clone(),
&responses_response,
&request,
)
.await;
Ok(responses_response)
}
......@@ -454,25 +445,15 @@ async fn process_and_transform_sse_stream(
event_emitter.send_event(&completed_event, &tx)?;
// Finalize and persist accumulated response
if original_request.store.unwrap_or(true) {
let final_response = accumulator.finalize();
if let Ok(response_json) = serde_json::to_value(&final_response) {
if let Err(e) = persist_conversation_items(
conversation_storage.clone(),
conversation_item_storage.clone(),
response_storage.clone(),
&response_json,
&original_request,
)
.await
{
warn!("Failed to persist streaming response: {}", e);
} else {
debug!("Persisted streaming response: {}", final_response.id);
}
}
}
let final_response = accumulator.finalize();
persist_response_if_needed(
conversation_storage,
conversation_item_storage,
response_storage,
&final_response,
&original_request,
)
.await;
Ok(())
}
......
......@@ -216,6 +216,10 @@ impl GrpcRouter {
self.shared_components.clone(),
self.harmony_responses_context.mcp_manager.clone(),
self.harmony_responses_context.response_storage.clone(),
self.harmony_responses_context.conversation_storage.clone(),
self.harmony_responses_context
.conversation_item_storage
.clone(),
);
if body.stream.unwrap_or(false) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment