"docs/vscode:/vscode.git/clone" did not exist on "2533f92532b4d48723df7cd5d1de67b53b7fffe5"
Unverified Commit e7387035 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] consolidate worker get loads (#10880)

parent fe531d6f
...@@ -12,7 +12,9 @@ use crate::core::{ ...@@ -12,7 +12,9 @@ use crate::core::{
Worker, WorkerFactory, WorkerRegistry, WorkerType, Worker, WorkerFactory, WorkerRegistry, WorkerType,
}; };
use crate::policies::PolicyRegistry; use crate::policies::PolicyRegistry;
use crate::protocols::worker_spec::{FlushCacheResult, WorkerConfigRequest}; use crate::protocols::worker_spec::{
FlushCacheResult, WorkerConfigRequest, WorkerLoadInfo, WorkerLoadsResult,
};
use crate::server::AppContext; use crate::server::AppContext;
use futures::future; use futures::future;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
...@@ -1079,6 +1081,100 @@ impl WorkerManager { ...@@ -1079,6 +1081,100 @@ impl WorkerManager {
message, message,
}) })
} }
pub async fn get_worker_load(
url: &str,
api_key: Option<&str>,
client: &reqwest::Client,
) -> Option<isize> {
let load_url = format!("{}/get_load", url);
let mut request = client.get(&load_url);
if let Some(key) = api_key {
request = request.bearer_auth(key);
}
match request.send().await {
Ok(response) if response.status().is_success() => {
match response.json::<Value>().await {
Ok(json) => {
if let Some(load) = json.get("load").and_then(|v| v.as_i64()) {
debug!("Worker {} load: {}", url, load);
Some(load as isize)
} else {
warn!("Invalid load response from {}: {:?}", url, json);
None
}
}
Err(e) => {
warn!("Failed to parse load response from {}: {}", url, e);
None
}
}
}
Ok(response) => {
warn!(
"Failed to get load from {}: HTTP {}",
url,
response.status()
);
None
}
Err(e) => {
warn!("Failed to connect to {} for load check: {}", url, e);
None
}
}
}
pub async fn get_all_worker_loads(
worker_registry: &WorkerRegistry,
client: &reqwest::Client,
) -> WorkerLoadsResult {
let workers = worker_registry.get_all();
let total_workers = workers.len();
// Prepare tasks for parallel execution
let mut tasks = Vec::new();
for worker in &workers {
let url = worker.url().to_string();
let api_key = worker.api_key().clone();
let worker_type = match worker.worker_type() {
WorkerType::Regular => None,
WorkerType::Prefill { .. } => Some("prefill".to_string()),
WorkerType::Decode => Some("decode".to_string()),
};
let is_http = matches!(worker.connection_mode(), ConnectionMode::Http);
let client = client.clone();
tasks.push(async move {
let load = if is_http {
Self::get_worker_load(&url, api_key.as_deref(), &client)
.await
.unwrap_or(-1)
} else {
-1
};
WorkerLoadInfo {
worker: url,
worker_type,
load,
}
});
}
let loads = futures::future::join_all(tasks).await;
let successful = loads.iter().filter(|l| l.load >= 0).count();
let failed = loads.iter().filter(|l| l.load < 0).count();
WorkerLoadsResult {
loads,
total_workers,
successful,
failed,
}
}
} }
#[cfg(test)] #[cfg(test)]
......
...@@ -215,3 +215,28 @@ pub struct FlushCacheResult { ...@@ -215,3 +215,28 @@ pub struct FlushCacheResult {
/// Human-readable summary message /// Human-readable summary message
pub message: String, pub message: String,
} }
/// Result from getting worker loads
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct WorkerLoadsResult {
/// Worker URL and load pairs
pub loads: Vec<WorkerLoadInfo>,
/// Total number of workers
pub total_workers: usize,
/// Number of workers with successful load fetches
pub successful: usize,
/// Number of workers with failed load fetches
pub failed: usize,
}
/// Individual worker load information
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct WorkerLoadInfo {
/// Worker URL
pub worker: String,
/// Worker type (regular, prefill, decode)
#[serde(skip_serializing_if = "Option::is_none")]
pub worker_type: Option<String>,
/// Current load (-1 indicates failure to fetch)
pub load: isize,
}
...@@ -340,10 +340,6 @@ impl RouterTrait for GrpcPDRouter { ...@@ -340,10 +340,6 @@ impl RouterTrait for GrpcPDRouter {
(StatusCode::NOT_IMPLEMENTED).into_response() (StatusCode::NOT_IMPLEMENTED).into_response()
} }
async fn get_worker_loads(&self) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
fn router_type(&self) -> &'static str { fn router_type(&self) -> &'static str {
"grpc_pd" "grpc_pd"
} }
......
...@@ -787,10 +787,6 @@ impl RouterTrait for GrpcRouter { ...@@ -787,10 +787,6 @@ impl RouterTrait for GrpcRouter {
(StatusCode::NOT_IMPLEMENTED).into_response() (StatusCode::NOT_IMPLEMENTED).into_response()
} }
async fn get_worker_loads(&self) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
fn router_type(&self) -> &'static str { fn router_type(&self) -> &'static str {
"grpc" "grpc"
} }
......
...@@ -1296,14 +1296,6 @@ impl super::super::RouterTrait for OpenAIRouter { ...@@ -1296,14 +1296,6 @@ impl super::super::RouterTrait for OpenAIRouter {
} }
} }
async fn get_worker_loads(&self) -> Response {
(
StatusCode::FORBIDDEN,
"get_worker_loads not supported for OpenAI router",
)
.into_response()
}
fn router_type(&self) -> &'static str { fn router_type(&self) -> &'static str {
"openai" "openai"
} }
......
use super::pd_types::api_path; use super::pd_types::api_path;
use crate::config::types::RetryConfig; use crate::config::types::RetryConfig;
use crate::core::{ use crate::core::{
is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerLoadGuard, WorkerRegistry, is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerLoadGuard, WorkerManager,
WorkerType, WorkerRegistry, WorkerType,
}; };
use crate::metrics::RouterMetrics; use crate::metrics::RouterMetrics;
use crate::policies::{LoadBalancingPolicy, PolicyRegistry}; use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
...@@ -18,7 +18,6 @@ use axum::{ ...@@ -18,7 +18,6 @@ use axum::{
extract::Request, extract::Request,
http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode}, http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
Json,
}; };
use futures_util::StreamExt; use futures_util::StreamExt;
use reqwest::Client; use reqwest::Client;
...@@ -53,26 +52,6 @@ struct PDRequestContext<'a> { ...@@ -53,26 +52,6 @@ struct PDRequestContext<'a> {
} }
impl PDRouter { impl PDRouter {
fn _get_worker_url_and_key(&self, w: &Arc<dyn Worker>) -> (String, Option<String>) {
(w.url().to_string(), w.api_key().clone())
}
fn get_prefill_worker_urls_with_api_key(&self) -> Vec<(String, Option<String>)> {
self.worker_registry
.get_prefill_workers()
.iter()
.map(|w| self._get_worker_url_and_key(w))
.collect()
}
fn get_decode_worker_urls_with_api_key(&self) -> Vec<(String, Option<String>)> {
self.worker_registry
.get_decode_workers()
.iter()
.map(|w| self._get_worker_url_and_key(w))
.collect()
}
async fn proxy_to_first_prefill_worker( async fn proxy_to_first_prefill_worker(
&self, &self,
endpoint: &str, endpoint: &str,
...@@ -749,7 +728,10 @@ impl PDRouter { ...@@ -749,7 +728,10 @@ impl PDRouter {
let url = url.clone(); let url = url.clone();
let api_key = api_key.clone(); let api_key = api_key.clone();
async move { async move {
let load = get_worker_load(&client, &url, &api_key).await.unwrap_or(0); let load =
WorkerManager::get_worker_load(&url, api_key.as_deref(), &client)
.await
.unwrap_or(0);
(url, load) (url, load)
} }
}) })
...@@ -1083,49 +1065,6 @@ impl PDRouter { ...@@ -1083,49 +1065,6 @@ impl PDRouter {
} }
} }
// Helper functions
async fn get_worker_load(
client: &Client,
worker_url: &str,
api_key: &Option<String>,
) -> Option<isize> {
let mut req_builder = client.get(format!("{}/get_load", worker_url));
if let Some(key) = api_key {
req_builder = req_builder.bearer_auth(key);
}
match req_builder.send().await {
Ok(res) if res.status().is_success() => match res.bytes().await {
Ok(bytes) => match serde_json::from_slice::<Value>(&bytes) {
Ok(data) => data
.get("load")
.and_then(|v| v.as_i64())
.map(|v| v as isize),
Err(e) => {
debug!("Failed to parse load response from {}: {}", worker_url, e);
None
}
},
Err(e) => {
debug!("Failed to read load response from {}: {}", worker_url, e);
None
}
},
Ok(res) => {
debug!(
"Worker {} returned non-success status: {}",
worker_url,
res.status()
);
None
}
Err(e) => {
debug!("Failed to get load from {}: {}", worker_url, e);
None
}
}
}
#[async_trait] #[async_trait]
impl RouterTrait for PDRouter { impl RouterTrait for PDRouter {
fn as_any(&self) -> &dyn std::any::Any { fn as_any(&self) -> &dyn std::any::Any {
...@@ -1418,44 +1357,6 @@ impl RouterTrait for PDRouter { ...@@ -1418,44 +1357,6 @@ impl RouterTrait for PDRouter {
self.execute_dual_dispatch(headers, body, context).await self.execute_dual_dispatch(headers, body, context).await
} }
async fn get_worker_loads(&self) -> Response {
let mut loads = HashMap::new();
let mut errors = Vec::new();
// Process prefill workers
let prefill_urls_with_key = self.get_prefill_worker_urls_with_api_key();
for (worker_url, api_key) in prefill_urls_with_key {
match get_worker_load(&self.client, &worker_url, &api_key).await {
Some(load) => {
loads.insert(format!("prefill_{}", worker_url), load);
}
None => {
errors.push(format!("Failed to get load from prefill {}", worker_url));
}
}
}
// Process decode workers
let decode_urls_with_key = self.get_decode_worker_urls_with_api_key();
for (worker_url, api_key) in decode_urls_with_key {
match get_worker_load(&self.client, &worker_url, &api_key).await {
Some(load) => {
loads.insert(format!("decode_{}", worker_url), load);
}
None => {
errors.push(format!("Failed to get load from decode {}", worker_url));
}
}
}
let response_data = serde_json::json!({
"loads": loads,
"errors": errors
});
(StatusCode::OK, Json(response_data)).into_response()
}
fn router_type(&self) -> &'static str { fn router_type(&self) -> &'static str {
"pd" "pd"
} }
......
use crate::config::types::RetryConfig; use crate::config::types::RetryConfig;
use crate::core::{ use crate::core::{
is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerRegistry, WorkerType, is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerManager, WorkerRegistry,
WorkerType,
}; };
use crate::metrics::RouterMetrics; use crate::metrics::RouterMetrics;
use crate::policies::{LoadBalancingPolicy, PolicyRegistry}; use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
...@@ -660,58 +661,6 @@ impl Router { ...@@ -660,58 +661,6 @@ impl Router {
} }
} }
async fn get_worker_load(&self, worker_url: &str, api_key: &Option<String>) -> Option<isize> {
let worker_url = if self.dp_aware {
// Need to extract the URL from "http://host:port@dp_rank"
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
Ok(tup) => tup,
Err(e) => {
error!("Failed to extract dp_rank: {}", e);
return None;
}
};
worker_url_prefix
} else {
worker_url
};
let mut req_builder = self.client.get(format!("{}/get_load", worker_url));
if let Some(key) = api_key {
req_builder = req_builder.bearer_auth(key);
}
match req_builder.send().await {
Ok(res) if res.status().is_success() => match res.bytes().await {
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
Ok(data) => data
.get("load")
.and_then(|v| v.as_i64())
.map(|v| v as isize),
Err(e) => {
debug!("Failed to parse load response from {}: {}", worker_url, e);
None
}
},
Err(e) => {
debug!("Failed to read load response from {}: {}", worker_url, e);
None
}
},
Ok(res) => {
debug!(
"Worker {} returned non-success status: {}",
worker_url,
res.status()
);
None
}
Err(e) => {
debug!("Failed to get load from {}: {}", worker_url, e);
None
}
}
}
// Background task to monitor worker loads // Background task to monitor worker loads
async fn monitor_worker_loads( async fn monitor_worker_loads(
worker_urls: Vec<String>, worker_urls: Vec<String>,
...@@ -728,7 +677,10 @@ impl Router { ...@@ -728,7 +677,10 @@ impl Router {
let mut loads = HashMap::new(); let mut loads = HashMap::new();
for (url, api_key) in worker_urls.iter().zip(worker_api_keys.iter()) { for (url, api_key) in worker_urls.iter().zip(worker_api_keys.iter()) {
if let Some(load) = Self::get_worker_load_static(&client, url, api_key).await { // Use WorkerManager for consistent load fetching
if let Some(load) =
WorkerManager::get_worker_load(url, api_key.as_deref(), &client).await
{
loads.insert(url.clone(), load); loads.insert(url.clone(), load);
} }
} }
...@@ -745,62 +697,6 @@ impl Router { ...@@ -745,62 +697,6 @@ impl Router {
} }
} }
// Static version of get_worker_load for use in monitoring task
async fn get_worker_load_static(
client: &Client,
worker_url: &str,
api_key: &Option<String>,
) -> Option<isize> {
let worker_url = if worker_url.contains("@") {
// Need to extract the URL from "http://host:port@dp_rank"
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
Ok(tup) => tup,
Err(e) => {
debug!("Failed to extract dp_rank: {}", e);
return None;
}
};
worker_url_prefix
} else {
worker_url
};
let mut req_builder = client.get(format!("{}/get_load", worker_url));
if let Some(key) = api_key {
req_builder = req_builder.bearer_auth(key);
}
match req_builder.send().await {
Ok(res) if res.status().is_success() => match res.bytes().await {
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
Ok(data) => data
.get("load")
.and_then(|v| v.as_i64())
.map(|v| v as isize),
Err(e) => {
debug!("Failed to parse load response from {}: {}", worker_url, e);
None
}
},
Err(e) => {
debug!("Failed to read load response from {}: {}", worker_url, e);
None
}
},
Ok(res) => {
debug!(
"Worker {} returned non-success status: {}",
worker_url,
res.status()
);
None
}
Err(e) => {
debug!("Failed to get load from {}: {}", worker_url, e);
None
}
}
}
async fn build_rerank_response( async fn build_rerank_response(
req: &RerankRequest, req: &RerankRequest,
response: Response, response: Response,
...@@ -953,25 +849,6 @@ impl RouterTrait for Router { ...@@ -953,25 +849,6 @@ impl RouterTrait for Router {
} }
} }
async fn get_worker_loads(&self) -> Response {
let urls_with_key = self.worker_registry.get_all_urls_with_api_key();
let mut loads = Vec::new();
// Get loads from all workers
for (url, api_key) in &urls_with_key {
let load = self.get_worker_load(url, api_key).await.unwrap_or(-1);
loads.push(serde_json::json!({
"worker": url,
"load": load
}));
}
Json(serde_json::json!({
"workers": loads
}))
.into_response()
}
fn router_type(&self) -> &'static str { fn router_type(&self) -> &'static str {
"regular" "regular"
} }
......
...@@ -126,9 +126,6 @@ pub trait RouterTrait: Send + Sync + Debug { ...@@ -126,9 +126,6 @@ pub trait RouterTrait: Send + Sync + Debug {
model_id: Option<&str>, model_id: Option<&str>,
) -> Response; ) -> Response;
/// Get worker loads (for monitoring)
async fn get_worker_loads(&self) -> Response;
/// Get router type name /// Get router type name
fn router_type(&self) -> &'static str; fn router_type(&self) -> &'static str;
......
...@@ -508,30 +508,6 @@ impl RouterTrait for RouterManager { ...@@ -508,30 +508,6 @@ impl RouterTrait for RouterManager {
} }
} }
async fn get_worker_loads(&self) -> Response {
let workers = self.worker_registry.get_all();
let loads: Vec<serde_json::Value> = workers
.iter()
.map(|w| {
serde_json::json!({
"url": w.url(),
"model": w.model_id(),
"load": w.load(),
"is_healthy": w.is_healthy()
})
})
.collect();
(
StatusCode::OK,
serde_json::json!({
"workers": loads
})
.to_string(),
)
.into_response()
}
fn router_type(&self) -> &'static str { fn router_type(&self) -> &'static str {
"manager" "manager"
} }
......
...@@ -28,7 +28,7 @@ use axum::{ ...@@ -28,7 +28,7 @@ use axum::{
}; };
use reqwest::Client; use reqwest::Client;
use serde::Deserialize; use serde::Deserialize;
use serde_json::json; use serde_json::{json, Value};
use std::{ use std::{
sync::atomic::{AtomicBool, Ordering}, sync::atomic::{AtomicBool, Ordering},
sync::Arc, sync::Arc,
...@@ -400,7 +400,28 @@ async fn flush_cache(State(state): State<Arc<AppState>>, _req: Request) -> Respo ...@@ -400,7 +400,28 @@ async fn flush_cache(State(state): State<Arc<AppState>>, _req: Request) -> Respo
} }
async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Response { async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Response {
state.router.get_worker_loads().await let result =
WorkerManager::get_all_worker_loads(&state.context.worker_registry, &state.context.client)
.await;
let loads: Vec<Value> = result
.loads
.iter()
.map(|info| {
json!({
"worker": &info.worker,
"load": info.load
})
})
.collect();
(
StatusCode::OK,
Json(json!({
"workers": loads
})),
)
.into_response()
} }
async fn create_worker( async fn create_worker(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment