"test/srt/old/test_httpserver_decode.py" did not exist on "6f560c761b2fc2f577682d0cfda62630f37a3bb0"
Unverified Commit 212f5e48 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] MCP Manager Refactoring - Flat Architecture with Connection Pooling (#12097)

parent fe527812
...@@ -206,6 +206,121 @@ python3 -m sglang_router.launch_router \ ...@@ -206,6 +206,121 @@ python3 -m sglang_router.launch_router \
- Provide exactly one `--worker-urls` entry per router instance. - Provide exactly one `--worker-urls` entry per router instance.
- The Rust binary supports the same flags (`./target/release/sglang-router --backend openai ...`). - The Rust binary supports the same flags (`./target/release/sglang-router --backend openai ...`).
### MCP Integration
The SGL Model Gateway provides native Model Context Protocol (MCP) client integration, enabling tool calling across STDIO, SSE, and Streamable transports. MCP servers are configured via a YAML configuration file and registered at startup through the workflow engine.
#### Basic Usage
```bash
# Rust binary
./target/release/sglang-router \
--mcp-config-path /path/to/mcp-config.yaml \
--worker-urls http://worker1:8000
# Python launcher
python3 -m sglang_router.launch_router \
--mcp-config-path /path/to/mcp-config.yaml \
--worker-urls http://worker1:8000
```
#### MCP Configuration File
Create an MCP configuration file to define servers, transports, and connection settings:
```yaml
servers:
- name: "filesystem"
command: "npx"
args: ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"]
required: false
- name: "github"
url: "https://api.github.com/mcp"
token: "ghp_xxxxx"
transport: "sse"
required: false
- name: "custom-tools"
url: "https://tools.example.com/mcp"
transport: "streamable"
required: true
pool:
max_connections: 100
idle_timeout: 300 # seconds
proxy:
http: "http://proxy.internal:8080"
https: "https://proxy.internal:8443"
no_proxy: "localhost,127.0.0.1,*.internal"
inventory:
enable_refresh: true
tool_ttl: 300 # seconds - how long tools are considered fresh
refresh_interval: 300 # seconds - background refresh interval
```
#### Configuration Options
**Server Configuration** (`servers` array):
- `name`: Unique identifier for the MCP server
- `command` + `args`: For STDIO transport (local process execution)
- `url`: For SSE or Streamable transports (HTTP/HTTPS endpoints)
- `token`: Optional authentication token for HTTP-based transports
- `transport`: Protocol type (`"sse"` or `"streamable"`; STDIO is inferred from `command`)
- `required`: If `true`, router fails to start if server is unreachable (default: `false`)
- `envs`: Environment variables for STDIO processes (optional)
- `proxy`: Per-server proxy override (set to `null` to bypass global proxy)
**Connection Pool** (`pool`):
- `max_connections`: Maximum pooled connections for dynamic servers (default: 100)
- `idle_timeout`: Idle connection timeout in seconds before cleanup (default: 300)
**Proxy Configuration** (`proxy`):
- `http`/`https`: Proxy URLs for MCP server connections (not LLM traffic)
- `no_proxy`: Comma-separated hosts to exclude from proxying (supports wildcards)
- **Note**: Proxy settings are currently ignored for `streamable` transport. Use STDIO or SSE transports if proxy support is required.
**Inventory Settings** (`inventory`):
- `enable_refresh`: Enable automatic background refresh of tool inventory (default: true)
- `tool_ttl`: Tool cache TTL in seconds - how long tools are considered fresh (default: 300)
- `refresh_interval`: Background refresh interval in seconds - proactive inventory refresh (default: 300)
#### Transport Types
**STDIO** (Local Process):
```yaml
name: "local-tools"
command: "python"
args: ["-m", "my_mcp_server"]
envs:
API_KEY: "secret"
DEBUG: "true"
```
**SSE** (Server-Sent Events):
```yaml
name: "remote-sse"
url: "https://mcp.example.com/events"
token: "bearer-token"
transport: "sse"
```
**Streamable** (Bidirectional Streaming):
```yaml
name: "streaming-tools"
url: "https://mcp.example.com/stream"
transport: "streamable"
required: true
```
#### Server Lifecycle
- MCP servers are registered via the workflow engine with retry logic (100 attempts, 2-hour timeout for STDIO servers)
- Discovery phase identifies tools, prompts, and resources
- Tool inventory is cached with configurable TTL and periodic refresh
- Failed optional servers log warnings; required servers halt startup
- Static servers (from config) are permanent; dynamic servers (per-request) use connection pooling
Check Prometheus metrics for MCP activity (`mcp_*` metrics) and workflow job status via the admin API.
### Python Launcher (Router + Workers) ### Python Launcher (Router + Workers)
Launch router and SGLang worker processes together; `launch_server` spins up workers (HTTP or gRPC) and the router in one shot. Launch router and SGLang worker processes together; `launch_server` spins up workers (HTTP or gRPC) and the router in one shot.
```bash ```bash
......
...@@ -94,6 +94,8 @@ class RouterArgs: ...@@ -94,6 +94,8 @@ class RouterArgs:
tokenizer_cache_l1_max_memory: int = 50 * 1024 * 1024 # 50MB tokenizer_cache_l1_max_memory: int = 50 * 1024 * 1024 # 50MB
reasoning_parser: Optional[str] = None reasoning_parser: Optional[str] = None
tool_call_parser: Optional[str] = None tool_call_parser: Optional[str] = None
# MCP server configuration
mcp_config_path: Optional[str] = None
# Backend selection # Backend selection
backend: str = "sglang" backend: str = "sglang"
# History backend configuration # History backend configuration
...@@ -512,6 +514,13 @@ class RouterArgs: ...@@ -512,6 +514,13 @@ class RouterArgs:
default=None, default=None,
help="Specify the parser for handling tool-call interactions", help="Specify the parser for handling tool-call interactions",
) )
# MCP server configuration
parser.add_argument(
f"--{prefix}mcp-config-path",
type=str,
default=None,
help="Path to MCP (Model Context Protocol) server configuration file",
)
# Backend selection # Backend selection
parser.add_argument( parser.add_argument(
f"--{prefix}backend", f"--{prefix}backend",
......
...@@ -13,6 +13,7 @@ use crate::{ ...@@ -13,6 +13,7 @@ use crate::{
create_storage, SharedConversationItemStorage, SharedConversationStorage, create_storage, SharedConversationItemStorage, SharedConversationStorage,
SharedResponseStorage, SharedResponseStorage,
}, },
mcp::McpManager,
middleware::TokenBucket, middleware::TokenBucket,
policies::PolicyRegistry, policies::PolicyRegistry,
reasoning_parser::ParserFactory as ReasoningParserFactory, reasoning_parser::ParserFactory as ReasoningParserFactory,
...@@ -56,6 +57,7 @@ pub struct AppContext { ...@@ -56,6 +57,7 @@ pub struct AppContext {
pub configured_tool_parser: Option<String>, pub configured_tool_parser: Option<String>,
pub worker_job_queue: Arc<OnceLock<Arc<JobQueue>>>, pub worker_job_queue: Arc<OnceLock<Arc<JobQueue>>>,
pub workflow_engine: Arc<OnceLock<Arc<WorkflowEngine>>>, pub workflow_engine: Arc<OnceLock<Arc<WorkflowEngine>>>,
pub mcp_manager: Arc<OnceLock<Arc<McpManager>>>,
} }
pub struct AppContextBuilder { pub struct AppContextBuilder {
...@@ -74,6 +76,7 @@ pub struct AppContextBuilder { ...@@ -74,6 +76,7 @@ pub struct AppContextBuilder {
load_monitor: Option<Arc<LoadMonitor>>, load_monitor: Option<Arc<LoadMonitor>>,
worker_job_queue: Option<Arc<OnceLock<Arc<JobQueue>>>>, worker_job_queue: Option<Arc<OnceLock<Arc<JobQueue>>>>,
workflow_engine: Option<Arc<OnceLock<Arc<WorkflowEngine>>>>, workflow_engine: Option<Arc<OnceLock<Arc<WorkflowEngine>>>>,
mcp_manager: Option<Arc<OnceLock<Arc<McpManager>>>>,
} }
impl AppContext { impl AppContext {
...@@ -112,6 +115,7 @@ impl AppContextBuilder { ...@@ -112,6 +115,7 @@ impl AppContextBuilder {
load_monitor: None, load_monitor: None,
worker_job_queue: None, worker_job_queue: None,
workflow_engine: None, workflow_engine: None,
mcp_manager: None,
} }
} }
...@@ -196,6 +200,11 @@ impl AppContextBuilder { ...@@ -196,6 +200,11 @@ impl AppContextBuilder {
self self
} }
pub fn mcp_manager(mut self, mcp_manager: Arc<OnceLock<Arc<McpManager>>>) -> Self {
self.mcp_manager = Some(mcp_manager);
self
}
pub fn build(self) -> Result<AppContext, AppContextBuildError> { pub fn build(self) -> Result<AppContext, AppContextBuildError> {
let router_config = self let router_config = self
.router_config .router_config
...@@ -235,6 +244,9 @@ impl AppContextBuilder { ...@@ -235,6 +244,9 @@ impl AppContextBuilder {
workflow_engine: self workflow_engine: self
.workflow_engine .workflow_engine
.ok_or(AppContextBuildError("workflow_engine"))?, .ok_or(AppContextBuildError("workflow_engine"))?,
mcp_manager: self
.mcp_manager
.ok_or(AppContextBuildError("mcp_manager"))?,
}) })
} }
...@@ -256,6 +268,8 @@ impl AppContextBuilder { ...@@ -256,6 +268,8 @@ impl AppContextBuilder {
.with_load_monitor(&router_config) .with_load_monitor(&router_config)
.with_worker_job_queue() .with_worker_job_queue()
.with_workflow_engine() .with_workflow_engine()
.with_mcp_manager(&router_config)
.await?
.router_config(router_config)) .router_config(router_config))
} }
...@@ -457,6 +471,38 @@ impl AppContextBuilder { ...@@ -457,6 +471,38 @@ impl AppContextBuilder {
self.workflow_engine = Some(Arc::new(OnceLock::new())); self.workflow_engine = Some(Arc::new(OnceLock::new()));
self self
} }
/// Create and initialize MCP manager with empty config
///
/// This initializes the MCP manager with an empty config and default settings.
/// MCP servers will be registered later via the InitializeMcpServers job.
async fn with_mcp_manager(mut self, _router_config: &RouterConfig) -> Result<Self, String> {
// Create OnceLock container
let mcp_manager_lock = Arc::new(OnceLock::new());
// Always create with empty config and defaults
info!("Initializing MCP manager with empty config and default settings (5 min TTL, 100 max connections)");
let empty_config = crate::mcp::McpConfig {
servers: Vec::new(),
pool: Default::default(),
proxy: None,
warmup: Vec::new(),
inventory: Default::default(),
};
let manager = McpManager::with_defaults(empty_config)
.await
.map_err(|e| format!("Failed to initialize MCP manager with defaults: {}", e))?;
// Store the initialized manager in the OnceLock
mcp_manager_lock
.set(Arc::new(manager))
.map_err(|_| "Failed to set MCP manager in OnceLock".to_string())?;
self.mcp_manager = Some(mcp_manager_lock);
Ok(self)
}
} }
impl Default for AppContextBuilder { impl Default for AppContextBuilder {
......
...@@ -3,7 +3,7 @@ use super::{ ...@@ -3,7 +3,7 @@ use super::{
HistoryBackend, MetricsConfig, OracleConfig, PolicyConfig, RetryConfig, RouterConfig, HistoryBackend, MetricsConfig, OracleConfig, PolicyConfig, RetryConfig, RouterConfig,
RoutingMode, TokenizerCacheConfig, RoutingMode, TokenizerCacheConfig,
}; };
use crate::core::ConnectionMode; use crate::{core::ConnectionMode, mcp::McpConfig};
/// Builder for RouterConfig that wraps the config itself /// Builder for RouterConfig that wraps the config itself
/// This eliminates field duplication and stays in sync automatically /// This eliminates field duplication and stays in sync automatically
...@@ -14,6 +14,7 @@ pub struct RouterConfigBuilder { ...@@ -14,6 +14,7 @@ pub struct RouterConfigBuilder {
client_cert_path: Option<String>, client_cert_path: Option<String>,
client_key_path: Option<String>, client_key_path: Option<String>,
ca_cert_paths: Vec<String>, ca_cert_paths: Vec<String>,
mcp_config_path: Option<String>,
} }
impl RouterConfigBuilder { impl RouterConfigBuilder {
...@@ -29,6 +30,7 @@ impl RouterConfigBuilder { ...@@ -29,6 +30,7 @@ impl RouterConfigBuilder {
client_cert_path: None, client_cert_path: None,
client_key_path: None, client_key_path: None,
ca_cert_paths: Vec::new(), ca_cert_paths: Vec::new(),
mcp_config_path: None,
} }
} }
...@@ -620,6 +622,21 @@ impl RouterConfigBuilder { ...@@ -620,6 +622,21 @@ impl RouterConfigBuilder {
self self
} }
// ==================== MCP Configuration ====================
/// Set MCP server configuration file path
/// The config file will be loaded during build()
pub fn mcp_config_path<S: Into<String>>(mut self, path: S) -> Self {
self.mcp_config_path = Some(path.into());
self
}
/// Set MCP server configuration file path if Some
pub fn maybe_mcp_config_path(mut self, path: Option<impl Into<String>>) -> Self {
self.mcp_config_path = path.map(|p| p.into());
self
}
// ==================== Builder Methods ==================== // ==================== Builder Methods ====================
/// Build the RouterConfig, validating if requested /// Build the RouterConfig, validating if requested
...@@ -637,6 +654,9 @@ impl RouterConfigBuilder { ...@@ -637,6 +654,9 @@ impl RouterConfigBuilder {
// Read mTLS certificates from paths if provided // Read mTLS certificates from paths if provided
self = self.read_mtls_certificates()?; self = self.read_mtls_certificates()?;
// Read MCP config from path if provided
self = self.read_mcp_config()?;
let config: RouterConfig = self.into(); let config: RouterConfig = self.into();
if validate { if validate {
config.validate()?; config.validate()?;
...@@ -695,6 +715,24 @@ impl RouterConfigBuilder { ...@@ -695,6 +715,24 @@ impl RouterConfigBuilder {
Ok(self) Ok(self)
} }
/// Internal method to read MCP config from path
fn read_mcp_config(mut self) -> ConfigResult<Self> {
if let Some(mcp_config_path) = &self.mcp_config_path {
let contents = std::fs::read_to_string(mcp_config_path).map_err(|e| {
ConfigError::ValidationFailed {
reason: format!("Failed to read MCP config from {}: {}", mcp_config_path, e),
}
})?;
let mcp_config: McpConfig =
serde_yaml::from_str(&contents).map_err(|e| ConfigError::ValidationFailed {
reason: format!("Failed to parse MCP config from {}: {}", mcp_config_path, e),
})?;
self.config.mcp_config = Some(mcp_config);
}
Ok(self)
}
} }
impl From<RouterConfigBuilder> for RouterConfig { impl From<RouterConfigBuilder> for RouterConfig {
......
...@@ -93,6 +93,10 @@ pub struct RouterConfig { ...@@ -93,6 +93,10 @@ pub struct RouterConfig {
/// Loaded from ca_cert_paths during config creation /// Loaded from ca_cert_paths during config creation
#[serde(default)] #[serde(default)]
pub ca_certificates: Vec<Vec<u8>>, pub ca_certificates: Vec<Vec<u8>>,
/// MCP server configuration (loaded from mcp_config_path during config creation)
/// This is loaded from the config file path and stored here for runtime use
#[serde(skip)]
pub mcp_config: Option<crate::mcp::McpConfig>,
} }
/// Tokenizer cache configuration /// Tokenizer cache configuration
...@@ -508,6 +512,7 @@ impl Default for RouterConfig { ...@@ -508,6 +512,7 @@ impl Default for RouterConfig {
tokenizer_cache: TokenizerCacheConfig::default(), tokenizer_cache: TokenizerCacheConfig::default(),
client_identity: None, client_identity: None,
ca_certificates: vec![], ca_certificates: vec![],
mcp_config: None,
} }
} }
} }
......
...@@ -17,9 +17,10 @@ use crate::{ ...@@ -17,9 +17,10 @@ use crate::{
app_context::AppContext, app_context::AppContext,
config::{RouterConfig, RoutingMode}, config::{RouterConfig, RoutingMode},
core::workflow::{ core::workflow::{
steps::WorkerRemovalRequest, WorkflowContext, WorkflowEngine, WorkflowId, steps::{McpServerConfigRequest, WorkerRemovalRequest},
WorkflowInstanceId, WorkflowStatus, WorkflowContext, WorkflowEngine, WorkflowId, WorkflowInstanceId, WorkflowStatus,
}, },
mcp::McpConfig,
metrics::RouterMetrics, metrics::RouterMetrics,
protocols::worker_spec::{JobStatus, WorkerConfigRequest}, protocols::worker_spec::{JobStatus, WorkerConfigRequest},
}; };
...@@ -30,6 +31,8 @@ pub enum Job { ...@@ -30,6 +31,8 @@ pub enum Job {
AddWorker { config: Box<WorkerConfigRequest> }, AddWorker { config: Box<WorkerConfigRequest> },
RemoveWorker { url: String }, RemoveWorker { url: String },
InitializeWorkersFromConfig { router_config: Box<RouterConfig> }, InitializeWorkersFromConfig { router_config: Box<RouterConfig> },
InitializeMcpServers { mcp_config: Box<McpConfig> },
RegisterMcpServer { config: Box<McpServerConfigRequest> },
} }
impl Job { impl Job {
...@@ -39,15 +42,19 @@ impl Job { ...@@ -39,15 +42,19 @@ impl Job {
Job::AddWorker { .. } => "AddWorker", Job::AddWorker { .. } => "AddWorker",
Job::RemoveWorker { .. } => "RemoveWorker", Job::RemoveWorker { .. } => "RemoveWorker",
Job::InitializeWorkersFromConfig { .. } => "InitializeWorkersFromConfig", Job::InitializeWorkersFromConfig { .. } => "InitializeWorkersFromConfig",
Job::InitializeMcpServers { .. } => "InitializeMcpServers",
Job::RegisterMcpServer { .. } => "RegisterMcpServer",
} }
} }
/// Get worker URL for logging /// Get worker URL or MCP server name for logging
pub fn worker_url(&self) -> &str { pub fn worker_url(&self) -> &str {
match self { match self {
Job::AddWorker { config } => &config.url, Job::AddWorker { config } => &config.url,
Job::RemoveWorker { url } => url, Job::RemoveWorker { url } => url,
Job::InitializeWorkersFromConfig { .. } => "startup", Job::InitializeWorkersFromConfig { .. } => "startup",
Job::InitializeMcpServers { .. } => "startup",
Job::RegisterMcpServer { config } => &config.name,
} }
} }
} }
...@@ -421,6 +428,64 @@ impl JobQueue { ...@@ -421,6 +428,64 @@ impl JobQueue {
Ok(format!("Submitted {} AddWorker jobs", worker_count)) Ok(format!("Submitted {} AddWorker jobs", worker_count))
} }
Job::InitializeMcpServers { mcp_config } => {
let mut server_count = 0;
debug!(
"Creating RegisterMcpServer jobs for {} MCP servers from config",
mcp_config.servers.len()
);
// Submit RegisterMcpServer jobs for each server in the config
for server_config in &mcp_config.servers {
let mcp_server_request = McpServerConfigRequest {
name: server_config.name.clone(),
config: server_config.clone(),
};
let job = Job::RegisterMcpServer {
config: Box::new(mcp_server_request),
};
if let Some(queue) = context.worker_job_queue.get() {
queue.submit(job).await.map_err(|e| {
format!(
"Failed to submit RegisterMcpServer job for '{}': {}",
server_config.name, e
)
})?;
server_count += 1;
} else {
return Err("JobQueue not available".to_string());
}
}
Ok(format!("Submitted {} RegisterMcpServer jobs", server_count))
}
Job::RegisterMcpServer { config } => {
let engine = context
.workflow_engine
.get()
.ok_or_else(|| "Workflow engine not initialized".to_string())?;
let instance_id =
Self::start_mcp_registration_workflow(engine, config, context).await?;
debug!(
"Started MCP registration workflow for {} (instance: {})",
config.name, instance_id
);
let timeout_duration = Duration::from_secs(7200 + 30); // 2hr + margin
Self::wait_for_workflow_completion(
engine,
instance_id,
&config.name,
timeout_duration,
)
.await
}
} }
} }
...@@ -461,6 +526,22 @@ impl JobQueue { ...@@ -461,6 +526,22 @@ impl JobQueue {
.map_err(|e| format!("Failed to start worker removal workflow: {:?}", e)) .map_err(|e| format!("Failed to start worker removal workflow: {:?}", e))
} }
/// Start MCP server registration workflow
async fn start_mcp_registration_workflow(
engine: &Arc<WorkflowEngine>,
config: &McpServerConfigRequest,
context: &Arc<AppContext>,
) -> Result<WorkflowInstanceId, String> {
let mut workflow_context = WorkflowContext::new(WorkflowInstanceId::new());
workflow_context.set("mcp_server_config", config.clone());
workflow_context.set_arc("app_context", Arc::clone(context));
engine
.start_workflow(WorkflowId::new("mcp_registration"), workflow_context)
.await
.map_err(|e| format!("Failed to start MCP registration workflow: {:?}", e))
}
/// Wait for workflow completion with adaptive polling /// Wait for workflow completion with adaptive polling
async fn wait_for_workflow_completion( async fn wait_for_workflow_completion(
engine: &Arc<WorkflowEngine>, engine: &Arc<WorkflowEngine>,
......
...@@ -14,5 +14,8 @@ pub use engine::WorkflowEngine; ...@@ -14,5 +14,8 @@ pub use engine::WorkflowEngine;
pub use event::{EventBus, EventSubscriber, LoggingSubscriber, WorkflowEvent}; pub use event::{EventBus, EventSubscriber, LoggingSubscriber, WorkflowEvent};
pub use executor::{FunctionStep, StepExecutor}; pub use executor::{FunctionStep, StepExecutor};
pub use state::WorkflowStateStore; pub use state::WorkflowStateStore;
pub use steps::{create_worker_registration_workflow, create_worker_removal_workflow}; pub use steps::{
create_mcp_registration_workflow, create_worker_registration_workflow,
create_worker_removal_workflow,
};
pub use types::*; pub use types::*;
//! MCP server registration workflow steps
//!
//! Each step is atomic and performs a single operation in the MCP server registration process.
//! Updated for flat manager architecture - single McpManager manages all clients directly.
//!
//! Workflow order:
//! 1. ConnectMcpServer - Establish connection to MCP server using McpManager::connect_server()
//! 2. DiscoverMcpInventory - Discover and cache inventory using McpManager::load_server_inventory()
//! 3. RegisterMcpServer - Register McpClient in McpManager's client map
use std::{sync::Arc, time::Duration};
use async_trait::async_trait;
use rmcp::{service::RunningService, RoleClient};
use tracing::{debug, error, info, warn};
use crate::{
app_context::AppContext,
core::workflow::*,
mcp::{config::McpServerConfig, manager::McpManager},
};
/// MCP server connection configuration
#[derive(Debug, Clone)]
pub struct McpServerConfigRequest {
/// Server name (unique identifier)
pub name: String,
/// Server configuration (transport, proxy, etc.)
pub config: McpServerConfig,
}
impl McpServerConfigRequest {
/// Check if this server is required for router startup
pub fn is_required(&self) -> bool {
self.config.required
}
}
/// Step 1: Connect to MCP server
///
/// This step establishes a connection to the MCP server using the flat manager architecture.
/// The connection is retried aggressively (100 attempts) with a long timeout (2 hours)
/// to handle slow-starting servers or network issues.
pub struct ConnectMcpServerStep;
#[async_trait]
impl StepExecutor for ConnectMcpServerStep {
async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult<StepResult> {
let config_request: Arc<McpServerConfigRequest> = context
.get("mcp_server_config")
.ok_or_else(|| WorkflowError::ContextValueNotFound("mcp_server_config".to_string()))?;
let app_context: Arc<AppContext> = context
.get("app_context")
.ok_or_else(|| WorkflowError::ContextValueNotFound("app_context".to_string()))?;
debug!("Connecting to MCP server: {}", config_request.name);
// Get proxy config from router_config if available, otherwise fall back to env
let proxy_config = app_context
.router_config
.mcp_config
.as_ref()
.and_then(|cfg| cfg.proxy.as_ref());
// Connect to MCP server
let client = McpManager::connect_server(&config_request.config, proxy_config)
.await
.map_err(|e| WorkflowError::StepFailed {
step_id: StepId::new("connect_mcp_server"),
message: format!(
"Failed to connect to MCP server {}: {}",
config_request.name, e
),
})?;
info!(
"Successfully connected to MCP server: {}",
config_request.name
);
// Store client in context (context.set() will wrap in Arc)
context.set("mcp_client", client);
Ok(StepResult::Success)
}
fn is_retryable(&self, _error: &WorkflowError) -> bool {
true // Connection failures are retryable
}
}
/// Step 2: Discover MCP inventory (tools, prompts, resources)
///
/// This step queries the MCP server for its capabilities using McpManager::load_server_inventory().
/// - Tools: Available function calls
/// - Prompts: Reusable prompt templates
/// - Resources: Accessible files/data
pub struct DiscoverMcpInventoryStep;
#[async_trait]
impl StepExecutor for DiscoverMcpInventoryStep {
async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult<StepResult> {
use rmcp::{service::RunningService, RoleClient};
let config_request: Arc<McpServerConfigRequest> = context
.get("mcp_server_config")
.ok_or_else(|| WorkflowError::ContextValueNotFound("mcp_server_config".to_string()))?;
let app_context: Arc<AppContext> = context
.get("app_context")
.ok_or_else(|| WorkflowError::ContextValueNotFound("app_context".to_string()))?;
let mcp_client: Arc<RunningService<RoleClient, ()>> = context
.get("mcp_client")
.ok_or_else(|| WorkflowError::ContextValueNotFound("mcp_client".to_string()))?;
debug!(
"Discovering inventory for MCP server: {}",
config_request.name
);
// Get shared ToolInventory from McpManager
let mcp_manager =
app_context
.mcp_manager
.get()
.ok_or_else(|| WorkflowError::StepFailed {
step_id: StepId::new("discover_mcp_inventory"),
message: "MCP manager not initialized".to_string(),
})?;
let inventory = mcp_manager.inventory();
// Use the public load_server_inventory method
McpManager::load_server_inventory(&inventory, &config_request.name, &mcp_client).await;
info!("Completed inventory discovery for {}", config_request.name);
Ok(StepResult::Success)
}
fn is_retryable(&self, _error: &WorkflowError) -> bool {
true // Discovery failures are retryable
}
}
/// Step 3: Register MCP server in manager
///
/// This step adds the MCP client to the McpManager's client map so it can be
/// used for tool calls and inventory management.
pub struct RegisterMcpServerStep;
#[async_trait]
impl StepExecutor for RegisterMcpServerStep {
async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult<StepResult> {
use rmcp::{service::RunningService, RoleClient};
let config_request: Arc<McpServerConfigRequest> = context
.get("mcp_server_config")
.ok_or_else(|| WorkflowError::ContextValueNotFound("mcp_server_config".to_string()))?;
let app_context: Arc<AppContext> = context
.get("app_context")
.ok_or_else(|| WorkflowError::ContextValueNotFound("app_context".to_string()))?;
let mcp_client: Arc<RunningService<RoleClient, ()>> = context
.get("mcp_client")
.ok_or_else(|| WorkflowError::ContextValueNotFound("mcp_client".to_string()))?;
debug!("Registering MCP server: {}", config_request.name);
// Get MCP manager from app context
let mcp_manager =
app_context
.mcp_manager
.get()
.ok_or_else(|| WorkflowError::StepFailed {
step_id: StepId::new("register_mcp_server"),
message: "MCP manager not initialized".to_string(),
})?;
// Register the client in the manager's client map
mcp_manager.register_static_server(config_request.name.clone(), mcp_client);
info!("Registered MCP server: {}", config_request.name);
Ok(StepResult::Success)
}
fn is_retryable(&self, _error: &WorkflowError) -> bool {
false // Registration is a simple operation, not retryable
}
}
/// Step 4: Validate registration based on required flag
///
/// This step checks if the server is marked as required. If the server is required
/// but wasn't successfully registered (client not in context), this step fails the workflow.
/// For optional servers, this step always succeeds, allowing the workflow to complete
/// even if earlier steps failed.
pub struct ValidateRegistrationStep;
#[async_trait]
impl StepExecutor for ValidateRegistrationStep {
async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult<StepResult> {
let config_request: Arc<McpServerConfigRequest> = context
.get("mcp_server_config")
.ok_or_else(|| WorkflowError::ContextValueNotFound("mcp_server_config".to_string()))?;
let client_registered = context
.get::<Arc<RunningService<RoleClient, ()>>>("mcp_client")
.is_some();
if client_registered {
info!(
"MCP server '{}' registered successfully",
config_request.name
);
return Ok(StepResult::Success);
}
if config_request.is_required() {
error!(
"Required MCP server '{}' failed to register",
config_request.name
);
Err(WorkflowError::StepFailed {
step_id: StepId::new("validate_registration"),
message: format!(
"Required MCP server '{}' failed to register",
config_request.name
),
})
} else {
warn!(
"Optional MCP server '{}' failed to register, continuing workflow",
config_request.name
);
Ok(StepResult::Success)
}
}
fn is_retryable(&self, _error: &WorkflowError) -> bool {
false
}
}
/// Create MCP server registration workflow
///
/// This workflow adapts its failure behavior based on the `required` field in the server config:
/// - If `required == true`: Uses FailWorkflow - router startup fails if server cannot be reached
/// - If `required == false` (default): Uses ContinueNextStep - logs warning but continues
///
/// Workflow configuration:
/// - ConnectMcpServer: 100 retries, 2hr timeout (aggressive retry for slow servers)
/// - DiscoverMcpInventory: 3 retries, 10s timeout (discovery + caching)
/// - RegisterMcpServer: No retry, 5s timeout (fast registration)
/// - ValidateRegistration: Final validation step
pub fn create_mcp_registration_workflow() -> WorkflowDefinition {
WorkflowDefinition::new("mcp_registration", "MCP Server Registration")
.add_step(
StepDefinition::new(
"connect_mcp_server",
"Connect to MCP Server",
Arc::new(ConnectMcpServerStep),
)
.with_retry(RetryPolicy {
max_attempts: 100,
backoff: BackoffStrategy::Linear {
increment: Duration::from_secs(1),
max: Duration::from_secs(5),
},
})
.with_timeout(Duration::from_secs(7200)) // 2 hours
.with_failure_action(FailureAction::ContinueNextStep),
)
.add_step(
StepDefinition::new(
"discover_mcp_inventory",
"Discover and Cache MCP Inventory",
Arc::new(DiscoverMcpInventoryStep),
)
.with_retry(RetryPolicy {
max_attempts: 3,
backoff: BackoffStrategy::Fixed(Duration::from_secs(1)),
})
.with_timeout(Duration::from_secs(10))
.with_failure_action(FailureAction::ContinueNextStep),
)
.add_step(
StepDefinition::new(
"register_mcp_server",
"Register MCP Server",
Arc::new(RegisterMcpServerStep),
)
.with_timeout(Duration::from_secs(5))
.with_failure_action(FailureAction::ContinueNextStep),
)
.add_step(
StepDefinition::new(
"validate_registration",
"Validate MCP Registration",
Arc::new(ValidateRegistrationStep),
)
.with_timeout(Duration::from_secs(1))
.with_failure_action(FailureAction::FailWorkflow),
)
}
...@@ -3,11 +3,17 @@ ...@@ -3,11 +3,17 @@
//! This module contains concrete step implementations for various workflows: //! This module contains concrete step implementations for various workflows:
//! - Worker registration and activation //! - Worker registration and activation
//! - Worker removal //! - Worker removal
//! - MCP server registration
//! - Future: Tokenizer fetching, LoRA updates, etc. //! - Future: Tokenizer fetching, LoRA updates, etc.
pub mod mcp_registration;
pub mod worker_registration; pub mod worker_registration;
pub mod worker_removal; pub mod worker_removal;
pub use mcp_registration::{
create_mcp_registration_workflow, ConnectMcpServerStep, DiscoverMcpInventoryStep,
McpServerConfigRequest, RegisterMcpServerStep, ValidateRegistrationStep,
};
pub use worker_registration::{ pub use worker_registration::{
create_worker_registration_workflow, ActivateWorkerStep, CreateWorkerStep, create_worker_registration_workflow, ActivateWorkerStep, CreateWorkerStep,
DetectConnectionModeStep, DiscoverMetadataStep, RegisterWorkerStep, UpdatePoliciesStep, DetectConnectionModeStep, DiscoverMetadataStep, RegisterWorkerStep, UpdatePoliciesStep,
......
...@@ -205,6 +205,7 @@ struct Router { ...@@ -205,6 +205,7 @@ struct Router {
tokenizer_cache_l1_max_memory: usize, tokenizer_cache_l1_max_memory: usize,
reasoning_parser: Option<String>, reasoning_parser: Option<String>,
tool_call_parser: Option<String>, tool_call_parser: Option<String>,
mcp_config_path: Option<String>,
backend: BackendType, backend: BackendType,
history_backend: HistoryBackendType, history_backend: HistoryBackendType,
oracle_config: Option<PyOracleConfig>, oracle_config: Option<PyOracleConfig>,
...@@ -360,6 +361,7 @@ impl Router { ...@@ -360,6 +361,7 @@ impl Router {
.maybe_oracle(oracle) .maybe_oracle(oracle)
.maybe_reasoning_parser(self.reasoning_parser.as_ref()) .maybe_reasoning_parser(self.reasoning_parser.as_ref())
.maybe_tool_call_parser(self.tool_call_parser.as_ref()) .maybe_tool_call_parser(self.tool_call_parser.as_ref())
.maybe_mcp_config_path(self.mcp_config_path.as_ref())
.dp_aware(self.dp_aware) .dp_aware(self.dp_aware)
.retries(!self.disable_retries) .retries(!self.disable_retries)
.circuit_breaker(!self.disable_circuit_breaker) .circuit_breaker(!self.disable_circuit_breaker)
...@@ -440,6 +442,7 @@ impl Router { ...@@ -440,6 +442,7 @@ impl Router {
tokenizer_cache_l1_max_memory = 52428800, tokenizer_cache_l1_max_memory = 52428800,
reasoning_parser = None, reasoning_parser = None,
tool_call_parser = None, tool_call_parser = None,
mcp_config_path = None,
backend = BackendType::Sglang, backend = BackendType::Sglang,
history_backend = HistoryBackendType::Memory, history_backend = HistoryBackendType::Memory,
oracle_config = None, oracle_config = None,
...@@ -512,6 +515,7 @@ impl Router { ...@@ -512,6 +515,7 @@ impl Router {
tokenizer_cache_l1_max_memory: usize, tokenizer_cache_l1_max_memory: usize,
reasoning_parser: Option<String>, reasoning_parser: Option<String>,
tool_call_parser: Option<String>, tool_call_parser: Option<String>,
mcp_config_path: Option<String>,
backend: BackendType, backend: BackendType,
history_backend: HistoryBackendType, history_backend: HistoryBackendType,
oracle_config: Option<PyOracleConfig>, oracle_config: Option<PyOracleConfig>,
...@@ -598,6 +602,7 @@ impl Router { ...@@ -598,6 +602,7 @@ impl Router {
tokenizer_cache_l1_max_memory, tokenizer_cache_l1_max_memory,
reasoning_parser, reasoning_parser,
tool_call_parser, tool_call_parser,
mcp_config_path,
backend, backend,
history_backend, history_backend,
oracle_config, oracle_config,
......
...@@ -315,6 +315,9 @@ struct CliArgs { ...@@ -315,6 +315,9 @@ struct CliArgs {
#[arg(long)] #[arg(long)]
tool_call_parser: Option<String>, tool_call_parser: Option<String>,
#[arg(long)]
mcp_config_path: Option<String>,
} }
enum OracleConnectSource { enum OracleConnectSource {
...@@ -594,6 +597,7 @@ impl CliArgs { ...@@ -594,6 +597,7 @@ impl CliArgs {
.maybe_oracle(oracle) .maybe_oracle(oracle)
.maybe_reasoning_parser(self.reasoning_parser.as_ref()) .maybe_reasoning_parser(self.reasoning_parser.as_ref())
.maybe_tool_call_parser(self.tool_call_parser.as_ref()) .maybe_tool_call_parser(self.tool_call_parser.as_ref())
.maybe_mcp_config_path(self.mcp_config_path.as_ref())
.dp_aware(self.dp_aware) .dp_aware(self.dp_aware)
.retries(!self.disable_retries) .retries(!self.disable_retries)
.circuit_breaker(!self.disable_circuit_breaker) .circuit_breaker(!self.disable_circuit_breaker)
......
This diff is collapsed.
...@@ -2,9 +2,63 @@ use std::collections::HashMap; ...@@ -2,9 +2,63 @@ use std::collections::HashMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
// ============================================================================
// MCP Data Structures
// ============================================================================
/// Information about an available tool
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolInfo {
pub name: String,
pub description: String,
pub server: String,
pub parameters: Option<serde_json::Value>,
}
/// Information about an available prompt
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptInfo {
pub name: String,
pub description: Option<String>,
pub server: String,
pub arguments: Option<Vec<serde_json::Value>>,
}
/// Information about an available resource
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResourceInfo {
pub uri: String,
pub name: String,
pub description: Option<String>,
pub mime_type: Option<String>,
pub server: String,
}
// ============================================================================
// Configuration Structures
// ============================================================================
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct McpConfig { pub struct McpConfig {
/// Static MCP servers (loaded at startup)
pub servers: Vec<McpServerConfig>, pub servers: Vec<McpServerConfig>,
/// Connection pool settings
#[serde(default)]
pub pool: McpPoolConfig,
/// Global MCP proxy configuration (default for all servers)
/// Can be overridden per-server
#[serde(default)]
pub proxy: Option<McpProxyConfig>,
/// Pre-warm these connections at startup
#[serde(default)]
pub warmup: Vec<WarmupServer>,
/// Tool inventory refresh settings
#[serde(default)]
pub inventory: InventoryConfig,
} }
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
...@@ -12,6 +66,17 @@ pub struct McpServerConfig { ...@@ -12,6 +66,17 @@ pub struct McpServerConfig {
pub name: String, pub name: String,
#[serde(flatten)] #[serde(flatten)]
pub transport: McpTransport, pub transport: McpTransport,
/// Per-server proxy override (overrides global proxy)
/// Set to `null` in YAML to force direct connection (no proxy)
#[serde(default)]
pub proxy: Option<McpProxyConfig>,
/// Whether this server is required for router startup
/// - true: Router startup fails if this server cannot be reached
/// - false: Log warning but continue (default)
#[serde(default)]
pub required: bool,
} }
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
...@@ -36,6 +101,144 @@ pub enum McpTransport { ...@@ -36,6 +101,144 @@ pub enum McpTransport {
}, },
} }
/// MCP-specific proxy configuration (does NOT affect LLM API traffic)
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct McpProxyConfig {
/// HTTP proxy URL (e.g., "http://proxy.internal:8080")
pub http: Option<String>,
/// HTTPS proxy URL
pub https: Option<String>,
/// Comma-separated hosts to exclude from proxying
/// Example: "localhost,127.0.0.1,*.internal,10.*"
pub no_proxy: Option<String>,
/// Custom proxy authentication (if needed)
#[serde(skip_serializing_if = "Option::is_none")]
pub username: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub password: Option<String>,
}
/// Connection pool configuration
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct McpPoolConfig {
/// Maximum cached connections per server URL
#[serde(default = "default_max_connections")]
pub max_connections: usize,
/// Idle timeout before closing connection (seconds)
#[serde(default = "default_idle_timeout")]
pub idle_timeout: u64,
}
/// Tool inventory refresh configuration
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct InventoryConfig {
/// Enable automatic tool inventory refresh
#[serde(default = "default_true")]
pub enable_refresh: bool,
/// Tool cache TTL (seconds) - how long tools are considered fresh
#[serde(default = "default_tool_ttl")]
pub tool_ttl: u64,
/// Background refresh interval (seconds) - proactive refresh
#[serde(default = "default_refresh_interval")]
pub refresh_interval: u64,
/// Refresh on tool call failure (try refreshing if tool not found)
#[serde(default = "default_true")]
pub refresh_on_error: bool,
}
/// Pre-warm server connections at startup
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct WarmupServer {
/// Server URL
pub url: String,
/// Server label/name
pub label: String,
/// Optional authentication token
#[serde(skip_serializing_if = "Option::is_none")]
pub token: Option<String>,
}
// Default value functions
fn default_max_connections() -> usize {
100
}
fn default_idle_timeout() -> u64 {
300 // 5 minutes
}
fn default_true() -> bool {
true
}
fn default_tool_ttl() -> u64 {
300 // 5 minutes
}
fn default_refresh_interval() -> u64 {
60 // 1 minute
}
// Default implementations
impl Default for McpPoolConfig {
fn default() -> Self {
Self {
max_connections: default_max_connections(),
idle_timeout: default_idle_timeout(),
}
}
}
impl Default for InventoryConfig {
fn default() -> Self {
Self {
enable_refresh: true,
tool_ttl: default_tool_ttl(),
refresh_interval: default_refresh_interval(),
refresh_on_error: true,
}
}
}
impl McpProxyConfig {
/// Load proxy config from standard environment variables
pub fn from_env() -> Option<Self> {
let http = std::env::var("MCP_HTTP_PROXY")
.ok()
.or_else(|| std::env::var("HTTP_PROXY").ok());
let https = std::env::var("MCP_HTTPS_PROXY")
.ok()
.or_else(|| std::env::var("HTTPS_PROXY").ok());
let no_proxy = std::env::var("MCP_NO_PROXY")
.ok()
.or_else(|| std::env::var("NO_PROXY").ok());
if http.is_some() || https.is_some() {
Some(Self {
http,
https,
no_proxy,
username: None,
password: None,
})
} else {
None
}
}
}
impl McpConfig { impl McpConfig {
/// Load configuration from a YAML file /// Load configuration from a YAML file
pub async fn from_file(path: &str) -> Result<Self, Box<dyn std::error::Error>> { pub async fn from_file(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
...@@ -50,4 +253,280 @@ impl McpConfig { ...@@ -50,4 +253,280 @@ impl McpConfig {
// For now, return None to indicate env config not implemented // For now, return None to indicate env config not implemented
None None
} }
/// Merge with environment-based proxy config
pub fn with_env_proxy(mut self) -> Self {
if self.proxy.is_none() {
self.proxy = McpProxyConfig::from_env();
}
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_pool_config() {
let config = McpPoolConfig::default();
assert_eq!(config.max_connections, 100);
assert_eq!(config.idle_timeout, 300);
}
#[test]
fn test_default_inventory_config() {
let config = InventoryConfig::default();
assert!(config.enable_refresh);
assert_eq!(config.tool_ttl, 300);
assert_eq!(config.refresh_interval, 60);
assert!(config.refresh_on_error);
}
#[test]
fn test_proxy_from_env_empty() {
// Ensure no proxy env vars are set for this test
std::env::remove_var("MCP_HTTP_PROXY");
std::env::remove_var("MCP_HTTPS_PROXY");
std::env::remove_var("HTTP_PROXY");
std::env::remove_var("HTTPS_PROXY");
let proxy = McpProxyConfig::from_env();
assert!(proxy.is_none(), "Should return None when no env vars set");
}
#[test]
fn test_proxy_from_env_with_vars() {
std::env::set_var("MCP_HTTP_PROXY", "http://test-proxy:8080");
std::env::set_var("MCP_NO_PROXY", "localhost,127.0.0.1");
let proxy = McpProxyConfig::from_env();
assert!(proxy.is_some(), "Should return Some when env vars set");
let proxy = proxy.unwrap();
assert_eq!(proxy.http.as_ref().unwrap(), "http://test-proxy:8080");
assert_eq!(proxy.no_proxy.as_ref().unwrap(), "localhost,127.0.0.1");
// Cleanup
std::env::remove_var("MCP_HTTP_PROXY");
std::env::remove_var("MCP_NO_PROXY");
}
#[tokio::test]
async fn test_yaml_minimal_config() {
let yaml = r#"
servers:
- name: "test-server"
protocol: sse
url: "http://localhost:3000/sse"
"#;
let config: McpConfig = serde_yaml::from_str(yaml).expect("Failed to parse YAML");
assert_eq!(config.servers.len(), 1);
assert_eq!(config.servers[0].name, "test-server");
assert!(!config.servers[0].required); // Should default to false
assert!(config.servers[0].proxy.is_none()); // Should default to None
assert_eq!(config.pool.max_connections, 100); // Should use default
assert_eq!(config.inventory.tool_ttl, 300); // Should use default
}
#[tokio::test]
async fn test_yaml_full_config() {
let yaml = r#"
# Global proxy configuration
proxy:
http: "http://global-proxy:8080"
https: "http://global-proxy:8080"
no_proxy: "localhost,127.0.0.1,*.internal"
# Connection pool settings
pool:
max_connections: 50
idle_timeout: 600
# Tool inventory settings
inventory:
enable_refresh: true
tool_ttl: 600
refresh_interval: 120
refresh_on_error: true
# Static servers
servers:
- name: "required-server"
protocol: sse
url: "https://api.example.com/sse"
token: "secret-token"
required: true
- name: "optional-server"
protocol: stdio
command: "mcp-server"
args: ["--port", "3000"]
required: false
proxy:
http: "http://server-specific-proxy:9090"
# Pre-warm connections
warmup:
- url: "http://localhost:3000/sse"
label: "local-dev"
"#;
let config: McpConfig = serde_yaml::from_str(yaml).expect("Failed to parse YAML");
// Check global proxy
assert!(config.proxy.is_some());
let global_proxy = config.proxy.as_ref().unwrap();
assert_eq!(
global_proxy.http.as_ref().unwrap(),
"http://global-proxy:8080"
);
// Check pool config
assert_eq!(config.pool.max_connections, 50);
assert_eq!(config.pool.idle_timeout, 600);
// Check inventory config
assert_eq!(config.inventory.tool_ttl, 600);
assert_eq!(config.inventory.refresh_interval, 120);
// Check servers
assert_eq!(config.servers.len(), 2);
// Required server
assert_eq!(config.servers[0].name, "required-server");
assert!(config.servers[0].required);
assert!(config.servers[0].proxy.is_none()); // Inherits global proxy
// Optional server with custom proxy
assert_eq!(config.servers[1].name, "optional-server");
assert!(!config.servers[1].required);
assert!(config.servers[1].proxy.is_some());
assert_eq!(
config.servers[1]
.proxy
.as_ref()
.unwrap()
.http
.as_ref()
.unwrap(),
"http://server-specific-proxy:9090"
);
// Check warmup
assert_eq!(config.warmup.len(), 1);
assert_eq!(config.warmup[0].label, "local-dev");
}
#[tokio::test]
async fn test_yaml_backward_compatibility() {
// Old config format should still work
let yaml = r#"
servers:
- name: "legacy-server"
protocol: sse
url: "http://localhost:3000/sse"
"#;
let config: McpConfig = serde_yaml::from_str(yaml).expect("Failed to parse old format");
assert_eq!(config.servers.len(), 1);
assert_eq!(config.servers[0].name, "legacy-server");
assert!(!config.servers[0].required); // New field defaults to false
assert!(config.servers[0].proxy.is_none()); // New field defaults to None
assert!(config.proxy.is_none()); // New field defaults to None
assert!(config.warmup.is_empty()); // New field defaults to empty
}
#[tokio::test]
async fn test_yaml_null_proxy_override() {
// Test that explicit null in YAML sets proxy to None
let yaml = r#"
proxy:
http: "http://global-proxy:8080"
servers:
- name: "direct-connection"
protocol: sse
url: "http://localhost:3000/sse"
proxy: null
"#;
let config: McpConfig = serde_yaml::from_str(yaml).expect("Failed to parse YAML");
assert!(config.proxy.is_some()); // Global proxy set
assert_eq!(config.servers.len(), 1);
assert!(config.servers[0].proxy.is_none()); // Explicitly set to null
}
#[test]
fn test_transport_stdio() {
let yaml = r#"
name: "test"
protocol: stdio
command: "mcp-server"
args: ["--port", "3000"]
envs:
VAR1: "value1"
VAR2: "value2"
"#;
let config: McpServerConfig = serde_yaml::from_str(yaml).expect("Failed to parse stdio");
assert_eq!(config.name, "test");
match config.transport {
McpTransport::Stdio {
command,
args,
envs,
} => {
assert_eq!(command, "mcp-server");
assert_eq!(args.len(), 2);
assert_eq!(args[0], "--port");
assert_eq!(envs.get("VAR1").unwrap(), "value1");
}
_ => panic!("Expected Stdio transport"),
}
}
#[test]
fn test_transport_sse() {
let yaml = r#"
name: "test"
protocol: sse
url: "http://localhost:3000/sse"
token: "secret"
"#;
let config: McpServerConfig = serde_yaml::from_str(yaml).expect("Failed to parse sse");
assert_eq!(config.name, "test");
match config.transport {
McpTransport::Sse { url, token } => {
assert_eq!(url, "http://localhost:3000/sse");
assert_eq!(token.unwrap(), "secret");
}
_ => panic!("Expected Sse transport"),
}
}
#[test]
fn test_transport_streamable() {
let yaml = r#"
name: "test"
protocol: streamable
url: "http://localhost:3000"
"#;
let config: McpServerConfig =
serde_yaml::from_str(yaml).expect("Failed to parse streamable");
assert_eq!(config.name, "test");
match config.transport {
McpTransport::Streamable { url, token } => {
assert_eq!(url, "http://localhost:3000");
assert!(token.is_none());
}
_ => panic!("Expected Streamable transport"),
}
}
} }
// MCP Connection Pool
//
// This module provides connection pooling for dynamic MCP servers (per-request).
// Connections are cached and reused to avoid 70-650ms connection overhead on every request.
//
// Performance target:
// - First request: 70-650ms (connection establishment)
// - Subsequent requests: <1ms (cache hit)
// - 90%+ reduction in per-request overhead
use std::{
sync::Arc,
time::{Duration, Instant},
};
use dashmap::DashMap;
use rmcp::{service::RunningService, RoleClient};
use crate::mcp::{
config::{McpProxyConfig, McpServerConfig},
error::McpResult,
};
/// Type alias for MCP client
type McpClient = RunningService<RoleClient, ()>;
/// Cached MCP connection with metadata
#[derive(Clone)]
pub struct CachedConnection {
/// The MCP client instance
pub client: Arc<McpClient>,
/// Last time this connection was accessed
pub last_used: Instant,
/// Server configuration used to create this connection
pub config: McpServerConfig,
}
impl CachedConnection {
/// Create a new cached connection
pub fn new(client: Arc<McpClient>, config: McpServerConfig) -> Self {
Self {
client,
last_used: Instant::now(),
config,
}
}
/// Update last_used timestamp
pub fn touch(&mut self) {
self.last_used = Instant::now();
}
/// Check if connection has been idle for longer than TTL
pub fn is_idle(&self, idle_ttl: Duration) -> bool {
self.last_used.elapsed() > idle_ttl
}
}
/// Connection pool for dynamic MCP servers
///
/// Provides thread-safe connection pooling with automatic cleanup of idle connections.
/// Connections are keyed by server URL and reused across requests.
pub struct McpConnectionPool {
/// Map of server_url -> cached connection
connections: DashMap<String, CachedConnection>,
/// Idle connection TTL (connections unused for this duration are cleaned up)
idle_ttl: Duration,
/// Maximum number of cached connections (prevents unbounded growth)
max_connections: usize,
/// Global proxy configuration (applied to all dynamic servers)
/// Can be overridden per-server via McpServerConfig.proxy
global_proxy: Option<McpProxyConfig>,
}
impl McpConnectionPool {
/// Create a new connection pool with default settings
///
/// Default settings:
/// - idle_ttl: 300 seconds (5 minutes)
/// - max_connections: 100
/// - global_proxy: Loaded from environment variables (MCP_HTTP_PROXY, etc.)
pub fn new() -> Self {
Self {
connections: DashMap::new(),
idle_ttl: Duration::from_secs(300),
max_connections: 100,
global_proxy: McpProxyConfig::from_env(),
}
}
/// Create a new connection pool with custom settings
pub fn with_config(idle_ttl: Duration, max_connections: usize) -> Self {
Self {
connections: DashMap::new(),
idle_ttl,
max_connections,
global_proxy: McpProxyConfig::from_env(),
}
}
/// Create a new connection pool with full custom configuration
pub fn with_full_config(
idle_ttl: Duration,
max_connections: usize,
global_proxy: Option<McpProxyConfig>,
) -> Self {
Self {
connections: DashMap::new(),
idle_ttl,
max_connections,
global_proxy,
}
}
/// Get an existing connection or create a new one
///
/// This method:
/// 1. Checks if a connection exists for the given URL
/// 2. If exists and fresh, updates last_used and returns it (fast path <1ms)
/// 3. If not exists or stale, creates new connection (slow path 70-650ms)
///
/// # Arguments
/// * `server_url` - The MCP server URL (used as cache key)
/// * `server_config` - Server configuration (used to create new connection if needed)
/// * `connect_fn` - Async function to create a new client connection
///
/// # Returns
/// Arc to the MCP client, either from cache or newly created
pub async fn get_or_create<F, Fut>(
&self,
server_url: &str,
server_config: McpServerConfig,
connect_fn: F,
) -> McpResult<Arc<McpClient>>
where
F: FnOnce(McpServerConfig, Option<McpProxyConfig>) -> Fut,
Fut: std::future::Future<Output = McpResult<McpClient>>,
{
// Fast path: Check if connection exists and is still fresh
if let Some(mut entry) = self.connections.get_mut(server_url) {
let cached = entry.value_mut();
// Check if connection is still within TTL
if !cached.is_idle(self.idle_ttl) {
// Update last_used and return cached connection
cached.touch();
return Ok(Arc::clone(&cached.client));
}
// Connection is stale, drop it and create new one
drop(entry);
self.connections.remove(server_url);
}
// Slow path: Create new connection
// Enforce max_connections limit
if self.connections.len() >= self.max_connections {
self.cleanup_idle_connections();
// If still at limit after cleanup, remove oldest connection
if self.connections.len() >= self.max_connections {
if let Some(oldest_key) = self.find_oldest_connection() {
self.connections.remove(&oldest_key);
}
}
}
// Create new MCP client using the provided connect function
let client = connect_fn(server_config.clone(), self.global_proxy.clone()).await?;
let client_arc = Arc::new(client);
// Cache the new connection
let cached = CachedConnection::new(Arc::clone(&client_arc), server_config);
self.connections.insert(server_url.to_string(), cached);
Ok(client_arc)
}
/// Remove all idle connections that have exceeded the TTL
///
/// This method is called:
/// - Automatically when max_connections limit is reached
/// - Can be called manually by background cleanup task
pub fn cleanup_idle_connections(&self) {
let now = Instant::now();
self.connections
.retain(|_, cached| now.duration_since(cached.last_used) < self.idle_ttl);
}
/// Find the oldest connection (by last_used timestamp)
///
/// Used for eviction when max_connections is reached and cleanup didn't free space
fn find_oldest_connection(&self) -> Option<String> {
self.connections
.iter()
.min_by_key(|entry| entry.value().last_used)
.map(|entry| entry.key().clone())
}
/// Get current number of cached connections
pub fn len(&self) -> usize {
self.connections.len()
}
/// Check if pool is empty
pub fn is_empty(&self) -> bool {
self.connections.is_empty()
}
/// Clear all connections (useful for tests)
pub fn clear(&self) {
self.connections.clear();
}
/// Get connection statistics
pub fn stats(&self) -> PoolStats {
let total = self.connections.len();
let idle_count = self
.connections
.iter()
.filter(|entry| entry.value().is_idle(self.idle_ttl))
.count();
PoolStats {
total_connections: total,
active_connections: total - idle_count,
idle_connections: idle_count,
}
}
}
impl Default for McpConnectionPool {
fn default() -> Self {
Self::new()
}
}
/// Connection pool statistics
#[derive(Debug, Clone)]
pub struct PoolStats {
pub total_connections: usize,
pub active_connections: usize,
pub idle_connections: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mcp::McpTransport;
// Helper to create test server config
fn create_test_config(url: &str) -> McpServerConfig {
McpServerConfig {
name: "test_server".to_string(),
transport: McpTransport::Streamable {
url: url.to_string(),
token: None,
},
proxy: None,
required: false,
}
}
#[tokio::test]
async fn test_pool_creation() {
let pool = McpConnectionPool::new();
assert_eq!(pool.len(), 0);
assert!(pool.is_empty());
}
#[test]
#[allow(invalid_value)]
fn test_cached_connection_touch() {
let config = create_test_config("http://localhost:3000");
let client: Arc<McpClient> = Arc::new(unsafe {
// SAFETY: This is only for testing the CachedConnection struct
std::mem::MaybeUninit::zeroed().assume_init()
});
let mut cached = CachedConnection::new(client.clone(), config);
let first_time = cached.last_used;
std::thread::sleep(Duration::from_millis(10));
cached.touch();
assert!(cached.last_used > first_time);
// Prevent drop of invalid Arc (would segfault)
std::mem::forget(client);
}
#[test]
#[allow(invalid_value)]
fn test_cached_connection_is_idle() {
let config = create_test_config("http://localhost:3000");
let client: Arc<McpClient> = Arc::new(unsafe {
// SAFETY: This is only for testing the CachedConnection struct
std::mem::MaybeUninit::zeroed().assume_init()
});
let cached = CachedConnection::new(client.clone(), config);
// Fresh connection should not be idle
assert!(!cached.is_idle(Duration::from_secs(1)));
// Wait and check
std::thread::sleep(Duration::from_millis(100));
assert!(cached.is_idle(Duration::from_millis(50)));
// Prevent drop of invalid Arc (would segfault)
std::mem::forget(client);
}
#[test]
fn test_pool_stats() {
let pool = McpConnectionPool::with_config(Duration::from_millis(100), 10);
let stats = pool.stats();
assert_eq!(stats.total_connections, 0);
assert_eq!(stats.active_connections, 0);
assert_eq!(stats.idle_connections, 0);
}
#[test]
#[allow(invalid_value)]
fn test_cleanup_idle_connections() {
let pool = McpConnectionPool::with_config(Duration::from_millis(50), 10);
// Initially empty
assert_eq!(pool.len(), 0);
// Add a connection manually for testing
let config = create_test_config("http://localhost:3000");
let client: Arc<McpClient> =
Arc::new(unsafe { std::mem::MaybeUninit::zeroed().assume_init() });
let cached = CachedConnection::new(client.clone(), config);
pool.connections
.insert("http://localhost:3000".to_string(), cached);
assert_eq!(pool.len(), 1);
// Wait for TTL to expire
std::thread::sleep(Duration::from_millis(100));
// Cleanup should remove idle connection
pool.cleanup_idle_connections();
assert_eq!(pool.len(), 0);
// Prevent drop of invalid Arc (would segfault)
std::mem::forget(client);
}
#[test]
#[allow(invalid_value)]
fn test_find_oldest_connection() {
let pool = McpConnectionPool::new();
// Collect clients to forget at end
let mut clients = Vec::new();
// Add connections with different timestamps
for i in 0..3 {
let url = format!("http://localhost:{}", 3000 + i);
let config = create_test_config(&url);
let client: Arc<McpClient> =
Arc::new(unsafe { std::mem::MaybeUninit::zeroed().assume_init() });
let cached = CachedConnection::new(client.clone(), config);
pool.connections.insert(url, cached);
clients.push(client);
std::thread::sleep(Duration::from_millis(10));
}
// Oldest should be the first one
let oldest = pool.find_oldest_connection();
assert!(oldest.is_some());
assert_eq!(oldest.unwrap(), "http://localhost:3000");
// Prevent drop of invalid Arcs (would segfault)
for client in clients {
std::mem::forget(client);
}
}
#[test]
#[allow(invalid_value)]
fn test_pool_clear() {
let pool = McpConnectionPool::new();
// Add a connection
let config = create_test_config("http://localhost:3000");
let client: Arc<McpClient> =
Arc::new(unsafe { std::mem::MaybeUninit::zeroed().assume_init() });
let cached = CachedConnection::new(client.clone(), config);
pool.connections
.insert("http://localhost:3000".to_string(), cached);
assert_eq!(pool.len(), 1);
pool.clear();
assert_eq!(pool.len(), 0);
assert!(pool.is_empty());
// Prevent drop of invalid Arc (would segfault)
std::mem::forget(client);
}
#[test]
fn test_pool_with_global_proxy() {
use crate::mcp::McpProxyConfig;
// Create proxy config
let proxy = McpProxyConfig {
http: Some("http://proxy.example.com:8080".to_string()),
https: None,
no_proxy: Some("localhost,127.0.0.1".to_string()),
username: None,
password: None,
};
// Create pool with proxy
let pool =
McpConnectionPool::with_full_config(Duration::from_secs(300), 100, Some(proxy.clone()));
// Verify proxy is stored
assert!(pool.global_proxy.is_some());
let stored_proxy = pool.global_proxy.as_ref().unwrap();
assert_eq!(
stored_proxy.http.as_ref().unwrap(),
"http://proxy.example.com:8080"
);
assert_eq!(
stored_proxy.no_proxy.as_ref().unwrap(),
"localhost,127.0.0.1"
);
}
#[test]
fn test_pool_proxy_from_env() {
// Note: This test depends on environment variables
// In production, proxy is loaded from MCP_HTTP_PROXY or HTTP_PROXY env vars
let pool = McpConnectionPool::new();
// Pool should either have proxy from env or None
// We can't assert specific value since it depends on test environment
// Just verify it doesn't panic
assert!(pool.global_proxy.is_some() || pool.global_proxy.is_none());
}
}
// MCP Tool Inventory with TTL-based Caching
//
// This module provides TTL-based caching for MCP tools, prompts, and resources.
// Tools are cached with timestamps and automatically expire after the configured TTL.
// Background refresh tasks can proactively update the inventory.
use std::time::{Duration, Instant};
use dashmap::DashMap;
use crate::mcp::config::{PromptInfo, ResourceInfo, ToolInfo};
/// Cached tool with metadata
#[derive(Clone)]
pub struct CachedTool {
pub server_name: String,
pub tool: ToolInfo,
pub cached_at: Instant,
}
/// Cached prompt with metadata
#[derive(Clone)]
pub struct CachedPrompt {
pub server_name: String,
pub prompt: PromptInfo,
pub cached_at: Instant,
}
/// Cached resource with metadata
#[derive(Clone)]
pub struct CachedResource {
pub server_name: String,
pub resource: ResourceInfo,
pub cached_at: Instant,
}
/// Tool inventory with TTL-based caching
///
/// Provides thread-safe caching of MCP tools, prompts, and resources with automatic expiration.
/// Entries are timestamped and can be queried with TTL validation.
pub struct ToolInventory {
/// Map of tool_name -> cached tool
tools: DashMap<String, CachedTool>,
/// Map of prompt_name -> cached prompt
prompts: DashMap<String, CachedPrompt>,
/// Map of resource_uri -> cached resource
resources: DashMap<String, CachedResource>,
/// Tool cache TTL
tool_ttl: Duration,
/// Last refresh time per server
server_refresh_times: DashMap<String, Instant>,
}
impl ToolInventory {
/// Create a new tool inventory with the specified TTL
pub fn new(tool_ttl: Duration) -> Self {
Self {
tools: DashMap::new(),
prompts: DashMap::new(),
resources: DashMap::new(),
tool_ttl,
server_refresh_times: DashMap::new(),
}
}
// ============================================================================
// Tool Methods
// ============================================================================
/// Get a tool if it exists and is fresh (within TTL)
///
/// Returns None if the tool doesn't exist or has expired.
pub fn get_tool(&self, tool_name: &str) -> Option<(String, ToolInfo)> {
self.tools.get(tool_name).and_then(|entry| {
let cached = entry.value();
// Check if still fresh
if cached.cached_at.elapsed() < self.tool_ttl {
Some((cached.server_name.clone(), cached.tool.clone()))
} else {
// Expired - will be removed by cleanup
None
}
})
}
/// Check if tool exists (regardless of TTL)
pub fn has_tool(&self, tool_name: &str) -> bool {
self.tools.contains_key(tool_name)
}
/// Insert or update a tool
pub fn insert_tool(&self, tool_name: String, server_name: String, tool: ToolInfo) {
self.tools.insert(
tool_name,
CachedTool {
server_name,
tool,
cached_at: Instant::now(),
},
);
}
/// Get all tools (fresh only)
pub fn list_tools(&self) -> Vec<(String, String, ToolInfo)> {
let now = Instant::now();
self.tools
.iter()
.filter_map(|entry| {
let (name, cached) = entry.pair();
if now.duration_since(cached.cached_at) < self.tool_ttl {
Some((
name.clone(),
cached.server_name.clone(),
cached.tool.clone(),
))
} else {
None
}
})
.collect()
}
// ============================================================================
// Prompt Methods
// ============================================================================
/// Get a prompt if it exists and is fresh (within TTL)
pub fn get_prompt(&self, prompt_name: &str) -> Option<(String, PromptInfo)> {
self.prompts.get(prompt_name).and_then(|entry| {
let cached = entry.value();
// Check if still fresh
if cached.cached_at.elapsed() < self.tool_ttl {
Some((cached.server_name.clone(), cached.prompt.clone()))
} else {
None
}
})
}
/// Check if prompt exists (regardless of TTL)
pub fn has_prompt(&self, prompt_name: &str) -> bool {
self.prompts.contains_key(prompt_name)
}
/// Insert or update a prompt
pub fn insert_prompt(&self, prompt_name: String, server_name: String, prompt: PromptInfo) {
self.prompts.insert(
prompt_name,
CachedPrompt {
server_name,
prompt,
cached_at: Instant::now(),
},
);
}
/// Get all prompts (fresh only)
pub fn list_prompts(&self) -> Vec<(String, String, PromptInfo)> {
let now = Instant::now();
self.prompts
.iter()
.filter_map(|entry| {
let (name, cached) = entry.pair();
if now.duration_since(cached.cached_at) < self.tool_ttl {
Some((
name.clone(),
cached.server_name.clone(),
cached.prompt.clone(),
))
} else {
None
}
})
.collect()
}
// ============================================================================
// Resource Methods
// ============================================================================
/// Get a resource if it exists and is fresh (within TTL)
pub fn get_resource(&self, resource_uri: &str) -> Option<(String, ResourceInfo)> {
self.resources.get(resource_uri).and_then(|entry| {
let cached = entry.value();
// Check if still fresh
if cached.cached_at.elapsed() < self.tool_ttl {
Some((cached.server_name.clone(), cached.resource.clone()))
} else {
None
}
})
}
/// Check if resource exists (regardless of TTL)
pub fn has_resource(&self, resource_uri: &str) -> bool {
self.resources.contains_key(resource_uri)
}
/// Insert or update a resource
pub fn insert_resource(
&self,
resource_uri: String,
server_name: String,
resource: ResourceInfo,
) {
self.resources.insert(
resource_uri,
CachedResource {
server_name,
resource,
cached_at: Instant::now(),
},
);
}
/// Get all resources (fresh only)
pub fn list_resources(&self) -> Vec<(String, String, ResourceInfo)> {
let now = Instant::now();
self.resources
.iter()
.filter_map(|entry| {
let (uri, cached) = entry.pair();
if now.duration_since(cached.cached_at) < self.tool_ttl {
Some((
uri.clone(),
cached.server_name.clone(),
cached.resource.clone(),
))
} else {
None
}
})
.collect()
}
// ============================================================================
// Server Management Methods
// ============================================================================
/// Clear all cached items for a specific server (before refresh)
pub fn clear_server_tools(&self, server_name: &str) {
self.tools
.retain(|_, cached| cached.server_name != server_name);
self.prompts
.retain(|_, cached| cached.server_name != server_name);
self.resources
.retain(|_, cached| cached.server_name != server_name);
}
/// Mark server as refreshed
pub fn mark_refreshed(&self, server_name: &str) {
self.server_refresh_times
.insert(server_name.to_string(), Instant::now());
}
/// Check if server needs refresh based on refresh interval
pub fn needs_refresh(&self, server_name: &str, refresh_interval: Duration) -> bool {
self.server_refresh_times
.get(server_name)
.map(|t| t.elapsed() > refresh_interval)
.unwrap_or(true) // Never refreshed = needs refresh
}
/// Get last refresh time for a server
pub fn last_refresh(&self, server_name: &str) -> Option<Instant> {
self.server_refresh_times
.get(server_name)
.map(|t| *t.value())
}
// ============================================================================
// Cleanup Methods
// ============================================================================
/// Cleanup expired entries
///
/// Removes all tools, prompts, and resources that have exceeded their TTL.
/// Should be called periodically by a background task.
pub fn cleanup_expired(&self) {
let now = Instant::now();
// Remove expired tools
self.tools
.retain(|_, cached| now.duration_since(cached.cached_at) < self.tool_ttl);
// Remove expired prompts
self.prompts
.retain(|_, cached| now.duration_since(cached.cached_at) < self.tool_ttl);
// Remove expired resources
self.resources
.retain(|_, cached| now.duration_since(cached.cached_at) < self.tool_ttl);
}
/// Get count of cached items
pub fn counts(&self) -> (usize, usize, usize) {
(self.tools.len(), self.prompts.len(), self.resources.len())
}
/// Clear all cached items
pub fn clear_all(&self) {
self.tools.clear();
self.prompts.clear();
self.resources.clear();
self.server_refresh_times.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
// Helper to create a test tool
fn create_test_tool(name: &str) -> ToolInfo {
ToolInfo {
name: name.to_string(),
description: format!("Test tool: {}", name),
server: "test_server".to_string(),
parameters: Some(serde_json::json!({
"type": "object",
"properties": {}
})),
}
}
// Helper to create a test prompt
fn create_test_prompt(name: &str) -> PromptInfo {
PromptInfo {
name: name.to_string(),
description: Some(format!("Test prompt: {}", name)),
server: "test_server".to_string(),
arguments: None,
}
}
// Helper to create a test resource
fn create_test_resource(uri: &str) -> ResourceInfo {
ResourceInfo {
uri: uri.to_string(),
name: uri.to_string(),
description: Some(format!("Test resource: {}", uri)),
mime_type: Some("text/plain".to_string()),
server: "test_server".to_string(),
}
}
#[test]
fn test_tool_insert_and_get() {
let inventory = ToolInventory::new(Duration::from_secs(60));
let tool = create_test_tool("test_tool");
inventory.insert_tool("test_tool".to_string(), "server1".to_string(), tool.clone());
let result = inventory.get_tool("test_tool");
assert!(result.is_some());
let (server_name, retrieved_tool) = result.unwrap();
assert_eq!(server_name, "server1");
assert_eq!(retrieved_tool.name, "test_tool");
}
#[test]
fn test_tool_expiration() {
let inventory = ToolInventory::new(Duration::from_millis(100));
let tool = create_test_tool("expiring_tool");
inventory.insert_tool(
"expiring_tool".to_string(),
"server1".to_string(),
tool.clone(),
);
// Should be available immediately
assert!(inventory.get_tool("expiring_tool").is_some());
// Wait for expiration
std::thread::sleep(Duration::from_millis(150));
// Should be expired now
assert!(inventory.get_tool("expiring_tool").is_none());
}
#[test]
fn test_has_tool() {
let inventory = ToolInventory::new(Duration::from_secs(60));
let tool = create_test_tool("check_tool");
assert!(!inventory.has_tool("check_tool"));
inventory.insert_tool("check_tool".to_string(), "server1".to_string(), tool);
assert!(inventory.has_tool("check_tool"));
}
#[test]
fn test_list_tools() {
let inventory = ToolInventory::new(Duration::from_secs(60));
inventory.insert_tool(
"tool1".to_string(),
"server1".to_string(),
create_test_tool("tool1"),
);
inventory.insert_tool(
"tool2".to_string(),
"server1".to_string(),
create_test_tool("tool2"),
);
inventory.insert_tool(
"tool3".to_string(),
"server2".to_string(),
create_test_tool("tool3"),
);
let tools = inventory.list_tools();
assert_eq!(tools.len(), 3);
}
#[test]
fn test_list_tools_filters_expired() {
let inventory = ToolInventory::new(Duration::from_millis(100));
inventory.insert_tool(
"tool1".to_string(),
"server1".to_string(),
create_test_tool("tool1"),
);
// Should have 1 tool
assert_eq!(inventory.list_tools().len(), 1);
// Wait for expiration
std::thread::sleep(Duration::from_millis(150));
// Should have 0 tools (filtered out)
assert_eq!(inventory.list_tools().len(), 0);
}
#[test]
fn test_clear_server_tools() {
let inventory = ToolInventory::new(Duration::from_secs(60));
inventory.insert_tool(
"tool1".to_string(),
"server1".to_string(),
create_test_tool("tool1"),
);
inventory.insert_tool(
"tool2".to_string(),
"server2".to_string(),
create_test_tool("tool2"),
);
assert_eq!(inventory.list_tools().len(), 2);
inventory.clear_server_tools("server1");
let tools = inventory.list_tools();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].0, "tool2");
}
#[test]
fn test_server_refresh_tracking() {
let inventory = ToolInventory::new(Duration::from_secs(60));
// Never refreshed
assert!(inventory.needs_refresh("server1", Duration::from_secs(10)));
// Mark as refreshed
inventory.mark_refreshed("server1");
// Should not need refresh immediately
assert!(!inventory.needs_refresh("server1", Duration::from_secs(10)));
// Wait and check again
std::thread::sleep(Duration::from_millis(100));
assert!(inventory.needs_refresh("server1", Duration::from_millis(50)));
}
#[test]
fn test_cleanup_expired() {
let inventory = ToolInventory::new(Duration::from_millis(100));
inventory.insert_tool(
"tool1".to_string(),
"server1".to_string(),
create_test_tool("tool1"),
);
inventory.insert_tool(
"tool2".to_string(),
"server1".to_string(),
create_test_tool("tool2"),
);
let (tools, _, _) = inventory.counts();
assert_eq!(tools, 2);
// Wait for expiration
std::thread::sleep(Duration::from_millis(150));
// Cleanup expired entries
inventory.cleanup_expired();
let (tools, _, _) = inventory.counts();
assert_eq!(tools, 0);
}
#[test]
fn test_prompt_operations() {
let inventory = ToolInventory::new(Duration::from_secs(60));
let prompt = create_test_prompt("test_prompt");
inventory.insert_prompt(
"test_prompt".to_string(),
"server1".to_string(),
prompt.clone(),
);
assert!(inventory.has_prompt("test_prompt"));
let result = inventory.get_prompt("test_prompt");
assert!(result.is_some());
let (server_name, retrieved_prompt) = result.unwrap();
assert_eq!(server_name, "server1");
assert_eq!(retrieved_prompt.name, "test_prompt");
}
#[test]
fn test_resource_operations() {
let inventory = ToolInventory::new(Duration::from_secs(60));
let resource = create_test_resource("file:///test.txt");
inventory.insert_resource(
"file:///test.txt".to_string(),
"server1".to_string(),
resource.clone(),
);
assert!(inventory.has_resource("file:///test.txt"));
let result = inventory.get_resource("file:///test.txt");
assert!(result.is_some());
let (server_name, retrieved_resource) = result.unwrap();
assert_eq!(server_name, "server1");
assert_eq!(retrieved_resource.uri, "file:///test.txt");
}
#[tokio::test]
async fn test_concurrent_access() {
use std::sync::Arc;
let inventory = Arc::new(ToolInventory::new(Duration::from_secs(60)));
// Spawn multiple tasks that insert tools concurrently
let mut handles = vec![];
for i in 0..10 {
let inv = Arc::clone(&inventory);
let handle = tokio::spawn(async move {
let tool = create_test_tool(&format!("tool_{}", i));
inv.insert_tool(format!("tool_{}", i), format!("server_{}", i % 3), tool);
});
handles.push(handle);
}
// Wait for all tasks to complete
for handle in handles {
handle.await.unwrap();
}
// Should have 10 tools
let (tools, _, _) = inventory.counts();
assert_eq!(tools, 10);
}
#[test]
fn test_clear_all() {
let inventory = ToolInventory::new(Duration::from_secs(60));
inventory.insert_tool(
"tool1".to_string(),
"server1".to_string(),
create_test_tool("tool1"),
);
inventory.insert_prompt(
"prompt1".to_string(),
"server1".to_string(),
create_test_prompt("prompt1"),
);
inventory.insert_resource(
"res1".to_string(),
"server1".to_string(),
create_test_resource("res1"),
);
inventory.mark_refreshed("server1");
let (tools, prompts, resources) = inventory.counts();
assert_eq!(tools, 1);
assert_eq!(prompts, 1);
assert_eq!(resources, 1);
inventory.clear_all();
let (tools, prompts, resources) = inventory.counts();
assert_eq!(tools, 0);
assert_eq!(prompts, 0);
assert_eq!(resources, 0);
assert!(inventory.last_refresh("server1").is_none());
}
}
This diff is collapsed.
...@@ -7,12 +7,21 @@ ...@@ -7,12 +7,21 @@
// - Resources: File/data access with subscription support // - Resources: File/data access with subscription support
// - OAuth: Secure authentication for remote servers // - OAuth: Secure authentication for remote servers
pub mod client_manager;
pub mod config; pub mod config;
pub mod connection_pool;
pub mod error; pub mod error;
pub mod inventory;
pub mod manager;
pub mod oauth; pub mod oauth;
pub mod proxy;
// Re-export the main types for convenience // Re-export the main types for convenience
pub use client_manager::{McpClientManager, PromptInfo, ResourceInfo, ToolInfo}; pub use config::{
pub use config::{McpConfig, McpServerConfig, McpTransport}; InventoryConfig, McpConfig, McpPoolConfig, McpProxyConfig, McpServerConfig, McpTransport,
PromptInfo, ResourceInfo, ToolInfo, WarmupServer,
};
pub use connection_pool::{CachedConnection, McpConnectionPool, PoolStats};
pub use error::{McpError, McpResult}; pub use error::{McpError, McpResult};
pub use inventory::ToolInventory;
pub use manager::{McpManager, McpManagerStats};
pub use proxy::{create_http_client, resolve_proxy_config};
// MCP Proxy Configuration and Resolution
//
// This module provides proxy configuration resolution and HTTP client creation
// for MCP server connections. Proxy settings are MCP-specific and do NOT affect
// LLM API traffic.
use std::time::Duration;
use crate::mcp::{McpError, McpProxyConfig, McpResult, McpServerConfig};
/// Resolve proxy configuration for a server
/// Priority: server.proxy > global.proxy > None
///
/// # Arguments
/// * `server_config` - Server-specific configuration
/// * `global_proxy` - Global proxy configuration from McpConfig
///
/// # Returns
/// The resolved proxy configuration, or None for direct connection
pub fn resolve_proxy_config<'a>(
server_config: &'a McpServerConfig,
global_proxy: Option<&'a McpProxyConfig>,
) -> Option<&'a McpProxyConfig> {
// Priority 1: Check if server has explicit proxy config
// Note: server.proxy = Some(config) uses that config
// server.proxy = None (set explicitly in YAML as null) forces direct connection
// server.proxy not set (field missing) falls back to global
if server_config.proxy.is_some() {
server_config.proxy.as_ref()
} else {
// Priority 2: Fall back to global proxy
global_proxy
}
}
/// Apply proxy configuration to a ClientBuilder
///
/// This is a reusable helper that applies proxy settings without building the client,
/// allowing additional configuration (like auth headers) to be added afterward.
///
/// # Arguments
/// * `builder` - The reqwest::ClientBuilder to configure
/// * `proxy_config` - The proxy configuration to apply
///
/// # Returns
/// The configured builder or error
pub fn apply_proxy_to_builder(
mut builder: reqwest::ClientBuilder,
proxy_cfg: &McpProxyConfig,
) -> McpResult<reqwest::ClientBuilder> {
// Configure HTTP proxy
if let Some(ref http_proxy) = proxy_cfg.http {
let mut proxy = reqwest::Proxy::http(http_proxy)
.map_err(|e| McpError::Config(format!("Invalid HTTP proxy: {}", e)))?;
// Apply no_proxy exclusions
if let Some(ref no_proxy) = proxy_cfg.no_proxy {
proxy = proxy.no_proxy(reqwest::NoProxy::from_string(no_proxy));
}
// Apply authentication if configured
if let (Some(ref username), Some(ref password)) = (&proxy_cfg.username, &proxy_cfg.password)
{
proxy = proxy.basic_auth(username, password);
}
builder = builder.proxy(proxy);
}
// Configure HTTPS proxy
if let Some(ref https_proxy) = proxy_cfg.https {
let mut proxy = reqwest::Proxy::https(https_proxy)
.map_err(|e| McpError::Config(format!("Invalid HTTPS proxy: {}", e)))?;
// Apply no_proxy exclusions
if let Some(ref no_proxy) = proxy_cfg.no_proxy {
proxy = proxy.no_proxy(reqwest::NoProxy::from_string(no_proxy));
}
// Apply authentication if configured
if let (Some(ref username), Some(ref password)) = (&proxy_cfg.username, &proxy_cfg.password)
{
proxy = proxy.basic_auth(username, password);
}
builder = builder.proxy(proxy);
}
Ok(builder)
}
/// Create HTTP client with MCP-specific proxy configuration
///
/// # Arguments
/// * `proxy_config` - Optional proxy configuration to apply
///
/// # Returns
/// A configured reqwest::Client or error
pub fn create_http_client(proxy_config: Option<&McpProxyConfig>) -> McpResult<reqwest::Client> {
let mut builder = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.connect_timeout(Duration::from_secs(10));
// Apply MCP-specific proxy if configured
if let Some(proxy_cfg) = proxy_config {
builder = apply_proxy_to_builder(builder, proxy_cfg)?;
}
builder
.build()
.map_err(|e| McpError::Transport(format!("Failed to build HTTP client: {}", e)))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mcp::McpTransport;
#[test]
fn test_resolve_proxy_no_config() {
let server = McpServerConfig {
name: "test".to_string(),
transport: McpTransport::Sse {
url: "http://localhost:3000/sse".to_string(),
token: None,
},
proxy: None,
required: false,
};
let result = resolve_proxy_config(&server, None);
assert!(
result.is_none(),
"Should return None when no proxy configured"
);
}
#[test]
fn test_resolve_proxy_global_only() {
let server = McpServerConfig {
name: "test".to_string(),
transport: McpTransport::Sse {
url: "http://localhost:3000/sse".to_string(),
token: None,
},
proxy: None,
required: false,
};
let global = McpProxyConfig {
http: Some("http://global-proxy:8080".to_string()),
https: None,
no_proxy: None,
username: None,
password: None,
};
let result = resolve_proxy_config(&server, Some(&global));
assert!(result.is_some(), "Should use global proxy");
assert_eq!(
result.unwrap().http.as_ref().unwrap(),
"http://global-proxy:8080"
);
}
#[test]
fn test_resolve_proxy_server_override() {
let server_proxy = McpProxyConfig {
http: Some("http://server-proxy:9090".to_string()),
https: None,
no_proxy: None,
username: None,
password: None,
};
let server = McpServerConfig {
name: "test".to_string(),
transport: McpTransport::Sse {
url: "http://localhost:3000/sse".to_string(),
token: None,
},
proxy: Some(server_proxy),
required: false,
};
let global = McpProxyConfig {
http: Some("http://global-proxy:8080".to_string()),
https: None,
no_proxy: None,
username: None,
password: None,
};
let result = resolve_proxy_config(&server, Some(&global));
assert!(result.is_some(), "Should use server-specific proxy");
assert_eq!(
result.unwrap().http.as_ref().unwrap(),
"http://server-proxy:9090",
"Server proxy should override global"
);
}
#[test]
fn test_create_http_client_no_proxy() {
let client = create_http_client(None);
assert!(client.is_ok(), "Should create client without proxy");
}
#[test]
fn test_create_http_client_with_proxy() {
let proxy = McpProxyConfig {
http: Some("http://proxy.example.com:8080".to_string()),
https: None,
no_proxy: Some("localhost,127.0.0.1".to_string()),
username: None,
password: None,
};
let client = create_http_client(Some(&proxy));
assert!(client.is_ok(), "Should create client with proxy");
}
#[test]
fn test_create_http_client_with_auth() {
let proxy = McpProxyConfig {
http: Some("http://proxy.example.com:8080".to_string()),
https: None,
no_proxy: None,
username: Some("user".to_string()),
password: Some("pass".to_string()),
};
let client = create_http_client(Some(&proxy));
assert!(
client.is_ok(),
"Should create client with proxy authentication"
);
}
#[test]
fn test_create_http_client_invalid_proxy() {
let proxy = McpProxyConfig {
http: Some("://invalid".to_string()), // Invalid URL format
https: None,
no_proxy: None,
username: None,
password: None,
};
let client = create_http_client(Some(&proxy));
assert!(client.is_err(), "Should fail with invalid proxy URL");
}
}
...@@ -127,14 +127,7 @@ impl RouterFactory { ...@@ -127,14 +127,7 @@ impl RouterFactory {
return Err("OpenAI mode requires at least one worker URL".to_string()); return Err("OpenAI mode requires at least one worker URL".to_string());
} }
let router = OpenAIRouter::new( let router = OpenAIRouter::new(worker_urls, ctx).await?;
worker_urls,
Some(ctx.router_config.circuit_breaker.clone()),
ctx.response_storage.clone(),
ctx.conversation_storage.clone(),
ctx.conversation_item_storage.clone(),
)
.await?;
Ok(Box::new(router)) Ok(Box::new(router))
} }
......
...@@ -65,6 +65,7 @@ pub async fn route_responses( ...@@ -65,6 +65,7 @@ pub async fn route_responses(
response_storage: SharedResponseStorage, response_storage: SharedResponseStorage,
conversation_storage: SharedConversationStorage, conversation_storage: SharedConversationStorage,
conversation_item_storage: SharedConversationItemStorage, conversation_item_storage: SharedConversationItemStorage,
mcp_manager: Arc<crate::mcp::McpManager>,
background_tasks: Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>, background_tasks: Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>,
) -> Response { ) -> Response {
// 1. Validate mutually exclusive parameters // 1. Validate mutually exclusive parameters
...@@ -113,6 +114,7 @@ pub async fn route_responses( ...@@ -113,6 +114,7 @@ pub async fn route_responses(
response_storage, response_storage,
conversation_storage, conversation_storage,
conversation_item_storage, conversation_item_storage,
mcp_manager,
) )
.await .await
} else if is_background { } else if is_background {
...@@ -125,6 +127,7 @@ pub async fn route_responses( ...@@ -125,6 +127,7 @@ pub async fn route_responses(
response_storage, response_storage,
conversation_storage, conversation_storage,
conversation_item_storage, conversation_item_storage,
mcp_manager,
background_tasks, background_tasks,
) )
.await .await
...@@ -138,6 +141,7 @@ pub async fn route_responses( ...@@ -138,6 +141,7 @@ pub async fn route_responses(
response_storage, response_storage,
conversation_storage, conversation_storage,
conversation_item_storage, conversation_item_storage,
mcp_manager,
None, // No response_id for sync None, // No response_id for sync
None, // No background_tasks for sync None, // No background_tasks for sync
) )
...@@ -167,6 +171,7 @@ async fn route_responses_sync( ...@@ -167,6 +171,7 @@ async fn route_responses_sync(
response_storage: SharedResponseStorage, response_storage: SharedResponseStorage,
conversation_storage: SharedConversationStorage, conversation_storage: SharedConversationStorage,
conversation_item_storage: SharedConversationItemStorage, conversation_item_storage: SharedConversationItemStorage,
mcp_manager: Arc<crate::mcp::McpManager>,
response_id: Option<String>, response_id: Option<String>,
background_tasks: Option<Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>>, background_tasks: Option<Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>>,
) -> Response { ) -> Response {
...@@ -179,6 +184,7 @@ async fn route_responses_sync( ...@@ -179,6 +184,7 @@ async fn route_responses_sync(
response_storage, response_storage,
conversation_storage, conversation_storage,
conversation_item_storage, conversation_item_storage,
mcp_manager,
response_id, response_id,
background_tasks, background_tasks,
) )
...@@ -209,6 +215,7 @@ async fn route_responses_internal( ...@@ -209,6 +215,7 @@ async fn route_responses_internal(
response_storage: SharedResponseStorage, response_storage: SharedResponseStorage,
conversation_storage: SharedConversationStorage, conversation_storage: SharedConversationStorage,
conversation_item_storage: SharedConversationItemStorage, conversation_item_storage: SharedConversationItemStorage,
mcp_manager: Arc<crate::mcp::McpManager>,
response_id: Option<String>, response_id: Option<String>,
background_tasks: Option<Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>>, background_tasks: Option<Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>>,
) -> Result<ResponsesResponse, String> { ) -> Result<ResponsesResponse, String> {
...@@ -223,7 +230,10 @@ async fn route_responses_internal( ...@@ -223,7 +230,10 @@ async fn route_responses_internal(
// 2. Check if request has MCP tools - if so, use tool loop // 2. Check if request has MCP tools - if so, use tool loop
let responses_response = if let Some(tools) = &request.tools { let responses_response = if let Some(tools) = &request.tools {
if let Some(mcp_manager) = create_mcp_manager_from_request(tools).await { // Try to create dynamic MCP client from request tools using the manager
if let Some(request_mcp_manager) =
create_mcp_manager_from_request(&mcp_manager, tools).await
{
debug!("MCP tools detected, using tool loop"); debug!("MCP tools detected, using tool loop");
// Execute with MCP tool loop // Execute with MCP tool loop
...@@ -234,13 +244,14 @@ async fn route_responses_internal( ...@@ -234,13 +244,14 @@ async fn route_responses_internal(
headers, headers,
model_id, model_id,
components, components,
mcp_manager, request_mcp_manager,
response_id.clone(), response_id.clone(),
background_tasks, background_tasks,
) )
.await? .await?
} else { } else {
// No MCP manager, execute normally debug!("Failed to create MCP client from request tools");
// Fall through to non-MCP execution
execute_without_mcp( execute_without_mcp(
pipeline, pipeline,
&modified_request, &modified_request,
...@@ -303,6 +314,7 @@ async fn route_responses_background( ...@@ -303,6 +314,7 @@ async fn route_responses_background(
response_storage: SharedResponseStorage, response_storage: SharedResponseStorage,
conversation_storage: SharedConversationStorage, conversation_storage: SharedConversationStorage,
conversation_item_storage: SharedConversationItemStorage, conversation_item_storage: SharedConversationItemStorage,
mcp_manager: Arc<crate::mcp::McpManager>,
background_tasks: Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>, background_tasks: Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>,
) -> Response { ) -> Response {
// Generate response_id for background tracking // Generate response_id for background tracking
...@@ -365,6 +377,7 @@ async fn route_responses_background( ...@@ -365,6 +377,7 @@ async fn route_responses_background(
let response_storage_clone = response_storage.clone(); let response_storage_clone = response_storage.clone();
let conversation_storage_clone = conversation_storage.clone(); let conversation_storage_clone = conversation_storage.clone();
let conversation_item_storage_clone = conversation_item_storage.clone(); let conversation_item_storage_clone = conversation_item_storage.clone();
let mcp_manager_clone = mcp_manager.clone();
let response_id_clone = response_id.clone(); let response_id_clone = response_id.clone();
let background_tasks_clone = background_tasks.clone(); let background_tasks_clone = background_tasks.clone();
...@@ -382,6 +395,7 @@ async fn route_responses_background( ...@@ -382,6 +395,7 @@ async fn route_responses_background(
response_storage_clone, response_storage_clone,
conversation_storage_clone, conversation_storage_clone,
conversation_item_storage_clone, conversation_item_storage_clone,
mcp_manager_clone,
Some(response_id_clone.clone()), Some(response_id_clone.clone()),
Some(background_tasks_clone.clone()), Some(background_tasks_clone.clone()),
) )
...@@ -434,6 +448,7 @@ async fn route_responses_streaming( ...@@ -434,6 +448,7 @@ async fn route_responses_streaming(
response_storage: SharedResponseStorage, response_storage: SharedResponseStorage,
conversation_storage: SharedConversationStorage, conversation_storage: SharedConversationStorage,
conversation_item_storage: SharedConversationItemStorage, conversation_item_storage: SharedConversationItemStorage,
mcp_manager: Arc<crate::mcp::McpManager>,
) -> Response { ) -> Response {
// 1. Load conversation history // 1. Load conversation history
let modified_request = match load_conversation_history( let modified_request = match load_conversation_history(
...@@ -461,7 +476,10 @@ async fn route_responses_streaming( ...@@ -461,7 +476,10 @@ async fn route_responses_streaming(
// 2. Check if request has MCP tools - if so, use streaming tool loop // 2. Check if request has MCP tools - if so, use streaming tool loop
if let Some(tools) = &request.tools { if let Some(tools) = &request.tools {
if let Some(mcp_manager) = create_mcp_manager_from_request(tools).await { // Try to create dynamic MCP client from request tools using the manager
if let Some(request_mcp_manager) =
create_mcp_manager_from_request(&mcp_manager, tools).await
{
debug!("MCP tools detected in streaming mode, using streaming tool loop"); debug!("MCP tools detected in streaming mode, using streaming tool loop");
return execute_tool_loop_streaming( return execute_tool_loop_streaming(
...@@ -471,7 +489,7 @@ async fn route_responses_streaming( ...@@ -471,7 +489,7 @@ async fn route_responses_streaming(
headers, headers,
model_id, model_id,
components, components,
mcp_manager, request_mcp_manager,
response_storage, response_storage,
conversation_storage, conversation_storage,
conversation_item_storage, conversation_item_storage,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment