"vscode:/vscode.git/clone" did not exist on "d3dd8e370008f267ccb1657d746a5abea2e88305"
Unverified Commit 5fe39e85 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router] fix router manager and router init in server (#10499)

parent fa5d0bf6
...@@ -53,7 +53,7 @@ pub struct RouterManager { ...@@ -53,7 +53,7 @@ pub struct RouterManager {
routers: Arc<DashMap<RouterId, Arc<dyn RouterTrait>>>, routers: Arc<DashMap<RouterId, Arc<dyn RouterTrait>>>,
/// Default router for requests without specific routing /// Default router for requests without specific routing
default_router: Option<RouterId>, default_router: Arc<std::sync::RwLock<Option<RouterId>>>,
/// HTTP client for querying worker info /// HTTP client for querying worker info
client: reqwest::Client, client: reqwest::Client,
...@@ -75,27 +75,29 @@ impl RouterManager { ...@@ -75,27 +75,29 @@ impl RouterManager {
worker_registry, worker_registry,
policy_registry, policy_registry,
routers: Arc::new(DashMap::new()), routers: Arc::new(DashMap::new()),
default_router: None, default_router: Arc::new(std::sync::RwLock::new(None)),
client, client,
config, config,
} }
} }
/// Register a router with the manager /// Register a router with the manager
pub fn register_router(&mut self, id: RouterId, router: Arc<dyn RouterTrait>) { pub fn register_router(&self, id: RouterId, router: Arc<dyn RouterTrait>) {
// Store router // Store router
self.routers.insert(id.clone(), router); self.routers.insert(id.clone(), router);
// Set as default if first router // Set as default if first router
if self.default_router.is_none() { let mut default_router = self.default_router.write().unwrap();
self.default_router = Some(id.clone()); if default_router.is_none() {
*default_router = Some(id.clone());
info!("Set default router to {}", id.as_str()); info!("Set default router to {}", id.as_str());
} }
} }
/// Set the default router /// Set the default router
pub fn set_default_router(&mut self, id: RouterId) { pub fn set_default_router(&self, id: RouterId) {
self.default_router = Some(id); let mut default_router = self.default_router.write().unwrap();
*default_router = Some(id);
} }
/// Get the number of registered routers /// Get the number of registered routers
...@@ -130,7 +132,8 @@ impl RouterManager { ...@@ -130,7 +132,8 @@ impl RouterManager {
} }
// Fall back to default router // Fall back to default router
if let Some(ref default_id) = self.default_router { let default_router = self.default_router.read().unwrap();
if let Some(ref default_id) = *default_router {
self.routers.get(default_id).map(|r| r.clone()) self.routers.get(default_id).map(|r| r.clone())
} else { } else {
None None
...@@ -808,7 +811,7 @@ impl std::fmt::Debug for RouterManager { ...@@ -808,7 +811,7 @@ impl std::fmt::Debug for RouterManager {
f.debug_struct("RouterManager") f.debug_struct("RouterManager")
.field("routers_count", &self.routers.len()) .field("routers_count", &self.routers.len())
.field("workers_count", &self.worker_registry.get_all().len()) .field("workers_count", &self.worker_registry.get_all().len())
.field("default_router", &self.default_router) .field("default_router", &*self.default_router.read().unwrap())
.finish() .finish()
} }
} }
...@@ -122,6 +122,7 @@ pub struct AppState { ...@@ -122,6 +122,7 @@ pub struct AppState {
pub router: Arc<dyn RouterTrait>, pub router: Arc<dyn RouterTrait>,
pub context: Arc<AppContext>, pub context: Arc<AppContext>,
pub concurrency_queue_tx: Option<tokio::sync::mpsc::Sender<QueuedRequest>>, pub concurrency_queue_tx: Option<tokio::sync::mpsc::Sender<QueuedRequest>>,
pub router_manager: Option<Arc<RouterManager>>,
} }
// Fallback handler for unmatched routes // Fallback handler for unmatched routes
...@@ -326,8 +327,8 @@ async fn create_worker( ...@@ -326,8 +327,8 @@ async fn create_worker(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
Json(config): Json<WorkerConfigRequest>, Json(config): Json<WorkerConfigRequest>,
) -> Response { ) -> Response {
// Check if the router is actually a RouterManager (enable_igw=true) // Check if we have a RouterManager (enable_igw=true)
if let Some(router_manager) = state.router.as_any().downcast_ref::<RouterManager>() { if let Some(router_manager) = &state.router_manager {
// Call RouterManager's add_worker method directly with the full config // Call RouterManager's add_worker method directly with the full config
match router_manager.add_worker(config).await { match router_manager.add_worker(config).await {
Ok(response) => (StatusCode::OK, Json(response)).into_response(), Ok(response) => (StatusCode::OK, Json(response)).into_response(),
...@@ -357,7 +358,7 @@ async fn create_worker( ...@@ -357,7 +358,7 @@ async fn create_worker(
/// GET /workers - List all workers with details /// GET /workers - List all workers with details
async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response { async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
if let Some(router_manager) = state.router.as_any().downcast_ref::<RouterManager>() { if let Some(router_manager) = &state.router_manager {
let response = router_manager.list_workers(); let response = router_manager.list_workers();
Json(response).into_response() Json(response).into_response()
} else { } else {
...@@ -400,7 +401,7 @@ async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response { ...@@ -400,7 +401,7 @@ async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
/// GET /workers/{url} - Get specific worker info /// GET /workers/{url} - Get specific worker info
async fn get_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response { async fn get_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response {
if let Some(router_manager) = state.router.as_any().downcast_ref::<RouterManager>() { if let Some(router_manager) = &state.router_manager {
if let Some(worker) = router_manager.get_worker(&url) { if let Some(worker) = router_manager.get_worker(&url) {
Json(worker).into_response() Json(worker).into_response()
} else { } else {
...@@ -431,7 +432,7 @@ async fn get_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) ...@@ -431,7 +432,7 @@ async fn get_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>)
/// DELETE /workers/{url} - Remove a worker /// DELETE /workers/{url} - Remove a worker
async fn delete_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response { async fn delete_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response {
if let Some(router_manager) = state.router.as_any().downcast_ref::<RouterManager>() { if let Some(router_manager) = &state.router_manager {
match router_manager.remove_worker_from_registry(&url) { match router_manager.remove_worker_from_registry(&url) {
Ok(response) => (StatusCode::OK, Json(response)).into_response(), Ok(response) => (StatusCode::OK, Json(response)).into_response(),
Err(error) => (StatusCode::BAD_REQUEST, Json(error)).into_response(), Err(error) => (StatusCode::BAD_REQUEST, Json(error)).into_response(),
...@@ -594,69 +595,76 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err ...@@ -594,69 +595,76 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
let app_context = Arc::new(app_context); let app_context = Arc::new(app_context);
// Create the appropriate router based on enable_igw flag // Create the appropriate router based on enable_igw flag
let router: Box<dyn RouterTrait> = if config.router_config.enable_igw { let (router, router_manager): (Arc<dyn RouterTrait>, Option<Arc<RouterManager>>) =
info!("Multi-router mode enabled (enable_igw=true)"); if config.router_config.enable_igw {
info!("Multi-router mode enabled (enable_igw=true)");
// Create RouterManager with shared registries from AppContext
let mut router_manager = RouterManager::new( // Create RouterManager with shared registries from AppContext
config.router_config.clone(), let router_manager = Arc::new(RouterManager::new(
client.clone(), config.router_config.clone(),
app_context.worker_registry.clone(), client.clone(),
app_context.policy_registry.clone(), app_context.worker_registry.clone(),
); app_context.policy_registry.clone(),
));
// 1. HTTP Regular Router
match RouterFactory::create_regular_router( // 1. HTTP Regular Router
&[], // Empty worker list - workers added later match RouterFactory::create_regular_router(
&app_context, &[], // Empty worker list - workers added later
) &app_context,
.await )
{ .await
Ok(http_regular) => { {
info!("Created HTTP Regular router"); Ok(http_regular) => {
router_manager.register_router( info!("Created HTTP Regular router");
RouterId::new("http-regular".to_string()), router_manager.register_router(
Arc::from(http_regular), RouterId::new("http-regular".to_string()),
); Arc::from(http_regular),
} );
Err(e) => { }
warn!("Failed to create HTTP Regular router: {e}"); Err(e) => {
warn!("Failed to create HTTP Regular router: {e}");
}
} }
}
// 2. HTTP PD Router // 2. HTTP PD Router
match RouterFactory::create_pd_router( match RouterFactory::create_pd_router(
&[], &[],
&[], &[],
None, None,
None, None,
&config.router_config.policy, &config.router_config.policy,
&app_context, &app_context,
) )
.await .await
{ {
Ok(http_pd) => { Ok(http_pd) => {
info!("Created HTTP PD router"); info!("Created HTTP PD router");
router_manager router_manager
.register_router(RouterId::new("http-pd".to_string()), Arc::from(http_pd)); .register_router(RouterId::new("http-pd".to_string()), Arc::from(http_pd));
} }
Err(e) => { Err(e) => {
warn!("Failed to create HTTP PD router: {e}"); warn!("Failed to create HTTP PD router: {e}");
}
} }
}
// TODO: Add gRPC routers once we have dynamic tokenizer loading // TODO: Add gRPC routers once we have dynamic tokenizer loading
info!( info!(
"RouterManager initialized with {} routers", "RouterManager initialized with {} routers",
router_manager.router_count() router_manager.router_count()
); );
Box::new(router_manager) (
} else { router_manager.clone() as Arc<dyn RouterTrait>,
info!("Single router mode (enable_igw=false)"); Some(router_manager),
// Create single router with the context )
RouterFactory::create_router(&app_context).await? } else {
}; info!("Single router mode (enable_igw=false)");
// Create single router with the context
(
Arc::from(RouterFactory::create_router(&app_context).await?),
None,
)
};
// Start health checker for all workers in the registry // Start health checker for all workers in the registry
let _health_checker = app_context let _health_checker = app_context
...@@ -685,9 +693,10 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err ...@@ -685,9 +693,10 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
// Create app state with router and context // Create app state with router and context
let app_state = Arc::new(AppState { let app_state = Arc::new(AppState {
router: Arc::from(router), router,
context: app_context.clone(), context: app_context.clone(),
concurrency_queue_tx: limiter.queue_tx.clone(), concurrency_queue_tx: limiter.queue_tx.clone(),
router_manager,
}); });
let router_arc = Arc::clone(&app_state.router); let router_arc = Arc::clone(&app_state.router);
......
...@@ -29,7 +29,8 @@ pub fn create_test_app( ...@@ -29,7 +29,8 @@ pub fn create_test_app(
let app_state = Arc::new(AppState { let app_state = Arc::new(AppState {
router, router,
context: app_context, context: app_context,
concurrency_queue_tx: None, // No queue for tests concurrency_queue_tx: None,
router_manager: None,
}); });
// Configure request ID headers (use defaults if not specified) // Configure request ID headers (use defaults if not specified)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment