Unverified Commit 60dbbd08 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

bugfix: Fix `get_worker_urls_for_model` in http/router.rs (#10754)

parent aa1c5cf5
...@@ -133,10 +133,12 @@ impl Router { ...@@ -133,10 +133,12 @@ impl Router {
/// Get worker URLs for a specific model /// Get worker URLs for a specific model
pub fn get_worker_urls_for_model(&self, model_id: Option<&str>) -> Vec<String> { pub fn get_worker_urls_for_model(&self, model_id: Option<&str>) -> Vec<String> {
let workers = match model_id { let workers = self.worker_registry.get_workers_filtered(
Some(model) => self.worker_registry.get_by_model_fast(model), model_id,
None => self.worker_registry.get_all(), Some(WorkerType::Regular),
}; Some(ConnectionMode::Http),
false, // get all workers
);
workers.iter().map(|w| w.url().to_string()).collect() workers.iter().map(|w| w.url().to_string()).collect()
} }
...@@ -315,22 +317,6 @@ impl Router { ...@@ -315,22 +317,6 @@ impl Router {
} }
} }
#[allow(dead_code)]
fn select_first_worker_for_model(&self, model_id: Option<&str>) -> Result<String, String> {
let workers = match model_id {
Some(model) => self.worker_registry.get_by_model_fast(model),
None => self.worker_registry.get_all(),
};
if workers.is_empty() {
Err(format!(
"No workers are available for model: {:?}",
model_id
))
} else {
Ok(workers[0].url().to_string())
}
}
pub async fn send_health_check(&self, worker_url: &str) -> Response { pub async fn send_health_check(&self, worker_url: &str) -> Response {
let health_url = if self.dp_aware { let health_url = if self.dp_aware {
// Need to extract the URL from "http://host:port@dp_rank" // Need to extract the URL from "http://host:port@dp_rank"
...@@ -444,11 +430,13 @@ impl Router { ...@@ -444,11 +430,13 @@ impl Router {
model_id: Option<&str>, model_id: Option<&str>,
text: Option<&str>, text: Option<&str>,
) -> Option<Arc<dyn Worker>> { ) -> Option<Arc<dyn Worker>> {
// Get workers for the specified model (O(1) lookup if model_id is provided) // Get workers for the specified model O(1), filtered by connection mode
let workers = match model_id { let workers = self.worker_registry.get_workers_filtered(
Some(model) => self.worker_registry.get_by_model_fast(model), model_id,
None => self.worker_registry.get_all(), Some(WorkerType::Regular),
}; Some(ConnectionMode::Http),
false, // get all workers, we'll filter by is_available() next
);
let available: Vec<Arc<dyn Worker>> = workers let available: Vec<Arc<dyn Worker>> = workers
.iter() .iter()
...@@ -982,8 +970,12 @@ impl Router { ...@@ -982,8 +970,12 @@ impl Router {
self.policy_registry.on_worker_added(model_id, None); self.policy_registry.on_worker_added(model_id, None);
// Initialize cache-aware policy if applicable // Initialize cache-aware policy if applicable
let model_workers = let model_workers = self.worker_registry.get_workers_filtered(
self.worker_registry.get_by_model_fast(model_id); Some(model_id),
Some(WorkerType::Regular),
Some(ConnectionMode::Http),
false,
);
self.policy_registry self.policy_registry
.init_cache_aware_policy(model_id, &model_workers); .init_cache_aware_policy(model_id, &model_workers);
...@@ -1018,7 +1010,12 @@ impl Router { ...@@ -1018,7 +1010,12 @@ impl Router {
self.policy_registry.on_worker_added(model_id, None); self.policy_registry.on_worker_added(model_id, None);
// Initialize cache-aware policy if applicable // Initialize cache-aware policy if applicable
let model_workers = self.worker_registry.get_by_model_fast(model_id); let model_workers = self.worker_registry.get_workers_filtered(
Some(model_id),
Some(WorkerType::Regular),
Some(ConnectionMode::Http),
false,
);
self.policy_registry self.policy_registry
.init_cache_aware_policy(model_id, &model_workers); .init_cache_aware_policy(model_id, &model_workers);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment