Unverified Commit 9e949e58 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] centralize mcp tool args handling (#12155)

parent 6dbb569b
...@@ -31,6 +31,9 @@ pub enum McpError { ...@@ -31,6 +31,9 @@ pub enum McpError {
#[error("Prompt not found: {0}")] #[error("Prompt not found: {0}")]
PromptNotFound(String), PromptNotFound(String),
#[error("Invalid arguments: {0}")]
InvalidArguments(String),
#[error(transparent)] #[error(transparent)]
Sdk(#[from] Box<rmcp::RmcpError>), Sdk(#[from] Box<rmcp::RmcpError>),
......
...@@ -37,7 +37,9 @@ use crate::mcp::{ ...@@ -37,7 +37,9 @@ use crate::mcp::{
connection_pool::McpConnectionPool, connection_pool::McpConnectionPool,
error::{McpError, McpResult}, error::{McpError, McpResult},
inventory::ToolInventory, inventory::ToolInventory,
tool_args::ToolArgs,
}; };
/// Type alias for MCP client /// Type alias for MCP client
type McpClient = RunningService<RoleClient, ()>; type McpClient = RunningService<RoleClient, ()>;
...@@ -221,18 +223,28 @@ impl McpManager { ...@@ -221,18 +223,28 @@ impl McpManager {
.collect() .collect()
} }
/// Call a tool by name /// Call a tool by name with automatic type coercion
///
/// Accepts either JSON string or parsed Map as arguments.
/// Automatically converts string numbers to actual numbers based on tool schema.
pub async fn call_tool( pub async fn call_tool(
&self, &self,
tool_name: &str, tool_name: &str,
args: Option<Map<String, serde_json::Value>>, args: impl Into<ToolArgs>,
) -> McpResult<CallToolResult> { ) -> McpResult<CallToolResult> {
// Get server that owns this tool // Get tool info for schema and server
let (server_name, _tool_info) = self let (server_name, tool_info) = self
.inventory .inventory
.get_tool(tool_name) .get_tool(tool_name)
.ok_or_else(|| McpError::ToolNotFound(tool_name.to_string()))?; .ok_or_else(|| McpError::ToolNotFound(tool_name.to_string()))?;
// Convert args with type coercion based on schema
let tool_schema = tool_info.parameters.as_ref();
let args_map = args
.into()
.into_map(tool_schema)
.map_err(McpError::InvalidArguments)?;
// Get client for that server // Get client for that server
let client = self let client = self
.get_client(&server_name) .get_client(&server_name)
...@@ -242,7 +254,7 @@ impl McpManager { ...@@ -242,7 +254,7 @@ impl McpManager {
// Call the tool // Call the tool
let request = CallToolRequestParam { let request = CallToolRequestParam {
name: Cow::Owned(tool_name.to_string()), name: Cow::Owned(tool_name.to_string()),
arguments: args, arguments: args_map,
}; };
client client
......
...@@ -14,6 +14,7 @@ pub mod inventory; ...@@ -14,6 +14,7 @@ pub mod inventory;
pub mod manager; pub mod manager;
pub mod oauth; pub mod oauth;
pub mod proxy; pub mod proxy;
pub mod tool_args;
// Re-export the main types for convenience // Re-export the main types for convenience
pub use config::{ pub use config::{
...@@ -25,3 +26,4 @@ pub use error::{McpError, McpResult}; ...@@ -25,3 +26,4 @@ pub use error::{McpError, McpResult};
pub use inventory::ToolInventory; pub use inventory::ToolInventory;
pub use manager::{McpManager, McpManagerStats}; pub use manager::{McpManager, McpManagerStats};
pub use proxy::{create_http_client, resolve_proxy_config}; pub use proxy::{create_http_client, resolve_proxy_config};
pub use tool_args::ToolArgs;
//! Tool arguments handling and type coercion
//!
//! This module provides utilities for handling MCP tool arguments,
//! supporting both JSON strings and parsed Maps with automatic type coercion.
use serde_json::Map;
/// Tool arguments input - supports both JSON strings and parsed Maps
pub enum ToolArgs {
/// JSON string that needs parsing
JsonString(String),
/// Already parsed map
Map(Option<Map<String, serde_json::Value>>),
}
impl ToolArgs {
/// Convert to Map with type coercion based on tool schema
pub(crate) fn into_map(
self,
tool_schema: Option<&serde_json::Value>,
) -> Result<Option<Map<String, serde_json::Value>>, String> {
match self {
ToolArgs::JsonString(s) => {
if s.is_empty() || s == "{}" {
return Ok(None);
}
let mut value: serde_json::Value =
serde_json::from_str(&s).map_err(|e| format!("parse tool args: {}", e))?;
Self::coerce_types(&mut value, tool_schema)?;
let result = match value {
serde_json::Value::Object(m) => Some(m),
_ => None,
};
Ok(result)
}
ToolArgs::Map(map) => {
if let Some(m) = map {
let mut value = serde_json::Value::Object(m);
Self::coerce_types(&mut value, tool_schema)?;
let result = match value {
serde_json::Value::Object(m) => Some(m),
_ => None,
};
Ok(result)
} else {
Ok(None)
}
}
}
}
/// Coerce string numbers to actual numbers based on schema
/// LLMs often output numbers as strings, so we need to convert them
fn coerce_types(
value: &mut serde_json::Value,
tool_schema: Option<&serde_json::Value>,
) -> Result<(), String> {
let Some(params) = tool_schema else {
return Ok(());
};
let Some(props) = params.get("properties").and_then(|p| p.as_object()) else {
return Ok(());
};
let Some(args) = value.as_object_mut() else {
return Ok(());
};
for (key, val) in args.iter_mut() {
let should_be_number = props
.get(key)
.and_then(|s| s.get("type"))
.and_then(|t| t.as_str())
.is_some_and(|t| matches!(t, "number" | "integer"));
if should_be_number {
if let Some(s) = val.as_str() {
if let Ok(num) = s.parse::<f64>() {
*val = serde_json::json!(num);
}
}
}
}
Ok(())
}
}
// Implement From traits for convenient conversion
impl From<String> for ToolArgs {
fn from(s: String) -> Self {
ToolArgs::JsonString(s)
}
}
impl From<&str> for ToolArgs {
fn from(s: &str) -> Self {
ToolArgs::JsonString(s.to_string())
}
}
impl From<Option<Map<String, serde_json::Value>>> for ToolArgs {
fn from(map: Option<Map<String, serde_json::Value>>) -> Self {
ToolArgs::Map(map)
}
}
...@@ -99,63 +99,6 @@ fn extract_all_tool_calls_from_chat( ...@@ -99,63 +99,6 @@ fn extract_all_tool_calls_from_chat(
} }
} }
/// Execute an MCP tool call
async fn execute_mcp_call(
mcp_mgr: &Arc<crate::mcp::McpManager>,
tool_name: &str,
args_json_str: &str,
) -> Result<String, String> {
// Parse arguments JSON string to Value
let mut args_value: serde_json::Value =
serde_json::from_str::<serde_json::Value>(args_json_str)
.map_err(|e| format!("parse tool args: {}", e))?;
// Get tool info to access schema for type coercion
let tool_info = mcp_mgr
.get_tool(tool_name)
.ok_or_else(|| format!("tool not found: {}", tool_name))?;
// Coerce string numbers to actual numbers based on schema (LLMs often output numbers as strings)
if let Some(params) = &tool_info.parameters {
let properties = params.get("properties").and_then(|p| p.as_object());
let args_obj = args_value.as_object_mut();
if let (Some(props), Some(args)) = (properties, args_obj) {
for (key, val) in args.iter_mut() {
let should_be_number = props
.get(key)
.and_then(|s| s.get("type"))
.and_then(|t| t.as_str())
.is_some_and(|t| matches!(t, "number" | "integer"));
if should_be_number {
if let Some(s) = val.as_str() {
if let Ok(num) = s.parse::<f64>() {
*val = json!(num);
}
}
}
}
}
}
let args_obj = args_value.as_object().cloned();
debug!(
"Calling MCP tool '{}' with args: {}",
tool_name, args_json_str
);
let result = mcp_mgr
.call_tool(tool_name, args_obj)
.await
.map_err(|e| format!("tool call failed: {}", e))?;
let output_str = serde_json::to_string(&result)
.map_err(|e| format!("Failed to serialize tool result: {}", e))?;
Ok(output_str)
}
/// State for tracking multi-turn tool calling loop /// State for tracking multi-turn tool calling loop
struct ToolLoopState { struct ToolLoopState {
iteration: usize, iteration: usize,
...@@ -381,18 +324,32 @@ pub(super) async fn execute_tool_loop( ...@@ -381,18 +324,32 @@ pub(super) async fn execute_tool_loop(
// Increment after check // Increment after check
state.total_calls += 1; state.total_calls += 1;
// Execute the MCP tool // Execute the MCP tool - manager handles parsing and type coercion
let (output_str, success, error) = debug!(
match execute_mcp_call(&mcp_manager, &tool_name, &args_json_str).await { "Calling MCP tool '{}' with args: {}",
tool_name, args_json_str
);
let (output_str, success, error) = match mcp_manager
.call_tool(tool_name.as_str(), args_json_str.as_str())
.await
{
Ok(result) => match serde_json::to_string(&result) {
Ok(output) => (output, true, None), Ok(output) => (output, true, None),
Err(err) => { Err(e) => {
warn!("Tool execution failed: {}", err); let err = format!("Failed to serialize tool result: {}", e);
let error_msg = err.clone(); warn!("{}", err);
// Return error as output, let model decide how to proceed let error_json = json!({ "error": &err }).to_string();
let error_json = json!({ "error": err }).to_string(); (error_json, false, Some(err))
(error_json, false, Some(error_msg))
} }
}; },
Err(err) => {
let err_str = format!("tool call failed: {}", err);
warn!("Tool execution failed: {}", err_str);
// Return error as output, let model decide how to proceed
let error_json = json!({ "error": &err_str }).to_string();
(error_json, false, Some(err_str))
}
};
// Record the call in state // Record the call in state
state.record_call( state.record_call(
...@@ -796,9 +753,16 @@ async fn execute_tool_loop_streaming_internal( ...@@ -796,9 +753,16 @@ async fn execute_tool_loop_streaming_internal(
emitter.emit_mcp_call_arguments_done(output_index, &item_id, &args_json_str); emitter.emit_mcp_call_arguments_done(output_index, &item_id, &args_json_str);
emitter.send_event(&event, &tx)?; emitter.send_event(&event, &tx)?;
// Execute the MCP tool // Execute the MCP tool - manager handles parsing and type coercion
let (output_str, success, error) = debug!(
match execute_mcp_call(&mcp_manager, &tool_name, &args_json_str).await { "Calling MCP tool '{}' with args: {}",
tool_name, args_json_str
);
let (output_str, success, error) = match mcp_manager
.call_tool(tool_name.as_str(), args_json_str.as_str())
.await
{
Ok(result) => match serde_json::to_string(&result) {
Ok(output) => { Ok(output) => {
// Emit mcp_call.completed // Emit mcp_call.completed
let event = emitter.emit_mcp_call_completed(output_index, &item_id); let event = emitter.emit_mcp_call_completed(output_index, &item_id);
...@@ -822,8 +786,9 @@ async fn execute_tool_loop_streaming_internal( ...@@ -822,8 +786,9 @@ async fn execute_tool_loop_streaming_internal(
emitter.complete_output_item(output_index); emitter.complete_output_item(output_index);
(output, true, None) (output, true, None)
} }
Err(err) => { Err(e) => {
warn!("Tool execution failed: {}", err); let err = format!("Failed to serialize tool result: {}", e);
warn!("{}", err);
// Emit mcp_call.failed // Emit mcp_call.failed
let event = emitter.emit_mcp_call_failed(output_index, &item_id, &err); let event = emitter.emit_mcp_call_failed(output_index, &item_id, &err);
emitter.send_event(&event, &tx)?; emitter.send_event(&event, &tx)?;
...@@ -836,7 +801,7 @@ async fn execute_tool_loop_streaming_internal( ...@@ -836,7 +801,7 @@ async fn execute_tool_loop_streaming_internal(
"server_label": state.server_label, "server_label": state.server_label,
"status": "failed", "status": "failed",
"arguments": args_json_str, "arguments": args_json_str,
"error": err "error": &err
}); });
// Emit output_item.done // Emit output_item.done
...@@ -844,11 +809,37 @@ async fn execute_tool_loop_streaming_internal( ...@@ -844,11 +809,37 @@ async fn execute_tool_loop_streaming_internal(
emitter.send_event(&event, &tx)?; emitter.send_event(&event, &tx)?;
emitter.complete_output_item(output_index); emitter.complete_output_item(output_index);
let error_msg = err.clone(); let error_json = json!({ "error": &err }).to_string();
let error_json = json!({ "error": err }).to_string(); (error_json, false, Some(err))
(error_json, false, Some(error_msg))
} }
}; },
Err(err) => {
let err_str = format!("tool call failed: {}", err);
warn!("Tool execution failed: {}", err_str);
// Emit mcp_call.failed
let event = emitter.emit_mcp_call_failed(output_index, &item_id, &err_str);
emitter.send_event(&event, &tx)?;
// Build failed item
let item_done = json!({
"id": item_id,
"type": "mcp_call",
"name": tool_name,
"server_label": state.server_label,
"status": "failed",
"arguments": args_json_str,
"error": &err_str
});
// Emit output_item.done
let event = emitter.emit_output_item_done(output_index, &item_done);
emitter.send_event(&event, &tx)?;
emitter.complete_output_item(output_index);
let error_json = json!({ "error": &err_str }).to_string();
(error_json, false, Some(err_str))
}
};
// Record the call in state // Record the call in state
state.record_call( state.record_call(
......
...@@ -14,7 +14,7 @@ use axum::http::HeaderMap; ...@@ -14,7 +14,7 @@ use axum::http::HeaderMap;
use bytes::Bytes; use bytes::Bytes;
use serde_json::{json, to_value, Value}; use serde_json::{json, to_value, Value};
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tracing::{info, warn}; use tracing::{debug, info, warn};
use super::utils::{event_types, generate_id}; use super::utils::{event_types, generate_id};
use crate::{ use crate::{
...@@ -191,31 +191,6 @@ pub async fn ensure_request_mcp_client( ...@@ -191,31 +191,6 @@ pub async fn ensure_request_mcp_client(
// Tool Execution // Tool Execution
// ============================================================================ // ============================================================================
/// Execute an MCP tool call
pub(super) async fn execute_mcp_call(
mcp_mgr: &Arc<mcp::McpManager>,
tool_name: &str,
args_json_str: &str,
) -> Result<(String, String), String> {
let args_value: Value =
serde_json::from_str(args_json_str).map_err(|e| format!("parse tool args: {}", e))?;
let args_obj = args_value.as_object().cloned();
let server_name = mcp_mgr
.get_tool(tool_name)
.map(|t| t.server)
.ok_or_else(|| format!("tool not found: {}", tool_name))?;
let result = mcp_mgr
.call_tool(tool_name, args_obj)
.await
.map_err(|e| format!("tool call failed: {}", e))?;
let output_str = serde_json::to_string(&result)
.map_err(|e| format!("Failed to serialize tool result: {}", e))?;
Ok((server_name, output_str))
}
/// Execute detected tool calls and send completion events to client /// Execute detected tool calls and send completion events to client
/// Returns false if client disconnected during execution /// Returns false if client disconnected during execution
pub(super) async fn execute_streaming_tool_calls( pub(super) async fn execute_streaming_tool_calls(
...@@ -249,12 +224,26 @@ pub(super) async fn execute_streaming_tool_calls( ...@@ -249,12 +224,26 @@ pub(super) async fn execute_streaming_tool_calls(
&call.arguments_buffer &call.arguments_buffer
}; };
let call_result = execute_mcp_call(active_mcp, &call.name, args_str).await; // Call tool directly - manager handles parsing and type coercion
debug!("Calling MCP tool '{}' with args: {}", call.name, args_str);
let call_result = active_mcp.call_tool(&call.name, args_str).await;
let (output_str, success, error_msg) = match call_result { let (output_str, success, error_msg) = match call_result {
Ok((_, output)) => (output, true, None), Ok(result) => match serde_json::to_string(&result) {
Ok(output) => (output, true, None),
Err(e) => {
let err = format!("Failed to serialize tool result: {}", e);
warn!("{}", err);
(json!({ "error": &err }).to_string(), false, Some(err))
}
},
Err(err) => { Err(err) => {
warn!("Tool execution failed during streaming: {}", err); let err_str = format!("tool call failed: {}", err);
(json!({ "error": &err }).to_string(), false, Some(err)) warn!("Tool execution failed during streaming: {}", err_str);
(
json!({ "error": &err_str }).to_string(),
false,
Some(err_str),
)
} }
}; };
...@@ -674,15 +663,27 @@ pub(super) async fn execute_tool_loop( ...@@ -674,15 +663,27 @@ pub(super) async fn execute_tool_loop(
); );
} }
// Execute tool // Execute tool - manager handles parsing and type coercion
let call_result = execute_mcp_call(active_mcp, &tool_name, &args_json_str).await; debug!(
"Calling MCP tool '{}' with args: {}",
tool_name, args_json_str
);
let call_result = active_mcp
.call_tool(&tool_name, args_json_str.as_str())
.await;
let output_str = match call_result { let output_str = match call_result {
Ok((_, output)) => output, Ok(result) => match serde_json::to_string(&result) {
Ok(output) => output,
Err(e) => {
warn!("Failed to serialize tool result: {}", e);
json!({ "error": format!("Serialization error: {}", e) }).to_string()
}
},
Err(err) => { Err(err) => {
warn!("Tool execution failed: {}", err); warn!("Tool execution failed: {}", err);
// Return error as output, let model decide how to proceed // Return error as output, let model decide how to proceed
json!({ "error": err }).to_string() json!({ "error": format!("tool call failed: {}", err) }).to_string()
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment