Unverified Commit 0311ce8e authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[router] Expose worker startup secs & Return error instead of panic for router init (#3016)

parent 5dfcacfc
...@@ -33,6 +33,7 @@ class RouterArgs: ...@@ -33,6 +33,7 @@ class RouterArgs:
# Routing policy # Routing policy
policy: str = "cache_aware" policy: str = "cache_aware"
worker_startup_timeout_secs: int = 300
cache_threshold: float = 0.5 cache_threshold: float = 0.5
balance_abs_threshold: int = 32 balance_abs_threshold: int = 32
balance_rel_threshold: float = 1.0001 balance_rel_threshold: float = 1.0001
...@@ -87,6 +88,12 @@ class RouterArgs: ...@@ -87,6 +88,12 @@ class RouterArgs:
choices=["random", "round_robin", "cache_aware"], choices=["random", "round_robin", "cache_aware"],
help="Load balancing policy to use", help="Load balancing policy to use",
) )
parser.add_argument(
f"--{prefix}worker-startup-timeout-secs",
type=int,
default=RouterArgs.worker_startup_timeout_secs,
help="Timeout in seconds for worker startup",
)
parser.add_argument( parser.add_argument(
f"--{prefix}cache-threshold", f"--{prefix}cache-threshold",
type=float, type=float,
...@@ -147,6 +154,9 @@ class RouterArgs: ...@@ -147,6 +154,9 @@ class RouterArgs:
host=args.host, host=args.host,
port=args.port, port=args.port,
policy=getattr(args, f"{prefix}policy"), policy=getattr(args, f"{prefix}policy"),
worker_startup_timeout_secs=getattr(
args, f"{prefix}worker_startup_timeout_secs"
),
cache_threshold=getattr(args, f"{prefix}cache_threshold"), cache_threshold=getattr(args, f"{prefix}cache_threshold"),
balance_abs_threshold=getattr(args, f"{prefix}balance_abs_threshold"), balance_abs_threshold=getattr(args, f"{prefix}balance_abs_threshold"),
balance_rel_threshold=getattr(args, f"{prefix}balance_rel_threshold"), balance_rel_threshold=getattr(args, f"{prefix}balance_rel_threshold"),
...@@ -188,9 +198,10 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]: ...@@ -188,9 +198,10 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
router = Router( router = Router(
worker_urls=router_args.worker_urls, worker_urls=router_args.worker_urls,
policy=policy_from_str(router_args.policy),
host=router_args.host, host=router_args.host,
port=router_args.port, port=router_args.port,
policy=policy_from_str(router_args.policy),
worker_startup_timeout_secs=router_args.worker_startup_timeout_secs,
cache_threshold=router_args.cache_threshold, cache_threshold=router_args.cache_threshold,
balance_abs_threshold=router_args.balance_abs_threshold, balance_abs_threshold=router_args.balance_abs_threshold,
balance_rel_threshold=router_args.balance_rel_threshold, balance_rel_threshold=router_args.balance_rel_threshold,
...@@ -205,7 +216,7 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]: ...@@ -205,7 +216,7 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
except Exception as e: except Exception as e:
logger.error(f"Error starting router: {e}") logger.error(f"Error starting router: {e}")
return None raise e
class CustomHelpFormatter( class CustomHelpFormatter(
...@@ -239,10 +250,7 @@ Examples: ...@@ -239,10 +250,7 @@ Examples:
def main() -> None: def main() -> None:
router_args = parse_router_args(sys.argv[1:]) router_args = parse_router_args(sys.argv[1:])
router = launch_router(router_args) launch_router(router_args)
if router is None:
sys.exit(1)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -68,7 +68,7 @@ def run_server(server_args, dp_rank): ...@@ -68,7 +68,7 @@ def run_server(server_args, dp_rank):
# create new process group # create new process group
os.setpgrp() os.setpgrp()
setproctitle(f"sglang::server") setproctitle("sglang::server")
# Set SGLANG_DP_RANK environment variable # Set SGLANG_DP_RANK environment variable
os.environ["SGLANG_DP_RANK"] = str(dp_rank) os.environ["SGLANG_DP_RANK"] = str(dp_rank)
...@@ -120,9 +120,26 @@ def find_available_ports(base_port: int, count: int) -> List[int]: ...@@ -120,9 +120,26 @@ def find_available_ports(base_port: int, count: int) -> List[int]:
def cleanup_processes(processes: List[mp.Process]): def cleanup_processes(processes: List[mp.Process]):
for process in processes: for process in processes:
logger.info(f"Terminating process {process.pid}") logger.info(f"Terminating process group {process.pid}")
process.terminate() try:
logger.info("All processes terminated") os.killpg(process.pid, signal.SIGTERM)
except ProcessLookupError:
# Process group may already be terminated
pass
# Wait for processes to terminate
for process in processes:
process.join(timeout=5)
if process.is_alive():
logger.warning(
f"Process {process.pid} did not terminate gracefully, forcing kill"
)
try:
os.killpg(process.pid, signal.SIGKILL)
except ProcessLookupError:
pass
logger.info("All process groups terminated")
def main(): def main():
...@@ -173,7 +190,12 @@ def main(): ...@@ -173,7 +190,12 @@ def main():
] ]
# Start the router # Start the router
router = launch_router(router_args) try:
launch_router(router_args)
except Exception as e:
logger.error(f"Failed to start router: {e}")
cleanup_processes(server_processes)
sys.exit(1)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -17,6 +17,7 @@ class Router: ...@@ -17,6 +17,7 @@ class Router:
- PolicyType.CacheAware: Distribute requests based on cache state and load balance - PolicyType.CacheAware: Distribute requests based on cache state and load balance
host: Host address to bind the router server. Default: '127.0.0.1' host: Host address to bind the router server. Default: '127.0.0.1'
port: Port number to bind the router server. Default: 3001 port: Port number to bind the router server. Default: 3001
worker_startup_timeout_secs: Timeout in seconds for worker startup. Default: 300
cache_threshold: Cache threshold (0.0-1.0) for cache-aware routing. Routes to cached worker cache_threshold: Cache threshold (0.0-1.0) for cache-aware routing. Routes to cached worker
if the match rate exceeds threshold, otherwise routes to the worker with the smallest if the match rate exceeds threshold, otherwise routes to the worker with the smallest
tree. Default: 0.5 tree. Default: 0.5
...@@ -37,6 +38,7 @@ class Router: ...@@ -37,6 +38,7 @@ class Router:
policy: PolicyType = PolicyType.RoundRobin, policy: PolicyType = PolicyType.RoundRobin,
host: str = "127.0.0.1", host: str = "127.0.0.1",
port: int = 3001, port: int = 3001,
worker_startup_timeout_secs: int = 300,
cache_threshold: float = 0.50, cache_threshold: float = 0.50,
balance_abs_threshold: int = 32, balance_abs_threshold: int = 32,
balance_rel_threshold: float = 1.0001, balance_rel_threshold: float = 1.0001,
...@@ -50,6 +52,7 @@ class Router: ...@@ -50,6 +52,7 @@ class Router:
policy=policy, policy=policy,
host=host, host=host,
port=port, port=port,
worker_startup_timeout_secs=worker_startup_timeout_secs,
cache_threshold=cache_threshold, cache_threshold=cache_threshold,
balance_abs_threshold=balance_abs_threshold, balance_abs_threshold=balance_abs_threshold,
balance_rel_threshold=balance_rel_threshold, balance_rel_threshold=balance_rel_threshold,
......
...@@ -28,6 +28,7 @@ class TestLaunchRouter(unittest.TestCase): ...@@ -28,6 +28,7 @@ class TestLaunchRouter(unittest.TestCase):
host="127.0.0.1", host="127.0.0.1",
port=30000, port=30000,
policy="cache_aware", policy="cache_aware",
worker_startup_timeout_secs=600,
cache_threshold=0.5, cache_threshold=0.5,
balance_abs_threshold=32, balance_abs_threshold=32,
balance_rel_threshold=1.0001, balance_rel_threshold=1.0001,
......
...@@ -17,6 +17,7 @@ struct Router { ...@@ -17,6 +17,7 @@ struct Router {
port: u16, port: u16,
worker_urls: Vec<String>, worker_urls: Vec<String>,
policy: PolicyType, policy: PolicyType,
worker_startup_timeout_secs: u64,
cache_threshold: f32, cache_threshold: f32,
balance_abs_threshold: usize, balance_abs_threshold: usize,
balance_rel_threshold: f32, balance_rel_threshold: f32,
...@@ -34,6 +35,7 @@ impl Router { ...@@ -34,6 +35,7 @@ impl Router {
policy = PolicyType::RoundRobin, policy = PolicyType::RoundRobin,
host = String::from("127.0.0.1"), host = String::from("127.0.0.1"),
port = 3001, port = 3001,
worker_startup_timeout_secs = 300,
cache_threshold = 0.50, cache_threshold = 0.50,
balance_abs_threshold = 32, balance_abs_threshold = 32,
balance_rel_threshold = 1.0001, balance_rel_threshold = 1.0001,
...@@ -47,6 +49,7 @@ impl Router { ...@@ -47,6 +49,7 @@ impl Router {
policy: PolicyType, policy: PolicyType,
host: String, host: String,
port: u16, port: u16,
worker_startup_timeout_secs: u64,
cache_threshold: f32, cache_threshold: f32,
balance_abs_threshold: usize, balance_abs_threshold: usize,
balance_rel_threshold: f32, balance_rel_threshold: f32,
...@@ -60,6 +63,7 @@ impl Router { ...@@ -60,6 +63,7 @@ impl Router {
port, port,
worker_urls, worker_urls,
policy, policy,
worker_startup_timeout_secs,
cache_threshold, cache_threshold,
balance_abs_threshold, balance_abs_threshold,
balance_rel_threshold, balance_rel_threshold,
...@@ -72,9 +76,14 @@ impl Router { ...@@ -72,9 +76,14 @@ impl Router {
fn start(&self) -> PyResult<()> { fn start(&self) -> PyResult<()> {
let policy_config = match &self.policy { let policy_config = match &self.policy {
PolicyType::Random => router::PolicyConfig::RandomConfig, PolicyType::Random => router::PolicyConfig::RandomConfig {
PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig, timeout_secs: self.worker_startup_timeout_secs,
},
PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig {
timeout_secs: self.worker_startup_timeout_secs,
},
PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig { PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig {
timeout_secs: self.worker_startup_timeout_secs,
cache_threshold: self.cache_threshold, cache_threshold: self.cache_threshold,
balance_abs_threshold: self.balance_abs_threshold, balance_abs_threshold: self.balance_abs_threshold,
balance_rel_threshold: self.balance_rel_threshold, balance_rel_threshold: self.balance_rel_threshold,
...@@ -93,10 +102,9 @@ impl Router { ...@@ -93,10 +102,9 @@ impl Router {
max_payload_size: self.max_payload_size, max_payload_size: self.max_payload_size,
}) })
.await .await
.unwrap(); .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
}); Ok(())
})
Ok(())
} }
} }
......
...@@ -3,7 +3,7 @@ use actix_web::http::header::{HeaderValue, CONTENT_TYPE}; ...@@ -3,7 +3,7 @@ use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
use actix_web::{HttpRequest, HttpResponse}; use actix_web::{HttpRequest, HttpResponse};
use bytes::Bytes; use bytes::Bytes;
use futures_util::{StreamExt, TryStreamExt}; use futures_util::{StreamExt, TryStreamExt};
use log::{debug, info, warn}; use log::{debug, error, info, warn};
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::Debug; use std::fmt::Debug;
use std::sync::atomic::AtomicUsize; use std::sync::atomic::AtomicUsize;
...@@ -17,9 +17,11 @@ pub enum Router { ...@@ -17,9 +17,11 @@ pub enum Router {
RoundRobin { RoundRobin {
worker_urls: Arc<RwLock<Vec<String>>>, worker_urls: Arc<RwLock<Vec<String>>>,
current_index: AtomicUsize, current_index: AtomicUsize,
timeout_secs: u64,
}, },
Random { Random {
worker_urls: Arc<RwLock<Vec<String>>>, worker_urls: Arc<RwLock<Vec<String>>>,
timeout_secs: u64,
}, },
CacheAware { CacheAware {
/* /*
...@@ -89,36 +91,51 @@ pub enum Router { ...@@ -89,36 +91,51 @@ pub enum Router {
cache_threshold: f32, cache_threshold: f32,
balance_abs_threshold: usize, balance_abs_threshold: usize,
balance_rel_threshold: f32, balance_rel_threshold: f32,
timeout_secs: u64,
_eviction_thread: Option<thread::JoinHandle<()>>, _eviction_thread: Option<thread::JoinHandle<()>>,
}, },
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum PolicyConfig { pub enum PolicyConfig {
RandomConfig, RandomConfig {
RoundRobinConfig, timeout_secs: u64,
},
RoundRobinConfig {
timeout_secs: u64,
},
CacheAwareConfig { CacheAwareConfig {
cache_threshold: f32, cache_threshold: f32,
balance_abs_threshold: usize, balance_abs_threshold: usize,
balance_rel_threshold: f32, balance_rel_threshold: f32,
eviction_interval_secs: u64, eviction_interval_secs: u64,
max_tree_size: usize, max_tree_size: usize,
timeout_secs: u64,
}, },
} }
impl Router { impl Router {
pub fn new(worker_urls: Vec<String>, policy_config: PolicyConfig) -> Result<Self, String> { pub fn new(worker_urls: Vec<String>, policy_config: PolicyConfig) -> Result<Self, String> {
// Get timeout from policy config
let timeout_secs = match &policy_config {
PolicyConfig::RandomConfig { timeout_secs } => *timeout_secs,
PolicyConfig::RoundRobinConfig { timeout_secs } => *timeout_secs,
PolicyConfig::CacheAwareConfig { timeout_secs, .. } => *timeout_secs,
};
// Wait until all workers are healthy // Wait until all workers are healthy
Self::wait_for_healthy_workers(&worker_urls, 300, 10)?; Self::wait_for_healthy_workers(&worker_urls, timeout_secs, 10)?;
// Create router based on policy... // Create router based on policy...
Ok(match policy_config { Ok(match policy_config {
PolicyConfig::RandomConfig => Router::Random { PolicyConfig::RandomConfig { timeout_secs } => Router::Random {
worker_urls: Arc::new(RwLock::new(worker_urls)), worker_urls: Arc::new(RwLock::new(worker_urls)),
timeout_secs,
}, },
PolicyConfig::RoundRobinConfig => Router::RoundRobin { PolicyConfig::RoundRobinConfig { timeout_secs } => Router::RoundRobin {
worker_urls: Arc::new(RwLock::new(worker_urls)), worker_urls: Arc::new(RwLock::new(worker_urls)),
current_index: std::sync::atomic::AtomicUsize::new(0), current_index: std::sync::atomic::AtomicUsize::new(0),
timeout_secs,
}, },
PolicyConfig::CacheAwareConfig { PolicyConfig::CacheAwareConfig {
cache_threshold, cache_threshold,
...@@ -126,6 +143,7 @@ impl Router { ...@@ -126,6 +143,7 @@ impl Router {
balance_rel_threshold, balance_rel_threshold,
eviction_interval_secs, eviction_interval_secs,
max_tree_size, max_tree_size,
timeout_secs,
} => { } => {
let mut running_queue = HashMap::new(); let mut running_queue = HashMap::new();
for url in &worker_urls { for url in &worker_urls {
...@@ -176,6 +194,7 @@ impl Router { ...@@ -176,6 +194,7 @@ impl Router {
cache_threshold, cache_threshold,
balance_abs_threshold, balance_abs_threshold,
balance_rel_threshold, balance_rel_threshold,
timeout_secs,
_eviction_thread: Some(eviction_thread), _eviction_thread: Some(eviction_thread),
} }
} }
...@@ -192,6 +211,10 @@ impl Router { ...@@ -192,6 +211,10 @@ impl Router {
loop { loop {
if start_time.elapsed() > Duration::from_secs(timeout_secs) { if start_time.elapsed() > Duration::from_secs(timeout_secs) {
error!(
"Timeout {}s waiting for workers to become healthy",
timeout_secs
);
return Err(format!( return Err(format!(
"Timeout {}s waiting for workers to become healthy", "Timeout {}s waiting for workers to become healthy",
timeout_secs timeout_secs
...@@ -238,7 +261,7 @@ impl Router { ...@@ -238,7 +261,7 @@ impl Router {
fn select_first_worker(&self) -> Result<String, String> { fn select_first_worker(&self) -> Result<String, String> {
match self { match self {
Router::RoundRobin { worker_urls, .. } Router::RoundRobin { worker_urls, .. }
| Router::Random { worker_urls } | Router::Random { worker_urls, .. }
| Router::CacheAware { worker_urls, .. } => { | Router::CacheAware { worker_urls, .. } => {
if worker_urls.read().unwrap().is_empty() { if worker_urls.read().unwrap().is_empty() {
Err("No workers are available".to_string()) Err("No workers are available".to_string())
...@@ -349,6 +372,7 @@ impl Router { ...@@ -349,6 +372,7 @@ impl Router {
Router::RoundRobin { Router::RoundRobin {
worker_urls, worker_urls,
current_index, current_index,
..
} => { } => {
let idx = current_index let idx = current_index
.fetch_update( .fetch_update(
...@@ -360,7 +384,7 @@ impl Router { ...@@ -360,7 +384,7 @@ impl Router {
worker_urls.read().unwrap()[idx].clone() worker_urls.read().unwrap()[idx].clone()
} }
Router::Random { worker_urls } => worker_urls.read().unwrap() Router::Random { worker_urls, .. } => worker_urls.read().unwrap()
[rand::random::<usize>() % worker_urls.read().unwrap().len()] [rand::random::<usize>() % worker_urls.read().unwrap().len()]
.clone(), .clone(),
...@@ -571,13 +595,21 @@ impl Router { ...@@ -571,13 +595,21 @@ impl Router {
pub async fn add_worker(&self, worker_url: &str) -> Result<String, String> { pub async fn add_worker(&self, worker_url: &str) -> Result<String, String> {
let interval_secs = 10; // check every 10 seconds let interval_secs = 10; // check every 10 seconds
let timeout_secs = 300; // 5 minutes let timeout_secs = match self {
Router::Random { timeout_secs, .. } => *timeout_secs,
Router::RoundRobin { timeout_secs, .. } => *timeout_secs,
Router::CacheAware { timeout_secs, .. } => *timeout_secs,
};
let start_time = std::time::Instant::now(); let start_time = std::time::Instant::now();
let client = reqwest::Client::new(); let client = reqwest::Client::new();
loop { loop {
if start_time.elapsed() > Duration::from_secs(timeout_secs) { if start_time.elapsed() > Duration::from_secs(timeout_secs) {
error!(
"Timeout {}s waiting for worker {} to become healthy",
timeout_secs, worker_url
);
return Err(format!( return Err(format!(
"Timeout {}s waiting for worker {} to become healthy", "Timeout {}s waiting for worker {} to become healthy",
timeout_secs, worker_url timeout_secs, worker_url
...@@ -589,7 +621,7 @@ impl Router { ...@@ -589,7 +621,7 @@ impl Router {
if res.status().is_success() { if res.status().is_success() {
match self { match self {
Router::RoundRobin { worker_urls, .. } Router::RoundRobin { worker_urls, .. }
| Router::Random { worker_urls } | Router::Random { worker_urls, .. }
| Router::CacheAware { worker_urls, .. } => { | Router::CacheAware { worker_urls, .. } => {
info!("Worker {} health check passed", worker_url); info!("Worker {} health check passed", worker_url);
let mut urls = worker_urls.write().unwrap(); let mut urls = worker_urls.write().unwrap();
...@@ -663,7 +695,7 @@ impl Router { ...@@ -663,7 +695,7 @@ impl Router {
pub fn remove_worker(&self, worker_url: &str) { pub fn remove_worker(&self, worker_url: &str) {
match self { match self {
Router::RoundRobin { worker_urls, .. } Router::RoundRobin { worker_urls, .. }
| Router::Random { worker_urls } | Router::Random { worker_urls, .. }
| Router::CacheAware { worker_urls, .. } => { | Router::CacheAware { worker_urls, .. } => {
let mut urls = worker_urls.write().unwrap(); let mut urls = worker_urls.write().unwrap();
if let Some(index) = urls.iter().position(|url| url == &worker_url) { if let Some(index) = urls.iter().position(|url| url == &worker_url) {
......
...@@ -18,14 +18,10 @@ impl AppState { ...@@ -18,14 +18,10 @@ impl AppState {
worker_urls: Vec<String>, worker_urls: Vec<String>,
client: reqwest::Client, client: reqwest::Client,
policy_config: PolicyConfig, policy_config: PolicyConfig,
) -> Self { ) -> Result<Self, String> {
// Create router based on policy // Create router based on policy
let router = match Router::new(worker_urls, policy_config) { let router = Router::new(worker_urls, policy_config)?;
Ok(router) => router, Ok(Self { router, client })
Err(error) => panic!("Failed to create router: {}", error),
};
Self { router, client }
} }
} }
...@@ -131,6 +127,7 @@ pub struct ServerConfig { ...@@ -131,6 +127,7 @@ pub struct ServerConfig {
} }
pub async fn startup(config: ServerConfig) -> std::io::Result<()> { pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
// Initialize logger
Builder::new() Builder::new()
.format(|buf, record| { .format(|buf, record| {
use chrono::Local; use chrono::Local;
...@@ -152,24 +149,30 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { ...@@ -152,24 +149,30 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
) )
.init(); .init();
info!("🚧 Initializing router on {}:{}", config.host, config.port);
info!("🚧 Initializing workers on {:?}", config.worker_urls);
info!("🚧 Policy Config: {:?}", config.policy_config);
info!(
"🚧 Max payload size: {} MB",
config.max_payload_size / (1024 * 1024)
);
let client = reqwest::Client::builder() let client = reqwest::Client::builder()
.build() .build()
.expect("Failed to create HTTP client"); .expect("Failed to create HTTP client");
let app_state = web::Data::new(AppState::new( let app_state = web::Data::new(
config.worker_urls.clone(), AppState::new(
client, config.worker_urls.clone(),
config.policy_config.clone(), client,
)); config.policy_config.clone(),
)
info!("✅ Starting router on {}:{}", config.host, config.port); .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?,
info!("✅ Serving Worker URLs: {:?}", config.worker_urls);
info!("✅ Policy Config: {:?}", config.policy_config);
info!(
"✅ Max payload size: {} MB",
config.max_payload_size / (1024 * 1024)
); );
info!("✅ Serving router on {}:{}", config.host, config.port);
info!("✅ Serving workers on {:?}", config.worker_urls);
HttpServer::new(move || { HttpServer::new(move || {
App::new() App::new()
.app_data(app_state.clone()) .app_data(app_state.clone())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment