Unverified Commit 2f173ea0 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] allow one router to support different model families and serving mode (#10244)

parent 321fecab
use crate::config::types::RetryConfig;
use crate::core::{
is_retryable_status, BasicWorker, CircuitBreakerConfig, HealthChecker, HealthConfig,
RetryExecutor, Worker, WorkerFactory, WorkerType,
is_retryable_status, BasicWorker, CircuitBreakerConfig, HealthConfig, RetryExecutor, Worker,
WorkerRegistry, WorkerType,
};
use crate::metrics::RouterMetrics;
use crate::policies::LoadBalancingPolicy;
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
use crate::protocols::spec::{
ChatCompletionRequest, CompletionRequest, GenerateRequest, GenerationRequest, RerankRequest,
RerankResponse, RerankResult, ResponsesRequest,
......@@ -22,7 +22,7 @@ use axum::{
use futures_util::StreamExt;
use reqwest::Client;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, info, warn};
......@@ -30,8 +30,8 @@ use tracing::{debug, error, info, warn};
/// Regular router that uses injected load balancing policies
#[derive(Debug)]
pub struct Router {
workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
policy: Arc<dyn LoadBalancingPolicy>,
worker_registry: Arc<WorkerRegistry>,
policy_registry: Arc<PolicyRegistry>,
client: Client,
worker_startup_timeout_secs: u64,
worker_startup_check_interval_secs: u64,
......@@ -41,7 +41,6 @@ pub struct Router {
circuit_breaker_config: CircuitBreakerConfig,
_worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
_load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
_health_checker: Option<HealthChecker>,
}
impl Router {
......@@ -49,7 +48,6 @@ impl Router {
#[allow(clippy::too_many_arguments)]
pub async fn new(
worker_urls: Vec<String>,
policy: Arc<dyn LoadBalancingPolicy>,
ctx: &Arc<crate::server::AppContext>,
) -> Result<Self, String> {
// Update active workers gauge
......@@ -82,45 +80,51 @@ impl Router {
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
};
// Create Worker trait objects from URLs with health check config
let workers: Vec<Box<dyn Worker>> = worker_urls
.iter()
.map(|url| {
let worker = BasicWorker::new(url.clone(), WorkerType::Regular)
.with_circuit_breaker_config(core_cb_config.clone())
.with_health_config(HealthConfig {
timeout_secs: ctx.router_config.health_check.timeout_secs,
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
endpoint: ctx.router_config.health_check.endpoint.clone(),
failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: ctx.router_config.health_check.success_threshold,
});
Box::new(worker) as Box<dyn Worker>
})
.collect();
// Register workers in the registry
// In IGW mode, we need to fetch model info from workers
for url in &worker_urls {
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
// For now, create worker without model_id
let worker = BasicWorker::new(url.clone(), WorkerType::Regular)
.with_circuit_breaker_config(core_cb_config.clone())
.with_health_config(HealthConfig {
timeout_secs: ctx.router_config.health_check.timeout_secs,
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
endpoint: ctx.router_config.health_check.endpoint.clone(),
failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: ctx.router_config.health_check.success_threshold,
});
// Initialize policy with workers if needed (e.g., for cache-aware)
if let Some(cache_aware) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.init_workers(&workers);
let worker_arc = Arc::new(worker);
ctx.worker_registry.register(worker_arc.clone());
// Notify PolicyRegistry about the new worker
let model_id = worker_arc.model_id();
let policy = ctx.policy_registry.on_worker_added(model_id, None);
// If this is a cache-aware policy and it's the first worker for this model,
// initialize it with the worker
if policy.name() == "cache_aware" {
if let Some(cache_aware) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
let worker_dyn: Arc<dyn Worker> = worker_arc.clone();
cache_aware.init_workers(std::slice::from_ref(&worker_dyn));
}
}
}
let workers = Arc::new(RwLock::new(workers));
let health_checker = crate::core::start_health_checker(
Arc::clone(&workers),
ctx.router_config.worker_startup_check_interval_secs,
);
// Setup load monitoring for PowerOfTwo policy
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
let worker_loads = Arc::new(rx);
let load_monitor_handle = if policy.name() == "power_of_two" {
// Check if default policy is power_of_two for load monitoring
let default_policy = ctx.policy_registry.get_default_policy();
let load_monitor_handle = if default_policy.name() == "power_of_two" {
let monitor_urls = worker_urls.clone();
let monitor_interval = ctx.router_config.worker_startup_check_interval_secs;
let policy_clone = Arc::clone(&policy);
let policy_clone = default_policy.clone();
let client_clone = ctx.client.clone();
Some(Arc::new(tokio::spawn(async move {
......@@ -138,8 +142,8 @@ impl Router {
};
Ok(Router {
workers,
policy,
worker_registry: ctx.worker_registry.clone(),
policy_registry: ctx.policy_registry.clone(),
client: ctx.client.clone(),
worker_startup_timeout_secs: ctx.router_config.worker_startup_timeout_secs,
worker_startup_check_interval_secs: ctx
......@@ -151,18 +155,21 @@ impl Router {
circuit_breaker_config: core_cb_config,
_worker_loads: worker_loads,
_load_monitor_handle: load_monitor_handle,
_health_checker: Some(health_checker),
})
}
/// Get the current list of worker URLs
pub fn get_worker_urls(&self) -> Vec<String> {
self.workers
.read()
.unwrap()
.iter()
.map(|w| w.url().to_string())
.collect()
self.worker_registry.get_all_urls()
}
/// Get worker URLs for a specific model
pub fn get_worker_urls_for_model(&self, model_id: Option<&str>) -> Vec<String> {
let workers = match model_id {
Some(model) => self.worker_registry.get_by_model_fast(model),
None => self.worker_registry.get_all(),
};
workers.iter().map(|w| w.url().to_string()).collect()
}
pub async fn wait_for_healthy_workers(
......@@ -332,11 +339,27 @@ impl Router {
}
fn select_first_worker(&self) -> Result<String, String> {
let workers_guard = self.workers.read().unwrap();
if workers_guard.is_empty() {
let workers = self.worker_registry.get_all();
if workers.is_empty() {
Err("No workers are available".to_string())
} else {
Ok(workers_guard[0].url().to_string())
Ok(workers[0].url().to_string())
}
}
#[allow(dead_code)]
fn select_first_worker_for_model(&self, model_id: Option<&str>) -> Result<String, String> {
let workers = match model_id {
Some(model) => self.worker_registry.get_by_model_fast(model),
None => self.worker_registry.get_all(),
};
if workers.is_empty() {
Err(format!(
"No workers are available for model: {:?}",
model_id
))
} else {
Ok(workers[0].url().to_string())
}
}
......@@ -447,20 +470,35 @@ impl Router {
}
}
// New method to route typed requests directly
/// Select worker considering circuit breaker state
fn select_worker_with_circuit_breaker(&self, text: Option<&str>) -> Option<Box<dyn Worker>> {
let workers = self.workers.read().ok()?;
let available: Vec<Box<dyn Worker>> = workers
/// Select worker for a specific model considering circuit breaker state
fn select_worker_for_model(
&self,
model_id: Option<&str>,
text: Option<&str>,
) -> Option<Arc<dyn Worker>> {
// Get workers for the specified model (O(1) lookup if model_id is provided)
let workers = match model_id {
Some(model) => self.worker_registry.get_by_model_fast(model),
None => self.worker_registry.get_all(),
};
let available: Vec<Arc<dyn Worker>> = workers
.iter()
.filter(|w| w.is_available())
.map(|w| w.clone_worker())
.cloned()
.collect();
if available.is_empty() {
return None;
}
let idx = self.policy.select_worker(&available, text)?;
Some(available[idx].clone_worker())
// Get the appropriate policy for this model
let policy = match model_id {
Some(model) => self.policy_registry.get_policy_or_default(model),
None => self.policy_registry.get_default_policy(),
};
let idx = policy.select_worker(&available, text)?;
Some(available[idx].clone())
}
pub async fn route_typed_request<T: GenerationRequest + serde::Serialize + Clone>(
......@@ -468,6 +506,7 @@ impl Router {
headers: Option<&HeaderMap>,
typed_req: &T,
route: &str,
model_id: Option<&str>,
) -> Response {
let start = Instant::now();
let is_stream = typed_req.is_stream();
......@@ -477,7 +516,7 @@ impl Router {
&self.retry_config,
// operation per attempt
|_: u32| async {
let worker = match self.select_worker_with_circuit_breaker(Some(&text)) {
let worker = match self.select_worker_for_model(model_id, Some(&text)) {
Some(w) => w,
None => {
RouterMetrics::record_request_error(route, "no_available_workers");
......@@ -490,7 +529,13 @@ impl Router {
};
// Optional load tracking for cache-aware policy
let load_incremented = if self.policy.name() == "cache_aware" {
// Get the policy for this model to check if it's cache-aware
let policy = match model_id {
Some(model) => self.policy_registry.get_policy_or_default(model),
None => self.policy_registry.get_default_policy(),
};
let load_incremented = if policy.name() == "cache_aware" {
worker.increment_load();
RouterMetrics::set_running_requests(worker.url(), worker.load());
true
......@@ -654,11 +699,9 @@ impl Router {
// Decrement load on error if it was incremented
if load_incremented {
if let Ok(workers_guard) = self.workers.read() {
if let Some(worker) = workers_guard.iter().find(|w| w.url() == worker_url) {
worker.decrement_load();
RouterMetrics::set_running_requests(worker_url, worker.load());
}
if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
worker.decrement_load();
RouterMetrics::set_running_requests(worker_url, worker.load());
}
}
......@@ -687,13 +730,9 @@ impl Router {
Err(e) => {
// IMPORTANT: Decrement load on error before returning
if load_incremented {
if let Ok(workers_guard) = self.workers.read() {
if let Some(worker) =
workers_guard.iter().find(|w| w.url() == worker_url)
{
worker.decrement_load();
RouterMetrics::set_running_requests(worker_url, worker.load());
}
if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
worker.decrement_load();
RouterMetrics::set_running_requests(worker_url, worker.load());
}
}
......@@ -704,18 +743,16 @@ impl Router {
// Decrement load counter for non-streaming requests if it was incremented
if load_incremented {
if let Ok(workers_guard) = self.workers.read() {
if let Some(worker) = workers_guard.iter().find(|w| w.url() == worker_url) {
worker.decrement_load();
RouterMetrics::set_running_requests(worker_url, worker.load());
}
if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
worker.decrement_load();
RouterMetrics::set_running_requests(worker_url, worker.load());
}
}
response
} else if load_incremented {
// For streaming with load tracking, we need to manually decrement when done
let workers = Arc::clone(&self.workers);
let registry = Arc::clone(&self.worker_registry);
let worker_url = worker_url.to_string();
// Preserve headers for streaming response
......@@ -739,17 +776,10 @@ impl Router {
.windows(12)
.any(|window| window == b"data: [DONE]")
{
if let Ok(workers_guard) = workers.read() {
if let Some(worker) =
workers_guard.iter().find(|w| w.url() == worker_url)
{
worker.decrement_load();
RouterMetrics::set_running_requests(
&worker_url,
worker.load(),
);
decremented = true;
}
if let Some(worker) = registry.get_by_url(&worker_url) {
worker.decrement_load();
RouterMetrics::set_running_requests(&worker_url, worker.load());
decremented = true;
}
}
if tx.send(Ok(bytes)).is_err() {
......@@ -763,11 +793,9 @@ impl Router {
}
}
if !decremented {
if let Ok(workers_guard) = workers.read() {
if let Some(worker) = workers_guard.iter().find(|w| w.url() == worker_url) {
worker.decrement_load();
RouterMetrics::set_running_requests(&worker_url, worker.load());
}
if let Some(worker) = registry.get_by_url(&worker_url) {
worker.decrement_load();
RouterMetrics::set_running_requests(&worker_url, worker.load());
}
}
});
......@@ -839,7 +867,6 @@ impl Router {
match client.get(format!("{}/health", worker_url)).send().await {
Ok(res) => {
if res.status().is_success() {
let mut workers_guard = self.workers.write().unwrap();
if self.dp_aware {
// Need to contact the worker to extract the dp_size,
// and add them as multiple workers
......@@ -848,47 +875,78 @@ impl Router {
.map_err(|e| format!("Failed to get dp-aware workers: {}", e))?;
let mut worker_added: bool = false;
for dp_url in &dp_url_vec {
if workers_guard.iter().any(|w| w.url() == dp_url) {
if self.worker_registry.get_by_url(dp_url).is_some() {
warn!("Worker {} already exists", dp_url);
continue;
}
info!("Added worker: {}", dp_url);
let new_worker = WorkerFactory::create_regular_with_config(
dp_url.to_string(),
self.circuit_breaker_config.clone(),
);
workers_guard.push(new_worker);
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
let new_worker =
BasicWorker::new(dp_url.to_string(), WorkerType::Regular)
.with_circuit_breaker_config(
self.circuit_breaker_config.clone(),
);
let worker_arc = Arc::new(new_worker);
self.worker_registry.register(worker_arc.clone());
// Notify PolicyRegistry about the new worker
let model_id = worker_arc.model_id();
let policy = self.policy_registry.on_worker_added(model_id, None);
// If this is a cache-aware policy, update it with all workers for this model
if policy.name() == "cache_aware" {
if let Some(cache_aware) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>(
) {
let model_workers =
self.worker_registry.get_by_model_fast(model_id);
cache_aware.init_workers(&model_workers);
}
}
worker_added = true;
}
if !worker_added {
return Err(format!("No worker added for {}", worker_url));
}
} else {
if workers_guard.iter().any(|w| w.url() == worker_url) {
if self.worker_registry.get_by_url(worker_url).is_some() {
return Err(format!("Worker {} already exists", worker_url));
}
info!("Added worker: {}", worker_url);
let new_worker = WorkerFactory::create_regular_with_config(
worker_url.to_string(),
self.circuit_breaker_config.clone(),
);
workers_guard.push(new_worker);
}
RouterMetrics::set_active_workers(workers_guard.len());
// If cache aware policy, initialize the worker in the tree
if let Some(cache_aware) =
self.policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
// Get updated workers after adding
drop(workers_guard);
let workers_guard = self.workers.read().unwrap();
cache_aware.init_workers(&workers_guard);
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
let new_worker =
BasicWorker::new(worker_url.to_string(), WorkerType::Regular)
.with_circuit_breaker_config(
self.circuit_breaker_config.clone(),
);
let worker_arc = Arc::new(new_worker);
self.worker_registry.register(worker_arc.clone());
// Notify PolicyRegistry about the new worker
let model_id = worker_arc.model_id();
let policy = self.policy_registry.on_worker_added(model_id, None);
// If this is a cache-aware policy, add this worker to it
if policy.name() == "cache_aware" {
if let Some(cache_aware) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>(
) {
// Get all workers for this model
let model_workers =
self.worker_registry.get_by_model_fast(model_id);
cache_aware.init_workers(&model_workers);
}
}
}
RouterMetrics::set_active_workers(self.worker_registry.get_all().len());
return Ok(format!("Successfully added worker: {}", worker_url));
} else {
debug!(
......@@ -931,66 +989,73 @@ impl Router {
if self.dp_aware {
// remove dp-aware workers in a prefix-matching fashion
// without contacting the remote worker
let mut candidate_workers: Vec<String> = Vec::new();
let mut removed_workers: Vec<String> = Vec::new();
let worker_url_prefix = format!("{}@", worker_url);
{
// find the candidate workers to be removed
let workers_guard = self.workers.read().unwrap();
for w in workers_guard.iter() {
if w.url().starts_with(&worker_url_prefix) {
candidate_workers.push(w.url().to_string());
}
}
}
// Find and remove all workers with matching prefix
let all_workers = self.worker_registry.get_all();
for w in all_workers.iter() {
if w.url().starts_with(&worker_url_prefix) {
// Get model_id before removing
let model_id = w.model_id().to_string();
if self.worker_registry.remove_by_url(w.url()).is_some() {
info!("Removed worker: {}", w.url());
removed_workers.push(w.url().to_string());
{
// do the removing on the worker_urls
let mut workers_guard = self.workers.write().unwrap();
for dp_url in candidate_workers.iter() {
if let Some(index) = workers_guard.iter().position(|w| w.url() == dp_url) {
workers_guard.remove(index);
info!("Removed worker: {}", dp_url);
removed_workers.push(dp_url.to_string());
// Notify PolicyRegistry about the removed worker
self.policy_registry.on_worker_removed(&model_id);
} else {
warn!("Worker {} not found, skipping removal", dp_url);
continue;
warn!("Worker {} not found, skipping removal", w.url());
}
}
RouterMetrics::set_active_workers(workers_guard.len());
}
// If cache aware policy, remove the workers from the tree
if let Some(cache_aware) = self
.policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
for dp_url in removed_workers.iter() {
cache_aware.remove_worker(dp_url);
info!("Removed worker from tree: {}", dp_url);
RouterMetrics::set_active_workers(self.worker_registry.get_all().len());
// If any models are using cache aware policy, remove the workers from the tree
// Check each removed worker's model and get its policy
for dp_url in removed_workers.iter() {
if let Some(worker) = self.worker_registry.get_by_url(dp_url) {
let model_id = worker.model_id();
if let Some(policy) = self.policy_registry.get_policy(model_id) {
if let Some(cache_aware) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.remove_worker_by_url(dp_url);
info!("Removed worker from cache-aware tree: {}", dp_url);
}
}
}
}
} else {
let mut workers_guard = self.workers.write().unwrap();
if let Some(index) = workers_guard.iter().position(|w| w.url() == worker_url) {
workers_guard.remove(index);
info!("Removed worker: {}", worker_url);
RouterMetrics::set_active_workers(workers_guard.len());
// Get the worker first to extract model_id
let model_id = if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
worker.model_id().to_string()
} else {
warn!("Worker {} not found, skipping removal", worker_url);
return;
};
if self.worker_registry.remove_by_url(worker_url).is_some() {
info!("Removed worker: {}", worker_url);
// Notify PolicyRegistry about the removed worker
self.policy_registry.on_worker_removed(&model_id);
RouterMetrics::set_active_workers(self.worker_registry.get_all().len());
}
// If cache aware policy, remove the workers from the tree
if let Some(cache_aware) = self
.policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.remove_worker(worker_url);
info!("Removed worker from tree: {}", worker_url);
// If the model is using cache aware policy, remove the worker from the tree
if let Some(policy) = self.policy_registry.get_policy(&model_id) {
if let Some(cache_aware) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.remove_worker_by_url(worker_url);
info!("Removed worker from cache-aware tree: {}", worker_url);
}
}
}
}
......@@ -1171,7 +1236,7 @@ impl RouterTrait for Router {
}
async fn health(&self, _req: Request<Body>) -> Response {
let workers = self.workers.read().unwrap();
let workers = self.worker_registry.get_all();
let unhealthy_servers: Vec<_> = workers
.iter()
.filter(|w| !w.is_healthy())
......@@ -1209,16 +1274,19 @@ impl RouterTrait for Router {
&self,
headers: Option<&HeaderMap>,
body: &GenerateRequest,
model_id: Option<&str>,
) -> Response {
self.route_typed_request(headers, body, "/generate").await
self.route_typed_request(headers, body, "/generate", model_id)
.await
}
async fn route_chat(
&self,
headers: Option<&HeaderMap>,
body: &ChatCompletionRequest,
model_id: Option<&str>,
) -> Response {
self.route_typed_request(headers, body, "/v1/chat/completions")
self.route_typed_request(headers, body, "/v1/chat/completions", model_id)
.await
}
......@@ -1226,8 +1294,9 @@ impl RouterTrait for Router {
&self,
headers: Option<&HeaderMap>,
body: &CompletionRequest,
model_id: Option<&str>,
) -> Response {
self.route_typed_request(headers, body, "/v1/completions")
self.route_typed_request(headers, body, "/v1/completions", model_id)
.await
}
......@@ -1235,8 +1304,9 @@ impl RouterTrait for Router {
&self,
headers: Option<&HeaderMap>,
body: &ResponsesRequest,
model_id: Option<&str>,
) -> Response {
self.route_typed_request(headers, body, "/v1/responses")
self.route_typed_request(headers, body, "/v1/responses", model_id)
.await
}
......@@ -1244,11 +1314,18 @@ impl RouterTrait for Router {
todo!()
}
async fn route_rerank(&self, headers: Option<&HeaderMap>, body: &RerankRequest) -> Response {
async fn route_rerank(
&self,
headers: Option<&HeaderMap>,
body: &RerankRequest,
model_id: Option<&str>,
) -> Response {
if let Err(e) = body.validate() {
return (StatusCode::BAD_REQUEST, e).into_response();
}
let response = self.route_typed_request(headers, body, "/v1/rerank").await;
let response = self
.route_typed_request(headers, body, "/v1/rerank", model_id)
.await;
if response.status().is_success() {
match Self::build_rerank_response(body, response).await {
Ok(rerank_response) => rerank_response,
......@@ -1340,19 +1417,15 @@ impl RouterTrait for Router {
fn readiness(&self) -> Response {
// Regular router is ready if it has at least one healthy worker
let healthy_count = self
.workers
.read()
.unwrap()
.iter()
.filter(|w| w.is_healthy())
.count();
let workers = self.worker_registry.get_all();
let healthy_count = workers.iter().filter(|w| w.is_healthy()).count();
let total_workers = workers.len();
if healthy_count > 0 {
Json(serde_json::json!({
"status": "ready",
"healthy_workers": healthy_count,
"total_workers": self.workers.read().unwrap().len()
"total_workers": total_workers
}))
.into_response()
} else {
......@@ -1361,7 +1434,7 @@ impl RouterTrait for Router {
Json(serde_json::json!({
"status": "not_ready",
"reason": "no healthy workers available",
"total_workers": self.workers.read().unwrap().len()
"total_workers": total_workers
})),
)
.into_response()
......@@ -1372,18 +1445,25 @@ impl RouterTrait for Router {
#[cfg(test)]
mod tests {
use super::*;
use crate::policies::RandomPolicy;
use std::collections::HashMap;
fn create_test_regular_router() -> Router {
let workers = vec![
WorkerFactory::create_regular("http://worker1:8080".to_string()),
WorkerFactory::create_regular("http://worker2:8080".to_string()),
];
// Create registries
let worker_registry = Arc::new(WorkerRegistry::new());
let policy_registry = Arc::new(PolicyRegistry::new(
crate::config::types::PolicyConfig::RoundRobin,
));
// Register test workers
let worker1 = BasicWorker::new("http://worker1:8080".to_string(), WorkerType::Regular);
let worker2 = BasicWorker::new("http://worker2:8080".to_string(), WorkerType::Regular);
worker_registry.register(Arc::new(worker1));
worker_registry.register(Arc::new(worker2));
let (_, rx) = tokio::sync::watch::channel(HashMap::new());
Router {
workers: Arc::new(RwLock::new(workers)),
policy: Arc::new(RandomPolicy::new()),
worker_registry,
policy_registry,
worker_startup_timeout_secs: 5,
worker_startup_check_interval_secs: 1,
dp_aware: false,
......@@ -1393,7 +1473,6 @@ mod tests {
circuit_breaker_config: CircuitBreakerConfig::default(),
_worker_loads: Arc::new(rx),
_load_monitor_handle: None,
_health_checker: None,
}
}
......@@ -1413,7 +1492,9 @@ mod tests {
let result = router.select_first_worker();
assert!(result.is_ok());
assert_eq!(result.unwrap(), "http://worker1:8080");
let url = result.unwrap();
// DashMap doesn't guarantee order, so just check we get one of the workers
assert!(url == "http://worker1:8080" || url == "http://worker2:8080");
}
#[tokio::test]
......
......@@ -17,6 +17,7 @@ pub mod factory;
pub mod grpc;
pub mod header_utils;
pub mod http;
pub mod router_manager;
pub use factory::RouterFactory;
// Re-export HTTP routers for convenience (keeps routers::openai_router path working)
......@@ -63,14 +64,19 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
async fn get_model_info(&self, req: Request<Body>) -> Response;
/// Route a generate request
async fn route_generate(&self, headers: Option<&HeaderMap>, body: &GenerateRequest)
-> Response;
async fn route_generate(
&self,
headers: Option<&HeaderMap>,
body: &GenerateRequest,
model_id: Option<&str>,
) -> Response;
/// Route a chat completion request
async fn route_chat(
&self,
headers: Option<&HeaderMap>,
body: &ChatCompletionRequest,
model_id: Option<&str>,
) -> Response;
/// Route a completion request
......@@ -78,6 +84,7 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
&self,
headers: Option<&HeaderMap>,
body: &CompletionRequest,
model_id: Option<&str>,
) -> Response;
/// Route a responses request
......@@ -85,11 +92,17 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
&self,
headers: Option<&HeaderMap>,
body: &ResponsesRequest,
model_id: Option<&str>,
) -> Response;
async fn route_embeddings(&self, headers: Option<&HeaderMap>, body: Body) -> Response;
async fn route_rerank(&self, headers: Option<&HeaderMap>, body: &RerankRequest) -> Response;
async fn route_rerank(
&self,
headers: Option<&HeaderMap>,
body: &RerankRequest,
model_id: Option<&str>,
) -> Response;
/// Flush cache on all workers
async fn flush_cache(&self) -> Response;
......
//! Router Manager for coordinating multiple routers and workers
//!
//! Provides centralized management based on enable_igw flag:
//! - Single Router Mode (enable_igw=false): Router owns workers directly
//! - Multi-Router Mode (enable_igw=true): RouterManager coordinates everything
use crate::config::RouterConfig;
use crate::core::{CircuitBreakerConfig, Worker, WorkerFactory, WorkerRegistry};
use crate::protocols::spec::{
ChatCompletionRequest, CompletionRequest, GenerateRequest, RerankRequest, ResponsesRequest,
};
use crate::protocols::worker_spec::{
ServerInfo, WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse, WorkerInfo,
WorkerListResponse, WorkerStats, WorkerTypeStats,
};
use crate::routers::{RouterTrait, WorkerManagement};
use async_trait::async_trait;
use axum::{
body::Body,
extract::Request,
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
};
use dashmap::DashMap;
use std::sync::Arc;
use tracing::{info, warn};
/// Router identifier
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct RouterId(String);
impl RouterId {
pub fn new(id: String) -> Self {
Self(id)
}
pub fn as_str(&self) -> &str {
&self.0
}
}
/// Router Manager - Central coordinator for routers and workers
/// Only created when enable_igw=true
pub struct RouterManager {
/// Worker registry (single source of truth in multi-router mode)
worker_registry: Arc<WorkerRegistry>,
/// Policy registry for managing model-to-policy mappings
policy_registry: Arc<crate::policies::PolicyRegistry>,
/// All routers managed by this manager (max 4 routers in Phase 2)
/// RouterId examples: "http-regular", "http-pd", "grpc-regular", "grpc-pd"
routers: Arc<DashMap<RouterId, Arc<dyn RouterTrait>>>,
/// Default router for requests without specific routing
default_router: Option<RouterId>,
/// Model to router mapping for model-aware routing
/// Multiple models can be served by the same router
model_routers: Arc<DashMap<String, Vec<RouterId>>>,
/// HTTP client for querying worker info
client: reqwest::Client,
/// Configuration
#[allow(dead_code)] // May be used in future enhancements
config: RouterConfig,
}
impl RouterManager {
/// Create a new router manager with shared registries
pub fn new(
config: RouterConfig,
client: reqwest::Client,
worker_registry: Arc<WorkerRegistry>,
policy_registry: Arc<crate::policies::PolicyRegistry>,
) -> Self {
Self {
worker_registry,
policy_registry,
routers: Arc::new(DashMap::new()),
default_router: None,
model_routers: Arc::new(DashMap::new()),
client,
config,
}
}
/// Register a router with the manager
pub fn register_router(
&mut self,
id: RouterId,
router: Arc<dyn RouterTrait>,
models: Vec<String>,
) {
// Store router
self.routers.insert(id.clone(), router);
// Update model mappings
for model in models {
self.model_routers
.entry(model)
.or_default()
.push(id.clone());
}
// Set as default if first router
if self.default_router.is_none() {
self.default_router = Some(id.clone());
info!("Set default router to {}", id.as_str());
}
}
/// Set the default router
pub fn set_default_router(&mut self, id: RouterId) {
self.default_router = Some(id);
}
/// Get the number of registered routers
pub fn router_count(&self) -> usize {
self.routers.len()
}
/// Get router for a specific model
pub fn get_router_for_model(&self, model_id: &str) -> Option<Arc<dyn RouterTrait>> {
// First try model-specific routers
if let Some(router_ids) = self.model_routers.get(model_id) {
if let Some(router_id) = router_ids.first() {
if let Some(router) = self.routers.get(router_id) {
return Some(router.clone());
}
}
}
// Fall back to default router
if let Some(ref default_id) = self.default_router {
self.routers.get(default_id).map(|r| r.clone())
} else {
None
}
}
/// Get workers for routing decision
pub fn get_workers_for_request(&self, model_id: Option<&str>) -> Vec<Arc<dyn Worker>> {
if let Some(model) = model_id {
self.worker_registry.get_by_model(model)
} else {
self.worker_registry.get_all()
}
}
/// Add a worker to the registry
pub async fn add_worker(
&self,
config: WorkerConfigRequest,
) -> Result<WorkerApiResponse, WorkerErrorResponse> {
// Build labels from configuration
let mut labels = config.labels.clone();
// Query server info if model_id not provided
let model_id = if let Some(model_id) = config.model_id {
model_id
} else {
match self.query_server_info(&config.url).await {
Ok(info) => {
// Extract model_id from server info
info.model_id
.or_else(|| {
info.model_path
.as_ref()
.and_then(|path| path.split('/').next_back().map(|s| s.to_string()))
})
.unwrap_or_else(|| "unknown".to_string())
}
Err(e) => {
warn!("Failed to query server info from {}: {}", config.url, e);
"unknown".to_string()
}
}
};
// Add configuration to labels
labels.insert("model_id".to_string(), model_id.clone());
if let Some(priority) = config.priority {
labels.insert("priority".to_string(), priority.to_string());
}
if let Some(cost) = config.cost {
labels.insert("cost".to_string(), cost.to_string());
}
// Add gRPC-specific configuration if provided
if let Some(tokenizer_path) = config.tokenizer_path {
labels.insert("tokenizer_path".to_string(), tokenizer_path);
}
if let Some(reasoning_parser) = config.reasoning_parser {
labels.insert("reasoning_parser".to_string(), reasoning_parser);
}
if let Some(tool_parser) = config.tool_parser {
labels.insert("tool_parser".to_string(), tool_parser);
}
if let Some(chat_template) = config.chat_template {
labels.insert("chat_template".to_string(), chat_template);
}
// Create worker based on type
// Note: For prefill and decode workers, we can't easily add labels after creation
// since they return Box<dyn Worker>. We'll need to enhance WorkerFactory in the future.
let worker = match config.worker_type.as_deref() {
Some("prefill") => {
// For now, prefill workers won't have custom labels
// TODO: Enhance WorkerFactory to accept labels for prefill workers
WorkerFactory::create_prefill(config.url.clone(), config.bootstrap_port)
}
Some("decode") => {
// For now, decode workers won't have custom labels
// TODO: Enhance WorkerFactory to accept labels for decode workers
WorkerFactory::create_decode(config.url.clone())
}
_ => {
// Regular workers can have labels
WorkerFactory::create_regular_with_labels(
config.url.clone(),
labels.clone(),
CircuitBreakerConfig::default(),
)
}
};
// Register worker
let worker_id = self.worker_registry.register(Arc::from(worker));
// Notify PolicyRegistry about the new worker
// Extract policy hint from labels if provided
let policy_hint = labels.get("policy").map(|s| s.as_str());
let policy = self.policy_registry.on_worker_added(&model_id, policy_hint);
info!(
"Added worker {} with URL {} for model {} using policy {}",
worker_id.as_str(),
config.url,
model_id,
policy.name()
);
// Return worker info
let worker_arc = self.worker_registry.get(&worker_id).unwrap();
let worker_info = self.worker_to_info(worker_id.as_str(), &worker_arc);
Ok(WorkerApiResponse {
success: true,
message: format!("Worker {} added successfully", worker_id.as_str()),
worker: Some(worker_info),
})
}
/// Remove a worker from the registry
pub fn remove_worker_from_registry(
&self,
url: &str,
) -> Result<WorkerApiResponse, WorkerErrorResponse> {
// Get worker to extract model_id before removing
let model_id = self
.worker_registry
.get_by_url(url)
.map(|worker| worker.model_id().to_string());
if let Some(_worker) = self.worker_registry.remove_by_url(url) {
// Notify PolicyRegistry about worker removal
if let Some(model_id) = model_id {
self.policy_registry.on_worker_removed(&model_id);
info!("Removed worker with URL {} for model {}", url, model_id);
} else {
info!("Removed worker with URL {}", url);
}
Ok(WorkerApiResponse {
success: true,
message: format!("Worker {} removed successfully", url),
worker: None,
})
} else {
Err(WorkerErrorResponse {
error: format!("Worker with URL {} not found", url),
code: "WORKER_NOT_FOUND".to_string(),
})
}
}
/// List all workers
pub fn list_workers(&self) -> WorkerListResponse {
let workers = self.worker_registry.get_all_with_ids();
let worker_infos: Vec<WorkerInfo> = workers
.iter()
.map(|(id, w)| self.worker_to_info(id.as_str(), w))
.collect();
let total = worker_infos.len();
// Get stats from the worker registry
let registry_stats = self.worker_registry.stats();
// Convert WorkerRegistryStats to WorkerStats
let stats = WorkerStats {
total_workers: registry_stats.total_workers,
healthy_workers: registry_stats.healthy_workers,
total_models: registry_stats.total_models,
total_load: registry_stats.total_load,
by_type: WorkerTypeStats {
regular: registry_stats.regular_workers,
prefill: registry_stats.prefill_workers,
decode: registry_stats.decode_workers,
},
};
WorkerListResponse {
workers: worker_infos,
total,
stats,
}
}
/// Get worker by URL
pub fn get_worker(&self, url: &str) -> Option<WorkerInfo> {
self.worker_registry
.get_by_url(url)
.map(|w| self.worker_to_info("unknown", &w))
}
/// Query server info from a worker URL
async fn query_server_info(&self, url: &str) -> Result<ServerInfo, String> {
let info_url = format!("{}/get_server_info", url.trim_end_matches('/'));
match self.client.get(&info_url).send().await {
Ok(response) => {
if response.status().is_success() {
response
.json::<ServerInfo>()
.await
.map_err(|e| format!("Failed to parse server info: {}", e))
} else {
Err(format!("Server returned status: {}", response.status()))
}
}
Err(e) => Err(format!("Failed to connect to server: {}", e)),
}
}
/// Convert Worker to WorkerInfo
fn worker_to_info(&self, id: &str, worker: &Arc<dyn Worker>) -> WorkerInfo {
let metadata = worker.metadata();
WorkerInfo {
id: id.to_string(),
url: worker.url().to_string(),
model_id: worker.model_id().to_string(),
priority: worker.priority(),
cost: worker.cost(),
worker_type: format!("{:?}", worker.worker_type()),
is_healthy: worker.is_healthy(),
load: worker.load(),
connection_mode: format!("{:?}", worker.connection_mode()),
tokenizer_path: worker.tokenizer_path().map(|s| s.to_string()),
reasoning_parser: worker.reasoning_parser().map(|s| s.to_string()),
tool_parser: worker.tool_parser().map(|s| s.to_string()),
chat_template: worker.chat_template().map(|s| s.to_string()),
metadata: metadata.labels.clone(),
}
}
// Note: calculate_stats removed - using WorkerRegistry::stats() instead
// === Phase 2: Router Management ===
// Note: Dynamic router creation removed - routers are created and registered externally
/// Get the appropriate router for a request based on headers and request content
pub fn select_router_for_request(
&self,
headers: Option<&HeaderMap>,
model_id: Option<&str>,
) -> Option<Arc<dyn RouterTrait>> {
// Extract priority and cost preferences from headers if available
let _priority_threshold = headers.and_then(|h| {
h.get("x-worker-priority")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u32>().ok())
});
let _max_cost = headers.and_then(|h| {
h.get("x-max-cost")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<f32>().ok())
});
// Check if PD (prefill-decode) mode is preferred from headers
let prefer_pd = headers
.and_then(|h| {
h.get("x-prefer-pd")
.and_then(|v| v.to_str().ok())
.map(|s| s == "true" || s == "1")
})
.unwrap_or(false);
// If model specified, find routers serving that model
let candidate_routers = if let Some(model) = model_id {
// Get routers for specific model
if let Some(router_ids) = self.model_routers.get(model) {
router_ids
.iter()
.filter_map(|id| self.routers.get(id).map(|r| r.clone()))
.collect::<Vec<_>>()
} else {
Vec::new()
}
} else {
// No model specified, consider all routers
self.routers
.iter()
.map(|entry| entry.value().clone())
.collect::<Vec<_>>()
};
if candidate_routers.is_empty() {
// No routers found for the specified model
return None;
}
// Score routers based on worker attributes and request preferences
let mut best_router = None;
let mut best_score = 0.0;
for router in candidate_routers {
let mut score = 1.0;
// Check if this is a PD router
let is_pd = router.is_pd_mode();
if prefer_pd && is_pd {
score += 2.0; // Bonus for matching PD preference
} else if !prefer_pd && !is_pd {
score += 1.0; // Bonus for matching regular preference
}
// Get workers for this router and evaluate based on priority/cost
// Note: This would require routers to expose their workers or stats
// For now, we'll use a simple selection based on router type
// TODO: Once routers expose worker stats, we can evaluate:
// - Average worker priority vs priority_threshold
// - Average worker cost vs max_cost
// - Current load and health status
if score > best_score {
best_score = score;
best_router = Some(router);
}
}
best_router
}
}
// Note: Default implementation removed as RouterManager now requires AppContext
// which cannot be defaulted. RouterManager must be created with explicit context.
// === Phase 2: RouterManager as RouterTrait ===
/// RouterManager implements RouterTrait to act as a meta-router
/// that delegates requests to the appropriate underlying router
#[async_trait]
impl WorkerManagement for RouterManager {
/// Add a worker - in multi-router mode, this adds to the registry
async fn add_worker(&self, worker_url: &str) -> Result<String, String> {
// Create a basic worker config request
let config = WorkerConfigRequest {
url: worker_url.to_string(),
model_id: None,
worker_type: None,
priority: None,
cost: None,
labels: std::collections::HashMap::new(),
bootstrap_port: None,
tokenizer_path: None,
reasoning_parser: None,
tool_parser: None,
chat_template: None,
};
match self.add_worker(config).await {
Ok(response) => Ok(response.message),
Err(e) => Err(e.error),
}
}
/// Remove a worker from the registry
fn remove_worker(&self, worker_url: &str) {
let _ = self.remove_worker_from_registry(worker_url);
}
/// Get all worker URLs from the registry
fn get_worker_urls(&self) -> Vec<String> {
self.worker_registry.get_all_urls()
}
}
#[async_trait]
impl RouterTrait for RouterManager {
fn as_any(&self) -> &dyn std::any::Any {
self
}
/// Health check - return 503 if no routers available
async fn health(&self, _req: Request<Body>) -> Response {
// Health check should succeed if RouterManager exists, even without routers
// Individual router health can be checked via specific endpoints
(StatusCode::OK, "RouterManager is healthy").into_response()
}
/// Health generate - check if any router can handle generate requests
async fn health_generate(&self, _req: Request<Body>) -> Response {
// Return 503 since we have no routers with workers
// TODO: Should check if any router has healthy workers
(
StatusCode::SERVICE_UNAVAILABLE,
"No routers with healthy workers available",
)
.into_response()
}
/// Get server information - aggregate from all routers
async fn get_server_info(&self, _req: Request<Body>) -> Response {
// TODO: Aggregate info from all routers with healthy workers
// For now, return basic info about the RouterManager
(
StatusCode::OK,
serde_json::json!({
"router_manager": true,
"routers_count": self.routers.len(),
"workers_count": self.worker_registry.get_all().len()
})
.to_string(),
)
.into_response()
}
/// Get available models - aggregate from all routers
async fn get_models(&self, _req: Request<Body>) -> Response {
// Return models that have registered routers
let models = self
.model_routers
.iter()
.map(|entry| entry.key().clone())
.collect::<Vec<_>>();
if models.is_empty() {
(StatusCode::SERVICE_UNAVAILABLE, "No models available").into_response()
} else {
(
StatusCode::OK,
serde_json::json!({
"models": models
})
.to_string(),
)
.into_response()
}
}
/// Get model information
async fn get_model_info(&self, _req: Request<Body>) -> Response {
// TODO: Extract model from request and route to appropriate router
// For now, return not implemented
(
StatusCode::NOT_IMPLEMENTED,
"Model info endpoint not yet implemented in RouterManager",
)
.into_response()
}
/// Route a generate request
async fn route_generate(
&self,
headers: Option<&HeaderMap>,
body: &GenerateRequest,
_model_id: Option<&str>,
) -> Response {
// Select router based on headers
// GenerateRequest doesn't have a model field
let router = self.select_router_for_request(headers, None);
if let Some(router) = router {
// In multi-model mode, pass None since GenerateRequest doesn't have model field
router.route_generate(headers, body, None).await
} else {
// Return 404 when no router is available for the request
(
StatusCode::NOT_FOUND,
"No router available for this request",
)
.into_response()
}
}
/// Route a chat completion request
async fn route_chat(
&self,
headers: Option<&HeaderMap>,
body: &ChatCompletionRequest,
_model_id: Option<&str>,
) -> Response {
// Select router based on headers and model
let router = self.select_router_for_request(headers, Some(&body.model));
if let Some(router) = router {
// In multi-model mode, pass the model_id to the router
router.route_chat(headers, body, Some(&body.model)).await
} else {
// Return 404 when the specified model is not found
(
StatusCode::NOT_FOUND,
format!("Model '{}' not found or no router available", body.model),
)
.into_response()
}
}
/// Route a completion request
async fn route_completion(
&self,
headers: Option<&HeaderMap>,
body: &CompletionRequest,
_model_id: Option<&str>,
) -> Response {
// Select router based on headers and model
let router = self.select_router_for_request(headers, Some(&body.model));
if let Some(router) = router {
// In multi-model mode, pass the model_id to the router
router
.route_completion(headers, body, Some(&body.model))
.await
} else {
// Return 404 when the specified model is not found
(
StatusCode::NOT_FOUND,
format!("Model '{}' not found or no router available", body.model),
)
.into_response()
}
}
async fn route_responses(
&self,
_headers: Option<&HeaderMap>,
_body: &ResponsesRequest,
_model_id: Option<&str>,
) -> Response {
todo!()
}
/// Route embeddings request
async fn route_embeddings(&self, headers: Option<&HeaderMap>, body: Body) -> Response {
// Try to select a router based on headers
let router = self.select_router_for_request(headers, None);
if let Some(router) = router {
router.route_embeddings(headers, body).await
} else {
(
StatusCode::NOT_FOUND,
"No router available for embeddings request",
)
.into_response()
}
}
/// Route rerank request
async fn route_rerank(
&self,
headers: Option<&HeaderMap>,
body: &RerankRequest,
model_id: Option<&str>,
) -> Response {
// Try to select a router based on headers
let router = self.select_router_for_request(headers, None);
if let Some(router) = router {
router.route_rerank(headers, body, model_id).await
} else {
(
StatusCode::NOT_FOUND,
"No router available for rerank request",
)
.into_response()
}
}
/// Flush cache on all routers and workers
async fn flush_cache(&self) -> Response {
// TODO: Call flush_cache on all routers that have workers
// For now, return success if we have any routers
if self.routers.is_empty() {
(StatusCode::SERVICE_UNAVAILABLE, "No routers configured").into_response()
} else {
// TODO: Actually flush cache on all routers
(StatusCode::OK, "Cache flush requested").into_response()
}
}
/// Get worker loads from all routers
async fn get_worker_loads(&self) -> Response {
// Return worker loads from the registry
let workers = self.worker_registry.get_all();
let loads: Vec<serde_json::Value> = workers
.iter()
.map(|w| {
serde_json::json!({
"url": w.url(),
"model": w.model_id(),
"load": w.load(),
"is_healthy": w.is_healthy()
})
})
.collect();
(
StatusCode::OK,
serde_json::json!({
"workers": loads
})
.to_string(),
)
.into_response()
}
/// Get router type name
fn router_type(&self) -> &'static str {
"manager"
}
/// Server readiness check - check if any router is ready
fn readiness(&self) -> Response {
if self.routers.is_empty() {
(StatusCode::SERVICE_UNAVAILABLE, "No routers configured").into_response()
} else {
// TODO: Check readiness of all routers
(StatusCode::OK, "Ready").into_response()
}
}
}
// Note: get_first_available_router removed - we now properly handle
// router selection based on model and worker availability
impl std::fmt::Debug for RouterManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RouterManager")
.field("routers_count", &self.routers.len())
.field("workers_count", &self.worker_registry.get_all().len())
.field("default_router", &self.default_router)
.finish()
}
}
use crate::config::RouterConfig;
use crate::core::WorkerRegistry;
use crate::logging::{self, LoggingConfig};
use crate::metrics::{self, PrometheusConfig};
use crate::middleware::TokenBucket;
use crate::policies::PolicyRegistry;
use crate::protocols::spec::{
ChatCompletionRequest, CompletionRequest, GenerateRequest, RerankRequest, ResponsesRequest,
V1RerankReqInput,
};
use crate::protocols::worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse};
use crate::reasoning_parser::ParserFactory;
use crate::routers::router_manager::{RouterId, RouterManager};
use crate::routers::{RouterFactory, RouterTrait};
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
use crate::tokenizer::{factory as tokenizer_factory, traits::Tokenizer};
......@@ -36,6 +40,9 @@ pub struct AppContext {
pub tokenizer: Option<Arc<dyn Tokenizer>>,
pub reasoning_parser_factory: Option<ParserFactory>,
pub tool_parser_registry: Option<&'static ParserRegistry>,
pub worker_registry: Arc<WorkerRegistry>, // Shared worker registry
pub policy_registry: Arc<PolicyRegistry>, // Shared policy registry
pub router_manager: Option<Arc<RouterManager>>, // Only present when enable_igw=true
}
impl AppContext {
......@@ -75,6 +82,15 @@ impl AppContext {
(None, None, None)
};
// Initialize shared registries
let worker_registry = Arc::new(WorkerRegistry::new());
let policy_registry = Arc::new(PolicyRegistry::new(
router_config.policy.clone(), // Use default policy from config
));
// Initialize RouterManager only when enable_igw is true
let router_manager = None; // Will be initialized in startup() based on config
Ok(Self {
client,
router_config,
......@@ -82,6 +98,9 @@ impl AppContext {
tokenizer,
reasoning_parser_factory,
tool_parser_registry,
worker_registry,
policy_registry,
router_manager,
})
}
}
......@@ -134,7 +153,10 @@ async fn generate(
headers: http::HeaderMap,
Json(body): Json<GenerateRequest>,
) -> Response {
state.router.route_generate(Some(&headers), &body).await
state
.router
.route_generate(Some(&headers), &body, None)
.await
}
async fn v1_chat_completions(
......@@ -142,7 +164,7 @@ async fn v1_chat_completions(
headers: http::HeaderMap,
Json(body): Json<ChatCompletionRequest>,
) -> Response {
state.router.route_chat(Some(&headers), &body).await
state.router.route_chat(Some(&headers), &body, None).await
}
async fn v1_completions(
......@@ -150,7 +172,10 @@ async fn v1_completions(
headers: http::HeaderMap,
Json(body): Json<CompletionRequest>,
) -> Response {
state.router.route_completion(Some(&headers), &body).await
state
.router
.route_completion(Some(&headers), &body, None)
.await
}
async fn rerank(
......@@ -158,7 +183,7 @@ async fn rerank(
headers: http::HeaderMap,
Json(body): Json<RerankRequest>,
) -> Response {
state.router.route_rerank(Some(&headers), &body).await
state.router.route_rerank(Some(&headers), &body, None).await
}
async fn v1_rerank(
......@@ -168,7 +193,7 @@ async fn v1_rerank(
) -> Response {
state
.router
.route_rerank(Some(&headers), &body.into())
.route_rerank(Some(&headers), &body.into(), None)
.await
}
......@@ -177,7 +202,10 @@ async fn v1_responses(
headers: http::HeaderMap,
Json(body): Json<ResponsesRequest>,
) -> Response {
state.router.route_responses(Some(&headers), &body).await
state
.router
.route_responses(Some(&headers), &body, None)
.await
}
// Worker management endpoints
......@@ -232,6 +260,137 @@ async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Respons
state.router.get_worker_loads().await
}
// New RESTful worker management endpoints (when enable_igw=true)
/// POST /workers - Add a new worker with full configuration
async fn create_worker(
State(state): State<Arc<AppState>>,
Json(config): Json<WorkerConfigRequest>,
) -> Response {
// Check if RouterManager is available (enable_igw=true)
if let Some(router_manager) = &state.context.router_manager {
match router_manager.add_worker(config).await {
Ok(response) => (StatusCode::OK, Json(response)).into_response(),
Err(error) => (StatusCode::BAD_REQUEST, Json(error)).into_response(),
}
} else {
// In single router mode, use the router's add_worker with basic config
match state.router.add_worker(&config.url).await {
Ok(message) => {
let response = WorkerApiResponse {
success: true,
message,
worker: None,
};
(StatusCode::OK, Json(response)).into_response()
}
Err(error) => {
let error_response = WorkerErrorResponse {
error,
code: "ADD_WORKER_FAILED".to_string(),
};
(StatusCode::BAD_REQUEST, Json(error_response)).into_response()
}
}
}
}
/// GET /workers - List all workers with details
async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
if let Some(router_manager) = &state.context.router_manager {
let response = router_manager.list_workers();
Json(response).into_response()
} else {
// In single router mode, get detailed worker info from registry
let workers = state.context.worker_registry.get_all();
let response = serde_json::json!({
"workers": workers.iter().map(|worker| {
let mut worker_info = serde_json::json!({
"url": worker.url(),
"model_id": worker.model_id(),
"worker_type": format!("{:?}", worker.worker_type()),
"is_healthy": worker.is_healthy(),
"load": worker.load(),
"connection_mode": format!("{:?}", worker.connection_mode()),
"priority": worker.priority(),
"cost": worker.cost(),
});
// Add bootstrap_port for Prefill workers
if let crate::core::WorkerType::Prefill { bootstrap_port } = worker.worker_type() {
worker_info["bootstrap_port"] = serde_json::json!(bootstrap_port);
}
worker_info
}).collect::<Vec<_>>(),
"total": workers.len(),
"stats": {
"prefill_count": state.context.worker_registry.get_prefill_workers().len(),
"decode_count": state.context.worker_registry.get_decode_workers().len(),
"regular_count": state.context.worker_registry.get_by_type(&crate::core::WorkerType::Regular).len(),
}
});
Json(response).into_response()
}
}
/// GET /workers/{url} - Get specific worker info
async fn get_worker(
State(state): State<Arc<AppState>>,
axum::extract::Path(url): axum::extract::Path<String>,
) -> Response {
if let Some(router_manager) = &state.context.router_manager {
if let Some(worker) = router_manager.get_worker(&url) {
Json(worker).into_response()
} else {
let error = WorkerErrorResponse {
error: format!("Worker {} not found", url),
code: "WORKER_NOT_FOUND".to_string(),
};
(StatusCode::NOT_FOUND, Json(error)).into_response()
}
} else {
// In single router mode, check if worker exists
let workers = state.router.get_worker_urls();
if workers.contains(&url) {
let worker_info = serde_json::json!({
"url": url,
"model_id": "unknown",
"is_healthy": true
});
Json(worker_info).into_response()
} else {
let error = WorkerErrorResponse {
error: format!("Worker {} not found", url),
code: "WORKER_NOT_FOUND".to_string(),
};
(StatusCode::NOT_FOUND, Json(error)).into_response()
}
}
}
/// DELETE /workers/{url} - Remove a worker
async fn delete_worker(
State(state): State<Arc<AppState>>,
axum::extract::Path(url): axum::extract::Path<String>,
) -> Response {
if let Some(router_manager) = &state.context.router_manager {
match router_manager.remove_worker_from_registry(&url) {
Ok(response) => (StatusCode::OK, Json(response)).into_response(),
Err(error) => (StatusCode::BAD_REQUEST, Json(error)).into_response(),
}
} else {
// In single router mode, use router's remove_worker
state.router.remove_worker(&url);
let response = WorkerApiResponse {
success: true,
message: format!("Worker {} removed successfully", url),
worker: None,
};
(StatusCode::OK, Json(response)).into_response()
}
}
pub struct ServerConfig {
pub host: String,
pub port: u16,
......@@ -281,11 +440,19 @@ pub fn build_app(
.route("/flush_cache", post(flush_cache))
.route("/get_loads", get(get_loads));
// Worker management routes
let worker_routes = Router::new()
.route("/workers", post(create_worker))
.route("/workers", get(list_workers_rest))
.route("/workers/{url}", get(get_worker))
.route("/workers/{url}", axum::routing::delete(delete_worker));
// Build app with all routes and middleware
Router::new()
.merge(protected_routes)
.merge(public_routes)
.merge(admin_routes)
.merge(worker_routes)
// Request body size limiting
.layer(tower_http::limit::RequestBodyLimitLayer::new(
max_payload_size,
......@@ -355,15 +522,100 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
.expect("Failed to create HTTP client");
// Create the application context with all dependencies
let app_context = Arc::new(AppContext::new(
let app_context = AppContext::new(
config.router_config.clone(),
client.clone(),
config.router_config.max_concurrent_requests,
config.router_config.rate_limit_tokens_per_second,
)?);
)?;
let app_context = Arc::new(app_context);
// Create the appropriate router based on enable_igw flag
let router: Box<dyn RouterTrait> = if config.router_config.enable_igw {
info!("Multi-router mode enabled (enable_igw=true)");
// Create RouterManager with shared registries from AppContext
let mut router_manager = RouterManager::new(
config.router_config.clone(),
client.clone(),
app_context.worker_registry.clone(),
app_context.policy_registry.clone(),
);
// Create HTTP routers at startup (with empty worker lists)
// Workers will be added to these routers dynamically via RouterManager's worker registry
// 1. HTTP Regular Router
match RouterFactory::create_regular_router(
&[], // Empty worker list - workers added later
&app_context,
)
.await
{
Ok(http_regular) => {
info!("Created HTTP Regular router");
router_manager.register_router(
RouterId::new("http-regular".to_string()),
Arc::from(http_regular),
vec![], // Models will be determined by workers
);
}
Err(e) => {
warn!("Failed to create HTTP Regular router: {}", e);
}
}
// 2. HTTP PD Router
match RouterFactory::create_pd_router(
&[], // Empty prefill URLs
&[], // Empty decode URLs
None, // Use default prefill policy
None, // Use default decode policy
&config.router_config.policy,
&app_context,
)
.await
{
Ok(http_pd) => {
info!("Created HTTP PD router");
router_manager.register_router(
RouterId::new("http-pd".to_string()),
Arc::from(http_pd),
vec![],
);
}
Err(e) => {
warn!("Failed to create HTTP PD router: {}", e);
}
}
// TODO: Add gRPC routers once we have dynamic tokenizer loading
// Currently gRPC routers require tokenizer to be initialized first,
// but each model needs its own tokenizer. Once we implement dynamic
// tokenizer loading per model, we can enable gRPC routers here:
// - RouterType::GrpcRegular (RouterId: "grpc-regular")
// - RouterType::GrpcPd (RouterId: "grpc-pd")
// Create router with the context
let router = RouterFactory::create_router(&app_context).await?;
info!(
"RouterManager initialized with {} routers",
router_manager.router_count()
);
Box::new(router_manager)
} else {
info!("Single router mode (enable_igw=false)");
// Create single router with the context
RouterFactory::create_router(&app_context).await?
};
// Start health checker for all workers in the registry
let _health_checker = app_context
.worker_registry
.start_health_checker(config.router_config.health_check.check_interval_secs);
info!(
"Started health checker for workers with {}s interval",
config.router_config.health_check.check_interval_secs
);
// Set up concurrency limiter with queue if configured
let (limiter, processor) = crate::middleware::ConcurrencyLimiter::new(
......
......@@ -579,9 +579,8 @@ mod tests {
// Helper to create a Router instance for testing event handlers
async fn create_test_router() -> Arc<dyn RouterTrait> {
use crate::config::{PolicyConfig, RouterConfig};
use crate::config::RouterConfig;
use crate::middleware::TokenBucket;
use crate::policies::PolicyFactory;
use crate::routers::http::router::Router;
use crate::server::AppContext;
......@@ -591,15 +590,19 @@ mod tests {
// Create AppContext with minimal components
let app_context = Arc::new(AppContext {
client: reqwest::Client::new(),
router_config,
router_config: router_config.clone(),
rate_limiter: Arc::new(TokenBucket::new(1000, 1000)),
worker_registry: Arc::new(crate::core::WorkerRegistry::new()),
policy_registry: Arc::new(crate::policies::PolicyRegistry::new(
router_config.policy.clone(),
)),
tokenizer: None, // HTTP mode doesn't need tokenizer
reasoning_parser_factory: None, // HTTP mode doesn't need reasoning parser
tool_parser_registry: None, // HTTP mode doesn't need tool parser
router_manager: None, // Test doesn't need router manager
});
let policy = PolicyFactory::create_from_config(&PolicyConfig::Random);
let router = Router::new(vec![], policy, &app_context).await.unwrap();
let router = Router::new(vec![], &app_context).await.unwrap();
Arc::new(router) as Arc<dyn RouterTrait>
}
......
use sglang_router_rs::core::{BasicWorker, Worker, WorkerType};
use sglang_router_rs::policies::{CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy};
use std::collections::HashMap;
use std::sync::Arc;
#[test]
fn test_backward_compatibility_with_empty_model_id() {
let config = CacheAwareConfig {
cache_threshold: 0.5,
balance_abs_threshold: 2,
balance_rel_threshold: 1.5,
eviction_interval_secs: 0, // Disable background eviction for testing
max_tree_size: 100,
};
let policy = CacheAwarePolicy::with_config(config);
// Create workers with empty model_id (simulating existing routers)
let worker1 = BasicWorker::new("http://worker1:8080".to_string(), WorkerType::Regular);
// No model_id label - should default to "unknown"
let mut labels2 = HashMap::new();
labels2.insert("model_id".to_string(), "unknown".to_string());
let worker2 = BasicWorker::new("http://worker2:8080".to_string(), WorkerType::Regular)
.with_labels(labels2);
// Add workers - should both go to "default" tree
policy.add_worker(&worker1);
policy.add_worker(&worker2);
// Create worker list
let workers: Vec<Arc<dyn Worker>> = vec![Arc::new(worker1.clone()), Arc::new(worker2.clone())];
// Select worker - should work without errors
let selected = policy.select_worker(&workers, Some("test request"));
assert!(selected.is_some(), "Should select a worker");
// Remove workers - should work without errors
policy.remove_worker(&worker1);
policy.remove_worker(&worker2);
}
#[test]
fn test_mixed_model_ids() {
let config = CacheAwareConfig {
cache_threshold: 0.5,
balance_abs_threshold: 2,
balance_rel_threshold: 1.5,
eviction_interval_secs: 0,
max_tree_size: 100,
};
let policy = CacheAwarePolicy::with_config(config);
// Create workers with different model_id scenarios
let worker1 = BasicWorker::new("http://worker1:8080".to_string(), WorkerType::Regular);
// No model_id label - defaults to "unknown" which goes to "default" tree
let mut labels2 = HashMap::new();
labels2.insert("model_id".to_string(), "llama-3".to_string());
let worker2 = BasicWorker::new("http://worker2:8080".to_string(), WorkerType::Regular)
.with_labels(labels2);
let mut labels3 = HashMap::new();
labels3.insert("model_id".to_string(), "unknown".to_string());
let worker3 = BasicWorker::new("http://worker3:8080".to_string(), WorkerType::Regular)
.with_labels(labels3);
let mut labels4 = HashMap::new();
labels4.insert("model_id".to_string(), "llama-3".to_string());
let worker4 = BasicWorker::new("http://worker4:8080".to_string(), WorkerType::Regular)
.with_labels(labels4);
// Add all workers
policy.add_worker(&worker1);
policy.add_worker(&worker2);
policy.add_worker(&worker3);
policy.add_worker(&worker4);
// Test selection with default workers only
let default_workers: Vec<Arc<dyn Worker>> =
vec![Arc::new(worker1.clone()), Arc::new(worker3.clone())];
let selected = policy.select_worker(&default_workers, Some("test request"));
assert!(selected.is_some(), "Should select from default workers");
// Test selection with specific model workers only
let llama_workers: Vec<Arc<dyn Worker>> =
vec![Arc::new(worker2.clone()), Arc::new(worker4.clone())];
let selected = policy.select_worker(&llama_workers, Some("test request"));
assert!(selected.is_some(), "Should select from llama-3 workers");
// Test selection with mixed workers
let all_workers: Vec<Arc<dyn Worker>> = vec![
Arc::new(worker1.clone()),
Arc::new(worker2.clone()),
Arc::new(worker3.clone()),
Arc::new(worker4.clone()),
];
let selected = policy.select_worker(&all_workers, Some("test request"));
assert!(selected.is_some(), "Should select from all workers");
}
#[test]
fn test_remove_worker_by_url_backward_compat() {
let config = CacheAwareConfig::default();
let policy = CacheAwarePolicy::with_config(config);
// Create workers with different model_ids
let mut labels1 = HashMap::new();
labels1.insert("model_id".to_string(), "llama-3".to_string());
let worker1 = BasicWorker::new("http://worker1:8080".to_string(), WorkerType::Regular)
.with_labels(labels1);
let worker2 = BasicWorker::new("http://worker2:8080".to_string(), WorkerType::Regular);
// No model_id label - defaults to "unknown"
// Add workers
policy.add_worker(&worker1);
policy.add_worker(&worker2);
// Remove by URL (backward compatibility method)
// Should remove from all trees since we don't know the model
policy.remove_worker_by_url("http://worker1:8080");
// Verify removal worked
let workers: Vec<Arc<dyn Worker>> = vec![Arc::new(worker2.clone())];
let selected = policy.select_worker(&workers, Some("test"));
assert_eq!(selected, Some(0), "Should only have worker2 left");
}
//! Integration tests for PolicyRegistry with RouterManager
use sglang_router_rs::config::{PolicyConfig, RouterConfig};
use sglang_router_rs::core::WorkerRegistry;
use sglang_router_rs::policies::PolicyRegistry;
use sglang_router_rs::protocols::worker_spec::WorkerConfigRequest;
use sglang_router_rs::routers::router_manager::RouterManager;
use std::collections::HashMap;
use std::sync::Arc;
#[tokio::test]
async fn test_policy_registry_with_router_manager() {
// Create RouterConfig
let config = RouterConfig {
enable_igw: true,
policy: PolicyConfig::RoundRobin,
..Default::default()
};
// Create HTTP client
let client = reqwest::Client::new();
// Create shared registries
let worker_registry = Arc::new(WorkerRegistry::new());
let policy_registry = Arc::new(PolicyRegistry::new(PolicyConfig::RoundRobin));
// Create RouterManager with shared registries
let _router_manager = RouterManager::new(
config,
client,
worker_registry.clone(),
policy_registry.clone(),
);
// Test adding workers with different models and policies
// Add first worker for llama-3 with cache_aware policy hint
let mut labels1 = HashMap::new();
labels1.insert("policy".to_string(), "cache_aware".to_string());
let _worker1_config = WorkerConfigRequest {
url: "http://worker1:8000".to_string(),
model_id: Some("llama-3".to_string()),
worker_type: None,
priority: None,
cost: None,
labels: labels1,
bootstrap_port: None,
tokenizer_path: None,
reasoning_parser: None,
tool_parser: None,
chat_template: None,
};
// This would normally connect to a real worker, but for testing we'll just verify the structure
// In a real test, we'd need to mock the worker or use a test server
// Verify PolicyRegistry has the correct policy for llama-3
let _llama_policy = policy_registry.get_policy("llama-3");
// After first worker is added, llama-3 should have a policy
// Add second worker for llama-3 with different policy hint (should be ignored)
let mut labels2 = HashMap::new();
labels2.insert("policy".to_string(), "random".to_string());
let _worker2_config = WorkerConfigRequest {
url: "http://worker2:8000".to_string(),
model_id: Some("llama-3".to_string()),
worker_type: None,
priority: None,
cost: None,
labels: labels2,
bootstrap_port: None,
tokenizer_path: None,
reasoning_parser: None,
tool_parser: None,
chat_template: None,
};
// The second worker should use the same policy as the first (cache_aware)
// Add worker for different model (gpt-4) with random policy
let mut labels3 = HashMap::new();
labels3.insert("policy".to_string(), "random".to_string());
let _worker3_config = WorkerConfigRequest {
url: "http://worker3:8000".to_string(),
model_id: Some("gpt-4".to_string()),
worker_type: None,
priority: None,
cost: None,
labels: labels3,
bootstrap_port: None,
tokenizer_path: None,
reasoning_parser: None,
tool_parser: None,
chat_template: None,
};
// Verify gpt-4 has random policy
let _gpt_policy = policy_registry.get_policy("gpt-4");
// Test removing workers
// When we remove both llama-3 workers, the policy should be cleaned up
println!("PolicyRegistry integration test structure created");
println!("Note: This test requires mocking or test servers to fully execute");
}
#[test]
fn test_policy_registry_cleanup() {
use sglang_router_rs::config::PolicyConfig;
use sglang_router_rs::policies::PolicyRegistry;
let registry = PolicyRegistry::new(PolicyConfig::RoundRobin);
// Add workers for a model
let policy1 = registry.on_worker_added("model-1", Some("cache_aware"));
assert_eq!(policy1.name(), "cache_aware");
// Second worker uses existing policy
let policy2 = registry.on_worker_added("model-1", Some("random"));
assert_eq!(policy2.name(), "cache_aware"); // Should still be cache_aware
// Verify policy exists
assert!(registry.get_policy("model-1").is_some());
// Remove first worker - policy should remain
registry.on_worker_removed("model-1");
assert!(registry.get_policy("model-1").is_some());
// Remove second worker - policy should be cleaned up
registry.on_worker_removed("model-1");
assert!(registry.get_policy("model-1").is_none());
println!("✓ PolicyRegistry cleanup test passed");
}
#[test]
fn test_policy_registry_multiple_models() {
use sglang_router_rs::config::PolicyConfig;
use sglang_router_rs::policies::PolicyRegistry;
let registry = PolicyRegistry::new(PolicyConfig::RoundRobin);
// Add workers for different models with different policies
let llama_policy = registry.on_worker_added("llama-3", Some("cache_aware"));
let gpt_policy = registry.on_worker_added("gpt-4", Some("random"));
let mistral_policy = registry.on_worker_added("mistral", None); // Uses default
assert_eq!(llama_policy.name(), "cache_aware");
assert_eq!(gpt_policy.name(), "random");
assert_eq!(mistral_policy.name(), "round_robin"); // Default
// Verify all policies are stored
assert!(registry.get_policy("llama-3").is_some());
assert!(registry.get_policy("gpt-4").is_some());
assert!(registry.get_policy("mistral").is_some());
// Get all mappings
let mappings = registry.get_all_mappings();
assert_eq!(mappings.len(), 3);
assert_eq!(mappings.get("llama-3").unwrap(), "cache_aware");
assert_eq!(mappings.get("gpt-4").unwrap(), "random");
assert_eq!(mappings.get("mistral").unwrap(), "round_robin");
println!("✓ PolicyRegistry multiple models test passed");
}
......@@ -197,12 +197,14 @@ async fn test_unsupported_endpoints() {
rid: None,
};
let response = router.route_generate(None, &generate_request).await;
let response = router.route_generate(None, &generate_request, None).await;
assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED);
// Test completion endpoint (should also not be supported)
let completion_request = create_minimal_completion_request();
let response = router.route_completion(None, &completion_request).await;
let response = router
.route_completion(None, &completion_request, None)
.await;
assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED);
}
......@@ -228,7 +230,7 @@ async fn test_openai_router_chat_completion_with_mock() {
chat_request.temperature = Some(0.7);
// Route the request
let response = router.route_chat(None, &chat_request).await;
let response = router.route_chat(None, &chat_request, None).await;
// Should get a successful response from mock server
assert_eq!(response.status(), StatusCode::OK);
......@@ -269,7 +271,9 @@ async fn test_openai_e2e_with_server() {
let chat_request: ChatCompletionRequest =
serde_json::from_str(&body_str).unwrap();
router.route_chat(Some(&parts.headers), &chat_request).await
router
.route_chat(Some(&parts.headers), &chat_request, None)
.await
}
}
}),
......@@ -327,7 +331,7 @@ async fn test_openai_router_chat_streaming_with_mock() {
});
let chat_request: ChatCompletionRequest = serde_json::from_value(val).unwrap();
let response = router.route_chat(None, &chat_request).await;
let response = router.route_chat(None, &chat_request, None).await;
assert_eq!(response.status(), StatusCode::OK);
// Should be SSE
......@@ -371,7 +375,7 @@ async fn test_openai_router_circuit_breaker() {
// First few requests should fail and record failures
for _ in 0..3 {
let response = router.route_chat(None, &chat_request).await;
let response = router.route_chat(None, &chat_request, None).await;
// Should get either an error or circuit breaker response
assert!(
response.status() == StatusCode::INTERNAL_SERVER_ERROR
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment