"tests/vscode:/vscode.git/clone" did not exist on "2de9e2df368241cf13f859cf51514cea4e53aed5"
Unverified Commit ef995dae authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[router] Health check on worker before adding to the router (#2392)

parent 75ae9689
# This file is automatically @generated by Cargo. # This file is automatically @generated by Cargo.
# It is not intended for manual editing. # It is not intended for manual editing.
version = 3 version = 4
[[package]] [[package]]
name = "actix-codec" name = "actix-codec"
...@@ -2219,6 +2219,7 @@ dependencies = [ ...@@ -2219,6 +2219,7 @@ dependencies = [
"serde", "serde",
"serde_json", "serde_json",
"tokenizers", "tokenizers",
"tokio",
] ]
[[package]] [[package]]
...@@ -2475,9 +2476,9 @@ dependencies = [ ...@@ -2475,9 +2476,9 @@ dependencies = [
[[package]] [[package]]
name = "tokio" name = "tokio"
version = "1.41.0" version = "1.42.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "145f3413504347a2be84393cc8a7d2fb4d863b375909ea59f2158261aa258bbb" checksum = "5cec9b21b0450273377fc97bd4c33a8acffc8c996c987a7c5b319a0083707551"
dependencies = [ dependencies = [
"backtrace", "backtrace",
"bytes", "bytes",
......
...@@ -29,6 +29,7 @@ http = "1.1.0" ...@@ -29,6 +29,7 @@ http = "1.1.0"
env_logger = "0.11.5" env_logger = "0.11.5"
log = "0.4.22" log = "0.4.22"
chrono = "0.4.38" chrono = "0.4.38"
tokio = "1.42.0"
[profile.release] [profile.release]
lto = "thin" lto = "thin"
......
...@@ -20,6 +20,7 @@ def popen_launch_router( ...@@ -20,6 +20,7 @@ def popen_launch_router(
base_url: str, base_url: str,
dp_size: int, dp_size: int,
timeout: float, timeout: float,
policy: str = "cache_aware",
): ):
""" """
Launch the router server process. Launch the router server process.
...@@ -29,6 +30,7 @@ def popen_launch_router( ...@@ -29,6 +30,7 @@ def popen_launch_router(
base_url: Server base URL base_url: Server base URL
dp_size: Data parallel size dp_size: Data parallel size
timeout: Server launch timeout timeout: Server launch timeout
policy: Router policy, one of "cache_aware", "round_robin", "random"
""" """
_, host, port = base_url.split(":") _, host, port = base_url.split(":")
host = host[2:] host = host[2:]
...@@ -47,11 +49,10 @@ def popen_launch_router( ...@@ -47,11 +49,10 @@ def popen_launch_router(
str(dp_size), # Convert dp_size to string str(dp_size), # Convert dp_size to string
"--router-eviction-interval", "--router-eviction-interval",
"5", # frequent eviction for testing "5", # frequent eviction for testing
"--router-policy",
policy,
] ]
# Use current environment
env = None
process = subprocess.Popen(command, stdout=None, stderr=None) process = subprocess.Popen(command, stdout=None, stderr=None)
start_time = time.time() start_time = time.time()
...@@ -99,19 +100,8 @@ def popen_launch_server( ...@@ -99,19 +100,8 @@ def popen_launch_server(
process = subprocess.Popen(command, stdout=None, stderr=None) process = subprocess.Popen(command, stdout=None, stderr=None)
start_time = time.time() # intentionally don't wait and defer the job to the router health check
with requests.Session() as session:
while time.time() - start_time < timeout:
try:
response = session.get(f"{base_url}/health")
if response.status_code == 200:
print(f"Server {base_url} is healthy")
return process return process
except requests.RequestException:
pass
time.sleep(10)
raise TimeoutError("Server failed to start within the timeout period.")
class TestLaunchServer(unittest.TestCase): class TestLaunchServer(unittest.TestCase):
...@@ -135,6 +125,7 @@ class TestLaunchServer(unittest.TestCase): ...@@ -135,6 +125,7 @@ class TestLaunchServer(unittest.TestCase):
self.base_url, self.base_url,
dp_size=2, dp_size=2,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
policy="cache_aware",
) )
args = SimpleNamespace( args = SimpleNamespace(
...@@ -160,6 +151,7 @@ class TestLaunchServer(unittest.TestCase): ...@@ -160,6 +151,7 @@ class TestLaunchServer(unittest.TestCase):
self.base_url, self.base_url,
dp_size=1, dp_size=1,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
policy="round_robin", # use round robin to make sure every worker processes requests
) )
# 1. start a worker, and wait until it is healthy # 1. start a worker, and wait until it is healthy
port = find_available_port() port = find_available_port()
...@@ -168,11 +160,13 @@ class TestLaunchServer(unittest.TestCase): ...@@ -168,11 +160,13 @@ class TestLaunchServer(unittest.TestCase):
self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
) )
TestLaunchServer.other_process.append(worker_process) TestLaunchServer.other_process.append(worker_process)
# 2. use /add_worker api to add it the the router
# 2. use /add_worker api to add it the the router. It will be used by router after it is healthy
with requests.Session() as session: with requests.Session() as session:
response = session.post(f"{self.base_url}/add_worker?url={worker_url}") response = session.post(f"{self.base_url}/add_worker?url={worker_url}")
print(f"status code: {response.status_code}, response: {response.text}") print(f"status code: {response.status_code}, response: {response.text}")
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
# 3. run mmlu # 3. run mmlu
args = SimpleNamespace( args = SimpleNamespace(
base_url=self.base_url, base_url=self.base_url,
......
...@@ -3,13 +3,14 @@ use actix_web::http::header::{HeaderValue, CONTENT_TYPE}; ...@@ -3,13 +3,14 @@ use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
use actix_web::{HttpRequest, HttpResponse}; use actix_web::{HttpRequest, HttpResponse};
use bytes::Bytes; use bytes::Bytes;
use futures_util::{StreamExt, TryStreamExt}; use futures_util::{StreamExt, TryStreamExt};
use log::{debug, info}; use log::{debug, info, warn};
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::Debug; use std::fmt::Debug;
use std::sync::atomic::AtomicUsize; use std::sync::atomic::AtomicUsize;
use std::sync::{Arc, Mutex, RwLock}; use std::sync::{Arc, Mutex, RwLock};
use std::thread; use std::thread;
use std::time::Duration; use std::time::Duration;
use tokio;
#[derive(Debug)] #[derive(Debug)]
pub enum Router { pub enum Router {
...@@ -385,14 +386,66 @@ impl Router { ...@@ -385,14 +386,66 @@ impl Router {
} }
} }
pub fn add_worker(&self, worker_url: String) { pub async fn add_worker(&self, worker_url: String) -> HttpResponse {
let interval_secs = 10; // check every 10 seconds
let timeout_secs = 300; // 5 minutes
let start_time = std::time::Instant::now();
let client = reqwest::Client::new();
loop {
if start_time.elapsed() > Duration::from_secs(timeout_secs) {
return HttpResponse::InternalServerError().body(format!(
"Timeout {}s waiting for worker {} to become healthy",
timeout_secs, worker_url
));
}
match client.get(&format!("{}/health", worker_url)).send().await {
Ok(res) => {
if res.status().is_success() {
match self { match self {
Router::RoundRobin { worker_urls, .. } Router::RoundRobin { worker_urls, .. }
| Router::Random { worker_urls } | Router::Random { worker_urls }
| Router::CacheAware { worker_urls, .. } => { | Router::CacheAware { worker_urls, .. } => {
info!("Worker {} health check passed", worker_url);
let mut urls = worker_urls.write().unwrap(); let mut urls = worker_urls.write().unwrap();
if urls.contains(&worker_url) {
return HttpResponse::BadRequest()
.body(format!("Worker {} already exists", worker_url));
}
info!("Added worker: {}", worker_url); info!("Added worker: {}", worker_url);
urls.push(worker_url); urls.push(worker_url.clone());
}
}
return HttpResponse::Ok()
.body(format!("Successfully added worker: {}", worker_url));
} else {
info!(
"Worker {} health check failed with status: {}. The worker might still be starting up.",
worker_url, res.status()
);
// if the url does not have http or https prefix, warn users
if !worker_url.starts_with("http://") && !worker_url.starts_with("https://")
{
warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url);
}
tokio::time::sleep(Duration::from_secs(interval_secs)).await;
continue;
}
}
Err(e) => {
info!("Worker {} health check failed: {}", worker_url, e);
// if the url does not have http or https prefix, warn users
if !worker_url.starts_with("http://") && !worker_url.starts_with("https://") {
warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url);
}
tokio::time::sleep(Duration::from_secs(interval_secs)).await;
continue;
}
} }
} }
} }
......
...@@ -141,8 +141,7 @@ async fn add_worker( ...@@ -141,8 +141,7 @@ async fn add_worker(
.body("Worker URL required. Provide 'url' query parameter") .body("Worker URL required. Provide 'url' query parameter")
} }
}; };
data.router.add_worker(worker_url); data.router.add_worker(worker_url).await
HttpResponse::Ok().finish()
} }
#[post("/remove_worker")] #[post("/remove_worker")]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment