Unverified Commit 873d858b authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] refactor worker to builder pattern 5/n (#10653)

parent 3fa3c22a
use super::{CircuitBreaker, CircuitBreakerConfig, WorkerError, WorkerResult}; use super::{CircuitBreaker, WorkerError, WorkerResult};
use crate::core::CircuitState; use crate::core::CircuitState;
use crate::core::{BasicWorkerBuilder, DPAwareWorkerBuilder}; use crate::core::{BasicWorkerBuilder, DPAwareWorkerBuilder};
use crate::grpc::SglangSchedulerClient; use crate::grpc::SglangSchedulerClient;
...@@ -525,23 +525,6 @@ pub struct DPAwareWorker { ...@@ -525,23 +525,6 @@ pub struct DPAwareWorker {
} }
impl DPAwareWorker { impl DPAwareWorker {
/// Create a new DP-aware worker of any type
pub fn new(base_url: String, dp_rank: usize, dp_size: usize, worker_type: WorkerType) -> Self {
use crate::core::BasicWorkerBuilder;
// Create URL with DP rank suffix for identification
let worker_url = format!("{}@{}", base_url, dp_rank);
let base_worker = BasicWorkerBuilder::new(worker_url)
.worker_type(worker_type)
.build();
Self {
base_worker,
dp_rank,
dp_size,
base_url,
}
}
/// Create a new DP-aware worker with a pre-configured base worker /// Create a new DP-aware worker with a pre-configured base worker
/// This is primarily used by the builder pattern /// This is primarily used by the builder pattern
pub fn with_base_worker( pub fn with_base_worker(
...@@ -661,95 +644,6 @@ impl Worker for DPAwareWorker { ...@@ -661,95 +644,6 @@ impl Worker for DPAwareWorker {
pub struct WorkerFactory; pub struct WorkerFactory;
impl WorkerFactory { impl WorkerFactory {
/// Create a BasicWorkerBuilder for customizable worker creation
pub fn builder(url: impl Into<String>) -> BasicWorkerBuilder {
BasicWorkerBuilder::new(url)
}
/// Create a DPAwareWorkerBuilder for customizable DP-aware worker creation
pub fn dp_builder(
base_url: impl Into<String>,
dp_rank: usize,
dp_size: usize,
) -> DPAwareWorkerBuilder {
DPAwareWorkerBuilder::new(base_url, dp_rank, dp_size)
}
/// Create a regular worker
pub fn create_regular(url: String) -> Box<dyn Worker> {
use crate::core::BasicWorkerBuilder;
Box::new(
BasicWorkerBuilder::new(url)
.worker_type(WorkerType::Regular)
.build(),
)
}
/// Create a prefill worker with optional bootstrap port
pub fn create_prefill(url: String, bootstrap_port: Option<u16>) -> Box<dyn Worker> {
use crate::core::BasicWorkerBuilder;
Box::new(
BasicWorkerBuilder::new(url)
.worker_type(WorkerType::Prefill { bootstrap_port })
.build(),
)
}
/// Create a decode worker
pub fn create_decode(url: String) -> Box<dyn Worker> {
use crate::core::BasicWorkerBuilder;
Box::new(
BasicWorkerBuilder::new(url)
.worker_type(WorkerType::Decode)
.build(),
)
}
/// Create a regular worker with custom labels (for multi-router support)
pub fn create_regular_with_labels(
url: String,
labels: std::collections::HashMap<String, String>,
circuit_breaker_config: CircuitBreakerConfig,
) -> Box<dyn Worker> {
Box::new(
BasicWorkerBuilder::new(url)
.labels(labels)
.circuit_breaker_config(circuit_breaker_config)
.build(),
)
}
/// Create a prefill worker with labels
pub fn create_prefill_with_labels(
url: String,
bootstrap_port: Option<u16>,
labels: std::collections::HashMap<String, String>,
circuit_breaker_config: CircuitBreakerConfig,
) -> Box<dyn Worker> {
Box::new(
BasicWorkerBuilder::new(url)
.worker_type(WorkerType::Prefill { bootstrap_port })
.labels(labels)
.circuit_breaker_config(circuit_breaker_config)
.build(),
)
}
/// Create a decode worker with labels
pub fn create_decode_with_labels(
url: String,
labels: std::collections::HashMap<String, String>,
circuit_breaker_config: CircuitBreakerConfig,
) -> Box<dyn Worker> {
Box::new(
BasicWorkerBuilder::new(url)
.worker_type(WorkerType::Decode)
.labels(labels)
.circuit_breaker_config(circuit_breaker_config)
.build(),
)
}
/// Create a DP-aware worker of specified type /// Create a DP-aware worker of specified type
pub fn create_dp_aware( pub fn create_dp_aware(
base_url: String, base_url: String,
...@@ -757,9 +651,13 @@ impl WorkerFactory { ...@@ -757,9 +651,13 @@ impl WorkerFactory {
dp_size: usize, dp_size: usize,
worker_type: WorkerType, worker_type: WorkerType,
) -> Box<dyn Worker> { ) -> Box<dyn Worker> {
Box::new(DPAwareWorker::new(base_url, dp_rank, dp_size, worker_type)) Box::new(
DPAwareWorkerBuilder::new(base_url, dp_rank, dp_size)
.worker_type(worker_type)
.build(),
)
} }
#[allow(dead_code)]
/// Get DP size from a worker /// Get DP size from a worker
async fn get_worker_dp_size(url: &str, api_key: &Option<String>) -> WorkerResult<usize> { async fn get_worker_dp_size(url: &str, api_key: &Option<String>) -> WorkerResult<usize> {
let mut req_builder = WORKER_CLIENT.get(format!("{}/get_server_info", url)); let mut req_builder = WORKER_CLIENT.get(format!("{}/get_server_info", url));
...@@ -807,81 +705,18 @@ impl WorkerFactory { ...@@ -807,81 +705,18 @@ impl WorkerFactory {
Ok(dp_size as usize) Ok(dp_size as usize)
} }
/// Private helper to create DP-aware workers of any type
async fn create_dp_aware_workers_of_type(
url: &str,
api_key: &Option<String>,
worker_type: WorkerType,
) -> WorkerResult<Vec<Box<dyn Worker>>> {
let dp_size = Self::get_worker_dp_size(url, api_key).await?;
let workers = (0..dp_size)
.map(|rank| Self::create_dp_aware(url.to_string(), rank, dp_size, worker_type.clone()))
.collect();
Ok(workers)
}
/// Create DP-aware regular workers from a single URL
pub async fn create_dp_aware_regular_workers(
url: &str,
api_key: &Option<String>,
) -> WorkerResult<Vec<Box<dyn Worker>>> {
Self::create_dp_aware_workers_of_type(url, api_key, WorkerType::Regular).await
}
/// Create DP-aware prefill workers from a single URL
pub async fn create_dp_aware_prefill_workers(
url: &str,
bootstrap_port: Option<u16>,
api_key: &Option<String>,
) -> WorkerResult<Vec<Box<dyn Worker>>> {
Self::create_dp_aware_workers_of_type(url, api_key, WorkerType::Prefill { bootstrap_port })
.await
}
/// Create DP-aware decode workers from a single URL
pub async fn create_dp_aware_decode_workers(
url: &str,
api_key: &Option<String>,
) -> WorkerResult<Vec<Box<dyn Worker>>> {
Self::create_dp_aware_workers_of_type(url, api_key, WorkerType::Decode).await
}
/// Create workers based on configuration (for regular router)
pub async fn create_workers(
urls: Vec<String>,
dp_aware: bool,
api_key: &Option<String>,
) -> WorkerResult<Vec<Box<dyn Worker>>> {
if dp_aware {
// Create futures for all worker creations
let worker_futs = urls
.iter()
.map(|url| Self::create_dp_aware_regular_workers(url, api_key));
// Execute all futures concurrently and flatten results
let all_workers = futures::future::try_join_all(worker_futs)
.await?
.into_iter()
.flatten()
.collect();
Ok(all_workers)
} else {
Ok(urls
.into_iter()
.map(|url| Self::create_regular(url))
.collect())
}
}
} }
/// Convert a list of worker URLs to worker trait objects /// Convert a list of worker URLs to worker trait objects
pub fn urls_to_workers(urls: Vec<String>) -> Vec<Box<dyn Worker>> { pub fn urls_to_workers(urls: Vec<String>) -> Vec<Box<dyn Worker>> {
urls.into_iter() urls.into_iter()
.map(WorkerFactory::create_regular) .map(|url| {
Box::new(
BasicWorkerBuilder::new(url)
.worker_type(WorkerType::Regular)
.build(),
) as Box<dyn Worker>
})
.collect() .collect()
} }
...@@ -952,7 +787,7 @@ impl HealthChecker { ...@@ -952,7 +787,7 @@ impl HealthChecker {
/// Start an async background health checker for a collection of workers /// Start an async background health checker for a collection of workers
pub fn start_health_checker( pub fn start_health_checker(
workers: std::sync::Arc<std::sync::RwLock<Vec<std::sync::Arc<dyn Worker>>>>, workers: Arc<std::sync::RwLock<Vec<Arc<dyn Worker>>>>,
check_interval_secs: u64, check_interval_secs: u64,
) -> HealthChecker { ) -> HealthChecker {
let shutdown = Arc::new(AtomicBool::new(false)); let shutdown = Arc::new(AtomicBool::new(false));
...@@ -1037,6 +872,7 @@ pub fn start_health_checker( ...@@ -1037,6 +872,7 @@ pub fn start_health_checker(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::core::CircuitBreakerConfig;
use std::thread; use std::thread;
use std::time::Duration; use std::time::Duration;
...@@ -1369,7 +1205,11 @@ mod tests { ...@@ -1369,7 +1205,11 @@ mod tests {
// Test WorkerFactory // Test WorkerFactory
#[test] #[test]
fn test_create_regular_worker() { fn test_create_regular_worker() {
let worker = WorkerFactory::create_regular("http://regular:8080".to_string()); let worker: Box<dyn Worker> = Box::new(
BasicWorkerBuilder::new("http://regular:8080")
.worker_type(WorkerType::Regular)
.build(),
);
assert_eq!(worker.url(), "http://regular:8080"); assert_eq!(worker.url(), "http://regular:8080");
assert_eq!(worker.worker_type(), WorkerType::Regular); assert_eq!(worker.worker_type(), WorkerType::Regular);
} }
...@@ -1377,7 +1217,13 @@ mod tests { ...@@ -1377,7 +1217,13 @@ mod tests {
#[test] #[test]
fn test_create_prefill_worker() { fn test_create_prefill_worker() {
// With bootstrap port // With bootstrap port
let worker1 = WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9090)); let worker1: Box<dyn Worker> = Box::new(
BasicWorkerBuilder::new("http://prefill:8080")
.worker_type(WorkerType::Prefill {
bootstrap_port: Some(9090),
})
.build(),
);
assert_eq!(worker1.url(), "http://prefill:8080"); assert_eq!(worker1.url(), "http://prefill:8080");
assert_eq!( assert_eq!(
worker1.worker_type(), worker1.worker_type(),
...@@ -1387,7 +1233,13 @@ mod tests { ...@@ -1387,7 +1233,13 @@ mod tests {
); );
// Without bootstrap port // Without bootstrap port
let worker2 = WorkerFactory::create_prefill("http://prefill:8080".to_string(), None); let worker2: Box<dyn Worker> = Box::new(
BasicWorkerBuilder::new("http://prefill:8080")
.worker_type(WorkerType::Prefill {
bootstrap_port: None,
})
.build(),
);
assert_eq!( assert_eq!(
worker2.worker_type(), worker2.worker_type(),
WorkerType::Prefill { WorkerType::Prefill {
...@@ -1398,7 +1250,11 @@ mod tests { ...@@ -1398,7 +1250,11 @@ mod tests {
#[test] #[test]
fn test_create_decode_worker() { fn test_create_decode_worker() {
let worker = WorkerFactory::create_decode("http://decode:8080".to_string()); let worker: Box<dyn Worker> = Box::new(
BasicWorkerBuilder::new("http://decode:8080")
.worker_type(WorkerType::Decode)
.build(),
);
assert_eq!(worker.url(), "http://decode:8080"); assert_eq!(worker.url(), "http://decode:8080");
assert_eq!(worker.worker_type(), WorkerType::Decode); assert_eq!(worker.worker_type(), WorkerType::Decode);
} }
...@@ -1424,9 +1280,21 @@ mod tests { ...@@ -1424,9 +1280,21 @@ mod tests {
#[test] #[test]
fn test_load_guard_multiple_workers() { fn test_load_guard_multiple_workers() {
let workers: Vec<Box<dyn Worker>> = vec![ let workers: Vec<Box<dyn Worker>> = vec![
WorkerFactory::create_regular("http://w1:8080".to_string()), Box::new(
WorkerFactory::create_regular("http://w2:8080".to_string()), BasicWorkerBuilder::new("http://w1:8080")
WorkerFactory::create_regular("http://w3:8080".to_string()), .worker_type(WorkerType::Regular)
.build(),
),
Box::new(
BasicWorkerBuilder::new("http://w2:8080")
.worker_type(WorkerType::Regular)
.build(),
),
Box::new(
BasicWorkerBuilder::new("http://w3:8080")
.worker_type(WorkerType::Regular)
.build(),
),
]; ];
let worker_refs: Vec<&dyn Worker> = workers.iter().map(|w| w.as_ref()).collect(); let worker_refs: Vec<&dyn Worker> = workers.iter().map(|w| w.as_ref()).collect();
...@@ -1492,8 +1360,16 @@ mod tests { ...@@ -1492,8 +1360,16 @@ mod tests {
#[test] #[test]
fn test_workers_to_urls() { fn test_workers_to_urls() {
let workers: Vec<Box<dyn Worker>> = vec![ let workers: Vec<Box<dyn Worker>> = vec![
WorkerFactory::create_regular("http://w1:8080".to_string()), Box::new(
WorkerFactory::create_regular("http://w2:8080".to_string()), BasicWorkerBuilder::new("http://w1:8080")
.worker_type(WorkerType::Regular)
.build(),
),
Box::new(
BasicWorkerBuilder::new("http://w2:8080")
.worker_type(WorkerType::Regular)
.build(),
),
]; ];
let urls = workers_to_urls(&workers); let urls = workers_to_urls(&workers);
...@@ -1544,8 +1420,9 @@ mod tests { ...@@ -1544,8 +1420,9 @@ mod tests {
#[test] #[test]
fn test_dp_aware_worker_creation() { fn test_dp_aware_worker_creation() {
let dp_worker = let dp_worker = DPAwareWorkerBuilder::new("http://worker1:8080", 2, 4)
DPAwareWorker::new("http://worker1:8080".to_string(), 2, 4, WorkerType::Regular); .worker_type(WorkerType::Regular)
.build();
assert_eq!(dp_worker.url(), "http://worker1:8080@2"); assert_eq!(dp_worker.url(), "http://worker1:8080@2");
assert_eq!(dp_worker.base_url(), "http://worker1:8080"); assert_eq!(dp_worker.base_url(), "http://worker1:8080");
...@@ -1557,14 +1434,11 @@ mod tests { ...@@ -1557,14 +1434,11 @@ mod tests {
#[test] #[test]
fn test_dp_aware_worker_creation_prefill() { fn test_dp_aware_worker_creation_prefill() {
let dp_worker = DPAwareWorker::new( let dp_worker = DPAwareWorkerBuilder::new("http://worker1:8080", 1, 2)
"http://worker1:8080".to_string(), .worker_type(WorkerType::Prefill {
1,
2,
WorkerType::Prefill {
bootstrap_port: Some(9090), bootstrap_port: Some(9090),
}, })
); .build();
assert_eq!(dp_worker.url(), "http://worker1:8080@1"); assert_eq!(dp_worker.url(), "http://worker1:8080@1");
assert!(dp_worker.is_dp_aware()); assert!(dp_worker.is_dp_aware());
...@@ -1578,8 +1452,9 @@ mod tests { ...@@ -1578,8 +1452,9 @@ mod tests {
#[test] #[test]
fn test_dp_aware_worker_creation_decode() { fn test_dp_aware_worker_creation_decode() {
let dp_worker = let dp_worker = DPAwareWorkerBuilder::new("http://worker1:8080", 0, 4)
DPAwareWorker::new("http://worker1:8080".to_string(), 0, 4, WorkerType::Decode); .worker_type(WorkerType::Decode)
.build();
assert_eq!(dp_worker.url(), "http://worker1:8080@0"); assert_eq!(dp_worker.url(), "http://worker1:8080@0");
assert!(dp_worker.is_dp_aware()); assert!(dp_worker.is_dp_aware());
...@@ -1588,8 +1463,9 @@ mod tests { ...@@ -1588,8 +1463,9 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_dp_aware_prepare_request() { async fn test_dp_aware_prepare_request() {
let dp_worker = let dp_worker = DPAwareWorkerBuilder::new("http://worker1:8080", 3, 8)
DPAwareWorker::new("http://worker1:8080".to_string(), 3, 8, WorkerType::Regular); .worker_type(WorkerType::Regular)
.build();
let original_req = serde_json::json!({ let original_req = serde_json::json!({
"prompt": "Hello", "prompt": "Hello",
...@@ -1605,8 +1481,9 @@ mod tests { ...@@ -1605,8 +1481,9 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_dp_aware_prepare_request_invalid() { async fn test_dp_aware_prepare_request_invalid() {
let dp_worker = let dp_worker = DPAwareWorkerBuilder::new("http://worker1:8080", 0, 4)
DPAwareWorker::new("http://worker1:8080".to_string(), 0, 4, WorkerType::Regular); .worker_type(WorkerType::Regular)
.build();
// Non-object JSON should fail // Non-object JSON should fail
let invalid_req = serde_json::json!("not an object"); let invalid_req = serde_json::json!("not an object");
...@@ -1623,8 +1500,9 @@ mod tests { ...@@ -1623,8 +1500,9 @@ mod tests {
#[test] #[test]
fn test_dp_aware_endpoint_url() { fn test_dp_aware_endpoint_url() {
let dp_worker = let dp_worker = DPAwareWorkerBuilder::new("http://worker1:8080", 1, 4)
DPAwareWorker::new("http://worker1:8080".to_string(), 1, 4, WorkerType::Regular); .worker_type(WorkerType::Regular)
.build();
assert_eq!( assert_eq!(
dp_worker.endpoint_url("/generate"), dp_worker.endpoint_url("/generate"),
...@@ -1638,8 +1516,9 @@ mod tests { ...@@ -1638,8 +1516,9 @@ mod tests {
#[test] #[test]
fn test_dp_aware_worker_delegated_methods() { fn test_dp_aware_worker_delegated_methods() {
let dp_worker = let dp_worker = DPAwareWorkerBuilder::new("http://worker1:8080", 0, 2)
DPAwareWorker::new("http://worker1:8080".to_string(), 0, 2, WorkerType::Regular); .worker_type(WorkerType::Regular)
.build();
// Test health status // Test health status
assert!(dp_worker.is_healthy()); assert!(dp_worker.is_healthy());
...@@ -1698,23 +1577,6 @@ mod tests { ...@@ -1698,23 +1577,6 @@ mod tests {
); );
} }
#[tokio::test]
async fn test_factory_create_workers_regular() {
let urls = vec!["http://w1:8080".to_string(), "http://w2:8080".to_string()];
let workers = WorkerFactory::create_workers(urls, false, &None)
.await
.unwrap();
assert_eq!(workers.len(), 2);
assert!(!workers[0].is_dp_aware());
assert!(!workers[1].is_dp_aware());
assert_eq!(workers[0].url(), "http://w1:8080");
assert_eq!(workers[1].url(), "http://w2:8080");
}
// ===== Circuit Breaker Integration Tests =====
#[test] #[test]
fn test_worker_circuit_breaker() { fn test_worker_circuit_breaker() {
use crate::core::BasicWorkerBuilder; use crate::core::BasicWorkerBuilder;
...@@ -1779,8 +1641,9 @@ mod tests { ...@@ -1779,8 +1641,9 @@ mod tests {
#[test] #[test]
fn test_dp_aware_worker_circuit_breaker() { fn test_dp_aware_worker_circuit_breaker() {
let dp_worker = let dp_worker = DPAwareWorkerBuilder::new("http://worker:8080", 0, 2)
DPAwareWorker::new("http://worker:8080".to_string(), 0, 2, WorkerType::Regular); .worker_type(WorkerType::Regular)
.build();
// Should have circuit breaker // Should have circuit breaker
assert!(dp_worker.is_available()); assert!(dp_worker.is_available());
...@@ -1800,9 +1663,23 @@ mod tests { ...@@ -1800,9 +1663,23 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_mixed_worker_types() { async fn test_mixed_worker_types() {
// Create a mix of worker types // Create a mix of worker types
let regular = WorkerFactory::create_regular("http://regular:8080".to_string()); let regular: Box<dyn Worker> = Box::new(
let prefill = WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9090)); BasicWorkerBuilder::new("http://regular:8080")
let decode = WorkerFactory::create_decode("http://decode:8080".to_string()); .worker_type(WorkerType::Regular)
.build(),
);
let prefill: Box<dyn Worker> = Box::new(
BasicWorkerBuilder::new("http://prefill:8080")
.worker_type(WorkerType::Prefill {
bootstrap_port: Some(9090),
})
.build(),
);
let decode: Box<dyn Worker> = Box::new(
BasicWorkerBuilder::new("http://decode:8080")
.worker_type(WorkerType::Decode)
.build(),
);
let dp_aware_regular = let dp_aware_regular =
WorkerFactory::create_dp_aware("http://dp:8080".to_string(), 0, 2, WorkerType::Regular); WorkerFactory::create_dp_aware("http://dp:8080".to_string(), 0, 2, WorkerType::Regular);
let dp_aware_prefill = WorkerFactory::create_dp_aware( let dp_aware_prefill = WorkerFactory::create_dp_aware(
......
...@@ -424,7 +424,7 @@ pub struct WorkerRegistryStats { ...@@ -424,7 +424,7 @@ pub struct WorkerRegistryStats {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::core::{CircuitBreakerConfig, WorkerFactory}; use crate::core::{BasicWorkerBuilder, CircuitBreakerConfig};
use std::collections::HashMap; use std::collections::HashMap;
#[test] #[test]
...@@ -437,10 +437,12 @@ mod tests { ...@@ -437,10 +437,12 @@ mod tests {
labels.insert("priority".to_string(), "50".to_string()); labels.insert("priority".to_string(), "50".to_string());
labels.insert("cost".to_string(), "0.8".to_string()); labels.insert("cost".to_string(), "0.8".to_string());
let worker = WorkerFactory::create_regular_with_labels( let worker: Box<dyn Worker> = Box::new(
"http://worker1:8080".to_string(), BasicWorkerBuilder::new("http://worker1:8080")
labels, .worker_type(WorkerType::Regular)
CircuitBreakerConfig::default(), .labels(labels)
.circuit_breaker_config(CircuitBreakerConfig::default())
.build(),
); );
// Register worker (WorkerFactory returns Box<dyn Worker>, convert to Arc) // Register worker (WorkerFactory returns Box<dyn Worker>, convert to Arc)
...@@ -470,26 +472,32 @@ mod tests { ...@@ -470,26 +472,32 @@ mod tests {
// Create workers for different models // Create workers for different models
let mut labels1 = HashMap::new(); let mut labels1 = HashMap::new();
labels1.insert("model_id".to_string(), "llama-3".to_string()); labels1.insert("model_id".to_string(), "llama-3".to_string());
let worker1 = WorkerFactory::create_regular_with_labels( let worker1: Box<dyn Worker> = Box::new(
"http://worker1:8080".to_string(), BasicWorkerBuilder::new("http://worker1:8080")
labels1, .worker_type(WorkerType::Regular)
CircuitBreakerConfig::default(), .labels(labels1)
.circuit_breaker_config(CircuitBreakerConfig::default())
.build(),
); );
let mut labels2 = HashMap::new(); let mut labels2 = HashMap::new();
labels2.insert("model_id".to_string(), "llama-3".to_string()); labels2.insert("model_id".to_string(), "llama-3".to_string());
let worker2 = WorkerFactory::create_regular_with_labels( let worker2: Box<dyn Worker> = Box::new(
"http://worker2:8080".to_string(), BasicWorkerBuilder::new("http://worker2:8080")
labels2, .worker_type(WorkerType::Regular)
CircuitBreakerConfig::default(), .labels(labels2)
.circuit_breaker_config(CircuitBreakerConfig::default())
.build(),
); );
let mut labels3 = HashMap::new(); let mut labels3 = HashMap::new();
labels3.insert("model_id".to_string(), "gpt-4".to_string()); labels3.insert("model_id".to_string(), "gpt-4".to_string());
let worker3 = WorkerFactory::create_regular_with_labels( let worker3: Box<dyn Worker> = Box::new(
"http://worker3:8080".to_string(), BasicWorkerBuilder::new("http://worker3:8080")
labels3, .worker_type(WorkerType::Regular)
CircuitBreakerConfig::default(), .labels(labels3)
.circuit_breaker_config(CircuitBreakerConfig::default())
.build(),
); );
// Register workers // Register workers
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
//! - Multi-Router Mode (enable_igw=true): RouterManager coordinates everything //! - Multi-Router Mode (enable_igw=true): RouterManager coordinates everything
use crate::config::RouterConfig; use crate::config::RouterConfig;
use crate::core::{CircuitBreakerConfig, Worker, WorkerFactory, WorkerRegistry, WorkerType}; use crate::core::{BasicWorkerBuilder, CircuitBreakerConfig, Worker, WorkerRegistry, WorkerType};
use crate::protocols::spec::{ use crate::protocols::spec::{
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
ResponsesRequest, ResponsesRequest,
...@@ -208,22 +208,29 @@ impl RouterManager { ...@@ -208,22 +208,29 @@ impl RouterManager {
} }
let worker = match config.worker_type.as_deref() { let worker = match config.worker_type.as_deref() {
Some("prefill") => WorkerFactory::create_prefill_with_labels( Some("prefill") => Box::new(
config.url.clone(), BasicWorkerBuilder::new(config.url.clone())
config.bootstrap_port, .worker_type(WorkerType::Prefill {
labels.clone(), bootstrap_port: config.bootstrap_port,
CircuitBreakerConfig::default(), })
), .labels(labels.clone())
Some("decode") => WorkerFactory::create_decode_with_labels( .circuit_breaker_config(CircuitBreakerConfig::default())
config.url.clone(), .build(),
labels.clone(), ) as Box<dyn Worker>,
CircuitBreakerConfig::default(), Some("decode") => Box::new(
), BasicWorkerBuilder::new(config.url.clone())
_ => WorkerFactory::create_regular_with_labels( .worker_type(WorkerType::Decode)
config.url.clone(), .labels(labels.clone())
labels.clone(), .circuit_breaker_config(CircuitBreakerConfig::default())
CircuitBreakerConfig::default(), .build(),
), ) as Box<dyn Worker>,
_ => Box::new(
BasicWorkerBuilder::new(config.url.clone())
.worker_type(WorkerType::Regular)
.labels(labels.clone())
.circuit_breaker_config(CircuitBreakerConfig::default())
.build(),
) as Box<dyn Worker>,
}; };
// Register worker // Register worker
......
...@@ -4,7 +4,7 @@ mod test_pd_routing { ...@@ -4,7 +4,7 @@ mod test_pd_routing {
use sglang_router_rs::config::{ use sglang_router_rs::config::{
CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, RoutingMode, CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
}; };
use sglang_router_rs::core::{WorkerFactory, WorkerType}; use sglang_router_rs::core::{BasicWorkerBuilder, Worker, WorkerType};
use sglang_router_rs::routers::http::pd_types::get_hostname; use sglang_router_rs::routers::http::pd_types::get_hostname;
use sglang_router_rs::routers::http::pd_types::PDSelectionPolicy; use sglang_router_rs::routers::http::pd_types::PDSelectionPolicy;
use sglang_router_rs::routers::RouterFactory; use sglang_router_rs::routers::RouterFactory;
...@@ -46,11 +46,16 @@ mod test_pd_routing { ...@@ -46,11 +46,16 @@ mod test_pd_routing {
#[test] #[test]
fn test_worker_types() { fn test_worker_types() {
use sglang_router_rs::core::{WorkerFactory, WorkerType}; use sglang_router_rs::core::{BasicWorkerBuilder, Worker, WorkerType};
// Test worker creation for prefill servers // Test worker creation for prefill servers
let prefill_worker = let prefill_worker: Box<dyn Worker> = Box::new(
WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9000)); BasicWorkerBuilder::new("http://prefill:8080")
.worker_type(WorkerType::Prefill {
bootstrap_port: Some(9000),
})
.build(),
);
assert_eq!(prefill_worker.url(), "http://prefill:8080"); assert_eq!(prefill_worker.url(), "http://prefill:8080");
match prefill_worker.worker_type() { match prefill_worker.worker_type() {
WorkerType::Prefill { bootstrap_port } => { WorkerType::Prefill { bootstrap_port } => {
...@@ -60,7 +65,11 @@ mod test_pd_routing { ...@@ -60,7 +65,11 @@ mod test_pd_routing {
} }
// Test worker creation for decode servers // Test worker creation for decode servers
let decode_worker = WorkerFactory::create_decode("http://decode:8080".to_string()); let decode_worker: Box<dyn Worker> = Box::new(
BasicWorkerBuilder::new("http://decode:8080")
.worker_type(WorkerType::Decode)
.build(),
);
assert_eq!(decode_worker.url(), "http://decode:8080"); assert_eq!(decode_worker.url(), "http://decode:8080");
match decode_worker.worker_type() { match decode_worker.worker_type() {
WorkerType::Decode => (), WorkerType::Decode => (),
...@@ -68,7 +77,11 @@ mod test_pd_routing { ...@@ -68,7 +77,11 @@ mod test_pd_routing {
} }
// Test regular worker creation // Test regular worker creation
let regular_worker = WorkerFactory::create_regular("http://regular:8080".to_string()); let regular_worker: Box<dyn Worker> = Box::new(
BasicWorkerBuilder::new("http://regular:8080")
.worker_type(WorkerType::Regular)
.build(),
);
assert_eq!(regular_worker.url(), "http://regular:8080"); assert_eq!(regular_worker.url(), "http://regular:8080");
match regular_worker.worker_type() { match regular_worker.worker_type() {
WorkerType::Regular => (), WorkerType::Regular => (),
...@@ -277,8 +290,13 @@ mod test_pd_routing { ...@@ -277,8 +290,13 @@ mod test_pd_routing {
}); });
// Create a prefill worker to simulate injection // Create a prefill worker to simulate injection
let prefill_worker = let prefill_worker: Box<dyn Worker> = Box::new(
WorkerFactory::create_prefill("http://prefill1:8080".to_string(), Some(9000)); BasicWorkerBuilder::new("http://prefill1:8080")
.worker_type(WorkerType::Prefill {
bootstrap_port: Some(9000),
})
.build(),
);
// Extract bootstrap port from worker type // Extract bootstrap port from worker type
let bootstrap_port = match prefill_worker.worker_type() { let bootstrap_port = match prefill_worker.worker_type() {
...@@ -660,7 +678,7 @@ mod test_pd_routing { ...@@ -660,7 +678,7 @@ mod test_pd_routing {
#[test] #[test]
fn test_bootstrap_injection_with_benchmark_requests() { fn test_bootstrap_injection_with_benchmark_requests() {
use sglang_router_rs::core::{WorkerFactory, WorkerType}; use sglang_router_rs::core::{BasicWorkerBuilder, Worker, WorkerType};
// Test bootstrap injection with actual benchmark request patterns // Test bootstrap injection with actual benchmark request patterns
let mut benchmark_request = json!({ let mut benchmark_request = json!({
...@@ -675,8 +693,13 @@ mod test_pd_routing { ...@@ -675,8 +693,13 @@ mod test_pd_routing {
}); });
// Create a prefill worker to simulate injection // Create a prefill worker to simulate injection
let prefill_worker = let prefill_worker: Box<dyn Worker> = Box::new(
WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9000)); BasicWorkerBuilder::new("http://prefill:8080")
.worker_type(WorkerType::Prefill {
bootstrap_port: Some(9000),
})
.build(),
);
// Extract bootstrap port from worker type // Extract bootstrap port from worker type
let bootstrap_port = match prefill_worker.worker_type() { let bootstrap_port = match prefill_worker.worker_type() {
...@@ -806,8 +829,13 @@ mod test_pd_routing { ...@@ -806,8 +829,13 @@ mod test_pd_routing {
}); });
// Create a prefill worker to simulate injection // Create a prefill worker to simulate injection
let prefill_worker = let prefill_worker: Box<dyn Worker> = Box::new(
WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9000)); BasicWorkerBuilder::new("http://prefill:8080")
.worker_type(WorkerType::Prefill {
bootstrap_port: Some(9000),
})
.build(),
);
// Extract bootstrap port from worker type // Extract bootstrap port from worker type
let bootstrap_port = match prefill_worker.worker_type() { let bootstrap_port = match prefill_worker.worker_type() {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment