Unverified Commit a28b394f authored by Keyang Ru's avatar Keyang Ru Committed by GitHub
Browse files

[router] Add multi-turn tool calling loop support for MCP integration (#11143)

parent 96fe2d0f
...@@ -723,7 +723,10 @@ pub enum ResponseToolType { ...@@ -723,7 +723,10 @@ pub enum ResponseToolType {
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ResponseReasoningParam { pub struct ResponseReasoningParam {
#[serde(default = "default_reasoning_effort")] #[serde(default = "default_reasoning_effort")]
#[serde(skip_serializing_if = "Option::is_none")]
pub effort: Option<ReasoningEffort>, pub effort: Option<ReasoningEffort>,
#[serde(skip_serializing_if = "Option::is_none")]
pub summary: Option<ReasoningSummary>,
} }
fn default_reasoning_effort() -> Option<ReasoningEffort> { fn default_reasoning_effort() -> Option<ReasoningEffort> {
...@@ -738,6 +741,14 @@ pub enum ReasoningEffort { ...@@ -738,6 +741,14 @@ pub enum ReasoningEffort {
High, High,
} }
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ReasoningSummary {
Auto,
Concise,
Detailed,
}
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type")] #[serde(tag = "type")]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
......
...@@ -26,7 +26,6 @@ use std::{ ...@@ -26,7 +26,6 @@ use std::{
collections::HashMap, collections::HashMap,
io, io,
sync::{atomic::AtomicBool, Arc}, sync::{atomic::AtomicBool, Arc},
time::SystemTime,
}; };
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
...@@ -57,6 +56,69 @@ impl std::fmt::Debug for OpenAIRouter { ...@@ -57,6 +56,69 @@ impl std::fmt::Debug for OpenAIRouter {
} }
} }
/// Configuration for MCP tool calling loops
#[derive(Debug, Clone)]
struct McpLoopConfig {
/// Maximum iterations as safety limit (internal only, default: 10)
/// Prevents infinite loops when max_tool_calls is not set
max_iterations: usize,
}
impl Default for McpLoopConfig {
fn default() -> Self {
Self { max_iterations: 10 }
}
}
/// State for tracking multi-turn tool calling loop
struct ToolLoopState {
/// Current iteration number (starts at 0, increments with each tool call)
iteration: usize,
/// Total number of tool calls executed
total_calls: usize,
/// Conversation history (function_call and function_call_output items)
conversation_history: Vec<Value>,
/// Original user input (preserved for building resume payloads)
original_input: ResponseInput,
}
impl ToolLoopState {
fn new(original_input: ResponseInput) -> Self {
Self {
iteration: 0,
total_calls: 0,
conversation_history: Vec::new(),
original_input,
}
}
/// Record a tool call in the loop state
fn record_call(
&mut self,
call_id: String,
tool_name: String,
args_json_str: String,
output_str: String,
) {
// Add function_call item to history
let func_item = json!({
"type": "function_call",
"call_id": call_id,
"name": tool_name,
"arguments": args_json_str
});
self.conversation_history.push(func_item);
// Add function_call_output item to history
let output_item = json!({
"type": "function_call_output",
"call_id": call_id,
"output": output_str
});
self.conversation_history.push(output_item);
}
}
/// Helper that parses SSE frames from the OpenAI responses stream and /// Helper that parses SSE frames from the OpenAI responses stream and
/// accumulates enough information to persist the final response locally. /// accumulates enough information to persist the final response locally.
struct StreamingResponseAccumulator { struct StreamingResponseAccumulator {
...@@ -388,126 +450,32 @@ impl OpenAIRouter { ...@@ -388,126 +450,32 @@ impl OpenAIRouter {
obj.insert("store".to_string(), Value::Bool(original_body.store)); obj.insert("store".to_string(), Value::Bool(original_body.store));
} }
let mut final_response_json = openai_response_json; // If MCP is active and we detect a function call, enter the tool loop
let mut final_response_json = if let Some(mcp) = active_mcp {
if let Some(mcp) = active_mcp { if Self::extract_function_call(&openai_response_json).is_some() {
if let Some((call_id, tool_name, args_json_str)) = // Use the loop to handle potentially multiple tool calls
Self::extract_function_call(&final_response_json) let loop_config = McpLoopConfig::default();
{
info!(
"Detected function call: name={}, call_id={}, args={}",
tool_name, call_id, args_json_str
);
let call_started = SystemTime::now();
let call_result =
Self::execute_mcp_call(mcp, &tool_name, &args_json_str).await;
let call_duration_ms =
call_started.elapsed().unwrap_or_default().as_millis();
let (output_payload, call_ok, call_error) = match call_result {
Ok((server, out)) => {
info!(
call_id = %call_id,
tool_name = %tool_name,
server = %server,
duration_ms = call_duration_ms,
"MCP tool call succeeded"
);
(out, true, None)
}
Err(err) => {
warn!(
call_id = %call_id,
tool_name = %tool_name,
duration_ms = call_duration_ms,
error = %err,
"MCP tool call failed"
);
(
serde_json::json!({
"error": err
})
.to_string(),
false,
Some(err),
)
}
};
match self match self
.resume_with_tool_result(ResumeWithToolArgs { .execute_tool_loop(
url: &url, &url,
headers, headers,
original_payload: &payload, payload.clone(),
call_id: &call_id,
tool_name: &tool_name,
args_json_str: &args_json_str,
output_str: &output_payload,
original_body, original_body,
})
.await
{
Ok(mut resumed_json) => {
// Inject MCP output items (mcp_list_tools and mcp_call)
let server_label = original_body
.tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
.and_then(|t| t.server_label.as_deref())
.unwrap_or("mcp");
if let Err(inject_err) = Self::inject_mcp_output_items(
&mut resumed_json,
mcp, mcp,
McpOutputItemsArgs { &loop_config,
tool_name: &tool_name, )
args_json: &args_json_str, .await
output: &output_payload,
server_label,
success: call_ok,
error: call_error.as_deref(),
},
) {
warn!(
"Failed to inject MCP output items: {}",
inject_err
);
}
if !call_ok {
if let Some(obj) = resumed_json.as_object_mut() {
let metadata_value =
obj.entry("metadata").or_insert_with(|| {
Value::Object(serde_json::Map::new())
});
if let Some(metadata) =
metadata_value.as_object_mut()
{ {
if let Some(err_msg) = call_error.as_ref() { Ok(loop_result) => loop_result,
metadata.insert(
"mcp_error".to_string(),
Value::String(err_msg.clone()),
);
}
}
}
}
final_response_json = resumed_json;
}
Err(err) => { Err(err) => {
warn!("Failed to resume with tool result: {}", err); warn!("Tool loop failed: {}", err);
let error_body = json!({ let error_body = json!({
"error": { "error": {
"message": format!( "message": format!("Tool loop failed: {}", err),
"Failed to resume with tool result: {}",
err
),
"type": "internal_error", "type": "internal_error",
} }
}) })
.to_string(); .to_string();
return ( return (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
[("content-type", "application/json")], [("content-type", "application/json")],
...@@ -517,10 +485,14 @@ impl OpenAIRouter { ...@@ -517,10 +485,14 @@ impl OpenAIRouter {
} }
} }
} else { } else {
info!("No function call found in upstream response; skipping MCP"); // No function call detected, use response as-is
} openai_response_json
} }
} else {
openai_response_json
};
// Mask tools back to MCP format for client
Self::mask_tools_as_mcp(&mut final_response_json, original_body); Self::mask_tools_as_mcp(&mut final_response_json, original_body);
if original_body.store { if original_body.store {
if let Err(e) = self if let Err(e) = self
...@@ -1040,26 +1012,6 @@ impl OpenAIRouter { ...@@ -1040,26 +1012,6 @@ impl OpenAIRouter {
} }
} }
struct ResumeWithToolArgs<'a> {
url: &'a str,
headers: Option<&'a HeaderMap>,
original_payload: &'a Value,
call_id: &'a str,
tool_name: &'a str,
args_json_str: &'a str,
output_str: &'a str,
original_body: &'a ResponsesRequest,
}
struct McpOutputItemsArgs<'a> {
tool_name: &'a str,
args_json: &'a str,
output: &'a str,
server_label: &'a str,
success: bool,
error: Option<&'a str>,
}
impl OpenAIRouter { impl OpenAIRouter {
fn extract_function_call(resp: &Value) -> Option<(String, String, String)> { fn extract_function_call(resp: &Value) -> Option<(String, String, String)> {
let output = resp.get("output")?.as_array()?; let output = resp.get("output")?.as_array()?;
...@@ -1150,6 +1102,375 @@ impl OpenAIRouter { ...@@ -1150,6 +1102,375 @@ impl OpenAIRouter {
Ok((server_name, output_str)) Ok((server_name, output_str))
} }
/// Build a resume payload with conversation history
fn build_resume_payload(
base_payload: &Value,
conversation_history: &[Value],
original_input: &ResponseInput,
tools_json: &Value,
) -> Result<Value, String> {
// Clone the base payload which already has cleaned fields
let mut payload = base_payload.clone();
let obj = payload
.as_object_mut()
.ok_or_else(|| "payload not an object".to_string())?;
// Build input array: start with original user input
let mut input_array = Vec::new();
// Add original user message
// For structured input, serialize the original input items
match original_input {
ResponseInput::Text(text) => {
let user_item = json!({
"type": "message",
"role": "user",
"content": [{ "type": "input_text", "text": text }]
});
input_array.push(user_item);
}
ResponseInput::Items(items) => {
// Items are already structured ResponseInputOutputItem, convert to JSON
if let Ok(items_value) = serde_json::to_value(items) {
if let Some(items_arr) = items_value.as_array() {
input_array.extend_from_slice(items_arr);
}
}
}
}
// Add all conversation history (function calls and outputs)
input_array.extend_from_slice(conversation_history);
obj.insert("input".to_string(), Value::Array(input_array));
// Use the transformed tools (function tools, not MCP tools)
if let Some(tools_arr) = tools_json.as_array() {
if !tools_arr.is_empty() {
obj.insert("tools".to_string(), tools_json.clone());
}
}
// Ensure non-streaming and no store to upstream
obj.insert("stream".to_string(), Value::Bool(false));
obj.insert("store".to_string(), Value::Bool(false));
// Note: SGLang-specific fields were already removed from base_payload
// before it was passed to execute_tool_loop (see route_responses lines 1935-1946)
Ok(payload)
}
/// Helper function to build mcp_call items from executed tool calls in conversation history
fn build_executed_mcp_call_items(
conversation_history: &[Value],
server_label: &str,
) -> Vec<Value> {
let mut mcp_call_items = Vec::new();
for item in conversation_history {
if item.get("type").and_then(|t| t.as_str()) == Some("function_call") {
let call_id = item.get("call_id").and_then(|v| v.as_str()).unwrap_or("");
let tool_name = item.get("name").and_then(|v| v.as_str()).unwrap_or("");
let args = item
.get("arguments")
.and_then(|v| v.as_str())
.unwrap_or("{}");
// Find corresponding output
let output_item = conversation_history.iter().find(|o| {
o.get("type").and_then(|t| t.as_str()) == Some("function_call_output")
&& o.get("call_id").and_then(|c| c.as_str()) == Some(call_id)
});
let output_str = output_item
.and_then(|o| o.get("output").and_then(|v| v.as_str()))
.unwrap_or("{}");
// Check if output contains error by parsing JSON
let is_error = serde_json::from_str::<serde_json::Value>(output_str)
.map(|v| v.get("error").is_some())
.unwrap_or(false);
let mcp_call_item = Self::build_mcp_call_item(
tool_name,
args,
output_str,
server_label,
!is_error,
if is_error {
Some("Tool execution failed")
} else {
None
},
);
mcp_call_items.push(mcp_call_item);
}
}
mcp_call_items
}
/// Build an incomplete response when limits are exceeded
fn build_incomplete_response(
mut response: Value,
state: ToolLoopState,
reason: &str,
active_mcp: &Arc<crate::mcp::McpClientManager>,
original_body: &ResponsesRequest,
) -> Result<Value, String> {
let obj = response
.as_object_mut()
.ok_or_else(|| "response not an object".to_string())?;
// Set status to completed (not failed - partial success)
obj.insert("status".to_string(), Value::String("completed".to_string()));
// Set incomplete_details
obj.insert(
"incomplete_details".to_string(),
json!({ "reason": reason }),
);
// Convert any function_call in output to mcp_call format
if let Some(output_array) = obj.get_mut("output").and_then(|v| v.as_array_mut()) {
let server_label = original_body
.tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
.and_then(|t| t.server_label.as_deref())
.unwrap_or("mcp");
// Find any function_call items and convert them to mcp_call (incomplete)
let mut mcp_call_items = Vec::new();
for item in output_array.iter() {
if item.get("type").and_then(|t| t.as_str()) == Some("function_tool_call") {
let tool_name = item.get("name").and_then(|v| v.as_str()).unwrap_or("");
let args = item
.get("arguments")
.and_then(|v| v.as_str())
.unwrap_or("{}");
// Mark as incomplete - not executed
let mcp_call_item = Self::build_mcp_call_item(
tool_name,
args,
"", // No output - wasn't executed
server_label,
false, // Not successful
Some("Not executed - response stopped due to limit"),
);
mcp_call_items.push(mcp_call_item);
}
}
// Add mcp_list_tools and executed mcp_call items at the beginning
if state.total_calls > 0 || !mcp_call_items.is_empty() {
let list_tools_item = Self::build_mcp_list_tools_item(active_mcp, server_label);
output_array.insert(0, list_tools_item);
// Add mcp_call items for executed calls using helper
let executed_items =
Self::build_executed_mcp_call_items(&state.conversation_history, server_label);
let mut insert_pos = 1;
for item in executed_items {
output_array.insert(insert_pos, item);
insert_pos += 1;
}
// Add incomplete mcp_call items
for item in mcp_call_items {
output_array.insert(insert_pos, item);
insert_pos += 1;
}
}
}
// Add warning to metadata
if let Some(metadata_val) = obj.get_mut("metadata") {
if let Some(metadata_obj) = metadata_val.as_object_mut() {
if let Some(mcp_val) = metadata_obj.get_mut("mcp") {
if let Some(mcp_obj) = mcp_val.as_object_mut() {
mcp_obj.insert(
"truncation_warning".to_string(),
Value::String(format!(
"Loop terminated at {} iterations, {} total calls (reason: {})",
state.iteration, state.total_calls, reason
)),
);
}
}
}
}
Ok(response)
}
/// Execute the tool calling loop
async fn execute_tool_loop(
&self,
url: &str,
headers: Option<&HeaderMap>,
initial_payload: Value,
original_body: &ResponsesRequest,
active_mcp: &Arc<crate::mcp::McpClientManager>,
config: &McpLoopConfig,
) -> Result<Value, String> {
let mut state = ToolLoopState::new(original_body.input.clone());
// Get max_tool_calls from request (None means no user-specified limit)
let max_tool_calls = original_body.max_tool_calls.map(|n| n as usize);
// Keep initial_payload as base template (already has fields cleaned)
let base_payload = initial_payload.clone();
let tools_json = base_payload.get("tools").cloned().unwrap_or(json!([]));
let mut current_payload = initial_payload;
info!(
"Starting tool loop: max_tool_calls={:?}, max_iterations={}",
max_tool_calls, config.max_iterations
);
loop {
// Make request to upstream
let request_builder = self.client.post(url).json(&current_payload);
let request_builder = if let Some(headers) = headers {
apply_request_headers(headers, request_builder, true)
} else {
request_builder
};
let response = request_builder
.send()
.await
.map_err(|e| format!("upstream request failed: {}", e))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(format!("upstream error {}: {}", status, body));
}
let mut response_json = response
.json::<Value>()
.await
.map_err(|e| format!("parse response: {}", e))?;
// Check for function call
if let Some((call_id, tool_name, args_json_str)) =
Self::extract_function_call(&response_json)
{
state.iteration += 1;
state.total_calls += 1;
info!(
"Tool loop iteration {}: calling {} (call_id: {})",
state.iteration, tool_name, call_id
);
// Check combined limit: use minimum of user's max_tool_calls (if set) and safety max_iterations
let effective_limit = match max_tool_calls {
Some(user_max) => user_max.min(config.max_iterations),
None => config.max_iterations,
};
if state.total_calls > effective_limit {
if let Some(user_max) = max_tool_calls {
if state.total_calls > user_max {
warn!("Reached user-specified max_tool_calls limit: {}", user_max);
} else {
warn!(
"Reached safety max_iterations limit: {}",
config.max_iterations
);
}
} else {
warn!(
"Reached safety max_iterations limit: {}",
config.max_iterations
);
}
return Self::build_incomplete_response(
response_json,
state,
"max_tool_calls",
active_mcp,
original_body,
);
}
// Execute tool
let call_result =
Self::execute_mcp_call(active_mcp, &tool_name, &args_json_str).await;
let output_str = match call_result {
Ok((_, output)) => output,
Err(err) => {
warn!("Tool execution failed: {}", err);
// Return error as output, let model decide how to proceed
json!({ "error": err }).to_string()
}
};
// Record the call
state.record_call(call_id, tool_name, args_json_str, output_str);
// Build resume payload
current_payload = Self::build_resume_payload(
&base_payload,
&state.conversation_history,
&state.original_input,
&tools_json,
)?;
} else {
// No more tool calls, we're done
info!(
"Tool loop completed: {} iterations, {} total calls",
state.iteration, state.total_calls
);
// Inject MCP output items if we executed any tools
if state.total_calls > 0 {
let server_label = original_body
.tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
.and_then(|t| t.server_label.as_deref())
.unwrap_or("mcp");
// Build mcp_list_tools item
let list_tools_item = Self::build_mcp_list_tools_item(active_mcp, server_label);
// Insert at beginning of output array
if let Some(output_array) = response_json
.get_mut("output")
.and_then(|v| v.as_array_mut())
{
output_array.insert(0, list_tools_item);
// Build mcp_call items using helper function
let mcp_call_items = Self::build_executed_mcp_call_items(
&state.conversation_history,
server_label,
);
// Insert mcp_call items after mcp_list_tools using mutable position
let mut insert_pos = 1;
for item in mcp_call_items {
output_array.insert(insert_pos, item);
insert_pos += 1;
}
}
}
return Ok(response_json);
}
}
}
/// Generate a unique ID for MCP output items (similar to OpenAI format) /// Generate a unique ID for MCP output items (similar to OpenAI format)
fn generate_mcp_id(prefix: &str) -> String { fn generate_mcp_id(prefix: &str) -> String {
use rand::RngCore; use rand::RngCore;
...@@ -1213,113 +1534,6 @@ impl OpenAIRouter { ...@@ -1213,113 +1534,6 @@ impl OpenAIRouter {
"server_label": server_label "server_label": server_label
}) })
} }
/// Inject mcp_list_tools and mcp_call items into the response output array
fn inject_mcp_output_items(
response_json: &mut Value,
mcp: &Arc<crate::mcp::McpClientManager>,
args: McpOutputItemsArgs,
) -> Result<(), String> {
let output_array = response_json
.get_mut("output")
.and_then(|v| v.as_array_mut())
.ok_or("missing output array")?;
// Build MCP output items
let list_tools_item = Self::build_mcp_list_tools_item(mcp, args.server_label);
let call_item = Self::build_mcp_call_item(
args.tool_name,
args.args_json,
args.output,
args.server_label,
args.success,
args.error,
);
// Find the index of the last message item to insert mcp_call before it
let call_insertion_index = output_array
.iter()
.rposition(|item| item.get("type").and_then(|v| v.as_str()) == Some("message"))
.unwrap_or(output_array.len());
// Insert items in-place for efficiency
output_array.insert(call_insertion_index, call_item);
output_array.insert(0, list_tools_item);
Ok(())
}
async fn resume_with_tool_result(&self, args: ResumeWithToolArgs<'_>) -> Result<Value, String> {
let mut payload2 = args.original_payload.clone();
let obj = payload2
.as_object_mut()
.ok_or_else(|| "payload not an object".to_string())?;
// Build function_call and tool result items per OpenAI Responses spec
let user_item = serde_json::json!({
"type": "message",
"role": "user",
"content": args.original_body.input.clone()
});
// temp system message since currently only support 1 turn of mcp function call
let system_item = serde_json::json!({
"type": "message",
"role": "system",
"content": "please resume with the following tool result, and answer user's question directly, don't trigger any more tool calls"
});
let func_item = serde_json::json!({
"type": "function_call",
"call_id": args.call_id,
"name": args.tool_name,
"arguments": args.args_json_str
});
// Build tool result item as function_call_output per OpenAI Responses spec
let tool_item = serde_json::json!({
"type": "function_call_output",
"call_id": args.call_id,
"output": args.output_str
});
obj.insert(
"input".to_string(),
Value::Array(vec![user_item, system_item, func_item, tool_item]),
);
// Ensure non-streaming and no store to upstream
obj.insert("stream".to_string(), Value::Bool(false));
obj.insert("store".to_string(), Value::Bool(false));
let mut req = self.client.post(args.url).json(&payload2);
if let Some(headers) = args.headers {
req = apply_request_headers(headers, req, true);
}
let resp = req
.send()
.await
.map_err(|e| format!("resume request failed: {}", e))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(format!("resume upstream error {}: {}", status, body));
}
let mut v = resp
.json::<Value>()
.await
.map_err(|e| format!("parse resume response: {}", e))?;
if let Some(instr) = &args.original_body.instructions {
if let Some(obj) = v.as_object_mut() {
obj.entry("instructions")
.or_insert(Value::String(instr.clone()));
}
}
// After resume, mask tools as MCP if request used MCP
Self::mask_tools_as_mcp(&mut v, args.original_body);
if let Some(obj) = v.as_object_mut() {
obj.insert("store".to_string(), Value::Bool(args.original_body.store));
}
Ok(v)
}
} }
#[async_trait] #[async_trait]
......
...@@ -252,6 +252,7 @@ fn test_responses_request_creation() { ...@@ -252,6 +252,7 @@ fn test_responses_request_creation() {
previous_response_id: None, previous_response_id: None,
reasoning: Some(ResponseReasoningParam { reasoning: Some(ResponseReasoningParam {
effort: Some(ReasoningEffort::Medium), effort: Some(ReasoningEffort::Medium),
summary: None,
}), }),
service_tier: ServiceTier::Auto, service_tier: ServiceTier::Auto,
store: true, store: true,
...@@ -380,6 +381,7 @@ fn test_usage_conversion() { ...@@ -380,6 +381,7 @@ fn test_usage_conversion() {
fn test_reasoning_param_default() { fn test_reasoning_param_default() {
let param = ResponseReasoningParam { let param = ResponseReasoningParam {
effort: Some(ReasoningEffort::Medium), effort: Some(ReasoningEffort::Medium),
summary: None,
}; };
let json = serde_json::to_string(&param).unwrap(); let json = serde_json::to_string(&param).unwrap();
...@@ -403,6 +405,7 @@ fn test_json_serialization() { ...@@ -403,6 +405,7 @@ fn test_json_serialization() {
previous_response_id: None, previous_response_id: None,
reasoning: Some(ResponseReasoningParam { reasoning: Some(ResponseReasoningParam {
effort: Some(ReasoningEffort::High), effort: Some(ReasoningEffort::High),
summary: None,
}), }),
service_tier: ServiceTier::Priority, service_tier: ServiceTier::Priority,
store: false, store: false,
...@@ -437,3 +440,328 @@ fn test_json_serialization() { ...@@ -437,3 +440,328 @@ fn test_json_serialization() {
assert!(parsed.stream); assert!(parsed.stream);
assert_eq!(parsed.tools.len(), 1); assert_eq!(parsed.tools.len(), 1);
} }
#[tokio::test]
async fn test_multi_turn_loop_with_mcp() {
// This test verifies the multi-turn loop functionality:
// 1. Initial request with MCP tools
// 2. Mock worker returns function_call
// 3. Router executes MCP tool and resumes
// 4. Mock worker returns final answer
// 5. Verify the complete flow worked
// Start mock MCP server
let mut mcp = MockMCPServer::start().await.expect("start mcp");
// Write a temp MCP config file
let mcp_yaml = format!(
"servers:\n - name: mock\n protocol: streamable\n url: {}\n",
mcp.url()
);
let dir = tempfile::tempdir().expect("tmpdir");
let cfg_path = dir.path().join("mcp.yaml");
std::fs::write(&cfg_path, mcp_yaml).expect("write mcp cfg");
std::env::set_var("SGLANG_MCP_CONFIG", cfg_path.to_str().unwrap());
// Start mock OpenAI worker
let mut worker = MockWorker::new(MockWorkerConfig {
port: 0,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
});
let worker_url = worker.start().await.expect("start worker");
// Build router config
let router_cfg = RouterConfig {
mode: RoutingMode::OpenAI {
worker_urls: vec![worker_url],
},
connection_mode: ConnectionMode::Http,
policy: PolicyConfig::Random,
host: "127.0.0.1".to_string(),
port: 0,
max_payload_size: 8 * 1024 * 1024,
request_timeout_secs: 60,
worker_startup_timeout_secs: 5,
worker_startup_check_interval_secs: 1,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,
log_level: Some("info".to_string()),
request_id_headers: None,
max_concurrent_requests: 32,
queue_size: 0,
queue_timeout_secs: 5,
rate_limit_tokens_per_second: None,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: HealthCheckConfig::default(),
enable_igw: false,
model_path: None,
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
};
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx");
let router = RouterFactory::create_router(&Arc::new(ctx))
.await
.expect("router");
// Build request with MCP tools
let req = ResponsesRequest {
background: false,
include: None,
input: ResponseInput::Text("search for SGLang".to_string()),
instructions: Some("Be helpful".to_string()),
max_output_tokens: Some(128),
max_tool_calls: None, // No limit - test unlimited
metadata: None,
model: Some("mock-model".to_string()),
parallel_tool_calls: true,
previous_response_id: None,
reasoning: None,
service_tier: ServiceTier::Auto,
store: true,
stream: false,
temperature: Some(0.7),
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
tools: vec![ResponseTool {
r#type: ResponseToolType::Mcp,
server_url: Some(mcp.url()),
server_label: Some("mock".to_string()),
server_description: Some("Mock MCP server for testing".to_string()),
require_approval: Some("never".to_string()),
..Default::default()
}],
top_logprobs: 0,
top_p: Some(1.0),
truncation: Truncation::Disabled,
user: None,
request_id: "resp_multi_turn_test".to_string(),
priority: 0,
frequency_penalty: 0.0,
presence_penalty: 0.0,
stop: None,
top_k: 50,
min_p: 0.0,
repetition_penalty: 1.0,
};
// Execute the request (this should trigger the multi-turn loop)
let response = router.route_responses(None, &req, None).await;
// Check status
assert_eq!(
response.status(),
axum::http::StatusCode::OK,
"Request should succeed"
);
// Read the response body
use axum::body::to_bytes;
let response_body = response.into_body();
let body_bytes = to_bytes(response_body, usize::MAX).await.unwrap();
let response_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
println!(
"Multi-turn response: {}",
serde_json::to_string_pretty(&response_json).unwrap()
);
// Verify the response structure
assert_eq!(response_json["object"], "response");
assert_eq!(response_json["status"], "completed");
// Note: mock worker generates its own ID, so we just verify it exists
assert!(
response_json["id"].is_string(),
"Response should have an id"
);
// Check that output contains final message
let output = response_json["output"]
.as_array()
.expect("output should be array");
assert!(!output.is_empty(), "output should not be empty");
// Find the final message with text
let has_final_text = output.iter().any(|item| {
item.get("type")
.and_then(|t| t.as_str())
.map(|t| t == "message")
.unwrap_or(false)
&& item
.get("content")
.and_then(|c| c.as_array())
.map(|arr| {
arr.iter().any(|part| {
part.get("type")
.and_then(|t| t.as_str())
.map(|t| t == "output_text")
.unwrap_or(false)
})
})
.unwrap_or(false)
});
assert!(has_final_text, "Should have final text output");
// Verify tools are masked back to MCP format
let tools = response_json["tools"]
.as_array()
.expect("tools should be array");
assert_eq!(tools.len(), 1);
assert_eq!(tools[0]["type"], "mcp");
assert_eq!(tools[0]["server_label"], "mock");
// Clean up
std::env::remove_var("SGLANG_MCP_CONFIG");
worker.stop().await;
mcp.stop().await;
}
#[tokio::test]
async fn test_max_tool_calls_limit() {
// This test verifies that max_tool_calls is respected
// Note: The mock worker returns a final answer after one tool call,
// so with max_tool_calls=1, it completes normally (doesn't exceed the limit)
let mut mcp = MockMCPServer::start().await.expect("start mcp");
let mcp_yaml = format!(
"servers:\n - name: mock\n protocol: streamable\n url: {}\n",
mcp.url()
);
let dir = tempfile::tempdir().expect("tmpdir");
let cfg_path = dir.path().join("mcp.yaml");
std::fs::write(&cfg_path, mcp_yaml).expect("write mcp cfg");
std::env::set_var("SGLANG_MCP_CONFIG", cfg_path.to_str().unwrap());
let mut worker = MockWorker::new(MockWorkerConfig {
port: 0,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
});
let worker_url = worker.start().await.expect("start worker");
let router_cfg = RouterConfig {
mode: RoutingMode::OpenAI {
worker_urls: vec![worker_url],
},
connection_mode: ConnectionMode::Http,
policy: PolicyConfig::Random,
host: "127.0.0.1".to_string(),
port: 0,
max_payload_size: 8 * 1024 * 1024,
request_timeout_secs: 60,
worker_startup_timeout_secs: 5,
worker_startup_check_interval_secs: 1,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,
log_level: Some("info".to_string()),
request_id_headers: None,
max_concurrent_requests: 32,
queue_size: 0,
queue_timeout_secs: 5,
rate_limit_tokens_per_second: None,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: HealthCheckConfig::default(),
enable_igw: false,
model_path: None,
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
};
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx");
let router = RouterFactory::create_router(&Arc::new(ctx))
.await
.expect("router");
let req = ResponsesRequest {
background: false,
include: None,
input: ResponseInput::Text("test max calls".to_string()),
instructions: None,
max_output_tokens: Some(128),
max_tool_calls: Some(1), // Limit to 1 call
metadata: None,
model: Some("mock-model".to_string()),
parallel_tool_calls: true,
previous_response_id: None,
reasoning: None,
service_tier: ServiceTier::Auto,
store: false,
stream: false,
temperature: Some(0.7),
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
tools: vec![ResponseTool {
r#type: ResponseToolType::Mcp,
server_url: Some(mcp.url()),
server_label: Some("mock".to_string()),
..Default::default()
}],
top_logprobs: 0,
top_p: Some(1.0),
truncation: Truncation::Disabled,
user: None,
request_id: "resp_max_calls_test".to_string(),
priority: 0,
frequency_penalty: 0.0,
presence_penalty: 0.0,
stop: None,
top_k: 50,
min_p: 0.0,
repetition_penalty: 1.0,
};
let response = router.route_responses(None, &req, None).await;
assert_eq!(response.status(), axum::http::StatusCode::OK);
use axum::body::to_bytes;
let response_body = response.into_body();
let body_bytes = to_bytes(response_body, usize::MAX).await.unwrap();
let response_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
println!(
"Max calls response: {}",
serde_json::to_string_pretty(&response_json).unwrap()
);
// With max_tool_calls=1, the mock returns a final answer after 1 call
// So it completes normally without exceeding the limit
assert_eq!(response_json["status"], "completed");
// Verify the basic response structure
assert!(response_json["id"].is_string());
assert_eq!(response_json["object"], "response");
// The response should have tools masked back to MCP format
let tools = response_json["tools"]
.as_array()
.expect("tools should be array");
assert_eq!(tools.len(), 1);
assert_eq!(tools[0]["type"], "mcp");
// Note: To test actual limit exceeding, we would need a mock that keeps
// calling tools indefinitely, which would hit max_iterations (safety limit)
std::env::remove_var("SGLANG_MCP_CONFIG");
worker.stop().await;
mcp.stop().await;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment