Unverified Commit 2fa0462c authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] introduce dp worker abstraction (#8639)

parent 915140fd
...@@ -11,6 +11,6 @@ pub mod worker; ...@@ -11,6 +11,6 @@ pub mod worker;
// Re-export commonly used types at the module level // Re-export commonly used types at the module level
pub use error::{WorkerError, WorkerResult}; pub use error::{WorkerError, WorkerResult};
pub use worker::{ pub use worker::{
start_health_checker, BasicWorker, HealthChecker, Worker, WorkerCollection, WorkerFactory, start_health_checker, BasicWorker, DPAwareWorker, HealthChecker, Worker, WorkerCollection,
WorkerLoadGuard, WorkerType, WorkerFactory, WorkerLoadGuard, WorkerType,
}; };
use super::{WorkerError, WorkerResult}; use super::{WorkerError, WorkerResult};
use async_trait::async_trait; use async_trait::async_trait;
use futures;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use serde_json;
use std::fmt; use std::fmt;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc; use std::sync::Arc;
// Shared HTTP client for health checks // Shared HTTP client for worker operations (health checks, server info, etc.)
static HEALTH_CHECK_CLIENT: Lazy<reqwest::Client> = Lazy::new(|| { static WORKER_CLIENT: Lazy<reqwest::Client> = Lazy::new(|| {
reqwest::Client::builder() reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30)) // Default timeout, overridden per request .timeout(std::time::Duration::from_secs(30)) // Default timeout, overridden per request
.build() .build()
.expect("Failed to create health check HTTP client") .expect("Failed to create worker HTTP client")
}); });
/// Core worker abstraction that represents a backend service /// Core worker abstraction that represents a backend service
...@@ -64,6 +66,43 @@ pub trait Worker: Send + Sync + fmt::Debug { ...@@ -64,6 +66,43 @@ pub trait Worker: Send + Sync + fmt::Debug {
/// Clone the worker (for trait objects) /// Clone the worker (for trait objects)
fn clone_worker(&self) -> Box<dyn Worker>; fn clone_worker(&self) -> Box<dyn Worker>;
// === DP-aware methods ===
/// Check if this worker is DP-aware
fn is_dp_aware(&self) -> bool {
false
}
/// Get the base URL without any DP rank suffix
fn base_url(&self) -> &str {
self.url()
}
/// Get DP rank if this is a DP-aware worker
fn dp_rank(&self) -> Option<usize> {
None
}
/// Get DP size if this worker is part of a DP group
fn dp_size(&self) -> Option<usize> {
None
}
/// Transform a request for DP-aware routing
async fn prepare_request(&self, req: serde_json::Value) -> WorkerResult<serde_json::Value> {
Ok(req)
}
/// Get the actual endpoint URL for requests
fn endpoint_url(&self, route: &str) -> String {
format!("{}{}", self.base_url(), route)
}
/// Check if this worker can handle a specific request
fn can_handle(&self, _req: &serde_json::Value) -> bool {
true
}
} }
/// Worker type classification /// Worker type classification
...@@ -212,12 +251,7 @@ impl Worker for BasicWorker { ...@@ -212,12 +251,7 @@ impl Worker for BasicWorker {
let timeout = Duration::from_secs(self.metadata.health_config.timeout_secs); let timeout = Duration::from_secs(self.metadata.health_config.timeout_secs);
// Use the shared client with a custom timeout for this request // Use the shared client with a custom timeout for this request
match HEALTH_CHECK_CLIENT match WORKER_CLIENT.get(&health_url).timeout(timeout).send().await {
.get(&health_url)
.timeout(timeout)
.send()
.await
{
Ok(response) => { Ok(response) => {
if response.status().is_success() { if response.status().is_success() {
self.set_healthy(true); self.set_healthy(true);
...@@ -273,6 +307,160 @@ impl Worker for BasicWorker { ...@@ -273,6 +307,160 @@ impl Worker for BasicWorker {
} }
} }
/// A DP-aware worker that handles data-parallel routing
#[derive(Debug, Clone)]
pub struct DPAwareWorker {
/// The underlying basic worker
base_worker: BasicWorker,
/// DP rank for this worker
dp_rank: usize,
/// Total DP size
dp_size: usize,
/// Base URL without DP suffix
base_url: String,
}
impl DPAwareWorker {
/// Create a new DP-aware worker of any type
pub fn new(base_url: String, dp_rank: usize, dp_size: usize, worker_type: WorkerType) -> Self {
// Create URL with DP rank suffix for identification
let worker_url = format!("{}@{}", base_url, dp_rank);
let base_worker = BasicWorker::new(worker_url, worker_type);
Self {
base_worker,
dp_rank,
dp_size,
base_url,
}
}
}
#[async_trait]
impl Worker for DPAwareWorker {
fn url(&self) -> &str {
self.base_worker.url()
}
fn worker_type(&self) -> WorkerType {
self.base_worker.worker_type()
}
fn is_healthy(&self) -> bool {
self.base_worker.is_healthy()
}
fn set_healthy(&self, healthy: bool) {
self.base_worker.set_healthy(healthy);
}
async fn check_health_async(&self) -> WorkerResult<()> {
// Use base URL for health checks
let health_url = format!("{}/health", self.base_url);
let timeout =
std::time::Duration::from_secs(self.base_worker.metadata.health_config.timeout_secs);
let health_result = async {
let response = WORKER_CLIENT
.get(&health_url)
.timeout(timeout)
.send()
.await
.map_err(|e| format!("Health check request failed: {}", e))?;
if response.status().is_success() {
Ok(())
} else {
Err(format!(
"Health check returned status: {}",
response.status()
))
}
}
.await;
match health_result {
Ok(()) => {
self.set_healthy(true);
Ok(())
}
Err(reason) => {
self.set_healthy(false);
Err(WorkerError::HealthCheckFailed {
url: self.base_url.clone(),
reason,
})
}
}
}
fn load(&self) -> usize {
self.base_worker.load()
}
fn increment_load(&self) {
self.base_worker.increment_load();
}
fn decrement_load(&self) {
self.base_worker.decrement_load();
}
fn processed_requests(&self) -> usize {
self.base_worker.processed_requests()
}
fn increment_processed(&self) {
self.base_worker.increment_processed();
}
fn metadata(&self) -> &WorkerMetadata {
self.base_worker.metadata()
}
fn clone_worker(&self) -> Box<dyn Worker> {
Box::new(self.clone())
}
// DP-aware specific implementations
fn is_dp_aware(&self) -> bool {
true
}
fn base_url(&self) -> &str {
&self.base_url
}
fn dp_rank(&self) -> Option<usize> {
Some(self.dp_rank)
}
fn dp_size(&self) -> Option<usize> {
Some(self.dp_size)
}
async fn prepare_request(&self, mut req: serde_json::Value) -> WorkerResult<serde_json::Value> {
// Inject data_parallel_rank into the request
if let Some(map) = req.as_object_mut() {
map.insert(
"data_parallel_rank".to_string(),
serde_json::json!(self.dp_rank),
);
Ok(req)
} else {
Err(WorkerError::InvalidConfiguration {
message: "Request must be a JSON object for DP-aware routing".to_string(),
})
}
}
fn endpoint_url(&self, route: &str) -> String {
// Use base URL for actual requests
format!("{}{}", self.base_url, route)
}
}
/// Worker factory for creating workers of different types /// Worker factory for creating workers of different types
pub struct WorkerFactory; pub struct WorkerFactory;
...@@ -318,6 +506,133 @@ impl WorkerFactory { ...@@ -318,6 +506,133 @@ impl WorkerFactory {
(regular_workers, prefill_workers, decode_workers) (regular_workers, prefill_workers, decode_workers)
} }
/// Create a DP-aware worker of specified type
pub fn create_dp_aware(
base_url: String,
dp_rank: usize,
dp_size: usize,
worker_type: WorkerType,
) -> Box<dyn Worker> {
Box::new(DPAwareWorker::new(base_url, dp_rank, dp_size, worker_type))
}
/// Get DP size from a worker
async fn get_worker_dp_size(url: &str, api_key: &Option<String>) -> WorkerResult<usize> {
let mut req_builder = WORKER_CLIENT.get(&format!("{}/get_server_info", url));
if let Some(key) = api_key {
req_builder = req_builder.bearer_auth(key);
}
let response = req_builder
.send()
.await
.map_err(|e| WorkerError::NetworkError {
url: url.to_string(),
error: e.to_string(),
})?;
if !response.status().is_success() {
return Err(WorkerError::NetworkError {
url: url.to_string(),
error: format!("Server returned: {}", response.status()),
});
}
let info: serde_json::Value =
response
.json()
.await
.map_err(|e| WorkerError::NetworkError {
url: url.to_string(),
error: format!("Failed to parse JSON: {}", e),
})?;
let dp_size = info
.get("dp_size")
.and_then(|v| v.as_u64())
.ok_or_else(|| WorkerError::InvalidConfiguration {
message: "dp_size not found in server info".to_string(),
})?;
if dp_size > usize::MAX as u64 {
return Err(WorkerError::InvalidConfiguration {
message: format!("dp_size is too large: {}", dp_size),
});
}
Ok(dp_size as usize)
}
/// Private helper to create DP-aware workers of any type
async fn create_dp_aware_workers_of_type(
url: &str,
api_key: &Option<String>,
worker_type: WorkerType,
) -> WorkerResult<Vec<Box<dyn Worker>>> {
let dp_size = Self::get_worker_dp_size(url, api_key).await?;
let workers = (0..dp_size)
.map(|rank| Self::create_dp_aware(url.to_string(), rank, dp_size, worker_type.clone()))
.collect();
Ok(workers)
}
/// Create DP-aware regular workers from a single URL
pub async fn create_dp_aware_regular_workers(
url: &str,
api_key: &Option<String>,
) -> WorkerResult<Vec<Box<dyn Worker>>> {
Self::create_dp_aware_workers_of_type(url, api_key, WorkerType::Regular).await
}
/// Create DP-aware prefill workers from a single URL
pub async fn create_dp_aware_prefill_workers(
url: &str,
bootstrap_port: Option<u16>,
api_key: &Option<String>,
) -> WorkerResult<Vec<Box<dyn Worker>>> {
Self::create_dp_aware_workers_of_type(url, api_key, WorkerType::Prefill { bootstrap_port })
.await
}
/// Create DP-aware decode workers from a single URL
pub async fn create_dp_aware_decode_workers(
url: &str,
api_key: &Option<String>,
) -> WorkerResult<Vec<Box<dyn Worker>>> {
Self::create_dp_aware_workers_of_type(url, api_key, WorkerType::Decode).await
}
/// Create workers based on configuration (for regular router)
pub async fn create_workers(
urls: Vec<String>,
dp_aware: bool,
api_key: &Option<String>,
) -> WorkerResult<Vec<Box<dyn Worker>>> {
if dp_aware {
// Create futures for all worker creations
let worker_futs = urls
.iter()
.map(|url| Self::create_dp_aware_regular_workers(url, api_key));
// Execute all futures concurrently and flatten results
let all_workers = futures::future::try_join_all(worker_futs)
.await?
.into_iter()
.flatten()
.collect();
Ok(all_workers)
} else {
Ok(urls
.into_iter()
.map(|url| Self::create_regular(url))
.collect())
}
}
} }
/// Helper trait for collections of workers /// Helper trait for collections of workers
...@@ -1086,4 +1401,245 @@ mod tests { ...@@ -1086,4 +1401,245 @@ mod tests {
// Should be well over 1M ops/sec // Should be well over 1M ops/sec
assert!(ops_per_sec > 1_000_000.0); assert!(ops_per_sec > 1_000_000.0);
} }
// ===== Tests for DPAwareWorker =====
#[test]
fn test_dp_aware_worker_creation() {
let dp_worker =
DPAwareWorker::new("http://worker1:8080".to_string(), 2, 4, WorkerType::Regular);
assert_eq!(dp_worker.url(), "http://worker1:8080@2");
assert_eq!(dp_worker.base_url(), "http://worker1:8080");
assert!(dp_worker.is_dp_aware());
assert_eq!(dp_worker.dp_rank(), Some(2));
assert_eq!(dp_worker.dp_size(), Some(4));
assert_eq!(dp_worker.worker_type(), WorkerType::Regular);
}
#[test]
fn test_dp_aware_worker_creation_prefill() {
let dp_worker = DPAwareWorker::new(
"http://worker1:8080".to_string(),
1,
2,
WorkerType::Prefill {
bootstrap_port: Some(9090),
},
);
assert_eq!(dp_worker.url(), "http://worker1:8080@1");
assert!(dp_worker.is_dp_aware());
assert_eq!(
dp_worker.worker_type(),
WorkerType::Prefill {
bootstrap_port: Some(9090)
}
);
}
#[test]
fn test_dp_aware_worker_creation_decode() {
let dp_worker =
DPAwareWorker::new("http://worker1:8080".to_string(), 0, 4, WorkerType::Decode);
assert_eq!(dp_worker.url(), "http://worker1:8080@0");
assert!(dp_worker.is_dp_aware());
assert_eq!(dp_worker.worker_type(), WorkerType::Decode);
}
#[tokio::test]
async fn test_dp_aware_prepare_request() {
let dp_worker =
DPAwareWorker::new("http://worker1:8080".to_string(), 3, 8, WorkerType::Regular);
let original_req = serde_json::json!({
"prompt": "Hello",
"max_tokens": 100
});
let prepared_req = dp_worker.prepare_request(original_req).await.unwrap();
assert_eq!(prepared_req["prompt"], "Hello");
assert_eq!(prepared_req["max_tokens"], 100);
assert_eq!(prepared_req["data_parallel_rank"], 3);
}
#[tokio::test]
async fn test_dp_aware_prepare_request_invalid() {
let dp_worker =
DPAwareWorker::new("http://worker1:8080".to_string(), 0, 4, WorkerType::Regular);
// Non-object JSON should fail
let invalid_req = serde_json::json!("not an object");
let result = dp_worker.prepare_request(invalid_req).await;
assert!(result.is_err());
match result.unwrap_err() {
WorkerError::InvalidConfiguration { message } => {
assert!(message.contains("JSON object"));
}
_ => panic!("Expected InvalidConfiguration error"),
}
}
#[test]
fn test_dp_aware_endpoint_url() {
let dp_worker =
DPAwareWorker::new("http://worker1:8080".to_string(), 1, 4, WorkerType::Regular);
assert_eq!(
dp_worker.endpoint_url("/generate"),
"http://worker1:8080/generate"
);
assert_eq!(
dp_worker.endpoint_url("/health"),
"http://worker1:8080/health"
);
}
#[test]
fn test_dp_aware_worker_delegated_methods() {
let dp_worker =
DPAwareWorker::new("http://worker1:8080".to_string(), 0, 2, WorkerType::Regular);
// Test health status
assert!(dp_worker.is_healthy());
dp_worker.set_healthy(false);
assert!(!dp_worker.is_healthy());
// Test load tracking
assert_eq!(dp_worker.load(), 0);
dp_worker.increment_load();
assert_eq!(dp_worker.load(), 1);
dp_worker.decrement_load();
assert_eq!(dp_worker.load(), 0);
// Test processed tracking
assert_eq!(dp_worker.processed_requests(), 0);
dp_worker.increment_processed();
assert_eq!(dp_worker.processed_requests(), 1);
}
// ===== Tests for WorkerFactory async methods =====
#[tokio::test]
async fn test_factory_create_dp_aware() {
let worker = WorkerFactory::create_dp_aware(
"http://worker1:8080".to_string(),
1,
4,
WorkerType::Regular,
);
assert_eq!(worker.url(), "http://worker1:8080@1");
assert!(worker.is_dp_aware());
assert_eq!(worker.dp_rank(), Some(1));
assert_eq!(worker.dp_size(), Some(4));
assert_eq!(worker.worker_type(), WorkerType::Regular);
}
#[tokio::test]
async fn test_factory_create_dp_aware_prefill() {
let worker = WorkerFactory::create_dp_aware(
"http://worker1:8080".to_string(),
0,
2,
WorkerType::Prefill {
bootstrap_port: Some(8090),
},
);
assert_eq!(worker.url(), "http://worker1:8080@0");
assert!(worker.is_dp_aware());
assert_eq!(
worker.worker_type(),
WorkerType::Prefill {
bootstrap_port: Some(8090)
}
);
}
#[tokio::test]
async fn test_factory_create_workers_regular() {
let urls = vec!["http://w1:8080".to_string(), "http://w2:8080".to_string()];
let workers = WorkerFactory::create_workers(urls, false, &None)
.await
.unwrap();
assert_eq!(workers.len(), 2);
assert!(!workers[0].is_dp_aware());
assert!(!workers[1].is_dp_aware());
assert_eq!(workers[0].url(), "http://w1:8080");
assert_eq!(workers[1].url(), "http://w2:8080");
}
// ===== Integration tests =====
#[tokio::test]
async fn test_mixed_worker_types() {
// Create a mix of worker types
let regular = WorkerFactory::create_regular("http://regular:8080".to_string());
let prefill = WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9090));
let decode = WorkerFactory::create_decode("http://decode:8080".to_string());
let dp_aware_regular =
WorkerFactory::create_dp_aware("http://dp:8080".to_string(), 0, 2, WorkerType::Regular);
let dp_aware_prefill = WorkerFactory::create_dp_aware(
"http://dp-prefill:8080".to_string(),
1,
2,
WorkerType::Prefill {
bootstrap_port: None,
},
);
let dp_aware_decode = WorkerFactory::create_dp_aware(
"http://dp-decode:8080".to_string(),
0,
4,
WorkerType::Decode,
);
let workers: Vec<Box<dyn Worker>> = vec![
regular,
prefill,
decode,
dp_aware_regular,
dp_aware_prefill,
dp_aware_decode,
];
// Test that they all implement Worker trait properly
for worker in &workers {
assert!(worker.is_healthy());
assert_eq!(worker.load(), 0);
assert_eq!(worker.processed_requests(), 0);
}
// Test specific behaviors
assert!(!workers[0].is_dp_aware()); // regular
assert!(!workers[1].is_dp_aware()); // prefill
assert!(!workers[2].is_dp_aware()); // decode
assert!(workers[3].is_dp_aware()); // dp_aware_regular
assert!(workers[4].is_dp_aware()); // dp_aware_prefill
assert!(workers[5].is_dp_aware()); // dp_aware_decode
// Test worker types
assert_eq!(workers[0].worker_type(), WorkerType::Regular);
assert_eq!(
workers[1].worker_type(),
WorkerType::Prefill {
bootstrap_port: Some(9090)
}
);
assert_eq!(workers[2].worker_type(), WorkerType::Decode);
assert_eq!(workers[3].worker_type(), WorkerType::Regular);
assert_eq!(
workers[4].worker_type(),
WorkerType::Prefill {
bootstrap_port: None
}
);
assert_eq!(workers[5].worker_type(), WorkerType::Decode);
}
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment