Unverified Commit 2689f0bf authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router] multi model registration fix (#10481)

parent 52074240
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
//! - Multi-Router Mode (enable_igw=true): RouterManager coordinates everything //! - Multi-Router Mode (enable_igw=true): RouterManager coordinates everything
use crate::config::RouterConfig; use crate::config::RouterConfig;
use crate::core::{CircuitBreakerConfig, Worker, WorkerFactory, WorkerRegistry}; use crate::core::{CircuitBreakerConfig, Worker, WorkerFactory, WorkerRegistry, WorkerType};
use crate::protocols::spec::{ use crate::protocols::spec::{
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
ResponsesRequest, ResponsesRequest,
...@@ -56,10 +56,6 @@ pub struct RouterManager { ...@@ -56,10 +56,6 @@ pub struct RouterManager {
/// Default router for requests without specific routing /// Default router for requests without specific routing
default_router: Option<RouterId>, default_router: Option<RouterId>,
/// Model to router mapping for model-aware routing
/// Multiple models can be served by the same router
model_routers: Arc<DashMap<String, Vec<RouterId>>>,
/// HTTP client for querying worker info /// HTTP client for querying worker info
client: reqwest::Client, client: reqwest::Client,
...@@ -81,7 +77,6 @@ impl RouterManager { ...@@ -81,7 +77,6 @@ impl RouterManager {
policy_registry, policy_registry,
routers: Arc::new(DashMap::new()), routers: Arc::new(DashMap::new()),
default_router: None, default_router: None,
model_routers: Arc::new(DashMap::new()),
client, client,
config, config,
} }
...@@ -92,19 +87,11 @@ impl RouterManager { ...@@ -92,19 +87,11 @@ impl RouterManager {
&mut self, &mut self,
id: RouterId, id: RouterId,
router: Arc<dyn RouterTrait>, router: Arc<dyn RouterTrait>,
models: Vec<String>, _models: Vec<String>, // Keep parameter for backward compatibility but ignore it
) { ) {
// Store router // Store router
self.routers.insert(id.clone(), router); self.routers.insert(id.clone(), router);
// Update model mappings
for model in models {
self.model_routers
.entry(model)
.or_default()
.push(id.clone());
}
// Set as default if first router // Set as default if first router
if self.default_router.is_none() { if self.default_router.is_none() {
self.default_router = Some(id.clone()); self.default_router = Some(id.clone());
...@@ -122,16 +109,31 @@ impl RouterManager { ...@@ -122,16 +109,31 @@ impl RouterManager {
self.routers.len() self.routers.len()
} }
/// Get router for a specific model /// Get router for a specific model based on worker types
pub fn get_router_for_model(&self, model_id: &str) -> Option<Arc<dyn RouterTrait>> { pub fn get_router_for_model(&self, model_id: &str) -> Option<Arc<dyn RouterTrait>> {
// First try model-specific routers // Query workers for this model from registry
if let Some(router_ids) = self.model_routers.get(model_id) { let workers = self.worker_registry.get_by_model(model_id);
if let Some(router_id) = router_ids.first() {
if let Some(router) = self.routers.get(router_id) { if !workers.is_empty() {
// Determine router based on worker types
let has_pd_workers = workers.iter().any(|w| {
matches!(
w.worker_type(),
WorkerType::Prefill { .. } | WorkerType::Decode
)
});
let router_id = if has_pd_workers {
RouterId::new("http-pd".to_string())
} else {
RouterId::new("http-regular".to_string())
};
// Return the router if it exists
if let Some(router) = self.routers.get(&router_id) {
return Some(router.clone()); return Some(router.clone());
} }
} }
}
// Fall back to default router // Fall back to default router
if let Some(ref default_id) = self.default_router { if let Some(ref default_id) = self.default_router {
...@@ -240,6 +242,17 @@ impl RouterManager { ...@@ -240,6 +242,17 @@ impl RouterManager {
let policy_hint = labels.get("policy").map(|s| s.as_str()); let policy_hint = labels.get("policy").map(|s| s.as_str());
let policy = self.policy_registry.on_worker_added(&model_id, policy_hint); let policy = self.policy_registry.on_worker_added(&model_id, policy_hint);
// Log which type of router would handle this worker (for debugging)
let expected_router = match config.worker_type.as_deref() {
Some("prefill") | Some("decode") => "http-pd",
_ => "http-regular",
};
info!(
"Worker for model '{}' would be handled by '{}' router based on type",
model_id, expected_router
);
info!( info!(
"Added worker {} with URL {} for model {} using policy {}", "Added worker {} with URL {} for model {} using policy {}",
worker_id.as_str(), worker_id.as_str(),
...@@ -272,8 +285,9 @@ impl RouterManager { ...@@ -272,8 +285,9 @@ impl RouterManager {
if let Some(_worker) = self.worker_registry.remove_by_url(url) { if let Some(_worker) = self.worker_registry.remove_by_url(url) {
// Notify PolicyRegistry about worker removal // Notify PolicyRegistry about worker removal
if let Some(model_id) = model_id { if let Some(ref model_id) = model_id {
self.policy_registry.on_worker_removed(&model_id); self.policy_registry.on_worker_removed(model_id);
info!("Removed worker with URL {} for model {}", url, model_id); info!("Removed worker with URL {} for model {}", url, model_id);
} else { } else {
info!("Removed worker with URL {}", url); info!("Removed worker with URL {}", url);
...@@ -406,14 +420,10 @@ impl RouterManager { ...@@ -406,14 +420,10 @@ impl RouterManager {
}) })
.unwrap_or(false); .unwrap_or(false);
// If model specified, find routers serving that model // If model specified, use get_router_for_model
let candidate_routers = if let Some(model) = model_id { let candidate_routers = if let Some(model) = model_id {
// Get routers for specific model if let Some(router) = self.get_router_for_model(model) {
if let Some(router_ids) = self.model_routers.get(model) { vec![router]
router_ids
.iter()
.filter_map(|id| self.routers.get(id).map(|r| r.clone()))
.collect::<Vec<_>>()
} else { } else {
Vec::new() Vec::new()
} }
...@@ -547,14 +557,10 @@ impl RouterTrait for RouterManager { ...@@ -547,14 +557,10 @@ impl RouterTrait for RouterManager {
.into_response() .into_response()
} }
/// Get available models - aggregate from all routers /// Get available models - query from worker registry
async fn get_models(&self, _req: Request<Body>) -> Response { async fn get_models(&self, _req: Request<Body>) -> Response {
// Return models that have registered routers // Get models from worker registry
let models = self let models = self.worker_registry.get_models();
.model_routers
.iter()
.map(|entry| entry.key().clone())
.collect::<Vec<_>>();
if models.is_empty() { if models.is_empty() {
(StatusCode::SERVICE_UNAVAILABLE, "No models available").into_response() (StatusCode::SERVICE_UNAVAILABLE, "No models available").into_response()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment