Unverified Commit 780d6a22 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] refactor worker to builder pattern 2/n (#10633)

parent 8b713c72
use super::{CircuitBreaker, CircuitBreakerConfig, WorkerError, WorkerResult}; use super::{CircuitBreaker, CircuitBreakerConfig, WorkerError, WorkerResult};
use crate::core::CircuitState;
use crate::core::{BasicWorkerBuilder, DPAwareWorkerBuilder};
use crate::grpc::SglangSchedulerClient; use crate::grpc::SglangSchedulerClient;
use crate::metrics::RouterMetrics; use crate::metrics::RouterMetrics;
use async_trait::async_trait; use async_trait::async_trait;
...@@ -96,22 +98,22 @@ pub trait Worker: Send + Sync + fmt::Debug { ...@@ -96,22 +98,22 @@ pub trait Worker: Send + Sync + fmt::Debug {
if before != after { if before != after {
let from = match before { let from = match before {
crate::core::CircuitState::Closed => "closed", CircuitState::Closed => "closed",
crate::core::CircuitState::Open => "open", CircuitState::Open => "open",
crate::core::CircuitState::HalfOpen => "half_open", CircuitState::HalfOpen => "half_open",
}; };
let to = match after { let to = match after {
crate::core::CircuitState::Closed => "closed", CircuitState::Closed => "closed",
crate::core::CircuitState::Open => "open", CircuitState::Open => "open",
crate::core::CircuitState::HalfOpen => "half_open", CircuitState::HalfOpen => "half_open",
}; };
RouterMetrics::record_cb_state_transition(self.url(), from, to); RouterMetrics::record_cb_state_transition(self.url(), from, to);
} }
let state_code = match self.circuit_breaker().state() { let state_code = match self.circuit_breaker().state() {
crate::core::CircuitState::Closed => 0u8, CircuitState::Closed => 0u8,
crate::core::CircuitState::Open => 1u8, CircuitState::Open => 1u8,
crate::core::CircuitState::HalfOpen => 2u8, CircuitState::HalfOpen => 2u8,
}; };
RouterMetrics::set_cb_state(self.url(), state_code); RouterMetrics::set_cb_state(self.url(), state_code);
} }
...@@ -706,6 +708,20 @@ impl Worker for DPAwareWorker { ...@@ -706,6 +708,20 @@ impl Worker for DPAwareWorker {
pub struct WorkerFactory; pub struct WorkerFactory;
impl WorkerFactory { impl WorkerFactory {
/// Create a BasicWorkerBuilder for customizable worker creation
pub fn builder(url: impl Into<String>) -> BasicWorkerBuilder {
BasicWorkerBuilder::new(url)
}
/// Create a DPAwareWorkerBuilder for customizable DP-aware worker creation
pub fn dp_builder(
base_url: impl Into<String>,
dp_rank: usize,
dp_size: usize,
) -> DPAwareWorkerBuilder {
DPAwareWorkerBuilder::new(base_url, dp_rank, dp_size)
}
/// Create a regular worker /// Create a regular worker
pub fn create_regular(url: String) -> Box<dyn Worker> { pub fn create_regular(url: String) -> Box<dyn Worker> {
Box::new(BasicWorker::new(url, WorkerType::Regular)) Box::new(BasicWorker::new(url, WorkerType::Regular))
...@@ -717,8 +733,9 @@ impl WorkerFactory { ...@@ -717,8 +733,9 @@ impl WorkerFactory {
circuit_breaker_config: CircuitBreakerConfig, circuit_breaker_config: CircuitBreakerConfig,
) -> Box<dyn Worker> { ) -> Box<dyn Worker> {
Box::new( Box::new(
BasicWorker::new(url, WorkerType::Regular) BasicWorkerBuilder::new(url)
.with_circuit_breaker_config(circuit_breaker_config), .circuit_breaker_config(circuit_breaker_config)
.build(),
) )
} }
...@@ -737,8 +754,10 @@ impl WorkerFactory { ...@@ -737,8 +754,10 @@ impl WorkerFactory {
circuit_breaker_config: CircuitBreakerConfig, circuit_breaker_config: CircuitBreakerConfig,
) -> Box<dyn Worker> { ) -> Box<dyn Worker> {
Box::new( Box::new(
BasicWorker::new(url, WorkerType::Prefill { bootstrap_port }) BasicWorkerBuilder::new(url)
.with_circuit_breaker_config(circuit_breaker_config), .worker_type(WorkerType::Prefill { bootstrap_port })
.circuit_breaker_config(circuit_breaker_config)
.build(),
) )
} }
...@@ -753,8 +772,10 @@ impl WorkerFactory { ...@@ -753,8 +772,10 @@ impl WorkerFactory {
circuit_breaker_config: CircuitBreakerConfig, circuit_breaker_config: CircuitBreakerConfig,
) -> Box<dyn Worker> { ) -> Box<dyn Worker> {
Box::new( Box::new(
BasicWorker::new(url, WorkerType::Decode) BasicWorkerBuilder::new(url)
.with_circuit_breaker_config(circuit_breaker_config), .worker_type(WorkerType::Decode)
.circuit_breaker_config(circuit_breaker_config)
.build(),
) )
} }
...@@ -800,8 +821,11 @@ impl WorkerFactory { ...@@ -800,8 +821,11 @@ impl WorkerFactory {
circuit_breaker_config: CircuitBreakerConfig, circuit_breaker_config: CircuitBreakerConfig,
) -> Box<dyn Worker> { ) -> Box<dyn Worker> {
Box::new( Box::new(
BasicWorker::with_connection_mode(url, worker_type, ConnectionMode::Grpc { port }) BasicWorkerBuilder::new(url)
.with_circuit_breaker_config(circuit_breaker_config), .worker_type(worker_type)
.connection_mode(ConnectionMode::Grpc { port })
.circuit_breaker_config(circuit_breaker_config)
.build(),
) )
} }
...@@ -811,13 +835,12 @@ impl WorkerFactory { ...@@ -811,13 +835,12 @@ impl WorkerFactory {
labels: std::collections::HashMap<String, String>, labels: std::collections::HashMap<String, String>,
circuit_breaker_config: CircuitBreakerConfig, circuit_breaker_config: CircuitBreakerConfig,
) -> Box<dyn Worker> { ) -> Box<dyn Worker> {
let mut worker = BasicWorker::new(url.clone(), WorkerType::Regular) Box::new(
.with_circuit_breaker_config(circuit_breaker_config); BasicWorkerBuilder::new(url)
.labels(labels)
// Add labels to metadata .circuit_breaker_config(circuit_breaker_config)
worker.metadata.labels = labels; .build(),
)
Box::new(worker)
} }
/// Create a prefill worker with labels /// Create a prefill worker with labels
...@@ -827,13 +850,13 @@ impl WorkerFactory { ...@@ -827,13 +850,13 @@ impl WorkerFactory {
labels: std::collections::HashMap<String, String>, labels: std::collections::HashMap<String, String>,
circuit_breaker_config: CircuitBreakerConfig, circuit_breaker_config: CircuitBreakerConfig,
) -> Box<dyn Worker> { ) -> Box<dyn Worker> {
let mut worker = BasicWorker::new(url.clone(), WorkerType::Prefill { bootstrap_port }) Box::new(
.with_circuit_breaker_config(circuit_breaker_config); BasicWorkerBuilder::new(url)
.worker_type(WorkerType::Prefill { bootstrap_port })
// Add labels to metadata .labels(labels)
worker.metadata.labels = labels; .circuit_breaker_config(circuit_breaker_config)
.build(),
Box::new(worker) )
} }
/// Create a decode worker with labels /// Create a decode worker with labels
...@@ -842,13 +865,13 @@ impl WorkerFactory { ...@@ -842,13 +865,13 @@ impl WorkerFactory {
labels: std::collections::HashMap<String, String>, labels: std::collections::HashMap<String, String>,
circuit_breaker_config: CircuitBreakerConfig, circuit_breaker_config: CircuitBreakerConfig,
) -> Box<dyn Worker> { ) -> Box<dyn Worker> {
let mut worker = BasicWorker::new(url.clone(), WorkerType::Decode) Box::new(
.with_circuit_breaker_config(circuit_breaker_config); BasicWorkerBuilder::new(url)
.worker_type(WorkerType::Decode)
// Add labels to metadata .labels(labels)
worker.metadata.labels = labels; .circuit_breaker_config(circuit_breaker_config)
.build(),
Box::new(worker) )
} }
/// Create a DP-aware worker of specified type /// Create a DP-aware worker of specified type
...@@ -1910,10 +1933,7 @@ mod tests { ...@@ -1910,10 +1933,7 @@ mod tests {
// Initial state should be available // Initial state should be available
assert!(worker.is_available()); assert!(worker.is_available());
assert_eq!( assert_eq!(worker.circuit_breaker().state(), CircuitState::Closed);
worker.circuit_breaker().state(),
crate::core::CircuitState::Closed
);
// Record some failures // Record some failures
worker.record_outcome(false); worker.record_outcome(false);
...@@ -1935,7 +1955,7 @@ mod tests { ...@@ -1935,7 +1955,7 @@ mod tests {
#[test] #[test]
fn test_worker_with_circuit_breaker_config() { fn test_worker_with_circuit_breaker_config() {
let config = crate::core::CircuitBreakerConfig { let config = CircuitBreakerConfig {
failure_threshold: 2, failure_threshold: 2,
success_threshold: 1, success_threshold: 1,
timeout_duration: Duration::from_millis(100), timeout_duration: Duration::from_millis(100),
...@@ -1956,17 +1976,11 @@ mod tests { ...@@ -1956,17 +1976,11 @@ mod tests {
// Should be half-open // Should be half-open
assert!(worker.is_available()); assert!(worker.is_available());
assert_eq!( assert_eq!(worker.circuit_breaker().state(), CircuitState::HalfOpen);
worker.circuit_breaker().state(),
crate::core::CircuitState::HalfOpen
);
// Success should close it // Success should close it
worker.record_outcome(true); worker.record_outcome(true);
assert_eq!( assert_eq!(worker.circuit_breaker().state(), CircuitState::Closed);
worker.circuit_breaker().state(),
crate::core::CircuitState::Closed
);
} }
#[test] #[test]
...@@ -1984,10 +1998,7 @@ mod tests { ...@@ -1984,10 +1998,7 @@ mod tests {
// Should not be available // Should not be available
assert!(!dp_worker.is_available()); assert!(!dp_worker.is_available());
assert_eq!( assert_eq!(dp_worker.circuit_breaker().state(), CircuitState::Open);
dp_worker.circuit_breaker().state(),
crate::core::CircuitState::Open
);
} }
// ===== Integration tests ===== // ===== Integration tests =====
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment