Unverified Commit db37422c authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] move to mcp sdk instead (#10057)

parent ab62b135
...@@ -55,6 +55,15 @@ tiktoken-rs = { version = "0.7.0" } ...@@ -55,6 +55,15 @@ tiktoken-rs = { version = "0.7.0" }
minijinja = { version = "2.0" } minijinja = { version = "2.0" }
rustls = { version = "0.23", default-features = false, features = ["ring", "std"] } rustls = { version = "0.23", default-features = false, features = ["ring", "std"] }
hf-hub = { version = "0.4.3", features = ["tokio"] } hf-hub = { version = "0.4.3", features = ["tokio"] }
rmcp = { version = "0.6.3", features = ["client", "server",
"transport-child-process",
"transport-sse-client-reqwest",
"transport-streamable-http-client-reqwest",
"transport-streamable-http-server",
"transport-streamable-http-server-session",
"reqwest",
"auth"] }
serde_yaml = "0.9"
# gRPC and Protobuf dependencies # gRPC and Protobuf dependencies
tonic = { version = "0.12", features = ["tls", "gzip", "transport"] } tonic = { version = "0.12", features = ["tls", "gzip", "transport"] }
......
use backoff::ExponentialBackoffBuilder;
use dashmap::DashMap;
use rmcp::{
model::{
CallToolRequestParam, GetPromptRequestParam, GetPromptResult, Prompt,
ReadResourceRequestParam, ReadResourceResult, Resource, Tool as McpTool,
},
service::RunningService,
transport::{
sse_client::SseClientConfig, streamable_http_client::StreamableHttpClientTransportConfig,
ConfigureCommandExt, SseClientTransport, StreamableHttpClientTransport, TokioChildProcess,
},
RoleClient, ServiceExt,
};
use serde::{Deserialize, Serialize};
use std::{borrow::Cow, collections::HashMap, time::Duration};
use crate::mcp::{
config::{McpConfig, McpServerConfig, McpTransport},
error::{McpError, McpResult},
};
/// Information about an available tool
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolInfo {
pub name: String,
pub description: String,
pub server: String,
pub parameters: Option<serde_json::Value>,
}
/// Information about an available prompt
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptInfo {
pub name: String,
pub description: Option<String>,
pub server: String,
pub arguments: Option<Vec<serde_json::Value>>,
}
/// Information about an available resource
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResourceInfo {
pub uri: String,
pub name: String,
pub description: Option<String>,
pub mime_type: Option<String>,
pub server: String,
}
/// Manages MCP client connections and tool execution
pub struct McpClientManager {
/// Map of server_name -> MCP client
clients: HashMap<String, RunningService<RoleClient, ()>>,
/// Map of tool_name -> (server_name, tool_definition)
tools: DashMap<String, (String, McpTool)>,
/// Map of prompt_name -> (server_name, prompt_definition)
prompts: DashMap<String, (String, Prompt)>,
/// Map of resource_uri -> (server_name, resource_definition)
resources: DashMap<String, (String, Resource)>,
}
impl McpClientManager {
/// Create a new manager and connect to all configured servers
pub async fn new(config: McpConfig) -> McpResult<Self> {
let mut mgr = Self {
clients: HashMap::new(),
tools: DashMap::new(),
prompts: DashMap::new(),
resources: DashMap::new(),
};
for server_config in config.servers {
match Self::connect_server(&server_config).await {
Ok(client) => {
mgr.load_server_inventory(&server_config.name, &client)
.await;
mgr.clients.insert(server_config.name.clone(), client);
}
Err(e) => {
tracing::error!(
"Failed to connect to server '{}': {}",
server_config.name,
e
);
}
}
}
if mgr.clients.is_empty() {
return Err(McpError::ConnectionFailed(
"Failed to connect to any MCP servers".to_string(),
));
}
Ok(mgr)
}
/// Discover and cache tools/prompts/resources for a connected server
async fn load_server_inventory(
&self,
server_name: &str,
client: &RunningService<RoleClient, ()>,
) {
// Tools
match client.peer().list_all_tools().await {
Ok(ts) => {
tracing::info!("Discovered {} tools from '{}'", ts.len(), server_name);
for t in ts {
if self.tools.contains_key(t.name.as_ref()) {
tracing::warn!(
"Tool '{}' from server '{}' is overwriting an existing tool.",
&t.name,
server_name
);
}
self.tools
.insert(t.name.to_string(), (server_name.to_string(), t));
}
}
Err(e) => tracing::warn!("Failed to list tools from '{}': {}", server_name, e),
}
// Prompts
match client.peer().list_all_prompts().await {
Ok(ps) => {
tracing::info!("Discovered {} prompts from '{}'", ps.len(), server_name);
for p in ps {
if self.prompts.contains_key(&p.name) {
tracing::warn!(
"Prompt '{}' from server '{}' is overwriting an existing prompt.",
&p.name,
server_name
);
}
self.prompts
.insert(p.name.clone(), (server_name.to_string(), p));
}
}
Err(e) => tracing::debug!("No prompts or failed to list on '{}': {}", server_name, e),
}
// Resources
match client.peer().list_all_resources().await {
Ok(rs) => {
tracing::info!("Discovered {} resources from '{}'", rs.len(), server_name);
for r in rs {
if self.resources.contains_key(&r.uri) {
tracing::warn!(
"Resource '{}' from server '{}' is overwriting an existing resource.",
&r.uri,
server_name
);
}
self.resources
.insert(r.uri.clone(), (server_name.to_string(), r));
}
}
Err(e) => tracing::debug!("No resources or failed to list on '{}': {}", server_name, e),
}
}
/// Connect to a single MCP server with retry logic for remote transports
async fn connect_server(config: &McpServerConfig) -> McpResult<RunningService<RoleClient, ()>> {
let needs_retry = matches!(
&config.transport,
McpTransport::Sse { .. } | McpTransport::Streamable { .. }
);
if needs_retry {
Self::connect_server_with_retry(config).await
} else {
Self::connect_server_impl(config).await
}
}
/// Connect with exponential backoff retry for remote servers
async fn connect_server_with_retry(
config: &McpServerConfig,
) -> McpResult<RunningService<RoleClient, ()>> {
let backoff = ExponentialBackoffBuilder::new()
.with_initial_interval(Duration::from_secs(1))
.with_max_interval(Duration::from_secs(30))
.with_max_elapsed_time(Some(Duration::from_secs(120)))
.build();
backoff::future::retry(backoff, || async {
match Self::connect_server_impl(config).await {
Ok(client) => Ok(client),
Err(e) => {
tracing::warn!("Failed to connect to '{}', retrying: {}", config.name, e);
Err(backoff::Error::transient(e))
}
}
})
.await
}
/// Internal implementation of server connection
async fn connect_server_impl(
config: &McpServerConfig,
) -> McpResult<RunningService<RoleClient, ()>> {
tracing::info!(
"Connecting to MCP server '{}' via {:?}",
config.name,
config.transport
);
match &config.transport {
McpTransport::Stdio {
command,
args,
envs,
} => {
let transport = TokioChildProcess::new(
tokio::process::Command::new(command).configure(|cmd| {
cmd.args(args)
.envs(envs.iter())
.stderr(std::process::Stdio::inherit());
}),
)
.map_err(|e| McpError::Transport(format!("create stdio transport: {}", e)))?;
let client = ().serve(transport).await.map_err(|e| {
McpError::ConnectionFailed(format!("initialize stdio client: {}", e))
})?;
tracing::info!("Connected to stdio server '{}'", config.name);
Ok(client)
}
McpTransport::Sse { url, token } => {
let transport = if let Some(tok) = token {
let client = reqwest::Client::builder()
.default_headers({
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::AUTHORIZATION,
format!("Bearer {}", tok).parse().map_err(|e| {
McpError::Transport(format!("auth token: {}", e))
})?,
);
headers
})
.build()
.map_err(|e| McpError::Transport(format!("build HTTP client: {}", e)))?;
let cfg = SseClientConfig {
sse_endpoint: url.clone().into(),
..Default::default()
};
SseClientTransport::start_with_client(client, cfg)
.await
.map_err(|e| McpError::Transport(format!("create SSE transport: {}", e)))?
} else {
SseClientTransport::start(url.as_str())
.await
.map_err(|e| McpError::Transport(format!("create SSE transport: {}", e)))?
};
let client = ().serve(transport).await.map_err(|e| {
McpError::ConnectionFailed(format!("initialize SSE client: {}", e))
})?;
tracing::info!("Connected to SSE server '{}' at {}", config.name, url);
Ok(client)
}
McpTransport::Streamable { url, token } => {
let transport = if let Some(tok) = token {
let mut cfg = StreamableHttpClientTransportConfig::with_uri(url.as_str());
cfg.auth_header = Some(format!("Bearer {}", tok));
StreamableHttpClientTransport::from_config(cfg)
} else {
StreamableHttpClientTransport::from_uri(url.as_str())
};
let client = ().serve(transport).await.map_err(|e| {
McpError::ConnectionFailed(format!("initialize streamable client: {}", e))
})?;
tracing::info!(
"Connected to streamable HTTP server '{}' at {}",
config.name,
url
);
Ok(client)
}
}
}
// ===== Helpers =====
fn client_for(&self, server_name: &str) -> McpResult<&RunningService<RoleClient, ()>> {
self.clients
.get(server_name)
.ok_or_else(|| McpError::ServerNotFound(server_name.to_string()))
}
fn tool_entry(&self, name: &str) -> McpResult<(String, McpTool)> {
self.tools
.get(name)
.map(|e| e.value().clone())
.ok_or_else(|| McpError::ToolNotFound(name.to_string()))
}
fn prompt_entry(&self, name: &str) -> McpResult<(String, Prompt)> {
self.prompts
.get(name)
.map(|e| e.value().clone())
.ok_or_else(|| McpError::PromptNotFound(name.to_string()))
}
fn resource_entry(&self, uri: &str) -> McpResult<(String, Resource)> {
self.resources
.get(uri)
.map(|e| e.value().clone())
.ok_or_else(|| McpError::ResourceNotFound(uri.to_string()))
}
// ===== Tool Methods =====
/// Call a tool by name
pub async fn call_tool(
&self,
tool_name: &str,
arguments: Option<serde_json::Map<String, serde_json::Value>>,
) -> McpResult<rmcp::model::CallToolResult> {
let (server_name, _tool) = self.tool_entry(tool_name)?;
let client = self.client_for(&server_name)?;
tracing::debug!("Calling tool '{}' on '{}'", tool_name, server_name);
client
.peer()
.call_tool(CallToolRequestParam {
name: Cow::Owned(tool_name.to_string()),
arguments,
})
.await
.map_err(|e| McpError::ToolExecution(format!("Tool call failed: {}", e)))
}
/// Get all available tools
pub fn list_tools(&self) -> Vec<ToolInfo> {
self.tools
.iter()
.map(|entry| {
let tool_name = entry.key().clone();
let (server_name, tool) = entry.value();
ToolInfo {
name: tool_name,
description: tool.description.as_deref().unwrap_or_default().to_string(),
server: server_name.clone(),
parameters: Some(serde_json::Value::Object((*tool.input_schema).clone())),
}
})
.collect()
}
/// Get a specific tool by name
pub fn get_tool(&self, name: &str) -> Option<ToolInfo> {
self.tools.get(name).map(|entry| {
let (server_name, tool) = entry.value();
ToolInfo {
name: name.to_string(),
description: tool.description.as_deref().unwrap_or_default().to_string(),
server: server_name.clone(),
parameters: Some(serde_json::Value::Object((*tool.input_schema).clone())),
}
})
}
/// Check if a tool exists
pub fn has_tool(&self, name: &str) -> bool {
self.tools.contains_key(name)
}
/// Get list of connected servers
pub fn list_servers(&self) -> Vec<String> {
self.clients.keys().cloned().collect()
}
// ===== Prompt Methods =====
/// Get a prompt by name with arguments
pub async fn get_prompt(
&self,
prompt_name: &str,
arguments: Option<serde_json::Map<String, serde_json::Value>>,
) -> McpResult<GetPromptResult> {
let (server_name, _prompt) = self.prompt_entry(prompt_name)?;
let client = self.client_for(&server_name)?;
tracing::debug!("Getting prompt '{}' from '{}'", prompt_name, server_name);
client
.peer()
.get_prompt(GetPromptRequestParam {
name: prompt_name.to_string(),
arguments,
})
.await
.map_err(|e| McpError::ToolExecution(format!("Failed to get prompt: {}", e)))
}
/// List all available prompts
pub fn list_prompts(&self) -> Vec<PromptInfo> {
self.prompts
.iter()
.map(|entry| {
let name = entry.key().clone();
let (server_name, prompt) = entry.value();
PromptInfo {
name,
description: prompt.description.clone(),
server: server_name.clone(),
arguments: prompt
.arguments
.clone()
.map(|args| args.into_iter().map(|arg| serde_json::json!(arg)).collect()),
}
})
.collect()
}
/// Get a specific prompt info by name
pub fn get_prompt_info(&self, name: &str) -> Option<PromptInfo> {
self.prompts.get(name).map(|entry| {
let (server_name, prompt) = entry.value();
PromptInfo {
name: name.to_string(),
description: prompt.description.clone(),
server: server_name.clone(),
arguments: prompt
.arguments
.clone()
.map(|args| args.into_iter().map(|arg| serde_json::json!(arg)).collect()),
}
})
}
// ===== Resource Methods =====
/// Read a resource by URI
pub async fn read_resource(&self, uri: &str) -> McpResult<ReadResourceResult> {
let (server_name, _resource) = self.resource_entry(uri)?;
let client = self.client_for(&server_name)?;
tracing::debug!("Reading resource '{}' from '{}'", uri, server_name);
client
.peer()
.read_resource(ReadResourceRequestParam {
uri: uri.to_string(),
})
.await
.map_err(|e| McpError::ToolExecution(format!("Failed to read resource: {}", e)))
}
/// List all available resources
pub fn list_resources(&self) -> Vec<ResourceInfo> {
self.resources
.iter()
.map(|entry| {
let uri = entry.key().clone();
let (server_name, resource) = entry.value();
ResourceInfo {
uri,
name: resource.name.clone(),
description: resource.description.clone(),
mime_type: resource.mime_type.clone(),
server: server_name.clone(),
}
})
.collect()
}
/// Get a specific resource info by URI
pub fn get_resource_info(&self, uri: &str) -> Option<ResourceInfo> {
self.resources.get(uri).map(|entry| {
let (server_name, resource) = entry.value();
ResourceInfo {
uri: uri.to_string(),
name: resource.name.clone(),
description: resource.description.clone(),
mime_type: resource.mime_type.clone(),
server: server_name.clone(),
}
})
}
/// Subscribe to resource changes
pub async fn subscribe_resource(&self, uri: &str) -> McpResult<()> {
let (server_name, _resource) = self.resource_entry(uri)?;
let client = self.client_for(&server_name)?;
tracing::debug!("Subscribing to '{}' on '{}'", uri, server_name);
client
.peer()
.subscribe(rmcp::model::SubscribeRequestParam {
uri: uri.to_string(),
})
.await
.map_err(|e| McpError::ToolExecution(format!("Failed to subscribe: {}", e)))
}
/// Unsubscribe from resource changes
pub async fn unsubscribe_resource(&self, uri: &str) -> McpResult<()> {
let (server_name, _resource) = self.resource_entry(uri)?;
let client = self.client_for(&server_name)?;
tracing::debug!("Unsubscribing from '{}' on '{}'", uri, server_name);
client
.peer()
.unsubscribe(rmcp::model::UnsubscribeRequestParam {
uri: uri.to_string(),
})
.await
.map_err(|e| McpError::ToolExecution(format!("Failed to unsubscribe: {}", e)))
}
/// Disconnect from all servers (for cleanup)
pub async fn shutdown(&mut self) {
for (name, client) in self.clients.drain() {
if let Err(e) = client.cancel().await {
tracing::warn!("Error disconnecting from '{}': {}", name, e);
}
}
self.tools.clear();
self.prompts.clear();
self.resources.clear();
}
}
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct McpConfig {
pub servers: Vec<McpServerConfig>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct McpServerConfig {
pub name: String,
#[serde(flatten)]
pub transport: McpTransport,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "protocol", rename_all = "lowercase")]
pub enum McpTransport {
Stdio {
command: String,
#[serde(default)]
args: Vec<String>,
#[serde(default)]
envs: HashMap<String, String>,
},
Sse {
url: String,
#[serde(skip_serializing_if = "Option::is_none")]
token: Option<String>,
},
Streamable {
url: String,
#[serde(skip_serializing_if = "Option::is_none")]
token: Option<String>,
},
}
impl McpConfig {
/// Load configuration from a YAML file
pub async fn from_file(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
let content = tokio::fs::read_to_string(path).await?;
let config: Self = serde_yaml::from_str(&content)?;
Ok(config)
}
/// Load configuration from environment variables (optional)
pub fn from_env() -> Option<Self> {
// This could be expanded to read from env vars
// For now, return None to indicate env config not implemented
None
}
}
use thiserror::Error;
pub type McpResult<T> = Result<T, McpError>;
#[derive(Debug, Error)]
pub enum McpError {
#[error("Server not found: {0}")]
ServerNotFound(String),
#[error("Tool not found: {0}")]
ToolNotFound(String),
#[error("Transport error: {0}")]
Transport(String),
#[error("Tool execution failed: {0}")]
ToolExecution(String),
#[error("Connection failed: {0}")]
ConnectionFailed(String),
#[error("Configuration error: {0}")]
Config(String),
#[error("Authentication error: {0}")]
Auth(String),
#[error("Resource not found: {0}")]
ResourceNotFound(String),
#[error("Prompt not found: {0}")]
PromptNotFound(String),
#[error(transparent)]
Sdk(#[from] Box<rmcp::RmcpError>),
#[error(transparent)]
Io(#[from] std::io::Error),
#[error(transparent)]
Http(#[from] reqwest::Error),
}
// mod.rs - MCP module exports // MCP Client for SGLang Router
pub mod tool_server; //
pub mod types; // This module provides a complete MCP (Model Context Protocol) client implementation
// supporting multiple transport types (stdio, SSE, HTTP) and all MCP features:
// - Tools: Discovery and execution
// - Prompts: Reusable templates for LLM interactions
// - Resources: File/data access with subscription support
// - OAuth: Secure authentication for remote servers
pub use tool_server::{parse_sse_event, MCPToolServer, ToolStats}; pub mod client_manager;
pub use types::{ pub mod config;
HttpConnection, MCPError, MCPResult, MultiToolSessionManager, SessionStats, ToolCall, pub mod error;
ToolResult, ToolSession, pub mod oauth;
};
// Re-export the main types for convenience
pub use client_manager::{McpClientManager, PromptInfo, ResourceInfo, ToolInfo};
pub use config::{McpConfig, McpServerConfig, McpTransport};
pub use error::{McpError, McpResult};
// OAuth authentication support for MCP servers
use axum::{
extract::{Query, State},
response::Html,
routing::get,
Router,
};
use rmcp::transport::auth::OAuthState;
use serde::Deserialize;
use std::{net::SocketAddr, sync::Arc};
use tokio::sync::{oneshot, Mutex};
use crate::mcp::error::{McpError, McpResult};
/// OAuth callback parameters
#[derive(Debug, Deserialize)]
struct CallbackParams {
code: String,
#[allow(dead_code)]
state: Option<String>,
}
/// State for the callback server
#[derive(Clone)]
struct CallbackState {
code_receiver: Arc<Mutex<Option<oneshot::Sender<String>>>>,
}
/// HTML page returned after successful OAuth callback
const CALLBACK_HTML: &str = r#"
<!DOCTYPE html>
<html>
<head>
<title>OAuth Success</title>
<style>
body {
font-family: Arial, sans-serif;
display: flex;
justify-content: center;
align-items: center;
height: 100vh;
margin: 0;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
}
.container {
background: white;
padding: 40px;
border-radius: 10px;
box-shadow: 0 10px 30px rgba(0,0,0,0.2);
text-align: center;
}
h1 { color: #333; }
p { color: #666; margin: 20px 0; }
.success { color: #4CAF50; font-size: 48px; }
</style>
</head>
<body>
<div class="container">
<div class="success">✓</div>
<h1>Authentication Successful!</h1>
<p>You can now close this window and return to your application.</p>
</div>
</body>
</html>
"#;
/// OAuth authentication helper for MCP servers
pub struct OAuthHelper {
server_url: String,
redirect_uri: String,
callback_port: u16,
}
impl OAuthHelper {
/// Create a new OAuth helper
pub fn new(server_url: String, redirect_uri: String, callback_port: u16) -> Self {
Self {
server_url,
redirect_uri,
callback_port,
}
}
/// Perform OAuth authentication flow
pub async fn authenticate(
&self,
scopes: &[&str],
) -> McpResult<rmcp::transport::auth::AuthorizationManager> {
// Initialize OAuth state machine
let mut oauth_state = OAuthState::new(&self.server_url, None)
.await
.map_err(|e| McpError::Auth(format!("Failed to initialize OAuth: {}", e)))?;
oauth_state
.start_authorization(scopes, &self.redirect_uri)
.await
.map_err(|e| McpError::Auth(format!("Failed to start authorization: {}", e)))?;
// Get authorization URL
let auth_url = oauth_state
.get_authorization_url()
.await
.map_err(|e| McpError::Auth(format!("Failed to get authorization URL: {}", e)))?;
tracing::info!("OAuth authorization URL: {}", auth_url);
// Start callback server and wait for code
let auth_code = self.start_callback_server().await?;
// Exchange code for token
oauth_state
.handle_callback(&auth_code)
.await
.map_err(|e| McpError::Auth(format!("Failed to handle OAuth callback: {}", e)))?;
// Get authorization manager
oauth_state
.into_authorization_manager()
.ok_or_else(|| McpError::Auth("Failed to get authorization manager".to_string()))
}
/// Start a local HTTP server to receive the OAuth callback
async fn start_callback_server(&self) -> McpResult<String> {
let (code_sender, code_receiver) = oneshot::channel::<String>();
let state = CallbackState {
code_receiver: Arc::new(Mutex::new(Some(code_sender))),
};
// Create router for callback
let app = Router::new()
.route("/callback", get(Self::callback_handler))
.with_state(state);
let addr = SocketAddr::from(([127, 0, 0, 1], self.callback_port));
// Start server in background
let listener = tokio::net::TcpListener::bind(addr).await.map_err(|e| {
McpError::Auth(format!(
"Failed to bind to callback port {}: {}",
self.callback_port, e
))
})?;
tokio::spawn(async move {
let _ = axum::serve(listener, app).await;
});
tracing::info!(
"OAuth callback server started on port {}",
self.callback_port
);
// Wait for authorization code
code_receiver
.await
.map_err(|_| McpError::Auth("Failed to receive authorization code".to_string()))
}
/// Handle OAuth callback
async fn callback_handler(
Query(params): Query<CallbackParams>,
State(state): State<CallbackState>,
) -> Html<String> {
tracing::debug!("Received OAuth callback with code");
// Send code to waiting task
if let Some(sender) = state.code_receiver.lock().await.take() {
let _ = sender.send(params.code);
}
Html(CALLBACK_HTML.to_string())
}
}
/// Create an OAuth-authenticated client
pub async fn create_oauth_client(
server_url: String,
_sse_url: String,
redirect_uri: String,
callback_port: u16,
scopes: &[&str],
) -> McpResult<rmcp::transport::auth::AuthClient<reqwest::Client>> {
let helper = OAuthHelper::new(server_url, redirect_uri, callback_port);
let auth_manager = helper.authenticate(scopes).await?;
let client = rmcp::transport::auth::AuthClient::new(reqwest::Client::default(), auth_manager);
Ok(client)
}
// tool_server.rs - Main MCP implementation (matching Python's tool_server.py)
use crate::mcp::types::*;
use serde_json::{json, Value};
use std::collections::HashMap;
/// Main MCP Tool Server
pub struct MCPToolServer {
/// Tool descriptions by server
tool_descriptions: HashMap<String, Value>,
/// Server URLs
urls: HashMap<String, String>,
}
impl Default for MCPToolServer {
fn default() -> Self {
Self::new()
}
}
impl MCPToolServer {
/// Create new MCPToolServer
pub fn new() -> Self {
Self {
tool_descriptions: HashMap::new(),
urls: HashMap::new(),
}
}
/// Clears all existing tool servers and adds new ones from the provided URL(s).
/// URLs can be a single string or multiple comma-separated strings.
pub async fn add_tool_server(&mut self, server_url: String) -> MCPResult<()> {
let tool_urls: Vec<&str> = server_url.split(",").collect();
let mut successful_connections = 0;
let mut errors = Vec::new();
// Clear existing
self.tool_descriptions = HashMap::new();
self.urls = HashMap::new();
for url_str in tool_urls {
let url_str = url_str.trim();
// Format URL for MCP-compliant connection
let formatted_url = if url_str.starts_with("http://") || url_str.starts_with("https://")
{
url_str.to_string()
} else {
// Default to MCP endpoint if no protocol specified
format!("http://{}", url_str)
};
// Server connection with retry and error recovery
match self.connect_to_server(&formatted_url).await {
Ok((_init_response, tools_response)) => {
// Process tools with validation
let tools_obj = post_process_tools_description(tools_response);
// Tool storage with conflict detection
for tool in &tools_obj.tools {
let tool_name = &tool.name;
// Check for duplicate tools
if self.tool_descriptions.contains_key(tool_name) {
tracing::warn!(
"Tool {} already exists. Ignoring duplicate tool from server {}",
tool_name,
formatted_url
);
continue;
}
// Store individual tool descriptions
let tool_json = json!(tool);
self.tool_descriptions
.insert(tool_name.clone(), tool_json.clone());
self.urls.insert(tool_name.clone(), formatted_url.clone());
}
successful_connections += 1;
}
Err(e) => {
errors.push(format!("Failed to connect to {}: {}", formatted_url, e));
tracing::warn!("Failed to connect to MCP server {}: {}", formatted_url, e);
}
}
}
// Error handling - succeed if at least one server connects
if successful_connections == 0 {
let combined_error = errors.join("; ");
return Err(MCPError::ConnectionError(format!(
"Failed to connect to any MCP servers: {}",
combined_error
)));
}
if !errors.is_empty() {
tracing::warn!("Some MCP servers failed to connect: {}", errors.join("; "));
}
tracing::info!(
"Successfully connected to {} MCP server(s), discovered {} tool(s)",
successful_connections,
self.tool_descriptions.len()
);
Ok(())
}
/// Server connection with retries (internal helper)
async fn connect_to_server(
&self,
url: &str,
) -> MCPResult<(InitializeResponse, ListToolsResponse)> {
const MAX_RETRIES: u32 = 3;
const RETRY_DELAY_MS: u64 = 1000;
let mut last_error = None;
for attempt in 1..=MAX_RETRIES {
match list_server_and_tools(url).await {
Ok(result) => return Ok(result),
Err(e) => {
last_error = Some(e);
if attempt < MAX_RETRIES {
tracing::debug!(
"MCP server connection attempt {}/{} failed for {}: {}. Retrying...",
attempt,
MAX_RETRIES,
url,
last_error.as_ref().unwrap()
);
tokio::time::sleep(tokio::time::Duration::from_millis(
RETRY_DELAY_MS * attempt as u64,
))
.await;
}
}
}
}
Err(last_error.unwrap())
}
/// Check if tool exists (matching Python's has_tool)
pub fn has_tool(&self, tool_name: &str) -> bool {
self.tool_descriptions.contains_key(tool_name)
}
/// Get tool description (matching Python's get_tool_description)
pub fn get_tool_description(&self, tool_name: &str) -> Option<&Value> {
self.tool_descriptions.get(tool_name)
}
/// Get tool session (matching Python's get_tool_session)
pub async fn get_tool_session(&self, tool_name: &str) -> MCPResult<ToolSession> {
let url = self
.urls
.get(tool_name)
.ok_or_else(|| MCPError::ToolNotFound(tool_name.to_string()))?;
// Create session
ToolSession::new(url.clone()).await
}
/// Create multi-tool session manager
pub async fn create_multi_tool_session(
&self,
tool_names: Vec<String>,
) -> MCPResult<MultiToolSessionManager> {
let mut session_manager = MultiToolSessionManager::new();
// Group tools by server URL for efficient session creation
let mut server_tools: std::collections::HashMap<String, Vec<String>> =
std::collections::HashMap::new();
for tool_name in tool_names {
if let Some(url) = self.urls.get(&tool_name) {
server_tools.entry(url.clone()).or_default().push(tool_name);
} else {
return Err(MCPError::ToolNotFound(format!(
"Tool not found: {}",
tool_name
)));
}
}
// Create sessions for each server
for (server_url, tools) in server_tools {
session_manager
.add_tools_from_server(server_url, tools)
.await?;
}
Ok(session_manager)
}
/// List all available tools
pub fn list_tools(&self) -> Vec<String> {
self.tool_descriptions.keys().cloned().collect()
}
/// Get tool statistics
pub fn get_tool_stats(&self) -> ToolStats {
ToolStats {
total_tools: self.tool_descriptions.len(),
total_servers: self
.urls
.values()
.collect::<std::collections::HashSet<_>>()
.len(),
}
}
/// List all connected servers
pub fn list_servers(&self) -> Vec<String> {
self.urls
.values()
.cloned()
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect()
}
/// Check if a specific server is connected
pub fn has_server(&self, server_url: &str) -> bool {
self.urls.values().any(|url| url == server_url)
}
/// Execute a tool directly (convenience method for simple usage)
pub async fn call_tool(
&self,
tool_name: &str,
arguments: serde_json::Value,
) -> MCPResult<serde_json::Value> {
let session = self.get_tool_session(tool_name).await?;
session.call_tool(tool_name, arguments).await
}
/// Create a tool session from server URL (convenience method)
pub async fn create_session_from_url(&self, server_url: &str) -> MCPResult<ToolSession> {
ToolSession::new(server_url.to_string()).await
}
}
/// Tool statistics for monitoring
#[derive(Debug, Clone)]
pub struct ToolStats {
pub total_tools: usize,
pub total_servers: usize,
}
/// MCP-compliant server connection using JSON-RPC over SSE
async fn list_server_and_tools(
server_url: &str,
) -> MCPResult<(InitializeResponse, ListToolsResponse)> {
// MCP specification:
// 1. Connect to MCP endpoint with GET (SSE) or POST (JSON-RPC)
// 2. Send initialize request
// 3. Send tools/list request
// 4. Parse JSON-RPC responses
let client = reqwest::Client::new();
// Step 1: Send initialize request
let init_request = MCPRequest {
jsonrpc: "2.0".to_string(),
id: "1".to_string(),
method: "initialize".to_string(),
params: Some(json!({
"protocolVersion": "2024-11-05",
"capabilities": {}
})),
};
let init_response = send_mcp_request(&client, server_url, init_request).await?;
let init_result: InitializeResponse = serde_json::from_value(init_response).map_err(|e| {
MCPError::SerializationError(format!("Failed to parse initialize response: {}", e))
})?;
// Step 2: Send tools/list request
let tools_request = MCPRequest {
jsonrpc: "2.0".to_string(),
id: "2".to_string(),
method: "tools/list".to_string(),
params: Some(json!({})),
};
let tools_response = send_mcp_request(&client, server_url, tools_request).await?;
let tools_result: ListToolsResponse = serde_json::from_value(tools_response).map_err(|e| {
MCPError::SerializationError(format!("Failed to parse tools/list response: {}", e))
})?;
Ok((init_result, tools_result))
}
/// Send MCP JSON-RPC request (supports both HTTP POST and SSE)
async fn send_mcp_request(
client: &reqwest::Client,
url: &str,
request: MCPRequest,
) -> MCPResult<Value> {
// Use HTTP POST for JSON-RPC requests
let response = client
.post(url)
.header("Content-Type", "application/json")
.header("Accept", "application/json")
.json(&request)
.send()
.await
.map_err(|e| MCPError::ConnectionError(format!("MCP request failed: {}", e)))?;
if !response.status().is_success() {
return Err(MCPError::ProtocolError(format!(
"HTTP {}",
response.status()
)));
}
let mcp_response: MCPResponse = response.json().await.map_err(|e| {
MCPError::SerializationError(format!("Failed to parse MCP response: {}", e))
})?;
if let Some(error) = mcp_response.error {
return Err(MCPError::ProtocolError(format!(
"MCP error: {}",
error.message
)));
}
mcp_response
.result
.ok_or_else(|| MCPError::ProtocolError("No result in MCP response".to_string()))
}
// Removed old send_http_request - now using send_mcp_request with proper MCP protocol
/// Parse SSE event format (MCP-compliant JSON-RPC only)
pub fn parse_sse_event(event: &str) -> MCPResult<Option<Value>> {
let mut data_lines = Vec::new();
for line in event.lines() {
if let Some(stripped) = line.strip_prefix("data: ") {
data_lines.push(stripped);
}
}
if data_lines.is_empty() {
return Ok(None);
}
let json_data = data_lines.join("\n");
if json_data.trim().is_empty() {
return Ok(None);
}
// Parse as MCP JSON-RPC response only (no custom events)
let mcp_response: MCPResponse = serde_json::from_str(&json_data).map_err(|e| {
MCPError::SerializationError(format!(
"Failed to parse JSON-RPC response: {} - Data: {}",
e, json_data
))
})?;
if let Some(error) = mcp_response.error {
return Err(MCPError::ProtocolError(error.message));
}
Ok(mcp_response.result)
}
/// Schema adaptation matching Python's trim_schema()
fn trim_schema(schema: &mut Value) {
if let Some(obj) = schema.as_object_mut() {
// Remove title and null defaults
obj.remove("title");
if obj.get("default") == Some(&Value::Null) {
obj.remove("default");
}
// Convert anyOf to type arrays
if let Some(any_of) = obj.remove("anyOf") {
if let Some(array) = any_of.as_array() {
let types: Vec<String> = array
.iter()
.filter_map(|item| {
item.get("type")
.and_then(|t| t.as_str())
.filter(|t| *t != "null")
.map(|t| t.to_string())
})
.collect();
// Handle single type vs array of types
match types.len() {
0 => {} // No valid types found
1 => {
obj.insert("type".to_string(), json!(types[0]));
}
_ => {
obj.insert("type".to_string(), json!(types));
}
}
}
}
// Handle oneOf similar to anyOf
if let Some(one_of) = obj.remove("oneOf") {
if let Some(array) = one_of.as_array() {
let types: Vec<String> = array
.iter()
.filter_map(|item| {
item.get("type")
.and_then(|t| t.as_str())
.filter(|t| *t != "null")
.map(|t| t.to_string())
})
.collect();
if !types.is_empty() {
obj.insert("type".to_string(), json!(types));
}
}
}
// Recursive processing for properties
if let Some(properties) = obj.get_mut("properties") {
if let Some(props_obj) = properties.as_object_mut() {
for (_, value) in props_obj.iter_mut() {
trim_schema(value);
}
}
}
// Handle nested schemas in items (for arrays)
if let Some(items) = obj.get_mut("items") {
trim_schema(items);
}
// Handle nested schemas in additionalProperties
if let Some(additional_props) = obj.get_mut("additionalProperties") {
if additional_props.is_object() {
trim_schema(additional_props);
}
}
// Handle patternProperties (for dynamic property names)
if let Some(pattern_props) = obj.get_mut("patternProperties") {
if let Some(pattern_obj) = pattern_props.as_object_mut() {
for (_, value) in pattern_obj.iter_mut() {
trim_schema(value);
}
}
}
// Handle allOf in nested contexts
if let Some(all_of) = obj.get_mut("allOf") {
if let Some(array) = all_of.as_array_mut() {
for item in array.iter_mut() {
trim_schema(item);
}
}
}
}
}
/// Tool processing with filtering
fn post_process_tools_description(mut tools_response: ListToolsResponse) -> ListToolsResponse {
// Adapt schemas for Harmony
for tool in &mut tools_response.tools {
trim_schema(&mut tool.input_schema);
}
// Tool filtering based on annotations
let initial_count = tools_response.tools.len();
tools_response.tools.retain(|tool| {
// Check include_in_prompt annotation (Python behavior)
let include_in_prompt = tool
.annotations
.as_ref()
.and_then(|a| a.get("include_in_prompt"))
.and_then(|v| v.as_bool())
.unwrap_or(true);
if !include_in_prompt {
tracing::debug!(
"Filtering out tool '{}' due to include_in_prompt=false",
tool.name
);
return false;
}
// Check if tool is explicitly disabled
let disabled = tool
.annotations
.as_ref()
.and_then(|a| a.get("disabled"))
.and_then(|v| v.as_bool())
.unwrap_or(false);
if disabled {
tracing::debug!("Filtering out disabled tool '{}'", tool.name);
return false;
}
// Validate tool has required fields
if tool.name.trim().is_empty() {
tracing::warn!("Filtering out tool with empty name");
return false;
}
// Check for valid input schema
if tool.input_schema.is_null() {
tracing::warn!("Tool '{}' has null input schema, but keeping it", tool.name);
}
true
});
let filtered_count = tools_response.tools.len();
if filtered_count != initial_count {
tracing::info!(
"Filtered tools: {} -> {} ({} removed)",
initial_count,
filtered_count,
initial_count - filtered_count
);
}
tools_response
}
// Tests moved to tests/mcp_comprehensive_test.rs for better organization
// types.rs - All MCP data structures
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use thiserror::Error;
use uuid;
// ===== Errors =====
#[derive(Error, Debug)]
pub enum MCPError {
#[error("Connection failed: {0}")]
ConnectionError(String),
#[error("Invalid URL: {0}")]
InvalidURL(String),
#[error("Protocol error: {0}")]
ProtocolError(String),
#[error("Tool execution failed: {0}")]
ToolExecutionError(String),
#[error("Tool not found: {0}")]
ToolNotFound(String),
#[error("Serialization error: {0}")]
SerializationError(String),
#[error("Configuration error: {0}")]
ConfigurationError(String),
}
pub type MCPResult<T> = Result<T, MCPError>;
// Add From implementations for common error types
impl From<serde_json::Error> for MCPError {
fn from(err: serde_json::Error) -> Self {
MCPError::SerializationError(err.to_string())
}
}
impl From<reqwest::Error> for MCPError {
fn from(err: reqwest::Error) -> Self {
MCPError::ConnectionError(err.to_string())
}
}
// ===== MCP Protocol Types =====
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MCPRequest {
pub jsonrpc: String,
pub id: String,
pub method: String,
pub params: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MCPResponse {
pub jsonrpc: String,
pub id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<MCPErrorResponse>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MCPErrorResponse {
pub code: i32,
pub message: String,
pub data: Option<serde_json::Value>,
}
// ===== MCP Server Response Types =====
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InitializeResponse {
#[serde(rename = "serverInfo")]
pub server_info: ServerInfo,
pub instructions: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerInfo {
pub name: String,
pub version: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ListToolsResponse {
pub tools: Vec<ToolInfo>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolInfo {
pub name: String,
pub description: Option<String>,
#[serde(rename = "inputSchema")]
pub input_schema: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub annotations: Option<serde_json::Value>,
}
// ===== Types =====
pub type ToolCall = serde_json::Value; // Python uses dict
pub type ToolResult = serde_json::Value; // Python uses dict
// ===== Connection Types =====
#[derive(Debug, Clone)]
pub struct HttpConnection {
pub url: String,
}
// ===== Tool Session =====
pub struct ToolSession {
pub connection: HttpConnection,
pub client: reqwest::Client,
pub session_initialized: bool,
}
impl ToolSession {
pub async fn new(connection_str: String) -> MCPResult<Self> {
if !connection_str.starts_with("http://") && !connection_str.starts_with("https://") {
return Err(MCPError::InvalidURL(format!(
"Only HTTP/HTTPS URLs are supported: {}",
connection_str
)));
}
let mut session = Self {
connection: HttpConnection {
url: connection_str,
},
client: reqwest::Client::new(),
session_initialized: false,
};
// Initialize the session
session.initialize().await?;
Ok(session)
}
pub async fn new_http(url: String) -> MCPResult<Self> {
Self::new(url).await
}
/// Initialize the session
pub async fn initialize(&mut self) -> MCPResult<()> {
if self.session_initialized {
return Ok(());
}
let init_request = MCPRequest {
jsonrpc: "2.0".to_string(),
id: "init".to_string(),
method: "initialize".to_string(),
params: Some(serde_json::json!({
"protocolVersion": "2024-11-05",
"capabilities": {}
})),
};
let response = self
.client
.post(&self.connection.url)
.header("Content-Type", "application/json")
.json(&init_request)
.send()
.await
.map_err(|e| MCPError::ConnectionError(format!("Initialize failed: {}", e)))?;
let mcp_response: MCPResponse = response.json().await.map_err(|e| {
MCPError::SerializationError(format!("Failed to parse initialize response: {}", e))
})?;
if let Some(error) = mcp_response.error {
return Err(MCPError::ProtocolError(format!(
"Initialize error: {}",
error.message
)));
}
self.session_initialized = true;
Ok(())
}
/// Call a tool using MCP tools/call
pub async fn call_tool(
&self,
name: &str,
arguments: serde_json::Value,
) -> MCPResult<serde_json::Value> {
if !self.session_initialized {
return Err(MCPError::ProtocolError(
"Session not initialized. Call initialize() first.".to_string(),
));
}
use serde_json::json;
let request = MCPRequest {
jsonrpc: "2.0".to_string(),
id: format!("call_{}", uuid::Uuid::new_v4()),
method: "tools/call".to_string(),
params: Some(json!({
"name": name,
"arguments": arguments
})),
};
let response = self
.client
.post(&self.connection.url)
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.map_err(|e| MCPError::ConnectionError(format!("Tool call failed: {}", e)))?;
let mcp_response: MCPResponse = response.json().await.map_err(|e| {
MCPError::SerializationError(format!("Failed to parse tool response: {}", e))
})?;
if let Some(error) = mcp_response.error {
return Err(MCPError::ToolExecutionError(format!(
"Tool '{}' failed: {}",
name, error.message
)));
}
mcp_response
.result
.ok_or_else(|| MCPError::ProtocolError("No result in tool response".to_string()))
}
/// Check if session is ready for tool calls
pub fn is_ready(&self) -> bool {
self.session_initialized
}
/// Get connection info
pub fn connection_info(&self) -> String {
format!("HTTP: {}", self.connection.url)
}
}
// ===== Multi-Tool Session Manager =====
pub struct MultiToolSessionManager {
sessions: HashMap<String, ToolSession>, // server_url -> session
tool_to_server: HashMap<String, String>, // tool_name -> server_url mapping
}
impl Default for MultiToolSessionManager {
fn default() -> Self {
Self::new()
}
}
impl MultiToolSessionManager {
/// Create new multi-tool session manager
pub fn new() -> Self {
Self {
sessions: HashMap::new(),
tool_to_server: HashMap::new(),
}
}
/// Add tools from an MCP server (optimized to share sessions per server)
pub async fn add_tools_from_server(
&mut self,
server_url: String,
tool_names: Vec<String>,
) -> MCPResult<()> {
// Create one session per server URL (if not already exists)
if !self.sessions.contains_key(&server_url) {
let session = ToolSession::new(server_url.clone()).await?;
self.sessions.insert(server_url.clone(), session);
}
// Map all tools to this server URL
for tool_name in tool_names {
self.tool_to_server.insert(tool_name, server_url.clone());
}
Ok(())
}
/// Get session for a specific tool
pub fn get_session(&self, tool_name: &str) -> Option<&ToolSession> {
let server_url = self.tool_to_server.get(tool_name)?;
self.sessions.get(server_url)
}
/// Execute tool with automatic session management
pub async fn call_tool(
&self,
tool_name: &str,
arguments: serde_json::Value,
) -> MCPResult<serde_json::Value> {
let server_url = self
.tool_to_server
.get(tool_name)
.ok_or_else(|| MCPError::ToolNotFound(format!("No mapping for tool: {}", tool_name)))?;
let session = self.sessions.get(server_url).ok_or_else(|| {
MCPError::ToolNotFound(format!("No session for server: {}", server_url))
})?;
session.call_tool(tool_name, arguments).await
}
/// Execute multiple tools concurrently
pub async fn call_tools_concurrent(
&self,
tool_calls: Vec<(String, serde_json::Value)>,
) -> Vec<MCPResult<serde_json::Value>> {
let futures: Vec<_> = tool_calls
.into_iter()
.map(|(tool_name, args)| async move { self.call_tool(&tool_name, args).await })
.collect();
futures::future::join_all(futures).await
}
/// Get all available tool names
pub fn list_tools(&self) -> Vec<String> {
self.tool_to_server.keys().cloned().collect()
}
/// Check if tool is available
pub fn has_tool(&self, tool_name: &str) -> bool {
self.tool_to_server.contains_key(tool_name)
}
/// Get session statistics
pub fn session_stats(&self) -> SessionStats {
let total_sessions = self.sessions.len();
let ready_sessions = self.sessions.values().filter(|s| s.is_ready()).count();
let unique_servers = self.sessions.len(); // Now sessions = servers
SessionStats {
total_sessions,
ready_sessions,
unique_servers,
}
}
}
#[derive(Debug, Clone)]
pub struct SessionStats {
pub total_sessions: usize,
pub ready_sessions: usize,
pub unique_servers: usize,
}
// tests/common/mock_mcp_server.rs - Mock MCP server for testing // tests/common/mock_mcp_server.rs - Mock MCP server for testing
use rmcp::{
use axum::{ handler::server::{router::tool::ToolRouter, wrapper::Parameters},
extract::Json, http::StatusCode, response::Json as ResponseJson, routing::post, Router, model::*,
service::RequestContext,
tool, tool_handler, tool_router,
transport::streamable_http_server::{
session::local::LocalSessionManager, StreamableHttpService,
},
ErrorData as McpError, RoleServer, ServerHandler,
}; };
use serde_json::{json, Value};
use tokio::net::TcpListener; use tokio::net::TcpListener;
/// Mock MCP server that returns hardcoded responses for testing /// Mock MCP server that returns hardcoded responses for testing
...@@ -12,6 +17,69 @@ pub struct MockMCPServer { ...@@ -12,6 +17,69 @@ pub struct MockMCPServer {
pub server_handle: Option<tokio::task::JoinHandle<()>>, pub server_handle: Option<tokio::task::JoinHandle<()>>,
} }
/// Simple test server with mock search tools
#[derive(Clone)]
pub struct MockSearchServer {
tool_router: ToolRouter<MockSearchServer>,
}
#[tool_router]
impl MockSearchServer {
pub fn new() -> Self {
Self {
tool_router: Self::tool_router(),
}
}
#[tool(description = "Mock web search tool")]
fn brave_web_search(
&self,
Parameters(params): Parameters<serde_json::Map<String, serde_json::Value>>,
) -> Result<CallToolResult, McpError> {
let query = params
.get("query")
.and_then(|v| v.as_str())
.unwrap_or("test");
Ok(CallToolResult::success(vec![Content::text(format!(
"Mock search results for: {}",
query
))]))
}
#[tool(description = "Mock local search tool")]
fn brave_local_search(
&self,
Parameters(_params): Parameters<serde_json::Map<String, serde_json::Value>>,
) -> Result<CallToolResult, McpError> {
Ok(CallToolResult::success(vec![Content::text(
"Mock local search results",
)]))
}
}
#[tool_handler]
impl ServerHandler for MockSearchServer {
fn get_info(&self) -> ServerInfo {
ServerInfo {
protocol_version: ProtocolVersion::V_2024_11_05,
capabilities: ServerCapabilities::builder().enable_tools().build(),
server_info: Implementation {
name: "Mock MCP Server".to_string(),
version: "1.0.0".to_string(),
},
instructions: Some("Mock server for testing".to_string()),
}
}
async fn initialize(
&self,
_request: InitializeRequestParam,
_context: RequestContext<RoleServer>,
) -> Result<InitializeResult, McpError> {
Ok(self.get_info())
}
}
impl MockMCPServer { impl MockMCPServer {
/// Start a mock MCP server on an available port /// Start a mock MCP server on an available port
pub async fn start() -> Result<Self, Box<dyn std::error::Error + Send + Sync>> { pub async fn start() -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
...@@ -19,7 +87,14 @@ impl MockMCPServer { ...@@ -19,7 +87,14 @@ impl MockMCPServer {
let listener = TcpListener::bind("127.0.0.1:0").await?; let listener = TcpListener::bind("127.0.0.1:0").await?;
let port = listener.local_addr()?.port(); let port = listener.local_addr()?.port();
let app = Router::new().route("/mcp", post(handle_mcp_request)); // Create the MCP service using rmcp's StreamableHttpService
let service = StreamableHttpService::new(
|| Ok(MockSearchServer::new()),
LocalSessionManager::default().into(),
Default::default(),
);
let app = axum::Router::new().nest_service("/mcp", service);
let server_handle = tokio::spawn(async move { let server_handle = tokio::spawn(async move {
axum::serve(listener, app) axum::serve(listener, app)
...@@ -59,142 +134,10 @@ impl Drop for MockMCPServer { ...@@ -59,142 +134,10 @@ impl Drop for MockMCPServer {
} }
} }
/// Handle MCP requests and return mock responses
async fn handle_mcp_request(Json(request): Json<Value>) -> Result<ResponseJson<Value>, StatusCode> {
// Parse the JSON-RPC request
let method = request.get("method").and_then(|m| m.as_str()).unwrap_or("");
let id = request
.get("id")
.and_then(|i| i.as_str())
.unwrap_or("unknown");
let response = match method {
"initialize" => {
// Mock initialize response
json!({
"jsonrpc": "2.0",
"id": id,
"result": {
"serverInfo": {
"name": "Mock MCP Server",
"version": "1.0.0"
},
"instructions": "Mock server for testing"
}
})
}
"tools/list" => {
// Mock tools list response
json!({
"jsonrpc": "2.0",
"id": id,
"result": {
"tools": [
{
"name": "brave_web_search",
"description": "Mock web search tool",
"inputSchema": {
"type": "object",
"properties": {
"query": {"type": "string"},
"count": {"type": "integer"}
},
"required": ["query"]
}
},
{
"name": "brave_local_search",
"description": "Mock local search tool",
"inputSchema": {
"type": "object",
"properties": {
"query": {"type": "string"}
},
"required": ["query"]
}
}
]
}
})
}
"tools/call" => {
// Mock tool call response
let empty_json = json!({});
let params = request.get("params").unwrap_or(&empty_json);
let tool_name = params.get("name").and_then(|n| n.as_str()).unwrap_or("");
let empty_args = json!({});
let arguments = params.get("arguments").unwrap_or(&empty_args);
match tool_name {
"brave_web_search" => {
let query = arguments
.get("query")
.and_then(|q| q.as_str())
.unwrap_or("test");
json!({
"jsonrpc": "2.0",
"id": id,
"result": {
"content": [
{
"type": "text",
"text": format!("Mock search results for: {}", query)
}
],
"isError": false
}
})
}
"brave_local_search" => {
json!({
"jsonrpc": "2.0",
"id": id,
"result": {
"content": [
{
"type": "text",
"text": "Mock local search results"
}
],
"isError": false
}
})
}
_ => {
// Unknown tool
json!({
"jsonrpc": "2.0",
"id": id,
"error": {
"code": -1,
"message": format!("Unknown tool: {}", tool_name)
}
})
}
}
}
_ => {
// Unknown method
json!({
"jsonrpc": "2.0",
"id": id,
"error": {
"code": -32601,
"message": format!("Method not found: {}", method)
}
})
}
};
Ok(ResponseJson(response))
}
#[cfg(test)] #[cfg(test)]
#[allow(unused_imports)]
mod tests { mod tests {
#[allow(unused_imports)]
use super::MockMCPServer; use super::MockMCPServer;
use serde_json::{json, Value};
#[tokio::test] #[tokio::test]
async fn test_mock_server_startup() { async fn test_mock_server_startup() {
...@@ -205,32 +148,32 @@ mod tests { ...@@ -205,32 +148,32 @@ mod tests {
} }
#[tokio::test] #[tokio::test]
async fn test_mock_server_responses() { async fn test_mock_server_with_rmcp_client() {
let mut server = MockMCPServer::start().await.unwrap(); let mut server = MockMCPServer::start().await.unwrap();
let client = reqwest::Client::new();
// Test that we can connect with rmcp client
// Test initialize use rmcp::transport::StreamableHttpClientTransport;
let init_request = json!({ use rmcp::ServiceExt;
"jsonrpc": "2.0",
"id": "1", let transport = StreamableHttpClientTransport::from_uri(server.url().as_str());
"method": "initialize", let client = ().serve(transport).await;
"params": {
"protocolVersion": "2024-11-05", assert!(client.is_ok(), "Should be able to connect to mock server");
"capabilities": {}
if let Ok(client) = client {
// Test listing tools
let tools = client.peer().list_all_tools().await;
assert!(tools.is_ok(), "Should be able to list tools");
if let Ok(tools) = tools {
assert_eq!(tools.len(), 2, "Should have 2 tools");
assert!(tools.iter().any(|t| t.name == "brave_web_search"));
assert!(tools.iter().any(|t| t.name == "brave_local_search"));
} }
});
let response = client // Shutdown by dropping the client
.post(server.url()) drop(client);
.json(&init_request) }
.send()
.await
.unwrap();
assert!(response.status().is_success());
let json: Value = response.json().await.unwrap();
assert_eq!(json["jsonrpc"], "2.0");
assert_eq!(json["result"]["serverInfo"]["name"], "Mock MCP Server");
server.stop().await; server.stop().await;
} }
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment