Unverified Commit 2689f0bf authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router] multi model registration fix (#10481)

parent 52074240
......@@ -5,7 +5,7 @@
//! - Multi-Router Mode (enable_igw=true): RouterManager coordinates everything
use crate::config::RouterConfig;
use crate::core::{CircuitBreakerConfig, Worker, WorkerFactory, WorkerRegistry};
use crate::core::{CircuitBreakerConfig, Worker, WorkerFactory, WorkerRegistry, WorkerType};
use crate::protocols::spec::{
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
ResponsesRequest,
......@@ -56,10 +56,6 @@ pub struct RouterManager {
/// Default router for requests without specific routing
default_router: Option<RouterId>,
/// Model to router mapping for model-aware routing
/// Multiple models can be served by the same router
model_routers: Arc<DashMap<String, Vec<RouterId>>>,
/// HTTP client for querying worker info
client: reqwest::Client,
......@@ -81,7 +77,6 @@ impl RouterManager {
policy_registry,
routers: Arc::new(DashMap::new()),
default_router: None,
model_routers: Arc::new(DashMap::new()),
client,
config,
}
......@@ -92,19 +87,11 @@ impl RouterManager {
&mut self,
id: RouterId,
router: Arc<dyn RouterTrait>,
models: Vec<String>,
_models: Vec<String>, // Keep parameter for backward compatibility but ignore it
) {
// Store router
self.routers.insert(id.clone(), router);
// Update model mappings
for model in models {
self.model_routers
.entry(model)
.or_default()
.push(id.clone());
}
// Set as default if first router
if self.default_router.is_none() {
self.default_router = Some(id.clone());
......@@ -122,14 +109,29 @@ impl RouterManager {
self.routers.len()
}
/// Get router for a specific model
/// Get router for a specific model based on worker types
pub fn get_router_for_model(&self, model_id: &str) -> Option<Arc<dyn RouterTrait>> {
// First try model-specific routers
if let Some(router_ids) = self.model_routers.get(model_id) {
if let Some(router_id) = router_ids.first() {
if let Some(router) = self.routers.get(router_id) {
return Some(router.clone());
}
// Query workers for this model from registry
let workers = self.worker_registry.get_by_model(model_id);
if !workers.is_empty() {
// Determine router based on worker types
let has_pd_workers = workers.iter().any(|w| {
matches!(
w.worker_type(),
WorkerType::Prefill { .. } | WorkerType::Decode
)
});
let router_id = if has_pd_workers {
RouterId::new("http-pd".to_string())
} else {
RouterId::new("http-regular".to_string())
};
// Return the router if it exists
if let Some(router) = self.routers.get(&router_id) {
return Some(router.clone());
}
}
......@@ -240,6 +242,17 @@ impl RouterManager {
let policy_hint = labels.get("policy").map(|s| s.as_str());
let policy = self.policy_registry.on_worker_added(&model_id, policy_hint);
// Log which type of router would handle this worker (for debugging)
let expected_router = match config.worker_type.as_deref() {
Some("prefill") | Some("decode") => "http-pd",
_ => "http-regular",
};
info!(
"Worker for model '{}' would be handled by '{}' router based on type",
model_id, expected_router
);
info!(
"Added worker {} with URL {} for model {} using policy {}",
worker_id.as_str(),
......@@ -272,8 +285,9 @@ impl RouterManager {
if let Some(_worker) = self.worker_registry.remove_by_url(url) {
// Notify PolicyRegistry about worker removal
if let Some(model_id) = model_id {
self.policy_registry.on_worker_removed(&model_id);
if let Some(ref model_id) = model_id {
self.policy_registry.on_worker_removed(model_id);
info!("Removed worker with URL {} for model {}", url, model_id);
} else {
info!("Removed worker with URL {}", url);
......@@ -406,14 +420,10 @@ impl RouterManager {
})
.unwrap_or(false);
// If model specified, find routers serving that model
// If model specified, use get_router_for_model
let candidate_routers = if let Some(model) = model_id {
// Get routers for specific model
if let Some(router_ids) = self.model_routers.get(model) {
router_ids
.iter()
.filter_map(|id| self.routers.get(id).map(|r| r.clone()))
.collect::<Vec<_>>()
if let Some(router) = self.get_router_for_model(model) {
vec![router]
} else {
Vec::new()
}
......@@ -547,14 +557,10 @@ impl RouterTrait for RouterManager {
.into_response()
}
/// Get available models - aggregate from all routers
/// Get available models - query from worker registry
async fn get_models(&self, _req: Request<Body>) -> Response {
// Return models that have registered routers
let models = self
.model_routers
.iter()
.map(|entry| entry.key().clone())
.collect::<Vec<_>>();
// Get models from worker registry
let models = self.worker_registry.get_models();
if models.is_empty() {
(StatusCode::SERVICE_UNAVAILABLE, "No models available").into_response()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment