connection_pool.rs 9.67 KB
Newer Older
1
2
3
4
/// MCP Connection Pool
///
/// This module provides connection pooling for dynamic MCP servers (per-request).
use std::sync::Arc;
5

6
7
use lru::LruCache;
use parking_lot::Mutex;
8
9
10
11
12
13
14
15
16
17
use rmcp::{service::RunningService, RoleClient};

use crate::mcp::{
    config::{McpProxyConfig, McpServerConfig},
    error::McpResult,
};

/// Type alias for MCP client
type McpClient = RunningService<RoleClient, ()>;

18
19
20
/// Type alias for eviction callback
type EvictionCallback = Arc<dyn Fn(&str) + Send + Sync>;

21
22
23
24
25
26
27
28
29
30
31
32
/// Cached MCP connection with metadata
#[derive(Clone)]
pub struct CachedConnection {
    /// The MCP client instance
    pub client: Arc<McpClient>,
    /// Server configuration used to create this connection
    pub config: McpServerConfig,
}

impl CachedConnection {
    /// Create a new cached connection
    pub fn new(client: Arc<McpClient>, config: McpServerConfig) -> Self {
33
        Self { client, config }
34
35
36
37
38
    }
}

/// Connection pool for dynamic MCP servers
///
39
/// Provides thread-safe connection pooling with LRU eviction.
40
41
/// Connections are keyed by server URL and reused across requests.
pub struct McpConnectionPool {
42
43
    /// LRU cache of server_url -> cached connection
    connections: Arc<Mutex<LruCache<String, CachedConnection>>>,
44

45
    /// Maximum number of cached connections (LRU capacity)
46
47
48
49
50
    max_connections: usize,

    /// Global proxy configuration (applied to all dynamic servers)
    /// Can be overridden per-server via McpServerConfig.proxy
    global_proxy: Option<McpProxyConfig>,
51
52
53
54

    /// Optional eviction callback (called when LRU evicts a connection)
    /// Used to clean up tools from inventory
    eviction_callback: Option<EvictionCallback>,
55
56
57
}

impl McpConnectionPool {
58
59
60
    /// Default max connections for pool
    const DEFAULT_MAX_CONNECTIONS: usize = 200;

61
62
63
    /// Create a new connection pool with default settings
    ///
    /// Default settings:
64
    /// - max_connections: 200
65
66
67
    /// - global_proxy: Loaded from environment variables (MCP_HTTP_PROXY, etc.)
    pub fn new() -> Self {
        Self {
68
69
70
71
            connections: Arc::new(Mutex::new(LruCache::new(
                std::num::NonZeroUsize::new(Self::DEFAULT_MAX_CONNECTIONS).unwrap(),
            ))),
            max_connections: Self::DEFAULT_MAX_CONNECTIONS,
72
            global_proxy: McpProxyConfig::from_env(),
73
            eviction_callback: None,
74
75
76
        }
    }

77
78
    /// Create a new connection pool with custom capacity
    pub fn with_capacity(max_connections: usize) -> Self {
79
        Self {
80
81
82
            connections: Arc::new(Mutex::new(LruCache::new(
                std::num::NonZeroUsize::new(max_connections).unwrap(),
            ))),
83
84
            max_connections,
            global_proxy: McpProxyConfig::from_env(),
85
            eviction_callback: None,
86
87
88
89
        }
    }

    /// Create a new connection pool with full custom configuration
90
    pub fn with_full_config(max_connections: usize, global_proxy: Option<McpProxyConfig>) -> Self {
91
        Self {
92
93
94
            connections: Arc::new(Mutex::new(LruCache::new(
                std::num::NonZeroUsize::new(max_connections).unwrap(),
            ))),
95
96
            max_connections,
            global_proxy,
97
            eviction_callback: None,
98
99
100
        }
    }

101
102
103
104
105
106
107
108
    /// Set the eviction callback (called when LRU evicts a connection)
    pub fn set_eviction_callback<F>(&mut self, callback: F)
    where
        F: Fn(&str) + Send + Sync + 'static,
    {
        self.eviction_callback = Some(Arc::new(callback));
    }

109
110
111
    /// Get an existing connection or create a new one
    ///
    /// This method:
112
113
114
    /// 1. Checks if a connection exists for the given URL (fast path <1ms)
    /// 2. If exists, promotes it in LRU and returns it
    /// 3. If not exists, creates new connection (slow path 70-650ms)
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    ///
    /// # Arguments
    /// * `server_url` - The MCP server URL (used as cache key)
    /// * `server_config` - Server configuration (used to create new connection if needed)
    /// * `connect_fn` - Async function to create a new client connection
    ///
    /// # Returns
    /// Arc to the MCP client, either from cache or newly created
    pub async fn get_or_create<F, Fut>(
        &self,
        server_url: &str,
        server_config: McpServerConfig,
        connect_fn: F,
    ) -> McpResult<Arc<McpClient>>
    where
        F: FnOnce(McpServerConfig, Option<McpProxyConfig>) -> Fut,
        Fut: std::future::Future<Output = McpResult<McpClient>>,
    {
133
134
135
136
137
        // Fast path: Check if connection exists in LRU cache
        {
            let mut connections = self.connections.lock();
            if let Some(cached) = connections.get(server_url) {
                // LRU get() promotes the entry
138
139
140
141
142
143
144
145
                return Ok(Arc::clone(&cached.client));
            }
        }

        // Slow path: Create new connection
        let client = connect_fn(server_config.clone(), self.global_proxy.clone()).await?;
        let client_arc = Arc::new(client);

146
        // Cache the new connection (LRU will automatically evict oldest if at capacity)
147
        let cached = CachedConnection::new(Arc::clone(&client_arc), server_config);
148
149
150
151
152
153
154
155
156
157
158
        {
            let mut connections = self.connections.lock();
            if let Some((evicted_key, _evicted_conn)) =
                connections.push(server_url.to_string(), cached)
            {
                // Call eviction callback if set
                if let Some(callback) = &self.eviction_callback {
                    callback(&evicted_key);
                }
            }
        }
159
160
161
162
163
164

        Ok(client_arc)
    }

    /// Get current number of cached connections
    pub fn len(&self) -> usize {
165
        self.connections.lock().len()
166
167
168
169
    }

    /// Check if pool is empty
    pub fn is_empty(&self) -> bool {
170
        self.connections.lock().is_empty()
171
172
    }

173
    /// Clear all connections
174
    pub fn clear(&self) {
175
        self.connections.lock().clear();
176
177
178
179
    }

    /// Get connection statistics
    pub fn stats(&self) -> PoolStats {
180
        let total = self.connections.lock().len();
181
182
183

        PoolStats {
            total_connections: total,
184
            capacity: self.max_connections,
185
186
        }
    }
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204

    /// List all server keys in the pool
    pub fn list_server_keys(&self) -> Vec<String> {
        self.connections
            .lock()
            .iter()
            .map(|(key, _)| key.clone())
            .collect()
    }

    /// Get a connection by server key without creating it
    /// Promotes the entry in LRU cache if found
    pub fn get(&self, server_key: &str) -> Option<Arc<McpClient>> {
        self.connections
            .lock()
            .get(server_key)
            .map(|cached| Arc::clone(&cached.client))
    }
205
206
207
208
209
210
211
212
213
214
215
216
}

impl Default for McpConnectionPool {
    fn default() -> Self {
        Self::new()
    }
}

/// Connection pool statistics
#[derive(Debug, Clone)]
pub struct PoolStats {
    pub total_connections: usize,
217
    pub capacity: usize,
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::mcp::McpTransport;

    // Helper to create test server config
    fn create_test_config(url: &str) -> McpServerConfig {
        McpServerConfig {
            name: "test_server".to_string(),
            transport: McpTransport::Streamable {
                url: url.to_string(),
                token: None,
            },
            proxy: None,
            required: false,
        }
    }

    #[tokio::test]
    async fn test_pool_creation() {
        let pool = McpConnectionPool::new();
        assert_eq!(pool.len(), 0);
        assert!(pool.is_empty());
    }

    #[test]
    fn test_pool_stats() {
247
        let pool = McpConnectionPool::with_capacity(10);
248
249
250

        let stats = pool.stats();
        assert_eq!(stats.total_connections, 0);
251
        assert_eq!(stats.capacity, 10);
252
253
254
255
256
257
258
259
260
261
262
263
264
    }

    #[test]
    #[allow(invalid_value)]
    fn test_pool_clear() {
        let pool = McpConnectionPool::new();

        // Add a connection
        let config = create_test_config("http://localhost:3000");
        let client: Arc<McpClient> =
            Arc::new(unsafe { std::mem::MaybeUninit::zeroed().assume_init() });
        let cached = CachedConnection::new(client.clone(), config);
        pool.connections
265
266
            .lock()
            .push("http://localhost:3000".to_string(), cached);
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291

        assert_eq!(pool.len(), 1);

        pool.clear();
        assert_eq!(pool.len(), 0);
        assert!(pool.is_empty());

        // Prevent drop of invalid Arc (would segfault)
        std::mem::forget(client);
    }

    #[test]
    fn test_pool_with_global_proxy() {
        use crate::mcp::McpProxyConfig;

        // Create proxy config
        let proxy = McpProxyConfig {
            http: Some("http://proxy.example.com:8080".to_string()),
            https: None,
            no_proxy: Some("localhost,127.0.0.1".to_string()),
            username: None,
            password: None,
        };

        // Create pool with proxy
292
        let pool = McpConnectionPool::with_full_config(100, Some(proxy.clone()));
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318

        // Verify proxy is stored
        assert!(pool.global_proxy.is_some());
        let stored_proxy = pool.global_proxy.as_ref().unwrap();
        assert_eq!(
            stored_proxy.http.as_ref().unwrap(),
            "http://proxy.example.com:8080"
        );
        assert_eq!(
            stored_proxy.no_proxy.as_ref().unwrap(),
            "localhost,127.0.0.1"
        );
    }

    #[test]
    fn test_pool_proxy_from_env() {
        // Note: This test depends on environment variables
        // In production, proxy is loaded from MCP_HTTP_PROXY or HTTP_PROXY env vars
        let pool = McpConnectionPool::new();

        // Pool should either have proxy from env or None
        // We can't assert specific value since it depends on test environment
        // Just verify it doesn't panic
        assert!(pool.global_proxy.is_some() || pool.global_proxy.is_none());
    }
}