Unverified Commit f2d5c492 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] add worker abstraction (#7960)

parent 2a2d3478
...@@ -30,6 +30,8 @@ tracing-appender = "0.2.3" ...@@ -30,6 +30,8 @@ tracing-appender = "0.2.3"
kube = { version = "0.88.1", features = ["runtime", "derive"] } kube = { version = "0.88.1", features = ["runtime", "derive"] }
k8s-openapi = { version = "0.21.0", features = ["v1_29"] } k8s-openapi = { version = "0.21.0", features = ["v1_29"] }
futures = "0.3" futures = "0.3"
async-trait = "0.1"
once_cell = "1.21"
# Added for metrics # Added for metrics
metrics = "0.24.2" metrics = "0.24.2"
metrics-exporter-prometheus = "0.17.0" metrics-exporter-prometheus = "0.17.0"
......
//! Error types for the SGLang router core
//!
//! This module defines error types used throughout the router for worker operations.
use std::fmt;
/// Worker-related errors
#[derive(Debug)]
pub enum WorkerError {
/// Health check failed
HealthCheckFailed { url: String, reason: String },
/// Worker not found
WorkerNotFound { url: String },
/// Invalid worker configuration
InvalidConfiguration { message: String },
/// Network error
NetworkError { url: String, error: String },
/// Worker is at capacity
WorkerAtCapacity { url: String },
}
impl fmt::Display for WorkerError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
WorkerError::HealthCheckFailed { url, reason } => {
write!(f, "Health check failed for worker {}: {}", url, reason)
}
WorkerError::WorkerNotFound { url } => {
write!(f, "Worker not found: {}", url)
}
WorkerError::InvalidConfiguration { message } => {
write!(f, "Invalid worker configuration: {}", message)
}
WorkerError::NetworkError { url, error } => {
write!(f, "Network error for worker {}: {}", url, error)
}
WorkerError::WorkerAtCapacity { url } => {
write!(f, "Worker at capacity: {}", url)
}
}
}
}
impl std::error::Error for WorkerError {}
/// Result type for worker operations
pub type WorkerResult<T> = Result<T, WorkerError>;
/// Convert from reqwest errors to worker errors
impl From<reqwest::Error> for WorkerError {
fn from(err: reqwest::Error) -> Self {
WorkerError::NetworkError {
url: err.url().map(|u| u.to_string()).unwrap_or_default(),
error: err.to_string(),
}
}
}
//! Core abstractions for the SGLang router
//!
//! This module contains the fundamental types and traits used throughout the router:
//! - Worker trait and implementations
//! - Error types
//! - Common utilities
pub mod error;
pub mod worker;
// Re-export commonly used types at the module level
pub use error::{WorkerError, WorkerResult};
pub use worker::{
start_health_checker, BasicWorker, HealthChecker, Worker, WorkerCollection, WorkerFactory,
WorkerLoadGuard, WorkerType,
};
use super::{WorkerError, WorkerResult};
use async_trait::async_trait;
use once_cell::sync::Lazy;
use std::fmt;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
// Shared HTTP client for health checks
static HEALTH_CHECK_CLIENT: Lazy<reqwest::Client> = Lazy::new(|| {
reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30)) // Default timeout, overridden per request
.build()
.expect("Failed to create health check HTTP client")
});
/// Core worker abstraction that represents a backend service
#[async_trait]
pub trait Worker: Send + Sync + fmt::Debug {
/// Get the worker's URL
fn url(&self) -> &str;
/// Get the worker's type (Regular, Prefill, or Decode)
fn worker_type(&self) -> WorkerType;
/// Check if the worker is currently healthy
fn is_healthy(&self) -> bool;
/// Set the worker's health status
fn set_healthy(&self, healthy: bool);
/// Perform an async health check on the worker
async fn check_health_async(&self) -> WorkerResult<()>;
/// Synchronous health check wrapper (for compatibility)
fn check_health(&self) -> WorkerResult<()> {
// Use a small runtime for synchronous contexts
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|e| WorkerError::HealthCheckFailed {
url: self.url().to_string(),
reason: format!("Failed to create runtime: {}", e),
})?
.block_on(self.check_health_async())
}
/// Get the current load (number of active requests)
fn load(&self) -> usize;
/// Increment the load counter
fn increment_load(&self);
/// Decrement the load counter
fn decrement_load(&self);
/// Get the number of processed requests
fn processed_requests(&self) -> usize;
/// Increment the processed requests counter
fn increment_processed(&self);
/// Get worker-specific metadata
fn metadata(&self) -> &WorkerMetadata;
/// Clone the worker (for trait objects)
fn clone_worker(&self) -> Box<dyn Worker>;
}
/// Worker type classification
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum WorkerType {
/// Regular worker for standard routing
Regular,
/// Prefill worker for PD disaggregated mode
Prefill {
/// Bootstrap port for communication with decode workers
bootstrap_port: Option<u16>,
},
/// Decode worker for PD disaggregated mode
Decode,
}
impl fmt::Display for WorkerType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
WorkerType::Regular => write!(f, "Regular"),
WorkerType::Prefill { bootstrap_port } => match bootstrap_port {
Some(port) => write!(f, "Prefill(bootstrap:{})", port),
None => write!(f, "Prefill"),
},
WorkerType::Decode => write!(f, "Decode"),
}
}
}
/// Health check configuration
#[derive(Debug, Clone)]
pub struct HealthConfig {
/// Timeout for health checks in seconds
pub timeout_secs: u64,
/// Interval between health checks in seconds
pub check_interval_secs: u64,
/// Health check endpoint path
pub endpoint: String,
}
impl Default for HealthConfig {
fn default() -> Self {
Self {
timeout_secs: 5,
check_interval_secs: 30,
endpoint: "/health".to_string(),
}
}
}
/// Metadata associated with a worker
#[derive(Debug, Clone)]
pub struct WorkerMetadata {
/// Worker URL
pub url: String,
/// Worker type
pub worker_type: WorkerType,
/// Additional labels/tags
pub labels: std::collections::HashMap<String, String>,
/// Health check configuration
pub health_config: HealthConfig,
}
/// Basic worker implementation
#[derive(Debug, Clone)]
pub struct BasicWorker {
metadata: WorkerMetadata,
load_counter: Arc<AtomicUsize>,
processed_counter: Arc<AtomicUsize>,
healthy: Arc<AtomicBool>,
}
impl BasicWorker {
pub fn new(url: String, worker_type: WorkerType) -> Self {
let metadata = WorkerMetadata {
url: url.clone(),
worker_type,
labels: std::collections::HashMap::new(),
health_config: HealthConfig::default(),
};
Self {
metadata,
load_counter: Arc::new(AtomicUsize::new(0)),
processed_counter: Arc::new(AtomicUsize::new(0)),
healthy: Arc::new(AtomicBool::new(true)),
}
}
pub fn with_labels(mut self, labels: std::collections::HashMap<String, String>) -> Self {
self.metadata.labels = labels;
self
}
pub fn with_health_config(mut self, config: HealthConfig) -> Self {
self.metadata.health_config = config;
self
}
}
#[async_trait]
impl Worker for BasicWorker {
fn url(&self) -> &str {
&self.metadata.url
}
fn worker_type(&self) -> WorkerType {
self.metadata.worker_type.clone()
}
fn is_healthy(&self) -> bool {
self.healthy.load(Ordering::Acquire)
}
fn set_healthy(&self, healthy: bool) {
self.healthy.store(healthy, Ordering::Release);
}
async fn check_health_async(&self) -> WorkerResult<()> {
use std::time::Duration;
// Perform actual HTTP health check
let health_url = format!("{}{}", self.url(), self.metadata.health_config.endpoint);
let timeout = Duration::from_secs(self.metadata.health_config.timeout_secs);
// Use the shared client with a custom timeout for this request
match HEALTH_CHECK_CLIENT
.get(&health_url)
.timeout(timeout)
.send()
.await
{
Ok(response) => {
if response.status().is_success() {
self.set_healthy(true);
Ok(())
} else {
self.set_healthy(false);
Err(WorkerError::HealthCheckFailed {
url: self.url().to_string(),
reason: format!("Health check returned status: {}", response.status()),
})
}
}
Err(e) => {
self.set_healthy(false);
Err(WorkerError::HealthCheckFailed {
url: self.url().to_string(),
reason: format!("Health check request failed: {}", e),
})
}
}
}
fn load(&self) -> usize {
self.load_counter.load(Ordering::Relaxed)
}
fn increment_load(&self) {
self.load_counter.fetch_add(1, Ordering::Relaxed);
}
fn decrement_load(&self) {
self.load_counter
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
current.checked_sub(1)
})
.ok();
}
fn processed_requests(&self) -> usize {
self.processed_counter.load(Ordering::Relaxed)
}
fn increment_processed(&self) {
self.processed_counter.fetch_add(1, Ordering::Relaxed);
}
fn metadata(&self) -> &WorkerMetadata {
&self.metadata
}
fn clone_worker(&self) -> Box<dyn Worker> {
Box::new(self.clone())
}
}
/// Worker factory for creating workers of different types
pub struct WorkerFactory;
impl WorkerFactory {
/// Create a regular worker
pub fn create_regular(url: String) -> Box<dyn Worker> {
Box::new(BasicWorker::new(url, WorkerType::Regular))
}
/// Create a prefill worker with optional bootstrap port
pub fn create_prefill(url: String, bootstrap_port: Option<u16>) -> Box<dyn Worker> {
Box::new(BasicWorker::new(
url,
WorkerType::Prefill { bootstrap_port },
))
}
/// Create a decode worker
pub fn create_decode(url: String) -> Box<dyn Worker> {
Box::new(BasicWorker::new(url, WorkerType::Decode))
}
/// Create workers from URLs with automatic type detection
pub fn create_from_urls(
regular_urls: Vec<String>,
prefill_urls: Vec<(String, Option<u16>)>,
decode_urls: Vec<String>,
) -> (
Vec<Box<dyn Worker>>,
Vec<Box<dyn Worker>>,
Vec<Box<dyn Worker>>,
) {
let regular_workers: Vec<Box<dyn Worker>> =
regular_urls.into_iter().map(Self::create_regular).collect();
let prefill_workers: Vec<Box<dyn Worker>> = prefill_urls
.into_iter()
.map(|(url, port)| Self::create_prefill(url, port))
.collect();
let decode_workers: Vec<Box<dyn Worker>> =
decode_urls.into_iter().map(Self::create_decode).collect();
(regular_workers, prefill_workers, decode_workers)
}
}
/// Helper trait for collections of workers
pub trait WorkerCollection {
fn healthy_workers(&self) -> Vec<&dyn Worker>;
fn total_load(&self) -> usize;
fn find_worker(&self, url: &str) -> Option<&dyn Worker>;
fn find_worker_mut(&mut self, url: &str) -> Option<&mut Box<dyn Worker>>;
}
impl WorkerCollection for Vec<Box<dyn Worker>> {
fn healthy_workers(&self) -> Vec<&dyn Worker> {
self.iter()
.filter(|w| w.is_healthy())
.map(|w| w.as_ref())
.collect()
}
fn total_load(&self) -> usize {
self.iter().map(|w| w.load()).sum()
}
fn find_worker(&self, url: &str) -> Option<&dyn Worker> {
self.iter().find(|w| w.url() == url).map(|w| w.as_ref())
}
fn find_worker_mut(&mut self, url: &str) -> Option<&mut Box<dyn Worker>> {
self.iter_mut().find(|w| w.url() == url)
}
}
/// Convert a list of worker URLs to worker trait objects
pub fn urls_to_workers(urls: Vec<String>) -> Vec<Box<dyn Worker>> {
urls.into_iter()
.map(WorkerFactory::create_regular)
.collect()
}
/// Convert worker trait objects back to URLs
pub fn workers_to_urls(workers: &[Box<dyn Worker>]) -> Vec<String> {
workers.iter().map(|w| w.url().to_string()).collect()
}
/// RAII guard for worker load management
pub struct WorkerLoadGuard<'a> {
workers: Vec<&'a dyn Worker>,
}
impl<'a> WorkerLoadGuard<'a> {
/// Create a new load guard for a single worker
pub fn new(worker: &'a dyn Worker) -> Self {
worker.increment_load();
Self {
workers: vec![worker],
}
}
/// Create a new load guard for multiple workers
pub fn new_multi(workers: Vec<&'a dyn Worker>) -> Self {
// Increment load counters for all workers
for worker in &workers {
worker.increment_load();
}
Self { workers }
}
}
impl<'a> Drop for WorkerLoadGuard<'a> {
fn drop(&mut self) {
// Decrement load counters for all workers
for worker in &self.workers {
worker.decrement_load();
}
}
}
/// Health checker handle with graceful shutdown
pub struct HealthChecker {
handle: tokio::task::JoinHandle<()>,
shutdown: Arc<AtomicBool>,
}
impl fmt::Debug for HealthChecker {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("HealthChecker")
.field("shutdown", &self.shutdown.load(Ordering::Relaxed))
.finish()
}
}
impl HealthChecker {
/// Shutdown the health checker gracefully
pub async fn shutdown(self) {
self.shutdown.store(true, Ordering::Release);
let _ = self.handle.await;
}
}
/// Start an async background health checker for a collection of workers
pub fn start_health_checker(
workers: std::sync::Arc<std::sync::RwLock<Vec<Box<dyn Worker>>>>,
check_interval_secs: u64,
) -> HealthChecker {
let shutdown = Arc::new(AtomicBool::new(false));
let shutdown_clone = shutdown.clone();
let handle = tokio::spawn(async move {
let mut interval =
tokio::time::interval(tokio::time::Duration::from_secs(check_interval_secs));
loop {
interval.tick().await;
// Check for shutdown signal
if shutdown_clone.load(Ordering::Acquire) {
tracing::info!("Health checker shutting down");
break;
}
// Check health of all workers
let workers_to_check = match workers.read() {
Ok(guard) => guard.iter().map(|w| w.clone_worker()).collect::<Vec<_>>(),
Err(poisoned) => {
tracing::error!("Worker lock poisoned: {}", poisoned);
continue;
}
};
// Perform health checks concurrently
let health_checks = workers_to_check.iter().map(|worker| {
let worker_url = worker.url().to_string();
let was_healthy = worker.is_healthy();
async move {
match worker.check_health_async().await {
Ok(_) => {
if !was_healthy {
tracing::info!("Worker {} is now healthy", worker_url);
}
}
Err(e) => {
if was_healthy {
tracing::warn!("Worker {} health check failed: {}", worker_url, e);
}
}
}
}
});
// Execute all health checks concurrently
futures::future::join_all(health_checks).await;
}
});
HealthChecker { handle, shutdown }
}
...@@ -2,6 +2,7 @@ use pyo3::prelude::*; ...@@ -2,6 +2,7 @@ use pyo3::prelude::*;
pub mod config; pub mod config;
pub mod logging; pub mod logging;
use std::collections::HashMap; use std::collections::HashMap;
pub mod core;
pub mod openai_api_types; pub mod openai_api_types;
pub mod pd_router; pub mod pd_router;
pub mod pd_types; pub mod pd_types;
......
This diff is collapsed.
// Essential PDLB types extracted for PD routing // Essential PDLB types extracted for PD routing
use crate::core::{Worker, WorkerType};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
...@@ -28,52 +29,21 @@ pub enum PDRouterError { ...@@ -28,52 +29,21 @@ pub enum PDRouterError {
Timeout { url: String }, Timeout { url: String },
} }
#[derive(Debug, Clone)] // Helper functions for workers
pub enum EngineType { pub fn api_path(url: &str, api_path: &str) -> String {
Prefill, if api_path.starts_with("/") {
Decode, format!("{}{}", url, api_path)
} } else {
format!("{}/{}", url, api_path)
#[derive(Debug, Clone)]
pub struct EngineInfo {
pub engine_type: EngineType,
pub url: String,
pub bootstrap_port: Option<u16>,
}
impl EngineInfo {
pub fn new_prefill(url: String, bootstrap_port: Option<u16>) -> Self {
EngineInfo {
engine_type: EngineType::Prefill,
url,
bootstrap_port,
}
}
pub fn new_decode(url: String) -> Self {
EngineInfo {
engine_type: EngineType::Decode,
url,
bootstrap_port: None,
}
}
pub fn api_path(&self, api_path: &str) -> String {
if api_path.starts_with("/") {
format!("{}{}", self.url, api_path)
} else {
format!("{}/{}", self.url, api_path)
}
} }
}
pub fn get_hostname(&self) -> String { pub fn get_hostname(url: &str) -> String {
// Simple hostname extraction without external dependencies // Simple hostname extraction without external dependencies
let url = self let url = url
.url .trim_start_matches("http://")
.trim_start_matches("http://") .trim_start_matches("https://");
.trim_start_matches("https://"); url.split(':').next().unwrap_or("localhost").to_string()
url.split(':').next().unwrap_or("localhost").to_string()
}
} }
// PD-specific routing policies // PD-specific routing policies
...@@ -112,12 +82,21 @@ pub trait Bootstrap: Send + Sync { ...@@ -112,12 +82,21 @@ pub trait Bootstrap: Send + Sync {
bootstrap_room: BootstrapRoom, bootstrap_room: BootstrapRoom,
); );
fn add_bootstrap_info(&mut self, prefill_info: &EngineInfo) -> Result<(), String> { fn add_bootstrap_info(&mut self, prefill_worker: &dyn Worker) -> Result<(), String> {
let batch_size = self.get_batch_size()?; let batch_size = self.get_batch_size()?;
// Extract bootstrap port from prefill worker if it's a prefill type
let bootstrap_port = match prefill_worker.worker_type() {
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
_ => None,
};
let hostname = get_hostname(prefill_worker.url());
if let Some(batch_size) = batch_size { if let Some(batch_size) = batch_size {
self.set_bootstrap_info( self.set_bootstrap_info(
BootstrapHost::Batch(vec![prefill_info.get_hostname(); batch_size]), BootstrapHost::Batch(vec![hostname; batch_size]),
BootstrapPort::Batch(vec![prefill_info.bootstrap_port; batch_size]), BootstrapPort::Batch(vec![bootstrap_port; batch_size]),
// Use high-quality random numbers to minimize collision risk // Use high-quality random numbers to minimize collision risk
BootstrapRoom::Batch( BootstrapRoom::Batch(
(0..batch_size) (0..batch_size)
...@@ -132,8 +111,8 @@ pub trait Bootstrap: Send + Sync { ...@@ -132,8 +111,8 @@ pub trait Bootstrap: Send + Sync {
); );
} else { } else {
self.set_bootstrap_info( self.set_bootstrap_info(
BootstrapHost::Single(prefill_info.get_hostname()), BootstrapHost::Single(hostname),
BootstrapPort::Single(prefill_info.bootstrap_port), BootstrapPort::Single(bootstrap_port),
BootstrapRoom::Single({ BootstrapRoom::Single({
// Use high-quality random number for single requests too // Use high-quality random number for single requests too
let r1 = rand::random::<u64>(); let r1 = rand::random::<u64>();
......
This diff is collapsed.
...@@ -236,8 +236,7 @@ async fn add_worker( ...@@ -236,8 +236,7 @@ async fn add_worker(
#[get("/list_workers")] #[get("/list_workers")]
async fn list_workers(data: web::Data<AppState>) -> impl Responder { async fn list_workers(data: web::Data<AppState>) -> impl Responder {
let workers = data.router.get_worker_urls(); let worker_list = data.router.get_worker_urls();
let worker_list = workers.read().unwrap().clone();
HttpResponse::Ok().json(serde_json::json!({ "urls": worker_list })) HttpResponse::Ok().json(serde_json::json!({ "urls": worker_list }))
} }
...@@ -381,7 +380,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { ...@@ -381,7 +380,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
info!("✅ Serving router on {}:{}", config.host, config.port); info!("✅ Serving router on {}:{}", config.host, config.port);
info!( info!(
"✅ Serving workers on {:?}", "✅ Serving workers on {:?}",
app_state.router.get_worker_urls().read().unwrap() app_state.router.get_worker_urls()
); );
HttpServer::new(move || { HttpServer::new(move || {
......
...@@ -547,11 +547,12 @@ mod tests { ...@@ -547,11 +547,12 @@ mod tests {
// Helper to create a Router instance for testing event handlers // Helper to create a Router instance for testing event handlers
fn create_test_router() -> Arc<Router> { fn create_test_router() -> Arc<Router> {
let worker_urls = Arc::new(RwLock::new(Vec::new())); let workers = Arc::new(RwLock::new(Vec::new()));
Arc::new(Router::Random { Arc::new(Router::Random {
worker_urls, workers,
timeout_secs: 5, timeout_secs: 5,
interval_secs: 1, interval_secs: 1,
_health_checker: None,
}) })
} }
...@@ -878,8 +879,6 @@ mod tests { ...@@ -878,8 +879,6 @@ mod tests {
assert!(!tracked_pods.lock().unwrap().contains(&pod_info)); assert!(!tracked_pods.lock().unwrap().contains(&pod_info));
assert!(!router assert!(!router
.get_worker_urls() .get_worker_urls()
.read()
.unwrap()
.contains(&pod_info.worker_url(port))); .contains(&pod_info.worker_url(port)));
} }
...@@ -907,7 +906,7 @@ mod tests { ...@@ -907,7 +906,7 @@ mod tests {
.await; .await;
assert!(tracked_pods.lock().unwrap().is_empty()); assert!(tracked_pods.lock().unwrap().is_empty());
assert!(router.get_worker_urls().read().unwrap().is_empty()); assert!(router.get_worker_urls().is_empty());
} }
#[tokio::test] #[tokio::test]
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
mod test_pd_routing { mod test_pd_routing {
use rand::Rng; use rand::Rng;
use serde_json::json; use serde_json::json;
use sglang_router_rs::pd_types::{EngineInfo, EngineType, PDSelectionPolicy}; use sglang_router_rs::pd_types::PDSelectionPolicy;
use sglang_router_rs::router::{PolicyConfig, Router}; use sglang_router_rs::router::{PolicyConfig, Router};
// Test-only struct to help validate PD request parsing // Test-only struct to help validate PD request parsing
...@@ -51,40 +51,35 @@ mod test_pd_routing { ...@@ -51,40 +51,35 @@ mod test_pd_routing {
// ======================================================================== // ========================================================================
#[test] #[test]
fn test_engine_info_creation() { fn test_worker_types() {
// Test EngineInfo creation for prefill servers use sglang_router_rs::core::{WorkerFactory, WorkerType};
let prefill_engine = EngineInfo::new_prefill("http://prefill:8080".to_string(), Some(9000));
match prefill_engine.engine_type { // Test worker creation for prefill servers
EngineType::Prefill => (), let prefill_worker =
_ => panic!("Expected Prefill engine type"), WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9000));
assert_eq!(prefill_worker.url(), "http://prefill:8080");
match prefill_worker.worker_type() {
WorkerType::Prefill { bootstrap_port } => {
assert_eq!(bootstrap_port, Some(9000));
}
_ => panic!("Expected Prefill worker type"),
} }
assert_eq!(prefill_engine.url, "http://prefill:8080");
assert_eq!(prefill_engine.bootstrap_port, Some(9000)); // Test worker creation for decode servers
assert_eq!(prefill_engine.get_hostname(), "prefill"); let decode_worker = WorkerFactory::create_decode("http://decode:8080".to_string());
assert_eq!(decode_worker.url(), "http://decode:8080");
// Test EngineInfo creation for decode servers match decode_worker.worker_type() {
let decode_engine = EngineInfo::new_decode("http://decode:8080".to_string()); WorkerType::Decode => (),
match decode_engine.engine_type { _ => panic!("Expected Decode worker type"),
EngineType::Decode => (),
_ => panic!("Expected Decode engine type"),
} }
assert_eq!(decode_engine.url, "http://decode:8080");
assert_eq!(decode_engine.bootstrap_port, None);
assert_eq!(decode_engine.get_hostname(), "decode");
// Test API path generation // Test regular worker creation
assert_eq!( let regular_worker = WorkerFactory::create_regular("http://regular:8080".to_string());
prefill_engine.api_path("/generate"), assert_eq!(regular_worker.url(), "http://regular:8080");
"http://prefill:8080/generate" match regular_worker.worker_type() {
); WorkerType::Regular => (),
assert_eq!( _ => panic!("Expected Regular worker type"),
prefill_engine.api_path("health"), }
"http://prefill:8080/health"
);
assert_eq!(
decode_engine.api_path("/v1/chat/completions"),
"http://decode:8080/v1/chat/completions"
);
} }
#[test] #[test]
...@@ -230,6 +225,9 @@ mod test_pd_routing { ...@@ -230,6 +225,9 @@ mod test_pd_routing {
#[test] #[test]
fn test_bootstrap_injection_simulation() { fn test_bootstrap_injection_simulation() {
use sglang_router_rs::core::{WorkerFactory, WorkerType};
use sglang_router_rs::pd_types::get_hostname;
// Since we can't test the actual inject_bootstrap_fields function here // Since we can't test the actual inject_bootstrap_fields function here
// (it's private in the router module), we'll test the expected behavior // (it's private in the router module), we'll test the expected behavior
...@@ -240,15 +238,24 @@ mod test_pd_routing { ...@@ -240,15 +238,24 @@ mod test_pd_routing {
"temperature": 0.7 "temperature": 0.7
}); });
// Create a prefill worker to simulate injection
let prefill_worker =
WorkerFactory::create_prefill("http://prefill1:8080".to_string(), Some(9000));
// Extract bootstrap port from worker type
let bootstrap_port = match prefill_worker.worker_type() {
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
_ => None,
};
// Simulate what inject_bootstrap_fields would do // Simulate what inject_bootstrap_fields would do
let prefill_info = EngineInfo::new_prefill("http://prefill1:8080".to_string(), Some(9000)); single_json["bootstrap_host"] = json!(get_hostname(prefill_worker.url()));
single_json["bootstrap_host"] = json!(prefill_info.get_hostname()); single_json["bootstrap_port"] = json!(bootstrap_port);
single_json["bootstrap_port"] = json!(prefill_info.bootstrap_port);
single_json["bootstrap_room"] = json!(12345u64); // Random room ID single_json["bootstrap_room"] = json!(12345u64); // Random room ID
// Verify bootstrap fields are added correctly // Verify bootstrap fields are added correctly
assert_eq!(single_json["bootstrap_host"], "prefill1"); assert_eq!(single_json["bootstrap_host"], "prefill1");
assert_eq!(single_json["bootstrap_port"], 9000); assert_eq!(single_json["bootstrap_port"], json!(Some(9000)));
assert!(single_json["bootstrap_room"].is_u64()); assert!(single_json["bootstrap_room"].is_u64());
assert_eq!(single_json["temperature"], 0.7); // Original field preserved assert_eq!(single_json["temperature"], 0.7); // Original field preserved
...@@ -259,8 +266,9 @@ mod test_pd_routing { ...@@ -259,8 +266,9 @@ mod test_pd_routing {
}); });
let batch_size = 3; let batch_size = 3;
batch_json["bootstrap_host"] = json!(vec![prefill_info.get_hostname(); batch_size]); let hostname = get_hostname(prefill_worker.url());
batch_json["bootstrap_port"] = json!(vec![prefill_info.bootstrap_port; batch_size]); batch_json["bootstrap_host"] = json!(vec![hostname; batch_size]);
batch_json["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
batch_json["bootstrap_room"] = json!(vec![111u64, 222u64, 333u64]); batch_json["bootstrap_room"] = json!(vec![111u64, 222u64, 333u64]);
// Verify batch bootstrap fields // Verify batch bootstrap fields
...@@ -306,7 +314,9 @@ mod test_pd_routing { ...@@ -306,7 +314,9 @@ mod test_pd_routing {
} }
#[test] #[test]
fn test_engine_info_hostname_extraction() { fn test_hostname_extraction() {
use sglang_router_rs::pd_types::get_hostname;
// Test various URL formats // Test various URL formats
let test_cases = vec![ let test_cases = vec![
("http://localhost:8080", "localhost"), ("http://localhost:8080", "localhost"),
...@@ -318,8 +328,7 @@ mod test_pd_routing { ...@@ -318,8 +328,7 @@ mod test_pd_routing {
]; ];
for (url, expected_hostname) in test_cases { for (url, expected_hostname) in test_cases {
let engine = EngineInfo::new_prefill(url.to_string(), None); assert_eq!(get_hostname(url), expected_hostname);
assert_eq!(engine.get_hostname(), expected_hostname);
} }
} }
...@@ -652,6 +661,9 @@ mod test_pd_routing { ...@@ -652,6 +661,9 @@ mod test_pd_routing {
#[test] #[test]
fn test_bootstrap_injection_with_benchmark_requests() { fn test_bootstrap_injection_with_benchmark_requests() {
use sglang_router_rs::core::{WorkerFactory, WorkerType};
use sglang_router_rs::pd_types::get_hostname;
// Test bootstrap injection with actual benchmark request patterns // Test bootstrap injection with actual benchmark request patterns
let mut benchmark_request = json!({ let mut benchmark_request = json!({
"input_ids": vec![vec![1, 2, 3, 4]; 16], // Batch size 16 "input_ids": vec![vec![1, 2, 3, 4]; 16], // Batch size 16
...@@ -664,12 +676,20 @@ mod test_pd_routing { ...@@ -664,12 +676,20 @@ mod test_pd_routing {
"stream": true "stream": true
}); });
// Simulate bootstrap injection // Create a prefill worker to simulate injection
let prefill_info = EngineInfo::new_prefill("http://prefill:8080".to_string(), Some(9000)); let prefill_worker =
WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9000));
// Extract bootstrap port from worker type
let bootstrap_port = match prefill_worker.worker_type() {
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
_ => None,
};
let batch_size = 16; let batch_size = 16;
let hostname = get_hostname(prefill_worker.url());
benchmark_request["bootstrap_host"] = json!(vec![prefill_info.get_hostname(); batch_size]); benchmark_request["bootstrap_host"] = json!(vec![hostname; batch_size]);
benchmark_request["bootstrap_port"] = json!(vec![prefill_info.bootstrap_port; batch_size]); benchmark_request["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
benchmark_request["bootstrap_room"] = benchmark_request["bootstrap_room"] =
json!((0..batch_size).map(|_| 12345u64).collect::<Vec<_>>()); json!((0..batch_size).map(|_| 12345u64).collect::<Vec<_>>());
...@@ -770,6 +790,9 @@ mod test_pd_routing { ...@@ -770,6 +790,9 @@ mod test_pd_routing {
#[test] #[test]
fn test_large_batch_bootstrap_injection() { fn test_large_batch_bootstrap_injection() {
use sglang_router_rs::core::{WorkerFactory, WorkerType};
use sglang_router_rs::pd_types::get_hostname;
// Test bootstrap injection performance with very large batches // Test bootstrap injection performance with very large batches
// This simulates the bench_one_batch_server.py scenario // This simulates the bench_one_batch_server.py scenario
let large_batch_sizes = vec![1024, 4096, 8192]; let large_batch_sizes = vec![1024, 4096, 8192];
...@@ -787,14 +810,19 @@ mod test_pd_routing { ...@@ -787,14 +810,19 @@ mod test_pd_routing {
"stream": true "stream": true
}); });
// Simulate bootstrap injection // Create a prefill worker to simulate injection
let prefill_info = let prefill_worker =
EngineInfo::new_prefill("http://prefill:8080".to_string(), Some(9000)); WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9000));
// Extract bootstrap port from worker type
let bootstrap_port = match prefill_worker.worker_type() {
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
_ => None,
};
let hostname = get_hostname(prefill_worker.url());
large_batch_request["bootstrap_host"] = large_batch_request["bootstrap_host"] = json!(vec![hostname; batch_size]);
json!(vec![prefill_info.get_hostname(); batch_size]); large_batch_request["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
large_batch_request["bootstrap_port"] =
json!(vec![prefill_info.bootstrap_port; batch_size]);
large_batch_request["bootstrap_room"] = json!((0..batch_size) large_batch_request["bootstrap_room"] = json!((0..batch_size)
.map(|_| rand::thread_rng().gen::<u64>()) .map(|_| rand::thread_rng().gen::<u64>())
.collect::<Vec<_>>()); .collect::<Vec<_>>());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment