Unverified Commit e7387035 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] consolidate worker get loads (#10880)

parent fe531d6f
......@@ -12,7 +12,9 @@ use crate::core::{
Worker, WorkerFactory, WorkerRegistry, WorkerType,
};
use crate::policies::PolicyRegistry;
use crate::protocols::worker_spec::{FlushCacheResult, WorkerConfigRequest};
use crate::protocols::worker_spec::{
FlushCacheResult, WorkerConfigRequest, WorkerLoadInfo, WorkerLoadsResult,
};
use crate::server::AppContext;
use futures::future;
use once_cell::sync::Lazy;
......@@ -1079,6 +1081,100 @@ impl WorkerManager {
message,
})
}
pub async fn get_worker_load(
url: &str,
api_key: Option<&str>,
client: &reqwest::Client,
) -> Option<isize> {
let load_url = format!("{}/get_load", url);
let mut request = client.get(&load_url);
if let Some(key) = api_key {
request = request.bearer_auth(key);
}
match request.send().await {
Ok(response) if response.status().is_success() => {
match response.json::<Value>().await {
Ok(json) => {
if let Some(load) = json.get("load").and_then(|v| v.as_i64()) {
debug!("Worker {} load: {}", url, load);
Some(load as isize)
} else {
warn!("Invalid load response from {}: {:?}", url, json);
None
}
}
Err(e) => {
warn!("Failed to parse load response from {}: {}", url, e);
None
}
}
}
Ok(response) => {
warn!(
"Failed to get load from {}: HTTP {}",
url,
response.status()
);
None
}
Err(e) => {
warn!("Failed to connect to {} for load check: {}", url, e);
None
}
}
}
pub async fn get_all_worker_loads(
worker_registry: &WorkerRegistry,
client: &reqwest::Client,
) -> WorkerLoadsResult {
let workers = worker_registry.get_all();
let total_workers = workers.len();
// Prepare tasks for parallel execution
let mut tasks = Vec::new();
for worker in &workers {
let url = worker.url().to_string();
let api_key = worker.api_key().clone();
let worker_type = match worker.worker_type() {
WorkerType::Regular => None,
WorkerType::Prefill { .. } => Some("prefill".to_string()),
WorkerType::Decode => Some("decode".to_string()),
};
let is_http = matches!(worker.connection_mode(), ConnectionMode::Http);
let client = client.clone();
tasks.push(async move {
let load = if is_http {
Self::get_worker_load(&url, api_key.as_deref(), &client)
.await
.unwrap_or(-1)
} else {
-1
};
WorkerLoadInfo {
worker: url,
worker_type,
load,
}
});
}
let loads = futures::future::join_all(tasks).await;
let successful = loads.iter().filter(|l| l.load >= 0).count();
let failed = loads.iter().filter(|l| l.load < 0).count();
WorkerLoadsResult {
loads,
total_workers,
successful,
failed,
}
}
}
#[cfg(test)]
......
......@@ -215,3 +215,28 @@ pub struct FlushCacheResult {
/// Human-readable summary message
pub message: String,
}
/// Result from getting worker loads
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct WorkerLoadsResult {
/// Worker URL and load pairs
pub loads: Vec<WorkerLoadInfo>,
/// Total number of workers
pub total_workers: usize,
/// Number of workers with successful load fetches
pub successful: usize,
/// Number of workers with failed load fetches
pub failed: usize,
}
/// Individual worker load information
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct WorkerLoadInfo {
/// Worker URL
pub worker: String,
/// Worker type (regular, prefill, decode)
#[serde(skip_serializing_if = "Option::is_none")]
pub worker_type: Option<String>,
/// Current load (-1 indicates failure to fetch)
pub load: isize,
}
......@@ -340,10 +340,6 @@ impl RouterTrait for GrpcPDRouter {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn get_worker_loads(&self) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
fn router_type(&self) -> &'static str {
"grpc_pd"
}
......
......@@ -787,10 +787,6 @@ impl RouterTrait for GrpcRouter {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn get_worker_loads(&self) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
fn router_type(&self) -> &'static str {
"grpc"
}
......
......@@ -1296,14 +1296,6 @@ impl super::super::RouterTrait for OpenAIRouter {
}
}
async fn get_worker_loads(&self) -> Response {
(
StatusCode::FORBIDDEN,
"get_worker_loads not supported for OpenAI router",
)
.into_response()
}
fn router_type(&self) -> &'static str {
"openai"
}
......
use super::pd_types::api_path;
use crate::config::types::RetryConfig;
use crate::core::{
is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerLoadGuard, WorkerRegistry,
WorkerType,
is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerLoadGuard, WorkerManager,
WorkerRegistry, WorkerType,
};
use crate::metrics::RouterMetrics;
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
......@@ -18,7 +18,6 @@ use axum::{
extract::Request,
http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode},
response::{IntoResponse, Response},
Json,
};
use futures_util::StreamExt;
use reqwest::Client;
......@@ -53,26 +52,6 @@ struct PDRequestContext<'a> {
}
impl PDRouter {
fn _get_worker_url_and_key(&self, w: &Arc<dyn Worker>) -> (String, Option<String>) {
(w.url().to_string(), w.api_key().clone())
}
fn get_prefill_worker_urls_with_api_key(&self) -> Vec<(String, Option<String>)> {
self.worker_registry
.get_prefill_workers()
.iter()
.map(|w| self._get_worker_url_and_key(w))
.collect()
}
fn get_decode_worker_urls_with_api_key(&self) -> Vec<(String, Option<String>)> {
self.worker_registry
.get_decode_workers()
.iter()
.map(|w| self._get_worker_url_and_key(w))
.collect()
}
async fn proxy_to_first_prefill_worker(
&self,
endpoint: &str,
......@@ -749,7 +728,10 @@ impl PDRouter {
let url = url.clone();
let api_key = api_key.clone();
async move {
let load = get_worker_load(&client, &url, &api_key).await.unwrap_or(0);
let load =
WorkerManager::get_worker_load(&url, api_key.as_deref(), &client)
.await
.unwrap_or(0);
(url, load)
}
})
......@@ -1083,49 +1065,6 @@ impl PDRouter {
}
}
// Helper functions
async fn get_worker_load(
client: &Client,
worker_url: &str,
api_key: &Option<String>,
) -> Option<isize> {
let mut req_builder = client.get(format!("{}/get_load", worker_url));
if let Some(key) = api_key {
req_builder = req_builder.bearer_auth(key);
}
match req_builder.send().await {
Ok(res) if res.status().is_success() => match res.bytes().await {
Ok(bytes) => match serde_json::from_slice::<Value>(&bytes) {
Ok(data) => data
.get("load")
.and_then(|v| v.as_i64())
.map(|v| v as isize),
Err(e) => {
debug!("Failed to parse load response from {}: {}", worker_url, e);
None
}
},
Err(e) => {
debug!("Failed to read load response from {}: {}", worker_url, e);
None
}
},
Ok(res) => {
debug!(
"Worker {} returned non-success status: {}",
worker_url,
res.status()
);
None
}
Err(e) => {
debug!("Failed to get load from {}: {}", worker_url, e);
None
}
}
}
#[async_trait]
impl RouterTrait for PDRouter {
fn as_any(&self) -> &dyn std::any::Any {
......@@ -1418,44 +1357,6 @@ impl RouterTrait for PDRouter {
self.execute_dual_dispatch(headers, body, context).await
}
async fn get_worker_loads(&self) -> Response {
let mut loads = HashMap::new();
let mut errors = Vec::new();
// Process prefill workers
let prefill_urls_with_key = self.get_prefill_worker_urls_with_api_key();
for (worker_url, api_key) in prefill_urls_with_key {
match get_worker_load(&self.client, &worker_url, &api_key).await {
Some(load) => {
loads.insert(format!("prefill_{}", worker_url), load);
}
None => {
errors.push(format!("Failed to get load from prefill {}", worker_url));
}
}
}
// Process decode workers
let decode_urls_with_key = self.get_decode_worker_urls_with_api_key();
for (worker_url, api_key) in decode_urls_with_key {
match get_worker_load(&self.client, &worker_url, &api_key).await {
Some(load) => {
loads.insert(format!("decode_{}", worker_url), load);
}
None => {
errors.push(format!("Failed to get load from decode {}", worker_url));
}
}
}
let response_data = serde_json::json!({
"loads": loads,
"errors": errors
});
(StatusCode::OK, Json(response_data)).into_response()
}
fn router_type(&self) -> &'static str {
"pd"
}
......
use crate::config::types::RetryConfig;
use crate::core::{
is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerRegistry, WorkerType,
is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerManager, WorkerRegistry,
WorkerType,
};
use crate::metrics::RouterMetrics;
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
......@@ -660,58 +661,6 @@ impl Router {
}
}
async fn get_worker_load(&self, worker_url: &str, api_key: &Option<String>) -> Option<isize> {
let worker_url = if self.dp_aware {
// Need to extract the URL from "http://host:port@dp_rank"
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
Ok(tup) => tup,
Err(e) => {
error!("Failed to extract dp_rank: {}", e);
return None;
}
};
worker_url_prefix
} else {
worker_url
};
let mut req_builder = self.client.get(format!("{}/get_load", worker_url));
if let Some(key) = api_key {
req_builder = req_builder.bearer_auth(key);
}
match req_builder.send().await {
Ok(res) if res.status().is_success() => match res.bytes().await {
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
Ok(data) => data
.get("load")
.and_then(|v| v.as_i64())
.map(|v| v as isize),
Err(e) => {
debug!("Failed to parse load response from {}: {}", worker_url, e);
None
}
},
Err(e) => {
debug!("Failed to read load response from {}: {}", worker_url, e);
None
}
},
Ok(res) => {
debug!(
"Worker {} returned non-success status: {}",
worker_url,
res.status()
);
None
}
Err(e) => {
debug!("Failed to get load from {}: {}", worker_url, e);
None
}
}
}
// Background task to monitor worker loads
async fn monitor_worker_loads(
worker_urls: Vec<String>,
......@@ -728,7 +677,10 @@ impl Router {
let mut loads = HashMap::new();
for (url, api_key) in worker_urls.iter().zip(worker_api_keys.iter()) {
if let Some(load) = Self::get_worker_load_static(&client, url, api_key).await {
// Use WorkerManager for consistent load fetching
if let Some(load) =
WorkerManager::get_worker_load(url, api_key.as_deref(), &client).await
{
loads.insert(url.clone(), load);
}
}
......@@ -745,62 +697,6 @@ impl Router {
}
}
// Static version of get_worker_load for use in monitoring task
async fn get_worker_load_static(
client: &Client,
worker_url: &str,
api_key: &Option<String>,
) -> Option<isize> {
let worker_url = if worker_url.contains("@") {
// Need to extract the URL from "http://host:port@dp_rank"
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
Ok(tup) => tup,
Err(e) => {
debug!("Failed to extract dp_rank: {}", e);
return None;
}
};
worker_url_prefix
} else {
worker_url
};
let mut req_builder = client.get(format!("{}/get_load", worker_url));
if let Some(key) = api_key {
req_builder = req_builder.bearer_auth(key);
}
match req_builder.send().await {
Ok(res) if res.status().is_success() => match res.bytes().await {
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
Ok(data) => data
.get("load")
.and_then(|v| v.as_i64())
.map(|v| v as isize),
Err(e) => {
debug!("Failed to parse load response from {}: {}", worker_url, e);
None
}
},
Err(e) => {
debug!("Failed to read load response from {}: {}", worker_url, e);
None
}
},
Ok(res) => {
debug!(
"Worker {} returned non-success status: {}",
worker_url,
res.status()
);
None
}
Err(e) => {
debug!("Failed to get load from {}: {}", worker_url, e);
None
}
}
}
async fn build_rerank_response(
req: &RerankRequest,
response: Response,
......@@ -953,25 +849,6 @@ impl RouterTrait for Router {
}
}
async fn get_worker_loads(&self) -> Response {
let urls_with_key = self.worker_registry.get_all_urls_with_api_key();
let mut loads = Vec::new();
// Get loads from all workers
for (url, api_key) in &urls_with_key {
let load = self.get_worker_load(url, api_key).await.unwrap_or(-1);
loads.push(serde_json::json!({
"worker": url,
"load": load
}));
}
Json(serde_json::json!({
"workers": loads
}))
.into_response()
}
fn router_type(&self) -> &'static str {
"regular"
}
......
......@@ -126,9 +126,6 @@ pub trait RouterTrait: Send + Sync + Debug {
model_id: Option<&str>,
) -> Response;
/// Get worker loads (for monitoring)
async fn get_worker_loads(&self) -> Response;
/// Get router type name
fn router_type(&self) -> &'static str;
......
......@@ -508,30 +508,6 @@ impl RouterTrait for RouterManager {
}
}
async fn get_worker_loads(&self) -> Response {
let workers = self.worker_registry.get_all();
let loads: Vec<serde_json::Value> = workers
.iter()
.map(|w| {
serde_json::json!({
"url": w.url(),
"model": w.model_id(),
"load": w.load(),
"is_healthy": w.is_healthy()
})
})
.collect();
(
StatusCode::OK,
serde_json::json!({
"workers": loads
})
.to_string(),
)
.into_response()
}
fn router_type(&self) -> &'static str {
"manager"
}
......
......@@ -28,7 +28,7 @@ use axum::{
};
use reqwest::Client;
use serde::Deserialize;
use serde_json::json;
use serde_json::{json, Value};
use std::{
sync::atomic::{AtomicBool, Ordering},
sync::Arc,
......@@ -400,7 +400,28 @@ async fn flush_cache(State(state): State<Arc<AppState>>, _req: Request) -> Respo
}
async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Response {
state.router.get_worker_loads().await
let result =
WorkerManager::get_all_worker_loads(&state.context.worker_registry, &state.context.client)
.await;
let loads: Vec<Value> = result
.loads
.iter()
.map(|info| {
json!({
"worker": &info.worker,
"load": info.load
})
})
.collect();
(
StatusCode::OK,
Json(json!({
"workers": loads
})),
)
.into_response()
}
async fn create_worker(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment