Unverified Commit 2f173ea0 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] allow one router to support different model families and serving mode (#10244)

parent 321fecab
......@@ -46,6 +46,9 @@ class Router:
max_payload_size: Maximum payload size in bytes. Default: 256MB
max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24
dp_aware: Enable data parallelism aware schedule. Default: False
enable_igw: Enable IGW (Inference-Gateway) mode for multi-model support. When enabled,
the router can manage multiple models simultaneously with per-model load balancing
policies. Default: False
api_key: The api key used for the authorization with the worker.
Useful when the dp aware scheduling strategy is enabled.
Default: None
......
......@@ -34,6 +34,7 @@ class RouterArgs:
max_tree_size: int = 2**26
max_payload_size: int = 512 * 1024 * 1024 # 512MB default for large batches
dp_aware: bool = False
enable_igw: bool = False # Enable IGW (Inter-Gateway) mode for multi-model support
api_key: Optional[str] = None
log_dir: Optional[str] = None
log_level: Optional[str] = None
......@@ -227,6 +228,11 @@ class RouterArgs:
action="store_true",
help="Enable data parallelism aware schedule",
)
parser.add_argument(
f"--{prefix}enable-igw",
action="store_true",
help="Enable IGW (Inference-Gateway) mode for multi-model support",
)
parser.add_argument(
f"--{prefix}api-key",
type=str,
......
......@@ -128,6 +128,7 @@ def _popen_launch_router_only(
timeout: float = 120.0,
*,
dp_aware: bool = False,
enable_igw: bool = False,
api_key: str | None = None,
) -> subprocess.Popen:
host, port = _parse_url(base_url)
......@@ -146,6 +147,8 @@ def _popen_launch_router_only(
]
if dp_aware:
cmd += ["--dp-aware"]
if enable_igw:
cmd += ["--enable-igw"]
if api_key is not None:
cmd += ["--api-key", api_key]
cmd += [
......
......@@ -35,7 +35,7 @@ def test_retry_reroutes_to_healthy_worker(router_manager, mock_workers):
)
assert r.status_code == 200
wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id")
assert wid == id_b # should have retried onto healthy worker
assert wid in [id_b, id_c] # should have retried onto a healthy worker (B or C)
# mock_workers fixture handles cleanup
......
......@@ -11,6 +11,7 @@ pub mod error;
pub mod retry;
pub mod token_bucket;
pub mod worker;
pub mod worker_registry;
// Re-export commonly used types at the module level
pub use circuit_breaker::{
......@@ -22,3 +23,4 @@ pub use worker::{
start_health_checker, BasicWorker, ConnectionMode, DPAwareWorker, HealthChecker, HealthConfig,
Worker, WorkerCollection, WorkerFactory, WorkerLoadGuard, WorkerType,
};
pub use worker_registry::{WorkerId, WorkerRegistry, WorkerRegistryStats};
......@@ -155,6 +155,82 @@ pub trait Worker: Send + Sync + fmt::Debug {
fn can_handle(&self, _req: &serde_json::Value) -> bool {
true
}
// === Multi-router support ===
// TODO: - Enhanced Worker Discovery
// The Worker trait should handle async discovery of metadata from the worker itself
// rather than having service discovery or other components query /get_server_info.
// This keeps service discovery decoupled from worker-specific APIs.
//
// Proposed additions:
// - async fn discover_metadata(&mut self) -> Result<(), Error>
// Query /get_server_info and populate metadata labels with model_id, priority, cost, etc.
// - async fn validate_configuration(&self) -> Result<(), Error>
// Ensure worker has required configuration for its mode (e.g., tokenizer for gRPC)
// - Make worker creation async to allow metadata discovery during initialization
//
// This way service discovery just calls router.add_worker() and the worker
// handles its own metadata discovery internally.
/// Get the model ID this worker serves
fn model_id(&self) -> &str {
self.metadata()
.labels
.get("model_id")
.map(|s| s.as_str())
.unwrap_or("unknown")
}
/// Get the priority of this worker (higher value = higher priority)
fn priority(&self) -> u32 {
self.metadata()
.labels
.get("priority")
.and_then(|s| s.parse().ok())
.unwrap_or(50) // Default priority is 50 (mid-range)
}
/// Get the cost factor of this worker (1.0 = baseline)
fn cost(&self) -> f32 {
self.metadata()
.labels
.get("cost")
.and_then(|s| s.parse().ok())
.unwrap_or(1.0)
}
/// Get the tokenizer path for this worker (gRPC mode only)
fn tokenizer_path(&self) -> Option<&str> {
self.metadata()
.labels
.get("tokenizer_path")
.map(|s| s.as_str())
}
/// Get the reasoning parser type for this worker (gRPC mode only)
fn reasoning_parser(&self) -> Option<&str> {
self.metadata()
.labels
.get("reasoning_parser")
.map(|s| s.as_str())
}
/// Get the tool parser type for this worker (gRPC mode only)
fn tool_parser(&self) -> Option<&str> {
self.metadata()
.labels
.get("tool_parser")
.map(|s| s.as_str())
}
/// Get the chat template for this worker (gRPC mode only)
fn chat_template(&self) -> Option<&str> {
self.metadata()
.labels
.get("chat_template")
.map(|s| s.as_str())
}
}
/// Connection mode for worker communication
......@@ -724,6 +800,21 @@ impl WorkerFactory {
)
}
/// Create a regular worker with custom labels (for multi-router support)
pub fn create_regular_with_labels(
url: String,
labels: std::collections::HashMap<String, String>,
circuit_breaker_config: CircuitBreakerConfig,
) -> Box<dyn Worker> {
let mut worker = BasicWorker::new(url.clone(), WorkerType::Regular)
.with_circuit_breaker_config(circuit_breaker_config);
// Add labels to metadata
worker.metadata.labels = labels;
Box::new(worker)
}
/// Create a DP-aware worker of specified type
pub fn create_dp_aware(
base_url: String,
......@@ -941,6 +1032,11 @@ impl fmt::Debug for HealthChecker {
}
impl HealthChecker {
/// Create a new HealthChecker
pub fn new(handle: tokio::task::JoinHandle<()>, shutdown: Arc<AtomicBool>) -> Self {
Self { handle, shutdown }
}
/// Shutdown the health checker gracefully
pub async fn shutdown(self) {
self.shutdown.store(true, Ordering::Release);
......@@ -950,7 +1046,7 @@ impl HealthChecker {
/// Start an async background health checker for a collection of workers
pub fn start_health_checker(
workers: std::sync::Arc<std::sync::RwLock<Vec<Box<dyn Worker>>>>,
workers: std::sync::Arc<std::sync::RwLock<Vec<std::sync::Arc<dyn Worker>>>>,
check_interval_secs: u64,
) -> HealthChecker {
let shutdown = Arc::new(AtomicBool::new(false));
......@@ -1602,9 +1698,11 @@ mod tests {
// Test HealthChecker background task
#[tokio::test]
async fn test_health_checker_startup() {
let workers = Arc::new(RwLock::new(vec![WorkerFactory::create_regular(
let worker = Arc::new(BasicWorker::new(
"http://w1:8080".to_string(),
)]));
WorkerType::Regular,
)) as Arc<dyn Worker>;
let workers = Arc::new(RwLock::new(vec![worker]));
let checker = start_health_checker(workers.clone(), 60);
......@@ -1617,9 +1715,11 @@ mod tests {
#[tokio::test]
async fn test_health_checker_shutdown() {
let workers = Arc::new(RwLock::new(vec![WorkerFactory::create_regular(
let worker = Arc::new(BasicWorker::new(
"http://w1:8080".to_string(),
)]));
WorkerType::Regular,
)) as Arc<dyn Worker>;
let workers = Arc::new(RwLock::new(vec![worker]));
let checker = start_health_checker(workers.clone(), 60);
......
//! Worker Registry for multi-router support
//!
//! Provides centralized registry for workers with model-based indexing
use crate::core::{ConnectionMode, Worker, WorkerType};
use dashmap::DashMap;
use std::sync::{Arc, RwLock};
use uuid::Uuid;
/// Unique identifier for a worker
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct WorkerId(String);
impl WorkerId {
/// Create a new worker ID
pub fn new() -> Self {
Self(Uuid::new_v4().to_string())
}
/// Create a worker ID from a string
pub fn from_string(s: String) -> Self {
Self(s)
}
/// Get the ID as a string
pub fn as_str(&self) -> &str {
&self.0
}
}
impl Default for WorkerId {
fn default() -> Self {
Self::new()
}
}
/// Type alias for the model index to reduce complexity
type ModelIndex = Arc<DashMap<String, Arc<RwLock<Vec<Arc<dyn Worker>>>>>>;
/// Worker registry with model-based indexing
#[derive(Debug)]
pub struct WorkerRegistry {
/// All workers indexed by ID
workers: Arc<DashMap<WorkerId, Arc<dyn Worker>>>,
/// Workers indexed by model ID (stores WorkerId for reference)
model_workers: Arc<DashMap<String, Vec<WorkerId>>>,
/// Optimized model index for O(1) lookups (stores Arc<dyn Worker> directly)
model_index: ModelIndex,
/// Workers indexed by worker type
type_workers: Arc<DashMap<WorkerType, Vec<WorkerId>>>,
/// Workers indexed by connection mode
connection_workers: Arc<DashMap<ConnectionMode, Vec<WorkerId>>>,
/// URL to worker ID mapping (for backward compatibility)
url_to_id: Arc<DashMap<String, WorkerId>>,
}
impl WorkerRegistry {
/// Create a new worker registry
pub fn new() -> Self {
Self {
workers: Arc::new(DashMap::new()),
model_workers: Arc::new(DashMap::new()),
model_index: Arc::new(DashMap::new()),
type_workers: Arc::new(DashMap::new()),
connection_workers: Arc::new(DashMap::new()),
url_to_id: Arc::new(DashMap::new()),
}
}
/// Register a new worker
pub fn register(&self, worker: Arc<dyn Worker>) -> WorkerId {
let worker_id = if let Some(existing_id) = self.url_to_id.get(worker.url()) {
// Worker with this URL already exists, update it
existing_id.clone()
} else {
WorkerId::new()
};
// Store worker
self.workers.insert(worker_id.clone(), worker.clone());
// Update URL mapping
self.url_to_id
.insert(worker.url().to_string(), worker_id.clone());
// Update model index (both ID-based and optimized)
let model_id = worker.model_id().to_string();
self.model_workers
.entry(model_id.clone())
.or_default()
.push(worker_id.clone());
// Update optimized model index for O(1) lookups
self.model_index
.entry(model_id)
.or_insert_with(|| Arc::new(RwLock::new(Vec::new())))
.write()
.expect("RwLock for model_index is poisoned")
.push(worker.clone());
// Update type index
self.type_workers
.entry(worker.worker_type())
.or_default()
.push(worker_id.clone());
// Update connection mode index
self.connection_workers
.entry(worker.connection_mode())
.or_default()
.push(worker_id.clone());
worker_id
}
/// Remove a worker by ID
pub fn remove(&self, worker_id: &WorkerId) -> Option<Arc<dyn Worker>> {
if let Some((_, worker)) = self.workers.remove(worker_id) {
// Remove from URL mapping
self.url_to_id.remove(worker.url());
// Remove from model index (both ID-based and optimized)
if let Some(mut model_workers) = self.model_workers.get_mut(worker.model_id()) {
model_workers.retain(|id| id != worker_id);
}
// Remove from optimized model index
if let Some(model_index_entry) = self.model_index.get(worker.model_id()) {
let worker_url = worker.url();
model_index_entry
.write()
.expect("RwLock for model_index is poisoned")
.retain(|w| w.url() != worker_url);
}
// Remove from type index
if let Some(mut type_workers) = self.type_workers.get_mut(&worker.worker_type()) {
type_workers.retain(|id| id != worker_id);
}
// Remove from connection mode index
if let Some(mut conn_workers) =
self.connection_workers.get_mut(&worker.connection_mode())
{
conn_workers.retain(|id| id != worker_id);
}
Some(worker)
} else {
None
}
}
/// Remove a worker by URL
pub fn remove_by_url(&self, url: &str) -> Option<Arc<dyn Worker>> {
if let Some((_, worker_id)) = self.url_to_id.remove(url) {
self.remove(&worker_id)
} else {
None
}
}
/// Get a worker by ID
pub fn get(&self, worker_id: &WorkerId) -> Option<Arc<dyn Worker>> {
self.workers.get(worker_id).map(|entry| entry.clone())
}
/// Get a worker by URL
pub fn get_by_url(&self, url: &str) -> Option<Arc<dyn Worker>> {
self.url_to_id.get(url).and_then(|id| self.get(&id))
}
/// Get all workers for a model
pub fn get_by_model(&self, model_id: &str) -> Vec<Arc<dyn Worker>> {
self.model_workers
.get(model_id)
.map(|ids| ids.iter().filter_map(|id| self.get(id)).collect())
.unwrap_or_default()
}
/// Get all workers for a model (O(1) optimized version)
/// This method uses the pre-indexed model_index for fast lookups
pub fn get_by_model_fast(&self, model_id: &str) -> Vec<Arc<dyn Worker>> {
self.model_index
.get(model_id)
.map(|workers| {
workers
.read()
.expect("RwLock for model_index is poisoned")
.clone()
})
.unwrap_or_default()
}
/// Get all workers by worker type
pub fn get_by_type(&self, worker_type: &WorkerType) -> Vec<Arc<dyn Worker>> {
self.type_workers
.get(worker_type)
.map(|ids| ids.iter().filter_map(|id| self.get(id)).collect())
.unwrap_or_default()
}
/// Get all prefill workers (regardless of bootstrap_port)
pub fn get_prefill_workers(&self) -> Vec<Arc<dyn Worker>> {
self.workers
.iter()
.filter_map(|entry| {
let worker = entry.value();
match worker.worker_type() {
WorkerType::Prefill { .. } => Some(worker.clone()),
_ => None,
}
})
.collect()
}
/// Get all decode workers
pub fn get_decode_workers(&self) -> Vec<Arc<dyn Worker>> {
self.get_by_type(&WorkerType::Decode)
}
/// Get all workers by connection mode
pub fn get_by_connection(&self, connection_mode: &ConnectionMode) -> Vec<Arc<dyn Worker>> {
self.connection_workers
.get(connection_mode)
.map(|ids| ids.iter().filter_map(|id| self.get(id)).collect())
.unwrap_or_default()
}
/// Get all workers
pub fn get_all(&self) -> Vec<Arc<dyn Worker>> {
self.workers
.iter()
.map(|entry| entry.value().clone())
.collect()
}
/// Get all workers with their IDs
pub fn get_all_with_ids(&self) -> Vec<(WorkerId, Arc<dyn Worker>)> {
self.workers
.iter()
.map(|entry| (entry.key().clone(), entry.value().clone()))
.collect()
}
/// Get all worker URLs
pub fn get_all_urls(&self) -> Vec<String> {
self.workers
.iter()
.map(|entry| entry.value().url().to_string())
.collect()
}
/// Get all model IDs with workers
pub fn get_models(&self) -> Vec<String> {
self.model_workers
.iter()
.filter(|entry| !entry.value().is_empty())
.map(|entry| entry.key().clone())
.collect()
}
/// Get workers filtered by multiple criteria
///
/// This method allows flexible filtering of workers based on:
/// - model_id: Filter by specific model
/// - worker_type: Filter by worker type (Regular, Prefill, Decode)
/// - connection_mode: Filter by connection mode (Http, Grpc)
/// - healthy_only: Only return healthy workers
pub fn get_workers_filtered(
&self,
model_id: Option<&str>,
worker_type: Option<WorkerType>,
connection_mode: Option<ConnectionMode>,
healthy_only: bool,
) -> Vec<Arc<dyn Worker>> {
// Start with the most efficient collection based on filters
// Use model index when possible as it's O(1) lookup
let workers = if let Some(model) = model_id {
self.get_by_model_fast(model)
} else {
self.get_all()
};
// Apply remaining filters
workers
.into_iter()
.filter(|w| {
// Check worker_type if specified
if let Some(ref wtype) = worker_type {
if w.worker_type() != *wtype {
return false;
}
}
// Check connection_mode if specified
if let Some(ref conn) = connection_mode {
if w.connection_mode() != *conn {
return false;
}
}
// Check health if required
if healthy_only && !w.is_healthy() {
return false;
}
true
})
.collect()
}
/// Get worker statistics
pub fn stats(&self) -> WorkerRegistryStats {
let total_workers = self.workers.len();
let total_models = self.get_models().len();
let mut healthy_count = 0;
let mut total_load = 0;
let mut regular_count = 0;
let mut prefill_count = 0;
let mut decode_count = 0;
for worker in self.get_all() {
if worker.is_healthy() {
healthy_count += 1;
}
total_load += worker.load();
match worker.worker_type() {
WorkerType::Regular => regular_count += 1,
WorkerType::Prefill { .. } => prefill_count += 1,
WorkerType::Decode => decode_count += 1,
}
}
WorkerRegistryStats {
total_workers,
total_models,
healthy_workers: healthy_count,
total_load,
regular_workers: regular_count,
prefill_workers: prefill_count,
decode_workers: decode_count,
}
}
/// Start a health checker for all workers in the registry
/// This should be called once after the registry is populated with workers
pub fn start_health_checker(&self, check_interval_secs: u64) -> crate::core::HealthChecker {
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
let shutdown = Arc::new(AtomicBool::new(false));
let shutdown_clone = shutdown.clone();
let workers_ref = self.workers.clone();
let handle = tokio::spawn(async move {
let mut interval =
tokio::time::interval(tokio::time::Duration::from_secs(check_interval_secs));
// Counter for periodic load reset (every 10 health check cycles)
let mut check_count = 0u64;
const LOAD_RESET_INTERVAL: u64 = 10;
loop {
interval.tick().await;
// Check for shutdown signal
if shutdown_clone.load(Ordering::Acquire) {
tracing::debug!("Registry health checker shutting down");
break;
}
// Get all workers from registry
let workers: Vec<Arc<dyn crate::core::Worker>> = workers_ref
.iter()
.map(|entry| entry.value().clone())
.collect();
// Perform health checks
for worker in &workers {
let _ = worker.check_health_async().await; // Use async version directly
}
// Reset loads periodically
check_count += 1;
if check_count % LOAD_RESET_INTERVAL == 0 {
tracing::debug!("Resetting worker loads (cycle {})", check_count);
for worker in &workers {
worker.reset_load();
}
}
}
});
crate::core::HealthChecker::new(handle, shutdown)
}
}
impl Default for WorkerRegistry {
fn default() -> Self {
Self::new()
}
}
/// Statistics for the worker registry
#[derive(Debug, Clone)]
pub struct WorkerRegistryStats {
pub total_workers: usize,
pub total_models: usize,
pub healthy_workers: usize,
pub total_load: usize,
pub regular_workers: usize,
pub prefill_workers: usize,
pub decode_workers: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::{CircuitBreakerConfig, WorkerFactory};
use std::collections::HashMap;
#[test]
fn test_worker_registry() {
let registry = WorkerRegistry::new();
// Create a worker with labels
let mut labels = HashMap::new();
labels.insert("model_id".to_string(), "llama-3-8b".to_string());
labels.insert("priority".to_string(), "50".to_string());
labels.insert("cost".to_string(), "0.8".to_string());
let worker = WorkerFactory::create_regular_with_labels(
"http://worker1:8080".to_string(),
labels,
CircuitBreakerConfig::default(),
);
// Register worker (WorkerFactory returns Box<dyn Worker>, convert to Arc)
let worker_id = registry.register(Arc::from(worker));
// Verify registration
assert!(registry.get(&worker_id).is_some());
assert!(registry.get_by_url("http://worker1:8080").is_some());
assert_eq!(registry.get_by_model("llama-3-8b").len(), 1);
assert_eq!(registry.get_by_type(&WorkerType::Regular).len(), 1);
assert_eq!(registry.get_by_connection(&ConnectionMode::Http).len(), 1);
// Test stats
let stats = registry.stats();
assert_eq!(stats.total_workers, 1);
assert_eq!(stats.total_models, 1);
// Remove worker
registry.remove(&worker_id);
assert!(registry.get(&worker_id).is_none());
}
#[test]
fn test_model_index_fast_lookup() {
let registry = WorkerRegistry::new();
// Create workers for different models
let mut labels1 = HashMap::new();
labels1.insert("model_id".to_string(), "llama-3".to_string());
let worker1 = WorkerFactory::create_regular_with_labels(
"http://worker1:8080".to_string(),
labels1,
CircuitBreakerConfig::default(),
);
let mut labels2 = HashMap::new();
labels2.insert("model_id".to_string(), "llama-3".to_string());
let worker2 = WorkerFactory::create_regular_with_labels(
"http://worker2:8080".to_string(),
labels2,
CircuitBreakerConfig::default(),
);
let mut labels3 = HashMap::new();
labels3.insert("model_id".to_string(), "gpt-4".to_string());
let worker3 = WorkerFactory::create_regular_with_labels(
"http://worker3:8080".to_string(),
labels3,
CircuitBreakerConfig::default(),
);
// Register workers
registry.register(Arc::from(worker1));
registry.register(Arc::from(worker2));
registry.register(Arc::from(worker3));
// Test get_by_model_fast for llama-3
let llama_workers = registry.get_by_model_fast("llama-3");
assert_eq!(llama_workers.len(), 2);
let urls: Vec<String> = llama_workers.iter().map(|w| w.url().to_string()).collect();
assert!(urls.contains(&"http://worker1:8080".to_string()));
assert!(urls.contains(&"http://worker2:8080".to_string()));
// Test get_by_model_fast for gpt-4
let gpt_workers = registry.get_by_model_fast("gpt-4");
assert_eq!(gpt_workers.len(), 1);
assert_eq!(gpt_workers[0].url(), "http://worker3:8080");
// Test get_by_model_fast for non-existent model
let unknown_workers = registry.get_by_model_fast("unknown-model");
assert_eq!(unknown_workers.len(), 0);
// Test that both get_by_model and get_by_model_fast return same results
let llama_workers_slow = registry.get_by_model("llama-3");
assert_eq!(llama_workers.len(), llama_workers_slow.len());
// Test removal updates the model index
registry.remove_by_url("http://worker1:8080");
let llama_workers_after = registry.get_by_model_fast("llama-3");
assert_eq!(llama_workers_after.len(), 1);
assert_eq!(llama_workers_after[0].url(), "http://worker2:8080");
}
}
......@@ -63,6 +63,7 @@ use super::{get_healthy_worker_indices, CacheAwareConfig, LoadBalancingPolicy};
use crate::core::Worker;
use crate::metrics::RouterMetrics;
use crate::tree::Tree;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Duration;
......@@ -72,10 +73,11 @@ use tracing::debug;
///
/// Routes requests based on cache affinity when load is balanced,
/// switches to shortest-queue routing when load is imbalanced.
/// Maintains separate trees per model for multi-model support.
#[derive(Debug)]
pub struct CacheAwarePolicy {
config: CacheAwareConfig,
tree: Arc<Mutex<Tree>>,
trees: Arc<Mutex<HashMap<String, Tree>>>, // model_id -> Tree
eviction_handle: Option<thread::JoinHandle<()>>,
}
......@@ -85,20 +87,26 @@ impl CacheAwarePolicy {
}
pub fn with_config(config: CacheAwareConfig) -> Self {
let tree = Arc::new(Mutex::new(Tree::new()));
let trees = Arc::new(Mutex::new(HashMap::<String, Tree>::new()));
// Start background eviction thread if configured
let eviction_handle = if config.eviction_interval_secs > 0 {
let tree_clone = Arc::clone(&tree);
let trees_clone = Arc::clone(&trees);
let max_tree_size = config.max_tree_size;
let interval = config.eviction_interval_secs;
Some(thread::spawn(move || loop {
thread::sleep(Duration::from_secs(interval));
if let Ok(tree_guard) = tree_clone.lock() {
tree_guard.evict_tenant_by_size(max_tree_size);
debug!("Cache eviction completed, max_size: {}", max_tree_size);
if let Ok(mut trees_guard) = trees_clone.lock() {
// Evict for all model trees
for (model_id, tree) in trees_guard.iter_mut() {
tree.evict_tenant_by_size(max_tree_size);
debug!(
"Cache eviction completed for model {}, max_size: {}",
model_id, max_tree_size
);
}
}
}))
} else {
......@@ -107,38 +115,97 @@ impl CacheAwarePolicy {
Self {
config,
tree,
trees,
eviction_handle,
}
}
/// Initialize the tree with worker URLs (used only during initial setup)
pub fn init_workers(&self, workers: &[Box<dyn Worker>]) {
if let Ok(tree) = self.tree.lock() {
pub fn init_workers(&self, workers: &[Arc<dyn Worker>]) {
if let Ok(mut trees) = self.trees.lock() {
// Group workers by model
let mut model_workers: HashMap<String, Vec<&Arc<dyn Worker>>> = HashMap::new();
for worker in workers {
tree.insert("", worker.url());
// Use "default" for unknown/empty model_ids for backward compatibility
let model_id = worker.model_id();
let tree_key = if model_id.is_empty() || model_id == "unknown" {
"default".to_string()
} else {
model_id.to_string()
};
model_workers.entry(tree_key).or_default().push(worker);
}
// Initialize tree for each model
for (tree_key, model_workers) in model_workers {
let tree = trees.entry(tree_key).or_insert_with(Tree::new);
for worker in model_workers {
tree.insert("", worker.url());
}
}
}
}
/// Add a single worker to the tree (incremental update)
pub fn add_worker(&self, url: &str) {
if let Ok(tree) = self.tree.lock() {
pub fn add_worker(&self, worker: &dyn Worker) {
if let Ok(mut trees) = self.trees.lock() {
// For backward compatibility: if model_id is "unknown" or empty,
// use a default tree. This preserves existing behavior for single-model routers.
let model_id = worker.model_id();
let tree_key = if model_id.is_empty() || model_id == "unknown" {
"default".to_string()
} else {
model_id.to_string()
};
let tree = trees.entry(tree_key).or_insert_with(Tree::new);
tree.insert("", worker.url());
}
}
/// Add a worker by URL and model (for backward compatibility)
pub fn add_worker_by_url(&self, url: &str, model_id: &str) {
if let Ok(mut trees) = self.trees.lock() {
let tree = trees.entry(model_id.to_string()).or_insert_with(Tree::new);
tree.insert("", url);
}
}
/// Remove a worker from the tree
pub fn remove_worker(&self, url: &str) {
if let Ok(tree) = self.tree.lock() {
tree.remove_tenant(url);
pub fn remove_worker(&self, worker: &dyn Worker) {
if let Ok(mut trees) = self.trees.lock() {
// Use same logic as add_worker for consistency
let model_id = worker.model_id();
let tree_key = if model_id.is_empty() || model_id == "unknown" {
"default".to_string()
} else {
model_id.to_string()
};
if let Some(tree) = trees.get_mut(&tree_key) {
tree.remove_tenant(worker.url());
}
}
}
/// Remove a worker by URL (removes from all model trees for backward compatibility)
pub fn remove_worker_by_url(&self, url: &str) {
if let Ok(mut trees) = self.trees.lock() {
// Remove from all trees since we don't know which model it belongs to
for (_model_id, tree) in trees.iter_mut() {
tree.remove_tenant(url);
}
}
}
/// Run cache eviction to prevent unbounded growth
pub fn evict_cache(&self, max_size: usize) {
if let Ok(tree) = self.tree.lock() {
tree.evict_tenant_by_size(max_size);
if let Ok(mut trees) = self.trees.lock() {
for (model_id, tree) in trees.iter_mut() {
tree.evict_tenant_by_size(max_size);
debug!(
"Cache eviction for model {}, max_size: {}",
model_id, max_size
);
}
}
}
}
......@@ -146,7 +213,7 @@ impl CacheAwarePolicy {
impl LoadBalancingPolicy for CacheAwarePolicy {
fn select_worker(
&self,
workers: &[Box<dyn Worker>],
workers: &[Arc<dyn Worker>],
request_text: Option<&str>,
) -> Option<usize> {
let healthy_indices = get_healthy_worker_indices(workers);
......@@ -155,6 +222,18 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
return None;
}
// Group workers by model (using "default" for unknown/empty model_ids)
let mut model_workers: HashMap<String, Vec<usize>> = HashMap::new();
for idx in &healthy_indices {
let model_id = workers[*idx].model_id();
let tree_key = if model_id.is_empty() || model_id == "unknown" {
"default".to_string()
} else {
model_id.to_string()
};
model_workers.entry(tree_key).or_default().push(*idx);
}
// Get current load statistics
let loads: Vec<usize> = workers.iter().map(|w| w.load()).collect();
let max_load = *loads.iter().max().unwrap_or(&0);
......@@ -187,7 +266,14 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
// Even in imbalanced mode, update the tree to maintain cache state
if let Some(text) = request_text {
if let Ok(tree) = self.tree.lock() {
if let Ok(mut trees) = self.trees.lock() {
let model_id = workers[min_load_idx].model_id();
let tree_key = if model_id.is_empty() || model_id == "unknown" {
"default".to_string()
} else {
model_id.to_string()
};
let tree = trees.entry(tree_key).or_insert_with(Tree::new);
tree.insert(text, workers[min_load_idx].url());
}
}
......@@ -203,43 +289,85 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
// Use cache-aware routing when balanced
let text = request_text.unwrap_or("");
if let Ok(tree) = self.tree.lock() {
let (matched_text, matched_worker) = tree.prefix_match(text);
let match_rate = if text.is_empty() {
0.0
} else {
matched_text.chars().count() as f32 / text.chars().count() as f32
};
if let Ok(mut trees) = self.trees.lock() {
let mut best_match_idx: Option<usize> = None;
let mut best_match_rate: f32 = 0.0;
// Find best match across all models
for (model_id, worker_indices) in &model_workers {
let tree = trees.entry(model_id.clone()).or_insert_with(Tree::new);
let (matched_text, matched_worker) = tree.prefix_match(text);
let match_rate = if text.is_empty() {
0.0
} else {
matched_text.chars().count() as f32 / text.chars().count() as f32
};
// Check if this model has the best match
if match_rate > best_match_rate {
// Find the worker index for this URL
if let Some(idx) = worker_indices
.iter()
.find(|&&idx| workers[idx].url() == matched_worker)
{
best_match_idx = Some(*idx);
best_match_rate = match_rate;
}
}
}
let selected_url = if match_rate > self.config.cache_threshold {
// Select worker based on cache threshold
let selected_idx = if let (Some(idx), true) = (
best_match_idx,
best_match_rate > self.config.cache_threshold,
) {
RouterMetrics::record_cache_hit();
matched_worker.to_string()
idx
} else {
RouterMetrics::record_cache_miss();
tree.get_smallest_tenant()
};
// Find the index of the selected worker
if let Some(selected_idx) = workers.iter().position(|w| w.url() == selected_url) {
// Only proceed if the worker is healthy
if workers[selected_idx].is_healthy() {
// Update the tree with this request
tree.insert(text, &selected_url);
// Increment processed counter
workers[selected_idx].increment_processed();
RouterMetrics::record_processed_request(&selected_url);
// Find model with smallest tree (most cache capacity)
let mut smallest_tree_model = String::new();
let mut smallest_tree_size = usize::MAX;
for model_id in model_workers.keys() {
let tree = trees.entry(model_id.clone()).or_insert_with(Tree::new);
let size = tree.get_used_size_per_tenant().values().sum::<usize>();
if size < smallest_tree_size {
smallest_tree_size = size;
smallest_tree_model = model_id.clone();
}
}
return Some(selected_idx);
// Select least loaded worker from model with most cache capacity
if let Some(worker_indices) = model_workers.get(&smallest_tree_model) {
worker_indices
.iter()
.min_by_key(|&&idx| workers[idx].load())
.copied()
.unwrap_or(healthy_indices[0])
} else {
healthy_indices[0]
}
};
// Update the tree with this request
let model_id = workers[selected_idx].model_id();
let tree_key = if model_id.is_empty() || model_id == "unknown" {
"default".to_string()
} else {
// Selected worker no longer exists, remove it from tree
tree.remove_tenant(&selected_url);
debug!("Removed stale worker {} from cache tree", selected_url);
}
model_id.to_string()
};
let tree = trees.entry(tree_key).or_insert_with(Tree::new);
tree.insert(text, workers[selected_idx].url());
// Increment processed counter
workers[selected_idx].increment_processed();
RouterMetrics::record_processed_request(workers[selected_idx].url());
RouterMetrics::record_policy_decision(self.name(), workers[selected_idx].url());
// Fallback to first healthy worker
return healthy_indices.first().copied();
return Some(selected_idx);
}
// Fallback to first healthy worker if tree operations fail
......@@ -272,8 +400,8 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
fn select_worker_pair(
&self,
prefill_workers: &[Box<dyn Worker>],
decode_workers: &[Box<dyn Worker>],
prefill_workers: &[Arc<dyn Worker>],
decode_workers: &[Arc<dyn Worker>],
request_text: Option<&str>,
) -> Option<(usize, usize)> {
// DEPRECATED: This method is no longer used when separate policies are configured.
......@@ -333,12 +461,12 @@ mod tests {
..Default::default()
};
let policy = CacheAwarePolicy::with_config(config);
let workers: Vec<Box<dyn Worker>> = vec![
Box::new(BasicWorker::new(
let workers: Vec<Arc<dyn Worker>> = vec![
Arc::new(BasicWorker::new(
"http://w1:8000".to_string(),
WorkerType::Regular,
)),
Box::new(BasicWorker::new(
Arc::new(BasicWorker::new(
"http://w2:8000".to_string(),
WorkerType::Regular,
)),
......@@ -378,7 +506,7 @@ mod tests {
}
// worker2 has load 0
let workers: Vec<Box<dyn Worker>> = vec![Box::new(worker1), Box::new(worker2)];
let workers: Vec<Arc<dyn Worker>> = vec![Arc::new(worker1), Arc::new(worker2)];
policy.init_workers(&workers);
// Should select worker2 (lower load) despite cache affinity
......@@ -395,12 +523,12 @@ mod tests {
..Default::default()
};
let policy = CacheAwarePolicy::with_config(config);
let workers: Vec<Box<dyn Worker>> = vec![
Box::new(BasicWorker::new(
let workers: Vec<Arc<dyn Worker>> = vec![
Arc::new(BasicWorker::new(
"http://w1:8000".to_string(),
WorkerType::Regular,
)),
Box::new(BasicWorker::new(
Arc::new(BasicWorker::new(
"http://w2:8000".to_string(),
WorkerType::Regular,
)),
......@@ -413,7 +541,7 @@ mod tests {
policy.select_worker(&workers, Some("test2"));
// Remove a worker
policy.remove_worker("http://w1:8000");
policy.remove_worker_by_url("http://w1:8000");
workers[0].set_healthy(false);
// All requests should now go to worker2
......
......@@ -5,17 +5,20 @@
use crate::core::Worker;
use std::fmt::Debug;
use std::sync::Arc;
mod cache_aware;
mod factory;
mod power_of_two;
mod random;
mod registry;
mod round_robin;
pub use cache_aware::CacheAwarePolicy;
pub use factory::PolicyFactory;
pub use power_of_two::PowerOfTwoPolicy;
pub use random::RandomPolicy;
pub use registry::PolicyRegistry;
pub use round_robin::RoundRobinPolicy;
/// Core trait for load balancing policies
......@@ -26,9 +29,10 @@ pub trait LoadBalancingPolicy: Send + Sync + Debug {
/// Select a single worker from the available workers
///
/// This is used for regular routing mode where requests go to a single worker.
/// Now uses Arc<dyn Worker> for better performance and to avoid unnecessary cloning.
fn select_worker(
&self,
workers: &[Box<dyn Worker>],
workers: &[Arc<dyn Worker>],
request_text: Option<&str>,
) -> Option<usize>;
......@@ -38,8 +42,8 @@ pub trait LoadBalancingPolicy: Send + Sync + Debug {
/// Default implementation uses select_worker for each array independently.
fn select_worker_pair(
&self,
prefill_workers: &[Box<dyn Worker>],
decode_workers: &[Box<dyn Worker>],
prefill_workers: &[Arc<dyn Worker>],
decode_workers: &[Arc<dyn Worker>],
request_text: Option<&str>,
) -> Option<(usize, usize)> {
// Default implementation: independently select from each pool
......@@ -105,7 +109,7 @@ impl Default for CacheAwareConfig {
}
/// Helper function to filter healthy workers and return their indices
pub(crate) fn get_healthy_worker_indices(workers: &[Box<dyn Worker>]) -> Vec<usize> {
pub(crate) fn get_healthy_worker_indices(workers: &[Arc<dyn Worker>]) -> Vec<usize> {
workers
.iter()
.enumerate()
......@@ -121,16 +125,16 @@ mod tests {
#[test]
fn test_get_healthy_worker_indices() {
let workers: Vec<Box<dyn Worker>> = vec![
Box::new(BasicWorker::new(
let workers: Vec<Arc<dyn Worker>> = vec![
Arc::new(BasicWorker::new(
"http://w1:8000".to_string(),
WorkerType::Regular,
)),
Box::new(BasicWorker::new(
Arc::new(BasicWorker::new(
"http://w2:8000".to_string(),
WorkerType::Regular,
)),
Box::new(BasicWorker::new(
Arc::new(BasicWorker::new(
"http://w3:8000".to_string(),
WorkerType::Regular,
)),
......
......@@ -5,7 +5,7 @@ use crate::core::Worker;
use crate::metrics::RouterMetrics;
use rand::Rng;
use std::collections::HashMap;
use std::sync::RwLock;
use std::sync::{Arc, RwLock};
use tracing::info;
/// Power-of-two choices policy
......@@ -41,7 +41,7 @@ impl PowerOfTwoPolicy {
impl LoadBalancingPolicy for PowerOfTwoPolicy {
fn select_worker(
&self,
workers: &[Box<dyn Worker>],
workers: &[Arc<dyn Worker>],
_request_text: Option<&str>,
) -> Option<usize> {
let healthy_indices = get_healthy_worker_indices(workers);
......@@ -137,8 +137,8 @@ mod tests {
}
// worker3 has load 0
let workers: Vec<Box<dyn Worker>> =
vec![Box::new(worker1), Box::new(worker2), Box::new(worker3)];
let workers: Vec<Arc<dyn Worker>> =
vec![Arc::new(worker1), Arc::new(worker2), Arc::new(worker3)];
// Run multiple selections
let mut selected_counts = [0; 3];
......@@ -156,12 +156,12 @@ mod tests {
#[test]
fn test_power_of_two_with_cached_loads() {
let policy = PowerOfTwoPolicy::new();
let workers: Vec<Box<dyn Worker>> = vec![
Box::new(BasicWorker::new(
let workers: Vec<Arc<dyn Worker>> = vec![
Arc::new(BasicWorker::new(
"http://w1:8000".to_string(),
WorkerType::Regular,
)),
Box::new(BasicWorker::new(
Arc::new(BasicWorker::new(
"http://w2:8000".to_string(),
WorkerType::Regular,
)),
......@@ -190,7 +190,7 @@ mod tests {
#[test]
fn test_power_of_two_single_worker() {
let policy = PowerOfTwoPolicy::new();
let workers: Vec<Box<dyn Worker>> = vec![Box::new(BasicWorker::new(
let workers: Vec<Arc<dyn Worker>> = vec![Arc::new(BasicWorker::new(
"http://w1:8000".to_string(),
WorkerType::Regular,
))];
......
......@@ -4,6 +4,7 @@ use super::{get_healthy_worker_indices, LoadBalancingPolicy};
use crate::core::Worker;
use crate::metrics::RouterMetrics;
use rand::Rng;
use std::sync::Arc;
/// Random selection policy
///
......@@ -20,7 +21,7 @@ impl RandomPolicy {
impl LoadBalancingPolicy for RandomPolicy {
fn select_worker(
&self,
workers: &[Box<dyn Worker>],
workers: &[Arc<dyn Worker>],
_request_text: Option<&str>,
) -> Option<usize> {
let healthy_indices = get_healthy_worker_indices(workers);
......@@ -56,16 +57,16 @@ mod tests {
#[test]
fn test_random_selection() {
let policy = RandomPolicy::new();
let workers: Vec<Box<dyn Worker>> = vec![
Box::new(BasicWorker::new(
let workers: Vec<Arc<dyn Worker>> = vec![
Arc::new(BasicWorker::new(
"http://w1:8000".to_string(),
WorkerType::Regular,
)),
Box::new(BasicWorker::new(
Arc::new(BasicWorker::new(
"http://w2:8000".to_string(),
WorkerType::Regular,
)),
Box::new(BasicWorker::new(
Arc::new(BasicWorker::new(
"http://w3:8000".to_string(),
WorkerType::Regular,
)),
......@@ -87,12 +88,12 @@ mod tests {
#[test]
fn test_random_with_unhealthy_workers() {
let policy = RandomPolicy::new();
let workers: Vec<Box<dyn Worker>> = vec![
Box::new(BasicWorker::new(
let workers: Vec<Arc<dyn Worker>> = vec![
Arc::new(BasicWorker::new(
"http://w1:8000".to_string(),
WorkerType::Regular,
)),
Box::new(BasicWorker::new(
Arc::new(BasicWorker::new(
"http://w2:8000".to_string(),
WorkerType::Regular,
)),
......@@ -110,7 +111,7 @@ mod tests {
#[test]
fn test_random_no_healthy_workers() {
let policy = RandomPolicy::new();
let workers: Vec<Box<dyn Worker>> = vec![Box::new(BasicWorker::new(
let workers: Vec<Arc<dyn Worker>> = vec![Arc::new(BasicWorker::new(
"http://w1:8000".to_string(),
WorkerType::Regular,
))];
......
/// Policy Registry for managing model-to-policy mappings
///
/// This registry manages the dynamic assignment of load balancing policies to models.
/// When the first worker of a new model is added, it determines the policy for that model.
/// All subsequent workers of the same model use the established policy.
/// When the last worker of a model is removed, the policy mapping is cleaned up.
use super::{
CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy, PowerOfTwoPolicy, RandomPolicy,
RoundRobinPolicy,
};
use crate::config::types::PolicyConfig;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tracing::{debug, info, warn};
/// Registry for managing model-to-policy mappings
#[derive(Clone)]
pub struct PolicyRegistry {
/// Model ID -> Policy instance mapping
model_policies: Arc<RwLock<HashMap<String, Arc<dyn LoadBalancingPolicy>>>>,
/// Model ID -> Worker count for cleanup tracking
model_worker_counts: Arc<RwLock<HashMap<String, usize>>>,
/// Default policy instance (cached)
default_policy: Arc<dyn LoadBalancingPolicy>,
/// Prefill policy for PD mode
prefill_policy: Arc<RwLock<Option<Arc<dyn LoadBalancingPolicy>>>>,
/// Decode policy for PD mode
decode_policy: Arc<RwLock<Option<Arc<dyn LoadBalancingPolicy>>>>,
}
impl PolicyRegistry {
/// Create a new PolicyRegistry with a default policy
pub fn new(default_policy_config: PolicyConfig) -> Self {
let default_policy = Self::create_policy_from_config(&default_policy_config);
Self {
model_policies: Arc::new(RwLock::new(HashMap::new())),
model_worker_counts: Arc::new(RwLock::new(HashMap::new())),
default_policy,
prefill_policy: Arc::new(RwLock::new(None)),
decode_policy: Arc::new(RwLock::new(None)),
}
}
/// Called when a worker is added
/// Returns the policy that should be used for this worker's model
pub fn on_worker_added(
&self,
model_id: &str,
policy_hint: Option<&str>,
) -> Arc<dyn LoadBalancingPolicy> {
// Increment worker count
{
let mut counts = self.model_worker_counts.write().unwrap();
*counts.entry(model_id.to_string()).or_insert(0) += 1;
debug!(
"Worker added for model {}, count: {}",
model_id,
counts.get(model_id).unwrap()
);
}
// Check if model already has a policy
{
let policies = self.model_policies.read().unwrap();
if let Some(existing_policy) = policies.get(model_id) {
debug!(
"Model {} already has policy: {}",
model_id,
existing_policy.name()
);
return Arc::clone(existing_policy);
}
}
// New model - determine policy
let policy = self.determine_policy_for_model(model_id, policy_hint);
info!(
"Assigning policy {} to new model {}",
policy.name(),
model_id
);
// Store policy for this model
{
let mut policies = self.model_policies.write().unwrap();
policies.insert(model_id.to_string(), Arc::clone(&policy));
}
policy
}
/// Called when a worker is removed
pub fn on_worker_removed(&self, model_id: &str) {
let should_cleanup = {
let mut counts = self.model_worker_counts.write().unwrap();
if let Some(count) = counts.get_mut(model_id) {
*count = count.saturating_sub(1);
debug!("Worker removed for model {}, count: {}", model_id, *count);
if *count == 0 {
counts.remove(model_id);
true
} else {
false
}
} else {
warn!(
"Attempted to remove worker for model {} with no registered workers",
model_id
);
false
}
};
// Clean up policy if this was the last worker
if should_cleanup {
let mut policies = self.model_policies.write().unwrap();
if let Some(policy) = policies.remove(model_id) {
info!(
"Removed policy {} for model {} (last worker removed)",
policy.name(),
model_id
);
// Policy will be dropped here, cleaning up any resources
drop(policy);
}
}
}
/// Get the policy for a model
pub fn get_policy(&self, model_id: &str) -> Option<Arc<dyn LoadBalancingPolicy>> {
self.model_policies.read().unwrap().get(model_id).cloned()
}
/// Get the default policy
pub fn get_default_policy(&self) -> Arc<dyn LoadBalancingPolicy> {
Arc::clone(&self.default_policy)
}
/// Get policy for a model, or default if not found
pub fn get_policy_or_default(&self, model_id: &str) -> Arc<dyn LoadBalancingPolicy> {
self.get_policy(model_id)
.unwrap_or_else(|| self.get_default_policy())
}
/// Determine policy for a new model
fn determine_policy_for_model(
&self,
model_id: &str,
policy_hint: Option<&str>,
) -> Arc<dyn LoadBalancingPolicy> {
// 1. Check policy hint from worker
if let Some(policy_type) = policy_hint {
debug!("Using policy hint '{}' for model {}", policy_type, model_id);
return self.create_policy_from_type(policy_type);
}
// 2. Use default policy
debug!("Using default policy for model {}", model_id);
Arc::clone(&self.default_policy)
}
/// Create a policy from a type string
fn create_policy_from_type(&self, policy_type: &str) -> Arc<dyn LoadBalancingPolicy> {
match policy_type {
"round_robin" => Arc::new(RoundRobinPolicy::new()),
"random" => Arc::new(RandomPolicy::new()),
"cache_aware" => Arc::new(CacheAwarePolicy::new()),
"power_of_two" => Arc::new(PowerOfTwoPolicy::new()),
_ => {
warn!("Unknown policy type '{}', using default", policy_type);
Arc::clone(&self.default_policy)
}
}
}
/// Create a policy from a PolicyConfig
fn create_policy_from_config(config: &PolicyConfig) -> Arc<dyn LoadBalancingPolicy> {
match config {
PolicyConfig::RoundRobin => Arc::new(RoundRobinPolicy::new()),
PolicyConfig::Random => Arc::new(RandomPolicy::new()),
PolicyConfig::CacheAware {
cache_threshold,
balance_abs_threshold,
balance_rel_threshold,
eviction_interval_secs,
max_tree_size,
} => {
let cache_config = CacheAwareConfig {
cache_threshold: *cache_threshold,
balance_abs_threshold: *balance_abs_threshold,
balance_rel_threshold: *balance_rel_threshold,
eviction_interval_secs: *eviction_interval_secs,
max_tree_size: *max_tree_size,
};
Arc::new(CacheAwarePolicy::with_config(cache_config))
}
PolicyConfig::PowerOfTwo { .. } => Arc::new(PowerOfTwoPolicy::new()),
}
}
/// Get current model->policy mappings (for debugging/monitoring)
pub fn get_all_mappings(&self) -> HashMap<String, String> {
let policies = self.model_policies.read().unwrap();
policies
.iter()
.map(|(model, policy)| (model.clone(), policy.name().to_string()))
.collect()
}
/// Get worker counts per model
pub fn get_worker_counts(&self) -> HashMap<String, usize> {
self.model_worker_counts.read().unwrap().clone()
}
/// Clear all policies (useful for testing)
pub fn clear(&self) {
let mut policies = self.model_policies.write().unwrap();
policies.clear();
let mut counts = self.model_worker_counts.write().unwrap();
counts.clear();
}
/// Set the prefill policy for PD mode
pub fn set_prefill_policy(&self, policy: Arc<dyn LoadBalancingPolicy>) {
let mut prefill_policy = self.prefill_policy.write().unwrap();
*prefill_policy = Some(policy);
}
/// Set the decode policy for PD mode
pub fn set_decode_policy(&self, policy: Arc<dyn LoadBalancingPolicy>) {
let mut decode_policy = self.decode_policy.write().unwrap();
*decode_policy = Some(policy);
}
/// Get the prefill policy for PD mode, or default if not set
pub fn get_prefill_policy(&self) -> Arc<dyn LoadBalancingPolicy> {
let prefill_policy = self.prefill_policy.read().unwrap();
prefill_policy
.as_ref()
.map(Arc::clone)
.unwrap_or_else(|| self.get_default_policy())
}
/// Get the decode policy for PD mode, or default if not set
pub fn get_decode_policy(&self) -> Arc<dyn LoadBalancingPolicy> {
let decode_policy = self.decode_policy.read().unwrap();
decode_policy
.as_ref()
.map(Arc::clone)
.unwrap_or_else(|| self.get_default_policy())
}
}
impl std::fmt::Debug for PolicyRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PolicyRegistry")
.field("model_policies", &self.model_policies)
.field("model_worker_counts", &self.model_worker_counts)
.field("default_policy", &self.default_policy.name())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_policy_registry_basic() {
let registry = PolicyRegistry::new(PolicyConfig::RoundRobin);
// First worker of a model sets the policy
let policy1 = registry.on_worker_added("llama-3", Some("cache_aware"));
assert_eq!(policy1.name(), "cache_aware");
// Second worker of same model uses existing policy
let policy2 = registry.on_worker_added("llama-3", Some("round_robin"));
assert_eq!(policy2.name(), "cache_aware"); // Ignores hint, uses existing
// Different model can have different policy
let policy3 = registry.on_worker_added("gpt-4", Some("random"));
assert_eq!(policy3.name(), "random");
// Check mappings
let mappings = registry.get_all_mappings();
assert_eq!(mappings.get("llama-3").unwrap(), "cache_aware");
assert_eq!(mappings.get("gpt-4").unwrap(), "random");
// Check worker counts
let counts = registry.get_worker_counts();
assert_eq!(*counts.get("llama-3").unwrap(), 2);
assert_eq!(*counts.get("gpt-4").unwrap(), 1);
}
#[test]
fn test_policy_registry_cleanup() {
let registry = PolicyRegistry::new(PolicyConfig::RoundRobin);
// Add workers
registry.on_worker_added("llama-3", Some("cache_aware"));
registry.on_worker_added("llama-3", None);
assert_eq!(registry.get_worker_counts().get("llama-3"), Some(&2));
// Remove one worker - policy should remain
registry.on_worker_removed("llama-3");
assert!(registry.get_policy("llama-3").is_some());
assert_eq!(registry.get_worker_counts().get("llama-3"), Some(&1));
// Remove last worker - policy should be cleaned up
registry.on_worker_removed("llama-3");
assert!(registry.get_policy("llama-3").is_none());
assert_eq!(registry.get_worker_counts().get("llama-3"), None);
}
#[test]
fn test_default_policy() {
let registry = PolicyRegistry::new(PolicyConfig::RoundRobin);
// No hint, no template - uses default
let policy = registry.on_worker_added("unknown-model", None);
assert_eq!(policy.name(), "round_robin");
// Get default directly
let default = registry.get_default_policy();
assert_eq!(default.name(), "round_robin");
}
}
......@@ -4,6 +4,7 @@ use super::{get_healthy_worker_indices, LoadBalancingPolicy};
use crate::core::Worker;
use crate::metrics::RouterMetrics;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
/// Round-robin selection policy
///
......@@ -24,7 +25,7 @@ impl RoundRobinPolicy {
impl LoadBalancingPolicy for RoundRobinPolicy {
fn select_worker(
&self,
workers: &[Box<dyn Worker>],
workers: &[Arc<dyn Worker>],
_request_text: Option<&str>,
) -> Option<usize> {
let healthy_indices = get_healthy_worker_indices(workers);
......@@ -64,16 +65,16 @@ mod tests {
#[test]
fn test_round_robin_selection() {
let policy = RoundRobinPolicy::new();
let workers: Vec<Box<dyn Worker>> = vec![
Box::new(BasicWorker::new(
let workers: Vec<Arc<dyn Worker>> = vec![
Arc::new(BasicWorker::new(
"http://w1:8000".to_string(),
WorkerType::Regular,
)),
Box::new(BasicWorker::new(
Arc::new(BasicWorker::new(
"http://w2:8000".to_string(),
WorkerType::Regular,
)),
Box::new(BasicWorker::new(
Arc::new(BasicWorker::new(
"http://w3:8000".to_string(),
WorkerType::Regular,
)),
......@@ -90,16 +91,16 @@ mod tests {
#[test]
fn test_round_robin_with_unhealthy_workers() {
let policy = RoundRobinPolicy::new();
let workers: Vec<Box<dyn Worker>> = vec![
Box::new(BasicWorker::new(
let workers: Vec<Arc<dyn Worker>> = vec![
Arc::new(BasicWorker::new(
"http://w1:8000".to_string(),
WorkerType::Regular,
)),
Box::new(BasicWorker::new(
Arc::new(BasicWorker::new(
"http://w2:8000".to_string(),
WorkerType::Regular,
)),
Box::new(BasicWorker::new(
Arc::new(BasicWorker::new(
"http://w3:8000".to_string(),
WorkerType::Regular,
)),
......@@ -118,12 +119,12 @@ mod tests {
#[test]
fn test_round_robin_reset() {
let policy = RoundRobinPolicy::new();
let workers: Vec<Box<dyn Worker>> = vec![
Box::new(BasicWorker::new(
let workers: Vec<Arc<dyn Worker>> = vec![
Arc::new(BasicWorker::new(
"http://w1:8000".to_string(),
WorkerType::Regular,
)),
Box::new(BasicWorker::new(
Arc::new(BasicWorker::new(
"http://w2:8000".to_string(),
WorkerType::Regular,
)),
......
......@@ -3,3 +3,4 @@
pub mod spec;
pub mod validation;
pub mod worker_spec;
//! Worker management API specifications
//!
//! Defines the request/response structures for worker management endpoints
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Worker configuration for API requests
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct WorkerConfigRequest {
/// Worker URL (required)
pub url: String,
/// Model ID (optional, will query from server if not provided)
#[serde(skip_serializing_if = "Option::is_none")]
pub model_id: Option<String>,
/// Worker priority (optional, default: 50, higher = preferred)
#[serde(skip_serializing_if = "Option::is_none")]
pub priority: Option<u32>,
/// Worker cost factor (optional, default: 1.0)
#[serde(skip_serializing_if = "Option::is_none")]
pub cost: Option<f32>,
/// Worker type (optional: "regular", "prefill", "decode")
#[serde(skip_serializing_if = "Option::is_none")]
pub worker_type: Option<String>,
/// Bootstrap port for prefill workers (optional)
#[serde(skip_serializing_if = "Option::is_none")]
pub bootstrap_port: Option<u16>,
// gRPC-specific configuration (optional, ignored in HTTP mode)
/// Tokenizer path for gRPC mode
#[serde(skip_serializing_if = "Option::is_none")]
pub tokenizer_path: Option<String>,
/// Reasoning parser type for gRPC mode
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_parser: Option<String>,
/// Tool parser type for gRPC mode
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_parser: Option<String>,
/// Chat template for gRPC mode
#[serde(skip_serializing_if = "Option::is_none")]
pub chat_template: Option<String>,
/// Additional labels (optional)
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub labels: HashMap<String, String>,
}
/// Worker information for API responses
#[derive(Debug, Clone, Serialize)]
pub struct WorkerInfo {
/// Worker unique identifier
pub id: String,
/// Worker URL
pub url: String,
/// Model ID this worker serves
pub model_id: String,
/// Worker priority
pub priority: u32,
/// Worker cost factor
pub cost: f32,
/// Worker type
pub worker_type: String,
/// Whether the worker is healthy
pub is_healthy: bool,
/// Current load on the worker
pub load: usize,
/// Connection mode (http or grpc)
pub connection_mode: String,
// gRPC-specific fields (None for HTTP workers)
#[serde(skip_serializing_if = "Option::is_none")]
pub tokenizer_path: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_parser: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_parser: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub chat_template: Option<String>,
/// Additional metadata
#[serde(skip_serializing_if = "HashMap::is_empty")]
pub metadata: HashMap<String, String>,
}
/// Worker list response
#[derive(Debug, Clone, Serialize)]
pub struct WorkerListResponse {
/// List of workers
pub workers: Vec<WorkerInfo>,
/// Total count
pub total: usize,
/// Statistics
pub stats: WorkerStats,
}
/// Worker statistics
#[derive(Debug, Clone, Serialize)]
pub struct WorkerStats {
pub total_workers: usize,
pub healthy_workers: usize,
pub total_models: usize,
pub total_load: usize,
pub by_type: WorkerTypeStats,
}
/// Worker statistics by type
#[derive(Debug, Clone, Serialize)]
pub struct WorkerTypeStats {
pub regular: usize,
pub prefill: usize,
pub decode: usize,
}
/// Worker update request
#[derive(Debug, Clone, Deserialize)]
pub struct WorkerUpdateRequest {
/// Update priority
#[serde(skip_serializing_if = "Option::is_none")]
pub priority: Option<u32>,
/// Update cost
#[serde(skip_serializing_if = "Option::is_none")]
pub cost: Option<f32>,
/// Update labels
#[serde(skip_serializing_if = "Option::is_none")]
pub labels: Option<HashMap<String, String>>,
}
/// Generic API response
#[derive(Debug, Clone, Serialize)]
pub struct WorkerApiResponse {
pub success: bool,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub worker: Option<WorkerInfo>,
}
/// Error response
#[derive(Debug, Clone, Serialize)]
pub struct WorkerErrorResponse {
pub error: String,
pub code: String,
}
/// Server info response from /get_server_info endpoint
#[derive(Debug, Clone, Deserialize)]
pub struct ServerInfo {
#[serde(skip_serializing_if = "Option::is_none")]
pub model_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub model_path: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub priority: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cost: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub worker_type: Option<String>,
// gRPC-specific
#[serde(skip_serializing_if = "Option::is_none")]
pub tokenizer_path: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_parser: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_parser: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub chat_template: Option<String>,
}
......@@ -15,11 +15,6 @@ pub struct RouterFactory;
impl RouterFactory {
/// Create a router instance from application context
pub async fn create_router(ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
// Check if IGW mode is enabled
if ctx.router_config.enable_igw {
return Self::create_igw_router(ctx).await;
}
// Check connection mode and route to appropriate implementation
match ctx.router_config.connection_mode {
ConnectionMode::Grpc => {
......@@ -53,8 +48,7 @@ impl RouterFactory {
// Route to HTTP implementation based on routing mode
match &ctx.router_config.mode {
RoutingMode::Regular { worker_urls } => {
Self::create_regular_router(worker_urls, &ctx.router_config.policy, ctx)
.await
Self::create_regular_router(worker_urls, ctx).await
}
RoutingMode::PrefillDecode {
prefill_urls,
......@@ -80,23 +74,19 @@ impl RouterFactory {
}
}
/// Create a regular router with injected policy
async fn create_regular_router(
/// Create a regular router
pub async fn create_regular_router(
worker_urls: &[String],
policy_config: &PolicyConfig,
ctx: &Arc<AppContext>,
) -> Result<Box<dyn RouterTrait>, String> {
// Create policy
let policy = PolicyFactory::create_from_config(policy_config);
// Create regular router with injected policy and context
let router = Router::new(worker_urls.to_vec(), policy, ctx).await?;
// Create regular router with context
let router = Router::new(worker_urls.to_vec(), ctx).await?;
Ok(Box::new(router))
}
/// Create a PD router with injected policy
async fn create_pd_router(
pub async fn create_pd_router(
prefill_urls: &[(String, Option<u16>)],
decode_urls: &[String],
prefill_policy_config: Option<&PolicyConfig>,
......@@ -104,21 +94,18 @@ impl RouterFactory {
main_policy_config: &PolicyConfig,
ctx: &Arc<AppContext>,
) -> Result<Box<dyn RouterTrait>, String> {
// Create policies - use specific policies if provided, otherwise fall back to main policy
// Initialize policies in PolicyRegistry - use specific policies if provided, otherwise fall back to main policy
let prefill_policy =
PolicyFactory::create_from_config(prefill_policy_config.unwrap_or(main_policy_config));
let decode_policy =
PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config));
// Create PD router with separate policies and context
let router = PDRouter::new(
prefill_urls.to_vec(),
decode_urls.to_vec(),
prefill_policy,
decode_policy,
ctx,
)
.await?;
// Set the prefill and decode policies in the registry
ctx.policy_registry.set_prefill_policy(prefill_policy);
ctx.policy_registry.set_decode_policy(decode_policy);
// Create PD router with context (policies are in PolicyRegistry)
let router = PDRouter::new(prefill_urls.to_vec(), decode_urls.to_vec(), ctx).await?;
Ok(Box::new(router))
}
......@@ -186,10 +173,4 @@ impl RouterFactory {
Ok(Box::new(router))
}
/// Create an IGW router (placeholder for future implementation)
async fn create_igw_router(_ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
// For now, return an error indicating IGW is not yet implemented
Err("IGW mode is not yet implemented".to_string())
}
}
......@@ -27,9 +27,9 @@ use tracing::{info, warn};
#[allow(dead_code)] // Fields will be used once implementation is complete
pub struct GrpcPDRouter {
/// Prefill worker connections
prefill_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
prefill_workers: Arc<RwLock<Vec<Arc<dyn Worker>>>>,
/// Decode worker connections
decode_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
decode_workers: Arc<RwLock<Vec<Arc<dyn Worker>>>>,
/// gRPC clients for prefill workers
prefill_grpc_clients: Arc<RwLock<HashMap<String, SglangSchedulerClient>>>,
/// gRPC clients for decode workers
......@@ -127,7 +127,7 @@ impl GrpcPDRouter {
}
// Create Prefill Worker trait objects with gRPC connection mode
let prefill_workers: Vec<Box<dyn Worker>> = prefill_urls
let prefill_workers: Vec<Arc<dyn Worker>> = prefill_urls
.iter()
.map(|(url, bootstrap_port)| {
let worker = BasicWorker::with_connection_mode(
......@@ -147,12 +147,12 @@ impl GrpcPDRouter {
failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: ctx.router_config.health_check.success_threshold,
});
Box::new(worker) as Box<dyn Worker>
Arc::new(worker) as Arc<dyn Worker>
})
.collect();
// Create Decode Worker trait objects with gRPC connection mode
let decode_workers: Vec<Box<dyn Worker>> = decode_urls
let decode_workers: Vec<Arc<dyn Worker>> = decode_urls
.iter()
.map(|url| {
let worker = BasicWorker::with_connection_mode(
......@@ -168,7 +168,7 @@ impl GrpcPDRouter {
failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: ctx.router_config.health_check.success_threshold,
});
Box::new(worker) as Box<dyn Worker>
Arc::new(worker) as Arc<dyn Worker>
})
.collect();
......@@ -269,6 +269,7 @@ impl RouterTrait for GrpcPDRouter {
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::GenerateRequest,
_model_id: Option<&str>,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
......@@ -277,6 +278,7 @@ impl RouterTrait for GrpcPDRouter {
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::ChatCompletionRequest,
_model_id: Option<&str>,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
......@@ -285,6 +287,7 @@ impl RouterTrait for GrpcPDRouter {
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::CompletionRequest,
_model_id: Option<&str>,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
......@@ -293,6 +296,7 @@ impl RouterTrait for GrpcPDRouter {
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::ResponsesRequest,
_model_id: Option<&str>,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
......@@ -305,6 +309,7 @@ impl RouterTrait for GrpcPDRouter {
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::RerankRequest,
_model_id: Option<&str>,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
......
......@@ -27,7 +27,7 @@ use tracing::{info, warn};
#[allow(dead_code)] // Fields will be used once implementation is complete
pub struct GrpcRouter {
/// Worker connections
workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
workers: Arc<RwLock<Vec<Arc<dyn Worker>>>>,
/// gRPC clients for each worker
grpc_clients: Arc<RwLock<HashMap<String, SglangSchedulerClient>>>,
/// Load balancing policy
......@@ -103,7 +103,7 @@ impl GrpcRouter {
}
// Create Worker trait objects with gRPC connection mode
let mut workers: Vec<Box<dyn Worker>> = Vec::new();
let mut workers: Vec<Arc<dyn Worker>> = Vec::new();
// Move clients from the HashMap to the workers
for url in &worker_urls {
......@@ -123,7 +123,7 @@ impl GrpcRouter {
})
.with_grpc_client(client);
workers.push(Box::new(worker) as Box<dyn Worker>);
workers.push(Arc::new(worker) as Arc<dyn Worker>);
} else {
warn!("No gRPC client for worker {}, skipping", url);
}
......@@ -202,6 +202,7 @@ impl RouterTrait for GrpcRouter {
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::GenerateRequest,
_model_id: Option<&str>,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
......@@ -210,6 +211,7 @@ impl RouterTrait for GrpcRouter {
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::ChatCompletionRequest,
_model_id: Option<&str>,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
......@@ -218,6 +220,7 @@ impl RouterTrait for GrpcRouter {
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::CompletionRequest,
_model_id: Option<&str>,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
......@@ -226,6 +229,7 @@ impl RouterTrait for GrpcRouter {
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::ResponsesRequest,
_model_id: Option<&str>,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
......@@ -238,6 +242,7 @@ impl RouterTrait for GrpcRouter {
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::RerankRequest,
_model_id: Option<&str>,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
......
......@@ -186,6 +186,7 @@ impl super::super::RouterTrait for OpenAIRouter {
&self,
_headers: Option<&HeaderMap>,
_body: &GenerateRequest,
_model_id: Option<&str>,
) -> Response {
// Generate endpoint is SGLang-specific, not supported for OpenAI backend
(
......@@ -199,6 +200,7 @@ impl super::super::RouterTrait for OpenAIRouter {
&self,
headers: Option<&HeaderMap>,
body: &ChatCompletionRequest,
_model_id: Option<&str>,
) -> Response {
if !self.circuit_breaker.can_execute() {
return (StatusCode::SERVICE_UNAVAILABLE, "Circuit breaker open").into_response();
......@@ -326,6 +328,7 @@ impl super::super::RouterTrait for OpenAIRouter {
&self,
_headers: Option<&HeaderMap>,
_body: &CompletionRequest,
_model_id: Option<&str>,
) -> Response {
// Completion endpoint not implemented for OpenAI backend
(
......@@ -339,6 +342,7 @@ impl super::super::RouterTrait for OpenAIRouter {
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::ResponsesRequest,
_model_id: Option<&str>,
) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
......@@ -383,7 +387,12 @@ impl super::super::RouterTrait for OpenAIRouter {
.into_response()
}
async fn route_rerank(&self, _headers: Option<&HeaderMap>, _body: &RerankRequest) -> Response {
async fn route_rerank(
&self,
_headers: Option<&HeaderMap>,
_body: &RerankRequest,
_model_id: Option<&str>,
) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"Rerank endpoint not implemented for OpenAI backend",
......
......@@ -3,11 +3,11 @@
use super::pd_types::{api_path, PDRouterError};
use crate::config::types::RetryConfig;
use crate::core::{
is_retryable_status, BasicWorker, CircuitBreakerConfig, HealthChecker, HealthConfig,
RetryExecutor, Worker, WorkerFactory, WorkerLoadGuard, WorkerType,
is_retryable_status, BasicWorker, CircuitBreakerConfig, HealthConfig, RetryExecutor, Worker,
WorkerFactory, WorkerLoadGuard, WorkerRegistry, WorkerType,
};
use crate::metrics::RouterMetrics;
use crate::policies::LoadBalancingPolicy;
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
use crate::protocols::spec::{
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, RerankRequest,
ResponsesRequest, StringOrArray, UserMessageContent,
......@@ -27,7 +27,7 @@ use reqwest::Client;
use serde::Serialize;
use serde_json::{json, Value};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
......@@ -35,10 +35,8 @@ use tracing::{debug, error, info, warn};
#[derive(Debug)]
pub struct PDRouter {
pub prefill_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
pub decode_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
pub prefill_policy: Arc<dyn LoadBalancingPolicy>,
pub decode_policy: Arc<dyn LoadBalancingPolicy>,
pub worker_registry: Arc<WorkerRegistry>,
pub policy_registry: Arc<PolicyRegistry>,
pub worker_startup_timeout_secs: u64,
pub worker_startup_check_interval_secs: u64,
pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
......@@ -48,25 +46,22 @@ pub struct PDRouter {
pub prefill_client: Client,
pub retry_config: RetryConfig,
pub circuit_breaker_config: CircuitBreakerConfig,
_prefill_health_checker: Option<HealthChecker>,
_decode_health_checker: Option<HealthChecker>,
// Channel for sending prefill responses to background workers for draining
prefill_drain_tx: mpsc::Sender<reqwest::Response>,
}
// Request context for PD router operations
#[derive(Clone)]
struct PDRequestContext {
struct PDRequestContext<'a> {
route: &'static str,
batch_size: Option<usize>,
is_stream: bool,
return_logprob: bool,
request_text: Option<String>,
model_id: Option<&'a str>,
}
impl PDRouter {
// Dynamic worker management methods for service discovery
// Private helper method to perform health check on a new server
async fn wait_for_server_health(&self, url: &str) -> Result<(), PDRouterError> {
crate::routers::http::router::Router::wait_for_healthy_workers(
......@@ -83,24 +78,16 @@ impl PDRouter {
// Generic helper for processing all workers with an endpoint
async fn process_workers(
&self,
workers: &RwLock<Vec<Box<dyn Worker>>>,
worker_type_enum: WorkerType,
worker_type: &str,
endpoint: &str,
) -> (Vec<String>, Vec<String>) {
let mut results = Vec::new();
let mut errors = Vec::new();
// Get worker URLs first to avoid holding lock across await
let urls = match workers.read() {
Ok(workers) => workers
.iter()
.map(|w| w.url().to_string())
.collect::<Vec<_>>(),
Err(_) => {
errors.push(format!("Failed to access {} workers", worker_type));
Vec::new()
}
};
// Get workers from registry based on type
let workers = self.worker_registry.get_by_type(&worker_type_enum);
let urls: Vec<String> = workers.iter().map(|w| w.url().to_string()).collect();
// Process each worker
for worker_url in urls {
......@@ -126,98 +113,95 @@ impl PDRouter {
(results, errors)
}
// Helper to get worker URLs from a worker collection
fn get_worker_urls(
workers: &RwLock<Vec<Box<dyn Worker>>>,
worker_type: &str,
) -> Result<Vec<String>, String> {
workers
.read()
.map(|workers| {
workers
.iter()
.map(|w| w.url().to_string())
.collect::<Vec<_>>()
})
.map_err(|_| format!("Failed to access {} workers", worker_type))
// Helper to get prefill worker URLs
fn get_prefill_worker_urls(&self) -> Vec<String> {
self.worker_registry
.get_prefill_workers()
.iter()
.map(|w| w.url().to_string())
.collect()
}
// Generic helper for proxying requests to the first worker
async fn proxy_to_first_worker(
// Helper to get decode worker URLs
fn get_decode_worker_urls(&self) -> Vec<String> {
self.worker_registry
.get_decode_workers()
.iter()
.map(|w| w.url().to_string())
.collect()
}
// Helper for proxying requests to the first prefill worker
async fn proxy_to_first_prefill_worker(
&self,
workers: &RwLock<Vec<Box<dyn Worker>>>,
endpoint: &str,
worker_type: &str,
headers: Option<Vec<(String, String)>>,
) -> Response {
// Get first worker URL to avoid holding lock across await
let first_worker_url = match workers.read() {
Ok(workers) => workers.first().map(|w| w.url().to_string()),
Err(_) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to access {} workers", worker_type),
)
.into_response();
}
};
let workers = self.worker_registry.get_prefill_workers();
let first_worker_url = workers.first().map(|w| w.url().to_string());
if let Some(worker_url) = first_worker_url {
let url = format!("{}/{}", worker_url, endpoint);
let mut request_builder = self.client.get(&url);
self.proxy_to_worker(worker_url, endpoint, headers).await
} else {
(
StatusCode::SERVICE_UNAVAILABLE,
"No prefill servers available".to_string(),
)
.into_response()
}
}
// Add headers if provided
if let Some(headers) = headers {
for (name, value) in headers {
request_builder = request_builder.header(name, value);
}
// Generic helper for proxying to a specific worker
async fn proxy_to_worker(
&self,
worker_url: String,
endpoint: &str,
headers: Option<Vec<(String, String)>>,
) -> Response {
let url = format!("{}/{}", worker_url, endpoint);
let mut request_builder = self.client.get(&url);
// Add headers if provided
if let Some(headers) = headers {
for (name, value) in headers {
request_builder = request_builder.header(name, value);
}
}
match request_builder.send().await {
Ok(res) if res.status().is_success() => {
let response_headers = header_utils::preserve_response_headers(res.headers());
match res.bytes().await {
Ok(body) => {
let mut response = Response::new(axum::body::Body::from(body));
*response.status_mut() = StatusCode::OK;
*response.headers_mut() = response_headers;
response
}
Err(e) => {
error!("Failed to read response body: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to read response body: {}", e),
)
.into_response()
}
match request_builder.send().await {
Ok(res) if res.status().is_success() => {
let response_headers = header_utils::preserve_response_headers(res.headers());
match res.bytes().await {
Ok(body) => {
let mut response = Response::new(axum::body::Body::from(body));
*response.status_mut() = StatusCode::OK;
*response.headers_mut() = response_headers;
response
}
Err(e) => {
error!("Failed to read response body: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to read response body: {}", e),
)
.into_response()
}
}
Ok(res) => {
let status = StatusCode::from_u16(res.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
(
status,
format!("{} server returned status: {}", worker_type, res.status()),
)
.into_response()
}
Err(e) => {
error!("Failed to proxy request to {} server: {}", worker_type, e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to proxy request: {}", e),
)
.into_response()
}
}
} else {
(
StatusCode::SERVICE_UNAVAILABLE,
format!("No {} servers available", worker_type),
)
.into_response()
Ok(res) => {
let status = StatusCode::from_u16(res.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
(status, format!("{} server returned status: ", res.status())).into_response()
}
Err(e) => {
error!("Failed to proxy request server: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to proxy request: {}", e),
)
.into_response()
}
}
}
......@@ -229,36 +213,37 @@ impl PDRouter {
// Wait for the new server to be healthy
self.wait_for_server_health(&url).await?;
// Check if already exists
if self.worker_registry.get_by_url(&url).is_some() {
return Err(PDRouterError::WorkerAlreadyExists { url: url.clone() });
}
// Create Worker for the new prefill server with circuit breaker configuration
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
let worker = WorkerFactory::create_prefill_with_config(
url.clone(),
bootstrap_port,
self.circuit_breaker_config.clone(),
);
// Add to prefill workers list
let mut workers = self
.prefill_workers
.write()
.map_err(|_| PDRouterError::LockError {
operation: "prefill_workers write".to_string(),
})?;
let worker_arc: Arc<dyn Worker> = Arc::from(worker);
// Check if already exists
if workers.iter().any(|w| w.url() == url) {
return Err(PDRouterError::WorkerAlreadyExists { url: url.clone() });
}
// Register the worker in the registry
self.worker_registry.register(worker_arc.clone());
workers.push(worker);
// Notify PolicyRegistry about the new worker
let model_id = worker_arc.model_id();
let policy = self.policy_registry.on_worker_added(model_id, None);
// Update cache-aware policy if applicable
drop(workers); // Release write lock
if let Some(cache_policy) = self
.prefill_policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_policy.add_worker(&url);
// If this is a cache-aware policy, update it with all workers for this model
if policy.name() == "cache_aware" {
if let Some(cache_aware) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
let model_workers = self.worker_registry.get_by_model_fast(model_id);
cache_aware.init_workers(&model_workers);
}
}
info!("Added prefill server: {}", url);
......@@ -269,35 +254,36 @@ impl PDRouter {
// Wait for the new server to be healthy
self.wait_for_server_health(&url).await?;
// Check if already exists
if self.worker_registry.get_by_url(&url).is_some() {
return Err(PDRouterError::WorkerAlreadyExists { url: url.clone() });
}
// Create Worker for the new decode server with circuit breaker configuration
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
let worker = WorkerFactory::create_decode_with_config(
url.clone(),
self.circuit_breaker_config.clone(),
);
// Add to decode workers list
let mut workers = self
.decode_workers
.write()
.map_err(|_| PDRouterError::LockError {
operation: "decode_workers write".to_string(),
})?;
let worker_arc: Arc<dyn Worker> = Arc::from(worker);
// Check if already exists
if workers.iter().any(|w| w.url() == url) {
return Err(PDRouterError::WorkerAlreadyExists { url: url.clone() });
}
// Register the worker in the registry
self.worker_registry.register(worker_arc.clone());
workers.push(worker);
// Notify PolicyRegistry about the new worker
let model_id = worker_arc.model_id();
let policy = self.policy_registry.on_worker_added(model_id, None);
// Update cache-aware policy if applicable
drop(workers); // Release write lock
if let Some(cache_policy) = self
.decode_policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_policy.add_worker(&url);
// If this is a cache-aware policy, update it with all workers for this model
if policy.name() == "cache_aware" {
if let Some(cache_aware) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
let model_workers = self.worker_registry.get_by_model_fast(model_id);
cache_aware.init_workers(&model_workers);
}
}
info!("Added decode server: {}", url);
......@@ -305,73 +291,91 @@ impl PDRouter {
}
pub async fn remove_prefill_server(&self, url: &str) -> Result<String, PDRouterError> {
let mut workers = self
.prefill_workers
.write()
.map_err(|_| PDRouterError::LockError {
operation: "prefill_workers write".to_string(),
})?;
// Check if worker exists and get model_id
let model_id = match self.worker_registry.get_by_url(url) {
Some(worker) => worker.model_id().to_string(),
None => {
return Err(PDRouterError::WorkerNotFound {
url: url.to_string(),
});
}
};
// Find and remove the server
let initial_len = workers.len();
workers.retain(|w| w.url() != url);
// Remove from registry
let removed = self.worker_registry.remove_by_url(url);
if workers.len() == initial_len {
return Err(PDRouterError::WorkerNotFound {
url: url.to_string(),
});
}
if removed.is_some() {
// Notify PolicyRegistry about the removed worker
self.policy_registry.on_worker_removed(&model_id);
// Remove from cache-aware policy if applicable
if let Some(cache_policy) = self
.prefill_policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_policy.remove_worker(url);
// Get the policy for this model to update cache-aware if needed
if let Some(policy) = self.policy_registry.get_policy(&model_id) {
if policy.name() == "cache_aware" {
if let Some(cache_aware) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.remove_worker_by_url(url);
}
}
}
}
info!("Removed prefill server: {}", url);
Ok(format!("Successfully removed prefill server: {}", url))
if removed.is_some() {
info!("Removed prefill server: {}", url);
Ok(format!("Successfully removed prefill server: {}", url))
} else {
Err(PDRouterError::WorkerNotFound {
url: url.to_string(),
})
}
}
pub async fn remove_decode_server(&self, url: &str) -> Result<String, PDRouterError> {
let mut workers = self
.decode_workers
.write()
.map_err(|_| PDRouterError::LockError {
operation: "decode_workers write".to_string(),
})?;
// Check if worker exists and get model_id
let model_id = match self.worker_registry.get_by_url(url) {
Some(worker) => worker.model_id().to_string(),
None => {
return Err(PDRouterError::WorkerNotFound {
url: url.to_string(),
});
}
};
// Find and remove the server
let initial_len = workers.len();
workers.retain(|w| w.url() != url);
// Remove from registry
let removed = self.worker_registry.remove_by_url(url);
if workers.len() == initial_len {
return Err(PDRouterError::WorkerNotFound {
url: url.to_string(),
});
}
if removed.is_some() {
// Notify PolicyRegistry about the removed worker
self.policy_registry.on_worker_removed(&model_id);
// Remove from cache-aware policy if applicable
if let Some(cache_policy) = self
.decode_policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_policy.remove_worker(url);
// Get the policy for this model to update cache-aware if needed
if let Some(policy) = self.policy_registry.get_policy(&model_id) {
if policy.name() == "cache_aware" {
if let Some(cache_aware) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.remove_worker_by_url(url);
}
}
}
}
info!("Removed decode server: {}", url);
Ok(format!("Successfully removed decode server: {}", url))
if removed.is_some() {
info!("Removed decode server: {}", url);
Ok(format!("Successfully removed decode server: {}", url))
} else {
Err(PDRouterError::WorkerNotFound {
url: url.to_string(),
})
}
}
#[allow(clippy::too_many_arguments)]
pub async fn new(
prefill_urls: Vec<(String, Option<u16>)>,
decode_urls: Vec<String>,
prefill_policy: Arc<dyn LoadBalancingPolicy>,
decode_policy: Arc<dyn LoadBalancingPolicy>,
ctx: &Arc<crate::server::AppContext>,
) -> Result<Self, String> {
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
......@@ -383,16 +387,28 @@ impl PDRouter {
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
};
// Convert URLs to Worker trait objects with health check config
let prefill_workers: Vec<Box<dyn Worker>> = prefill_urls
.into_iter()
.map(|(url, port)| {
let worker = BasicWorker::new(
url,
WorkerType::Prefill {
bootstrap_port: port,
},
)
// Register prefill workers in the registry
for (url, port) in prefill_urls {
let worker = BasicWorker::new(
url,
WorkerType::Prefill {
bootstrap_port: port,
},
)
.with_circuit_breaker_config(core_cb_config.clone())
.with_health_config(HealthConfig {
timeout_secs: ctx.router_config.health_check.timeout_secs,
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
endpoint: ctx.router_config.health_check.endpoint.clone(),
failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: ctx.router_config.health_check.success_threshold,
});
ctx.worker_registry.register(Arc::new(worker));
}
// Register decode workers in the registry
for url in decode_urls {
let worker = BasicWorker::new(url, WorkerType::Decode)
.with_circuit_breaker_config(core_cb_config.clone())
.with_health_config(HealthConfig {
timeout_secs: ctx.router_config.health_check.timeout_secs,
......@@ -401,30 +417,13 @@ impl PDRouter {
failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: ctx.router_config.health_check.success_threshold,
});
Box::new(worker) as Box<dyn Worker>
})
.collect();
let decode_workers: Vec<Box<dyn Worker>> = decode_urls
.into_iter()
.map(|url| {
let worker = BasicWorker::new(url, WorkerType::Decode)
.with_circuit_breaker_config(core_cb_config.clone())
.with_health_config(HealthConfig {
timeout_secs: ctx.router_config.health_check.timeout_secs,
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
endpoint: ctx.router_config.health_check.endpoint.clone(),
failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: ctx.router_config.health_check.success_threshold,
});
Box::new(worker) as Box<dyn Worker>
})
.collect();
ctx.worker_registry.register(Arc::new(worker));
}
// Wait for PD workers to be healthy (skip if empty - for service discovery mode)
let all_urls: Vec<String> = prefill_workers
// Get all workers from registry for health check
let all_workers = ctx.worker_registry.get_all();
let all_urls: Vec<String> = all_workers
.iter()
.chain(decode_workers.iter())
.map(|worker| worker.url().to_string())
.collect();
if !all_urls.is_empty() {
......@@ -436,25 +435,19 @@ impl PDRouter {
.await?;
}
// Initialize cache-aware policies with workers
if let Some(cache_policy) = prefill_policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_policy.init_workers(&prefill_workers);
}
if let Some(cache_policy) = decode_policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_policy.init_workers(&decode_workers);
}
// Initialize cache-aware policies with workers from registry
// Note: We need to get workers by type and convert to Box<dyn Worker> for CacheAwarePolicy
// This is a temporary workaround until CacheAwarePolicy is updated to work with Arc<dyn Worker>
// TODO: Update CacheAwarePolicy to accept Arc<dyn Worker> instead of Box<dyn Worker>
// Set up background load monitoring for power-of-two selection
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
let worker_loads = Arc::new(rx);
// Get policies from registry to check if we need load monitoring
let prefill_policy = ctx.policy_registry.get_prefill_policy();
let decode_policy = ctx.policy_registry.get_decode_policy();
let load_monitor_handle =
if prefill_policy.name() == "power_of_two" || decode_policy.name() == "power_of_two" {
let monitor_urls = all_urls.clone();
......@@ -478,18 +471,8 @@ impl PDRouter {
None
};
let prefill_workers = Arc::new(RwLock::new(prefill_workers));
let decode_workers = Arc::new(RwLock::new(decode_workers));
// Start health checkers for both worker pools
let prefill_health_checker = crate::core::start_health_checker(
Arc::clone(&prefill_workers),
ctx.router_config.health_check.check_interval_secs,
);
let decode_health_checker = crate::core::start_health_checker(
Arc::clone(&decode_workers),
ctx.router_config.health_check.check_interval_secs,
);
// Note: Health checking is now handled centrally by RouterManager
// Individual routers no longer need to manage health checkers
// Build a dedicated prefill client for fire-and-forget semantics
let prefill_client = reqwest::Client::builder()
......@@ -570,10 +553,8 @@ impl PDRouter {
});
Ok(PDRouter {
prefill_workers,
decode_workers,
prefill_policy,
decode_policy,
worker_registry: Arc::clone(&ctx.worker_registry),
policy_registry: Arc::clone(&ctx.policy_registry),
worker_startup_timeout_secs: ctx.router_config.worker_startup_timeout_secs,
worker_startup_check_interval_secs: ctx
.router_config
......@@ -585,8 +566,6 @@ impl PDRouter {
prefill_drain_tx,
retry_config: ctx.router_config.effective_retry_config(),
circuit_breaker_config: core_cb_config,
_prefill_health_checker: Some(prefill_health_checker),
_decode_health_checker: Some(decode_health_checker),
})
}
......@@ -721,7 +700,7 @@ impl PDRouter {
&self,
headers: Option<&HeaderMap>,
original_request: &T,
context: PDRequestContext,
context: PDRequestContext<'_>,
) -> Response {
let start_time = Instant::now();
......@@ -736,14 +715,16 @@ impl PDRouter {
let context = context.clone();
async move {
// Select workers fresh for each attempt
let (prefill, decode) =
match self.select_pd_pair(context.request_text.as_deref()).await {
Ok(pair) => pair,
Err(e) => {
RouterMetrics::record_pd_error("server_selection");
return Self::handle_server_selection_error(e);
}
};
let (prefill, decode) = match self
.select_pd_pair(context.request_text.as_deref(), context.model_id)
.await
{
Ok(pair) => pair,
Err(e) => {
RouterMetrics::record_pd_error("server_selection");
return Self::handle_server_selection_error(e);
}
};
debug!(
"PD retry attempt {} using prefill={} decode={}",
......@@ -806,7 +787,7 @@ impl PDRouter {
async fn handle_decode_error_response(
&self,
res: reqwest::Response,
context: &PDRequestContext,
context: &PDRequestContext<'_>,
prefill: &dyn Worker,
decode: &dyn Worker,
) -> Response {
......@@ -859,7 +840,7 @@ impl PDRouter {
&self,
headers: Option<&HeaderMap>,
json_request: Value,
context: PDRequestContext,
context: PDRequestContext<'_>,
prefill: &dyn Worker,
decode: &dyn Worker,
start_time: Instant,
......@@ -1131,35 +1112,56 @@ impl PDRouter {
// Check if either prefill or decode policy needs request text
fn policies_need_request_text(&self) -> bool {
self.prefill_policy.needs_request_text() || self.decode_policy.needs_request_text()
// Check both prefill and decode policies
let prefill_policy = self.policy_registry.get_prefill_policy();
let decode_policy = self.policy_registry.get_decode_policy();
prefill_policy.needs_request_text() || decode_policy.needs_request_text()
}
// Select a pair of prefill and decode servers considering circuit breaker state
async fn select_pd_pair(
&self,
request_text: Option<&str>,
) -> Result<(Box<dyn Worker>, Box<dyn Worker>), String> {
// Get read locks for both worker lists
let prefill_workers = self
.prefill_workers
.read()
.map_err(|e| format!("Failed to acquire prefill workers lock: {}", e))?;
let decode_workers = self
.decode_workers
.read()
.map_err(|e| format!("Failed to acquire decode workers lock: {}", e))?;
model_id: Option<&str>,
) -> Result<(Arc<dyn Worker>, Arc<dyn Worker>), String> {
// Get workers from registry - filter by model if provided
let prefill_workers = if let Some(model) = model_id {
// Get model-specific workers and filter for prefill type
self.worker_registry
.get_by_model_fast(model)
.into_iter()
.filter(|w| matches!(w.worker_type(), WorkerType::Prefill { .. }))
.collect()
} else {
self.worker_registry.get_prefill_workers()
};
let decode_workers = if let Some(model) = model_id {
// Get model-specific workers and filter for decode type
self.worker_registry
.get_by_model_fast(model)
.into_iter()
.filter(|w| matches!(w.worker_type(), WorkerType::Decode))
.collect()
} else {
self.worker_registry.get_decode_workers()
};
// Select workers using helper function
let prefill = Self::pick_worker_by_policy(
// Use separate policies for prefill and decode to avoid counter conflicts
let prefill_policy = self.policy_registry.get_prefill_policy();
let decode_policy = self.policy_registry.get_decode_policy();
let prefill = Self::pick_worker_by_policy_arc(
&prefill_workers,
&*self.prefill_policy,
&*prefill_policy,
request_text,
"prefill",
)?;
let decode = Self::pick_worker_by_policy(
let decode = Self::pick_worker_by_policy_arc(
&decode_workers,
&*self.decode_policy,
&*decode_policy,
request_text,
"decode",
)?;
......@@ -1167,13 +1169,13 @@ impl PDRouter {
Ok((prefill, decode))
}
// Helper function to select a worker using the policy
fn pick_worker_by_policy(
workers: &[Box<dyn Worker>],
// Helper function to select a worker using the policy (Arc version)
fn pick_worker_by_policy_arc(
workers: &[Arc<dyn Worker>],
policy: &dyn LoadBalancingPolicy,
request_text: Option<&str>,
worker_type: &str,
) -> Result<Box<dyn Worker>, String> {
) -> Result<Arc<dyn Worker>, String> {
// Check if we have any workers
if workers.is_empty() {
return Err(format!(
......@@ -1183,10 +1185,10 @@ impl PDRouter {
}
// Filter available workers (healthy + circuit breaker not open)
let available_workers: Vec<Box<dyn Worker>> = workers
let available_workers: Vec<Arc<dyn Worker>> = workers
.iter()
.filter(|w| w.is_available())
.map(|w| w.clone_worker())
.cloned()
.collect();
if available_workers.is_empty() {
......@@ -1196,11 +1198,19 @@ impl PDRouter {
));
}
// Let policy select from available workers only
match policy.select_worker(&available_workers, request_text) {
Some(idx) => Ok(available_workers[idx].clone_worker()),
None => Err(format!("Policy could not select a {} worker", worker_type)),
}
// Let policy select from available workers (no conversion needed now!)
let selected_idx = policy
.select_worker(&available_workers, request_text)
.ok_or_else(|| {
format!(
"Policy {} failed to select a {} worker",
policy.name(),
worker_type
)
})?;
// Return the selected Arc worker
Ok(available_workers[selected_idx].clone())
}
// Background task to monitor worker loads with shared client
......@@ -1272,9 +1282,8 @@ impl PDRouter {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
// Clone the worker collections for the spawned task
let prefill_workers = self.prefill_workers.clone();
let decode_workers = self.decode_workers.clone();
// Clone the registry for the spawned task
let registry = self.worker_registry.clone();
tokio::spawn(async move {
// Use a flag to track whether stream completed successfully
......@@ -1321,31 +1330,21 @@ impl PDRouter {
// Always decrement load after streaming (either completes or errors)
// Find and decrement prefill worker
if let Ok(prefill_workers_guard) = prefill_workers.read() {
for worker in prefill_workers_guard.iter() {
if worker.url() == prefill_url.as_str() {
worker.decrement_load();
debug!(
"Decremented load for prefill worker: {} (stream_completed: {})",
prefill_url, stream_completed
);
break;
}
}
if let Some(worker) = registry.get_by_url(&prefill_url) {
worker.decrement_load();
debug!(
"Decremented load for prefill worker: {} (stream_completed: {})",
prefill_url, stream_completed
);
}
// Find and decrement decode worker
if let Ok(decode_workers_guard) = decode_workers.read() {
for worker in decode_workers_guard.iter() {
if worker.url() == decode_url_str.as_str() {
worker.decrement_load();
debug!(
"Decremented load for decode worker: {} (stream_completed: {})",
decode_url_str, stream_completed
);
break;
}
}
if let Some(worker) = registry.get_by_url(&decode_url_str) {
worker.decrement_load();
debug!(
"Decremented load for decode worker: {} (stream_completed: {})",
decode_url_str, stream_completed
);
}
});
......@@ -1626,42 +1625,24 @@ impl WorkerManagement for PDRouter {
}
fn remove_worker(&self, worker_url: &str) {
// For PD router, we would need to know if it's a prefill or decode server
// For now, try both
if let Ok(mut workers) = self.prefill_workers.write() {
if let Some(index) = workers.iter().position(|w| w.url() == worker_url) {
workers.remove(index);
info!("Removed prefill worker: {}", worker_url);
return;
}
}
if let Ok(mut workers) = self.decode_workers.write() {
if let Some(index) = workers.iter().position(|w| w.url() == worker_url) {
workers.remove(index);
info!("Removed decode worker: {}", worker_url);
// Remove from registry
if let Some(worker) = self.worker_registry.remove_by_url(worker_url) {
match worker.worker_type() {
WorkerType::Prefill { .. } => {
info!("Removed prefill worker: {}", worker_url);
}
WorkerType::Decode => {
info!("Removed decode worker: {}", worker_url);
}
_ => {
info!("Removed worker: {}", worker_url);
}
}
}
}
fn get_worker_urls(&self) -> Vec<String> {
let mut urls = Vec::new();
// Add prefill worker URLs
if let Ok(workers) = self.prefill_workers.read() {
for worker in workers.iter() {
urls.push(worker.url().to_string());
}
}
// Add decode worker URLs
if let Ok(workers) = self.decode_workers.read() {
for worker in workers.iter() {
urls.push(worker.url().to_string());
}
}
urls
self.worker_registry.get_all_urls()
}
}
......@@ -1677,19 +1658,16 @@ impl RouterTrait for PDRouter {
let mut all_healthy = true;
let mut unhealthy_servers = Vec::new();
// Check prefill servers
for worker in self.prefill_workers.read().unwrap().iter() {
// Check all workers
for worker in self.worker_registry.get_all() {
if !worker.is_healthy() {
all_healthy = false;
unhealthy_servers.push(format!("Prefill: {}", worker.url()));
}
}
// Check decode servers
for worker in self.decode_workers.read().unwrap().iter() {
if !worker.is_healthy() {
all_healthy = false;
unhealthy_servers.push(format!("Decode: {}", worker.url()));
let worker_type = match worker.worker_type() {
WorkerType::Prefill { .. } => "Prefill",
WorkerType::Decode => "Decode",
_ => "Worker",
};
unhealthy_servers.push(format!("{}: {}", worker_type, worker.url()));
}
}
......@@ -1709,7 +1687,7 @@ impl RouterTrait for PDRouter {
// Note: This endpoint actually causes the model to generate tokens, so we only test one pair
// Select a random worker pair using the policy
let (prefill, decode) = match self.select_pd_pair(None).await {
let (prefill, decode) = match self.select_pd_pair(None, None).await {
Ok(pair) => pair,
Err(e) => {
return (
......@@ -1789,7 +1767,7 @@ impl RouterTrait for PDRouter {
async fn get_server_info(&self, _req: Request<Body>) -> Response {
// Get info from the first decode server to match sglang's server info format
// Note: We use decode workers for server info to match expected format
self.proxy_to_first_worker(&self.decode_workers, "get_server_info", "decode", None)
self.proxy_to_first_prefill_worker("get_server_info", None)
.await
}
......@@ -1798,7 +1776,7 @@ impl RouterTrait for PDRouter {
let headers = header_utils::copy_request_headers(&req);
// Proxy to first prefill worker
self.proxy_to_first_worker(&self.prefill_workers, "v1/models", "prefill", Some(headers))
self.proxy_to_first_prefill_worker("v1/models", Some(headers))
.await
}
......@@ -1807,19 +1785,15 @@ impl RouterTrait for PDRouter {
let headers = header_utils::copy_request_headers(&req);
// Proxy to first prefill worker
self.proxy_to_first_worker(
&self.prefill_workers,
"get_model_info",
"prefill",
Some(headers),
)
.await
self.proxy_to_first_prefill_worker("get_model_info", Some(headers))
.await
}
async fn route_generate(
&self,
headers: Option<&HeaderMap>,
body: &GenerateRequest,
model_id: Option<&str>,
) -> Response {
// Extract parameters
let is_stream = body.stream;
......@@ -1850,6 +1824,7 @@ impl RouterTrait for PDRouter {
is_stream,
return_logprob,
request_text,
model_id,
};
// Execute with retry and bootstrap injection
......@@ -1860,6 +1835,7 @@ impl RouterTrait for PDRouter {
&self,
headers: Option<&HeaderMap>,
body: &ChatCompletionRequest,
model_id: Option<&str>,
) -> Response {
// Extract parameters
let is_stream = body.stream;
......@@ -1889,6 +1865,7 @@ impl RouterTrait for PDRouter {
is_stream,
return_logprob,
request_text,
model_id,
};
// Execute with retry and bootstrap injection
......@@ -1899,6 +1876,7 @@ impl RouterTrait for PDRouter {
&self,
headers: Option<&HeaderMap>,
body: &CompletionRequest,
model_id: Option<&str>,
) -> Response {
// Extract parameters
let is_stream = body.stream;
......@@ -1924,6 +1902,7 @@ impl RouterTrait for PDRouter {
is_stream,
return_logprob,
request_text,
model_id,
};
// Execute with retry and bootstrap injection
......@@ -1934,6 +1913,7 @@ impl RouterTrait for PDRouter {
&self,
_headers: Option<&HeaderMap>,
_body: &ResponsesRequest,
_model_id: Option<&str>,
) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
......@@ -1946,7 +1926,12 @@ impl RouterTrait for PDRouter {
todo!()
}
async fn route_rerank(&self, headers: Option<&HeaderMap>, body: &RerankRequest) -> Response {
async fn route_rerank(
&self,
headers: Option<&HeaderMap>,
body: &RerankRequest,
model_id: Option<&str>,
) -> Response {
// Extract text for cache-aware routing
let req_text = if self.policies_need_request_text() {
Some(body.query.clone())
......@@ -1961,6 +1946,7 @@ impl RouterTrait for PDRouter {
is_stream: false,
return_logprob: false,
request_text: req_text,
model_id,
};
// Execute with retry and bootstrap injection
......@@ -1970,10 +1956,16 @@ impl RouterTrait for PDRouter {
async fn flush_cache(&self) -> Response {
// Process both prefill and decode workers
let (prefill_results, prefill_errors) = self
.process_workers(&self.prefill_workers, "Prefill", "flush_cache")
.process_workers(
WorkerType::Prefill {
bootstrap_port: None,
},
"Prefill",
"flush_cache",
)
.await;
let (decode_results, decode_errors) = self
.process_workers(&self.decode_workers, "Decode", "flush_cache")
.process_workers(WorkerType::Decode, "Decode", "flush_cache")
.await;
// Combine results and errors
......@@ -2005,37 +1997,29 @@ impl RouterTrait for PDRouter {
let mut errors = Vec::new();
// Process prefill workers
match Self::get_worker_urls(&self.prefill_workers, "prefill") {
Ok(urls) => {
for worker_url in urls {
match get_worker_load(&self.client, &worker_url).await {
Some(load) => {
loads.insert(format!("prefill_{}", worker_url), load);
}
None => {
errors.push(format!("Failed to get load from prefill {}", worker_url));
}
}
let prefill_urls = self.get_prefill_worker_urls();
for worker_url in prefill_urls {
match get_worker_load(&self.client, &worker_url).await {
Some(load) => {
loads.insert(format!("prefill_{}", worker_url), load);
}
None => {
errors.push(format!("Failed to get load from prefill {}", worker_url));
}
}
Err(e) => errors.push(e),
}
// Process decode workers
match Self::get_worker_urls(&self.decode_workers, "decode") {
Ok(urls) => {
for worker_url in urls {
match get_worker_load(&self.client, &worker_url).await {
Some(load) => {
loads.insert(format!("decode_{}", worker_url), load);
}
None => {
errors.push(format!("Failed to get load from decode {}", worker_url));
}
}
let decode_urls = self.get_decode_worker_urls();
for worker_url in decode_urls {
match get_worker_load(&self.client, &worker_url).await {
Some(load) => {
loads.insert(format!("decode_{}", worker_url), load);
}
None => {
errors.push(format!("Failed to get load from decode {}", worker_url));
}
}
Err(e) => errors.push(e),
}
let response_data = serde_json::json!({
......@@ -2052,24 +2036,15 @@ impl RouterTrait for PDRouter {
fn readiness(&self) -> Response {
// PD router is ready if it has at least one healthy prefill AND one healthy decode worker
let healthy_prefill_count = self
.prefill_workers
.read()
.unwrap()
.iter()
.filter(|w| w.is_healthy())
.count();
let prefill_workers = self.worker_registry.get_prefill_workers();
let decode_workers = self.worker_registry.get_decode_workers();
let healthy_decode_count = self
.decode_workers
.read()
.unwrap()
.iter()
.filter(|w| w.is_healthy())
.count();
let healthy_prefill_count = prefill_workers.iter().filter(|w| w.is_healthy()).count();
let total_prefill = self.prefill_workers.read().unwrap().len();
let total_decode = self.decode_workers.read().unwrap().len();
let healthy_decode_count = decode_workers.iter().filter(|w| w.is_healthy()).count();
let total_prefill = prefill_workers.len();
let total_decode = decode_workers.len();
if healthy_prefill_count > 0 && healthy_decode_count > 0 {
Json(json!({
......@@ -2117,17 +2092,15 @@ impl RouterTrait for PDRouter {
mod tests {
use super::*;
use crate::core::{BasicWorker, WorkerType};
use crate::policies::RandomPolicy;
fn create_test_pd_router() -> PDRouter {
let prefill_policy = Arc::new(RandomPolicy::new());
let decode_policy = Arc::new(RandomPolicy::new());
let worker_registry = Arc::new(WorkerRegistry::new());
let policy_registry =
Arc::new(PolicyRegistry::new(crate::config::PolicyConfig::RoundRobin));
PDRouter {
prefill_workers: Arc::new(RwLock::new(vec![])),
decode_workers: Arc::new(RwLock::new(vec![])),
prefill_policy,
decode_policy,
worker_registry,
policy_registry,
worker_startup_timeout_secs: 5,
worker_startup_check_interval_secs: 1,
worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1),
......@@ -2137,8 +2110,6 @@ mod tests {
prefill_drain_tx: mpsc::channel(100).0,
retry_config: RetryConfig::default(),
circuit_breaker_config: CircuitBreakerConfig::default(),
_prefill_health_checker: None,
_decode_health_checker: None,
}
}
......@@ -2162,12 +2133,14 @@ mod tests {
},
true,
);
router.prefill_workers.write().unwrap().push(worker);
router.worker_registry.register(Arc::from(worker));
// Try to add the same URL again - this would fail during health check in real scenario
// For unit test, we test the duplicate check logic
let workers = router.prefill_workers.read().unwrap();
let exists = workers.iter().any(|w| w.url() == "http://localhost:8000");
let exists = router
.worker_registry
.get_by_url("http://localhost:8000")
.is_some();
assert!(exists);
}
......@@ -2191,8 +2164,8 @@ mod tests {
true,
);
router.prefill_workers.write().unwrap().push(worker1);
router.prefill_workers.write().unwrap().push(worker2);
router.worker_registry.register(Arc::from(worker1));
router.worker_registry.register(Arc::from(worker2));
// Remove one
let result = router.remove_prefill_server("http://worker1").await;
......@@ -2200,7 +2173,7 @@ mod tests {
assert!(result.is_ok());
assert!(result.unwrap().contains("Successfully removed"));
let workers = router.prefill_workers.read().unwrap();
let workers = router.worker_registry.get_prefill_workers();
assert_eq!(workers.len(), 1);
assert_eq!(workers[0].url(), "http://worker2");
}
......@@ -2226,44 +2199,42 @@ mod tests {
// Add server first
let worker = create_test_worker("http://decode1".to_string(), WorkerType::Decode, true);
router.decode_workers.write().unwrap().push(worker);
router.worker_registry.register(Arc::from(worker));
let result = router.remove_decode_server("http://decode1").await;
assert!(result.is_ok());
assert!(result.unwrap().contains("Successfully removed"));
let workers = router.decode_workers.read().unwrap();
let workers = router.worker_registry.get_decode_workers();
assert_eq!(workers.len(), 0);
}
// ============= Lock Error Handling Tests =============
#[test]
fn test_lock_operations() {
fn test_registry_operations() {
let router = create_test_pd_router();
// Test read/write locks work correctly
{
let read_guard = router.prefill_workers.read().unwrap();
assert_eq!(read_guard.len(), 0);
}
// Test registry operations
let workers = router.worker_registry.get_all();
assert_eq!(workers.len(), 0);
{
let mut write_guard = router.prefill_workers.write().unwrap();
write_guard.push(create_test_worker(
"http://test".to_string(),
WorkerType::Prefill {
bootstrap_port: None,
},
true,
));
}
// Add a worker
let worker = create_test_worker(
"http://test".to_string(),
WorkerType::Prefill {
bootstrap_port: None,
},
true,
);
router.worker_registry.register(Arc::from(worker));
{
let read_guard = router.prefill_workers.read().unwrap();
assert_eq!(read_guard.len(), 1);
}
let workers = router.worker_registry.get_all();
assert_eq!(workers.len(), 1);
let prefill_workers = router.worker_registry.get_prefill_workers();
assert_eq!(prefill_workers.len(), 1);
}
// ============= Bootstrap Injection Tests =============
......@@ -2297,15 +2268,11 @@ mod tests {
let decode_worker =
create_test_worker("http://decode".to_string(), WorkerType::Decode, true);
router
.prefill_workers
.write()
.unwrap()
.push(unhealthy_worker);
router.prefill_workers.write().unwrap().push(healthy_worker);
router.decode_workers.write().unwrap().push(decode_worker);
router.worker_registry.register(Arc::from(unhealthy_worker));
router.worker_registry.register(Arc::from(healthy_worker));
router.worker_registry.register(Arc::from(decode_worker));
let result = router.select_pd_pair(None).await;
let result = router.select_pd_pair(None, None).await;
assert!(result.is_ok());
let (prefill, _decode) = result.unwrap();
......@@ -2319,7 +2286,7 @@ mod tests {
async fn test_empty_worker_lists() {
let router = create_test_pd_router();
let result = router.select_pd_pair(None).await;
let result = router.select_pd_pair(None, None).await;
assert!(result.is_err());
assert!(result.unwrap_err().contains("No prefill workers available"));
......@@ -2331,7 +2298,7 @@ mod tests {
async fn test_health_endpoints() {
let router = create_test_pd_router();
// Add healthy workers
// Add healthy workers - create_test_worker returns Box<dyn Worker>, convert to Arc
let prefill_worker = create_test_worker(
"http://localhost:8000".to_string(),
WorkerType::Prefill {
......@@ -2345,8 +2312,8 @@ mod tests {
true,
);
router.prefill_workers.write().unwrap().push(prefill_worker);
router.decode_workers.write().unwrap().push(decode_worker);
router.worker_registry.register(Arc::from(prefill_worker));
router.worker_registry.register(Arc::from(decode_worker));
// Test health endpoint
let http_req = axum::http::Request::builder()
......@@ -2367,8 +2334,13 @@ mod tests {
async fn test_load_monitor_updates() {
let power_of_two_policy = Arc::new(crate::policies::PowerOfTwoPolicy::new());
let mut router = create_test_pd_router();
router.prefill_policy = power_of_two_policy.clone();
router.decode_policy = power_of_two_policy;
// Set power_of_two policies in the registry
router
.policy_registry
.set_prefill_policy(power_of_two_policy.clone());
router
.policy_registry
.set_decode_policy(power_of_two_policy);
// Create load channel
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
......@@ -2423,7 +2395,7 @@ mod tests {
let router = create_test_pd_router();
// Add workers
// Add workers - create_test_worker returns Box<dyn Worker>, convert to Arc
let prefill_worker = create_test_worker(
"http://prefill".to_string(),
WorkerType::Prefill {
......@@ -2434,18 +2406,15 @@ mod tests {
let decode_worker =
create_test_worker("http://decode".to_string(), WorkerType::Decode, true);
router.prefill_workers.write().unwrap().push(prefill_worker);
router.decode_workers.write().unwrap().push(decode_worker);
// Get references to the workers - clone to avoid holding lock across await
let (prefill_ref, decode_ref) = {
let workers = router.prefill_workers.read().unwrap();
let prefill = workers[0].clone_worker();
drop(workers);
let workers = router.decode_workers.read().unwrap();
let decode = workers[0].clone_worker();
(prefill, decode)
};
router.worker_registry.register(Arc::from(prefill_worker));
router.worker_registry.register(Arc::from(decode_worker));
// Get references to the workers from registry
let prefill_workers = router.worker_registry.get_prefill_workers();
let decode_workers = router.worker_registry.get_decode_workers();
let prefill_ref = prefill_workers[0].clone();
let decode_ref = decode_workers[0].clone();
// Initially load should be 0
assert_eq!(prefill_ref.load(), 0);
......@@ -2512,7 +2481,7 @@ mod tests {
},
true,
);
router_clone.prefill_workers.write().unwrap().push(worker);
router_clone.worker_registry.register(Arc::from(worker));
});
handles.push(handle);
}
......@@ -2523,7 +2492,7 @@ mod tests {
}
// Check final state
let workers = router.prefill_workers.read().unwrap();
let workers = router.worker_registry.get_prefill_workers();
assert_eq!(workers.len(), 5);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment