Unverified Commit ac2a723b authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] refactor worker to builder pattern 3/n (#10647)

parent 56b991b1
...@@ -2,7 +2,7 @@ use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criteri ...@@ -2,7 +2,7 @@ use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criteri
use serde_json::{from_str, to_string, to_value, to_vec}; use serde_json::{from_str, to_string, to_value, to_vec};
use std::time::Instant; use std::time::Instant;
use sglang_router_rs::core::{BasicWorker, Worker, WorkerType}; use sglang_router_rs::core::{BasicWorker, BasicWorkerBuilder, Worker, WorkerType};
use sglang_router_rs::protocols::spec::{ use sglang_router_rs::protocols::spec::{
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest, ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest,
SamplingParams, StringOrArray, UserMessageContent, SamplingParams, StringOrArray, UserMessageContent,
...@@ -12,12 +12,11 @@ use sglang_router_rs::routers::http::pd_types::{ ...@@ -12,12 +12,11 @@ use sglang_router_rs::routers::http::pd_types::{
}; };
fn create_test_worker() -> BasicWorker { fn create_test_worker() -> BasicWorker {
BasicWorker::new( BasicWorkerBuilder::new("http://test-server:8000")
"http://test-server:8000".to_string(), .worker_type(WorkerType::Prefill {
WorkerType::Prefill {
bootstrap_port: Some(5678), bootstrap_port: Some(5678),
}, })
) .build()
} }
// Helper function to get bootstrap info from worker // Helper function to get bootstrap info from worker
......
...@@ -451,7 +451,7 @@ impl Drop for CacheAwarePolicy { ...@@ -451,7 +451,7 @@ impl Drop for CacheAwarePolicy {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::core::{BasicWorker, WorkerType}; use crate::core::{BasicWorkerBuilder, WorkerType};
#[test] #[test]
fn test_cache_aware_with_balanced_load() { fn test_cache_aware_with_balanced_load() {
...@@ -462,14 +462,16 @@ mod tests { ...@@ -462,14 +462,16 @@ mod tests {
}; };
let policy = CacheAwarePolicy::with_config(config); let policy = CacheAwarePolicy::with_config(config);
let workers: Vec<Arc<dyn Worker>> = vec![ let workers: Vec<Arc<dyn Worker>> = vec![
Arc::new(BasicWorker::new( Arc::new(
"http://w1:8000".to_string(), BasicWorkerBuilder::new("http://w1:8000")
WorkerType::Regular, .worker_type(WorkerType::Regular)
)), .build(),
Arc::new(BasicWorker::new( ),
"http://w2:8000".to_string(), Arc::new(
WorkerType::Regular, BasicWorkerBuilder::new("http://w2:8000")
)), .worker_type(WorkerType::Regular)
.build(),
),
]; ];
// Initialize the policy with workers // Initialize the policy with workers
...@@ -497,8 +499,12 @@ mod tests { ...@@ -497,8 +499,12 @@ mod tests {
max_tree_size: 10000, max_tree_size: 10000,
}); });
let worker1 = BasicWorker::new("http://w1:8000".to_string(), WorkerType::Regular); let worker1 = BasicWorkerBuilder::new("http://w1:8000")
let worker2 = BasicWorker::new("http://w2:8000".to_string(), WorkerType::Regular); .worker_type(WorkerType::Regular)
.build();
let worker2 = BasicWorkerBuilder::new("http://w2:8000")
.worker_type(WorkerType::Regular)
.build();
// Create significant load imbalance // Create significant load imbalance
for _ in 0..20 { for _ in 0..20 {
...@@ -524,14 +530,16 @@ mod tests { ...@@ -524,14 +530,16 @@ mod tests {
}; };
let policy = CacheAwarePolicy::with_config(config); let policy = CacheAwarePolicy::with_config(config);
let workers: Vec<Arc<dyn Worker>> = vec![ let workers: Vec<Arc<dyn Worker>> = vec![
Arc::new(BasicWorker::new( Arc::new(
"http://w1:8000".to_string(), BasicWorkerBuilder::new("http://w1:8000")
WorkerType::Regular, .worker_type(WorkerType::Regular)
)), .build(),
Arc::new(BasicWorker::new( ),
"http://w2:8000".to_string(), Arc::new(
WorkerType::Regular, BasicWorkerBuilder::new("http://w2:8000")
)), .worker_type(WorkerType::Regular)
.build(),
),
]; ];
policy.init_workers(&workers); policy.init_workers(&workers);
......
...@@ -121,23 +121,26 @@ pub(crate) fn get_healthy_worker_indices(workers: &[Arc<dyn Worker>]) -> Vec<usi ...@@ -121,23 +121,26 @@ pub(crate) fn get_healthy_worker_indices(workers: &[Arc<dyn Worker>]) -> Vec<usi
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::core::{BasicWorker, WorkerType}; use crate::core::{BasicWorkerBuilder, WorkerType};
#[test] #[test]
fn test_get_healthy_worker_indices() { fn test_get_healthy_worker_indices() {
let workers: Vec<Arc<dyn Worker>> = vec![ let workers: Vec<Arc<dyn Worker>> = vec![
Arc::new(BasicWorker::new( Arc::new(
"http://w1:8000".to_string(), BasicWorkerBuilder::new("http://w1:8000")
WorkerType::Regular, .worker_type(WorkerType::Regular)
)), .build(),
Arc::new(BasicWorker::new( ),
"http://w2:8000".to_string(), Arc::new(
WorkerType::Regular, BasicWorkerBuilder::new("http://w2:8000")
)), .worker_type(WorkerType::Regular)
Arc::new(BasicWorker::new( .build(),
"http://w3:8000".to_string(), ),
WorkerType::Regular, Arc::new(
)), BasicWorkerBuilder::new("http://w3:8000")
.worker_type(WorkerType::Regular)
.build(),
),
]; ];
// All healthy initially // All healthy initially
......
...@@ -119,14 +119,20 @@ impl Default for PowerOfTwoPolicy { ...@@ -119,14 +119,20 @@ impl Default for PowerOfTwoPolicy {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::core::{BasicWorker, WorkerType}; use crate::core::{BasicWorkerBuilder, WorkerType};
#[test] #[test]
fn test_power_of_two_selection() { fn test_power_of_two_selection() {
let policy = PowerOfTwoPolicy::new(); let policy = PowerOfTwoPolicy::new();
let worker1 = BasicWorker::new("http://w1:8000".to_string(), WorkerType::Regular); let worker1 = BasicWorkerBuilder::new("http://w1:8000")
let worker2 = BasicWorker::new("http://w2:8000".to_string(), WorkerType::Regular); .worker_type(WorkerType::Regular)
let worker3 = BasicWorker::new("http://w3:8000".to_string(), WorkerType::Regular); .build();
let worker2 = BasicWorkerBuilder::new("http://w2:8000")
.worker_type(WorkerType::Regular)
.build();
let worker3 = BasicWorkerBuilder::new("http://w3:8000")
.worker_type(WorkerType::Regular)
.build();
// Set different loads // Set different loads
for _ in 0..10 { for _ in 0..10 {
...@@ -157,14 +163,16 @@ mod tests { ...@@ -157,14 +163,16 @@ mod tests {
fn test_power_of_two_with_cached_loads() { fn test_power_of_two_with_cached_loads() {
let policy = PowerOfTwoPolicy::new(); let policy = PowerOfTwoPolicy::new();
let workers: Vec<Arc<dyn Worker>> = vec![ let workers: Vec<Arc<dyn Worker>> = vec![
Arc::new(BasicWorker::new( Arc::new(
"http://w1:8000".to_string(), BasicWorkerBuilder::new("http://w1:8000")
WorkerType::Regular, .worker_type(WorkerType::Regular)
)), .build(),
Arc::new(BasicWorker::new( ),
"http://w2:8000".to_string(), Arc::new(
WorkerType::Regular, BasicWorkerBuilder::new("http://w2:8000")
)), .worker_type(WorkerType::Regular)
.build(),
),
]; ];
// Update cached loads // Update cached loads
...@@ -190,10 +198,11 @@ mod tests { ...@@ -190,10 +198,11 @@ mod tests {
#[test] #[test]
fn test_power_of_two_single_worker() { fn test_power_of_two_single_worker() {
let policy = PowerOfTwoPolicy::new(); let policy = PowerOfTwoPolicy::new();
let workers: Vec<Arc<dyn Worker>> = vec![Arc::new(BasicWorker::new( let workers: Vec<Arc<dyn Worker>> = vec![Arc::new(
"http://w1:8000".to_string(), BasicWorkerBuilder::new("http://w1:8000")
WorkerType::Regular, .worker_type(WorkerType::Regular)
))]; .build(),
)];
// With single worker, should always select it // With single worker, should always select it
assert_eq!(policy.select_worker(&workers, None), Some(0)); assert_eq!(policy.select_worker(&workers, None), Some(0));
......
...@@ -51,25 +51,28 @@ impl LoadBalancingPolicy for RandomPolicy { ...@@ -51,25 +51,28 @@ impl LoadBalancingPolicy for RandomPolicy {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::core::{BasicWorker, WorkerType}; use crate::core::{BasicWorkerBuilder, WorkerType};
use std::collections::HashMap; use std::collections::HashMap;
#[test] #[test]
fn test_random_selection() { fn test_random_selection() {
let policy = RandomPolicy::new(); let policy = RandomPolicy::new();
let workers: Vec<Arc<dyn Worker>> = vec![ let workers: Vec<Arc<dyn Worker>> = vec![
Arc::new(BasicWorker::new( Arc::new(
"http://w1:8000".to_string(), BasicWorkerBuilder::new("http://w1:8000")
WorkerType::Regular, .worker_type(WorkerType::Regular)
)), .build(),
Arc::new(BasicWorker::new( ),
"http://w2:8000".to_string(), Arc::new(
WorkerType::Regular, BasicWorkerBuilder::new("http://w2:8000")
)), .worker_type(WorkerType::Regular)
Arc::new(BasicWorker::new( .build(),
"http://w3:8000".to_string(), ),
WorkerType::Regular, Arc::new(
)), BasicWorkerBuilder::new("http://w3:8000")
.worker_type(WorkerType::Regular)
.build(),
),
]; ];
// Test multiple selections to ensure randomness // Test multiple selections to ensure randomness
...@@ -89,14 +92,16 @@ mod tests { ...@@ -89,14 +92,16 @@ mod tests {
fn test_random_with_unhealthy_workers() { fn test_random_with_unhealthy_workers() {
let policy = RandomPolicy::new(); let policy = RandomPolicy::new();
let workers: Vec<Arc<dyn Worker>> = vec![ let workers: Vec<Arc<dyn Worker>> = vec![
Arc::new(BasicWorker::new( Arc::new(
"http://w1:8000".to_string(), BasicWorkerBuilder::new("http://w1:8000")
WorkerType::Regular, .worker_type(WorkerType::Regular)
)), .build(),
Arc::new(BasicWorker::new( ),
"http://w2:8000".to_string(), Arc::new(
WorkerType::Regular, BasicWorkerBuilder::new("http://w2:8000")
)), .worker_type(WorkerType::Regular)
.build(),
),
]; ];
// Mark first worker as unhealthy // Mark first worker as unhealthy
...@@ -111,10 +116,11 @@ mod tests { ...@@ -111,10 +116,11 @@ mod tests {
#[test] #[test]
fn test_random_no_healthy_workers() { fn test_random_no_healthy_workers() {
let policy = RandomPolicy::new(); let policy = RandomPolicy::new();
let workers: Vec<Arc<dyn Worker>> = vec![Arc::new(BasicWorker::new( let workers: Vec<Arc<dyn Worker>> = vec![Arc::new(
"http://w1:8000".to_string(), BasicWorkerBuilder::new("http://w1:8000")
WorkerType::Regular, .worker_type(WorkerType::Regular)
))]; .build(),
)];
workers[0].set_healthy(false); workers[0].set_healthy(false);
assert_eq!(policy.select_worker(&workers, None), None); assert_eq!(policy.select_worker(&workers, None), None);
......
...@@ -60,24 +60,27 @@ impl LoadBalancingPolicy for RoundRobinPolicy { ...@@ -60,24 +60,27 @@ impl LoadBalancingPolicy for RoundRobinPolicy {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::core::{BasicWorker, WorkerType}; use crate::core::{BasicWorkerBuilder, WorkerType};
#[test] #[test]
fn test_round_robin_selection() { fn test_round_robin_selection() {
let policy = RoundRobinPolicy::new(); let policy = RoundRobinPolicy::new();
let workers: Vec<Arc<dyn Worker>> = vec![ let workers: Vec<Arc<dyn Worker>> = vec![
Arc::new(BasicWorker::new( Arc::new(
"http://w1:8000".to_string(), BasicWorkerBuilder::new("http://w1:8000")
WorkerType::Regular, .worker_type(WorkerType::Regular)
)), .build(),
Arc::new(BasicWorker::new( ),
"http://w2:8000".to_string(), Arc::new(
WorkerType::Regular, BasicWorkerBuilder::new("http://w2:8000")
)), .worker_type(WorkerType::Regular)
Arc::new(BasicWorker::new( .build(),
"http://w3:8000".to_string(), ),
WorkerType::Regular, Arc::new(
)), BasicWorkerBuilder::new("http://w3:8000")
.worker_type(WorkerType::Regular)
.build(),
),
]; ];
// Should select workers in order: 0, 1, 2, 0, 1, 2, ... // Should select workers in order: 0, 1, 2, 0, 1, 2, ...
...@@ -92,18 +95,21 @@ mod tests { ...@@ -92,18 +95,21 @@ mod tests {
fn test_round_robin_with_unhealthy_workers() { fn test_round_robin_with_unhealthy_workers() {
let policy = RoundRobinPolicy::new(); let policy = RoundRobinPolicy::new();
let workers: Vec<Arc<dyn Worker>> = vec![ let workers: Vec<Arc<dyn Worker>> = vec![
Arc::new(BasicWorker::new( Arc::new(
"http://w1:8000".to_string(), BasicWorkerBuilder::new("http://w1:8000")
WorkerType::Regular, .worker_type(WorkerType::Regular)
)), .build(),
Arc::new(BasicWorker::new( ),
"http://w2:8000".to_string(), Arc::new(
WorkerType::Regular, BasicWorkerBuilder::new("http://w2:8000")
)), .worker_type(WorkerType::Regular)
Arc::new(BasicWorker::new( .build(),
"http://w3:8000".to_string(), ),
WorkerType::Regular, Arc::new(
)), BasicWorkerBuilder::new("http://w3:8000")
.worker_type(WorkerType::Regular)
.build(),
),
]; ];
// Mark middle worker as unhealthy // Mark middle worker as unhealthy
...@@ -120,14 +126,16 @@ mod tests { ...@@ -120,14 +126,16 @@ mod tests {
fn test_round_robin_reset() { fn test_round_robin_reset() {
let policy = RoundRobinPolicy::new(); let policy = RoundRobinPolicy::new();
let workers: Vec<Arc<dyn Worker>> = vec![ let workers: Vec<Arc<dyn Worker>> = vec![
Arc::new(BasicWorker::new( Arc::new(
"http://w1:8000".to_string(), BasicWorkerBuilder::new("http://w1:8000")
WorkerType::Regular, .worker_type(WorkerType::Regular)
)), .build(),
Arc::new(BasicWorker::new( ),
"http://w2:8000".to_string(), Arc::new(
WorkerType::Regular, BasicWorkerBuilder::new("http://w2:8000")
)), .worker_type(WorkerType::Regular)
.build(),
),
]; ];
// Advance the counter // Advance the counter
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
use crate::config::types::RetryConfig; use crate::config::types::RetryConfig;
use crate::core::{ use crate::core::{
BasicWorker, CircuitBreakerConfig, HealthChecker, HealthConfig, Worker, WorkerType, BasicWorkerBuilder, CircuitBreakerConfig, HealthChecker, HealthConfig, Worker, WorkerType,
}; };
use crate::grpc::SglangSchedulerClient; use crate::grpc::SglangSchedulerClient;
use crate::metrics::RouterMetrics; use crate::metrics::RouterMetrics;
...@@ -130,23 +130,22 @@ impl GrpcPDRouter { ...@@ -130,23 +130,22 @@ impl GrpcPDRouter {
let prefill_workers: Vec<Arc<dyn Worker>> = prefill_urls let prefill_workers: Vec<Arc<dyn Worker>> = prefill_urls
.iter() .iter()
.map(|(url, bootstrap_port)| { .map(|(url, bootstrap_port)| {
let worker = BasicWorker::with_connection_mode( let worker = BasicWorkerBuilder::new(url.clone())
url.clone(), .worker_type(WorkerType::Prefill {
WorkerType::Prefill {
bootstrap_port: *bootstrap_port, bootstrap_port: *bootstrap_port,
}, })
crate::core::ConnectionMode::Grpc { .connection_mode(crate::core::ConnectionMode::Grpc {
port: *bootstrap_port, port: *bootstrap_port,
}, })
) .circuit_breaker_config(core_cb_config.clone())
.with_circuit_breaker_config(core_cb_config.clone()) .health_config(HealthConfig {
.with_health_config(HealthConfig {
timeout_secs: ctx.router_config.health_check.timeout_secs, timeout_secs: ctx.router_config.health_check.timeout_secs,
check_interval_secs: ctx.router_config.health_check.check_interval_secs, check_interval_secs: ctx.router_config.health_check.check_interval_secs,
endpoint: ctx.router_config.health_check.endpoint.clone(), endpoint: ctx.router_config.health_check.endpoint.clone(),
failure_threshold: ctx.router_config.health_check.failure_threshold, failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: ctx.router_config.health_check.success_threshold, success_threshold: ctx.router_config.health_check.success_threshold,
}); })
.build();
Arc::new(worker) as Arc<dyn Worker> Arc::new(worker) as Arc<dyn Worker>
}) })
.collect(); .collect();
...@@ -155,19 +154,18 @@ impl GrpcPDRouter { ...@@ -155,19 +154,18 @@ impl GrpcPDRouter {
let decode_workers: Vec<Arc<dyn Worker>> = decode_urls let decode_workers: Vec<Arc<dyn Worker>> = decode_urls
.iter() .iter()
.map(|url| { .map(|url| {
let worker = BasicWorker::with_connection_mode( let worker = BasicWorkerBuilder::new(url.clone())
url.clone(), .worker_type(WorkerType::Decode)
WorkerType::Decode, .connection_mode(crate::core::ConnectionMode::Grpc { port: None })
crate::core::ConnectionMode::Grpc { port: None }, .circuit_breaker_config(core_cb_config.clone())
) .health_config(HealthConfig {
.with_circuit_breaker_config(core_cb_config.clone())
.with_health_config(HealthConfig {
timeout_secs: ctx.router_config.health_check.timeout_secs, timeout_secs: ctx.router_config.health_check.timeout_secs,
check_interval_secs: ctx.router_config.health_check.check_interval_secs, check_interval_secs: ctx.router_config.health_check.check_interval_secs,
endpoint: ctx.router_config.health_check.endpoint.clone(), endpoint: ctx.router_config.health_check.endpoint.clone(),
failure_threshold: ctx.router_config.health_check.failure_threshold, failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: ctx.router_config.health_check.success_threshold, success_threshold: ctx.router_config.health_check.success_threshold,
}); })
.build();
Arc::new(worker) as Arc<dyn Worker> Arc::new(worker) as Arc<dyn Worker>
}) })
.collect(); .collect();
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
use crate::config::types::RetryConfig; use crate::config::types::RetryConfig;
use crate::core::{ use crate::core::{
BasicWorker, CircuitBreakerConfig, HealthChecker, HealthConfig, Worker, WorkerType, BasicWorkerBuilder, CircuitBreakerConfig, HealthChecker, HealthConfig, Worker, WorkerType,
}; };
use crate::grpc::SglangSchedulerClient; use crate::grpc::SglangSchedulerClient;
use crate::metrics::RouterMetrics; use crate::metrics::RouterMetrics;
...@@ -108,20 +108,19 @@ impl GrpcRouter { ...@@ -108,20 +108,19 @@ impl GrpcRouter {
// Move clients from the HashMap to the workers // Move clients from the HashMap to the workers
for url in &worker_urls { for url in &worker_urls {
if let Some(client) = grpc_clients.remove(url) { if let Some(client) = grpc_clients.remove(url) {
let worker = BasicWorker::with_connection_mode( let worker = BasicWorkerBuilder::new(url.clone())
url.clone(), .worker_type(WorkerType::Regular)
WorkerType::Regular, .connection_mode(crate::core::ConnectionMode::Grpc { port: None })
crate::core::ConnectionMode::Grpc { port: None }, .circuit_breaker_config(core_cb_config.clone())
) .health_config(HealthConfig {
.with_circuit_breaker_config(core_cb_config.clone())
.with_health_config(HealthConfig {
timeout_secs: ctx.router_config.health_check.timeout_secs, timeout_secs: ctx.router_config.health_check.timeout_secs,
check_interval_secs: ctx.router_config.health_check.check_interval_secs, check_interval_secs: ctx.router_config.health_check.check_interval_secs,
endpoint: ctx.router_config.health_check.endpoint.clone(), endpoint: ctx.router_config.health_check.endpoint.clone(),
failure_threshold: ctx.router_config.health_check.failure_threshold, failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: ctx.router_config.health_check.success_threshold, success_threshold: ctx.router_config.health_check.success_threshold,
}) })
.with_grpc_client(client); .grpc_client(client)
.build();
workers.push(Arc::new(worker) as Arc<dyn Worker>); workers.push(Arc::new(worker) as Arc<dyn Worker>);
} else { } else {
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
use super::pd_types::{api_path, PDRouterError}; use super::pd_types::{api_path, PDRouterError};
use crate::config::types::RetryConfig; use crate::config::types::RetryConfig;
use crate::core::{ use crate::core::{
is_retryable_status, BasicWorker, CircuitBreakerConfig, HealthConfig, RetryExecutor, Worker, is_retryable_status, BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, RetryExecutor,
WorkerFactory, WorkerLoadGuard, WorkerRegistry, WorkerType, Worker, WorkerFactory, WorkerLoadGuard, WorkerRegistry, WorkerType,
}; };
use crate::metrics::RouterMetrics; use crate::metrics::RouterMetrics;
use crate::policies::{LoadBalancingPolicy, PolicyRegistry}; use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
...@@ -389,34 +389,35 @@ impl PDRouter { ...@@ -389,34 +389,35 @@ impl PDRouter {
// Register prefill workers in the registry // Register prefill workers in the registry
for (url, port) in prefill_urls { for (url, port) in prefill_urls {
let worker = BasicWorker::new( let worker = BasicWorkerBuilder::new(url)
url, .worker_type(WorkerType::Prefill {
WorkerType::Prefill {
bootstrap_port: port, bootstrap_port: port,
}, })
) .circuit_breaker_config(core_cb_config.clone())
.with_circuit_breaker_config(core_cb_config.clone()) .health_config(HealthConfig {
.with_health_config(HealthConfig {
timeout_secs: ctx.router_config.health_check.timeout_secs, timeout_secs: ctx.router_config.health_check.timeout_secs,
check_interval_secs: ctx.router_config.health_check.check_interval_secs, check_interval_secs: ctx.router_config.health_check.check_interval_secs,
endpoint: ctx.router_config.health_check.endpoint.clone(), endpoint: ctx.router_config.health_check.endpoint.clone(),
failure_threshold: ctx.router_config.health_check.failure_threshold, failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: ctx.router_config.health_check.success_threshold, success_threshold: ctx.router_config.health_check.success_threshold,
}); })
.build();
ctx.worker_registry.register(Arc::new(worker)); ctx.worker_registry.register(Arc::new(worker));
} }
// Register decode workers in the registry // Register decode workers in the registry
for url in decode_urls { for url in decode_urls {
let worker = BasicWorker::new(url, WorkerType::Decode) let worker = BasicWorkerBuilder::new(url)
.with_circuit_breaker_config(core_cb_config.clone()) .worker_type(WorkerType::Decode)
.with_health_config(HealthConfig { .circuit_breaker_config(core_cb_config.clone())
.health_config(HealthConfig {
timeout_secs: ctx.router_config.health_check.timeout_secs, timeout_secs: ctx.router_config.health_check.timeout_secs,
check_interval_secs: ctx.router_config.health_check.check_interval_secs, check_interval_secs: ctx.router_config.health_check.check_interval_secs,
endpoint: ctx.router_config.health_check.endpoint.clone(), endpoint: ctx.router_config.health_check.endpoint.clone(),
failure_threshold: ctx.router_config.health_check.failure_threshold, failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: ctx.router_config.health_check.success_threshold, success_threshold: ctx.router_config.health_check.success_threshold,
}); })
.build();
ctx.worker_registry.register(Arc::new(worker)); ctx.worker_registry.register(Arc::new(worker));
} }
...@@ -2116,7 +2117,7 @@ impl RouterTrait for PDRouter { ...@@ -2116,7 +2117,7 @@ impl RouterTrait for PDRouter {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::core::{BasicWorker, WorkerType}; use crate::core::WorkerType;
fn create_test_pd_router() -> PDRouter { fn create_test_pd_router() -> PDRouter {
let worker_registry = Arc::new(WorkerRegistry::new()); let worker_registry = Arc::new(WorkerRegistry::new());
...@@ -2139,7 +2140,9 @@ mod tests { ...@@ -2139,7 +2140,9 @@ mod tests {
} }
fn create_test_worker(url: String, worker_type: WorkerType, healthy: bool) -> Box<dyn Worker> { fn create_test_worker(url: String, worker_type: WorkerType, healthy: bool) -> Box<dyn Worker> {
let worker = BasicWorker::new(url, worker_type); let worker = BasicWorkerBuilder::new(url)
.worker_type(worker_type)
.build();
worker.set_healthy(healthy); worker.set_healthy(healthy);
Box::new(worker) Box::new(worker)
} }
......
use crate::config::types::RetryConfig; use crate::config::types::RetryConfig;
use crate::core::{ use crate::core::{
is_retryable_status, BasicWorker, CircuitBreakerConfig, HealthConfig, RetryExecutor, Worker, is_retryable_status, BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, RetryExecutor,
WorkerRegistry, WorkerType, Worker, WorkerRegistry, WorkerType,
}; };
use crate::metrics::RouterMetrics; use crate::metrics::RouterMetrics;
use crate::policies::{LoadBalancingPolicy, PolicyRegistry}; use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
...@@ -87,15 +87,17 @@ impl Router { ...@@ -87,15 +87,17 @@ impl Router {
for url in &worker_urls { for url in &worker_urls {
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint // TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
// For now, create worker without model_id // For now, create worker without model_id
let worker = BasicWorker::new(url.clone(), WorkerType::Regular) let worker = BasicWorkerBuilder::new(url.clone())
.with_circuit_breaker_config(core_cb_config.clone()) .worker_type(WorkerType::Regular)
.with_health_config(HealthConfig { .circuit_breaker_config(core_cb_config.clone())
.health_config(HealthConfig {
timeout_secs: ctx.router_config.health_check.timeout_secs, timeout_secs: ctx.router_config.health_check.timeout_secs,
check_interval_secs: ctx.router_config.health_check.check_interval_secs, check_interval_secs: ctx.router_config.health_check.check_interval_secs,
endpoint: ctx.router_config.health_check.endpoint.clone(), endpoint: ctx.router_config.health_check.endpoint.clone(),
failure_threshold: ctx.router_config.health_check.failure_threshold, failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: ctx.router_config.health_check.success_threshold, success_threshold: ctx.router_config.health_check.success_threshold,
}); })
.build();
let worker_arc = Arc::new(worker); let worker_arc = Arc::new(worker);
ctx.worker_registry.register(worker_arc.clone()); ctx.worker_registry.register(worker_arc.clone());
...@@ -991,11 +993,10 @@ impl Router { ...@@ -991,11 +993,10 @@ impl Router {
} }
info!("Added worker: {}", dp_url); info!("Added worker: {}", dp_url);
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint // TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
let new_worker = let new_worker = BasicWorkerBuilder::new(dp_url.to_string())
BasicWorker::new(dp_url.to_string(), WorkerType::Regular) .worker_type(WorkerType::Regular)
.with_circuit_breaker_config( .circuit_breaker_config(self.circuit_breaker_config.clone())
self.circuit_breaker_config.clone(), .build();
);
let worker_arc = Arc::new(new_worker); let worker_arc = Arc::new(new_worker);
self.worker_registry.register(worker_arc.clone()); self.worker_registry.register(worker_arc.clone());
...@@ -1028,11 +1029,10 @@ impl Router { ...@@ -1028,11 +1029,10 @@ impl Router {
info!("Added worker: {}", worker_url); info!("Added worker: {}", worker_url);
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint // TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
let new_worker = let new_worker = BasicWorkerBuilder::new(worker_url.to_string())
BasicWorker::new(worker_url.to_string(), WorkerType::Regular) .worker_type(WorkerType::Regular)
.with_circuit_breaker_config( .circuit_breaker_config(self.circuit_breaker_config.clone())
self.circuit_breaker_config.clone(), .build();
);
let worker_arc = Arc::new(new_worker); let worker_arc = Arc::new(new_worker);
self.worker_registry.register(worker_arc.clone()); self.worker_registry.register(worker_arc.clone());
...@@ -1595,8 +1595,12 @@ mod tests { ...@@ -1595,8 +1595,12 @@ mod tests {
)); ));
// Register test workers // Register test workers
let worker1 = BasicWorker::new("http://worker1:8080".to_string(), WorkerType::Regular); let worker1 = BasicWorkerBuilder::new("http://worker1:8080")
let worker2 = BasicWorker::new("http://worker2:8080".to_string(), WorkerType::Regular); .worker_type(WorkerType::Regular)
.build();
let worker2 = BasicWorkerBuilder::new("http://worker2:8080")
.worker_type(WorkerType::Regular)
.build();
worker_registry.register(Arc::new(worker1)); worker_registry.register(Arc::new(worker1));
worker_registry.register(Arc::new(worker2)); worker_registry.register(Arc::new(worker2));
......
use sglang_router_rs::core::{BasicWorker, Worker, WorkerType}; use sglang_router_rs::core::{BasicWorkerBuilder, Worker, WorkerType};
use sglang_router_rs::policies::{CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy}; use sglang_router_rs::policies::{CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
...@@ -16,13 +16,17 @@ fn test_backward_compatibility_with_empty_model_id() { ...@@ -16,13 +16,17 @@ fn test_backward_compatibility_with_empty_model_id() {
let policy = CacheAwarePolicy::with_config(config); let policy = CacheAwarePolicy::with_config(config);
// Create workers with empty model_id (simulating existing routers) // Create workers with empty model_id (simulating existing routers)
let worker1 = BasicWorker::new("http://worker1:8080".to_string(), WorkerType::Regular); let worker1 = BasicWorkerBuilder::new("http://worker1:8080")
.worker_type(WorkerType::Regular)
.build();
// No model_id label - should default to "unknown" // No model_id label - should default to "unknown"
let mut labels2 = HashMap::new(); let mut labels2 = HashMap::new();
labels2.insert("model_id".to_string(), "unknown".to_string()); labels2.insert("model_id".to_string(), "unknown".to_string());
let worker2 = BasicWorker::new("http://worker2:8080".to_string(), WorkerType::Regular) let worker2 = BasicWorkerBuilder::new("http://worker2:8080")
.with_labels(labels2); .worker_type(WorkerType::Regular)
.labels(labels2)
.build();
// Add workers - should both go to "default" tree // Add workers - should both go to "default" tree
policy.add_worker(&worker1); policy.add_worker(&worker1);
...@@ -53,23 +57,31 @@ fn test_mixed_model_ids() { ...@@ -53,23 +57,31 @@ fn test_mixed_model_ids() {
let policy = CacheAwarePolicy::with_config(config); let policy = CacheAwarePolicy::with_config(config);
// Create workers with different model_id scenarios // Create workers with different model_id scenarios
let worker1 = BasicWorker::new("http://worker1:8080".to_string(), WorkerType::Regular); let worker1 = BasicWorkerBuilder::new("http://worker1:8080")
.worker_type(WorkerType::Regular)
.build();
// No model_id label - defaults to "unknown" which goes to "default" tree // No model_id label - defaults to "unknown" which goes to "default" tree
let mut labels2 = HashMap::new(); let mut labels2 = HashMap::new();
labels2.insert("model_id".to_string(), "llama-3".to_string()); labels2.insert("model_id".to_string(), "llama-3".to_string());
let worker2 = BasicWorker::new("http://worker2:8080".to_string(), WorkerType::Regular) let worker2 = BasicWorkerBuilder::new("http://worker2:8080")
.with_labels(labels2); .worker_type(WorkerType::Regular)
.labels(labels2)
.build();
let mut labels3 = HashMap::new(); let mut labels3 = HashMap::new();
labels3.insert("model_id".to_string(), "unknown".to_string()); labels3.insert("model_id".to_string(), "unknown".to_string());
let worker3 = BasicWorker::new("http://worker3:8080".to_string(), WorkerType::Regular) let worker3 = BasicWorkerBuilder::new("http://worker3:8080")
.with_labels(labels3); .worker_type(WorkerType::Regular)
.labels(labels3)
.build();
let mut labels4 = HashMap::new(); let mut labels4 = HashMap::new();
labels4.insert("model_id".to_string(), "llama-3".to_string()); labels4.insert("model_id".to_string(), "llama-3".to_string());
let worker4 = BasicWorker::new("http://worker4:8080".to_string(), WorkerType::Regular) let worker4 = BasicWorkerBuilder::new("http://worker4:8080")
.with_labels(labels4); .worker_type(WorkerType::Regular)
.labels(labels4)
.build();
// Add all workers // Add all workers
policy.add_worker(&worker1); policy.add_worker(&worker1);
...@@ -108,10 +120,14 @@ fn test_remove_worker_by_url_backward_compat() { ...@@ -108,10 +120,14 @@ fn test_remove_worker_by_url_backward_compat() {
// Create workers with different model_ids // Create workers with different model_ids
let mut labels1 = HashMap::new(); let mut labels1 = HashMap::new();
labels1.insert("model_id".to_string(), "llama-3".to_string()); labels1.insert("model_id".to_string(), "llama-3".to_string());
let worker1 = BasicWorker::new("http://worker1:8080".to_string(), WorkerType::Regular) let worker1 = BasicWorkerBuilder::new("http://worker1:8080")
.with_labels(labels1); .worker_type(WorkerType::Regular)
.labels(labels1)
let worker2 = BasicWorker::new("http://worker2:8080".to_string(), WorkerType::Regular); .build();
let worker2 = BasicWorkerBuilder::new("http://worker2:8080")
.worker_type(WorkerType::Regular)
.build();
// No model_id label - defaults to "unknown" // No model_id label - defaults to "unknown"
// Add workers // Add workers
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment