Unverified Commit d717e73e authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router] refactor mcp to use LRU and fix pooling bug (#12346)

parent a1816187
...@@ -38,6 +38,7 @@ futures-util = "0.3" ...@@ -38,6 +38,7 @@ futures-util = "0.3"
futures = "0.3" futures = "0.3"
pyo3 = { version = "0.27.1", features = ["extension-module", "abi3-py38"] } pyo3 = { version = "0.27.1", features = ["extension-module", "abi3-py38"] }
dashmap = "6.1.0" dashmap = "6.1.0"
lru = "0.16.2"
blake3 = "1.5" blake3 = "1.5"
http = "1.1.0" http = "1.1.0"
tokio = { version = "1.42.0", features = ["full"] } tokio = { version = "1.42.0", features = ["full"] }
......
//! MCP configuration types and utilities.
//!
//! Defines configuration structures for MCP servers, transports, proxies, and inventory.
use std::collections::HashMap; use std::collections::HashMap;
// Re-export rmcp types for convenient access
pub use rmcp::model::{Prompt, RawResource, Tool}; pub use rmcp::model::{Prompt, RawResource, Tool};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
// ============================================================================
// Configuration Structures
// ============================================================================
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct McpConfig { pub struct McpConfig {
/// Static MCP servers (loaded at startup) /// Static MCP servers (loaded at startup)
......
// MCP Connection Pool /// MCP Connection Pool
// ///
// This module provides connection pooling for dynamic MCP servers (per-request). /// This module provides connection pooling for dynamic MCP servers (per-request).
// Connections are cached and reused to avoid 70-650ms connection overhead on every request. use std::sync::Arc;
//
// Performance target:
// - First request: 70-650ms (connection establishment)
// - Subsequent requests: <1ms (cache hit)
// - 90%+ reduction in per-request overhead
use std::{
sync::Arc,
time::{Duration, Instant},
};
use dashmap::DashMap; use lru::LruCache;
use parking_lot::Mutex;
use rmcp::{service::RunningService, RoleClient}; use rmcp::{service::RunningService, RoleClient};
use crate::mcp::{ use crate::mcp::{
...@@ -24,13 +15,14 @@ use crate::mcp::{ ...@@ -24,13 +15,14 @@ use crate::mcp::{
/// Type alias for MCP client /// Type alias for MCP client
type McpClient = RunningService<RoleClient, ()>; type McpClient = RunningService<RoleClient, ()>;
/// Type alias for eviction callback
type EvictionCallback = Arc<dyn Fn(&str) + Send + Sync>;
/// Cached MCP connection with metadata /// Cached MCP connection with metadata
#[derive(Clone)] #[derive(Clone)]
pub struct CachedConnection { pub struct CachedConnection {
/// The MCP client instance /// The MCP client instance
pub client: Arc<McpClient>, pub client: Arc<McpClient>,
/// Last time this connection was accessed
pub last_used: Instant,
/// Server configuration used to create this connection /// Server configuration used to create this connection
pub config: McpServerConfig, pub config: McpServerConfig,
} }
...@@ -38,89 +30,88 @@ pub struct CachedConnection { ...@@ -38,89 +30,88 @@ pub struct CachedConnection {
impl CachedConnection { impl CachedConnection {
/// Create a new cached connection /// Create a new cached connection
pub fn new(client: Arc<McpClient>, config: McpServerConfig) -> Self { pub fn new(client: Arc<McpClient>, config: McpServerConfig) -> Self {
Self { Self { client, config }
client,
last_used: Instant::now(),
config,
}
}
/// Update last_used timestamp
pub fn touch(&mut self) {
self.last_used = Instant::now();
}
/// Check if connection has been idle for longer than TTL
pub fn is_idle(&self, idle_ttl: Duration) -> bool {
self.last_used.elapsed() > idle_ttl
} }
} }
/// Connection pool for dynamic MCP servers /// Connection pool for dynamic MCP servers
/// ///
/// Provides thread-safe connection pooling with automatic cleanup of idle connections. /// Provides thread-safe connection pooling with LRU eviction.
/// Connections are keyed by server URL and reused across requests. /// Connections are keyed by server URL and reused across requests.
pub struct McpConnectionPool { pub struct McpConnectionPool {
/// Map of server_url -> cached connection /// LRU cache of server_url -> cached connection
connections: DashMap<String, CachedConnection>, connections: Arc<Mutex<LruCache<String, CachedConnection>>>,
/// Idle connection TTL (connections unused for this duration are cleaned up)
idle_ttl: Duration,
/// Maximum number of cached connections (prevents unbounded growth) /// Maximum number of cached connections (LRU capacity)
max_connections: usize, max_connections: usize,
/// Global proxy configuration (applied to all dynamic servers) /// Global proxy configuration (applied to all dynamic servers)
/// Can be overridden per-server via McpServerConfig.proxy /// Can be overridden per-server via McpServerConfig.proxy
global_proxy: Option<McpProxyConfig>, global_proxy: Option<McpProxyConfig>,
/// Optional eviction callback (called when LRU evicts a connection)
/// Used to clean up tools from inventory
eviction_callback: Option<EvictionCallback>,
} }
impl McpConnectionPool { impl McpConnectionPool {
/// Default max connections for pool
const DEFAULT_MAX_CONNECTIONS: usize = 200;
/// Create a new connection pool with default settings /// Create a new connection pool with default settings
/// ///
/// Default settings: /// Default settings:
/// - idle_ttl: 300 seconds (5 minutes) /// - max_connections: 200
/// - max_connections: 100
/// - global_proxy: Loaded from environment variables (MCP_HTTP_PROXY, etc.) /// - global_proxy: Loaded from environment variables (MCP_HTTP_PROXY, etc.)
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
connections: DashMap::new(), connections: Arc::new(Mutex::new(LruCache::new(
idle_ttl: Duration::from_secs(300), std::num::NonZeroUsize::new(Self::DEFAULT_MAX_CONNECTIONS).unwrap(),
max_connections: 100, ))),
max_connections: Self::DEFAULT_MAX_CONNECTIONS,
global_proxy: McpProxyConfig::from_env(), global_proxy: McpProxyConfig::from_env(),
eviction_callback: None,
} }
} }
/// Create a new connection pool with custom settings /// Create a new connection pool with custom capacity
pub fn with_config(idle_ttl: Duration, max_connections: usize) -> Self { pub fn with_capacity(max_connections: usize) -> Self {
Self { Self {
connections: DashMap::new(), connections: Arc::new(Mutex::new(LruCache::new(
idle_ttl, std::num::NonZeroUsize::new(max_connections).unwrap(),
))),
max_connections, max_connections,
global_proxy: McpProxyConfig::from_env(), global_proxy: McpProxyConfig::from_env(),
eviction_callback: None,
} }
} }
/// Create a new connection pool with full custom configuration /// Create a new connection pool with full custom configuration
pub fn with_full_config( pub fn with_full_config(max_connections: usize, global_proxy: Option<McpProxyConfig>) -> Self {
idle_ttl: Duration,
max_connections: usize,
global_proxy: Option<McpProxyConfig>,
) -> Self {
Self { Self {
connections: DashMap::new(), connections: Arc::new(Mutex::new(LruCache::new(
idle_ttl, std::num::NonZeroUsize::new(max_connections).unwrap(),
))),
max_connections, max_connections,
global_proxy, global_proxy,
eviction_callback: None,
}
} }
/// Set the eviction callback (called when LRU evicts a connection)
pub fn set_eviction_callback<F>(&mut self, callback: F)
where
F: Fn(&str) + Send + Sync + 'static,
{
self.eviction_callback = Some(Arc::new(callback));
} }
/// Get an existing connection or create a new one /// Get an existing connection or create a new one
/// ///
/// This method: /// This method:
/// 1. Checks if a connection exists for the given URL /// 1. Checks if a connection exists for the given URL (fast path <1ms)
/// 2. If exists and fresh, updates last_used and returns it (fast path <1ms) /// 2. If exists, promotes it in LRU and returns it
/// 3. If not exists or stale, creates new connection (slow path 70-650ms) /// 3. If not exists, creates new connection (slow path 70-650ms)
/// ///
/// # Arguments /// # Arguments
/// * `server_url` - The MCP server URL (used as cache key) /// * `server_url` - The MCP server URL (used as cache key)
...@@ -139,97 +130,78 @@ impl McpConnectionPool { ...@@ -139,97 +130,78 @@ impl McpConnectionPool {
F: FnOnce(McpServerConfig, Option<McpProxyConfig>) -> Fut, F: FnOnce(McpServerConfig, Option<McpProxyConfig>) -> Fut,
Fut: std::future::Future<Output = McpResult<McpClient>>, Fut: std::future::Future<Output = McpResult<McpClient>>,
{ {
// Fast path: Check if connection exists and is still fresh // Fast path: Check if connection exists in LRU cache
if let Some(mut entry) = self.connections.get_mut(server_url) { {
let cached = entry.value_mut(); let mut connections = self.connections.lock();
if let Some(cached) = connections.get(server_url) {
// Check if connection is still within TTL // LRU get() promotes the entry
if !cached.is_idle(self.idle_ttl) {
// Update last_used and return cached connection
cached.touch();
return Ok(Arc::clone(&cached.client)); return Ok(Arc::clone(&cached.client));
} }
// Connection is stale, drop it and create new one
drop(entry);
self.connections.remove(server_url);
} }
// Slow path: Create new connection // Slow path: Create new connection
// Enforce max_connections limit
if self.connections.len() >= self.max_connections {
self.cleanup_idle_connections();
// If still at limit after cleanup, remove oldest connection
if self.connections.len() >= self.max_connections {
if let Some(oldest_key) = self.find_oldest_connection() {
self.connections.remove(&oldest_key);
}
}
}
// Create new MCP client using the provided connect function
let client = connect_fn(server_config.clone(), self.global_proxy.clone()).await?; let client = connect_fn(server_config.clone(), self.global_proxy.clone()).await?;
let client_arc = Arc::new(client); let client_arc = Arc::new(client);
// Cache the new connection // Cache the new connection (LRU will automatically evict oldest if at capacity)
let cached = CachedConnection::new(Arc::clone(&client_arc), server_config); let cached = CachedConnection::new(Arc::clone(&client_arc), server_config);
self.connections.insert(server_url.to_string(), cached); {
let mut connections = self.connections.lock();
Ok(client_arc) if let Some((evicted_key, _evicted_conn)) =
connections.push(server_url.to_string(), cached)
{
// Call eviction callback if set
if let Some(callback) = &self.eviction_callback {
callback(&evicted_key);
}
} }
/// Remove all idle connections that have exceeded the TTL
///
/// This method is called:
/// - Automatically when max_connections limit is reached
/// - Can be called manually by background cleanup task
pub fn cleanup_idle_connections(&self) {
let now = Instant::now();
self.connections
.retain(|_, cached| now.duration_since(cached.last_used) < self.idle_ttl);
} }
/// Find the oldest connection (by last_used timestamp) Ok(client_arc)
///
/// Used for eviction when max_connections is reached and cleanup didn't free space
fn find_oldest_connection(&self) -> Option<String> {
self.connections
.iter()
.min_by_key(|entry| entry.value().last_used)
.map(|entry| entry.key().clone())
} }
/// Get current number of cached connections /// Get current number of cached connections
pub fn len(&self) -> usize { pub fn len(&self) -> usize {
self.connections.len() self.connections.lock().len()
} }
/// Check if pool is empty /// Check if pool is empty
pub fn is_empty(&self) -> bool { pub fn is_empty(&self) -> bool {
self.connections.is_empty() self.connections.lock().is_empty()
} }
/// Clear all connections (useful for tests) /// Clear all connections
pub fn clear(&self) { pub fn clear(&self) {
self.connections.clear(); self.connections.lock().clear();
} }
/// Get connection statistics /// Get connection statistics
pub fn stats(&self) -> PoolStats { pub fn stats(&self) -> PoolStats {
let total = self.connections.len(); let total = self.connections.lock().len();
let idle_count = self
.connections
.iter()
.filter(|entry| entry.value().is_idle(self.idle_ttl))
.count();
PoolStats { PoolStats {
total_connections: total, total_connections: total,
active_connections: total - idle_count, capacity: self.max_connections,
idle_connections: idle_count,
} }
} }
/// List all server keys in the pool
pub fn list_server_keys(&self) -> Vec<String> {
self.connections
.lock()
.iter()
.map(|(key, _)| key.clone())
.collect()
}
/// Get a connection by server key without creating it
/// Promotes the entry in LRU cache if found
pub fn get(&self, server_key: &str) -> Option<Arc<McpClient>> {
self.connections
.lock()
.get(server_key)
.map(|cached| Arc::clone(&cached.client))
}
} }
impl Default for McpConnectionPool { impl Default for McpConnectionPool {
...@@ -242,8 +214,7 @@ impl Default for McpConnectionPool { ...@@ -242,8 +214,7 @@ impl Default for McpConnectionPool {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct PoolStats { pub struct PoolStats {
pub total_connections: usize, pub total_connections: usize,
pub active_connections: usize, pub capacity: usize,
pub idle_connections: usize,
} }
#[cfg(test)] #[cfg(test)]
...@@ -271,114 +242,13 @@ mod tests { ...@@ -271,114 +242,13 @@ mod tests {
assert!(pool.is_empty()); assert!(pool.is_empty());
} }
#[test]
#[allow(invalid_value)]
fn test_cached_connection_touch() {
let config = create_test_config("http://localhost:3000");
let client: Arc<McpClient> = Arc::new(unsafe {
// SAFETY: This is only for testing the CachedConnection struct
std::mem::MaybeUninit::zeroed().assume_init()
});
let mut cached = CachedConnection::new(client.clone(), config);
let first_time = cached.last_used;
std::thread::sleep(Duration::from_millis(10));
cached.touch();
assert!(cached.last_used > first_time);
// Prevent drop of invalid Arc (would segfault)
std::mem::forget(client);
}
#[test]
#[allow(invalid_value)]
fn test_cached_connection_is_idle() {
let config = create_test_config("http://localhost:3000");
let client: Arc<McpClient> = Arc::new(unsafe {
// SAFETY: This is only for testing the CachedConnection struct
std::mem::MaybeUninit::zeroed().assume_init()
});
let cached = CachedConnection::new(client.clone(), config);
// Fresh connection should not be idle
assert!(!cached.is_idle(Duration::from_secs(1)));
// Wait and check
std::thread::sleep(Duration::from_millis(100));
assert!(cached.is_idle(Duration::from_millis(50)));
// Prevent drop of invalid Arc (would segfault)
std::mem::forget(client);
}
#[test] #[test]
fn test_pool_stats() { fn test_pool_stats() {
let pool = McpConnectionPool::with_config(Duration::from_millis(100), 10); let pool = McpConnectionPool::with_capacity(10);
let stats = pool.stats(); let stats = pool.stats();
assert_eq!(stats.total_connections, 0); assert_eq!(stats.total_connections, 0);
assert_eq!(stats.active_connections, 0); assert_eq!(stats.capacity, 10);
assert_eq!(stats.idle_connections, 0);
}
#[test]
#[allow(invalid_value)]
fn test_cleanup_idle_connections() {
let pool = McpConnectionPool::with_config(Duration::from_millis(50), 10);
// Initially empty
assert_eq!(pool.len(), 0);
// Add a connection manually for testing
let config = create_test_config("http://localhost:3000");
let client: Arc<McpClient> =
Arc::new(unsafe { std::mem::MaybeUninit::zeroed().assume_init() });
let cached = CachedConnection::new(client.clone(), config);
pool.connections
.insert("http://localhost:3000".to_string(), cached);
assert_eq!(pool.len(), 1);
// Wait for TTL to expire
std::thread::sleep(Duration::from_millis(100));
// Cleanup should remove idle connection
pool.cleanup_idle_connections();
assert_eq!(pool.len(), 0);
// Prevent drop of invalid Arc (would segfault)
std::mem::forget(client);
}
#[test]
#[allow(invalid_value)]
fn test_find_oldest_connection() {
let pool = McpConnectionPool::new();
// Collect clients to forget at end
let mut clients = Vec::new();
// Add connections with different timestamps
for i in 0..3 {
let url = format!("http://localhost:{}", 3000 + i);
let config = create_test_config(&url);
let client: Arc<McpClient> =
Arc::new(unsafe { std::mem::MaybeUninit::zeroed().assume_init() });
let cached = CachedConnection::new(client.clone(), config);
pool.connections.insert(url, cached);
clients.push(client);
std::thread::sleep(Duration::from_millis(10));
}
// Oldest should be the first one
let oldest = pool.find_oldest_connection();
assert!(oldest.is_some());
assert_eq!(oldest.unwrap(), "http://localhost:3000");
// Prevent drop of invalid Arcs (would segfault)
for client in clients {
std::mem::forget(client);
}
} }
#[test] #[test]
...@@ -392,7 +262,8 @@ mod tests { ...@@ -392,7 +262,8 @@ mod tests {
Arc::new(unsafe { std::mem::MaybeUninit::zeroed().assume_init() }); Arc::new(unsafe { std::mem::MaybeUninit::zeroed().assume_init() });
let cached = CachedConnection::new(client.clone(), config); let cached = CachedConnection::new(client.clone(), config);
pool.connections pool.connections
.insert("http://localhost:3000".to_string(), cached); .lock()
.push("http://localhost:3000".to_string(), cached);
assert_eq!(pool.len(), 1); assert_eq!(pool.len(), 1);
...@@ -418,8 +289,7 @@ mod tests { ...@@ -418,8 +289,7 @@ mod tests {
}; };
// Create pool with proxy // Create pool with proxy
let pool = let pool = McpConnectionPool::with_full_config(100, Some(proxy.clone()));
McpConnectionPool::with_full_config(Duration::from_secs(300), 100, Some(proxy.clone()));
// Verify proxy is stored // Verify proxy is stored
assert!(pool.global_proxy.is_some()); assert!(pool.global_proxy.is_some());
......
//! MCP error types.
//!
//! Defines error variants for MCP operations including connection, tool execution,
//! and configuration errors.
use thiserror::Error; use thiserror::Error;
pub type McpResult<T> = Result<T, McpError>; pub type McpResult<T> = Result<T, McpError>;
......
// MCP Tool Inventory with TTL-based Caching //! MCP tool, prompt, and resource inventory.
// //!
// This module provides TTL-based caching for MCP tools, prompts, and resources. //! Thread-safe cache for MCP capabilities across all connected servers.
// Tools are cached with timestamps and automatically expire after the configured TTL.
// Background refresh tasks can proactively update the inventory.
use std::time::{Duration, Instant};
use dashmap::DashMap; use dashmap::DashMap;
...@@ -15,7 +11,6 @@ use crate::mcp::config::{Prompt, RawResource, Tool}; ...@@ -15,7 +11,6 @@ use crate::mcp::config::{Prompt, RawResource, Tool};
pub struct CachedTool { pub struct CachedTool {
pub server_name: String, pub server_name: String,
pub tool: Tool, pub tool: Tool,
pub cached_at: Instant,
} }
/// Cached prompt with metadata /// Cached prompt with metadata
...@@ -23,7 +18,6 @@ pub struct CachedTool { ...@@ -23,7 +18,6 @@ pub struct CachedTool {
pub struct CachedPrompt { pub struct CachedPrompt {
pub server_name: String, pub server_name: String,
pub prompt: Prompt, pub prompt: Prompt,
pub cached_at: Instant,
} }
/// Cached resource with metadata /// Cached resource with metadata
...@@ -31,13 +25,12 @@ pub struct CachedPrompt { ...@@ -31,13 +25,12 @@ pub struct CachedPrompt {
pub struct CachedResource { pub struct CachedResource {
pub server_name: String, pub server_name: String,
pub resource: RawResource, pub resource: RawResource,
pub cached_at: Instant,
} }
/// Tool inventory with TTL-based caching /// Tool inventory with periodic refresh
/// ///
/// Provides thread-safe caching of MCP tools, prompts, and resources with automatic expiration. /// Provides thread-safe caching of MCP tools, prompts, and resources.
/// Entries are timestamped and can be queried with TTL validation. /// Entries are refreshed periodically by background tasks.
pub struct ToolInventory { pub struct ToolInventory {
/// Map of tool_name -> cached tool /// Map of tool_name -> cached tool
tools: DashMap<String, CachedTool>, tools: DashMap<String, CachedTool>,
...@@ -47,80 +40,59 @@ pub struct ToolInventory { ...@@ -47,80 +40,59 @@ pub struct ToolInventory {
/// Map of resource_uri -> cached resource /// Map of resource_uri -> cached resource
resources: DashMap<String, CachedResource>, resources: DashMap<String, CachedResource>,
/// Tool cache TTL
tool_ttl: Duration,
/// Last refresh time per server
server_refresh_times: DashMap<String, Instant>,
} }
impl ToolInventory { impl ToolInventory {
/// Create a new tool inventory with the specified TTL /// Create a new tool inventory
pub fn new(tool_ttl: Duration) -> Self { pub fn new() -> Self {
Self { Self {
tools: DashMap::new(), tools: DashMap::new(),
prompts: DashMap::new(), prompts: DashMap::new(),
resources: DashMap::new(), resources: DashMap::new(),
tool_ttl,
server_refresh_times: DashMap::new(),
} }
} }
}
impl Default for ToolInventory {
fn default() -> Self {
Self::new()
}
}
impl ToolInventory {
// ============================================================================ // ============================================================================
// Tool Methods // Tool Methods
// ============================================================================ // ============================================================================
/// Get a tool if it exists and is fresh (within TTL) /// Get a tool if it exists
///
/// Returns None if the tool doesn't exist or has expired.
pub fn get_tool(&self, tool_name: &str) -> Option<(String, Tool)> { pub fn get_tool(&self, tool_name: &str) -> Option<(String, Tool)> {
self.tools.get(tool_name).and_then(|entry| { self.tools
let cached = entry.value(); .get(tool_name)
.map(|entry| (entry.server_name.clone(), entry.tool.clone()))
// Check if still fresh
if cached.cached_at.elapsed() < self.tool_ttl {
Some((cached.server_name.clone(), cached.tool.clone()))
} else {
// Expired - will be removed by cleanup
None
}
})
} }
/// Check if tool exists (regardless of TTL) /// Check if tool exists
pub fn has_tool(&self, tool_name: &str) -> bool { pub fn has_tool(&self, tool_name: &str) -> bool {
self.tools.contains_key(tool_name) self.tools.contains_key(tool_name)
} }
/// Insert or update a tool /// Insert or update a tool
pub fn insert_tool(&self, tool_name: String, server_name: String, tool: Tool) { pub fn insert_tool(&self, tool_name: String, server_name: String, tool: Tool) {
self.tools.insert( self.tools
tool_name, .insert(tool_name, CachedTool { server_name, tool });
CachedTool {
server_name,
tool,
cached_at: Instant::now(),
},
);
} }
/// Get all tools (fresh only) /// Get all tools
pub fn list_tools(&self) -> Vec<(String, String, Tool)> { pub fn list_tools(&self) -> Vec<(String, String, Tool)> {
let now = Instant::now();
self.tools self.tools
.iter() .iter()
.filter_map(|entry| { .map(|entry| {
let (name, cached) = entry.pair(); let (name, cached) = entry.pair();
if now.duration_since(cached.cached_at) < self.tool_ttl { (
Some((
name.clone(), name.clone(),
cached.server_name.clone(), cached.server_name.clone(),
cached.tool.clone(), cached.tool.clone(),
)) )
} else {
None
}
}) })
.collect() .collect()
} }
...@@ -129,21 +101,14 @@ impl ToolInventory { ...@@ -129,21 +101,14 @@ impl ToolInventory {
// Prompt Methods // Prompt Methods
// ============================================================================ // ============================================================================
/// Get a prompt if it exists and is fresh (within TTL) /// Get a prompt if it exists
pub fn get_prompt(&self, prompt_name: &str) -> Option<(String, Prompt)> { pub fn get_prompt(&self, prompt_name: &str) -> Option<(String, Prompt)> {
self.prompts.get(prompt_name).and_then(|entry| { self.prompts
let cached = entry.value(); .get(prompt_name)
.map(|entry| (entry.server_name.clone(), entry.prompt.clone()))
// Check if still fresh
if cached.cached_at.elapsed() < self.tool_ttl {
Some((cached.server_name.clone(), cached.prompt.clone()))
} else {
None
}
})
} }
/// Check if prompt exists (regardless of TTL) /// Check if prompt exists
pub fn has_prompt(&self, prompt_name: &str) -> bool { pub fn has_prompt(&self, prompt_name: &str) -> bool {
self.prompts.contains_key(prompt_name) self.prompts.contains_key(prompt_name)
} }
...@@ -155,27 +120,21 @@ impl ToolInventory { ...@@ -155,27 +120,21 @@ impl ToolInventory {
CachedPrompt { CachedPrompt {
server_name, server_name,
prompt, prompt,
cached_at: Instant::now(),
}, },
); );
} }
/// Get all prompts (fresh only) /// Get all prompts
pub fn list_prompts(&self) -> Vec<(String, String, Prompt)> { pub fn list_prompts(&self) -> Vec<(String, String, Prompt)> {
let now = Instant::now();
self.prompts self.prompts
.iter() .iter()
.filter_map(|entry| { .map(|entry| {
let (name, cached) = entry.pair(); let (name, cached) = entry.pair();
if now.duration_since(cached.cached_at) < self.tool_ttl { (
Some((
name.clone(), name.clone(),
cached.server_name.clone(), cached.server_name.clone(),
cached.prompt.clone(), cached.prompt.clone(),
)) )
} else {
None
}
}) })
.collect() .collect()
} }
...@@ -184,21 +143,14 @@ impl ToolInventory { ...@@ -184,21 +143,14 @@ impl ToolInventory {
// Resource Methods // Resource Methods
// ============================================================================ // ============================================================================
/// Get a resource if it exists and is fresh (within TTL) /// Get a resource if it exists
pub fn get_resource(&self, resource_uri: &str) -> Option<(String, RawResource)> { pub fn get_resource(&self, resource_uri: &str) -> Option<(String, RawResource)> {
self.resources.get(resource_uri).and_then(|entry| { self.resources
let cached = entry.value(); .get(resource_uri)
.map(|entry| (entry.server_name.clone(), entry.resource.clone()))
// Check if still fresh
if cached.cached_at.elapsed() < self.tool_ttl {
Some((cached.server_name.clone(), cached.resource.clone()))
} else {
None
}
})
} }
/// Check if resource exists (regardless of TTL) /// Check if resource exists
pub fn has_resource(&self, resource_uri: &str) -> bool { pub fn has_resource(&self, resource_uri: &str) -> bool {
self.resources.contains_key(resource_uri) self.resources.contains_key(resource_uri)
} }
...@@ -215,27 +167,21 @@ impl ToolInventory { ...@@ -215,27 +167,21 @@ impl ToolInventory {
CachedResource { CachedResource {
server_name, server_name,
resource, resource,
cached_at: Instant::now(),
}, },
); );
} }
/// Get all resources (fresh only) /// Get all resources
pub fn list_resources(&self) -> Vec<(String, String, RawResource)> { pub fn list_resources(&self) -> Vec<(String, String, RawResource)> {
let now = Instant::now();
self.resources self.resources
.iter() .iter()
.filter_map(|entry| { .map(|entry| {
let (uri, cached) = entry.pair(); let (uri, cached) = entry.pair();
if now.duration_since(cached.cached_at) < self.tool_ttl { (
Some((
uri.clone(), uri.clone(),
cached.server_name.clone(), cached.server_name.clone(),
cached.resource.clone(), cached.resource.clone(),
)) )
} else {
None
}
}) })
.collect() .collect()
} }
...@@ -244,7 +190,7 @@ impl ToolInventory { ...@@ -244,7 +190,7 @@ impl ToolInventory {
// Server Management Methods // Server Management Methods
// ============================================================================ // ============================================================================
/// Clear all cached items for a specific server (before refresh) /// Clear all cached items for a specific server (called when LRU evicts client)
pub fn clear_server_tools(&self, server_name: &str) { pub fn clear_server_tools(&self, server_name: &str) {
self.tools self.tools
.retain(|_, cached| cached.server_name != server_name); .retain(|_, cached| cached.server_name != server_name);
...@@ -254,51 +200,6 @@ impl ToolInventory { ...@@ -254,51 +200,6 @@ impl ToolInventory {
.retain(|_, cached| cached.server_name != server_name); .retain(|_, cached| cached.server_name != server_name);
} }
/// Mark server as refreshed
pub fn mark_refreshed(&self, server_name: &str) {
self.server_refresh_times
.insert(server_name.to_string(), Instant::now());
}
/// Check if server needs refresh based on refresh interval
pub fn needs_refresh(&self, server_name: &str, refresh_interval: Duration) -> bool {
self.server_refresh_times
.get(server_name)
.map(|t| t.elapsed() > refresh_interval)
.unwrap_or(true) // Never refreshed = needs refresh
}
/// Get last refresh time for a server
pub fn last_refresh(&self, server_name: &str) -> Option<Instant> {
self.server_refresh_times
.get(server_name)
.map(|t| *t.value())
}
// ============================================================================
// Cleanup Methods
// ============================================================================
/// Cleanup expired entries
///
/// Removes all tools, prompts, and resources that have exceeded their TTL.
/// Should be called periodically by a background task.
pub fn cleanup_expired(&self) {
let now = Instant::now();
// Remove expired tools
self.tools
.retain(|_, cached| now.duration_since(cached.cached_at) < self.tool_ttl);
// Remove expired prompts
self.prompts
.retain(|_, cached| now.duration_since(cached.cached_at) < self.tool_ttl);
// Remove expired resources
self.resources
.retain(|_, cached| now.duration_since(cached.cached_at) < self.tool_ttl);
}
/// Get count of cached items /// Get count of cached items
pub fn counts(&self) -> (usize, usize, usize) { pub fn counts(&self) -> (usize, usize, usize) {
(self.tools.len(), self.prompts.len(), self.resources.len()) (self.tools.len(), self.prompts.len(), self.resources.len())
...@@ -309,7 +210,6 @@ impl ToolInventory { ...@@ -309,7 +210,6 @@ impl ToolInventory {
self.tools.clear(); self.tools.clear();
self.prompts.clear(); self.prompts.clear();
self.resources.clear(); self.resources.clear();
self.server_refresh_times.clear();
} }
} }
...@@ -370,7 +270,7 @@ mod tests { ...@@ -370,7 +270,7 @@ mod tests {
#[test] #[test]
fn test_tool_insert_and_get() { fn test_tool_insert_and_get() {
let inventory = ToolInventory::new(Duration::from_secs(60)); let inventory = ToolInventory::new();
let tool = create_test_tool("test_tool"); let tool = create_test_tool("test_tool");
inventory.insert_tool("test_tool".to_string(), "server1".to_string(), tool.clone()); inventory.insert_tool("test_tool".to_string(), "server1".to_string(), tool.clone());
...@@ -383,30 +283,9 @@ mod tests { ...@@ -383,30 +283,9 @@ mod tests {
assert_eq!(retrieved_tool.name, "test_tool"); assert_eq!(retrieved_tool.name, "test_tool");
} }
#[test]
fn test_tool_expiration() {
let inventory = ToolInventory::new(Duration::from_millis(100));
let tool = create_test_tool("expiring_tool");
inventory.insert_tool(
"expiring_tool".to_string(),
"server1".to_string(),
tool.clone(),
);
// Should be available immediately
assert!(inventory.get_tool("expiring_tool").is_some());
// Wait for expiration
std::thread::sleep(Duration::from_millis(150));
// Should be expired now
assert!(inventory.get_tool("expiring_tool").is_none());
}
#[test] #[test]
fn test_has_tool() { fn test_has_tool() {
let inventory = ToolInventory::new(Duration::from_secs(60)); let inventory = ToolInventory::new();
let tool = create_test_tool("check_tool"); let tool = create_test_tool("check_tool");
assert!(!inventory.has_tool("check_tool")); assert!(!inventory.has_tool("check_tool"));
...@@ -418,7 +297,7 @@ mod tests { ...@@ -418,7 +297,7 @@ mod tests {
#[test] #[test]
fn test_list_tools() { fn test_list_tools() {
let inventory = ToolInventory::new(Duration::from_secs(60)); let inventory = ToolInventory::new();
inventory.insert_tool( inventory.insert_tool(
"tool1".to_string(), "tool1".to_string(),
...@@ -440,29 +319,9 @@ mod tests { ...@@ -440,29 +319,9 @@ mod tests {
assert_eq!(tools.len(), 3); assert_eq!(tools.len(), 3);
} }
#[test]
fn test_list_tools_filters_expired() {
let inventory = ToolInventory::new(Duration::from_millis(100));
inventory.insert_tool(
"tool1".to_string(),
"server1".to_string(),
create_test_tool("tool1"),
);
// Should have 1 tool
assert_eq!(inventory.list_tools().len(), 1);
// Wait for expiration
std::thread::sleep(Duration::from_millis(150));
// Should have 0 tools (filtered out)
assert_eq!(inventory.list_tools().len(), 0);
}
#[test] #[test]
fn test_clear_server_tools() { fn test_clear_server_tools() {
let inventory = ToolInventory::new(Duration::from_secs(60)); let inventory = ToolInventory::new();
inventory.insert_tool( inventory.insert_tool(
"tool1".to_string(), "tool1".to_string(),
...@@ -484,55 +343,9 @@ mod tests { ...@@ -484,55 +343,9 @@ mod tests {
assert_eq!(tools[0].0, "tool2"); assert_eq!(tools[0].0, "tool2");
} }
#[test]
fn test_server_refresh_tracking() {
let inventory = ToolInventory::new(Duration::from_secs(60));
// Never refreshed
assert!(inventory.needs_refresh("server1", Duration::from_secs(10)));
// Mark as refreshed
inventory.mark_refreshed("server1");
// Should not need refresh immediately
assert!(!inventory.needs_refresh("server1", Duration::from_secs(10)));
// Wait and check again
std::thread::sleep(Duration::from_millis(100));
assert!(inventory.needs_refresh("server1", Duration::from_millis(50)));
}
#[test]
fn test_cleanup_expired() {
let inventory = ToolInventory::new(Duration::from_millis(100));
inventory.insert_tool(
"tool1".to_string(),
"server1".to_string(),
create_test_tool("tool1"),
);
inventory.insert_tool(
"tool2".to_string(),
"server1".to_string(),
create_test_tool("tool2"),
);
let (tools, _, _) = inventory.counts();
assert_eq!(tools, 2);
// Wait for expiration
std::thread::sleep(Duration::from_millis(150));
// Cleanup expired entries
inventory.cleanup_expired();
let (tools, _, _) = inventory.counts();
assert_eq!(tools, 0);
}
#[test] #[test]
fn test_prompt_operations() { fn test_prompt_operations() {
let inventory = ToolInventory::new(Duration::from_secs(60)); let inventory = ToolInventory::new();
let prompt = create_test_prompt("test_prompt"); let prompt = create_test_prompt("test_prompt");
inventory.insert_prompt( inventory.insert_prompt(
...@@ -553,7 +366,7 @@ mod tests { ...@@ -553,7 +366,7 @@ mod tests {
#[test] #[test]
fn test_resource_operations() { fn test_resource_operations() {
let inventory = ToolInventory::new(Duration::from_secs(60)); let inventory = ToolInventory::new();
let resource = create_test_resource("file:///test.txt"); let resource = create_test_resource("file:///test.txt");
inventory.insert_resource( inventory.insert_resource(
...@@ -576,7 +389,7 @@ mod tests { ...@@ -576,7 +389,7 @@ mod tests {
async fn test_concurrent_access() { async fn test_concurrent_access() {
use std::sync::Arc; use std::sync::Arc;
let inventory = Arc::new(ToolInventory::new(Duration::from_secs(60))); let inventory = Arc::new(ToolInventory::new());
// Spawn multiple tasks that insert tools concurrently // Spawn multiple tasks that insert tools concurrently
let mut handles = vec![]; let mut handles = vec![];
...@@ -601,7 +414,7 @@ mod tests { ...@@ -601,7 +414,7 @@ mod tests {
#[test] #[test]
fn test_clear_all() { fn test_clear_all() {
let inventory = ToolInventory::new(Duration::from_secs(60)); let inventory = ToolInventory::new();
inventory.insert_tool( inventory.insert_tool(
"tool1".to_string(), "tool1".to_string(),
...@@ -619,8 +432,6 @@ mod tests { ...@@ -619,8 +432,6 @@ mod tests {
create_test_resource("res1"), create_test_resource("res1"),
); );
inventory.mark_refreshed("server1");
let (tools, prompts, resources) = inventory.counts(); let (tools, prompts, resources) = inventory.counts();
assert_eq!(tools, 1); assert_eq!(tools, 1);
assert_eq!(prompts, 1); assert_eq!(prompts, 1);
...@@ -632,6 +443,5 @@ mod tests { ...@@ -632,6 +443,5 @@ mod tests {
assert_eq!(tools, 0); assert_eq!(tools, 0);
assert_eq!(prompts, 0); assert_eq!(prompts, 0);
assert_eq!(resources, 0); assert_eq!(resources, 0);
assert!(inventory.last_refresh("server1").is_none());
} }
} }
//! Refactored MCP Manager - Single flat structure for all MCP operations //! MCP client management and orchestration.
//! //!
//! This replaces the previous hierarchy: //! Manages both static MCP servers (from config) and dynamic MCP servers (from requests).
//! - McpManager (wrapper for static/dynamic distinction) //! Static clients are never evicted; dynamic clients use LRU eviction via connection pool.
//! - McpClientManager (manages multiple clients)
//! - McpClient (actual client)
//!
//! New flat structure:
//! - McpManager (single component handling all MCP concerns)
//! - McpClient (actual client to one server)
use std::{borrow::Cow, sync::Arc, time::Duration}; use std::{borrow::Cow, sync::Arc, time::Duration};
...@@ -40,68 +34,47 @@ use crate::mcp::{ ...@@ -40,68 +34,47 @@ use crate::mcp::{
/// Type alias for MCP client /// Type alias for MCP client
type McpClient = RunningService<RoleClient, ()>; type McpClient = RunningService<RoleClient, ()>;
/// Unified MCP Manager - handles all MCP operations
///
/// This single component manages:
/// - Client connections (both static and dynamic)
/// - Tool inventory and caching
/// - Connection pooling
/// - Background refresh
/// - Tool/prompt/resource operations
pub struct McpManager { pub struct McpManager {
/// All MCP clients (static + dynamic) static_clients: Arc<DashMap<String, Arc<McpClient>>>,
/// Key: server_name for static, server_url for dynamic
/// Using DashMap for concurrent access
clients: Arc<DashMap<String, Arc<McpClient>>>,
/// Track which servers are static (from config)
/// Using DashMap for thread-safe mutation during workflow registration
static_servers: Arc<DashMap<String, ()>>,
/// Shared tool inventory with TTL and caching
inventory: Arc<ToolInventory>, inventory: Arc<ToolInventory>,
/// Connection pool for dynamic servers (TTL-based cleanup)
connection_pool: Arc<McpConnectionPool>, connection_pool: Arc<McpConnectionPool>,
/// Original config for static servers (kept for potential future use)
_config: McpConfig, _config: McpConfig,
} }
impl McpManager { impl McpManager {
/// Create a new MCP manager with custom TTLs const MAX_DYNAMIC_CLIENTS: usize = 200;
pub async fn new(
config: McpConfig, pub async fn new(config: McpConfig, pool_max_connections: usize) -> McpResult<Self> {
tool_ttl: Duration, let inventory = Arc::new(ToolInventory::new());
pool_idle_ttl: Duration,
pool_max_connections: usize, let mut connection_pool =
) -> McpResult<Self> { McpConnectionPool::with_full_config(pool_max_connections, config.proxy.clone());
// Create shared inventory
let inventory = Arc::new(ToolInventory::new(tool_ttl)); let inventory_clone = Arc::clone(&inventory);
connection_pool.set_eviction_callback(move |server_key: &str| {
// Create connection pool debug!(
let connection_pool = Arc::new(McpConnectionPool::with_config( "LRU evicted dynamic server '{}' - clearing tools from inventory",
pool_idle_ttl, server_key
pool_max_connections, );
)); inventory_clone.clear_server_tools(server_key);
});
// Create manager structure
let clients = Arc::new(DashMap::new()); let connection_pool = Arc::new(connection_pool);
let static_servers = Arc::new(DashMap::new());
// Create storage for static clients
let static_clients = Arc::new(DashMap::new());
// Get global proxy config for all servers // Get global proxy config for all servers
let global_proxy = config.proxy.as_ref(); let global_proxy = config.proxy.as_ref();
// Connect to all static servers from config // Connect to all static servers from config
for server_config in &config.servers { for server_config in &config.servers {
static_servers.insert(server_config.name.clone(), ());
match Self::connect_server(server_config, global_proxy).await { match Self::connect_server(server_config, global_proxy).await {
Ok(client) => { Ok(client) => {
let client_arc = Arc::new(client); let client_arc = Arc::new(client);
// Load inventory for this server // Load inventory for this server
Self::load_server_inventory(&inventory, &server_config.name, &client_arc).await; Self::load_server_inventory(&inventory, &server_config.name, &client_arc).await;
clients.insert(server_config.name.clone(), client_arc); static_clients.insert(server_config.name.clone(), client_arc);
info!("Connected to static server '{}'", server_config.name); info!("Connected to static server '{}'", server_config.name);
} }
Err(e) => { Err(e) => {
...@@ -113,52 +86,40 @@ impl McpManager { ...@@ -113,52 +86,40 @@ impl McpManager {
} }
} }
if static_servers.is_empty() || clients.is_empty() { if static_clients.is_empty() {
warn!("No static MCP servers connected"); warn!("No static MCP servers connected");
} }
Ok(Self { Ok(Self {
clients, static_clients,
static_servers,
inventory, inventory,
connection_pool, connection_pool,
_config: config, _config: config,
}) })
} }
/// Create with default settings (300s TTL, 300s idle, 100 max connections)
pub async fn with_defaults(config: McpConfig) -> McpResult<Self> { pub async fn with_defaults(config: McpConfig) -> McpResult<Self> {
Self::new( Self::new(config, Self::MAX_DYNAMIC_CLIENTS).await
config,
Duration::from_secs(300),
Duration::from_secs(300),
100,
)
.await
} }
// ========================================================================
// Client Management
// ========================================================================
/// Get a client by server name (static or dynamic)
pub async fn get_client(&self, server_name: &str) -> Option<Arc<McpClient>> { pub async fn get_client(&self, server_name: &str) -> Option<Arc<McpClient>> {
self.clients.get(server_name).map(|e| Arc::clone(e.value())) if let Some(client) = self.static_clients.get(server_name) {
return Some(Arc::clone(client.value()));
}
self.connection_pool.get(server_name)
} }
/// Get or create a dynamic client from server config
pub async fn get_or_create_client( pub async fn get_or_create_client(
&self, &self,
server_config: McpServerConfig, server_config: McpServerConfig,
) -> McpResult<Arc<McpClient>> { ) -> McpResult<Arc<McpClient>> {
// Check if client already exists let server_name = server_config.name.clone();
let server_key = Self::server_key(&server_config);
if let Some(client) = self.clients.get(&server_key) { if let Some(client) = self.static_clients.get(&server_name) {
return Ok(Arc::clone(client.value())); return Ok(Arc::clone(client.value()));
} }
// Client doesn't exist, create new one via connection pool let server_key = Self::server_key(&server_config);
let client = self let client = self
.connection_pool .connection_pool
.get_or_create( .get_or_create(
...@@ -170,47 +131,27 @@ impl McpManager { ...@@ -170,47 +131,27 @@ impl McpManager {
) )
.await?; .await?;
// Store in clients map self.inventory.clear_server_tools(&server_key);
self.clients.insert(server_key, Arc::clone(&client)); Self::load_server_inventory(&self.inventory, &server_key, &client).await;
Ok(client) Ok(client)
} }
/// List all static server names
pub fn list_static_servers(&self) -> Vec<String> { pub fn list_static_servers(&self) -> Vec<String> {
self.static_servers self.static_clients
.iter() .iter()
.map(|e| e.key().clone()) .map(|e| e.key().clone())
.collect() .collect()
} }
/// Check if a server is static
pub fn is_static_server(&self, server_name: &str) -> bool { pub fn is_static_server(&self, server_name: &str) -> bool {
self.static_servers.contains_key(server_name) self.static_clients.contains_key(server_name)
} }
/// Register a static server (called by workflow system)
///
/// This method registers a static MCP server that was configured and connected
/// via the workflow system. Static servers are never removed during runtime.
///
/// # Arguments
/// * `name` - Unique server name (from config)
/// * `client` - Connected MCP client
pub fn register_static_server(&self, name: String, client: Arc<McpClient>) { pub fn register_static_server(&self, name: String, client: Arc<McpClient>) {
// Insert into clients map self.static_clients.insert(name.clone(), client);
self.clients.insert(name.clone(), client);
// Mark as static server (for background refresh and stats)
self.static_servers.insert(name.clone(), ());
info!("Registered static MCP server: {}", name); info!("Registered static MCP server: {}", name);
} }
// ========================================================================
// Tool Operations (delegate to clients via inventory)
// ========================================================================
/// List all available tools from all servers /// List all available tools from all servers
pub fn list_tools(&self) -> Vec<Tool> { pub fn list_tools(&self) -> Vec<Tool> {
self.inventory self.inventory
...@@ -267,10 +208,6 @@ impl McpManager { ...@@ -267,10 +208,6 @@ impl McpManager {
.map(|(_server_name, tool_info)| tool_info) .map(|(_server_name, tool_info)| tool_info)
} }
// ========================================================================
// Prompt Operations
// ========================================================================
/// Get a prompt by name /// Get a prompt by name
pub async fn get_prompt( pub async fn get_prompt(
&self, &self,
...@@ -310,10 +247,6 @@ impl McpManager { ...@@ -310,10 +247,6 @@ impl McpManager {
.collect() .collect()
} }
// ========================================================================
// Resource Operations
// ========================================================================
/// Read a resource by URI /// Read a resource by URI
pub async fn read_resource(&self, uri: &str) -> McpResult<ReadResourceResult> { pub async fn read_resource(&self, uri: &str) -> McpResult<ReadResourceResult> {
// Get server that owns this resource // Get server that owns this resource
...@@ -348,10 +281,6 @@ impl McpManager { ...@@ -348,10 +281,6 @@ impl McpManager {
.collect() .collect()
} }
// ========================================================================
// Inventory Management
// ========================================================================
/// Refresh inventory for a specific server /// Refresh inventory for a specific server
pub async fn refresh_server_inventory(&self, server_name: &str) -> McpResult<()> { pub async fn refresh_server_inventory(&self, server_name: &str) -> McpResult<()> {
let client = self let client = self
...@@ -365,7 +294,8 @@ impl McpManager { ...@@ -365,7 +294,8 @@ impl McpManager {
Ok(()) Ok(())
} }
/// Start background refresh for all static servers /// Start background refresh for ALL servers (static + dynamic)
/// Refreshes every 10-15 minutes to keep tool inventory up-to-date
pub fn spawn_background_refresh_all( pub fn spawn_background_refresh_all(
self: Arc<Self>, self: Arc<Self>,
refresh_interval: Duration, refresh_interval: Duration,
...@@ -377,17 +307,24 @@ impl McpManager { ...@@ -377,17 +307,24 @@ impl McpManager {
loop { loop {
interval.tick().await; interval.tick().await;
let server_names = self.list_static_servers(); // Get all static server keys
// Note: Dynamic clients in the connection pool are refreshed on-demand
// when they are accessed via get_or_create_client()
let server_keys: Vec<String> = self
.static_clients
.iter()
.map(|e| e.key().clone())
.collect();
if !server_names.is_empty() { if !server_keys.is_empty() {
debug!( debug!(
"Background refresh: Refreshing {} static server(s)", "Background refresh: Refreshing {} static server(s)",
server_names.len() server_keys.len()
); );
for server_name in server_names { for server_key in server_keys {
if let Err(e) = self.refresh_server_inventory(&server_name).await { if let Err(e) = self.refresh_server_inventory(&server_key).await {
warn!("Background refresh failed for '{}': {}", server_name, e); warn!("Background refresh failed for '{}': {}", server_key, e);
} }
} }
...@@ -397,10 +334,6 @@ impl McpManager { ...@@ -397,10 +334,6 @@ impl McpManager {
}) })
} }
// ========================================================================
// Additional Tool/Prompt/Resource Methods
// ========================================================================
/// Check if a tool exists /// Check if a tool exists
pub fn has_tool(&self, name: &str) -> bool { pub fn has_tool(&self, name: &str) -> bool {
self.inventory.has_tool(name) self.inventory.has_tool(name)
...@@ -462,41 +395,56 @@ impl McpManager { ...@@ -462,41 +395,56 @@ impl McpManager {
.map_err(|e| McpError::ToolExecution(format!("Failed to unsubscribe: {}", e))) .map_err(|e| McpError::ToolExecution(format!("Failed to unsubscribe: {}", e)))
} }
/// List all connected servers /// List all connected servers (static + dynamic)
pub fn list_servers(&self) -> Vec<String> { pub fn list_servers(&self) -> Vec<String> {
self.clients.iter().map(|e| e.key().clone()).collect() let mut servers = Vec::new();
// Add static servers
servers.extend(self.static_clients.iter().map(|e| e.key().clone()));
// Add dynamic servers from connection pool
servers.extend(self.connection_pool.list_server_keys());
servers
} }
/// Disconnect from all servers (for cleanup) /// Disconnect from all servers (for cleanup)
pub async fn shutdown(&self) { pub async fn shutdown(&self) {
let keys: Vec<String> = self.clients.iter().map(|e| e.key().clone()).collect(); // Shutdown static servers
let static_keys: Vec<String> = self
for name in keys { .static_clients
if let Some((_, client)) = self.clients.remove(&name) { .iter()
.map(|e| e.key().clone())
.collect();
for name in static_keys {
if let Some((_key, client)) = self.static_clients.remove(&name) {
// Try to unwrap Arc to call cancel // Try to unwrap Arc to call cancel
match Arc::try_unwrap(client) { match Arc::try_unwrap(client) {
Ok(client) => { Ok(client) => {
if let Err(e) = client.cancel().await { if let Err(e) = client.cancel().await {
warn!("Error disconnecting from '{}': {}", name, e); warn!("Error disconnecting from static server '{}': {}", name, e);
} }
} }
Err(_) => { Err(_) => {
warn!("Could not shutdown '{}': client still in use", name); warn!(
} "Could not shutdown static server '{}': client still in use",
name
);
} }
} }
} }
} }
// ======================================================================== // Clear dynamic clients from connection pool
// Statistics and Accessors // The pool will handle cleanup on drop
// ======================================================================== self.connection_pool.clear();
}
/// Get statistics about the manager /// Get statistics about the manager
pub fn stats(&self) -> McpManagerStats { pub fn stats(&self) -> McpManagerStats {
let (tools, prompts, resources) = self.inventory.counts(); let (tools, prompts, resources) = self.inventory.counts();
McpManagerStats { McpManagerStats {
static_server_count: self.static_servers.len(), static_server_count: self.static_clients.len(),
pool_stats: self.connection_pool.stats(), pool_stats: self.connection_pool.stats(),
tool_count: tools, tool_count: tools,
prompt_count: prompts, prompt_count: prompts,
...@@ -525,44 +473,41 @@ impl McpManager { ...@@ -525,44 +473,41 @@ impl McpManager {
/// It discovers all tools, prompts, and resources from the client and caches them in the inventory. /// It discovers all tools, prompts, and resources from the client and caches them in the inventory.
pub async fn load_server_inventory( pub async fn load_server_inventory(
inventory: &Arc<ToolInventory>, inventory: &Arc<ToolInventory>,
server_name: &str, server_key: &str,
client: &Arc<McpClient>, client: &Arc<McpClient>,
) { ) {
// Tools // Tools
match client.peer().list_all_tools().await { match client.peer().list_all_tools().await {
Ok(ts) => { Ok(ts) => {
info!("Discovered {} tools from '{}'", ts.len(), server_name); info!("Discovered {} tools from '{}'", ts.len(), server_key);
for t in ts { for t in ts {
inventory.insert_tool(t.name.to_string(), server_name.to_string(), t); inventory.insert_tool(t.name.to_string(), server_key.to_string(), t);
} }
} }
Err(e) => warn!("Failed to list tools from '{}': {}", server_name, e), Err(e) => warn!("Failed to list tools from '{}': {}", server_key, e),
} }
// Prompts // Prompts
match client.peer().list_all_prompts().await { match client.peer().list_all_prompts().await {
Ok(ps) => { Ok(ps) => {
info!("Discovered {} prompts from '{}'", ps.len(), server_name); info!("Discovered {} prompts from '{}'", ps.len(), server_key);
for p in ps { for p in ps {
inventory.insert_prompt(p.name.clone(), server_name.to_string(), p); inventory.insert_prompt(p.name.clone(), server_key.to_string(), p);
} }
} }
Err(e) => debug!("No prompts or failed to list on '{}': {}", server_name, e), Err(e) => debug!("No prompts or failed to list on '{}': {}", server_key, e),
} }
// Resources // Resources
match client.peer().list_all_resources().await { match client.peer().list_all_resources().await {
Ok(rs) => { Ok(rs) => {
info!("Discovered {} resources from '{}'", rs.len(), server_name); info!("Discovered {} resources from '{}'", rs.len(), server_key);
for r in rs { for r in rs {
inventory.insert_resource(r.uri.clone(), server_name.to_string(), r.raw); inventory.insert_resource(r.uri.clone(), server_key.to_string(), r.raw);
} }
} }
Err(e) => debug!("No resources or failed to list on '{}': {}", server_name, e), Err(e) => debug!("No resources or failed to list on '{}': {}", server_key, e),
} }
// Mark server as refreshed
inventory.mark_refreshed(server_name);
} }
/// Discover and cache tools/prompts/resources for a connected server (internal wrapper) /// Discover and cache tools/prompts/resources for a connected server (internal wrapper)
...@@ -602,9 +547,6 @@ impl McpManager { ...@@ -602,9 +547,6 @@ impl McpManager {
} }
Err(e) => debug!("No resources or failed to list on '{}': {}", server_name, e), Err(e) => debug!("No resources or failed to list on '{}': {}", server_name, e),
} }
// Mark server as refreshed
self.inventory.mark_refreshed(server_name);
} }
// ======================================================================== // ========================================================================
...@@ -834,14 +776,7 @@ mod tests { ...@@ -834,14 +776,7 @@ mod tests {
inventory: Default::default(), inventory: Default::default(),
}; };
let manager = McpManager::new( let manager = McpManager::new(config, 100).await.unwrap();
config,
Duration::from_secs(300),
Duration::from_secs(300),
100,
)
.await
.unwrap();
assert_eq!(manager.list_static_servers().len(), 0); assert_eq!(manager.list_static_servers().len(), 0);
} }
} }
// MCP Client for SGLang Router //! Model Context Protocol (MCP) client implementation.
// //!
// This module provides a complete MCP (Model Context Protocol) client implementation //! Provides MCP client functionality including tools, prompts, resources, and OAuth.
// supporting multiple transport types (stdio, SSE, HTTP) and all MCP features: //! Supports stdio, SSE, and HTTP transports with connection pooling and caching.
// - Tools: Discovery and execution
// - Prompts: Reusable templates for LLM interactions
// - Resources: File/data access with subscription support
// - OAuth: Secure authentication for remote servers
pub mod config; pub mod config;
pub mod connection_pool; pub mod connection_pool;
......
// OAuth authentication support for MCP servers //! OAuth authentication for MCP servers.
//!
//! Handles OAuth flow including callback server and token exchange.
use std::{net::SocketAddr, sync::Arc}; use std::{net::SocketAddr, sync::Arc};
......
// MCP Proxy Configuration and Resolution //! HTTP proxy configuration for MCP connections.
// //!
// This module provides proxy configuration resolution and HTTP client creation //! Resolves proxy settings and creates HTTP clients for MCP server connections.
// for MCP server connections. Proxy settings are MCP-specific and do NOT affect
// LLM API traffic.
use std::time::Duration; use std::time::Duration;
......
//! Tool arguments handling and type coercion //! MCP tool argument handling.
//! //!
//! This module provides utilities for handling MCP tool arguments, //! Supports both JSON strings and parsed Maps with automatic type coercion.
//! supporting both JSON strings and parsed Maps with automatic type coercion.
use serde_json::Map; use serde_json::Map;
......
...@@ -782,12 +782,12 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err ...@@ -782,12 +782,12 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
info!("No MCP config provided, skipping MCP server initialization"); info!("No MCP config provided, skipping MCP server initialization");
} }
// Start background refresh for all registered static MCP servers // Start background refresh for ALL MCP servers (static + dynamic in LRU cache)
if let Some(mcp_manager) = app_context.mcp_manager.get() { if let Some(mcp_manager) = app_context.mcp_manager.get() {
let refresh_interval = Duration::from_secs(300); // 5 minutes, matches default TTL let refresh_interval = Duration::from_secs(600); // 10 minutes
let _refresh_handle = let _refresh_handle =
Arc::clone(mcp_manager).spawn_background_refresh_all(refresh_interval); Arc::clone(mcp_manager).spawn_background_refresh_all(refresh_interval);
info!("Started background refresh for all static MCP servers"); info!("Started background refresh for all MCP servers (every 10 minutes)");
} }
let worker_stats = app_context.worker_registry.stats(); let worker_stats = app_context.worker_registry.stats();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment