Unverified Commit 5291f32d authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] refactor worker to builder pattern 1/n (#10628)

parent 67073dde
......@@ -11,6 +11,7 @@ pub mod error;
pub mod retry;
pub mod token_bucket;
pub mod worker;
pub mod worker_builder;
pub mod worker_registry;
// Re-export commonly used types at the module level
......@@ -23,4 +24,5 @@ pub use worker::{
start_health_checker, BasicWorker, ConnectionMode, DPAwareWorker, HealthChecker, HealthConfig,
Worker, WorkerCollection, WorkerFactory, WorkerLoadGuard, WorkerType,
};
pub use worker_builder::{BasicWorkerBuilder, DPAwareWorkerBuilder};
pub use worker_registry::{WorkerId, WorkerRegistry, WorkerRegistryStats};
......@@ -586,6 +586,22 @@ impl DPAwareWorker {
base_url,
}
}
/// Create a new DP-aware worker with a pre-configured base worker
/// This is primarily used by the builder pattern
pub fn with_base_worker(
base_worker: BasicWorker,
base_url: String,
dp_rank: usize,
dp_size: usize,
) -> Self {
Self {
base_worker,
dp_rank,
dp_size,
base_url,
}
}
}
#[async_trait]
......@@ -1102,7 +1118,7 @@ pub fn start_health_checker(
// Periodically reset load counters to prevent drift
// Only do this when we believe all workers should be idle
if check_count % LOAD_RESET_INTERVAL == 0 {
if check_count.is_multiple_of(LOAD_RESET_INTERVAL) {
let max_load = workers_to_check.iter().map(|w| w.load()).max().unwrap_or(0);
// Only reset if load appears to be very low (likely drift)
if max_load <= 2 {
......
use super::circuit_breaker::CircuitBreakerConfig;
use super::worker::{BasicWorker, ConnectionMode, DPAwareWorker, HealthConfig, WorkerType};
use crate::grpc::client::SglangSchedulerClient;
use std::collections::HashMap;
/// Builder for creating BasicWorker instances with fluent API
pub struct BasicWorkerBuilder {
// Required fields
url: String,
// Optional fields with defaults
worker_type: WorkerType,
connection_mode: ConnectionMode,
labels: HashMap<String, String>,
health_config: HealthConfig,
circuit_breaker_config: CircuitBreakerConfig,
grpc_client: Option<SglangSchedulerClient>,
}
impl BasicWorkerBuilder {
/// Create a new builder with only the URL (defaults to Regular worker type)
pub fn new(url: impl Into<String>) -> Self {
Self {
url: url.into(),
worker_type: WorkerType::Regular,
connection_mode: ConnectionMode::Http,
labels: HashMap::new(),
health_config: HealthConfig::default(),
circuit_breaker_config: CircuitBreakerConfig::default(),
grpc_client: None,
}
}
/// Create a new builder with URL and worker type (for backwards compatibility)
pub fn new_with_type(url: impl Into<String>, worker_type: WorkerType) -> Self {
Self {
url: url.into(),
worker_type,
connection_mode: ConnectionMode::Http,
labels: HashMap::new(),
health_config: HealthConfig::default(),
circuit_breaker_config: CircuitBreakerConfig::default(),
grpc_client: None,
}
}
/// Set the worker type (Regular, Prefill, or Decode)
pub fn worker_type(mut self, worker_type: WorkerType) -> Self {
self.worker_type = worker_type;
self
}
/// Set the connection mode (HTTP or gRPC)
pub fn connection_mode(mut self, mode: ConnectionMode) -> Self {
self.connection_mode = mode;
self
}
/// Set labels for worker identification
pub fn labels(mut self, labels: HashMap<String, String>) -> Self {
self.labels = labels;
self
}
/// Add a single label
pub fn label(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.labels.insert(key.into(), value.into());
self
}
/// Set health check configuration
pub fn health_config(mut self, config: HealthConfig) -> Self {
self.health_config = config;
self
}
/// Set circuit breaker configuration
pub fn circuit_breaker_config(mut self, config: CircuitBreakerConfig) -> Self {
self.circuit_breaker_config = config;
self
}
/// Set gRPC client for gRPC workers
pub fn grpc_client(mut self, client: SglangSchedulerClient) -> Self {
self.grpc_client = Some(client);
self
}
/// Build the BasicWorker instance
pub fn build(self) -> BasicWorker {
// Use the existing constructor methods for now
let mut worker =
BasicWorker::with_connection_mode(self.url, self.worker_type, self.connection_mode);
// Apply optional configurations using existing methods
if !self.labels.is_empty() {
worker = worker.with_labels(self.labels);
}
worker = worker.with_health_config(self.health_config);
worker = worker.with_circuit_breaker_config(self.circuit_breaker_config);
if let Some(client) = self.grpc_client {
worker = worker.with_grpc_client(client);
}
worker
}
}
/// Builder for creating DPAwareWorker instances with fluent API
pub struct DPAwareWorkerBuilder {
// Required fields
base_url: String,
dp_rank: usize,
dp_size: usize,
// Optional fields with defaults
worker_type: WorkerType,
connection_mode: ConnectionMode,
labels: HashMap<String, String>,
health_config: HealthConfig,
circuit_breaker_config: CircuitBreakerConfig,
grpc_client: Option<SglangSchedulerClient>,
}
impl DPAwareWorkerBuilder {
/// Create a new DP-aware worker builder (defaults to Regular worker type)
pub fn new(base_url: impl Into<String>, dp_rank: usize, dp_size: usize) -> Self {
Self {
base_url: base_url.into(),
dp_rank,
dp_size,
worker_type: WorkerType::Regular,
connection_mode: ConnectionMode::Http,
labels: HashMap::new(),
health_config: HealthConfig::default(),
circuit_breaker_config: CircuitBreakerConfig::default(),
grpc_client: None,
}
}
/// Create a new DP-aware worker builder with worker type (for backwards compatibility)
pub fn new_with_type(
base_url: impl Into<String>,
dp_rank: usize,
dp_size: usize,
worker_type: WorkerType,
) -> Self {
Self {
base_url: base_url.into(),
dp_rank,
dp_size,
worker_type,
connection_mode: ConnectionMode::Http,
labels: HashMap::new(),
health_config: HealthConfig::default(),
circuit_breaker_config: CircuitBreakerConfig::default(),
grpc_client: None,
}
}
/// Set the worker type (Regular, Prefill, or Decode)
pub fn worker_type(mut self, worker_type: WorkerType) -> Self {
self.worker_type = worker_type;
self
}
/// Set the connection mode (HTTP or gRPC)
pub fn connection_mode(mut self, mode: ConnectionMode) -> Self {
self.connection_mode = mode;
self
}
/// Set labels for worker identification
pub fn labels(mut self, labels: HashMap<String, String>) -> Self {
self.labels = labels;
self
}
/// Add a single label
pub fn label(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.labels.insert(key.into(), value.into());
self
}
/// Set health check configuration
pub fn health_config(mut self, config: HealthConfig) -> Self {
self.health_config = config;
self
}
/// Set circuit breaker configuration
pub fn circuit_breaker_config(mut self, config: CircuitBreakerConfig) -> Self {
self.circuit_breaker_config = config;
self
}
/// Set gRPC client for gRPC workers
pub fn grpc_client(mut self, client: SglangSchedulerClient) -> Self {
self.grpc_client = Some(client);
self
}
/// Build the DPAwareWorker instance
pub fn build(self) -> DPAwareWorker {
// Create URL with DP rank suffix for identification
let worker_url = format!("{}@{}", self.base_url, self.dp_rank);
// Use BasicWorkerBuilder to create a properly configured base worker
let mut builder = BasicWorkerBuilder::new(worker_url)
.worker_type(self.worker_type)
.connection_mode(self.connection_mode)
.labels(self.labels)
.health_config(self.health_config)
.circuit_breaker_config(self.circuit_breaker_config);
// Add gRPC client if provided
if let Some(client) = self.grpc_client {
builder = builder.grpc_client(client);
}
let base_worker = builder.build();
// Create the DPAwareWorker with the configured base worker
DPAwareWorker::with_base_worker(base_worker, self.base_url, self.dp_rank, self.dp_size)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::worker::Worker;
use std::time::Duration;
#[test]
fn test_basic_worker_builder_minimal() {
// Using new API - defaults to Regular type
let worker = BasicWorkerBuilder::new("http://localhost:8080").build();
assert_eq!(worker.url(), "http://localhost:8080");
assert_eq!(worker.worker_type(), WorkerType::Regular);
assert_eq!(worker.connection_mode(), ConnectionMode::Http);
assert!(worker.is_healthy());
}
#[test]
fn test_basic_worker_builder_with_type() {
// Test setting worker type explicitly
let worker = BasicWorkerBuilder::new("http://localhost:8080")
.worker_type(WorkerType::Decode)
.build();
assert_eq!(worker.url(), "http://localhost:8080");
assert_eq!(worker.worker_type(), WorkerType::Decode);
assert_eq!(worker.connection_mode(), ConnectionMode::Http);
assert!(worker.is_healthy());
}
#[test]
fn test_basic_worker_builder_full() {
let mut labels = HashMap::new();
labels.insert("env".to_string(), "prod".to_string());
labels.insert("region".to_string(), "us-east".to_string());
let health_config = HealthConfig {
endpoint: "/health".to_string(),
timeout_secs: 30,
check_interval_secs: 60,
failure_threshold: 3,
success_threshold: 2,
};
let cb_config = CircuitBreakerConfig {
failure_threshold: 10,
success_threshold: 5,
timeout_duration: Duration::from_millis(2000),
window_duration: Duration::from_millis(30000),
};
let worker = BasicWorkerBuilder::new("http://localhost:8080")
.worker_type(WorkerType::Prefill {
bootstrap_port: None,
})
.connection_mode(ConnectionMode::Grpc { port: Some(50051) })
.labels(labels.clone())
.health_config(health_config.clone())
.circuit_breaker_config(cb_config)
.build();
assert_eq!(worker.url(), "http://localhost:8080");
assert_eq!(
worker.worker_type(),
WorkerType::Prefill {
bootstrap_port: None
}
);
assert_eq!(
worker.connection_mode(),
ConnectionMode::Grpc { port: Some(50051) }
);
assert_eq!(worker.metadata().labels, labels);
// Can't directly compare HealthConfig without PartialEq, so check individual fields
assert_eq!(
worker.metadata().health_config.endpoint,
health_config.endpoint
);
assert_eq!(
worker.metadata().health_config.timeout_secs,
health_config.timeout_secs
);
assert_eq!(
worker.metadata().health_config.check_interval_secs,
health_config.check_interval_secs
);
assert_eq!(
worker.metadata().health_config.failure_threshold,
health_config.failure_threshold
);
assert_eq!(
worker.metadata().health_config.success_threshold,
health_config.success_threshold
);
}
#[test]
fn test_basic_worker_builder_with_single_label() {
let worker = BasicWorkerBuilder::new("http://localhost:8080")
.worker_type(WorkerType::Decode)
.label("env", "staging")
.label("version", "v1.2.3")
.build();
assert_eq!(
worker.metadata().labels.get("env"),
Some(&"staging".to_string())
);
assert_eq!(
worker.metadata().labels.get("version"),
Some(&"v1.2.3".to_string())
);
}
#[test]
fn test_dp_aware_worker_builder_minimal() {
// Using new API - defaults to Regular type
let worker = DPAwareWorkerBuilder::new("http://localhost:8080", 2, 8).build();
assert_eq!(worker.url(), "http://localhost:8080@2");
assert_eq!(worker.dp_rank(), Some(2));
assert_eq!(worker.dp_size(), Some(8));
// Note: base_url is a private field, we can only test through the url() method
assert_eq!(worker.worker_type(), WorkerType::Regular);
}
#[test]
fn test_dp_aware_worker_builder_full() {
let mut labels = HashMap::new();
labels.insert("cluster".to_string(), "main".to_string());
let health_config = HealthConfig {
endpoint: "/status".to_string(),
timeout_secs: 20,
check_interval_secs: 45,
failure_threshold: 5,
success_threshold: 3,
};
let worker = DPAwareWorkerBuilder::new("http://localhost:8080", 3, 16)
.worker_type(WorkerType::Prefill {
bootstrap_port: Some(9090),
})
.connection_mode(ConnectionMode::Http)
.labels(labels.clone())
.health_config(health_config.clone())
.build();
assert_eq!(worker.url(), "http://localhost:8080@3");
assert_eq!(worker.dp_rank(), Some(3));
assert_eq!(worker.dp_size(), Some(16));
assert_eq!(worker.metadata().labels, labels);
// Can't directly compare HealthConfig without PartialEq, so check individual fields
assert_eq!(
worker.metadata().health_config.endpoint,
health_config.endpoint
);
assert_eq!(
worker.metadata().health_config.timeout_secs,
health_config.timeout_secs
);
assert_eq!(
worker.metadata().health_config.check_interval_secs,
health_config.check_interval_secs
);
assert_eq!(
worker.metadata().health_config.failure_threshold,
health_config.failure_threshold
);
assert_eq!(
worker.metadata().health_config.success_threshold,
health_config.success_threshold
);
}
#[test]
fn test_dp_aware_worker_with_grpc() {
// Test that DPAwareWorkerBuilder can set a gRPC client
let worker = DPAwareWorkerBuilder::new("grpc://cluster.local", 1, 4)
.worker_type(WorkerType::Decode)
.connection_mode(ConnectionMode::Grpc { port: Some(50051) })
.label("transport", "grpc")
.build();
assert_eq!(worker.url(), "grpc://cluster.local@1");
assert_eq!(worker.dp_rank(), Some(1));
assert_eq!(worker.dp_size(), Some(4));
assert_eq!(worker.worker_type(), WorkerType::Decode);
assert_eq!(
worker.connection_mode(),
ConnectionMode::Grpc { port: Some(50051) }
);
assert_eq!(
worker.metadata().labels.get("transport"),
Some(&"grpc".to_string())
);
// Note: We can't directly test the grpc_client as it's private,
// but the fact that the worker builds successfully with grpc connection mode
// validates that the configuration is properly passed through
}
}
......@@ -390,7 +390,7 @@ impl WorkerRegistry {
// Reset loads periodically
check_count += 1;
if check_count % LOAD_RESET_INTERVAL == 0 {
if check_count.is_multiple_of(LOAD_RESET_INTERVAL) {
tracing::debug!("Resetting worker loads (cycle {})", check_count);
for worker in &workers {
worker.reset_load();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment