Unverified Commit 212f5e48 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] MCP Manager Refactoring - Flat Architecture with Connection Pooling (#12097)

parent fe527812
...@@ -24,12 +24,11 @@ use super::{ ...@@ -24,12 +24,11 @@ use super::{
types::BackgroundTaskInfo, types::BackgroundTaskInfo,
}; };
/// This is a re-export of the shared implementation from openai::mcp /// This is a re-export of the shared implementation from openai::mcp
pub(super) use crate::routers::openai::mcp::mcp_manager_from_request_tools as create_mcp_manager_from_request; pub(super) use crate::routers::openai::mcp::ensure_request_mcp_client as create_mcp_manager_from_request;
use crate::{ use crate::{
data_connector::{ data_connector::{
SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage, SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage,
}, },
mcp::McpClientManager,
protocols::{ protocols::{
chat::ChatCompletionResponse, chat::ChatCompletionResponse,
common::{Tool, ToolChoice, ToolChoiceValue}, common::{Tool, ToolChoice, ToolChoiceValue},
...@@ -102,7 +101,7 @@ fn extract_all_tool_calls_from_chat( ...@@ -102,7 +101,7 @@ fn extract_all_tool_calls_from_chat(
/// Execute an MCP tool call /// Execute an MCP tool call
async fn execute_mcp_call( async fn execute_mcp_call(
mcp_mgr: &Arc<McpClientManager>, mcp_mgr: &Arc<crate::mcp::McpManager>,
tool_name: &str, tool_name: &str,
args_json_str: &str, args_json_str: &str,
) -> Result<String, String> { ) -> Result<String, String> {
...@@ -222,7 +221,7 @@ fn generate_mcp_id(prefix: &str) -> String { ...@@ -222,7 +221,7 @@ fn generate_mcp_id(prefix: &str) -> String {
/// Build mcp_list_tools output item /// Build mcp_list_tools output item
fn build_mcp_list_tools_item( fn build_mcp_list_tools_item(
mcp: &Arc<McpClientManager>, mcp: &Arc<crate::mcp::McpManager>,
server_label: &str, server_label: &str,
) -> ResponseOutputItem { ) -> ResponseOutputItem {
let tools = mcp.list_tools(); let tools = mcp.list_tools();
...@@ -287,7 +286,7 @@ pub(super) async fn execute_tool_loop( ...@@ -287,7 +286,7 @@ pub(super) async fn execute_tool_loop(
headers: Option<http::HeaderMap>, headers: Option<http::HeaderMap>,
model_id: Option<String>, model_id: Option<String>,
components: Arc<SharedComponents>, components: Arc<SharedComponents>,
mcp_manager: Arc<McpClientManager>, mcp_manager: Arc<crate::mcp::McpManager>,
response_id: Option<String>, response_id: Option<String>,
background_tasks: Option<Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>>, background_tasks: Option<Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>>,
) -> Result<ResponsesResponse, String> { ) -> Result<ResponsesResponse, String> {
...@@ -507,7 +506,7 @@ pub(super) async fn execute_tool_loop_streaming( ...@@ -507,7 +506,7 @@ pub(super) async fn execute_tool_loop_streaming(
headers: Option<http::HeaderMap>, headers: Option<http::HeaderMap>,
model_id: Option<String>, model_id: Option<String>,
components: Arc<SharedComponents>, components: Arc<SharedComponents>,
mcp_manager: Arc<McpClientManager>, mcp_manager: Arc<crate::mcp::McpManager>,
response_storage: SharedResponseStorage, response_storage: SharedResponseStorage,
conversation_storage: SharedConversationStorage, conversation_storage: SharedConversationStorage,
conversation_item_storage: SharedConversationItemStorage, conversation_item_storage: SharedConversationItemStorage,
...@@ -598,7 +597,7 @@ async fn execute_tool_loop_streaming_internal( ...@@ -598,7 +597,7 @@ async fn execute_tool_loop_streaming_internal(
headers: Option<http::HeaderMap>, headers: Option<http::HeaderMap>,
model_id: Option<String>, model_id: Option<String>,
components: Arc<SharedComponents>, components: Arc<SharedComponents>,
mcp_manager: Arc<McpClientManager>, mcp_manager: Arc<crate::mcp::McpManager>,
server_label: String, server_label: String,
_response_storage: SharedResponseStorage, _response_storage: SharedResponseStorage,
_conversation_storage: SharedConversationStorage, _conversation_storage: SharedConversationStorage,
......
...@@ -24,6 +24,7 @@ use crate::{ ...@@ -24,6 +24,7 @@ use crate::{
data_connector::{ data_connector::{
SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage, SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage,
}, },
mcp::McpManager,
policies::PolicyRegistry, policies::PolicyRegistry,
protocols::{ protocols::{
chat::ChatCompletionRequest, chat::ChatCompletionRequest,
...@@ -60,8 +61,7 @@ pub struct GrpcRouter { ...@@ -60,8 +61,7 @@ pub struct GrpcRouter {
response_storage: SharedResponseStorage, response_storage: SharedResponseStorage,
conversation_storage: SharedConversationStorage, conversation_storage: SharedConversationStorage,
conversation_item_storage: SharedConversationItemStorage, conversation_item_storage: SharedConversationItemStorage,
// Optional MCP manager for tool execution (enabled via SGLANG_MCP_CONFIG env var) mcp_manager: Arc<McpManager>,
mcp_manager: Option<Arc<crate::mcp::McpClientManager>>,
// Background task handles for cancellation support (includes gRPC client for Python abort) // Background task handles for cancellation support (includes gRPC client for Python abort)
background_tasks: Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>, background_tasks: Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>,
} }
...@@ -94,25 +94,12 @@ impl GrpcRouter { ...@@ -94,25 +94,12 @@ impl GrpcRouter {
let conversation_storage = ctx.conversation_storage.clone(); let conversation_storage = ctx.conversation_storage.clone();
let conversation_item_storage = ctx.conversation_item_storage.clone(); let conversation_item_storage = ctx.conversation_item_storage.clone();
// Optional MCP manager activation via env var path (config-driven gate) // Get MCP manager from app context
let mcp_manager = match std::env::var("SGLANG_MCP_CONFIG").ok() { let mcp_manager = ctx
Some(path) if !path.trim().is_empty() => { .mcp_manager
match crate::mcp::McpConfig::from_file(&path).await { .get()
Ok(cfg) => match crate::mcp::McpClientManager::new(cfg).await { .ok_or_else(|| "gRPC router requires MCP manager".to_string())?
Ok(mgr) => Some(Arc::new(mgr)), .clone();
Err(err) => {
tracing::warn!("Failed to initialize MCP manager: {}", err);
None
}
},
Err(err) => {
tracing::warn!("Failed to load MCP config from '{}': {}", path, err);
None
}
}
}
_ => None,
};
// Create shared components for pipeline // Create shared components for pipeline
let shared_components = Arc::new(SharedComponents { let shared_components = Arc::new(SharedComponents {
...@@ -285,6 +272,7 @@ impl RouterTrait for GrpcRouter { ...@@ -285,6 +272,7 @@ impl RouterTrait for GrpcRouter {
self.response_storage.clone(), self.response_storage.clone(),
self.conversation_storage.clone(), self.conversation_storage.clone(),
self.conversation_item_storage.clone(), self.conversation_item_storage.clone(),
self.mcp_manager.clone(),
self.background_tasks.clone(), self.background_tasks.clone(),
) )
.await .await
......
...@@ -18,7 +18,7 @@ use tracing::{info, warn}; ...@@ -18,7 +18,7 @@ use tracing::{info, warn};
use super::utils::{event_types, generate_id}; use super::utils::{event_types, generate_id};
use crate::{ use crate::{
mcp::McpClientManager, mcp,
protocols::responses::{ResponseInput, ResponseTool, ResponseToolType, ResponsesRequest}, protocols::responses::{ResponseInput, ResponseTool, ResponseToolType, ResponsesRequest},
routers::header_utils::apply_request_headers, routers::header_utils::apply_request_headers,
}; };
...@@ -128,10 +128,19 @@ impl FunctionCallInProgress { ...@@ -128,10 +128,19 @@ impl FunctionCallInProgress {
// MCP Manager Integration // MCP Manager Integration
// ============================================================================ // ============================================================================
/// Build a request-scoped MCP manager from request tools, if present. /// Ensure a dynamic MCP client exists for request-scoped tools.
pub async fn mcp_manager_from_request_tools( ///
/// This function parses request tools to extract MCP server configuration,
/// then ensures a dynamic client exists in the McpManager via `get_or_create_client()`.
/// The McpManager itself is returned (cloned Arc) for convenience, though the main
/// purpose is the side effect of registering the dynamic client.
///
/// Returns Some(manager) if a dynamic MCP tool was found and client was created/retrieved,
/// None if no MCP tools were found or connection failed.
pub async fn ensure_request_mcp_client(
mcp_manager: &Arc<mcp::McpManager>,
tools: &[ResponseTool], tools: &[ResponseTool],
) -> Option<Arc<McpClientManager>> { ) -> Option<Arc<mcp::McpManager>> {
let tool = tools let tool = tools
.iter() .iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some())?; .find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some())?;
...@@ -149,23 +158,30 @@ pub async fn mcp_manager_from_request_tools( ...@@ -149,23 +158,30 @@ pub async fn mcp_manager_from_request_tools(
.unwrap_or_else(|| "request-mcp".to_string()); .unwrap_or_else(|| "request-mcp".to_string());
let token = tool.authorization.clone(); let token = tool.authorization.clone();
let transport = if server_url.contains("/sse") { let transport = if server_url.contains("/sse") {
crate::mcp::McpTransport::Sse { mcp::McpTransport::Sse {
url: server_url, url: server_url.clone(),
token, token,
} }
} else { } else {
crate::mcp::McpTransport::Streamable { mcp::McpTransport::Streamable {
url: server_url, url: server_url.clone(),
token, token,
} }
}; };
let cfg = crate::mcp::McpConfig {
servers: vec![crate::mcp::McpServerConfig { name, transport }], // Create server config
let server_config = mcp::McpServerConfig {
name,
transport,
proxy: None,
required: false,
}; };
match McpClientManager::new(cfg).await {
Ok(mgr) => Some(Arc::new(mgr)), // Use McpManager to get or create dynamic client
match mcp_manager.get_or_create_client(server_config).await {
Ok(_client) => Some(mcp_manager.clone()),
Err(err) => { Err(err) => {
warn!("Failed to initialize request-scoped MCP manager: {}", err); warn!("Failed to get/create MCP connection: {}", err);
None None
} }
} }
...@@ -177,7 +193,7 @@ pub async fn mcp_manager_from_request_tools( ...@@ -177,7 +193,7 @@ pub async fn mcp_manager_from_request_tools(
/// Execute an MCP tool call /// Execute an MCP tool call
pub(super) async fn execute_mcp_call( pub(super) async fn execute_mcp_call(
mcp_mgr: &Arc<McpClientManager>, mcp_mgr: &Arc<mcp::McpManager>,
tool_name: &str, tool_name: &str,
args_json_str: &str, args_json_str: &str,
) -> Result<(String, String), String> { ) -> Result<(String, String), String> {
...@@ -204,7 +220,7 @@ pub(super) async fn execute_mcp_call( ...@@ -204,7 +220,7 @@ pub(super) async fn execute_mcp_call(
/// Returns false if client disconnected during execution /// Returns false if client disconnected during execution
pub(super) async fn execute_streaming_tool_calls( pub(super) async fn execute_streaming_tool_calls(
pending_calls: Vec<FunctionCallInProgress>, pending_calls: Vec<FunctionCallInProgress>,
active_mcp: &Arc<McpClientManager>, active_mcp: &Arc<mcp::McpManager>,
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>, tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
state: &mut ToolLoopState, state: &mut ToolLoopState,
server_label: &str, server_label: &str,
...@@ -269,7 +285,7 @@ pub(super) async fn execute_streaming_tool_calls( ...@@ -269,7 +285,7 @@ pub(super) async fn execute_streaming_tool_calls(
/// Transform payload to replace MCP tools with function tools for streaming /// Transform payload to replace MCP tools with function tools for streaming
pub(super) fn prepare_mcp_payload_for_streaming( pub(super) fn prepare_mcp_payload_for_streaming(
payload: &mut Value, payload: &mut Value,
active_mcp: &Arc<McpClientManager>, active_mcp: &Arc<mcp::McpManager>,
) { ) {
if let Some(obj) = payload.as_object_mut() { if let Some(obj) = payload.as_object_mut() {
// Remove any non-function tools from outgoing payload // Remove any non-function tools from outgoing payload
...@@ -377,7 +393,7 @@ pub(super) fn build_resume_payload( ...@@ -377,7 +393,7 @@ pub(super) fn build_resume_payload(
/// Returns false if client disconnected /// Returns false if client disconnected
pub(super) fn send_mcp_list_tools_events( pub(super) fn send_mcp_list_tools_events(
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>, tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
mcp: &Arc<McpClientManager>, mcp: &Arc<mcp::McpManager>,
server_label: &str, server_label: &str,
output_index: usize, output_index: usize,
sequence_number: &mut u64, sequence_number: &mut u64,
...@@ -533,7 +549,7 @@ pub(super) fn send_mcp_call_completion_events_with_error( ...@@ -533,7 +549,7 @@ pub(super) fn send_mcp_call_completion_events_with_error(
pub(super) fn inject_mcp_metadata_streaming( pub(super) fn inject_mcp_metadata_streaming(
response: &mut Value, response: &mut Value,
state: &ToolLoopState, state: &ToolLoopState,
mcp: &Arc<McpClientManager>, mcp: &Arc<mcp::McpManager>,
server_label: &str, server_label: &str,
) { ) {
if let Some(output_array) = response.get_mut("output").and_then(|v| v.as_array_mut()) { if let Some(output_array) = response.get_mut("output").and_then(|v| v.as_array_mut()) {
...@@ -573,7 +589,7 @@ pub(super) async fn execute_tool_loop( ...@@ -573,7 +589,7 @@ pub(super) async fn execute_tool_loop(
headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
initial_payload: Value, initial_payload: Value,
original_body: &ResponsesRequest, original_body: &ResponsesRequest,
active_mcp: &Arc<McpClientManager>, active_mcp: &Arc<mcp::McpManager>,
config: &McpLoopConfig, config: &McpLoopConfig,
) -> Result<Value, String> { ) -> Result<Value, String> {
let mut state = ToolLoopState::new(original_body.input.clone()); let mut state = ToolLoopState::new(original_body.input.clone());
...@@ -734,7 +750,7 @@ pub(super) fn build_incomplete_response( ...@@ -734,7 +750,7 @@ pub(super) fn build_incomplete_response(
mut response: Value, mut response: Value,
state: ToolLoopState, state: ToolLoopState,
reason: &str, reason: &str,
active_mcp: &Arc<McpClientManager>, active_mcp: &Arc<mcp::McpManager>,
original_body: &ResponsesRequest, original_body: &ResponsesRequest,
) -> Result<Value, String> { ) -> Result<Value, String> {
let obj = response let obj = response
...@@ -837,7 +853,7 @@ pub(super) fn build_incomplete_response( ...@@ -837,7 +853,7 @@ pub(super) fn build_incomplete_response(
// ============================================================================ // ============================================================================
/// Build an mcp_list_tools output item /// Build an mcp_list_tools output item
pub(super) fn build_mcp_list_tools_item(mcp: &Arc<McpClientManager>, server_label: &str) -> Value { pub(super) fn build_mcp_list_tools_item(mcp: &Arc<mcp::McpManager>, server_label: &str) -> Value {
let tools = mcp.list_tools(); let tools = mcp.list_tools();
let tools_json: Vec<Value> = tools let tools_json: Vec<Value> = tools
.iter() .iter()
......
...@@ -28,7 +28,7 @@ use super::conversations::{ ...@@ -28,7 +28,7 @@ use super::conversations::{
}; };
use super::{ use super::{
mcp::{ mcp::{
execute_tool_loop, mcp_manager_from_request_tools, prepare_mcp_payload_for_streaming, ensure_request_mcp_client, execute_tool_loop, prepare_mcp_payload_for_streaming,
McpLoopConfig, McpLoopConfig,
}, },
responses::{mask_tools_as_mcp, patch_streaming_response_json}, responses::{mask_tools_as_mcp, patch_streaming_response_json},
...@@ -36,12 +36,12 @@ use super::{ ...@@ -36,12 +36,12 @@ use super::{
utils::{apply_provider_headers, extract_auth_header, probe_endpoint_for_model}, utils::{apply_provider_headers, extract_auth_header, probe_endpoint_for_model},
}; };
use crate::{ use crate::{
config::CircuitBreakerConfig,
core::{CircuitBreaker, CircuitBreakerConfig as CoreCircuitBreakerConfig}, core::{CircuitBreaker, CircuitBreakerConfig as CoreCircuitBreakerConfig},
data_connector::{ data_connector::{
ConversationId, ListParams, ResponseId, SharedConversationItemStorage, ConversationId, ListParams, ResponseId, SharedConversationItemStorage,
SharedConversationStorage, SharedResponseStorage, SortOrder, SharedConversationStorage, SharedResponseStorage, SortOrder,
}, },
mcp::McpManager,
protocols::{ protocols::{
chat::ChatCompletionRequest, chat::ChatCompletionRequest,
classify::ClassifyRequest, classify::ClassifyRequest,
...@@ -86,8 +86,8 @@ pub struct OpenAIRouter { ...@@ -86,8 +86,8 @@ pub struct OpenAIRouter {
conversation_storage: SharedConversationStorage, conversation_storage: SharedConversationStorage,
/// Conversation item storage backend /// Conversation item storage backend
conversation_item_storage: SharedConversationItemStorage, conversation_item_storage: SharedConversationItemStorage,
/// Optional MCP manager (enabled via config presence) /// MCP manager (handles both static and dynamic servers)
mcp_manager: Option<Arc<crate::mcp::McpClientManager>>, mcp_manager: Arc<McpManager>,
} }
impl std::fmt::Debug for OpenAIRouter { impl std::fmt::Debug for OpenAIRouter {
...@@ -109,15 +109,10 @@ impl OpenAIRouter { ...@@ -109,15 +109,10 @@ impl OpenAIRouter {
/// Create a new OpenAI router /// Create a new OpenAI router
pub async fn new( pub async fn new(
worker_urls: Vec<String>, worker_urls: Vec<String>,
circuit_breaker_config: Option<CircuitBreakerConfig>, ctx: &Arc<crate::app_context::AppContext>,
response_storage: SharedResponseStorage,
conversation_storage: SharedConversationStorage,
conversation_item_storage: SharedConversationItemStorage,
) -> Result<Self, String> { ) -> Result<Self, String> {
let client = reqwest::Client::builder() // Use HTTP client from AppContext
.timeout(Duration::from_secs(300)) let client = ctx.client.clone();
.build()
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
// Normalize URLs (remove trailing slashes) // Normalize URLs (remove trailing slashes)
let worker_urls: Vec<String> = worker_urls let worker_urls: Vec<String> = worker_urls
...@@ -125,37 +120,23 @@ impl OpenAIRouter { ...@@ -125,37 +120,23 @@ impl OpenAIRouter {
.map(|url| url.trim_end_matches('/').to_string()) .map(|url| url.trim_end_matches('/').to_string())
.collect(); .collect();
// Convert circuit breaker config // Convert circuit breaker config from AppContext
let core_cb_config = circuit_breaker_config let cb = &ctx.router_config.circuit_breaker;
.map(|cb| CoreCircuitBreakerConfig { let core_cb_config = CoreCircuitBreakerConfig {
failure_threshold: cb.failure_threshold, failure_threshold: cb.failure_threshold,
success_threshold: cb.success_threshold, success_threshold: cb.success_threshold,
timeout_duration: Duration::from_secs(cb.timeout_duration_secs), timeout_duration: Duration::from_secs(cb.timeout_duration_secs),
window_duration: Duration::from_secs(cb.window_duration_secs), window_duration: Duration::from_secs(cb.window_duration_secs),
}) };
.unwrap_or_default();
let circuit_breaker = CircuitBreaker::with_config(core_cb_config); let circuit_breaker = CircuitBreaker::with_config(core_cb_config);
// Optional MCP manager activation via env var path (config-driven gate) // Get MCP manager from AppContext (must be initialized)
let mcp_manager = match std::env::var("SGLANG_MCP_CONFIG").ok() { let mcp_manager = ctx
Some(path) if !path.trim().is_empty() => { .mcp_manager
match crate::mcp::McpConfig::from_file(&path).await { .get()
Ok(cfg) => match crate::mcp::McpClientManager::new(cfg).await { .ok_or_else(|| "MCP manager not initialized in AppContext".to_string())?
Ok(mgr) => Some(Arc::new(mgr)), .clone();
Err(err) => {
warn!("Failed to initialize MCP manager: {}", err);
None
}
},
Err(err) => {
warn!("Failed to load MCP config from '{}': {}", path, err);
None
}
}
}
_ => None,
};
Ok(Self { Ok(Self {
client, client,
...@@ -163,9 +144,9 @@ impl OpenAIRouter { ...@@ -163,9 +144,9 @@ impl OpenAIRouter {
model_cache: Arc::new(DashMap::new()), model_cache: Arc::new(DashMap::new()),
circuit_breaker, circuit_breaker,
healthy: AtomicBool::new(true), healthy: AtomicBool::new(true),
response_storage, response_storage: ctx.response_storage.clone(),
conversation_storage, conversation_storage: ctx.conversation_storage.clone(),
conversation_item_storage, conversation_item_storage: ctx.conversation_item_storage.clone(),
mcp_manager, mcp_manager,
}) })
} }
...@@ -241,12 +222,17 @@ impl OpenAIRouter { ...@@ -241,12 +222,17 @@ impl OpenAIRouter {
original_previous_response_id: Option<String>, original_previous_response_id: Option<String>,
) -> Response { ) -> Response {
// Check if MCP is active for this request // Check if MCP is active for this request
let req_mcp_manager = if let Some(ref tools) = original_body.tools { // Ensure dynamic client is created if needed
mcp_manager_from_request_tools(tools.as_slice()).await if let Some(ref tools) = original_body.tools {
} else { ensure_request_mcp_client(&self.mcp_manager, tools.as_slice()).await;
}
// Use the tool loop if the manager has any tools available (static or dynamic).
let active_mcp = if self.mcp_manager.list_tools().is_empty() {
None None
} else {
Some(&self.mcp_manager)
}; };
let active_mcp = req_mcp_manager.as_ref().or(self.mcp_manager.as_ref());
let mut response_json: Value; let mut response_json: Value;
...@@ -984,7 +970,7 @@ impl crate::routers::RouterTrait for OpenAIRouter { ...@@ -984,7 +970,7 @@ impl crate::routers::RouterTrait for OpenAIRouter {
handle_streaming_response( handle_streaming_response(
&self.client, &self.client,
&self.circuit_breaker, &self.circuit_breaker,
self.mcp_manager.as_ref(), Some(&self.mcp_manager),
self.response_storage.clone(), self.response_storage.clone(),
self.conversation_storage.clone(), self.conversation_storage.clone(),
self.conversation_item_storage.clone(), self.conversation_item_storage.clone(),
......
...@@ -25,8 +25,8 @@ use tracing::warn; ...@@ -25,8 +25,8 @@ use tracing::warn;
use super::conversations::persist_conversation_items; use super::conversations::persist_conversation_items;
use super::{ use super::{
mcp::{ mcp::{
build_resume_payload, execute_streaming_tool_calls, inject_mcp_metadata_streaming, build_resume_payload, ensure_request_mcp_client, execute_streaming_tool_calls,
mcp_manager_from_request_tools, prepare_mcp_payload_for_streaming, inject_mcp_metadata_streaming, prepare_mcp_payload_for_streaming,
send_mcp_list_tools_events, McpLoopConfig, ToolLoopState, send_mcp_list_tools_events, McpLoopConfig, ToolLoopState,
}, },
responses::{mask_tools_as_mcp, patch_streaming_response_json, rewrite_streaming_block}, responses::{mask_tools_as_mcp, patch_streaming_response_json, rewrite_streaming_block},
...@@ -907,7 +907,7 @@ pub(super) fn send_final_response_event( ...@@ -907,7 +907,7 @@ pub(super) fn send_final_response_event(
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>, tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
sequence_number: &mut u64, sequence_number: &mut u64,
state: &ToolLoopState, state: &ToolLoopState,
active_mcp: Option<&Arc<crate::mcp::McpClientManager>>, active_mcp: Option<&Arc<crate::mcp::McpManager>>,
original_request: &ResponsesRequest, original_request: &ResponsesRequest,
previous_response_id: Option<&str>, previous_response_id: Option<&str>,
server_label: &str, server_label: &str,
...@@ -1138,7 +1138,7 @@ pub(super) async fn handle_streaming_with_tool_interception( ...@@ -1138,7 +1138,7 @@ pub(super) async fn handle_streaming_with_tool_interception(
mut payload: Value, mut payload: Value,
original_body: &ResponsesRequest, original_body: &ResponsesRequest,
original_previous_response_id: Option<String>, original_previous_response_id: Option<String>,
active_mcp: &Arc<crate::mcp::McpClientManager>, active_mcp: &Arc<crate::mcp::McpManager>,
) -> Response { ) -> Response {
// Transform MCP tools to function tools in payload // Transform MCP tools to function tools in payload
prepare_mcp_payload_for_streaming(&mut payload, active_mcp); prepare_mcp_payload_for_streaming(&mut payload, active_mcp);
...@@ -1491,7 +1491,7 @@ pub(super) async fn handle_streaming_with_tool_interception( ...@@ -1491,7 +1491,7 @@ pub(super) async fn handle_streaming_with_tool_interception(
pub(super) async fn handle_streaming_response( pub(super) async fn handle_streaming_response(
client: &reqwest::Client, client: &reqwest::Client,
circuit_breaker: &crate::core::CircuitBreaker, circuit_breaker: &crate::core::CircuitBreaker,
mcp_manager: Option<&Arc<crate::mcp::McpClientManager>>, mcp_manager: Option<&Arc<crate::mcp::McpManager>>,
response_storage: SharedResponseStorage, response_storage: SharedResponseStorage,
conversation_storage: SharedConversationStorage, conversation_storage: SharedConversationStorage,
conversation_item_storage: SharedConversationItemStorage, conversation_item_storage: SharedConversationItemStorage,
...@@ -1502,12 +1502,19 @@ pub(super) async fn handle_streaming_response( ...@@ -1502,12 +1502,19 @@ pub(super) async fn handle_streaming_response(
original_previous_response_id: Option<String>, original_previous_response_id: Option<String>,
) -> Response { ) -> Response {
// Check if MCP is active for this request // Check if MCP is active for this request
let req_mcp_manager = if let Some(ref tools) = original_body.tools { // Ensure dynamic client is created if needed
mcp_manager_from_request_tools(tools.as_slice()).await if let (Some(manager), Some(ref tools)) = (mcp_manager, &original_body.tools) {
} else { ensure_request_mcp_client(manager, tools.as_slice()).await;
None }
};
let active_mcp = req_mcp_manager.as_ref().or(mcp_manager); // Use the tool loop if the manager has any tools available (static or dynamic).
let active_mcp = mcp_manager.and_then(|mgr| {
if mgr.list_tools().is_empty() {
None
} else {
Some(mgr)
}
});
// If no MCP is active, use simple pass-through streaming // If no MCP is active, use simple pass-through streaming
if active_mcp.is_none() { if active_mcp.is_none() {
......
...@@ -24,8 +24,8 @@ use crate::{ ...@@ -24,8 +24,8 @@ use crate::{
core::{ core::{
worker_to_info, worker_to_info,
workflow::{ workflow::{
create_worker_registration_workflow, create_worker_removal_workflow, LoggingSubscriber, create_mcp_registration_workflow, create_worker_registration_workflow,
WorkflowEngine, create_worker_removal_workflow, LoggingSubscriber, WorkflowEngine,
}, },
Job, JobQueue, JobQueueConfig, WorkerManager, WorkerType, Job, JobQueue, JobQueueConfig, WorkerManager, WorkerType,
}, },
...@@ -739,11 +739,12 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err ...@@ -739,11 +739,12 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
engine.register_workflow(create_worker_registration_workflow()); engine.register_workflow(create_worker_registration_workflow());
engine.register_workflow(create_worker_removal_workflow()); engine.register_workflow(create_worker_removal_workflow());
engine.register_workflow(create_mcp_registration_workflow());
app_context app_context
.workflow_engine .workflow_engine
.set(engine) .set(engine)
.expect("WorkflowEngine should only be initialized once"); .expect("WorkflowEngine should only be initialized once");
info!("Workflow engine initialized with worker registration and removal workflows"); info!("Workflow engine initialized with worker and MCP registration workflows");
info!( info!(
"Initializing workers for routing mode: {:?}", "Initializing workers for routing mode: {:?}",
...@@ -763,6 +764,27 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err ...@@ -763,6 +764,27 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
.await .await
.map_err(|e| format!("Failed to submit worker initialization job: {}", e))?; .map_err(|e| format!("Failed to submit worker initialization job: {}", e))?;
if let Some(mcp_config) = &config.router_config.mcp_config {
info!("Found {} MCP server(s) in config", mcp_config.servers.len());
let mcp_job = Job::InitializeMcpServers {
mcp_config: Box::new(mcp_config.clone()),
};
job_queue
.submit(mcp_job)
.await
.map_err(|e| format!("Failed to submit MCP initialization job: {}", e))?;
} else {
info!("No MCP config provided, skipping MCP server initialization");
}
// Start background refresh for all registered static MCP servers
if let Some(mcp_manager) = app_context.mcp_manager.get() {
let refresh_interval = Duration::from_secs(300); // 5 minutes, matches default TTL
let _refresh_handle =
Arc::clone(mcp_manager).spawn_background_refresh_all(refresh_interval);
info!("Started background refresh for all static MCP servers");
}
let worker_stats = app_context.worker_registry.stats(); let worker_stats = app_context.worker_registry.stats();
info!( info!(
"Workers initialized: {} total, {} healthy", "Workers initialized: {} total, {} healthy",
......
...@@ -593,6 +593,7 @@ mod tests { ...@@ -593,6 +593,7 @@ mod tests {
configured_tool_parser: None, configured_tool_parser: None,
worker_job_queue: Arc::new(std::sync::OnceLock::new()), worker_job_queue: Arc::new(std::sync::OnceLock::new()),
workflow_engine: Arc::new(std::sync::OnceLock::new()), workflow_engine: Arc::new(std::sync::OnceLock::new()),
mcp_manager: Arc::new(std::sync::OnceLock::new()),
}) })
} }
......
...@@ -90,7 +90,7 @@ impl TestContext { ...@@ -90,7 +90,7 @@ impl TestContext {
.unwrap(); .unwrap();
// Create app context // Create app context
let app_context = common::create_test_context(config.clone()); let app_context = common::create_test_context(config.clone()).await;
// Submit worker initialization job (same as real server does) // Submit worker initialization job (same as real server does)
if !worker_urls.is_empty() { if !worker_urls.is_empty() {
...@@ -1538,7 +1538,7 @@ mod pd_mode_tests { ...@@ -1538,7 +1538,7 @@ mod pd_mode_tests {
.build_unchecked(); .build_unchecked();
// Create app context // Create app context
let app_context = common::create_test_context(config); let app_context = common::create_test_context(config).await;
// Create router - this might fail due to health check issues // Create router - this might fail due to health check issues
let router_result = RouterFactory::create_router(&app_context).await; let router_result = RouterFactory::create_router(&app_context).await;
......
...@@ -27,7 +27,7 @@ use sglang_router_rs::{ ...@@ -27,7 +27,7 @@ use sglang_router_rs::{
}; };
/// Helper function to create AppContext for tests /// Helper function to create AppContext for tests
pub fn create_test_context(config: RouterConfig) -> Arc<AppContext> { pub async fn create_test_context(config: RouterConfig) -> Arc<AppContext> {
let client = reqwest::Client::new(); let client = reqwest::Client::new();
// Initialize rate limiter // Initialize rate limiter
...@@ -62,9 +62,10 @@ pub fn create_test_context(config: RouterConfig) -> Arc<AppContext> { ...@@ -62,9 +62,10 @@ pub fn create_test_context(config: RouterConfig) -> Arc<AppContext> {
config.worker_startup_check_interval_secs, config.worker_startup_check_interval_secs,
))); )));
// Create empty OnceLock for worker job queue and workflow engine // Create empty OnceLock for worker job queue, workflow engine, and mcp manager
let worker_job_queue = Arc::new(OnceLock::new()); let worker_job_queue = Arc::new(OnceLock::new());
let workflow_engine = Arc::new(OnceLock::new()); let workflow_engine = Arc::new(OnceLock::new());
let mcp_manager_lock = Arc::new(OnceLock::new());
let app_context = Arc::new( let app_context = Arc::new(
AppContext::builder() AppContext::builder()
...@@ -82,6 +83,7 @@ pub fn create_test_context(config: RouterConfig) -> Arc<AppContext> { ...@@ -82,6 +83,7 @@ pub fn create_test_context(config: RouterConfig) -> Arc<AppContext> {
.load_monitor(load_monitor) .load_monitor(load_monitor)
.worker_job_queue(worker_job_queue) .worker_job_queue(worker_job_queue)
.workflow_engine(workflow_engine) .workflow_engine(workflow_engine)
.mcp_manager(mcp_manager_lock)
.build() .build()
.unwrap(), .unwrap(),
); );
...@@ -109,6 +111,130 @@ pub fn create_test_context(config: RouterConfig) -> Arc<AppContext> { ...@@ -109,6 +111,130 @@ pub fn create_test_context(config: RouterConfig) -> Arc<AppContext> {
.set(engine) .set(engine)
.expect("WorkflowEngine should only be initialized once"); .expect("WorkflowEngine should only be initialized once");
// Initialize MCP manager with empty config
use sglang_router_rs::mcp::{McpConfig, McpManager};
let empty_config = McpConfig {
servers: vec![],
pool: Default::default(),
proxy: None,
warmup: vec![],
inventory: Default::default(),
};
let mcp_manager = McpManager::with_defaults(empty_config)
.await
.expect("Failed to create MCP manager");
app_context
.mcp_manager
.set(Arc::new(mcp_manager))
.ok()
.expect("McpManager should only be initialized once");
app_context
}
/// Helper function to create AppContext for tests with MCP config from file
pub async fn create_test_context_with_mcp_config(
config: RouterConfig,
mcp_config_path: &str,
) -> Arc<AppContext> {
use sglang_router_rs::mcp::{McpConfig, McpManager};
let client = reqwest::Client::new();
// Initialize rate limiter
let rate_limiter = match config.max_concurrent_requests {
n if n <= 0 => None,
n => {
let rate_limit_tokens = config
.rate_limit_tokens_per_second
.filter(|&t| t > 0)
.unwrap_or(n);
Some(Arc::new(TokenBucket::new(
n as usize,
rate_limit_tokens as usize,
)))
}
};
// Initialize registries
let worker_registry = Arc::new(WorkerRegistry::new());
let policy_registry = Arc::new(PolicyRegistry::new(config.policy.clone()));
// Initialize storage backends (Memory for tests)
let response_storage = Arc::new(MemoryResponseStorage::new());
let conversation_storage = Arc::new(MemoryConversationStorage::new());
let conversation_item_storage = Arc::new(MemoryConversationItemStorage::new());
// Initialize load monitor
let load_monitor = Some(Arc::new(LoadMonitor::new(
worker_registry.clone(),
policy_registry.clone(),
client.clone(),
config.worker_startup_check_interval_secs,
)));
// Create empty OnceLock for worker job queue, workflow engine, and mcp manager
let worker_job_queue = Arc::new(OnceLock::new());
let workflow_engine = Arc::new(OnceLock::new());
let mcp_manager_lock = Arc::new(OnceLock::new());
let app_context = Arc::new(
AppContext::builder()
.router_config(config)
.client(client)
.rate_limiter(rate_limiter)
.tokenizer(None) // tokenizer
.reasoning_parser_factory(None) // reasoning_parser_factory
.tool_parser_factory(None) // tool_parser_factory
.worker_registry(worker_registry)
.policy_registry(policy_registry)
.response_storage(response_storage)
.conversation_storage(conversation_storage)
.conversation_item_storage(conversation_item_storage)
.load_monitor(load_monitor)
.worker_job_queue(worker_job_queue)
.workflow_engine(workflow_engine)
.mcp_manager(mcp_manager_lock)
.build()
.unwrap(),
);
// Initialize JobQueue after AppContext is created
let weak_context = Arc::downgrade(&app_context);
let job_queue = sglang_router_rs::core::JobQueue::new(
sglang_router_rs::core::JobQueueConfig::default(),
weak_context,
);
app_context
.worker_job_queue
.set(job_queue)
.expect("JobQueue should only be initialized once");
// Initialize WorkflowEngine and register workflows
use sglang_router_rs::core::workflow::{
create_worker_registration_workflow, create_worker_removal_workflow, WorkflowEngine,
};
let engine = Arc::new(WorkflowEngine::new());
engine.register_workflow(create_worker_registration_workflow());
engine.register_workflow(create_worker_removal_workflow());
app_context
.workflow_engine
.set(engine)
.expect("WorkflowEngine should only be initialized once");
// Initialize MCP manager from config file
let mcp_config = McpConfig::from_file(mcp_config_path)
.await
.expect("Failed to load MCP config from file");
let mcp_manager = McpManager::with_defaults(mcp_config)
.await
.expect("Failed to create MCP manager");
app_context
.mcp_manager
.set(Arc::new(mcp_manager))
.ok()
.expect("McpManager should only be initialized once");
app_context app_context
} }
......
This diff is collapsed.
This diff is collapsed.
...@@ -44,7 +44,7 @@ impl TestContext { ...@@ -44,7 +44,7 @@ impl TestContext {
worker_urls: worker_urls.clone(), worker_urls: worker_urls.clone(),
}; };
let app_context = common::create_test_context(config.clone()); let app_context = common::create_test_context(config.clone()).await;
let router = RouterFactory::create_router(&app_context).await.unwrap(); let router = RouterFactory::create_router(&app_context).await.unwrap();
let router = Arc::from(router); let router = Arc::from(router);
......
This diff is collapsed.
...@@ -45,7 +45,7 @@ impl TestContext { ...@@ -45,7 +45,7 @@ impl TestContext {
worker_urls: worker_urls.clone(), worker_urls: worker_urls.clone(),
}; };
let app_context = common::create_test_context(config.clone()); let app_context = common::create_test_context(config.clone()).await;
let router = RouterFactory::create_router(&app_context).await.unwrap(); let router = RouterFactory::create_router(&app_context).await.unwrap();
let router = Arc::from(router); let router = Arc::from(router);
......
This diff is collapsed.
...@@ -218,9 +218,10 @@ mod test_pd_routing { ...@@ -218,9 +218,10 @@ mod test_pd_routing {
config.worker_startup_check_interval_secs, config.worker_startup_check_interval_secs,
))); )));
// Create empty OnceLock for worker job queue and workflow engine // Create empty OnceLock for worker job queue, workflow engine, and mcp manager
let worker_job_queue = Arc::new(OnceLock::new()); let worker_job_queue = Arc::new(OnceLock::new());
let workflow_engine = Arc::new(OnceLock::new()); let workflow_engine = Arc::new(OnceLock::new());
let mcp_manager = Arc::new(OnceLock::new());
Arc::new( Arc::new(
AppContext::builder() AppContext::builder()
...@@ -238,6 +239,7 @@ mod test_pd_routing { ...@@ -238,6 +239,7 @@ mod test_pd_routing {
.load_monitor(load_monitor) .load_monitor(load_monitor)
.worker_job_queue(worker_job_queue) .worker_job_queue(worker_job_queue)
.workflow_engine(workflow_engine) .workflow_engine(workflow_engine)
.mcp_manager(mcp_manager)
.build() .build()
.unwrap(), .unwrap(),
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment