Unverified Commit f2d5c492 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] add worker abstraction (#7960)

parent 2a2d3478
...@@ -30,6 +30,8 @@ tracing-appender = "0.2.3" ...@@ -30,6 +30,8 @@ tracing-appender = "0.2.3"
kube = { version = "0.88.1", features = ["runtime", "derive"] } kube = { version = "0.88.1", features = ["runtime", "derive"] }
k8s-openapi = { version = "0.21.0", features = ["v1_29"] } k8s-openapi = { version = "0.21.0", features = ["v1_29"] }
futures = "0.3" futures = "0.3"
async-trait = "0.1"
once_cell = "1.21"
# Added for metrics # Added for metrics
metrics = "0.24.2" metrics = "0.24.2"
metrics-exporter-prometheus = "0.17.0" metrics-exporter-prometheus = "0.17.0"
......
//! Error types for the SGLang router core
//!
//! This module defines error types used throughout the router for worker operations.
use std::fmt;
/// Worker-related errors
#[derive(Debug)]
pub enum WorkerError {
/// Health check failed
HealthCheckFailed { url: String, reason: String },
/// Worker not found
WorkerNotFound { url: String },
/// Invalid worker configuration
InvalidConfiguration { message: String },
/// Network error
NetworkError { url: String, error: String },
/// Worker is at capacity
WorkerAtCapacity { url: String },
}
impl fmt::Display for WorkerError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
WorkerError::HealthCheckFailed { url, reason } => {
write!(f, "Health check failed for worker {}: {}", url, reason)
}
WorkerError::WorkerNotFound { url } => {
write!(f, "Worker not found: {}", url)
}
WorkerError::InvalidConfiguration { message } => {
write!(f, "Invalid worker configuration: {}", message)
}
WorkerError::NetworkError { url, error } => {
write!(f, "Network error for worker {}: {}", url, error)
}
WorkerError::WorkerAtCapacity { url } => {
write!(f, "Worker at capacity: {}", url)
}
}
}
}
impl std::error::Error for WorkerError {}
/// Result type for worker operations
pub type WorkerResult<T> = Result<T, WorkerError>;
/// Convert from reqwest errors to worker errors
impl From<reqwest::Error> for WorkerError {
fn from(err: reqwest::Error) -> Self {
WorkerError::NetworkError {
url: err.url().map(|u| u.to_string()).unwrap_or_default(),
error: err.to_string(),
}
}
}
//! Core abstractions for the SGLang router
//!
//! This module contains the fundamental types and traits used throughout the router:
//! - Worker trait and implementations
//! - Error types
//! - Common utilities
pub mod error;
pub mod worker;
// Re-export commonly used types at the module level
pub use error::{WorkerError, WorkerResult};
pub use worker::{
start_health_checker, BasicWorker, HealthChecker, Worker, WorkerCollection, WorkerFactory,
WorkerLoadGuard, WorkerType,
};
use super::{WorkerError, WorkerResult};
use async_trait::async_trait;
use once_cell::sync::Lazy;
use std::fmt;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
// Shared HTTP client for health checks
static HEALTH_CHECK_CLIENT: Lazy<reqwest::Client> = Lazy::new(|| {
reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30)) // Default timeout, overridden per request
.build()
.expect("Failed to create health check HTTP client")
});
/// Core worker abstraction that represents a backend service
#[async_trait]
pub trait Worker: Send + Sync + fmt::Debug {
/// Get the worker's URL
fn url(&self) -> &str;
/// Get the worker's type (Regular, Prefill, or Decode)
fn worker_type(&self) -> WorkerType;
/// Check if the worker is currently healthy
fn is_healthy(&self) -> bool;
/// Set the worker's health status
fn set_healthy(&self, healthy: bool);
/// Perform an async health check on the worker
async fn check_health_async(&self) -> WorkerResult<()>;
/// Synchronous health check wrapper (for compatibility)
fn check_health(&self) -> WorkerResult<()> {
// Use a small runtime for synchronous contexts
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|e| WorkerError::HealthCheckFailed {
url: self.url().to_string(),
reason: format!("Failed to create runtime: {}", e),
})?
.block_on(self.check_health_async())
}
/// Get the current load (number of active requests)
fn load(&self) -> usize;
/// Increment the load counter
fn increment_load(&self);
/// Decrement the load counter
fn decrement_load(&self);
/// Get the number of processed requests
fn processed_requests(&self) -> usize;
/// Increment the processed requests counter
fn increment_processed(&self);
/// Get worker-specific metadata
fn metadata(&self) -> &WorkerMetadata;
/// Clone the worker (for trait objects)
fn clone_worker(&self) -> Box<dyn Worker>;
}
/// Worker type classification
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum WorkerType {
/// Regular worker for standard routing
Regular,
/// Prefill worker for PD disaggregated mode
Prefill {
/// Bootstrap port for communication with decode workers
bootstrap_port: Option<u16>,
},
/// Decode worker for PD disaggregated mode
Decode,
}
impl fmt::Display for WorkerType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
WorkerType::Regular => write!(f, "Regular"),
WorkerType::Prefill { bootstrap_port } => match bootstrap_port {
Some(port) => write!(f, "Prefill(bootstrap:{})", port),
None => write!(f, "Prefill"),
},
WorkerType::Decode => write!(f, "Decode"),
}
}
}
/// Health check configuration
#[derive(Debug, Clone)]
pub struct HealthConfig {
/// Timeout for health checks in seconds
pub timeout_secs: u64,
/// Interval between health checks in seconds
pub check_interval_secs: u64,
/// Health check endpoint path
pub endpoint: String,
}
impl Default for HealthConfig {
fn default() -> Self {
Self {
timeout_secs: 5,
check_interval_secs: 30,
endpoint: "/health".to_string(),
}
}
}
/// Metadata associated with a worker
#[derive(Debug, Clone)]
pub struct WorkerMetadata {
/// Worker URL
pub url: String,
/// Worker type
pub worker_type: WorkerType,
/// Additional labels/tags
pub labels: std::collections::HashMap<String, String>,
/// Health check configuration
pub health_config: HealthConfig,
}
/// Basic worker implementation
#[derive(Debug, Clone)]
pub struct BasicWorker {
metadata: WorkerMetadata,
load_counter: Arc<AtomicUsize>,
processed_counter: Arc<AtomicUsize>,
healthy: Arc<AtomicBool>,
}
impl BasicWorker {
pub fn new(url: String, worker_type: WorkerType) -> Self {
let metadata = WorkerMetadata {
url: url.clone(),
worker_type,
labels: std::collections::HashMap::new(),
health_config: HealthConfig::default(),
};
Self {
metadata,
load_counter: Arc::new(AtomicUsize::new(0)),
processed_counter: Arc::new(AtomicUsize::new(0)),
healthy: Arc::new(AtomicBool::new(true)),
}
}
pub fn with_labels(mut self, labels: std::collections::HashMap<String, String>) -> Self {
self.metadata.labels = labels;
self
}
pub fn with_health_config(mut self, config: HealthConfig) -> Self {
self.metadata.health_config = config;
self
}
}
#[async_trait]
impl Worker for BasicWorker {
fn url(&self) -> &str {
&self.metadata.url
}
fn worker_type(&self) -> WorkerType {
self.metadata.worker_type.clone()
}
fn is_healthy(&self) -> bool {
self.healthy.load(Ordering::Acquire)
}
fn set_healthy(&self, healthy: bool) {
self.healthy.store(healthy, Ordering::Release);
}
async fn check_health_async(&self) -> WorkerResult<()> {
use std::time::Duration;
// Perform actual HTTP health check
let health_url = format!("{}{}", self.url(), self.metadata.health_config.endpoint);
let timeout = Duration::from_secs(self.metadata.health_config.timeout_secs);
// Use the shared client with a custom timeout for this request
match HEALTH_CHECK_CLIENT
.get(&health_url)
.timeout(timeout)
.send()
.await
{
Ok(response) => {
if response.status().is_success() {
self.set_healthy(true);
Ok(())
} else {
self.set_healthy(false);
Err(WorkerError::HealthCheckFailed {
url: self.url().to_string(),
reason: format!("Health check returned status: {}", response.status()),
})
}
}
Err(e) => {
self.set_healthy(false);
Err(WorkerError::HealthCheckFailed {
url: self.url().to_string(),
reason: format!("Health check request failed: {}", e),
})
}
}
}
fn load(&self) -> usize {
self.load_counter.load(Ordering::Relaxed)
}
fn increment_load(&self) {
self.load_counter.fetch_add(1, Ordering::Relaxed);
}
fn decrement_load(&self) {
self.load_counter
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
current.checked_sub(1)
})
.ok();
}
fn processed_requests(&self) -> usize {
self.processed_counter.load(Ordering::Relaxed)
}
fn increment_processed(&self) {
self.processed_counter.fetch_add(1, Ordering::Relaxed);
}
fn metadata(&self) -> &WorkerMetadata {
&self.metadata
}
fn clone_worker(&self) -> Box<dyn Worker> {
Box::new(self.clone())
}
}
/// Worker factory for creating workers of different types
pub struct WorkerFactory;
impl WorkerFactory {
/// Create a regular worker
pub fn create_regular(url: String) -> Box<dyn Worker> {
Box::new(BasicWorker::new(url, WorkerType::Regular))
}
/// Create a prefill worker with optional bootstrap port
pub fn create_prefill(url: String, bootstrap_port: Option<u16>) -> Box<dyn Worker> {
Box::new(BasicWorker::new(
url,
WorkerType::Prefill { bootstrap_port },
))
}
/// Create a decode worker
pub fn create_decode(url: String) -> Box<dyn Worker> {
Box::new(BasicWorker::new(url, WorkerType::Decode))
}
/// Create workers from URLs with automatic type detection
pub fn create_from_urls(
regular_urls: Vec<String>,
prefill_urls: Vec<(String, Option<u16>)>,
decode_urls: Vec<String>,
) -> (
Vec<Box<dyn Worker>>,
Vec<Box<dyn Worker>>,
Vec<Box<dyn Worker>>,
) {
let regular_workers: Vec<Box<dyn Worker>> =
regular_urls.into_iter().map(Self::create_regular).collect();
let prefill_workers: Vec<Box<dyn Worker>> = prefill_urls
.into_iter()
.map(|(url, port)| Self::create_prefill(url, port))
.collect();
let decode_workers: Vec<Box<dyn Worker>> =
decode_urls.into_iter().map(Self::create_decode).collect();
(regular_workers, prefill_workers, decode_workers)
}
}
/// Helper trait for collections of workers
pub trait WorkerCollection {
fn healthy_workers(&self) -> Vec<&dyn Worker>;
fn total_load(&self) -> usize;
fn find_worker(&self, url: &str) -> Option<&dyn Worker>;
fn find_worker_mut(&mut self, url: &str) -> Option<&mut Box<dyn Worker>>;
}
impl WorkerCollection for Vec<Box<dyn Worker>> {
fn healthy_workers(&self) -> Vec<&dyn Worker> {
self.iter()
.filter(|w| w.is_healthy())
.map(|w| w.as_ref())
.collect()
}
fn total_load(&self) -> usize {
self.iter().map(|w| w.load()).sum()
}
fn find_worker(&self, url: &str) -> Option<&dyn Worker> {
self.iter().find(|w| w.url() == url).map(|w| w.as_ref())
}
fn find_worker_mut(&mut self, url: &str) -> Option<&mut Box<dyn Worker>> {
self.iter_mut().find(|w| w.url() == url)
}
}
/// Convert a list of worker URLs to worker trait objects
pub fn urls_to_workers(urls: Vec<String>) -> Vec<Box<dyn Worker>> {
urls.into_iter()
.map(WorkerFactory::create_regular)
.collect()
}
/// Convert worker trait objects back to URLs
pub fn workers_to_urls(workers: &[Box<dyn Worker>]) -> Vec<String> {
workers.iter().map(|w| w.url().to_string()).collect()
}
/// RAII guard for worker load management
pub struct WorkerLoadGuard<'a> {
workers: Vec<&'a dyn Worker>,
}
impl<'a> WorkerLoadGuard<'a> {
/// Create a new load guard for a single worker
pub fn new(worker: &'a dyn Worker) -> Self {
worker.increment_load();
Self {
workers: vec![worker],
}
}
/// Create a new load guard for multiple workers
pub fn new_multi(workers: Vec<&'a dyn Worker>) -> Self {
// Increment load counters for all workers
for worker in &workers {
worker.increment_load();
}
Self { workers }
}
}
impl<'a> Drop for WorkerLoadGuard<'a> {
fn drop(&mut self) {
// Decrement load counters for all workers
for worker in &self.workers {
worker.decrement_load();
}
}
}
/// Health checker handle with graceful shutdown
pub struct HealthChecker {
handle: tokio::task::JoinHandle<()>,
shutdown: Arc<AtomicBool>,
}
impl fmt::Debug for HealthChecker {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("HealthChecker")
.field("shutdown", &self.shutdown.load(Ordering::Relaxed))
.finish()
}
}
impl HealthChecker {
/// Shutdown the health checker gracefully
pub async fn shutdown(self) {
self.shutdown.store(true, Ordering::Release);
let _ = self.handle.await;
}
}
/// Start an async background health checker for a collection of workers
pub fn start_health_checker(
workers: std::sync::Arc<std::sync::RwLock<Vec<Box<dyn Worker>>>>,
check_interval_secs: u64,
) -> HealthChecker {
let shutdown = Arc::new(AtomicBool::new(false));
let shutdown_clone = shutdown.clone();
let handle = tokio::spawn(async move {
let mut interval =
tokio::time::interval(tokio::time::Duration::from_secs(check_interval_secs));
loop {
interval.tick().await;
// Check for shutdown signal
if shutdown_clone.load(Ordering::Acquire) {
tracing::info!("Health checker shutting down");
break;
}
// Check health of all workers
let workers_to_check = match workers.read() {
Ok(guard) => guard.iter().map(|w| w.clone_worker()).collect::<Vec<_>>(),
Err(poisoned) => {
tracing::error!("Worker lock poisoned: {}", poisoned);
continue;
}
};
// Perform health checks concurrently
let health_checks = workers_to_check.iter().map(|worker| {
let worker_url = worker.url().to_string();
let was_healthy = worker.is_healthy();
async move {
match worker.check_health_async().await {
Ok(_) => {
if !was_healthy {
tracing::info!("Worker {} is now healthy", worker_url);
}
}
Err(e) => {
if was_healthy {
tracing::warn!("Worker {} health check failed: {}", worker_url, e);
}
}
}
}
});
// Execute all health checks concurrently
futures::future::join_all(health_checks).await;
}
});
HealthChecker { handle, shutdown }
}
...@@ -2,6 +2,7 @@ use pyo3::prelude::*; ...@@ -2,6 +2,7 @@ use pyo3::prelude::*;
pub mod config; pub mod config;
pub mod logging; pub mod logging;
use std::collections::HashMap; use std::collections::HashMap;
pub mod core;
pub mod openai_api_types; pub mod openai_api_types;
pub mod pd_router; pub mod pd_router;
pub mod pd_types; pub mod pd_types;
......
// PD (Prefill-Decode) Router Implementation // PD (Prefill-Decode) Router Implementation
// This module handles routing for disaggregated prefill-decode systems // This module handles routing for disaggregated prefill-decode systems
use crate::core::{HealthChecker, Worker, WorkerFactory, WorkerLoadGuard};
use crate::pd_types::{ use crate::pd_types::{
Bootstrap, ChatReqInput, EngineInfo, GenerateReqInput, PDRouterError, PDSelectionPolicy, api_path, Bootstrap, ChatReqInput, GenerateReqInput, PDRouterError, PDSelectionPolicy,
}; };
use crate::tree::Tree; use crate::tree::Tree;
use actix_web::http::header::{HeaderValue, CONTENT_TYPE}; use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
...@@ -11,7 +12,6 @@ use futures_util::{StreamExt, TryStreamExt}; ...@@ -11,7 +12,6 @@ use futures_util::{StreamExt, TryStreamExt};
use metrics::{counter, histogram}; use metrics::{counter, histogram};
use serde_json::Value; use serde_json::Value;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex, RwLock}; use std::sync::{Arc, Mutex, RwLock};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use tracing::{debug, error, info, warn}; use tracing::{debug, error, info, warn};
...@@ -21,49 +21,17 @@ use uuid::Uuid; ...@@ -21,49 +21,17 @@ use uuid::Uuid;
#[derive(Debug)] #[derive(Debug)]
pub struct PDRouter { pub struct PDRouter {
pub prefill_workers: Arc<RwLock<Vec<EngineInfo>>>, pub prefill_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
pub decode_workers: Arc<RwLock<Vec<EngineInfo>>>, pub decode_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
pub selection_policy: PDSelectionPolicy, pub selection_policy: PDSelectionPolicy,
pub load_tracking: Arc<dashmap::DashMap<String, Arc<AtomicUsize>>>,
pub prefill_tree: Option<Arc<Mutex<Tree>>>, pub prefill_tree: Option<Arc<Mutex<Tree>>>,
pub timeout_secs: u64, pub timeout_secs: u64,
pub interval_secs: u64, pub interval_secs: u64,
pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>, pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
pub load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>, pub load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
pub http_client: reqwest::Client, pub http_client: reqwest::Client,
} _prefill_health_checker: Option<HealthChecker>,
_decode_health_checker: Option<HealthChecker>,
// RAII guard for load tracking to ensure cleanup even on panic
struct LoadGuard<'a> {
tracking: &'a Arc<dashmap::DashMap<String, Arc<AtomicUsize>>>,
urls: Vec<String>,
}
impl<'a> LoadGuard<'a> {
fn new(
tracking: &'a Arc<dashmap::DashMap<String, Arc<AtomicUsize>>>,
urls: Vec<String>,
) -> Self {
// Increment counters
for url in &urls {
let counter = tracking
.entry(url.clone())
.or_insert_with(|| Arc::new(AtomicUsize::new(0)));
counter.fetch_add(1, Ordering::Relaxed);
}
LoadGuard { tracking, urls }
}
}
impl Drop for LoadGuard<'_> {
fn drop(&mut self) {
// Guaranteed cleanup even on panic
for url in &self.urls {
if let Some(counter) = self.tracking.get(url) {
counter.fetch_sub(1, Ordering::Relaxed);
}
}
}
} }
impl PDRouter { impl PDRouter {
...@@ -73,9 +41,6 @@ impl PDRouter { ...@@ -73,9 +41,6 @@ impl PDRouter {
url: String, url: String,
bootstrap_port: Option<u16>, bootstrap_port: Option<u16>,
) -> Result<String, PDRouterError> { ) -> Result<String, PDRouterError> {
// Create EngineInfo for the new prefill server
let engine_info = EngineInfo::new_prefill(url.clone(), bootstrap_port);
// Wait for the new server to be healthy // Wait for the new server to be healthy
crate::router::Router::wait_for_healthy_workers( crate::router::Router::wait_for_healthy_workers(
&[url.clone()], &[url.clone()],
...@@ -84,6 +49,9 @@ impl PDRouter { ...@@ -84,6 +49,9 @@ impl PDRouter {
) )
.map_err(|_| PDRouterError::HealthCheckFailed { url: url.clone() })?; .map_err(|_| PDRouterError::HealthCheckFailed { url: url.clone() })?;
// Create Worker for the new prefill server
let worker = WorkerFactory::create_prefill(url.clone(), bootstrap_port);
// Add to prefill workers list // Add to prefill workers list
let mut workers = self let mut workers = self
.prefill_workers .prefill_workers
...@@ -93,15 +61,11 @@ impl PDRouter { ...@@ -93,15 +61,11 @@ impl PDRouter {
})?; })?;
// Check if already exists // Check if already exists
if workers.iter().any(|w| w.url == url) { if workers.iter().any(|w| w.url() == &url) {
return Err(PDRouterError::WorkerAlreadyExists { url: url.clone() }); return Err(PDRouterError::WorkerAlreadyExists { url: url.clone() });
} }
workers.push(engine_info); workers.push(worker);
// Initialize load tracking
self.load_tracking
.insert(url.clone(), Arc::new(AtomicUsize::new(0)));
// Add to cache tree if using cache-aware policy // Add to cache tree if using cache-aware policy
if let Some(ref tree) = self.prefill_tree { if let Some(ref tree) = self.prefill_tree {
...@@ -113,9 +77,6 @@ impl PDRouter { ...@@ -113,9 +77,6 @@ impl PDRouter {
} }
pub async fn add_decode_server(&self, url: String) -> Result<String, PDRouterError> { pub async fn add_decode_server(&self, url: String) -> Result<String, PDRouterError> {
// Create EngineInfo for the new decode server
let engine_info = EngineInfo::new_decode(url.clone());
// Wait for the new server to be healthy // Wait for the new server to be healthy
crate::router::Router::wait_for_healthy_workers( crate::router::Router::wait_for_healthy_workers(
&[url.clone()], &[url.clone()],
...@@ -124,6 +85,9 @@ impl PDRouter { ...@@ -124,6 +85,9 @@ impl PDRouter {
) )
.map_err(|_| PDRouterError::HealthCheckFailed { url: url.clone() })?; .map_err(|_| PDRouterError::HealthCheckFailed { url: url.clone() })?;
// Create Worker for the new decode server
let worker = WorkerFactory::create_decode(url.clone());
// Add to decode workers list // Add to decode workers list
let mut workers = self let mut workers = self
.decode_workers .decode_workers
...@@ -133,15 +97,14 @@ impl PDRouter { ...@@ -133,15 +97,14 @@ impl PDRouter {
})?; })?;
// Check if already exists // Check if already exists
if workers.iter().any(|w| w.url == url) { if workers.iter().any(|w| w.url() == &url) {
return Err(PDRouterError::WorkerAlreadyExists { url: url.clone() }); return Err(PDRouterError::WorkerAlreadyExists { url: url.clone() });
} }
workers.push(engine_info); workers.push(worker);
// Initialize load tracking // Initialize load tracking
self.load_tracking // Worker tracks its own load internally
.insert(url.clone(), Arc::new(AtomicUsize::new(0)));
info!("Added decode server: {}", url); info!("Added decode server: {}", url);
Ok(format!("Successfully added decode server: {}", url)) Ok(format!("Successfully added decode server: {}", url))
...@@ -157,7 +120,7 @@ impl PDRouter { ...@@ -157,7 +120,7 @@ impl PDRouter {
// Find and remove the server // Find and remove the server
let initial_len = workers.len(); let initial_len = workers.len();
workers.retain(|w| w.url != url); workers.retain(|w| w.url() != url);
if workers.len() == initial_len { if workers.len() == initial_len {
return Err(PDRouterError::WorkerNotFound { return Err(PDRouterError::WorkerNotFound {
...@@ -166,7 +129,7 @@ impl PDRouter { ...@@ -166,7 +129,7 @@ impl PDRouter {
} }
// Remove from load tracking // Remove from load tracking
self.load_tracking.remove(url); // Worker load tracking is internal
// Remove from cache tree if using cache-aware policy // Remove from cache tree if using cache-aware policy
if let Some(ref tree) = self.prefill_tree { if let Some(ref tree) = self.prefill_tree {
...@@ -174,7 +137,7 @@ impl PDRouter { ...@@ -174,7 +137,7 @@ impl PDRouter {
let mut tree_guard = tree.lock().unwrap(); let mut tree_guard = tree.lock().unwrap();
*tree_guard = Tree::new(); *tree_guard = Tree::new();
for worker in workers.iter() { for worker in workers.iter() {
tree_guard.insert("", &worker.url); tree_guard.insert("", worker.url());
} }
} }
...@@ -192,7 +155,7 @@ impl PDRouter { ...@@ -192,7 +155,7 @@ impl PDRouter {
// Find and remove the server // Find and remove the server
let initial_len = workers.len(); let initial_len = workers.len();
workers.retain(|w| w.url != url); workers.retain(|w| w.url() != url);
if workers.len() == initial_len { if workers.len() == initial_len {
return Err(PDRouterError::WorkerNotFound { return Err(PDRouterError::WorkerNotFound {
...@@ -200,9 +163,6 @@ impl PDRouter { ...@@ -200,9 +163,6 @@ impl PDRouter {
}); });
} }
// Remove from load tracking
self.load_tracking.remove(url);
info!("Removed decode server: {}", url); info!("Removed decode server: {}", url);
Ok(format!("Successfully removed decode server: {}", url)) Ok(format!("Successfully removed decode server: {}", url))
} }
...@@ -214,41 +174,32 @@ impl PDRouter { ...@@ -214,41 +174,32 @@ impl PDRouter {
timeout_secs: u64, timeout_secs: u64,
interval_secs: u64, interval_secs: u64,
) -> Result<Self, String> { ) -> Result<Self, String> {
// Convert URLs to EngineInfo // Convert URLs to Worker trait objects
let prefill_workers: Vec<EngineInfo> = prefill_urls let prefill_workers: Vec<Box<dyn Worker>> = prefill_urls
.into_iter() .into_iter()
.map(|(url, port)| EngineInfo::new_prefill(url, port)) .map(|(url, port)| WorkerFactory::create_prefill(url, port))
.collect(); .collect();
let decode_workers: Vec<EngineInfo> = decode_urls let decode_workers: Vec<Box<dyn Worker>> = decode_urls
.into_iter() .into_iter()
.map(EngineInfo::new_decode) .map(WorkerFactory::create_decode)
.collect(); .collect();
// Wait for PD workers to be healthy // Wait for PD workers to be healthy
let all_urls: Vec<String> = prefill_workers let all_urls: Vec<String> = prefill_workers
.iter() .iter()
.chain(decode_workers.iter()) .chain(decode_workers.iter())
.map(|engine| engine.url.clone()) .map(|worker| worker.url().to_string())
.collect(); .collect();
crate::router::Router::wait_for_healthy_workers(&all_urls, timeout_secs, interval_secs)?; crate::router::Router::wait_for_healthy_workers(&all_urls, timeout_secs, interval_secs)?;
// Initialize load tracking with atomic counters
let load_tracking = Arc::new(dashmap::DashMap::new());
for engine in &prefill_workers {
load_tracking.insert(engine.url.clone(), Arc::new(AtomicUsize::new(0)));
}
for engine in &decode_workers {
load_tracking.insert(engine.url.clone(), Arc::new(AtomicUsize::new(0)));
}
// Initialize cache-aware components if needed // Initialize cache-aware components if needed
let prefill_tree = match &selection_policy { let prefill_tree = match &selection_policy {
PDSelectionPolicy::CacheAware { .. } => { PDSelectionPolicy::CacheAware { .. } => {
let tree = Arc::new(Mutex::new(Tree::new())); let tree = Arc::new(Mutex::new(Tree::new()));
// Initialize tree with prefill workers // Initialize tree with prefill workers
for engine in &prefill_workers { for worker in &prefill_workers {
tree.lock().unwrap().insert("", &engine.url); tree.lock().unwrap().insert("", worker.url());
} }
Some(tree) Some(tree)
} }
...@@ -283,17 +234,27 @@ impl PDRouter { ...@@ -283,17 +234,27 @@ impl PDRouter {
None None
}; };
let prefill_workers = Arc::new(RwLock::new(prefill_workers));
let decode_workers = Arc::new(RwLock::new(decode_workers));
// Start health checkers for both worker pools
let prefill_health_checker =
crate::core::start_health_checker(Arc::clone(&prefill_workers), interval_secs);
let decode_health_checker =
crate::core::start_health_checker(Arc::clone(&decode_workers), interval_secs);
Ok(PDRouter { Ok(PDRouter {
prefill_workers: Arc::new(RwLock::new(prefill_workers)), prefill_workers,
decode_workers: Arc::new(RwLock::new(decode_workers)), decode_workers,
selection_policy, selection_policy,
load_tracking,
prefill_tree, prefill_tree,
timeout_secs, timeout_secs,
interval_secs, interval_secs,
worker_loads, worker_loads,
load_monitor_handle, load_monitor_handle,
http_client, http_client,
_prefill_health_checker: Some(prefill_health_checker),
_decode_health_checker: Some(decode_health_checker),
}) })
} }
...@@ -330,11 +291,13 @@ impl PDRouter { ...@@ -330,11 +291,13 @@ impl PDRouter {
// Log routing decision // Log routing decision
info!( info!(
"PD routing: {} -> prefill={}, decode={}", "PD routing: {} -> prefill={}, decode={}",
route, prefill.url, decode.url route,
prefill.url(),
decode.url()
); );
// Add bootstrap info using the trait method // Add bootstrap info using the trait method
if let Err(e) = typed_req.add_bootstrap_info(&prefill) { if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) {
error!("Failed to add bootstrap info: {}", e); error!("Failed to add bootstrap info: {}", e);
counter!("sgl_router_pd_errors_total", "error" => "bootstrap_injection").increment(1); counter!("sgl_router_pd_errors_total", "error" => "bootstrap_injection").increment(1);
return HttpResponse::InternalServerError() return HttpResponse::InternalServerError()
...@@ -356,8 +319,8 @@ impl PDRouter { ...@@ -356,8 +319,8 @@ impl PDRouter {
req, req,
json_with_bootstrap, json_with_bootstrap,
route, route,
&prefill, prefill.as_ref(),
&decode, decode.as_ref(),
is_stream, is_stream,
return_logprob, return_logprob,
start, start,
...@@ -397,11 +360,13 @@ impl PDRouter { ...@@ -397,11 +360,13 @@ impl PDRouter {
// Log routing decision // Log routing decision
info!( info!(
"PD routing: {} -> prefill={}, decode={}", "PD routing: {} -> prefill={}, decode={}",
route, prefill.url, decode.url route,
prefill.url(),
decode.url()
); );
// Add bootstrap info using the trait method // Add bootstrap info using the trait method
if let Err(e) = typed_req.add_bootstrap_info(&prefill) { if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) {
error!("Failed to add bootstrap info: {}", e); error!("Failed to add bootstrap info: {}", e);
counter!("sgl_router_pd_errors_total", "error" => "bootstrap_injection").increment(1); counter!("sgl_router_pd_errors_total", "error" => "bootstrap_injection").increment(1);
return HttpResponse::InternalServerError() return HttpResponse::InternalServerError()
...@@ -423,8 +388,8 @@ impl PDRouter { ...@@ -423,8 +388,8 @@ impl PDRouter {
req, req,
json_with_bootstrap, json_with_bootstrap,
route, route,
&prefill, prefill.as_ref(),
&decode, decode.as_ref(),
is_stream, is_stream,
return_logprob, return_logprob,
start, start,
...@@ -440,22 +405,23 @@ impl PDRouter { ...@@ -440,22 +405,23 @@ impl PDRouter {
req: &HttpRequest, req: &HttpRequest,
json_request: serde_json::Value, json_request: serde_json::Value,
route: &str, route: &str,
prefill: &EngineInfo, prefill: &dyn Worker,
decode: &EngineInfo, decode: &dyn Worker,
is_stream: bool, is_stream: bool,
return_logprob: bool, return_logprob: bool,
start_time: Instant, start_time: Instant,
) -> HttpResponse { ) -> HttpResponse {
// Update load tracking for both workers // Update load tracking for both workers
let _guard = LoadGuard::new( let _guard = WorkerLoadGuard::new_multi(vec![prefill, decode]);
&self.load_tracking,
vec![prefill.url.clone(), decode.url.clone()],
);
// Build requests using .json() method // Build requests using .json() method
let mut prefill_request = client.post(prefill.api_path(route)).json(&json_request); let mut prefill_request = client
.post(api_path(prefill.url(), route))
.json(&json_request);
let mut decode_request = client.post(decode.api_path(route)).json(&json_request); let mut decode_request = client
.post(api_path(decode.url(), route))
.json(&json_request);
// Copy headers from original request // Copy headers from original request
for (name, value) in crate::router::copy_request_headers(req) { for (name, value) in crate::router::copy_request_headers(req) {
...@@ -474,9 +440,9 @@ impl PDRouter { ...@@ -474,9 +440,9 @@ impl PDRouter {
histogram!("sgl_router_pd_request_duration_seconds", "route" => route.to_string()) histogram!("sgl_router_pd_request_duration_seconds", "route" => route.to_string())
.record(duration.as_secs_f64()); .record(duration.as_secs_f64());
counter!("sgl_router_pd_requests_total", "route" => route.to_string()).increment(1); counter!("sgl_router_pd_requests_total", "route" => route.to_string()).increment(1);
counter!("sgl_router_pd_prefill_requests_total", "worker" => prefill.url.to_string()) counter!("sgl_router_pd_prefill_requests_total", "worker" => prefill.url().to_string())
.increment(1); .increment(1);
counter!("sgl_router_pd_decode_requests_total", "worker" => decode.url.to_string()) counter!("sgl_router_pd_decode_requests_total", "worker" => decode.url().to_string())
.increment(1); .increment(1);
// Process decode response // Process decode response
...@@ -486,10 +452,11 @@ impl PDRouter { ...@@ -486,10 +452,11 @@ impl PDRouter {
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
if !status.is_success() { if !status.is_success() {
counter!("sgl_router_pd_decode_errors_total", "worker" => decode.url.to_string()).increment(1); counter!("sgl_router_pd_decode_errors_total", "worker" => decode.url().to_string()).increment(1);
error!( error!(
"Decode server {} returned error status: {}", "Decode server {} returned error status: {}",
decode.url, status decode.url(),
status
); );
// Return the error response from decode server // Return the error response from decode server
...@@ -508,9 +475,10 @@ impl PDRouter { ...@@ -508,9 +475,10 @@ impl PDRouter {
if let Err(e) = &prefill_result { if let Err(e) = &prefill_result {
error!( error!(
"Prefill server {} failed (non-critical): {}", "Prefill server {} failed (non-critical): {}",
prefill.url, e prefill.url(),
e
); );
counter!("sgl_router_pd_prefill_errors_total", "worker" => prefill.url.to_string()).increment(1); counter!("sgl_router_pd_prefill_errors_total", "worker" => prefill.url().to_string()).increment(1);
} }
if is_stream { if is_stream {
...@@ -559,7 +527,7 @@ impl PDRouter { ...@@ -559,7 +527,7 @@ impl PDRouter {
HttpResponse::build(status) HttpResponse::build(status)
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream"))) .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
.streaming({ .streaming({
let decode_url = decode.url.clone(); let decode_url = decode.url().to_string();
res.bytes_stream().map_err(move |e| { res.bytes_stream().map_err(move |e| {
error!("Stream error from decode server {}: {}", decode_url, e); error!("Stream error from decode server {}: {}", decode_url, e);
counter!("sgl_router_pd_stream_errors_total", "worker" => decode_url.to_string()).increment(1); counter!("sgl_router_pd_stream_errors_total", "worker" => decode_url.to_string()).increment(1);
...@@ -587,7 +555,7 @@ impl PDRouter { ...@@ -587,7 +555,7 @@ impl PDRouter {
} }
Err(e) => { Err(e) => {
error!("Decode request failed: {}", e); error!("Decode request failed: {}", e);
counter!("sgl_router_pd_decode_errors_total", "worker" => decode.url.to_string()) counter!("sgl_router_pd_decode_errors_total", "worker" => decode.url().to_string())
.increment(1); .increment(1);
HttpResponse::BadGateway().body(format!("Decode server error: {}", e)) HttpResponse::BadGateway().body(format!("Decode server error: {}", e))
} }
...@@ -652,7 +620,7 @@ impl PDRouter { ...@@ -652,7 +620,7 @@ impl PDRouter {
async fn select_pd_pair( async fn select_pd_pair(
&self, &self,
_client: &reqwest::Client, _client: &reqwest::Client,
) -> Result<(EngineInfo, EngineInfo), String> { ) -> Result<(Box<dyn Worker>, Box<dyn Worker>), String> {
// Check we have workers // Check we have workers
if self if self
.prefill_workers .prefill_workers
...@@ -681,17 +649,17 @@ impl PDRouter { ...@@ -681,17 +649,17 @@ impl PDRouter {
} }
} }
fn select_random(&self) -> Result<(EngineInfo, EngineInfo), String> { fn select_random(&self) -> Result<(Box<dyn Worker>, Box<dyn Worker>), String> {
let prefill_list = self.prefill_workers.read().map_err(|_| "Lock error")?; let prefill_list = self.prefill_workers.read().map_err(|_| "Lock error")?;
let decode_list = self.decode_workers.read().map_err(|_| "Lock error")?; let decode_list = self.decode_workers.read().map_err(|_| "Lock error")?;
let prefill = prefill_list[rand::random::<usize>() % prefill_list.len()].clone(); let prefill = prefill_list[rand::random::<usize>() % prefill_list.len()].clone_worker();
let decode = decode_list[rand::random::<usize>() % decode_list.len()].clone(); let decode = decode_list[rand::random::<usize>() % decode_list.len()].clone_worker();
Ok((prefill, decode)) Ok((prefill, decode))
} }
async fn select_power_of_two(&self) -> Result<(EngineInfo, EngineInfo), String> { async fn select_power_of_two(&self) -> Result<(Box<dyn Worker>, Box<dyn Worker>), String> {
let prefill_list = self.prefill_workers.read().map_err(|_| "Lock error")?; let prefill_list = self.prefill_workers.read().map_err(|_| "Lock error")?;
let decode_list = self.decode_workers.read().map_err(|_| "Lock error")?; let decode_list = self.decode_workers.read().map_err(|_| "Lock error")?;
...@@ -700,33 +668,45 @@ impl PDRouter { ...@@ -700,33 +668,45 @@ impl PDRouter {
let loads = self.worker_loads.borrow(); let loads = self.worker_loads.borrow();
let p1_load = loads.get(&prefill_list[p1_idx].url).copied().unwrap_or(0); let p1_load = loads
let p2_load = loads.get(&prefill_list[p2_idx].url).copied().unwrap_or(0); .get(prefill_list[p1_idx].url())
let d1_load = loads.get(&decode_list[d1_idx].url).copied().unwrap_or(0); .copied()
let d2_load = loads.get(&decode_list[d2_idx].url).copied().unwrap_or(0); .unwrap_or(isize::MAX);
let p2_load = loads
.get(prefill_list[p2_idx].url())
.copied()
.unwrap_or(isize::MAX);
let d1_load = loads
.get(decode_list[d1_idx].url())
.copied()
.unwrap_or(isize::MAX);
let d2_load = loads
.get(decode_list[d2_idx].url())
.copied()
.unwrap_or(isize::MAX);
info!( info!(
"Power-of-two selection - Prefill: {}={} vs {}={} | Decode: {}={} vs {}={}", "Power-of-two selection - Prefill: {}={} vs {}={} | Decode: {}={} vs {}={}",
prefill_list[p1_idx].url, prefill_list[p1_idx].url(),
p1_load, p1_load,
prefill_list[p2_idx].url, prefill_list[p2_idx].url(),
p2_load, p2_load,
decode_list[d1_idx].url, decode_list[d1_idx].url(),
d1_load, d1_load,
decode_list[d2_idx].url, decode_list[d2_idx].url(),
d2_load d2_load
); );
let selected_prefill = if p1_load <= p2_load { let selected_prefill = if p1_load <= p2_load {
prefill_list[p1_idx].clone() prefill_list[p1_idx].clone_worker()
} else { } else {
prefill_list[p2_idx].clone() prefill_list[p2_idx].clone_worker()
}; };
let selected_decode = if d1_load <= d2_load { let selected_decode = if d1_load <= d2_load {
decode_list[d1_idx].clone() decode_list[d1_idx].clone_worker()
} else { } else {
decode_list[d2_idx].clone() decode_list[d2_idx].clone_worker()
}; };
Ok((selected_prefill, selected_decode)) Ok((selected_prefill, selected_decode))
...@@ -868,11 +848,11 @@ impl PDRouter { ...@@ -868,11 +848,11 @@ impl PDRouter {
let mut worker_infos = Vec::new(); let mut worker_infos = Vec::new();
for worker in self.prefill_workers.read().unwrap().iter() { for worker in self.prefill_workers.read().unwrap().iter() {
worker_infos.push((worker.url.clone(), "prefill")); worker_infos.push((worker.url().to_string(), "prefill"));
} }
for worker in self.decode_workers.read().unwrap().iter() { for worker in self.decode_workers.read().unwrap().iter() {
worker_infos.push((worker.url.clone(), "decode")); worker_infos.push((worker.url().to_string(), "decode"));
} }
// Create tasks with URL tracking // Create tasks with URL tracking
...@@ -922,7 +902,7 @@ impl PDRouter { ...@@ -922,7 +902,7 @@ impl PDRouter {
pub async fn get_server_info(&self, client: &reqwest::Client) -> HttpResponse { pub async fn get_server_info(&self, client: &reqwest::Client) -> HttpResponse {
// Get info from the first decode server to match sglang's server info format // Get info from the first decode server to match sglang's server info format
let first_decode_url = if let Ok(workers) = self.decode_workers.read() { let first_decode_url = if let Ok(workers) = self.decode_workers.read() {
workers.first().map(|w| w.url.clone()) workers.first().map(|w| w.url().to_string())
} else { } else {
return HttpResponse::InternalServerError().body("Failed to access decode workers"); return HttpResponse::InternalServerError().body("Failed to access decode workers");
}; };
...@@ -967,7 +947,7 @@ impl PDRouter { ...@@ -967,7 +947,7 @@ impl PDRouter {
pub async fn get_models(&self, client: &reqwest::Client, req: &HttpRequest) -> HttpResponse { pub async fn get_models(&self, client: &reqwest::Client, req: &HttpRequest) -> HttpResponse {
// Get first prefill worker URL to avoid holding lock across await // Get first prefill worker URL to avoid holding lock across await
let first_worker_url = if let Ok(workers) = self.prefill_workers.read() { let first_worker_url = if let Ok(workers) = self.prefill_workers.read() {
workers.first().map(|w| w.url.clone()) workers.first().map(|w| w.url().to_string())
} else { } else {
return HttpResponse::InternalServerError().body("Failed to access prefill workers"); return HttpResponse::InternalServerError().body("Failed to access prefill workers");
}; };
...@@ -1005,14 +985,14 @@ impl PDRouter { ...@@ -1005,14 +985,14 @@ impl PDRouter {
.read() .read()
.unwrap() .unwrap()
.iter() .iter()
.map(|w| w.url.clone()) .map(|w| w.url().to_string())
.collect(); .collect();
let d_urls: Vec<_> = self let d_urls: Vec<_> = self
.decode_workers .decode_workers
.read() .read()
.unwrap() .unwrap()
.iter() .iter()
.map(|w| w.url.clone()) .map(|w| w.url().to_string())
.collect(); .collect();
let mut prefill_loads = Vec::new(); let mut prefill_loads = Vec::new();
...@@ -1048,7 +1028,7 @@ impl PDRouter { ...@@ -1048,7 +1028,7 @@ impl PDRouter {
// Get model info from the first prefill server (matches original Rust PDLB behavior) // Get model info from the first prefill server (matches original Rust PDLB behavior)
// Get first prefill worker URL to avoid holding lock across await // Get first prefill worker URL to avoid holding lock across await
let first_worker_url = if let Ok(workers) = self.prefill_workers.read() { let first_worker_url = if let Ok(workers) = self.prefill_workers.read() {
workers.first().map(|w| w.url.clone()) workers.first().map(|w| w.url().to_string())
} else { } else {
return HttpResponse::InternalServerError().body("Failed to access prefill workers"); return HttpResponse::InternalServerError().body("Failed to access prefill workers");
}; };
...@@ -1084,13 +1064,13 @@ impl PDRouter { ...@@ -1084,13 +1064,13 @@ impl PDRouter {
// Flush cache on all prefill servers // Flush cache on all prefill servers
for worker in self.prefill_workers.read().unwrap().iter() { for worker in self.prefill_workers.read().unwrap().iter() {
let url = format!("{}/flush_cache", worker.url); let url = format!("{}/flush_cache", worker.url());
tasks.push(client.post(&url).send()); tasks.push(client.post(&url).send());
} }
// Flush cache on all decode servers // Flush cache on all decode servers
for worker in self.decode_workers.read().unwrap().iter() { for worker in self.decode_workers.read().unwrap().iter() {
let url = format!("{}/flush_cache", worker.url); let url = format!("{}/flush_cache", worker.url());
tasks.push(client.post(&url).send()); tasks.push(client.post(&url).send());
} }
......
// Essential PDLB types extracted for PD routing // Essential PDLB types extracted for PD routing
use crate::core::{Worker, WorkerType};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
...@@ -28,52 +29,21 @@ pub enum PDRouterError { ...@@ -28,52 +29,21 @@ pub enum PDRouterError {
Timeout { url: String }, Timeout { url: String },
} }
#[derive(Debug, Clone)] // Helper functions for workers
pub enum EngineType { pub fn api_path(url: &str, api_path: &str) -> String {
Prefill, if api_path.starts_with("/") {
Decode, format!("{}{}", url, api_path)
} } else {
format!("{}/{}", url, api_path)
#[derive(Debug, Clone)]
pub struct EngineInfo {
pub engine_type: EngineType,
pub url: String,
pub bootstrap_port: Option<u16>,
}
impl EngineInfo {
pub fn new_prefill(url: String, bootstrap_port: Option<u16>) -> Self {
EngineInfo {
engine_type: EngineType::Prefill,
url,
bootstrap_port,
}
}
pub fn new_decode(url: String) -> Self {
EngineInfo {
engine_type: EngineType::Decode,
url,
bootstrap_port: None,
}
}
pub fn api_path(&self, api_path: &str) -> String {
if api_path.starts_with("/") {
format!("{}{}", self.url, api_path)
} else {
format!("{}/{}", self.url, api_path)
}
} }
}
pub fn get_hostname(&self) -> String { pub fn get_hostname(url: &str) -> String {
// Simple hostname extraction without external dependencies // Simple hostname extraction without external dependencies
let url = self let url = url
.url .trim_start_matches("http://")
.trim_start_matches("http://") .trim_start_matches("https://");
.trim_start_matches("https://"); url.split(':').next().unwrap_or("localhost").to_string()
url.split(':').next().unwrap_or("localhost").to_string()
}
} }
// PD-specific routing policies // PD-specific routing policies
...@@ -112,12 +82,21 @@ pub trait Bootstrap: Send + Sync { ...@@ -112,12 +82,21 @@ pub trait Bootstrap: Send + Sync {
bootstrap_room: BootstrapRoom, bootstrap_room: BootstrapRoom,
); );
fn add_bootstrap_info(&mut self, prefill_info: &EngineInfo) -> Result<(), String> { fn add_bootstrap_info(&mut self, prefill_worker: &dyn Worker) -> Result<(), String> {
let batch_size = self.get_batch_size()?; let batch_size = self.get_batch_size()?;
// Extract bootstrap port from prefill worker if it's a prefill type
let bootstrap_port = match prefill_worker.worker_type() {
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
_ => None,
};
let hostname = get_hostname(prefill_worker.url());
if let Some(batch_size) = batch_size { if let Some(batch_size) = batch_size {
self.set_bootstrap_info( self.set_bootstrap_info(
BootstrapHost::Batch(vec![prefill_info.get_hostname(); batch_size]), BootstrapHost::Batch(vec![hostname; batch_size]),
BootstrapPort::Batch(vec![prefill_info.bootstrap_port; batch_size]), BootstrapPort::Batch(vec![bootstrap_port; batch_size]),
// Use high-quality random numbers to minimize collision risk // Use high-quality random numbers to minimize collision risk
BootstrapRoom::Batch( BootstrapRoom::Batch(
(0..batch_size) (0..batch_size)
...@@ -132,8 +111,8 @@ pub trait Bootstrap: Send + Sync { ...@@ -132,8 +111,8 @@ pub trait Bootstrap: Send + Sync {
); );
} else { } else {
self.set_bootstrap_info( self.set_bootstrap_info(
BootstrapHost::Single(prefill_info.get_hostname()), BootstrapHost::Single(hostname),
BootstrapPort::Single(prefill_info.bootstrap_port), BootstrapPort::Single(bootstrap_port),
BootstrapRoom::Single({ BootstrapRoom::Single({
// Use high-quality random number for single requests too // Use high-quality random number for single requests too
let r1 = rand::random::<u64>(); let r1 = rand::random::<u64>();
......
use crate::core::{HealthChecker, Worker, WorkerFactory};
use crate::pd_router::PDRouter; use crate::pd_router::PDRouter;
use crate::pd_types::PDSelectionPolicy; use crate::pd_types::PDSelectionPolicy;
use crate::tree::Tree; use crate::tree::Tree;
...@@ -5,7 +6,6 @@ use ::metrics::{counter, gauge, histogram}; ...@@ -5,7 +6,6 @@ use ::metrics::{counter, gauge, histogram};
use actix_web::http::header::{HeaderValue, CONTENT_TYPE}; use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
use actix_web::{HttpRequest, HttpResponse}; use actix_web::{HttpRequest, HttpResponse};
use futures_util::{StreamExt, TryStreamExt}; use futures_util::{StreamExt, TryStreamExt};
use std::collections::HashMap;
use std::fmt::Debug; use std::fmt::Debug;
use std::sync::atomic::AtomicUsize; use std::sync::atomic::AtomicUsize;
use std::sync::{Arc, Mutex, RwLock}; use std::sync::{Arc, Mutex, RwLock};
...@@ -30,15 +30,17 @@ pub fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> { ...@@ -30,15 +30,17 @@ pub fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> {
#[derive(Debug)] #[derive(Debug)]
pub enum Router { pub enum Router {
RoundRobin { RoundRobin {
worker_urls: Arc<RwLock<Vec<String>>>, workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
current_index: AtomicUsize, current_index: AtomicUsize,
timeout_secs: u64, timeout_secs: u64,
interval_secs: u64, interval_secs: u64,
_health_checker: Option<HealthChecker>,
}, },
Random { Random {
worker_urls: Arc<RwLock<Vec<String>>>, workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
timeout_secs: u64, timeout_secs: u64,
interval_secs: u64, interval_secs: u64,
_health_checker: Option<HealthChecker>,
}, },
PrefillDecode { PrefillDecode {
pd_router: Arc<PDRouter>, pd_router: Arc<PDRouter>,
...@@ -104,16 +106,15 @@ pub enum Router { ...@@ -104,16 +106,15 @@ pub enum Router {
Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted
during the next eviction cycle. during the next eviction cycle.
*/ */
worker_urls: Arc<RwLock<Vec<String>>>, workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
tree: Arc<Mutex<Tree>>, tree: Arc<Mutex<Tree>>,
running_queue: Arc<Mutex<HashMap<String, usize>>>,
processed_queue: Arc<Mutex<HashMap<String, usize>>>,
cache_threshold: f32, cache_threshold: f32,
balance_abs_threshold: usize, balance_abs_threshold: usize,
balance_rel_threshold: f32, balance_rel_threshold: f32,
timeout_secs: u64, timeout_secs: u64,
interval_secs: u64, interval_secs: u64,
_eviction_thread: Option<thread::JoinHandle<()>>, _eviction_thread: Option<thread::JoinHandle<()>>,
_health_checker: Option<HealthChecker>,
}, },
} }
...@@ -192,25 +193,43 @@ impl Router { ...@@ -192,25 +193,43 @@ impl Router {
} }
} }
// Create Worker trait objects from URLs
let workers: Vec<Box<dyn Worker>> = worker_urls
.iter()
.map(|url| WorkerFactory::create_regular(url.clone()))
.collect();
// Create router based on policy... // Create router based on policy...
Ok(match policy_config { Ok(match policy_config {
PolicyConfig::RandomConfig { PolicyConfig::RandomConfig {
timeout_secs, timeout_secs,
interval_secs, interval_secs,
} => Router::Random { } => {
worker_urls: Arc::new(RwLock::new(worker_urls)), let workers = Arc::new(RwLock::new(workers));
timeout_secs, let health_checker =
interval_secs, crate::core::start_health_checker(Arc::clone(&workers), interval_secs);
}, Router::Random {
workers,
timeout_secs,
interval_secs,
_health_checker: Some(health_checker),
}
}
PolicyConfig::RoundRobinConfig { PolicyConfig::RoundRobinConfig {
timeout_secs, timeout_secs,
interval_secs, interval_secs,
} => Router::RoundRobin { } => {
worker_urls: Arc::new(RwLock::new(worker_urls)), let workers = Arc::new(RwLock::new(workers));
current_index: std::sync::atomic::AtomicUsize::new(0), let health_checker =
timeout_secs, crate::core::start_health_checker(Arc::clone(&workers), interval_secs);
interval_secs, Router::RoundRobin {
}, workers,
current_index: std::sync::atomic::AtomicUsize::new(0),
timeout_secs,
interval_secs,
_health_checker: Some(health_checker),
}
}
PolicyConfig::CacheAwareConfig { PolicyConfig::CacheAwareConfig {
cache_threshold, cache_threshold,
balance_abs_threshold, balance_abs_threshold,
...@@ -220,24 +239,12 @@ impl Router { ...@@ -220,24 +239,12 @@ impl Router {
timeout_secs, timeout_secs,
interval_secs, interval_secs,
} => { } => {
let mut running_queue = HashMap::new();
for url in &worker_urls {
running_queue.insert(url.clone(), 0);
}
let mut processed_queue = HashMap::new();
for url in &worker_urls {
processed_queue.insert(url.clone(), 0);
}
let tree = Arc::new(Mutex::new(Tree::new())); let tree = Arc::new(Mutex::new(Tree::new()));
let running_queue = Arc::new(Mutex::new(running_queue));
let processed_queue = Arc::new(Mutex::new(processed_queue));
// Create background eviction thread // Create background eviction thread
let tree_clone = Arc::clone(&tree); let tree_clone = Arc::clone(&tree);
let processed_queue_clone = Arc::clone(&processed_queue); let workers = Arc::new(RwLock::new(workers));
let running_queue_clone = Arc::clone(&running_queue); let workers_clone = Arc::clone(&workers);
let eviction_thread = thread::spawn(move || { let eviction_thread = thread::spawn(move || {
loop { loop {
// Sleep for the specified interval // Sleep for the specified interval
...@@ -246,32 +253,41 @@ impl Router { ...@@ -246,32 +253,41 @@ impl Router {
let locked_tree_clone = tree_clone.lock().unwrap(); let locked_tree_clone = tree_clone.lock().unwrap();
// Run eviction // Run eviction
locked_tree_clone.evict_tenant_by_size(max_tree_size); locked_tree_clone.evict_tenant_by_size(max_tree_size);
drop(locked_tree_clone);
// Print the process queue
let locked_processed_queue = processed_queue_clone.lock().unwrap(); // Log worker loads and processed requests
info!("Processed Queue: {:?}", locked_processed_queue); let workers_guard = workers_clone.read().unwrap();
let loads: Vec<(String, usize)> = workers_guard
// Print the running queue .iter()
let locked_running_queue = running_queue_clone.lock().unwrap(); .map(|w| (w.url().to_string(), w.load()))
info!("Running Queue: {:?}", locked_running_queue); .collect();
info!("Worker loads: {:?}", loads);
let processed: Vec<(String, usize)> = workers_guard
.iter()
.map(|w| (w.url().to_string(), w.processed_requests()))
.collect();
info!("Processed requests: {:?}", processed);
} }
}); });
for url in &worker_urls { for worker in workers.read().unwrap().iter() {
tree.lock().unwrap().insert("", url); tree.lock().unwrap().insert("", worker.url());
} }
let health_checker =
crate::core::start_health_checker(Arc::clone(&workers), interval_secs);
Router::CacheAware { Router::CacheAware {
worker_urls: Arc::new(RwLock::new(worker_urls)), workers,
tree, tree,
running_queue,
processed_queue,
cache_threshold, cache_threshold,
balance_abs_threshold, balance_abs_threshold,
balance_rel_threshold, balance_rel_threshold,
timeout_secs, timeout_secs,
interval_secs, interval_secs,
_eviction_thread: Some(eviction_thread), _eviction_thread: Some(eviction_thread),
_health_checker: Some(health_checker),
} }
} }
PolicyConfig::PrefillDecodeConfig { PolicyConfig::PrefillDecodeConfig {
...@@ -297,16 +313,18 @@ impl Router { ...@@ -297,16 +313,18 @@ impl Router {
}) })
} }
/// Get a reference to the worker URLs shared across threads /// Get the current list of worker URLs
pub fn get_worker_urls(&self) -> Arc<RwLock<Vec<String>>> { pub fn get_worker_urls(&self) -> Vec<String> {
match self { match self {
Router::RoundRobin { worker_urls, .. } => Arc::clone(worker_urls), Router::RoundRobin { workers, .. }
Router::Random { worker_urls, .. } => Arc::clone(worker_urls), | Router::Random { workers, .. }
Router::CacheAware { worker_urls, .. } => Arc::clone(worker_urls), | Router::CacheAware { workers, .. } => workers
Router::PrefillDecode { .. } => { .read()
// For PD mode, return empty list since we manage workers differently .unwrap()
Arc::new(RwLock::new(Vec::new())) .iter()
} .map(|w| w.url().to_string())
.collect(),
Router::PrefillDecode { .. } => Vec::new(),
} }
} }
...@@ -373,13 +391,14 @@ impl Router { ...@@ -373,13 +391,14 @@ impl Router {
fn select_first_worker(&self) -> Result<String, String> { fn select_first_worker(&self) -> Result<String, String> {
match self { match self {
Router::RoundRobin { worker_urls, .. } Router::RoundRobin { workers, .. }
| Router::Random { worker_urls, .. } | Router::Random { workers, .. }
| Router::CacheAware { worker_urls, .. } => { | Router::CacheAware { workers, .. } => {
if worker_urls.read().unwrap().is_empty() { let workers_guard = workers.read().unwrap();
if workers_guard.is_empty() {
Err("No workers are available".to_string()) Err("No workers are available".to_string())
} else { } else {
Ok(worker_urls.read().unwrap()[0].clone()) Ok(workers_guard[0].url().to_string())
} }
} }
Router::PrefillDecode { .. } => { Router::PrefillDecode { .. } => {
...@@ -514,7 +533,7 @@ impl Router { ...@@ -514,7 +533,7 @@ impl Router {
return HttpResponse::NotImplemented() return HttpResponse::NotImplemented()
.body("route_to_all not implemented for PrefillDecode mode"); .body("route_to_all not implemented for PrefillDecode mode");
} }
_ => self.get_worker_urls().read().unwrap().clone(), _ => self.get_worker_urls(),
}; };
// Send requests to all workers concurrently // Send requests to all workers concurrently
...@@ -562,7 +581,7 @@ impl Router { ...@@ -562,7 +581,7 @@ impl Router {
} }
} }
let urls = self.get_worker_urls().read().unwrap().clone(); let urls = self.get_worker_urls();
let prefill_urls: Vec<String> = Vec::new(); let prefill_urls: Vec<String> = Vec::new();
let decode_urls = urls; let decode_urls = urls;
...@@ -631,6 +650,24 @@ impl Router { ...@@ -631,6 +650,24 @@ impl Router {
.increment(1); .increment(1);
} }
// For CacheAware router, increment load before request
let load_incremented = match self {
Router::CacheAware { workers, .. } => {
let workers_guard = workers.read().unwrap();
if let Some(worker) =
workers_guard.iter().find(|w| w.url() == &worker_url)
{
worker.increment_load();
gauge!("sgl_router_running_requests", "worker" => worker_url.to_string())
.set(worker.load() as f64);
true
} else {
false
}
}
_ => false,
};
// Send typed request directly // Send typed request directly
let response = self let response = self
.send_typed_request( .send_typed_request(
...@@ -640,6 +677,7 @@ impl Router { ...@@ -640,6 +677,7 @@ impl Router {
route, route,
&worker_url, &worker_url,
is_stream, is_stream,
load_incremented,
) )
.await; .await;
...@@ -684,44 +722,47 @@ impl Router { ...@@ -684,44 +722,47 @@ impl Router {
} }
} }
// Helper method to select worker from text // Helper method to select worker from text (returns index for RoundRobin/Random, URL for CacheAware)
fn select_generate_worker_from_text(&self, text: &str) -> String { fn select_generate_worker_from_text(&self, text: &str) -> String {
match self { match self {
Router::RoundRobin { Router::RoundRobin {
worker_urls, workers,
current_index, current_index,
.. ..
} => { } => {
let workers_guard = workers.read().unwrap();
let idx = current_index let idx = current_index
.fetch_update( .fetch_update(
std::sync::atomic::Ordering::SeqCst, std::sync::atomic::Ordering::SeqCst,
std::sync::atomic::Ordering::SeqCst, std::sync::atomic::Ordering::SeqCst,
|x| Some((x + 1) % worker_urls.read().unwrap().len()), |x| Some((x + 1) % workers_guard.len()),
) )
.unwrap(); .unwrap();
worker_urls.read().unwrap()[idx].clone() workers_guard[idx].url().to_string()
} }
Router::Random { worker_urls, .. } => worker_urls.read().unwrap() Router::Random { workers, .. } => {
[rand::random::<usize>() % worker_urls.read().unwrap().len()] let workers_guard = workers.read().unwrap();
.clone(), workers_guard[rand::random::<usize>() % workers_guard.len()]
.url()
.to_string()
}
Router::CacheAware { Router::CacheAware {
worker_urls, workers,
tree, tree,
running_queue,
processed_queue,
cache_threshold, cache_threshold,
balance_abs_threshold, balance_abs_threshold,
balance_rel_threshold, balance_rel_threshold,
.. ..
} => { } => {
let tree = tree.lock().unwrap(); let tree = tree.lock().unwrap();
let mut running_queue = running_queue.lock().unwrap(); let workers_guard = workers.read().unwrap();
// Get current load statistics // Get current load statistics from workers
let max_load = *running_queue.values().max().unwrap_or(&0); let loads: Vec<usize> = workers_guard.iter().map(|w| w.load()).collect();
let min_load = *running_queue.values().min().unwrap_or(&0); let max_load = *loads.iter().max().unwrap_or(&0);
let min_load = *loads.iter().min().unwrap_or(&0);
// Load is considered imbalanced if: // Load is considered imbalanced if:
// 1. (max - min) > abs_threshold AND // 1. (max - min) > abs_threshold AND
...@@ -731,11 +772,16 @@ impl Router { ...@@ -731,11 +772,16 @@ impl Router {
let selected_url = if is_imbalanced { let selected_url = if is_imbalanced {
// Log load balancing trigger and current queue state // Log load balancing trigger and current queue state
let worker_loads: Vec<(String, usize)> = workers_guard
.iter()
.map(|w| (w.url().to_string(), w.load()))
.collect();
info!( info!(
"Load balancing triggered due to workload imbalance:\n\ "Load balancing triggered due to workload imbalance:\n\
Max load: {}, Min load: {}\n\ Max load: {}, Min load: {}\n\
Current running queue: {:?}", Current worker loads: {:?}",
max_load, min_load, running_queue max_load, min_load, worker_loads
); );
counter!("sgl_router_load_balancing_events_total").increment(1); counter!("sgl_router_load_balancing_events_total").increment(1);
...@@ -743,11 +789,11 @@ impl Router { ...@@ -743,11 +789,11 @@ impl Router {
gauge!("sgl_router_min_load").set(min_load as f64); gauge!("sgl_router_min_load").set(min_load as f64);
// Use shortest queue routing when load is imbalanced // Use shortest queue routing when load is imbalanced
running_queue workers_guard
.iter() .iter()
.min_by_key(|(_url, &count)| count) .min_by_key(|w| w.load())
.map(|(url, _)| url.clone()) .map(|w| w.url().to_string())
.unwrap_or_else(|| worker_urls.read().unwrap()[0].clone()) .unwrap_or_else(|| workers_guard[0].url().to_string())
} else { } else {
// Use cache-aware routing when load is balanced // Use cache-aware routing when load is balanced
let (matched_text, matched_worker) = tree.prefix_match(&text); let (matched_text, matched_worker) = tree.prefix_match(&text);
...@@ -763,18 +809,12 @@ impl Router { ...@@ -763,18 +809,12 @@ impl Router {
} }
}; };
// Update queues and tree // Find the selected worker and increment processed counter only
*running_queue.get_mut(&selected_url).unwrap() += 1; if let Some(worker) = workers_guard.iter().find(|w| w.url() == &selected_url) {
worker.increment_processed();
*processed_queue counter!("sgl_router_processed_requests_total", "worker" => selected_url.to_string())
.lock() .increment(1);
.unwrap() }
.get_mut(&selected_url)
.unwrap() += 1;
gauge!("sgl_router_running_requests", "worker" => selected_url.to_string())
.set(*running_queue.get(&selected_url).unwrap() as f64);
counter!("sgl_router_processed_requests_total", "worker" => selected_url.to_string()).increment(1);
tree.insert(&text, &selected_url); tree.insert(&text, &selected_url);
...@@ -796,6 +836,7 @@ impl Router { ...@@ -796,6 +836,7 @@ impl Router {
route: &str, route: &str,
worker_url: &str, worker_url: &str,
is_stream: bool, is_stream: bool,
load_incremented: bool, // Whether load was incremented for this request
) -> HttpResponse { ) -> HttpResponse {
let start = Instant::now(); let start = Instant::now();
...@@ -820,6 +861,22 @@ impl Router { ...@@ -820,6 +861,22 @@ impl Router {
Ok(res) => res, Ok(res) => res,
Err(e) => { Err(e) => {
error!("Failed to send request to {}: {}", worker_url, e); error!("Failed to send request to {}: {}", worker_url, e);
// Decrement load on error for CacheAware router
if load_incremented {
if let Router::CacheAware { workers, .. } = self {
if let Ok(workers_guard) = workers.read() {
if let Some(worker) =
workers_guard.iter().find(|w| w.url() == worker_url)
{
worker.decrement_load();
gauge!("sgl_router_running_requests", "worker" => worker_url.to_string())
.set(worker.load() as f64);
}
}
}
}
return HttpResponse::InternalServerError().body(format!("Request failed: {}", e)); return HttpResponse::InternalServerError().body(format!("Request failed: {}", e));
} }
}; };
...@@ -837,13 +894,15 @@ impl Router { ...@@ -837,13 +894,15 @@ impl Router {
} }
}; };
// Then decrement running queue counter if using CacheAware // Decrement load counter for non-streaming CacheAware requests
if let Router::CacheAware { running_queue, .. } = self { if load_incremented && !is_stream {
if let Ok(mut queue) = running_queue.lock() { if let Router::CacheAware { workers, .. } = self {
if let Some(count) = queue.get_mut(worker_url) { if let Ok(workers_guard) = workers.read() {
*count = count.saturating_sub(1); if let Some(worker) = workers_guard.iter().find(|w| w.url() == worker_url) {
gauge!("sgl_router_running_requests", "worker" => worker_url.to_string()) worker.decrement_load();
.set(*count as f64); gauge!("sgl_router_running_requests", "worker" => worker_url.to_string())
.set(worker.load() as f64);
}
} }
} }
} }
...@@ -855,8 +914,9 @@ impl Router { ...@@ -855,8 +914,9 @@ impl Router {
counter!("sgl_router_requests_total", "route" => route.to_string()).increment(1); counter!("sgl_router_requests_total", "route" => route.to_string()).increment(1);
response response
} else if let Router::CacheAware { running_queue, .. } = self { } else if let Router::CacheAware { workers, .. } = self {
let running_queue = Arc::clone(running_queue); // For streaming with CacheAware router, we need to manually decrement when done
let workers = Arc::clone(workers);
let worker_url = worker_url.to_string(); let worker_url = worker_url.to_string();
HttpResponse::build(status) HttpResponse::build(status)
...@@ -867,21 +927,28 @@ impl Router { ...@@ -867,21 +927,28 @@ impl Router {
actix_web::error::ErrorInternalServerError("Failed to read stream") actix_web::error::ErrorInternalServerError("Failed to read stream")
}) })
.inspect(move |bytes| { .inspect(move |bytes| {
let bytes = bytes.as_ref().unwrap(); if let Ok(bytes) = bytes {
if bytes if bytes
.as_ref() .as_ref()
.windows(12) .windows(12)
.any(|window| window == b"data: [DONE]") .any(|window| window == b"data: [DONE]")
{ {
let mut locked_queue = running_queue.lock().unwrap(); if let Ok(workers_guard) = workers.read() {
let count = locked_queue.get_mut(&worker_url).unwrap(); if let Some(worker) =
*count = count.saturating_sub(1); workers_guard.iter().find(|w| w.url() == &worker_url)
gauge!("sgl_router_running_requests", "worker" => worker_url.to_string()).set(*count as f64); {
debug!("Streaming is done!!") worker.decrement_load();
gauge!("sgl_router_running_requests", "worker" => worker_url.to_string())
.set(worker.load() as f64);
debug!("Streaming is done!!")
}
}
}
} }
}), }),
) )
} else { } else {
// For non-CacheAware routers, just stream without load tracking
HttpResponse::build(status) HttpResponse::build(status)
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream"))) .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
.streaming(res.bytes_stream().map_err(|_| { .streaming(res.bytes_stream().map_err(|_| {
...@@ -935,43 +1002,27 @@ impl Router { ...@@ -935,43 +1002,27 @@ impl Router {
Ok(res) => { Ok(res) => {
if res.status().is_success() { if res.status().is_success() {
match self { match self {
Router::RoundRobin { worker_urls, .. } Router::RoundRobin { workers, .. }
| Router::Random { worker_urls, .. } | Router::Random { workers, .. }
| Router::CacheAware { worker_urls, .. } => { | Router::CacheAware { workers, .. } => {
info!("Worker {} health check passed", worker_url); info!("Worker {} health check passed", worker_url);
let mut urls = worker_urls.write().unwrap(); let mut workers_guard = workers.write().unwrap();
if urls.contains(&worker_url.to_string()) { if workers_guard.iter().any(|w| w.url() == worker_url) {
return Err(format!("Worker {} already exists", worker_url)); return Err(format!("Worker {} already exists", worker_url));
} }
info!("Added worker: {}", worker_url); info!("Added worker: {}", worker_url);
urls.push(worker_url.to_string()); let new_worker =
gauge!("sgl_router_active_workers").set(urls.len() as f64); WorkerFactory::create_regular(worker_url.to_string());
workers_guard.push(new_worker);
gauge!("sgl_router_active_workers").set(workers_guard.len() as f64);
} }
Router::PrefillDecode { .. } => { Router::PrefillDecode { .. } => {
return Err("Adding workers to PrefillDecode router not supported via add_worker. Use dedicated PD management methods.".to_string()); return Err("Adding workers to PrefillDecode router not supported via add_worker. Use dedicated PD management methods.".to_string());
} }
} }
// If cache aware, initialize the queues for the new worker // If cache aware, add worker to tree
if let Router::CacheAware { if let Router::CacheAware { tree, .. } = self {
running_queue,
processed_queue,
tree,
..
} = self
{
// Add worker to running queue with initial count of 0
running_queue
.lock()
.unwrap()
.insert(worker_url.to_string(), 0);
// Add worker to processed queue with initial count of 0
processed_queue
.lock()
.unwrap()
.insert(worker_url.to_string(), 0);
// Add worker to tree // Add worker to tree
tree.lock().unwrap().insert("", worker_url); tree.lock().unwrap().insert("", worker_url);
} }
...@@ -1013,14 +1064,14 @@ impl Router { ...@@ -1013,14 +1064,14 @@ impl Router {
pub fn remove_worker(&self, worker_url: &str) { pub fn remove_worker(&self, worker_url: &str) {
match self { match self {
Router::RoundRobin { worker_urls, .. } Router::RoundRobin { workers, .. }
| Router::Random { worker_urls, .. } | Router::Random { workers, .. }
| Router::CacheAware { worker_urls, .. } => { | Router::CacheAware { workers, .. } => {
let mut urls = worker_urls.write().unwrap(); let mut workers_guard = workers.write().unwrap();
if let Some(index) = urls.iter().position(|url| url == &worker_url) { if let Some(index) = workers_guard.iter().position(|w| w.url() == worker_url) {
urls.remove(index); workers_guard.remove(index);
info!("Removed worker: {}", worker_url); info!("Removed worker: {}", worker_url);
gauge!("sgl_router_active_workers").set(urls.len() as f64); gauge!("sgl_router_active_workers").set(workers_guard.len() as f64);
} else { } else {
warn!("Worker {} not found, skipping removal", worker_url); warn!("Worker {} not found, skipping removal", worker_url);
return; return;
...@@ -1033,26 +1084,9 @@ impl Router { ...@@ -1033,26 +1084,9 @@ impl Router {
} }
// if cache aware, remove the worker from the tree // if cache aware, remove the worker from the tree
if let Router::CacheAware { if let Router::CacheAware { tree, .. } = self {
tree,
running_queue,
processed_queue,
..
} = self
{
tree.lock().unwrap().remove_tenant(&worker_url); tree.lock().unwrap().remove_tenant(&worker_url);
running_queue info!("Removed worker from tree: {}", worker_url);
.lock()
.unwrap()
.remove(&worker_url.to_string());
processed_queue
.lock()
.unwrap()
.remove(&worker_url.to_string());
info!(
"Removed worker from tree and cleaned up queues: {}",
worker_url
);
} }
} }
...@@ -1241,21 +1275,22 @@ mod tests { ...@@ -1241,21 +1275,22 @@ mod tests {
use crate::service_discovery::PodType; use crate::service_discovery::PodType;
fn create_test_regular_router() -> Router { fn create_test_regular_router() -> Router {
let workers = vec![
WorkerFactory::create_regular("http://worker1:8080".to_string()),
WorkerFactory::create_regular("http://worker2:8080".to_string()),
];
Router::Random { Router::Random {
worker_urls: Arc::new(RwLock::new(vec![ workers: Arc::new(RwLock::new(workers)),
"http://worker1:8080".to_string(),
"http://worker2:8080".to_string(),
])),
timeout_secs: 5, timeout_secs: 5,
interval_secs: 1, interval_secs: 1,
_health_checker: None,
} }
} }
#[test] #[test]
fn test_router_get_worker_urls_regular() { fn test_router_get_worker_urls_regular() {
let router = create_test_regular_router(); let router = create_test_regular_router();
let worker_urls = router.get_worker_urls(); let urls = router.get_worker_urls();
let urls = worker_urls.read().unwrap();
assert_eq!(urls.len(), 2); assert_eq!(urls.len(), 2);
assert!(urls.contains(&"http://worker1:8080".to_string())); assert!(urls.contains(&"http://worker1:8080".to_string()));
......
...@@ -236,8 +236,7 @@ async fn add_worker( ...@@ -236,8 +236,7 @@ async fn add_worker(
#[get("/list_workers")] #[get("/list_workers")]
async fn list_workers(data: web::Data<AppState>) -> impl Responder { async fn list_workers(data: web::Data<AppState>) -> impl Responder {
let workers = data.router.get_worker_urls(); let worker_list = data.router.get_worker_urls();
let worker_list = workers.read().unwrap().clone();
HttpResponse::Ok().json(serde_json::json!({ "urls": worker_list })) HttpResponse::Ok().json(serde_json::json!({ "urls": worker_list }))
} }
...@@ -381,7 +380,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { ...@@ -381,7 +380,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
info!("✅ Serving router on {}:{}", config.host, config.port); info!("✅ Serving router on {}:{}", config.host, config.port);
info!( info!(
"✅ Serving workers on {:?}", "✅ Serving workers on {:?}",
app_state.router.get_worker_urls().read().unwrap() app_state.router.get_worker_urls()
); );
HttpServer::new(move || { HttpServer::new(move || {
......
...@@ -547,11 +547,12 @@ mod tests { ...@@ -547,11 +547,12 @@ mod tests {
// Helper to create a Router instance for testing event handlers // Helper to create a Router instance for testing event handlers
fn create_test_router() -> Arc<Router> { fn create_test_router() -> Arc<Router> {
let worker_urls = Arc::new(RwLock::new(Vec::new())); let workers = Arc::new(RwLock::new(Vec::new()));
Arc::new(Router::Random { Arc::new(Router::Random {
worker_urls, workers,
timeout_secs: 5, timeout_secs: 5,
interval_secs: 1, interval_secs: 1,
_health_checker: None,
}) })
} }
...@@ -878,8 +879,6 @@ mod tests { ...@@ -878,8 +879,6 @@ mod tests {
assert!(!tracked_pods.lock().unwrap().contains(&pod_info)); assert!(!tracked_pods.lock().unwrap().contains(&pod_info));
assert!(!router assert!(!router
.get_worker_urls() .get_worker_urls()
.read()
.unwrap()
.contains(&pod_info.worker_url(port))); .contains(&pod_info.worker_url(port)));
} }
...@@ -907,7 +906,7 @@ mod tests { ...@@ -907,7 +906,7 @@ mod tests {
.await; .await;
assert!(tracked_pods.lock().unwrap().is_empty()); assert!(tracked_pods.lock().unwrap().is_empty());
assert!(router.get_worker_urls().read().unwrap().is_empty()); assert!(router.get_worker_urls().is_empty());
} }
#[tokio::test] #[tokio::test]
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
mod test_pd_routing { mod test_pd_routing {
use rand::Rng; use rand::Rng;
use serde_json::json; use serde_json::json;
use sglang_router_rs::pd_types::{EngineInfo, EngineType, PDSelectionPolicy}; use sglang_router_rs::pd_types::PDSelectionPolicy;
use sglang_router_rs::router::{PolicyConfig, Router}; use sglang_router_rs::router::{PolicyConfig, Router};
// Test-only struct to help validate PD request parsing // Test-only struct to help validate PD request parsing
...@@ -51,40 +51,35 @@ mod test_pd_routing { ...@@ -51,40 +51,35 @@ mod test_pd_routing {
// ======================================================================== // ========================================================================
#[test] #[test]
fn test_engine_info_creation() { fn test_worker_types() {
// Test EngineInfo creation for prefill servers use sglang_router_rs::core::{WorkerFactory, WorkerType};
let prefill_engine = EngineInfo::new_prefill("http://prefill:8080".to_string(), Some(9000));
match prefill_engine.engine_type { // Test worker creation for prefill servers
EngineType::Prefill => (), let prefill_worker =
_ => panic!("Expected Prefill engine type"), WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9000));
assert_eq!(prefill_worker.url(), "http://prefill:8080");
match prefill_worker.worker_type() {
WorkerType::Prefill { bootstrap_port } => {
assert_eq!(bootstrap_port, Some(9000));
}
_ => panic!("Expected Prefill worker type"),
} }
assert_eq!(prefill_engine.url, "http://prefill:8080");
assert_eq!(prefill_engine.bootstrap_port, Some(9000)); // Test worker creation for decode servers
assert_eq!(prefill_engine.get_hostname(), "prefill"); let decode_worker = WorkerFactory::create_decode("http://decode:8080".to_string());
assert_eq!(decode_worker.url(), "http://decode:8080");
// Test EngineInfo creation for decode servers match decode_worker.worker_type() {
let decode_engine = EngineInfo::new_decode("http://decode:8080".to_string()); WorkerType::Decode => (),
match decode_engine.engine_type { _ => panic!("Expected Decode worker type"),
EngineType::Decode => (),
_ => panic!("Expected Decode engine type"),
} }
assert_eq!(decode_engine.url, "http://decode:8080");
assert_eq!(decode_engine.bootstrap_port, None);
assert_eq!(decode_engine.get_hostname(), "decode");
// Test API path generation // Test regular worker creation
assert_eq!( let regular_worker = WorkerFactory::create_regular("http://regular:8080".to_string());
prefill_engine.api_path("/generate"), assert_eq!(regular_worker.url(), "http://regular:8080");
"http://prefill:8080/generate" match regular_worker.worker_type() {
); WorkerType::Regular => (),
assert_eq!( _ => panic!("Expected Regular worker type"),
prefill_engine.api_path("health"), }
"http://prefill:8080/health"
);
assert_eq!(
decode_engine.api_path("/v1/chat/completions"),
"http://decode:8080/v1/chat/completions"
);
} }
#[test] #[test]
...@@ -230,6 +225,9 @@ mod test_pd_routing { ...@@ -230,6 +225,9 @@ mod test_pd_routing {
#[test] #[test]
fn test_bootstrap_injection_simulation() { fn test_bootstrap_injection_simulation() {
use sglang_router_rs::core::{WorkerFactory, WorkerType};
use sglang_router_rs::pd_types::get_hostname;
// Since we can't test the actual inject_bootstrap_fields function here // Since we can't test the actual inject_bootstrap_fields function here
// (it's private in the router module), we'll test the expected behavior // (it's private in the router module), we'll test the expected behavior
...@@ -240,15 +238,24 @@ mod test_pd_routing { ...@@ -240,15 +238,24 @@ mod test_pd_routing {
"temperature": 0.7 "temperature": 0.7
}); });
// Create a prefill worker to simulate injection
let prefill_worker =
WorkerFactory::create_prefill("http://prefill1:8080".to_string(), Some(9000));
// Extract bootstrap port from worker type
let bootstrap_port = match prefill_worker.worker_type() {
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
_ => None,
};
// Simulate what inject_bootstrap_fields would do // Simulate what inject_bootstrap_fields would do
let prefill_info = EngineInfo::new_prefill("http://prefill1:8080".to_string(), Some(9000)); single_json["bootstrap_host"] = json!(get_hostname(prefill_worker.url()));
single_json["bootstrap_host"] = json!(prefill_info.get_hostname()); single_json["bootstrap_port"] = json!(bootstrap_port);
single_json["bootstrap_port"] = json!(prefill_info.bootstrap_port);
single_json["bootstrap_room"] = json!(12345u64); // Random room ID single_json["bootstrap_room"] = json!(12345u64); // Random room ID
// Verify bootstrap fields are added correctly // Verify bootstrap fields are added correctly
assert_eq!(single_json["bootstrap_host"], "prefill1"); assert_eq!(single_json["bootstrap_host"], "prefill1");
assert_eq!(single_json["bootstrap_port"], 9000); assert_eq!(single_json["bootstrap_port"], json!(Some(9000)));
assert!(single_json["bootstrap_room"].is_u64()); assert!(single_json["bootstrap_room"].is_u64());
assert_eq!(single_json["temperature"], 0.7); // Original field preserved assert_eq!(single_json["temperature"], 0.7); // Original field preserved
...@@ -259,8 +266,9 @@ mod test_pd_routing { ...@@ -259,8 +266,9 @@ mod test_pd_routing {
}); });
let batch_size = 3; let batch_size = 3;
batch_json["bootstrap_host"] = json!(vec![prefill_info.get_hostname(); batch_size]); let hostname = get_hostname(prefill_worker.url());
batch_json["bootstrap_port"] = json!(vec![prefill_info.bootstrap_port; batch_size]); batch_json["bootstrap_host"] = json!(vec![hostname; batch_size]);
batch_json["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
batch_json["bootstrap_room"] = json!(vec![111u64, 222u64, 333u64]); batch_json["bootstrap_room"] = json!(vec![111u64, 222u64, 333u64]);
// Verify batch bootstrap fields // Verify batch bootstrap fields
...@@ -306,7 +314,9 @@ mod test_pd_routing { ...@@ -306,7 +314,9 @@ mod test_pd_routing {
} }
#[test] #[test]
fn test_engine_info_hostname_extraction() { fn test_hostname_extraction() {
use sglang_router_rs::pd_types::get_hostname;
// Test various URL formats // Test various URL formats
let test_cases = vec![ let test_cases = vec![
("http://localhost:8080", "localhost"), ("http://localhost:8080", "localhost"),
...@@ -318,8 +328,7 @@ mod test_pd_routing { ...@@ -318,8 +328,7 @@ mod test_pd_routing {
]; ];
for (url, expected_hostname) in test_cases { for (url, expected_hostname) in test_cases {
let engine = EngineInfo::new_prefill(url.to_string(), None); assert_eq!(get_hostname(url), expected_hostname);
assert_eq!(engine.get_hostname(), expected_hostname);
} }
} }
...@@ -652,6 +661,9 @@ mod test_pd_routing { ...@@ -652,6 +661,9 @@ mod test_pd_routing {
#[test] #[test]
fn test_bootstrap_injection_with_benchmark_requests() { fn test_bootstrap_injection_with_benchmark_requests() {
use sglang_router_rs::core::{WorkerFactory, WorkerType};
use sglang_router_rs::pd_types::get_hostname;
// Test bootstrap injection with actual benchmark request patterns // Test bootstrap injection with actual benchmark request patterns
let mut benchmark_request = json!({ let mut benchmark_request = json!({
"input_ids": vec![vec![1, 2, 3, 4]; 16], // Batch size 16 "input_ids": vec![vec![1, 2, 3, 4]; 16], // Batch size 16
...@@ -664,12 +676,20 @@ mod test_pd_routing { ...@@ -664,12 +676,20 @@ mod test_pd_routing {
"stream": true "stream": true
}); });
// Simulate bootstrap injection // Create a prefill worker to simulate injection
let prefill_info = EngineInfo::new_prefill("http://prefill:8080".to_string(), Some(9000)); let prefill_worker =
WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9000));
// Extract bootstrap port from worker type
let bootstrap_port = match prefill_worker.worker_type() {
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
_ => None,
};
let batch_size = 16; let batch_size = 16;
let hostname = get_hostname(prefill_worker.url());
benchmark_request["bootstrap_host"] = json!(vec![prefill_info.get_hostname(); batch_size]); benchmark_request["bootstrap_host"] = json!(vec![hostname; batch_size]);
benchmark_request["bootstrap_port"] = json!(vec![prefill_info.bootstrap_port; batch_size]); benchmark_request["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
benchmark_request["bootstrap_room"] = benchmark_request["bootstrap_room"] =
json!((0..batch_size).map(|_| 12345u64).collect::<Vec<_>>()); json!((0..batch_size).map(|_| 12345u64).collect::<Vec<_>>());
...@@ -770,6 +790,9 @@ mod test_pd_routing { ...@@ -770,6 +790,9 @@ mod test_pd_routing {
#[test] #[test]
fn test_large_batch_bootstrap_injection() { fn test_large_batch_bootstrap_injection() {
use sglang_router_rs::core::{WorkerFactory, WorkerType};
use sglang_router_rs::pd_types::get_hostname;
// Test bootstrap injection performance with very large batches // Test bootstrap injection performance with very large batches
// This simulates the bench_one_batch_server.py scenario // This simulates the bench_one_batch_server.py scenario
let large_batch_sizes = vec![1024, 4096, 8192]; let large_batch_sizes = vec![1024, 4096, 8192];
...@@ -787,14 +810,19 @@ mod test_pd_routing { ...@@ -787,14 +810,19 @@ mod test_pd_routing {
"stream": true "stream": true
}); });
// Simulate bootstrap injection // Create a prefill worker to simulate injection
let prefill_info = let prefill_worker =
EngineInfo::new_prefill("http://prefill:8080".to_string(), Some(9000)); WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9000));
// Extract bootstrap port from worker type
let bootstrap_port = match prefill_worker.worker_type() {
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
_ => None,
};
let hostname = get_hostname(prefill_worker.url());
large_batch_request["bootstrap_host"] = large_batch_request["bootstrap_host"] = json!(vec![hostname; batch_size]);
json!(vec![prefill_info.get_hostname(); batch_size]); large_batch_request["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
large_batch_request["bootstrap_port"] =
json!(vec![prefill_info.bootstrap_port; batch_size]);
large_batch_request["bootstrap_room"] = json!((0..batch_size) large_batch_request["bootstrap_room"] = json!((0..batch_size)
.map(|_| rand::thread_rng().gen::<u64>()) .map(|_| rand::thread_rng().gen::<u64>())
.collect::<Vec<_>>()); .collect::<Vec<_>>());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment