Unverified Commit f2d5c492 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] add worker abstraction (#7960)

parent 2a2d3478
......@@ -30,6 +30,8 @@ tracing-appender = "0.2.3"
kube = { version = "0.88.1", features = ["runtime", "derive"] }
k8s-openapi = { version = "0.21.0", features = ["v1_29"] }
futures = "0.3"
async-trait = "0.1"
once_cell = "1.21"
# Added for metrics
metrics = "0.24.2"
metrics-exporter-prometheus = "0.17.0"
......
//! Error types for the SGLang router core
//!
//! This module defines error types used throughout the router for worker operations.
use std::fmt;
/// Worker-related errors
#[derive(Debug)]
pub enum WorkerError {
/// Health check failed
HealthCheckFailed { url: String, reason: String },
/// Worker not found
WorkerNotFound { url: String },
/// Invalid worker configuration
InvalidConfiguration { message: String },
/// Network error
NetworkError { url: String, error: String },
/// Worker is at capacity
WorkerAtCapacity { url: String },
}
impl fmt::Display for WorkerError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
WorkerError::HealthCheckFailed { url, reason } => {
write!(f, "Health check failed for worker {}: {}", url, reason)
}
WorkerError::WorkerNotFound { url } => {
write!(f, "Worker not found: {}", url)
}
WorkerError::InvalidConfiguration { message } => {
write!(f, "Invalid worker configuration: {}", message)
}
WorkerError::NetworkError { url, error } => {
write!(f, "Network error for worker {}: {}", url, error)
}
WorkerError::WorkerAtCapacity { url } => {
write!(f, "Worker at capacity: {}", url)
}
}
}
}
impl std::error::Error for WorkerError {}
/// Result type for worker operations
pub type WorkerResult<T> = Result<T, WorkerError>;
/// Convert from reqwest errors to worker errors
impl From<reqwest::Error> for WorkerError {
fn from(err: reqwest::Error) -> Self {
WorkerError::NetworkError {
url: err.url().map(|u| u.to_string()).unwrap_or_default(),
error: err.to_string(),
}
}
}
//! Core abstractions for the SGLang router
//!
//! This module contains the fundamental types and traits used throughout the router:
//! - Worker trait and implementations
//! - Error types
//! - Common utilities
pub mod error;
pub mod worker;
// Re-export commonly used types at the module level
pub use error::{WorkerError, WorkerResult};
pub use worker::{
start_health_checker, BasicWorker, HealthChecker, Worker, WorkerCollection, WorkerFactory,
WorkerLoadGuard, WorkerType,
};
use super::{WorkerError, WorkerResult};
use async_trait::async_trait;
use once_cell::sync::Lazy;
use std::fmt;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
// Shared HTTP client for health checks
static HEALTH_CHECK_CLIENT: Lazy<reqwest::Client> = Lazy::new(|| {
reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30)) // Default timeout, overridden per request
.build()
.expect("Failed to create health check HTTP client")
});
/// Core worker abstraction that represents a backend service
#[async_trait]
pub trait Worker: Send + Sync + fmt::Debug {
/// Get the worker's URL
fn url(&self) -> &str;
/// Get the worker's type (Regular, Prefill, or Decode)
fn worker_type(&self) -> WorkerType;
/// Check if the worker is currently healthy
fn is_healthy(&self) -> bool;
/// Set the worker's health status
fn set_healthy(&self, healthy: bool);
/// Perform an async health check on the worker
async fn check_health_async(&self) -> WorkerResult<()>;
/// Synchronous health check wrapper (for compatibility)
fn check_health(&self) -> WorkerResult<()> {
// Use a small runtime for synchronous contexts
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|e| WorkerError::HealthCheckFailed {
url: self.url().to_string(),
reason: format!("Failed to create runtime: {}", e),
})?
.block_on(self.check_health_async())
}
/// Get the current load (number of active requests)
fn load(&self) -> usize;
/// Increment the load counter
fn increment_load(&self);
/// Decrement the load counter
fn decrement_load(&self);
/// Get the number of processed requests
fn processed_requests(&self) -> usize;
/// Increment the processed requests counter
fn increment_processed(&self);
/// Get worker-specific metadata
fn metadata(&self) -> &WorkerMetadata;
/// Clone the worker (for trait objects)
fn clone_worker(&self) -> Box<dyn Worker>;
}
/// Worker type classification
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum WorkerType {
/// Regular worker for standard routing
Regular,
/// Prefill worker for PD disaggregated mode
Prefill {
/// Bootstrap port for communication with decode workers
bootstrap_port: Option<u16>,
},
/// Decode worker for PD disaggregated mode
Decode,
}
impl fmt::Display for WorkerType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
WorkerType::Regular => write!(f, "Regular"),
WorkerType::Prefill { bootstrap_port } => match bootstrap_port {
Some(port) => write!(f, "Prefill(bootstrap:{})", port),
None => write!(f, "Prefill"),
},
WorkerType::Decode => write!(f, "Decode"),
}
}
}
/// Health check configuration
#[derive(Debug, Clone)]
pub struct HealthConfig {
/// Timeout for health checks in seconds
pub timeout_secs: u64,
/// Interval between health checks in seconds
pub check_interval_secs: u64,
/// Health check endpoint path
pub endpoint: String,
}
impl Default for HealthConfig {
fn default() -> Self {
Self {
timeout_secs: 5,
check_interval_secs: 30,
endpoint: "/health".to_string(),
}
}
}
/// Metadata associated with a worker
#[derive(Debug, Clone)]
pub struct WorkerMetadata {
/// Worker URL
pub url: String,
/// Worker type
pub worker_type: WorkerType,
/// Additional labels/tags
pub labels: std::collections::HashMap<String, String>,
/// Health check configuration
pub health_config: HealthConfig,
}
/// Basic worker implementation
#[derive(Debug, Clone)]
pub struct BasicWorker {
metadata: WorkerMetadata,
load_counter: Arc<AtomicUsize>,
processed_counter: Arc<AtomicUsize>,
healthy: Arc<AtomicBool>,
}
impl BasicWorker {
pub fn new(url: String, worker_type: WorkerType) -> Self {
let metadata = WorkerMetadata {
url: url.clone(),
worker_type,
labels: std::collections::HashMap::new(),
health_config: HealthConfig::default(),
};
Self {
metadata,
load_counter: Arc::new(AtomicUsize::new(0)),
processed_counter: Arc::new(AtomicUsize::new(0)),
healthy: Arc::new(AtomicBool::new(true)),
}
}
pub fn with_labels(mut self, labels: std::collections::HashMap<String, String>) -> Self {
self.metadata.labels = labels;
self
}
pub fn with_health_config(mut self, config: HealthConfig) -> Self {
self.metadata.health_config = config;
self
}
}
#[async_trait]
impl Worker for BasicWorker {
fn url(&self) -> &str {
&self.metadata.url
}
fn worker_type(&self) -> WorkerType {
self.metadata.worker_type.clone()
}
fn is_healthy(&self) -> bool {
self.healthy.load(Ordering::Acquire)
}
fn set_healthy(&self, healthy: bool) {
self.healthy.store(healthy, Ordering::Release);
}
async fn check_health_async(&self) -> WorkerResult<()> {
use std::time::Duration;
// Perform actual HTTP health check
let health_url = format!("{}{}", self.url(), self.metadata.health_config.endpoint);
let timeout = Duration::from_secs(self.metadata.health_config.timeout_secs);
// Use the shared client with a custom timeout for this request
match HEALTH_CHECK_CLIENT
.get(&health_url)
.timeout(timeout)
.send()
.await
{
Ok(response) => {
if response.status().is_success() {
self.set_healthy(true);
Ok(())
} else {
self.set_healthy(false);
Err(WorkerError::HealthCheckFailed {
url: self.url().to_string(),
reason: format!("Health check returned status: {}", response.status()),
})
}
}
Err(e) => {
self.set_healthy(false);
Err(WorkerError::HealthCheckFailed {
url: self.url().to_string(),
reason: format!("Health check request failed: {}", e),
})
}
}
}
fn load(&self) -> usize {
self.load_counter.load(Ordering::Relaxed)
}
fn increment_load(&self) {
self.load_counter.fetch_add(1, Ordering::Relaxed);
}
fn decrement_load(&self) {
self.load_counter
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
current.checked_sub(1)
})
.ok();
}
fn processed_requests(&self) -> usize {
self.processed_counter.load(Ordering::Relaxed)
}
fn increment_processed(&self) {
self.processed_counter.fetch_add(1, Ordering::Relaxed);
}
fn metadata(&self) -> &WorkerMetadata {
&self.metadata
}
fn clone_worker(&self) -> Box<dyn Worker> {
Box::new(self.clone())
}
}
/// Worker factory for creating workers of different types
pub struct WorkerFactory;
impl WorkerFactory {
/// Create a regular worker
pub fn create_regular(url: String) -> Box<dyn Worker> {
Box::new(BasicWorker::new(url, WorkerType::Regular))
}
/// Create a prefill worker with optional bootstrap port
pub fn create_prefill(url: String, bootstrap_port: Option<u16>) -> Box<dyn Worker> {
Box::new(BasicWorker::new(
url,
WorkerType::Prefill { bootstrap_port },
))
}
/// Create a decode worker
pub fn create_decode(url: String) -> Box<dyn Worker> {
Box::new(BasicWorker::new(url, WorkerType::Decode))
}
/// Create workers from URLs with automatic type detection
pub fn create_from_urls(
regular_urls: Vec<String>,
prefill_urls: Vec<(String, Option<u16>)>,
decode_urls: Vec<String>,
) -> (
Vec<Box<dyn Worker>>,
Vec<Box<dyn Worker>>,
Vec<Box<dyn Worker>>,
) {
let regular_workers: Vec<Box<dyn Worker>> =
regular_urls.into_iter().map(Self::create_regular).collect();
let prefill_workers: Vec<Box<dyn Worker>> = prefill_urls
.into_iter()
.map(|(url, port)| Self::create_prefill(url, port))
.collect();
let decode_workers: Vec<Box<dyn Worker>> =
decode_urls.into_iter().map(Self::create_decode).collect();
(regular_workers, prefill_workers, decode_workers)
}
}
/// Helper trait for collections of workers
pub trait WorkerCollection {
fn healthy_workers(&self) -> Vec<&dyn Worker>;
fn total_load(&self) -> usize;
fn find_worker(&self, url: &str) -> Option<&dyn Worker>;
fn find_worker_mut(&mut self, url: &str) -> Option<&mut Box<dyn Worker>>;
}
impl WorkerCollection for Vec<Box<dyn Worker>> {
fn healthy_workers(&self) -> Vec<&dyn Worker> {
self.iter()
.filter(|w| w.is_healthy())
.map(|w| w.as_ref())
.collect()
}
fn total_load(&self) -> usize {
self.iter().map(|w| w.load()).sum()
}
fn find_worker(&self, url: &str) -> Option<&dyn Worker> {
self.iter().find(|w| w.url() == url).map(|w| w.as_ref())
}
fn find_worker_mut(&mut self, url: &str) -> Option<&mut Box<dyn Worker>> {
self.iter_mut().find(|w| w.url() == url)
}
}
/// Convert a list of worker URLs to worker trait objects
pub fn urls_to_workers(urls: Vec<String>) -> Vec<Box<dyn Worker>> {
urls.into_iter()
.map(WorkerFactory::create_regular)
.collect()
}
/// Convert worker trait objects back to URLs
pub fn workers_to_urls(workers: &[Box<dyn Worker>]) -> Vec<String> {
workers.iter().map(|w| w.url().to_string()).collect()
}
/// RAII guard for worker load management
pub struct WorkerLoadGuard<'a> {
workers: Vec<&'a dyn Worker>,
}
impl<'a> WorkerLoadGuard<'a> {
/// Create a new load guard for a single worker
pub fn new(worker: &'a dyn Worker) -> Self {
worker.increment_load();
Self {
workers: vec![worker],
}
}
/// Create a new load guard for multiple workers
pub fn new_multi(workers: Vec<&'a dyn Worker>) -> Self {
// Increment load counters for all workers
for worker in &workers {
worker.increment_load();
}
Self { workers }
}
}
impl<'a> Drop for WorkerLoadGuard<'a> {
fn drop(&mut self) {
// Decrement load counters for all workers
for worker in &self.workers {
worker.decrement_load();
}
}
}
/// Health checker handle with graceful shutdown
pub struct HealthChecker {
handle: tokio::task::JoinHandle<()>,
shutdown: Arc<AtomicBool>,
}
impl fmt::Debug for HealthChecker {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("HealthChecker")
.field("shutdown", &self.shutdown.load(Ordering::Relaxed))
.finish()
}
}
impl HealthChecker {
/// Shutdown the health checker gracefully
pub async fn shutdown(self) {
self.shutdown.store(true, Ordering::Release);
let _ = self.handle.await;
}
}
/// Start an async background health checker for a collection of workers
pub fn start_health_checker(
workers: std::sync::Arc<std::sync::RwLock<Vec<Box<dyn Worker>>>>,
check_interval_secs: u64,
) -> HealthChecker {
let shutdown = Arc::new(AtomicBool::new(false));
let shutdown_clone = shutdown.clone();
let handle = tokio::spawn(async move {
let mut interval =
tokio::time::interval(tokio::time::Duration::from_secs(check_interval_secs));
loop {
interval.tick().await;
// Check for shutdown signal
if shutdown_clone.load(Ordering::Acquire) {
tracing::info!("Health checker shutting down");
break;
}
// Check health of all workers
let workers_to_check = match workers.read() {
Ok(guard) => guard.iter().map(|w| w.clone_worker()).collect::<Vec<_>>(),
Err(poisoned) => {
tracing::error!("Worker lock poisoned: {}", poisoned);
continue;
}
};
// Perform health checks concurrently
let health_checks = workers_to_check.iter().map(|worker| {
let worker_url = worker.url().to_string();
let was_healthy = worker.is_healthy();
async move {
match worker.check_health_async().await {
Ok(_) => {
if !was_healthy {
tracing::info!("Worker {} is now healthy", worker_url);
}
}
Err(e) => {
if was_healthy {
tracing::warn!("Worker {} health check failed: {}", worker_url, e);
}
}
}
}
});
// Execute all health checks concurrently
futures::future::join_all(health_checks).await;
}
});
HealthChecker { handle, shutdown }
}
......@@ -2,6 +2,7 @@ use pyo3::prelude::*;
pub mod config;
pub mod logging;
use std::collections::HashMap;
pub mod core;
pub mod openai_api_types;
pub mod pd_router;
pub mod pd_types;
......
// PD (Prefill-Decode) Router Implementation
// This module handles routing for disaggregated prefill-decode systems
use crate::core::{HealthChecker, Worker, WorkerFactory, WorkerLoadGuard};
use crate::pd_types::{
Bootstrap, ChatReqInput, EngineInfo, GenerateReqInput, PDRouterError, PDSelectionPolicy,
api_path, Bootstrap, ChatReqInput, GenerateReqInput, PDRouterError, PDSelectionPolicy,
};
use crate::tree::Tree;
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
......@@ -11,7 +12,6 @@ use futures_util::{StreamExt, TryStreamExt};
use metrics::{counter, histogram};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex, RwLock};
use std::time::{Duration, Instant};
use tracing::{debug, error, info, warn};
......@@ -21,49 +21,17 @@ use uuid::Uuid;
#[derive(Debug)]
pub struct PDRouter {
pub prefill_workers: Arc<RwLock<Vec<EngineInfo>>>,
pub decode_workers: Arc<RwLock<Vec<EngineInfo>>>,
pub prefill_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
pub decode_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
pub selection_policy: PDSelectionPolicy,
pub load_tracking: Arc<dashmap::DashMap<String, Arc<AtomicUsize>>>,
pub prefill_tree: Option<Arc<Mutex<Tree>>>,
pub timeout_secs: u64,
pub interval_secs: u64,
pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
pub load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
pub http_client: reqwest::Client,
}
// RAII guard for load tracking to ensure cleanup even on panic
struct LoadGuard<'a> {
tracking: &'a Arc<dashmap::DashMap<String, Arc<AtomicUsize>>>,
urls: Vec<String>,
}
impl<'a> LoadGuard<'a> {
fn new(
tracking: &'a Arc<dashmap::DashMap<String, Arc<AtomicUsize>>>,
urls: Vec<String>,
) -> Self {
// Increment counters
for url in &urls {
let counter = tracking
.entry(url.clone())
.or_insert_with(|| Arc::new(AtomicUsize::new(0)));
counter.fetch_add(1, Ordering::Relaxed);
}
LoadGuard { tracking, urls }
}
}
impl Drop for LoadGuard<'_> {
fn drop(&mut self) {
// Guaranteed cleanup even on panic
for url in &self.urls {
if let Some(counter) = self.tracking.get(url) {
counter.fetch_sub(1, Ordering::Relaxed);
}
}
}
_prefill_health_checker: Option<HealthChecker>,
_decode_health_checker: Option<HealthChecker>,
}
impl PDRouter {
......@@ -73,9 +41,6 @@ impl PDRouter {
url: String,
bootstrap_port: Option<u16>,
) -> Result<String, PDRouterError> {
// Create EngineInfo for the new prefill server
let engine_info = EngineInfo::new_prefill(url.clone(), bootstrap_port);
// Wait for the new server to be healthy
crate::router::Router::wait_for_healthy_workers(
&[url.clone()],
......@@ -84,6 +49,9 @@ impl PDRouter {
)
.map_err(|_| PDRouterError::HealthCheckFailed { url: url.clone() })?;
// Create Worker for the new prefill server
let worker = WorkerFactory::create_prefill(url.clone(), bootstrap_port);
// Add to prefill workers list
let mut workers = self
.prefill_workers
......@@ -93,15 +61,11 @@ impl PDRouter {
})?;
// Check if already exists
if workers.iter().any(|w| w.url == url) {
if workers.iter().any(|w| w.url() == &url) {
return Err(PDRouterError::WorkerAlreadyExists { url: url.clone() });
}
workers.push(engine_info);
// Initialize load tracking
self.load_tracking
.insert(url.clone(), Arc::new(AtomicUsize::new(0)));
workers.push(worker);
// Add to cache tree if using cache-aware policy
if let Some(ref tree) = self.prefill_tree {
......@@ -113,9 +77,6 @@ impl PDRouter {
}
pub async fn add_decode_server(&self, url: String) -> Result<String, PDRouterError> {
// Create EngineInfo for the new decode server
let engine_info = EngineInfo::new_decode(url.clone());
// Wait for the new server to be healthy
crate::router::Router::wait_for_healthy_workers(
&[url.clone()],
......@@ -124,6 +85,9 @@ impl PDRouter {
)
.map_err(|_| PDRouterError::HealthCheckFailed { url: url.clone() })?;
// Create Worker for the new decode server
let worker = WorkerFactory::create_decode(url.clone());
// Add to decode workers list
let mut workers = self
.decode_workers
......@@ -133,15 +97,14 @@ impl PDRouter {
})?;
// Check if already exists
if workers.iter().any(|w| w.url == url) {
if workers.iter().any(|w| w.url() == &url) {
return Err(PDRouterError::WorkerAlreadyExists { url: url.clone() });
}
workers.push(engine_info);
workers.push(worker);
// Initialize load tracking
self.load_tracking
.insert(url.clone(), Arc::new(AtomicUsize::new(0)));
// Worker tracks its own load internally
info!("Added decode server: {}", url);
Ok(format!("Successfully added decode server: {}", url))
......@@ -157,7 +120,7 @@ impl PDRouter {
// Find and remove the server
let initial_len = workers.len();
workers.retain(|w| w.url != url);
workers.retain(|w| w.url() != url);
if workers.len() == initial_len {
return Err(PDRouterError::WorkerNotFound {
......@@ -166,7 +129,7 @@ impl PDRouter {
}
// Remove from load tracking
self.load_tracking.remove(url);
// Worker load tracking is internal
// Remove from cache tree if using cache-aware policy
if let Some(ref tree) = self.prefill_tree {
......@@ -174,7 +137,7 @@ impl PDRouter {
let mut tree_guard = tree.lock().unwrap();
*tree_guard = Tree::new();
for worker in workers.iter() {
tree_guard.insert("", &worker.url);
tree_guard.insert("", worker.url());
}
}
......@@ -192,7 +155,7 @@ impl PDRouter {
// Find and remove the server
let initial_len = workers.len();
workers.retain(|w| w.url != url);
workers.retain(|w| w.url() != url);
if workers.len() == initial_len {
return Err(PDRouterError::WorkerNotFound {
......@@ -200,9 +163,6 @@ impl PDRouter {
});
}
// Remove from load tracking
self.load_tracking.remove(url);
info!("Removed decode server: {}", url);
Ok(format!("Successfully removed decode server: {}", url))
}
......@@ -214,41 +174,32 @@ impl PDRouter {
timeout_secs: u64,
interval_secs: u64,
) -> Result<Self, String> {
// Convert URLs to EngineInfo
let prefill_workers: Vec<EngineInfo> = prefill_urls
// Convert URLs to Worker trait objects
let prefill_workers: Vec<Box<dyn Worker>> = prefill_urls
.into_iter()
.map(|(url, port)| EngineInfo::new_prefill(url, port))
.map(|(url, port)| WorkerFactory::create_prefill(url, port))
.collect();
let decode_workers: Vec<EngineInfo> = decode_urls
let decode_workers: Vec<Box<dyn Worker>> = decode_urls
.into_iter()
.map(EngineInfo::new_decode)
.map(WorkerFactory::create_decode)
.collect();
// Wait for PD workers to be healthy
let all_urls: Vec<String> = prefill_workers
.iter()
.chain(decode_workers.iter())
.map(|engine| engine.url.clone())
.map(|worker| worker.url().to_string())
.collect();
crate::router::Router::wait_for_healthy_workers(&all_urls, timeout_secs, interval_secs)?;
// Initialize load tracking with atomic counters
let load_tracking = Arc::new(dashmap::DashMap::new());
for engine in &prefill_workers {
load_tracking.insert(engine.url.clone(), Arc::new(AtomicUsize::new(0)));
}
for engine in &decode_workers {
load_tracking.insert(engine.url.clone(), Arc::new(AtomicUsize::new(0)));
}
// Initialize cache-aware components if needed
let prefill_tree = match &selection_policy {
PDSelectionPolicy::CacheAware { .. } => {
let tree = Arc::new(Mutex::new(Tree::new()));
// Initialize tree with prefill workers
for engine in &prefill_workers {
tree.lock().unwrap().insert("", &engine.url);
for worker in &prefill_workers {
tree.lock().unwrap().insert("", worker.url());
}
Some(tree)
}
......@@ -283,17 +234,27 @@ impl PDRouter {
None
};
let prefill_workers = Arc::new(RwLock::new(prefill_workers));
let decode_workers = Arc::new(RwLock::new(decode_workers));
// Start health checkers for both worker pools
let prefill_health_checker =
crate::core::start_health_checker(Arc::clone(&prefill_workers), interval_secs);
let decode_health_checker =
crate::core::start_health_checker(Arc::clone(&decode_workers), interval_secs);
Ok(PDRouter {
prefill_workers: Arc::new(RwLock::new(prefill_workers)),
decode_workers: Arc::new(RwLock::new(decode_workers)),
prefill_workers,
decode_workers,
selection_policy,
load_tracking,
prefill_tree,
timeout_secs,
interval_secs,
worker_loads,
load_monitor_handle,
http_client,
_prefill_health_checker: Some(prefill_health_checker),
_decode_health_checker: Some(decode_health_checker),
})
}
......@@ -330,11 +291,13 @@ impl PDRouter {
// Log routing decision
info!(
"PD routing: {} -> prefill={}, decode={}",
route, prefill.url, decode.url
route,
prefill.url(),
decode.url()
);
// Add bootstrap info using the trait method
if let Err(e) = typed_req.add_bootstrap_info(&prefill) {
if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) {
error!("Failed to add bootstrap info: {}", e);
counter!("sgl_router_pd_errors_total", "error" => "bootstrap_injection").increment(1);
return HttpResponse::InternalServerError()
......@@ -356,8 +319,8 @@ impl PDRouter {
req,
json_with_bootstrap,
route,
&prefill,
&decode,
prefill.as_ref(),
decode.as_ref(),
is_stream,
return_logprob,
start,
......@@ -397,11 +360,13 @@ impl PDRouter {
// Log routing decision
info!(
"PD routing: {} -> prefill={}, decode={}",
route, prefill.url, decode.url
route,
prefill.url(),
decode.url()
);
// Add bootstrap info using the trait method
if let Err(e) = typed_req.add_bootstrap_info(&prefill) {
if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) {
error!("Failed to add bootstrap info: {}", e);
counter!("sgl_router_pd_errors_total", "error" => "bootstrap_injection").increment(1);
return HttpResponse::InternalServerError()
......@@ -423,8 +388,8 @@ impl PDRouter {
req,
json_with_bootstrap,
route,
&prefill,
&decode,
prefill.as_ref(),
decode.as_ref(),
is_stream,
return_logprob,
start,
......@@ -440,22 +405,23 @@ impl PDRouter {
req: &HttpRequest,
json_request: serde_json::Value,
route: &str,
prefill: &EngineInfo,
decode: &EngineInfo,
prefill: &dyn Worker,
decode: &dyn Worker,
is_stream: bool,
return_logprob: bool,
start_time: Instant,
) -> HttpResponse {
// Update load tracking for both workers
let _guard = LoadGuard::new(
&self.load_tracking,
vec![prefill.url.clone(), decode.url.clone()],
);
let _guard = WorkerLoadGuard::new_multi(vec![prefill, decode]);
// Build requests using .json() method
let mut prefill_request = client.post(prefill.api_path(route)).json(&json_request);
let mut prefill_request = client
.post(api_path(prefill.url(), route))
.json(&json_request);
let mut decode_request = client.post(decode.api_path(route)).json(&json_request);
let mut decode_request = client
.post(api_path(decode.url(), route))
.json(&json_request);
// Copy headers from original request
for (name, value) in crate::router::copy_request_headers(req) {
......@@ -474,9 +440,9 @@ impl PDRouter {
histogram!("sgl_router_pd_request_duration_seconds", "route" => route.to_string())
.record(duration.as_secs_f64());
counter!("sgl_router_pd_requests_total", "route" => route.to_string()).increment(1);
counter!("sgl_router_pd_prefill_requests_total", "worker" => prefill.url.to_string())
counter!("sgl_router_pd_prefill_requests_total", "worker" => prefill.url().to_string())
.increment(1);
counter!("sgl_router_pd_decode_requests_total", "worker" => decode.url.to_string())
counter!("sgl_router_pd_decode_requests_total", "worker" => decode.url().to_string())
.increment(1);
// Process decode response
......@@ -486,10 +452,11 @@ impl PDRouter {
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
if !status.is_success() {
counter!("sgl_router_pd_decode_errors_total", "worker" => decode.url.to_string()).increment(1);
counter!("sgl_router_pd_decode_errors_total", "worker" => decode.url().to_string()).increment(1);
error!(
"Decode server {} returned error status: {}",
decode.url, status
decode.url(),
status
);
// Return the error response from decode server
......@@ -508,9 +475,10 @@ impl PDRouter {
if let Err(e) = &prefill_result {
error!(
"Prefill server {} failed (non-critical): {}",
prefill.url, e
prefill.url(),
e
);
counter!("sgl_router_pd_prefill_errors_total", "worker" => prefill.url.to_string()).increment(1);
counter!("sgl_router_pd_prefill_errors_total", "worker" => prefill.url().to_string()).increment(1);
}
if is_stream {
......@@ -559,7 +527,7 @@ impl PDRouter {
HttpResponse::build(status)
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
.streaming({
let decode_url = decode.url.clone();
let decode_url = decode.url().to_string();
res.bytes_stream().map_err(move |e| {
error!("Stream error from decode server {}: {}", decode_url, e);
counter!("sgl_router_pd_stream_errors_total", "worker" => decode_url.to_string()).increment(1);
......@@ -587,7 +555,7 @@ impl PDRouter {
}
Err(e) => {
error!("Decode request failed: {}", e);
counter!("sgl_router_pd_decode_errors_total", "worker" => decode.url.to_string())
counter!("sgl_router_pd_decode_errors_total", "worker" => decode.url().to_string())
.increment(1);
HttpResponse::BadGateway().body(format!("Decode server error: {}", e))
}
......@@ -652,7 +620,7 @@ impl PDRouter {
async fn select_pd_pair(
&self,
_client: &reqwest::Client,
) -> Result<(EngineInfo, EngineInfo), String> {
) -> Result<(Box<dyn Worker>, Box<dyn Worker>), String> {
// Check we have workers
if self
.prefill_workers
......@@ -681,17 +649,17 @@ impl PDRouter {
}
}
fn select_random(&self) -> Result<(EngineInfo, EngineInfo), String> {
fn select_random(&self) -> Result<(Box<dyn Worker>, Box<dyn Worker>), String> {
let prefill_list = self.prefill_workers.read().map_err(|_| "Lock error")?;
let decode_list = self.decode_workers.read().map_err(|_| "Lock error")?;
let prefill = prefill_list[rand::random::<usize>() % prefill_list.len()].clone();
let decode = decode_list[rand::random::<usize>() % decode_list.len()].clone();
let prefill = prefill_list[rand::random::<usize>() % prefill_list.len()].clone_worker();
let decode = decode_list[rand::random::<usize>() % decode_list.len()].clone_worker();
Ok((prefill, decode))
}
async fn select_power_of_two(&self) -> Result<(EngineInfo, EngineInfo), String> {
async fn select_power_of_two(&self) -> Result<(Box<dyn Worker>, Box<dyn Worker>), String> {
let prefill_list = self.prefill_workers.read().map_err(|_| "Lock error")?;
let decode_list = self.decode_workers.read().map_err(|_| "Lock error")?;
......@@ -700,33 +668,45 @@ impl PDRouter {
let loads = self.worker_loads.borrow();
let p1_load = loads.get(&prefill_list[p1_idx].url).copied().unwrap_or(0);
let p2_load = loads.get(&prefill_list[p2_idx].url).copied().unwrap_or(0);
let d1_load = loads.get(&decode_list[d1_idx].url).copied().unwrap_or(0);
let d2_load = loads.get(&decode_list[d2_idx].url).copied().unwrap_or(0);
let p1_load = loads
.get(prefill_list[p1_idx].url())
.copied()
.unwrap_or(isize::MAX);
let p2_load = loads
.get(prefill_list[p2_idx].url())
.copied()
.unwrap_or(isize::MAX);
let d1_load = loads
.get(decode_list[d1_idx].url())
.copied()
.unwrap_or(isize::MAX);
let d2_load = loads
.get(decode_list[d2_idx].url())
.copied()
.unwrap_or(isize::MAX);
info!(
"Power-of-two selection - Prefill: {}={} vs {}={} | Decode: {}={} vs {}={}",
prefill_list[p1_idx].url,
prefill_list[p1_idx].url(),
p1_load,
prefill_list[p2_idx].url,
prefill_list[p2_idx].url(),
p2_load,
decode_list[d1_idx].url,
decode_list[d1_idx].url(),
d1_load,
decode_list[d2_idx].url,
decode_list[d2_idx].url(),
d2_load
);
let selected_prefill = if p1_load <= p2_load {
prefill_list[p1_idx].clone()
prefill_list[p1_idx].clone_worker()
} else {
prefill_list[p2_idx].clone()
prefill_list[p2_idx].clone_worker()
};
let selected_decode = if d1_load <= d2_load {
decode_list[d1_idx].clone()
decode_list[d1_idx].clone_worker()
} else {
decode_list[d2_idx].clone()
decode_list[d2_idx].clone_worker()
};
Ok((selected_prefill, selected_decode))
......@@ -868,11 +848,11 @@ impl PDRouter {
let mut worker_infos = Vec::new();
for worker in self.prefill_workers.read().unwrap().iter() {
worker_infos.push((worker.url.clone(), "prefill"));
worker_infos.push((worker.url().to_string(), "prefill"));
}
for worker in self.decode_workers.read().unwrap().iter() {
worker_infos.push((worker.url.clone(), "decode"));
worker_infos.push((worker.url().to_string(), "decode"));
}
// Create tasks with URL tracking
......@@ -922,7 +902,7 @@ impl PDRouter {
pub async fn get_server_info(&self, client: &reqwest::Client) -> HttpResponse {
// Get info from the first decode server to match sglang's server info format
let first_decode_url = if let Ok(workers) = self.decode_workers.read() {
workers.first().map(|w| w.url.clone())
workers.first().map(|w| w.url().to_string())
} else {
return HttpResponse::InternalServerError().body("Failed to access decode workers");
};
......@@ -967,7 +947,7 @@ impl PDRouter {
pub async fn get_models(&self, client: &reqwest::Client, req: &HttpRequest) -> HttpResponse {
// Get first prefill worker URL to avoid holding lock across await
let first_worker_url = if let Ok(workers) = self.prefill_workers.read() {
workers.first().map(|w| w.url.clone())
workers.first().map(|w| w.url().to_string())
} else {
return HttpResponse::InternalServerError().body("Failed to access prefill workers");
};
......@@ -1005,14 +985,14 @@ impl PDRouter {
.read()
.unwrap()
.iter()
.map(|w| w.url.clone())
.map(|w| w.url().to_string())
.collect();
let d_urls: Vec<_> = self
.decode_workers
.read()
.unwrap()
.iter()
.map(|w| w.url.clone())
.map(|w| w.url().to_string())
.collect();
let mut prefill_loads = Vec::new();
......@@ -1048,7 +1028,7 @@ impl PDRouter {
// Get model info from the first prefill server (matches original Rust PDLB behavior)
// Get first prefill worker URL to avoid holding lock across await
let first_worker_url = if let Ok(workers) = self.prefill_workers.read() {
workers.first().map(|w| w.url.clone())
workers.first().map(|w| w.url().to_string())
} else {
return HttpResponse::InternalServerError().body("Failed to access prefill workers");
};
......@@ -1084,13 +1064,13 @@ impl PDRouter {
// Flush cache on all prefill servers
for worker in self.prefill_workers.read().unwrap().iter() {
let url = format!("{}/flush_cache", worker.url);
let url = format!("{}/flush_cache", worker.url());
tasks.push(client.post(&url).send());
}
// Flush cache on all decode servers
for worker in self.decode_workers.read().unwrap().iter() {
let url = format!("{}/flush_cache", worker.url);
let url = format!("{}/flush_cache", worker.url());
tasks.push(client.post(&url).send());
}
......
// Essential PDLB types extracted for PD routing
use crate::core::{Worker, WorkerType};
use serde::{Deserialize, Serialize};
use serde_json::Value;
......@@ -28,52 +29,21 @@ pub enum PDRouterError {
Timeout { url: String },
}
#[derive(Debug, Clone)]
pub enum EngineType {
Prefill,
Decode,
}
#[derive(Debug, Clone)]
pub struct EngineInfo {
pub engine_type: EngineType,
pub url: String,
pub bootstrap_port: Option<u16>,
}
impl EngineInfo {
pub fn new_prefill(url: String, bootstrap_port: Option<u16>) -> Self {
EngineInfo {
engine_type: EngineType::Prefill,
url,
bootstrap_port,
}
}
pub fn new_decode(url: String) -> Self {
EngineInfo {
engine_type: EngineType::Decode,
url,
bootstrap_port: None,
}
}
pub fn api_path(&self, api_path: &str) -> String {
if api_path.starts_with("/") {
format!("{}{}", self.url, api_path)
} else {
format!("{}/{}", self.url, api_path)
}
// Helper functions for workers
pub fn api_path(url: &str, api_path: &str) -> String {
if api_path.starts_with("/") {
format!("{}{}", url, api_path)
} else {
format!("{}/{}", url, api_path)
}
}
pub fn get_hostname(&self) -> String {
// Simple hostname extraction without external dependencies
let url = self
.url
.trim_start_matches("http://")
.trim_start_matches("https://");
url.split(':').next().unwrap_or("localhost").to_string()
}
pub fn get_hostname(url: &str) -> String {
// Simple hostname extraction without external dependencies
let url = url
.trim_start_matches("http://")
.trim_start_matches("https://");
url.split(':').next().unwrap_or("localhost").to_string()
}
// PD-specific routing policies
......@@ -112,12 +82,21 @@ pub trait Bootstrap: Send + Sync {
bootstrap_room: BootstrapRoom,
);
fn add_bootstrap_info(&mut self, prefill_info: &EngineInfo) -> Result<(), String> {
fn add_bootstrap_info(&mut self, prefill_worker: &dyn Worker) -> Result<(), String> {
let batch_size = self.get_batch_size()?;
// Extract bootstrap port from prefill worker if it's a prefill type
let bootstrap_port = match prefill_worker.worker_type() {
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
_ => None,
};
let hostname = get_hostname(prefill_worker.url());
if let Some(batch_size) = batch_size {
self.set_bootstrap_info(
BootstrapHost::Batch(vec![prefill_info.get_hostname(); batch_size]),
BootstrapPort::Batch(vec![prefill_info.bootstrap_port; batch_size]),
BootstrapHost::Batch(vec![hostname; batch_size]),
BootstrapPort::Batch(vec![bootstrap_port; batch_size]),
// Use high-quality random numbers to minimize collision risk
BootstrapRoom::Batch(
(0..batch_size)
......@@ -132,8 +111,8 @@ pub trait Bootstrap: Send + Sync {
);
} else {
self.set_bootstrap_info(
BootstrapHost::Single(prefill_info.get_hostname()),
BootstrapPort::Single(prefill_info.bootstrap_port),
BootstrapHost::Single(hostname),
BootstrapPort::Single(bootstrap_port),
BootstrapRoom::Single({
// Use high-quality random number for single requests too
let r1 = rand::random::<u64>();
......
use crate::core::{HealthChecker, Worker, WorkerFactory};
use crate::pd_router::PDRouter;
use crate::pd_types::PDSelectionPolicy;
use crate::tree::Tree;
......@@ -5,7 +6,6 @@ use ::metrics::{counter, gauge, histogram};
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
use actix_web::{HttpRequest, HttpResponse};
use futures_util::{StreamExt, TryStreamExt};
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::atomic::AtomicUsize;
use std::sync::{Arc, Mutex, RwLock};
......@@ -30,15 +30,17 @@ pub fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> {
#[derive(Debug)]
pub enum Router {
RoundRobin {
worker_urls: Arc<RwLock<Vec<String>>>,
workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
current_index: AtomicUsize,
timeout_secs: u64,
interval_secs: u64,
_health_checker: Option<HealthChecker>,
},
Random {
worker_urls: Arc<RwLock<Vec<String>>>,
workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
timeout_secs: u64,
interval_secs: u64,
_health_checker: Option<HealthChecker>,
},
PrefillDecode {
pd_router: Arc<PDRouter>,
......@@ -104,16 +106,15 @@ pub enum Router {
Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted
during the next eviction cycle.
*/
worker_urls: Arc<RwLock<Vec<String>>>,
workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
tree: Arc<Mutex<Tree>>,
running_queue: Arc<Mutex<HashMap<String, usize>>>,
processed_queue: Arc<Mutex<HashMap<String, usize>>>,
cache_threshold: f32,
balance_abs_threshold: usize,
balance_rel_threshold: f32,
timeout_secs: u64,
interval_secs: u64,
_eviction_thread: Option<thread::JoinHandle<()>>,
_health_checker: Option<HealthChecker>,
},
}
......@@ -192,25 +193,43 @@ impl Router {
}
}
// Create Worker trait objects from URLs
let workers: Vec<Box<dyn Worker>> = worker_urls
.iter()
.map(|url| WorkerFactory::create_regular(url.clone()))
.collect();
// Create router based on policy...
Ok(match policy_config {
PolicyConfig::RandomConfig {
timeout_secs,
interval_secs,
} => Router::Random {
worker_urls: Arc::new(RwLock::new(worker_urls)),
timeout_secs,
interval_secs,
},
} => {
let workers = Arc::new(RwLock::new(workers));
let health_checker =
crate::core::start_health_checker(Arc::clone(&workers), interval_secs);
Router::Random {
workers,
timeout_secs,
interval_secs,
_health_checker: Some(health_checker),
}
}
PolicyConfig::RoundRobinConfig {
timeout_secs,
interval_secs,
} => Router::RoundRobin {
worker_urls: Arc::new(RwLock::new(worker_urls)),
current_index: std::sync::atomic::AtomicUsize::new(0),
timeout_secs,
interval_secs,
},
} => {
let workers = Arc::new(RwLock::new(workers));
let health_checker =
crate::core::start_health_checker(Arc::clone(&workers), interval_secs);
Router::RoundRobin {
workers,
current_index: std::sync::atomic::AtomicUsize::new(0),
timeout_secs,
interval_secs,
_health_checker: Some(health_checker),
}
}
PolicyConfig::CacheAwareConfig {
cache_threshold,
balance_abs_threshold,
......@@ -220,24 +239,12 @@ impl Router {
timeout_secs,
interval_secs,
} => {
let mut running_queue = HashMap::new();
for url in &worker_urls {
running_queue.insert(url.clone(), 0);
}
let mut processed_queue = HashMap::new();
for url in &worker_urls {
processed_queue.insert(url.clone(), 0);
}
let tree = Arc::new(Mutex::new(Tree::new()));
let running_queue = Arc::new(Mutex::new(running_queue));
let processed_queue = Arc::new(Mutex::new(processed_queue));
// Create background eviction thread
let tree_clone = Arc::clone(&tree);
let processed_queue_clone = Arc::clone(&processed_queue);
let running_queue_clone = Arc::clone(&running_queue);
let workers = Arc::new(RwLock::new(workers));
let workers_clone = Arc::clone(&workers);
let eviction_thread = thread::spawn(move || {
loop {
// Sleep for the specified interval
......@@ -246,32 +253,41 @@ impl Router {
let locked_tree_clone = tree_clone.lock().unwrap();
// Run eviction
locked_tree_clone.evict_tenant_by_size(max_tree_size);
// Print the process queue
let locked_processed_queue = processed_queue_clone.lock().unwrap();
info!("Processed Queue: {:?}", locked_processed_queue);
// Print the running queue
let locked_running_queue = running_queue_clone.lock().unwrap();
info!("Running Queue: {:?}", locked_running_queue);
drop(locked_tree_clone);
// Log worker loads and processed requests
let workers_guard = workers_clone.read().unwrap();
let loads: Vec<(String, usize)> = workers_guard
.iter()
.map(|w| (w.url().to_string(), w.load()))
.collect();
info!("Worker loads: {:?}", loads);
let processed: Vec<(String, usize)> = workers_guard
.iter()
.map(|w| (w.url().to_string(), w.processed_requests()))
.collect();
info!("Processed requests: {:?}", processed);
}
});
for url in &worker_urls {
tree.lock().unwrap().insert("", url);
for worker in workers.read().unwrap().iter() {
tree.lock().unwrap().insert("", worker.url());
}
let health_checker =
crate::core::start_health_checker(Arc::clone(&workers), interval_secs);
Router::CacheAware {
worker_urls: Arc::new(RwLock::new(worker_urls)),
workers,
tree,
running_queue,
processed_queue,
cache_threshold,
balance_abs_threshold,
balance_rel_threshold,
timeout_secs,
interval_secs,
_eviction_thread: Some(eviction_thread),
_health_checker: Some(health_checker),
}
}
PolicyConfig::PrefillDecodeConfig {
......@@ -297,16 +313,18 @@ impl Router {
})
}
/// Get a reference to the worker URLs shared across threads
pub fn get_worker_urls(&self) -> Arc<RwLock<Vec<String>>> {
/// Get the current list of worker URLs
pub fn get_worker_urls(&self) -> Vec<String> {
match self {
Router::RoundRobin { worker_urls, .. } => Arc::clone(worker_urls),
Router::Random { worker_urls, .. } => Arc::clone(worker_urls),
Router::CacheAware { worker_urls, .. } => Arc::clone(worker_urls),
Router::PrefillDecode { .. } => {
// For PD mode, return empty list since we manage workers differently
Arc::new(RwLock::new(Vec::new()))
}
Router::RoundRobin { workers, .. }
| Router::Random { workers, .. }
| Router::CacheAware { workers, .. } => workers
.read()
.unwrap()
.iter()
.map(|w| w.url().to_string())
.collect(),
Router::PrefillDecode { .. } => Vec::new(),
}
}
......@@ -373,13 +391,14 @@ impl Router {
fn select_first_worker(&self) -> Result<String, String> {
match self {
Router::RoundRobin { worker_urls, .. }
| Router::Random { worker_urls, .. }
| Router::CacheAware { worker_urls, .. } => {
if worker_urls.read().unwrap().is_empty() {
Router::RoundRobin { workers, .. }
| Router::Random { workers, .. }
| Router::CacheAware { workers, .. } => {
let workers_guard = workers.read().unwrap();
if workers_guard.is_empty() {
Err("No workers are available".to_string())
} else {
Ok(worker_urls.read().unwrap()[0].clone())
Ok(workers_guard[0].url().to_string())
}
}
Router::PrefillDecode { .. } => {
......@@ -514,7 +533,7 @@ impl Router {
return HttpResponse::NotImplemented()
.body("route_to_all not implemented for PrefillDecode mode");
}
_ => self.get_worker_urls().read().unwrap().clone(),
_ => self.get_worker_urls(),
};
// Send requests to all workers concurrently
......@@ -562,7 +581,7 @@ impl Router {
}
}
let urls = self.get_worker_urls().read().unwrap().clone();
let urls = self.get_worker_urls();
let prefill_urls: Vec<String> = Vec::new();
let decode_urls = urls;
......@@ -631,6 +650,24 @@ impl Router {
.increment(1);
}
// For CacheAware router, increment load before request
let load_incremented = match self {
Router::CacheAware { workers, .. } => {
let workers_guard = workers.read().unwrap();
if let Some(worker) =
workers_guard.iter().find(|w| w.url() == &worker_url)
{
worker.increment_load();
gauge!("sgl_router_running_requests", "worker" => worker_url.to_string())
.set(worker.load() as f64);
true
} else {
false
}
}
_ => false,
};
// Send typed request directly
let response = self
.send_typed_request(
......@@ -640,6 +677,7 @@ impl Router {
route,
&worker_url,
is_stream,
load_incremented,
)
.await;
......@@ -684,44 +722,47 @@ impl Router {
}
}
// Helper method to select worker from text
// Helper method to select worker from text (returns index for RoundRobin/Random, URL for CacheAware)
fn select_generate_worker_from_text(&self, text: &str) -> String {
match self {
Router::RoundRobin {
worker_urls,
workers,
current_index,
..
} => {
let workers_guard = workers.read().unwrap();
let idx = current_index
.fetch_update(
std::sync::atomic::Ordering::SeqCst,
std::sync::atomic::Ordering::SeqCst,
|x| Some((x + 1) % worker_urls.read().unwrap().len()),
|x| Some((x + 1) % workers_guard.len()),
)
.unwrap();
worker_urls.read().unwrap()[idx].clone()
workers_guard[idx].url().to_string()
}
Router::Random { worker_urls, .. } => worker_urls.read().unwrap()
[rand::random::<usize>() % worker_urls.read().unwrap().len()]
.clone(),
Router::Random { workers, .. } => {
let workers_guard = workers.read().unwrap();
workers_guard[rand::random::<usize>() % workers_guard.len()]
.url()
.to_string()
}
Router::CacheAware {
worker_urls,
workers,
tree,
running_queue,
processed_queue,
cache_threshold,
balance_abs_threshold,
balance_rel_threshold,
..
} => {
let tree = tree.lock().unwrap();
let mut running_queue = running_queue.lock().unwrap();
let workers_guard = workers.read().unwrap();
// Get current load statistics
let max_load = *running_queue.values().max().unwrap_or(&0);
let min_load = *running_queue.values().min().unwrap_or(&0);
// Get current load statistics from workers
let loads: Vec<usize> = workers_guard.iter().map(|w| w.load()).collect();
let max_load = *loads.iter().max().unwrap_or(&0);
let min_load = *loads.iter().min().unwrap_or(&0);
// Load is considered imbalanced if:
// 1. (max - min) > abs_threshold AND
......@@ -731,11 +772,16 @@ impl Router {
let selected_url = if is_imbalanced {
// Log load balancing trigger and current queue state
let worker_loads: Vec<(String, usize)> = workers_guard
.iter()
.map(|w| (w.url().to_string(), w.load()))
.collect();
info!(
"Load balancing triggered due to workload imbalance:\n\
Max load: {}, Min load: {}\n\
Current running queue: {:?}",
max_load, min_load, running_queue
Current worker loads: {:?}",
max_load, min_load, worker_loads
);
counter!("sgl_router_load_balancing_events_total").increment(1);
......@@ -743,11 +789,11 @@ impl Router {
gauge!("sgl_router_min_load").set(min_load as f64);
// Use shortest queue routing when load is imbalanced
running_queue
workers_guard
.iter()
.min_by_key(|(_url, &count)| count)
.map(|(url, _)| url.clone())
.unwrap_or_else(|| worker_urls.read().unwrap()[0].clone())
.min_by_key(|w| w.load())
.map(|w| w.url().to_string())
.unwrap_or_else(|| workers_guard[0].url().to_string())
} else {
// Use cache-aware routing when load is balanced
let (matched_text, matched_worker) = tree.prefix_match(&text);
......@@ -763,18 +809,12 @@ impl Router {
}
};
// Update queues and tree
*running_queue.get_mut(&selected_url).unwrap() += 1;
*processed_queue
.lock()
.unwrap()
.get_mut(&selected_url)
.unwrap() += 1;
gauge!("sgl_router_running_requests", "worker" => selected_url.to_string())
.set(*running_queue.get(&selected_url).unwrap() as f64);
counter!("sgl_router_processed_requests_total", "worker" => selected_url.to_string()).increment(1);
// Find the selected worker and increment processed counter only
if let Some(worker) = workers_guard.iter().find(|w| w.url() == &selected_url) {
worker.increment_processed();
counter!("sgl_router_processed_requests_total", "worker" => selected_url.to_string())
.increment(1);
}
tree.insert(&text, &selected_url);
......@@ -796,6 +836,7 @@ impl Router {
route: &str,
worker_url: &str,
is_stream: bool,
load_incremented: bool, // Whether load was incremented for this request
) -> HttpResponse {
let start = Instant::now();
......@@ -820,6 +861,22 @@ impl Router {
Ok(res) => res,
Err(e) => {
error!("Failed to send request to {}: {}", worker_url, e);
// Decrement load on error for CacheAware router
if load_incremented {
if let Router::CacheAware { workers, .. } = self {
if let Ok(workers_guard) = workers.read() {
if let Some(worker) =
workers_guard.iter().find(|w| w.url() == worker_url)
{
worker.decrement_load();
gauge!("sgl_router_running_requests", "worker" => worker_url.to_string())
.set(worker.load() as f64);
}
}
}
}
return HttpResponse::InternalServerError().body(format!("Request failed: {}", e));
}
};
......@@ -837,13 +894,15 @@ impl Router {
}
};
// Then decrement running queue counter if using CacheAware
if let Router::CacheAware { running_queue, .. } = self {
if let Ok(mut queue) = running_queue.lock() {
if let Some(count) = queue.get_mut(worker_url) {
*count = count.saturating_sub(1);
gauge!("sgl_router_running_requests", "worker" => worker_url.to_string())
.set(*count as f64);
// Decrement load counter for non-streaming CacheAware requests
if load_incremented && !is_stream {
if let Router::CacheAware { workers, .. } = self {
if let Ok(workers_guard) = workers.read() {
if let Some(worker) = workers_guard.iter().find(|w| w.url() == worker_url) {
worker.decrement_load();
gauge!("sgl_router_running_requests", "worker" => worker_url.to_string())
.set(worker.load() as f64);
}
}
}
}
......@@ -855,8 +914,9 @@ impl Router {
counter!("sgl_router_requests_total", "route" => route.to_string()).increment(1);
response
} else if let Router::CacheAware { running_queue, .. } = self {
let running_queue = Arc::clone(running_queue);
} else if let Router::CacheAware { workers, .. } = self {
// For streaming with CacheAware router, we need to manually decrement when done
let workers = Arc::clone(workers);
let worker_url = worker_url.to_string();
HttpResponse::build(status)
......@@ -867,21 +927,28 @@ impl Router {
actix_web::error::ErrorInternalServerError("Failed to read stream")
})
.inspect(move |bytes| {
let bytes = bytes.as_ref().unwrap();
if bytes
.as_ref()
.windows(12)
.any(|window| window == b"data: [DONE]")
{
let mut locked_queue = running_queue.lock().unwrap();
let count = locked_queue.get_mut(&worker_url).unwrap();
*count = count.saturating_sub(1);
gauge!("sgl_router_running_requests", "worker" => worker_url.to_string()).set(*count as f64);
debug!("Streaming is done!!")
if let Ok(bytes) = bytes {
if bytes
.as_ref()
.windows(12)
.any(|window| window == b"data: [DONE]")
{
if let Ok(workers_guard) = workers.read() {
if let Some(worker) =
workers_guard.iter().find(|w| w.url() == &worker_url)
{
worker.decrement_load();
gauge!("sgl_router_running_requests", "worker" => worker_url.to_string())
.set(worker.load() as f64);
debug!("Streaming is done!!")
}
}
}
}
}),
)
} else {
// For non-CacheAware routers, just stream without load tracking
HttpResponse::build(status)
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
.streaming(res.bytes_stream().map_err(|_| {
......@@ -935,43 +1002,27 @@ impl Router {
Ok(res) => {
if res.status().is_success() {
match self {
Router::RoundRobin { worker_urls, .. }
| Router::Random { worker_urls, .. }
| Router::CacheAware { worker_urls, .. } => {
Router::RoundRobin { workers, .. }
| Router::Random { workers, .. }
| Router::CacheAware { workers, .. } => {
info!("Worker {} health check passed", worker_url);
let mut urls = worker_urls.write().unwrap();
if urls.contains(&worker_url.to_string()) {
let mut workers_guard = workers.write().unwrap();
if workers_guard.iter().any(|w| w.url() == worker_url) {
return Err(format!("Worker {} already exists", worker_url));
}
info!("Added worker: {}", worker_url);
urls.push(worker_url.to_string());
gauge!("sgl_router_active_workers").set(urls.len() as f64);
let new_worker =
WorkerFactory::create_regular(worker_url.to_string());
workers_guard.push(new_worker);
gauge!("sgl_router_active_workers").set(workers_guard.len() as f64);
}
Router::PrefillDecode { .. } => {
return Err("Adding workers to PrefillDecode router not supported via add_worker. Use dedicated PD management methods.".to_string());
}
}
// If cache aware, initialize the queues for the new worker
if let Router::CacheAware {
running_queue,
processed_queue,
tree,
..
} = self
{
// Add worker to running queue with initial count of 0
running_queue
.lock()
.unwrap()
.insert(worker_url.to_string(), 0);
// Add worker to processed queue with initial count of 0
processed_queue
.lock()
.unwrap()
.insert(worker_url.to_string(), 0);
// If cache aware, add worker to tree
if let Router::CacheAware { tree, .. } = self {
// Add worker to tree
tree.lock().unwrap().insert("", worker_url);
}
......@@ -1013,14 +1064,14 @@ impl Router {
pub fn remove_worker(&self, worker_url: &str) {
match self {
Router::RoundRobin { worker_urls, .. }
| Router::Random { worker_urls, .. }
| Router::CacheAware { worker_urls, .. } => {
let mut urls = worker_urls.write().unwrap();
if let Some(index) = urls.iter().position(|url| url == &worker_url) {
urls.remove(index);
Router::RoundRobin { workers, .. }
| Router::Random { workers, .. }
| Router::CacheAware { workers, .. } => {
let mut workers_guard = workers.write().unwrap();
if let Some(index) = workers_guard.iter().position(|w| w.url() == worker_url) {
workers_guard.remove(index);
info!("Removed worker: {}", worker_url);
gauge!("sgl_router_active_workers").set(urls.len() as f64);
gauge!("sgl_router_active_workers").set(workers_guard.len() as f64);
} else {
warn!("Worker {} not found, skipping removal", worker_url);
return;
......@@ -1033,26 +1084,9 @@ impl Router {
}
// if cache aware, remove the worker from the tree
if let Router::CacheAware {
tree,
running_queue,
processed_queue,
..
} = self
{
if let Router::CacheAware { tree, .. } = self {
tree.lock().unwrap().remove_tenant(&worker_url);
running_queue
.lock()
.unwrap()
.remove(&worker_url.to_string());
processed_queue
.lock()
.unwrap()
.remove(&worker_url.to_string());
info!(
"Removed worker from tree and cleaned up queues: {}",
worker_url
);
info!("Removed worker from tree: {}", worker_url);
}
}
......@@ -1241,21 +1275,22 @@ mod tests {
use crate::service_discovery::PodType;
fn create_test_regular_router() -> Router {
let workers = vec![
WorkerFactory::create_regular("http://worker1:8080".to_string()),
WorkerFactory::create_regular("http://worker2:8080".to_string()),
];
Router::Random {
worker_urls: Arc::new(RwLock::new(vec![
"http://worker1:8080".to_string(),
"http://worker2:8080".to_string(),
])),
workers: Arc::new(RwLock::new(workers)),
timeout_secs: 5,
interval_secs: 1,
_health_checker: None,
}
}
#[test]
fn test_router_get_worker_urls_regular() {
let router = create_test_regular_router();
let worker_urls = router.get_worker_urls();
let urls = worker_urls.read().unwrap();
let urls = router.get_worker_urls();
assert_eq!(urls.len(), 2);
assert!(urls.contains(&"http://worker1:8080".to_string()));
......
......@@ -236,8 +236,7 @@ async fn add_worker(
#[get("/list_workers")]
async fn list_workers(data: web::Data<AppState>) -> impl Responder {
let workers = data.router.get_worker_urls();
let worker_list = workers.read().unwrap().clone();
let worker_list = data.router.get_worker_urls();
HttpResponse::Ok().json(serde_json::json!({ "urls": worker_list }))
}
......@@ -381,7 +380,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
info!("✅ Serving router on {}:{}", config.host, config.port);
info!(
"✅ Serving workers on {:?}",
app_state.router.get_worker_urls().read().unwrap()
app_state.router.get_worker_urls()
);
HttpServer::new(move || {
......
......@@ -547,11 +547,12 @@ mod tests {
// Helper to create a Router instance for testing event handlers
fn create_test_router() -> Arc<Router> {
let worker_urls = Arc::new(RwLock::new(Vec::new()));
let workers = Arc::new(RwLock::new(Vec::new()));
Arc::new(Router::Random {
worker_urls,
workers,
timeout_secs: 5,
interval_secs: 1,
_health_checker: None,
})
}
......@@ -878,8 +879,6 @@ mod tests {
assert!(!tracked_pods.lock().unwrap().contains(&pod_info));
assert!(!router
.get_worker_urls()
.read()
.unwrap()
.contains(&pod_info.worker_url(port)));
}
......@@ -907,7 +906,7 @@ mod tests {
.await;
assert!(tracked_pods.lock().unwrap().is_empty());
assert!(router.get_worker_urls().read().unwrap().is_empty());
assert!(router.get_worker_urls().is_empty());
}
#[tokio::test]
......
......@@ -12,7 +12,7 @@
mod test_pd_routing {
use rand::Rng;
use serde_json::json;
use sglang_router_rs::pd_types::{EngineInfo, EngineType, PDSelectionPolicy};
use sglang_router_rs::pd_types::PDSelectionPolicy;
use sglang_router_rs::router::{PolicyConfig, Router};
// Test-only struct to help validate PD request parsing
......@@ -51,40 +51,35 @@ mod test_pd_routing {
// ========================================================================
#[test]
fn test_engine_info_creation() {
// Test EngineInfo creation for prefill servers
let prefill_engine = EngineInfo::new_prefill("http://prefill:8080".to_string(), Some(9000));
match prefill_engine.engine_type {
EngineType::Prefill => (),
_ => panic!("Expected Prefill engine type"),
fn test_worker_types() {
use sglang_router_rs::core::{WorkerFactory, WorkerType};
// Test worker creation for prefill servers
let prefill_worker =
WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9000));
assert_eq!(prefill_worker.url(), "http://prefill:8080");
match prefill_worker.worker_type() {
WorkerType::Prefill { bootstrap_port } => {
assert_eq!(bootstrap_port, Some(9000));
}
_ => panic!("Expected Prefill worker type"),
}
assert_eq!(prefill_engine.url, "http://prefill:8080");
assert_eq!(prefill_engine.bootstrap_port, Some(9000));
assert_eq!(prefill_engine.get_hostname(), "prefill");
// Test EngineInfo creation for decode servers
let decode_engine = EngineInfo::new_decode("http://decode:8080".to_string());
match decode_engine.engine_type {
EngineType::Decode => (),
_ => panic!("Expected Decode engine type"),
// Test worker creation for decode servers
let decode_worker = WorkerFactory::create_decode("http://decode:8080".to_string());
assert_eq!(decode_worker.url(), "http://decode:8080");
match decode_worker.worker_type() {
WorkerType::Decode => (),
_ => panic!("Expected Decode worker type"),
}
assert_eq!(decode_engine.url, "http://decode:8080");
assert_eq!(decode_engine.bootstrap_port, None);
assert_eq!(decode_engine.get_hostname(), "decode");
// Test API path generation
assert_eq!(
prefill_engine.api_path("/generate"),
"http://prefill:8080/generate"
);
assert_eq!(
prefill_engine.api_path("health"),
"http://prefill:8080/health"
);
assert_eq!(
decode_engine.api_path("/v1/chat/completions"),
"http://decode:8080/v1/chat/completions"
);
// Test regular worker creation
let regular_worker = WorkerFactory::create_regular("http://regular:8080".to_string());
assert_eq!(regular_worker.url(), "http://regular:8080");
match regular_worker.worker_type() {
WorkerType::Regular => (),
_ => panic!("Expected Regular worker type"),
}
}
#[test]
......@@ -230,6 +225,9 @@ mod test_pd_routing {
#[test]
fn test_bootstrap_injection_simulation() {
use sglang_router_rs::core::{WorkerFactory, WorkerType};
use sglang_router_rs::pd_types::get_hostname;
// Since we can't test the actual inject_bootstrap_fields function here
// (it's private in the router module), we'll test the expected behavior
......@@ -240,15 +238,24 @@ mod test_pd_routing {
"temperature": 0.7
});
// Create a prefill worker to simulate injection
let prefill_worker =
WorkerFactory::create_prefill("http://prefill1:8080".to_string(), Some(9000));
// Extract bootstrap port from worker type
let bootstrap_port = match prefill_worker.worker_type() {
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
_ => None,
};
// Simulate what inject_bootstrap_fields would do
let prefill_info = EngineInfo::new_prefill("http://prefill1:8080".to_string(), Some(9000));
single_json["bootstrap_host"] = json!(prefill_info.get_hostname());
single_json["bootstrap_port"] = json!(prefill_info.bootstrap_port);
single_json["bootstrap_host"] = json!(get_hostname(prefill_worker.url()));
single_json["bootstrap_port"] = json!(bootstrap_port);
single_json["bootstrap_room"] = json!(12345u64); // Random room ID
// Verify bootstrap fields are added correctly
assert_eq!(single_json["bootstrap_host"], "prefill1");
assert_eq!(single_json["bootstrap_port"], 9000);
assert_eq!(single_json["bootstrap_port"], json!(Some(9000)));
assert!(single_json["bootstrap_room"].is_u64());
assert_eq!(single_json["temperature"], 0.7); // Original field preserved
......@@ -259,8 +266,9 @@ mod test_pd_routing {
});
let batch_size = 3;
batch_json["bootstrap_host"] = json!(vec![prefill_info.get_hostname(); batch_size]);
batch_json["bootstrap_port"] = json!(vec![prefill_info.bootstrap_port; batch_size]);
let hostname = get_hostname(prefill_worker.url());
batch_json["bootstrap_host"] = json!(vec![hostname; batch_size]);
batch_json["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
batch_json["bootstrap_room"] = json!(vec![111u64, 222u64, 333u64]);
// Verify batch bootstrap fields
......@@ -306,7 +314,9 @@ mod test_pd_routing {
}
#[test]
fn test_engine_info_hostname_extraction() {
fn test_hostname_extraction() {
use sglang_router_rs::pd_types::get_hostname;
// Test various URL formats
let test_cases = vec![
("http://localhost:8080", "localhost"),
......@@ -318,8 +328,7 @@ mod test_pd_routing {
];
for (url, expected_hostname) in test_cases {
let engine = EngineInfo::new_prefill(url.to_string(), None);
assert_eq!(engine.get_hostname(), expected_hostname);
assert_eq!(get_hostname(url), expected_hostname);
}
}
......@@ -652,6 +661,9 @@ mod test_pd_routing {
#[test]
fn test_bootstrap_injection_with_benchmark_requests() {
use sglang_router_rs::core::{WorkerFactory, WorkerType};
use sglang_router_rs::pd_types::get_hostname;
// Test bootstrap injection with actual benchmark request patterns
let mut benchmark_request = json!({
"input_ids": vec![vec![1, 2, 3, 4]; 16], // Batch size 16
......@@ -664,12 +676,20 @@ mod test_pd_routing {
"stream": true
});
// Simulate bootstrap injection
let prefill_info = EngineInfo::new_prefill("http://prefill:8080".to_string(), Some(9000));
// Create a prefill worker to simulate injection
let prefill_worker =
WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9000));
// Extract bootstrap port from worker type
let bootstrap_port = match prefill_worker.worker_type() {
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
_ => None,
};
let batch_size = 16;
let hostname = get_hostname(prefill_worker.url());
benchmark_request["bootstrap_host"] = json!(vec![prefill_info.get_hostname(); batch_size]);
benchmark_request["bootstrap_port"] = json!(vec![prefill_info.bootstrap_port; batch_size]);
benchmark_request["bootstrap_host"] = json!(vec![hostname; batch_size]);
benchmark_request["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
benchmark_request["bootstrap_room"] =
json!((0..batch_size).map(|_| 12345u64).collect::<Vec<_>>());
......@@ -770,6 +790,9 @@ mod test_pd_routing {
#[test]
fn test_large_batch_bootstrap_injection() {
use sglang_router_rs::core::{WorkerFactory, WorkerType};
use sglang_router_rs::pd_types::get_hostname;
// Test bootstrap injection performance with very large batches
// This simulates the bench_one_batch_server.py scenario
let large_batch_sizes = vec![1024, 4096, 8192];
......@@ -787,14 +810,19 @@ mod test_pd_routing {
"stream": true
});
// Simulate bootstrap injection
let prefill_info =
EngineInfo::new_prefill("http://prefill:8080".to_string(), Some(9000));
// Create a prefill worker to simulate injection
let prefill_worker =
WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9000));
// Extract bootstrap port from worker type
let bootstrap_port = match prefill_worker.worker_type() {
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
_ => None,
};
let hostname = get_hostname(prefill_worker.url());
large_batch_request["bootstrap_host"] =
json!(vec![prefill_info.get_hostname(); batch_size]);
large_batch_request["bootstrap_port"] =
json!(vec![prefill_info.bootstrap_port; batch_size]);
large_batch_request["bootstrap_host"] = json!(vec![hostname; batch_size]);
large_batch_request["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
large_batch_request["bootstrap_room"] = json!((0..batch_size)
.map(|_| rand::thread_rng().gen::<u64>())
.collect::<Vec<_>>());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment