Unverified Commit 4f2055ad authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] refactor worker to builder pattern 4/n (#10650)

parent 616a3e20
...@@ -22,7 +22,7 @@ pub use error::{WorkerError, WorkerResult}; ...@@ -22,7 +22,7 @@ pub use error::{WorkerError, WorkerResult};
pub use retry::{is_retryable_status, BackoffCalculator, RetryError, RetryExecutor}; pub use retry::{is_retryable_status, BackoffCalculator, RetryError, RetryExecutor};
pub use worker::{ pub use worker::{
start_health_checker, BasicWorker, ConnectionMode, DPAwareWorker, HealthChecker, HealthConfig, start_health_checker, BasicWorker, ConnectionMode, DPAwareWorker, HealthChecker, HealthConfig,
Worker, WorkerCollection, WorkerFactory, WorkerLoadGuard, WorkerType, Worker, WorkerFactory, WorkerLoadGuard, WorkerType,
}; };
pub use worker_builder::{BasicWorkerBuilder, DPAwareWorkerBuilder}; pub use worker_builder::{BasicWorkerBuilder, DPAwareWorkerBuilder};
pub use worker_registry::{WorkerId, WorkerRegistry, WorkerRegistryStats}; pub use worker_registry::{WorkerId, WorkerRegistry, WorkerRegistryStats};
...@@ -328,15 +328,15 @@ pub struct WorkerMetadata { ...@@ -328,15 +328,15 @@ pub struct WorkerMetadata {
/// Basic worker implementation /// Basic worker implementation
#[derive(Clone)] #[derive(Clone)]
pub struct BasicWorker { pub struct BasicWorker {
metadata: WorkerMetadata, pub metadata: WorkerMetadata,
load_counter: Arc<AtomicUsize>, pub load_counter: Arc<AtomicUsize>,
processed_counter: Arc<AtomicUsize>, pub processed_counter: Arc<AtomicUsize>,
healthy: Arc<AtomicBool>, pub healthy: Arc<AtomicBool>,
consecutive_failures: Arc<AtomicUsize>, pub consecutive_failures: Arc<AtomicUsize>,
consecutive_successes: Arc<AtomicUsize>, pub consecutive_successes: Arc<AtomicUsize>,
circuit_breaker: CircuitBreaker, pub circuit_breaker: CircuitBreaker,
/// Optional gRPC client for gRPC workers /// Optional gRPC client for gRPC workers
grpc_client: Option<Arc<Mutex<SglangSchedulerClient>>>, pub grpc_client: Option<Arc<Mutex<SglangSchedulerClient>>>,
} }
impl fmt::Debug for BasicWorker { impl fmt::Debug for BasicWorker {
...@@ -351,56 +351,6 @@ impl fmt::Debug for BasicWorker { ...@@ -351,56 +351,6 @@ impl fmt::Debug for BasicWorker {
} }
impl BasicWorker { impl BasicWorker {
pub fn new(url: String, worker_type: WorkerType) -> Self {
Self::with_connection_mode(url, worker_type, ConnectionMode::Http)
}
pub fn with_connection_mode(
url: String,
worker_type: WorkerType,
connection_mode: ConnectionMode,
) -> Self {
let metadata = WorkerMetadata {
url: url.clone(),
worker_type,
connection_mode,
labels: std::collections::HashMap::new(),
health_config: HealthConfig::default(),
};
Self {
metadata,
load_counter: Arc::new(AtomicUsize::new(0)),
processed_counter: Arc::new(AtomicUsize::new(0)),
healthy: Arc::new(AtomicBool::new(true)),
consecutive_failures: Arc::new(AtomicUsize::new(0)),
consecutive_successes: Arc::new(AtomicUsize::new(0)),
circuit_breaker: CircuitBreaker::new(),
grpc_client: None,
}
}
pub fn with_labels(mut self, labels: std::collections::HashMap<String, String>) -> Self {
self.metadata.labels = labels;
self
}
pub fn with_health_config(mut self, config: HealthConfig) -> Self {
self.metadata.health_config = config;
self
}
pub fn with_circuit_breaker_config(mut self, config: CircuitBreakerConfig) -> Self {
self.circuit_breaker = CircuitBreaker::with_config(config);
self
}
/// Set the gRPC client for gRPC workers
pub fn with_grpc_client(mut self, client: SglangSchedulerClient) -> Self {
self.grpc_client = Some(Arc::new(Mutex::new(client)));
self
}
pub fn normalised_url(&self) -> WorkerResult<&str> { pub fn normalised_url(&self) -> WorkerResult<&str> {
if self.url().contains("@") { if self.url().contains("@") {
// Need to extract the URL from "http://host:port@dp_rank" // Need to extract the URL from "http://host:port@dp_rank"
...@@ -577,9 +527,12 @@ pub struct DPAwareWorker { ...@@ -577,9 +527,12 @@ pub struct DPAwareWorker {
impl DPAwareWorker { impl DPAwareWorker {
/// Create a new DP-aware worker of any type /// Create a new DP-aware worker of any type
pub fn new(base_url: String, dp_rank: usize, dp_size: usize, worker_type: WorkerType) -> Self { pub fn new(base_url: String, dp_rank: usize, dp_size: usize, worker_type: WorkerType) -> Self {
use crate::core::BasicWorkerBuilder;
// Create URL with DP rank suffix for identification // Create URL with DP rank suffix for identification
let worker_url = format!("{}@{}", base_url, dp_rank); let worker_url = format!("{}@{}", base_url, dp_rank);
let base_worker = BasicWorker::new(worker_url, worker_type); let base_worker = BasicWorkerBuilder::new(worker_url)
.worker_type(worker_type)
.build();
Self { Self {
base_worker, base_worker,
...@@ -724,107 +677,30 @@ impl WorkerFactory { ...@@ -724,107 +677,30 @@ impl WorkerFactory {
/// Create a regular worker /// Create a regular worker
pub fn create_regular(url: String) -> Box<dyn Worker> { pub fn create_regular(url: String) -> Box<dyn Worker> {
Box::new(BasicWorker::new(url, WorkerType::Regular)) use crate::core::BasicWorkerBuilder;
}
/// Create a regular worker with custom circuit breaker configuration
pub fn create_regular_with_config(
url: String,
circuit_breaker_config: CircuitBreakerConfig,
) -> Box<dyn Worker> {
Box::new( Box::new(
BasicWorkerBuilder::new(url) BasicWorkerBuilder::new(url)
.circuit_breaker_config(circuit_breaker_config) .worker_type(WorkerType::Regular)
.build(), .build(),
) )
} }
/// Create a prefill worker with optional bootstrap port /// Create a prefill worker with optional bootstrap port
pub fn create_prefill(url: String, bootstrap_port: Option<u16>) -> Box<dyn Worker> { pub fn create_prefill(url: String, bootstrap_port: Option<u16>) -> Box<dyn Worker> {
Box::new(BasicWorker::new( use crate::core::BasicWorkerBuilder;
url,
WorkerType::Prefill { bootstrap_port },
))
}
/// Create a prefill worker with custom circuit breaker configuration
pub fn create_prefill_with_config(
url: String,
bootstrap_port: Option<u16>,
circuit_breaker_config: CircuitBreakerConfig,
) -> Box<dyn Worker> {
Box::new( Box::new(
BasicWorkerBuilder::new(url) BasicWorkerBuilder::new(url)
.worker_type(WorkerType::Prefill { bootstrap_port }) .worker_type(WorkerType::Prefill { bootstrap_port })
.circuit_breaker_config(circuit_breaker_config)
.build(), .build(),
) )
} }
/// Create a decode worker /// Create a decode worker
pub fn create_decode(url: String) -> Box<dyn Worker> { pub fn create_decode(url: String) -> Box<dyn Worker> {
Box::new(BasicWorker::new(url, WorkerType::Decode)) use crate::core::BasicWorkerBuilder;
}
/// Create a decode worker with custom circuit breaker configuration
pub fn create_decode_with_config(
url: String,
circuit_breaker_config: CircuitBreakerConfig,
) -> Box<dyn Worker> {
Box::new( Box::new(
BasicWorkerBuilder::new(url) BasicWorkerBuilder::new(url)
.worker_type(WorkerType::Decode) .worker_type(WorkerType::Decode)
.circuit_breaker_config(circuit_breaker_config)
.build(),
)
}
/// Create workers from URLs with automatic type detection
#[allow(clippy::type_complexity)]
pub fn create_from_urls(
regular_urls: Vec<String>,
prefill_urls: Vec<(String, Option<u16>)>,
decode_urls: Vec<String>,
) -> (
Vec<Box<dyn Worker>>,
Vec<Box<dyn Worker>>,
Vec<Box<dyn Worker>>,
) {
let regular_workers: Vec<Box<dyn Worker>> =
regular_urls.into_iter().map(Self::create_regular).collect();
let prefill_workers: Vec<Box<dyn Worker>> = prefill_urls
.into_iter()
.map(|(url, port)| Self::create_prefill(url, port))
.collect();
let decode_workers: Vec<Box<dyn Worker>> =
decode_urls.into_iter().map(Self::create_decode).collect();
(regular_workers, prefill_workers, decode_workers)
}
/// Create a gRPC worker
pub fn create_grpc(url: String, worker_type: WorkerType, port: Option<u16>) -> Box<dyn Worker> {
Box::new(BasicWorker::with_connection_mode(
url,
worker_type,
ConnectionMode::Grpc { port },
))
}
/// Create a gRPC worker with custom circuit breaker configuration
pub fn create_grpc_with_config(
url: String,
worker_type: WorkerType,
port: Option<u16>,
circuit_breaker_config: CircuitBreakerConfig,
) -> Box<dyn Worker> {
Box::new(
BasicWorkerBuilder::new(url)
.worker_type(worker_type)
.connection_mode(ConnectionMode::Grpc { port })
.circuit_breaker_config(circuit_breaker_config)
.build(), .build(),
) )
} }
...@@ -1002,35 +878,6 @@ impl WorkerFactory { ...@@ -1002,35 +878,6 @@ impl WorkerFactory {
} }
} }
/// Helper trait for collections of workers
pub trait WorkerCollection {
fn healthy_workers(&self) -> Vec<&dyn Worker>;
fn total_load(&self) -> usize;
fn find_worker(&self, url: &str) -> Option<&dyn Worker>;
fn find_worker_mut(&mut self, url: &str) -> Option<&mut Box<dyn Worker>>;
}
impl WorkerCollection for Vec<Box<dyn Worker>> {
fn healthy_workers(&self) -> Vec<&dyn Worker> {
self.iter()
.filter(|w| w.is_healthy())
.map(|w| w.as_ref())
.collect()
}
fn total_load(&self) -> usize {
self.iter().map(|w| w.load()).sum()
}
fn find_worker(&self, url: &str) -> Option<&dyn Worker> {
self.iter().find(|w| w.url() == url).map(|w| w.as_ref())
}
fn find_worker_mut(&mut self, url: &str) -> Option<&mut Box<dyn Worker>> {
self.iter_mut().find(|w| w.url() == url)
}
}
/// Convert a list of worker URLs to worker trait objects /// Convert a list of worker URLs to worker trait objects
pub fn urls_to_workers(urls: Vec<String>) -> Vec<Box<dyn Worker>> { pub fn urls_to_workers(urls: Vec<String>) -> Vec<Box<dyn Worker>> {
urls.into_iter() urls.into_iter()
...@@ -1275,7 +1122,10 @@ mod tests { ...@@ -1275,7 +1122,10 @@ mod tests {
// Test BasicWorker // Test BasicWorker
#[test] #[test]
fn test_basic_worker_creation() { fn test_basic_worker_creation() {
let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular); use crate::core::BasicWorkerBuilder;
let worker = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Regular)
.build();
assert_eq!(worker.url(), "http://test:8080"); assert_eq!(worker.url(), "http://test:8080");
assert_eq!(worker.worker_type(), WorkerType::Regular); assert_eq!(worker.worker_type(), WorkerType::Regular);
assert!(worker.is_healthy()); assert!(worker.is_healthy());
...@@ -1289,8 +1139,11 @@ mod tests { ...@@ -1289,8 +1139,11 @@ mod tests {
labels.insert("env".to_string(), "prod".to_string()); labels.insert("env".to_string(), "prod".to_string());
labels.insert("zone".to_string(), "us-west".to_string()); labels.insert("zone".to_string(), "us-west".to_string());
let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular) use crate::core::BasicWorkerBuilder;
.with_labels(labels.clone()); let worker = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Regular)
.labels(labels.clone())
.build();
assert_eq!(worker.metadata().labels, labels); assert_eq!(worker.metadata().labels, labels);
} }
...@@ -1305,8 +1158,11 @@ mod tests { ...@@ -1305,8 +1158,11 @@ mod tests {
success_threshold: 2, success_threshold: 2,
}; };
let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular) use crate::core::BasicWorkerBuilder;
.with_health_config(custom_config.clone()); let worker = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Regular)
.health_config(custom_config.clone())
.build();
assert_eq!(worker.metadata().health_config.timeout_secs, 15); assert_eq!(worker.metadata().health_config.timeout_secs, 15);
assert_eq!(worker.metadata().health_config.check_interval_secs, 45); assert_eq!(worker.metadata().health_config.check_interval_secs, 45);
...@@ -1316,21 +1172,26 @@ mod tests { ...@@ -1316,21 +1172,26 @@ mod tests {
// Test Worker trait implementation // Test Worker trait implementation
#[test] #[test]
fn test_worker_url() { fn test_worker_url() {
let worker = BasicWorker::new("http://worker1:8080".to_string(), WorkerType::Regular); use crate::core::BasicWorkerBuilder;
let worker = BasicWorkerBuilder::new("http://worker1:8080")
.worker_type(WorkerType::Regular)
.build();
assert_eq!(worker.url(), "http://worker1:8080"); assert_eq!(worker.url(), "http://worker1:8080");
} }
#[test] #[test]
fn test_worker_type_getter() { fn test_worker_type_getter() {
let regular = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular); use crate::core::BasicWorkerBuilder;
let regular = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Regular)
.build();
assert_eq!(regular.worker_type(), WorkerType::Regular); assert_eq!(regular.worker_type(), WorkerType::Regular);
let prefill = BasicWorker::new( let prefill = BasicWorkerBuilder::new("http://test:8080")
"http://test:8080".to_string(), .worker_type(WorkerType::Prefill {
WorkerType::Prefill {
bootstrap_port: Some(9090), bootstrap_port: Some(9090),
}, })
); .build();
assert_eq!( assert_eq!(
prefill.worker_type(), prefill.worker_type(),
WorkerType::Prefill { WorkerType::Prefill {
...@@ -1338,13 +1199,18 @@ mod tests { ...@@ -1338,13 +1199,18 @@ mod tests {
} }
); );
let decode = BasicWorker::new("http://test:8080".to_string(), WorkerType::Decode); let decode = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Decode)
.build();
assert_eq!(decode.worker_type(), WorkerType::Decode); assert_eq!(decode.worker_type(), WorkerType::Decode);
} }
#[test] #[test]
fn test_health_status() { fn test_health_status() {
let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular); use crate::core::BasicWorkerBuilder;
let worker = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Regular)
.build();
// Initial state is healthy // Initial state is healthy
assert!(worker.is_healthy()); assert!(worker.is_healthy());
...@@ -1360,7 +1226,10 @@ mod tests { ...@@ -1360,7 +1226,10 @@ mod tests {
#[test] #[test]
fn test_load_counter_operations() { fn test_load_counter_operations() {
let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular); use crate::core::BasicWorkerBuilder;
let worker = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Regular)
.build();
// Initial load is 0 // Initial load is 0
assert_eq!(worker.load(), 0); assert_eq!(worker.load(), 0);
...@@ -1390,7 +1259,10 @@ mod tests { ...@@ -1390,7 +1259,10 @@ mod tests {
#[test] #[test]
fn test_processed_counter() { fn test_processed_counter() {
let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular); use crate::core::BasicWorkerBuilder;
let worker = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Regular)
.build();
// Initial count is 0 // Initial count is 0
assert_eq!(worker.processed_requests(), 0); assert_eq!(worker.processed_requests(), 0);
...@@ -1405,10 +1277,12 @@ mod tests { ...@@ -1405,10 +1277,12 @@ mod tests {
// Test concurrent operations // Test concurrent operations
#[tokio::test] #[tokio::test]
async fn test_concurrent_load_increments() { async fn test_concurrent_load_increments() {
let worker = Arc::new(BasicWorker::new( use crate::core::BasicWorkerBuilder;
"http://test:8080".to_string(), let worker = Arc::new(
WorkerType::Regular, BasicWorkerBuilder::new("http://test:8080")
)); .worker_type(WorkerType::Regular)
.build(),
);
let mut handles = vec![]; let mut handles = vec![];
...@@ -1432,10 +1306,12 @@ mod tests { ...@@ -1432,10 +1306,12 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_concurrent_load_decrements() { async fn test_concurrent_load_decrements() {
let worker = Arc::new(BasicWorker::new( use crate::core::BasicWorkerBuilder;
"http://test:8080".to_string(), let worker = Arc::new(
WorkerType::Regular, BasicWorkerBuilder::new("http://test:8080")
)); .worker_type(WorkerType::Regular)
.build(),
);
// Set initial load to 100 // Set initial load to 100
for _ in 0..100 { for _ in 0..100 {
...@@ -1465,10 +1341,12 @@ mod tests { ...@@ -1465,10 +1341,12 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_concurrent_health_updates() { async fn test_concurrent_health_updates() {
let worker = Arc::new(BasicWorker::new( use crate::core::BasicWorkerBuilder;
"http://test:8080".to_string(), let worker = Arc::new(
WorkerType::Regular, BasicWorkerBuilder::new("http://test:8080")
)); .worker_type(WorkerType::Regular)
.build(),
);
let mut handles = vec![]; let mut handles = vec![];
...@@ -1525,111 +1403,13 @@ mod tests { ...@@ -1525,111 +1403,13 @@ mod tests {
assert_eq!(worker.worker_type(), WorkerType::Decode); assert_eq!(worker.worker_type(), WorkerType::Decode);
} }
#[test]
fn test_create_from_urls() {
let regular_urls = vec![
"http://regular1:8080".to_string(),
"http://regular2:8080".to_string(),
];
let prefill_urls = vec![
("http://prefill1:8080".to_string(), Some(9090)),
("http://prefill2:8080".to_string(), None),
];
let decode_urls = vec![
"http://decode1:8080".to_string(),
"http://decode2:8080".to_string(),
];
let (regular, prefill, decode) =
WorkerFactory::create_from_urls(regular_urls, prefill_urls, decode_urls);
assert_eq!(regular.len(), 2);
assert_eq!(prefill.len(), 2);
assert_eq!(decode.len(), 2);
assert_eq!(regular[0].url(), "http://regular1:8080");
assert_eq!(prefill[0].url(), "http://prefill1:8080");
assert_eq!(decode[0].url(), "http://decode1:8080");
}
// Test WorkerCollection trait
#[test]
fn test_healthy_workers_filter() {
let workers: Vec<Box<dyn Worker>> = vec![
WorkerFactory::create_regular("http://w1:8080".to_string()),
WorkerFactory::create_regular("http://w2:8080".to_string()),
WorkerFactory::create_regular("http://w3:8080".to_string()),
];
// Set some workers unhealthy
workers[0].set_healthy(false);
workers[2].set_healthy(false);
let healthy = workers.healthy_workers();
assert_eq!(healthy.len(), 1);
assert_eq!(healthy[0].url(), "http://w2:8080");
}
#[test]
fn test_total_load_calculation() {
let workers: Vec<Box<dyn Worker>> = vec![
WorkerFactory::create_regular("http://w1:8080".to_string()),
WorkerFactory::create_regular("http://w2:8080".to_string()),
WorkerFactory::create_regular("http://w3:8080".to_string()),
];
// Set different loads
workers[0].increment_load();
workers[0].increment_load(); // load = 2
workers[1].increment_load();
workers[1].increment_load();
workers[1].increment_load(); // load = 3
workers[2].increment_load(); // load = 1
assert_eq!(workers.total_load(), 6);
}
#[test]
fn test_find_worker() {
let workers: Vec<Box<dyn Worker>> = vec![
WorkerFactory::create_regular("http://w1:8080".to_string()),
WorkerFactory::create_regular("http://w2:8080".to_string()),
WorkerFactory::create_regular("http://w3:8080".to_string()),
];
// Found case
let found = workers.find_worker("http://w2:8080");
assert!(found.is_some());
assert_eq!(found.unwrap().url(), "http://w2:8080");
// Not found case
let not_found = workers.find_worker("http://w4:8080");
assert!(not_found.is_none());
}
#[test]
fn test_find_worker_mut() {
let mut workers: Vec<Box<dyn Worker>> = vec![
WorkerFactory::create_regular("http://w1:8080".to_string()),
WorkerFactory::create_regular("http://w2:8080".to_string()),
];
// Find and modify
if let Some(worker) = workers.find_worker_mut("http://w1:8080") {
worker.set_healthy(false);
}
// Verify modification
assert!(!workers[0].is_healthy());
assert!(workers[1].is_healthy());
}
// Test WorkerLoadGuard // Test WorkerLoadGuard
#[test] #[test]
fn test_load_guard_single_worker() { fn test_load_guard_single_worker() {
let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular); use crate::core::BasicWorkerBuilder;
let worker = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Regular)
.build();
assert_eq!(worker.load(), 0); assert_eq!(worker.load(), 0);
{ {
...@@ -1667,10 +1447,12 @@ mod tests { ...@@ -1667,10 +1447,12 @@ mod tests {
#[test] #[test]
fn test_load_guard_panic_safety() { fn test_load_guard_panic_safety() {
let worker = Arc::new(BasicWorker::new( use crate::core::BasicWorkerBuilder;
"http://test:8080".to_string(), let worker = Arc::new(
WorkerType::Regular, BasicWorkerBuilder::new("http://test:8080")
)); .worker_type(WorkerType::Regular)
.build(),
);
assert_eq!(worker.load(), 0); assert_eq!(worker.load(), 0);
// Clone for use inside catch_unwind // Clone for use inside catch_unwind
...@@ -1723,7 +1505,10 @@ mod tests { ...@@ -1723,7 +1505,10 @@ mod tests {
fn test_check_health_sync_wrapper() { fn test_check_health_sync_wrapper() {
// We can't easily test the actual HTTP call without mocking, // We can't easily test the actual HTTP call without mocking,
// but we can verify the sync wrapper works // but we can verify the sync wrapper works
let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular); use crate::core::BasicWorkerBuilder;
let worker = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Regular)
.build();
// This will fail because there's no server at this URL, // This will fail because there's no server at this URL,
// but it tests that the sync wrapper doesn't panic // but it tests that the sync wrapper doesn't panic
...@@ -1734,9 +1519,12 @@ mod tests { ...@@ -1734,9 +1519,12 @@ mod tests {
// Performance test for load counter // Performance test for load counter
#[test] #[test]
fn test_load_counter_performance() { fn test_load_counter_performance() {
use crate::core::BasicWorkerBuilder;
use std::time::Instant; use std::time::Instant;
let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular); let worker = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Regular)
.build();
let iterations = 1_000_000; let iterations = 1_000_000;
let start = Instant::now(); let start = Instant::now();
...@@ -1929,7 +1717,10 @@ mod tests { ...@@ -1929,7 +1717,10 @@ mod tests {
#[test] #[test]
fn test_worker_circuit_breaker() { fn test_worker_circuit_breaker() {
let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular); use crate::core::BasicWorkerBuilder;
let worker = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Regular)
.build();
// Initial state should be available // Initial state should be available
assert!(worker.is_available()); assert!(worker.is_available());
...@@ -1962,8 +1753,11 @@ mod tests { ...@@ -1962,8 +1753,11 @@ mod tests {
window_duration: Duration::from_secs(60), window_duration: Duration::from_secs(60),
}; };
let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular) use crate::core::BasicWorkerBuilder;
.with_circuit_breaker_config(config); let worker = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Regular)
.circuit_breaker_config(config)
.build();
// Should open after 2 failures // Should open after 2 failures
worker.record_outcome(false); worker.record_outcome(false);
......
use super::circuit_breaker::CircuitBreakerConfig; use super::circuit_breaker::{CircuitBreaker, CircuitBreakerConfig};
use super::worker::{BasicWorker, ConnectionMode, DPAwareWorker, HealthConfig, WorkerType}; use super::worker::{
BasicWorker, ConnectionMode, DPAwareWorker, HealthConfig, WorkerMetadata, WorkerType,
};
use crate::grpc::client::SglangSchedulerClient; use crate::grpc::client::SglangSchedulerClient;
use std::collections::HashMap; use std::collections::HashMap;
...@@ -88,23 +90,30 @@ impl BasicWorkerBuilder { ...@@ -88,23 +90,30 @@ impl BasicWorkerBuilder {
/// Build the BasicWorker instance /// Build the BasicWorker instance
pub fn build(self) -> BasicWorker { pub fn build(self) -> BasicWorker {
// Use the existing constructor methods for now use std::sync::{
let mut worker = atomic::{AtomicBool, AtomicUsize},
BasicWorker::with_connection_mode(self.url, self.worker_type, self.connection_mode); Arc,
};
// Apply optional configurations using existing methods use tokio::sync::Mutex;
if !self.labels.is_empty() {
worker = worker.with_labels(self.labels); let metadata = WorkerMetadata {
} url: self.url.clone(),
worker_type: self.worker_type,
worker = worker.with_health_config(self.health_config); connection_mode: self.connection_mode,
worker = worker.with_circuit_breaker_config(self.circuit_breaker_config); labels: self.labels,
health_config: self.health_config,
};
if let Some(client) = self.grpc_client { BasicWorker {
worker = worker.with_grpc_client(client); metadata,
load_counter: Arc::new(AtomicUsize::new(0)),
processed_counter: Arc::new(AtomicUsize::new(0)),
healthy: Arc::new(AtomicBool::new(true)),
consecutive_failures: Arc::new(AtomicUsize::new(0)),
consecutive_successes: Arc::new(AtomicUsize::new(0)),
circuit_breaker: CircuitBreaker::with_config(self.circuit_breaker_config),
grpc_client: self.grpc_client.map(|client| Arc::new(Mutex::new(client))),
} }
worker
} }
} }
......
...@@ -4,7 +4,7 @@ use super::pd_types::{api_path, PDRouterError}; ...@@ -4,7 +4,7 @@ use super::pd_types::{api_path, PDRouterError};
use crate::config::types::RetryConfig; use crate::config::types::RetryConfig;
use crate::core::{ use crate::core::{
is_retryable_status, BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, RetryExecutor, is_retryable_status, BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, RetryExecutor,
Worker, WorkerFactory, WorkerLoadGuard, WorkerRegistry, WorkerType, Worker, WorkerLoadGuard, WorkerRegistry, WorkerType,
}; };
use crate::metrics::RouterMetrics; use crate::metrics::RouterMetrics;
use crate::policies::{LoadBalancingPolicy, PolicyRegistry}; use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
...@@ -220,13 +220,12 @@ impl PDRouter { ...@@ -220,13 +220,12 @@ impl PDRouter {
// Create Worker for the new prefill server with circuit breaker configuration // Create Worker for the new prefill server with circuit breaker configuration
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint // TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
let worker = WorkerFactory::create_prefill_with_config( let worker = BasicWorkerBuilder::new(url.clone())
url.clone(), .worker_type(WorkerType::Prefill { bootstrap_port })
bootstrap_port, .circuit_breaker_config(self.circuit_breaker_config.clone())
self.circuit_breaker_config.clone(), .build();
);
let worker_arc: Arc<dyn Worker> = Arc::from(worker); let worker_arc: Arc<dyn Worker> = Arc::new(worker);
// Register the worker in the registry // Register the worker in the registry
self.worker_registry.register(worker_arc.clone()); self.worker_registry.register(worker_arc.clone());
...@@ -261,12 +260,12 @@ impl PDRouter { ...@@ -261,12 +260,12 @@ impl PDRouter {
// Create Worker for the new decode server with circuit breaker configuration // Create Worker for the new decode server with circuit breaker configuration
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint // TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
let worker = WorkerFactory::create_decode_with_config( let worker = BasicWorkerBuilder::new(url.clone())
url.clone(), .worker_type(WorkerType::Decode)
self.circuit_breaker_config.clone(), .circuit_breaker_config(self.circuit_breaker_config.clone())
); .build();
let worker_arc: Arc<dyn Worker> = Arc::from(worker); let worker_arc: Arc<dyn Worker> = Arc::new(worker);
// Register the worker in the registry // Register the worker in the registry
self.worker_registry.register(worker_arc.clone()); self.worker_registry.register(worker_arc.clone());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment