Unverified Commit 67b65794 authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[router] support `/add_worker` api (#2369)

parent 37ee906f
import socket
import subprocess
import time
import unittest
......@@ -49,7 +50,7 @@ def popen_launch_router(
# Use current environment
env = None
process = subprocess.Popen(command, stdout=None, stderr=None, env=env)
process = subprocess.Popen(command, stdout=None, stderr=None)
start_time = time.time()
with requests.Session() as session:
......@@ -57,6 +58,52 @@ def popen_launch_router(
try:
response = session.get(f"{base_url}/health")
if response.status_code == 200:
print(f"Router {base_url} is healthy")
return process
except requests.RequestException:
pass
time.sleep(10)
raise TimeoutError("Router failed to start within the timeout period.")
def find_available_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
def popen_launch_server(
model: str,
base_url: str,
timeout: float,
):
_, host, port = base_url.split(":")
host = host[2:]
command = [
"python3",
"-m",
"sglang.launch_server",
"--model-path",
model,
"--host",
host,
"--port",
port,
"--base-gpu-id",
"1",
]
process = subprocess.Popen(command, stdout=None, stderr=None)
start_time = time.time()
with requests.Session() as session:
while time.time() - start_time < timeout:
try:
response = session.get(f"{base_url}/health")
if response.status_code == 200:
print(f"Server {base_url} is healthy")
return process
except requests.RequestException:
pass
......@@ -76,10 +123,13 @@ class TestEvalAccuracyMini(unittest.TestCase):
dp_size=1,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
)
cls.other_process = []
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
for process in cls.other_process:
kill_process_tree(process.pid)
def test_mmlu(self):
args = SimpleNamespace(
......@@ -98,6 +148,35 @@ class TestEvalAccuracyMini(unittest.TestCase):
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
self.assertGreaterEqual(score, THRESHOLD, msg)
def test_add_worker(self):
# 1. start a worker, and wait until it is healthy
port = find_available_port()
worker_url = f"http://127.0.0.1:{port}"
worker_process = popen_launch_server(
self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
self.other_process.append(worker_process)
# 2. use /add_worker api to add it the the router
with requests.Session() as session:
response = session.post(f"{self.base_url}/add_worker?url={worker_url}")
print(f"status code: {response.status_code}, response: {response.text}")
self.assertEqual(response.status_code, 200)
# 3. run mmlu
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
temperature=0.1,
)
metrics = run_eval(args)
score = metrics["score"]
THRESHOLD = 0.65
passed = score >= THRESHOLD
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
self.assertGreaterEqual(score, THRESHOLD, msg)
if __name__ == "__main__":
unittest.main()
......@@ -7,18 +7,18 @@ use log::{debug, info};
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::atomic::AtomicUsize;
use std::sync::{Arc, Mutex};
use std::sync::{Arc, Mutex, RwLock};
use std::thread;
use std::time::Duration;
#[derive(Debug)]
pub enum Router {
RoundRobin {
worker_urls: Vec<String>,
worker_urls: Arc<RwLock<Vec<String>>>,
current_index: AtomicUsize,
},
Random {
worker_urls: Vec<String>,
worker_urls: Arc<RwLock<Vec<String>>>,
},
CacheAware {
/*
......@@ -81,7 +81,7 @@ pub enum Router {
Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted
during the next eviction cycle.
*/
worker_urls: Vec<String>,
worker_urls: Arc<RwLock<Vec<String>>>,
tree: Arc<Mutex<Tree>>,
running_queue: Arc<Mutex<HashMap<String, usize>>>,
processed_queue: Arc<Mutex<HashMap<String, usize>>>,
......@@ -129,9 +129,11 @@ fn get_text_from_request(body: &Bytes, route: &str) -> String {
impl Router {
pub fn new(worker_urls: Vec<String>, policy_config: PolicyConfig) -> Self {
match policy_config {
PolicyConfig::RandomConfig => Router::Random { worker_urls },
PolicyConfig::RandomConfig => Router::Random {
worker_urls: Arc::new(RwLock::new(worker_urls)),
},
PolicyConfig::RoundRobinConfig => Router::RoundRobin {
worker_urls,
worker_urls: Arc::new(RwLock::new(worker_urls)),
current_index: std::sync::atomic::AtomicUsize::new(0),
},
PolicyConfig::CacheAwareConfig {
......@@ -183,7 +185,7 @@ impl Router {
}
Router::CacheAware {
worker_urls,
worker_urls: Arc::new(RwLock::new(worker_urls)),
tree,
running_queue,
processed_queue,
......@@ -201,10 +203,10 @@ impl Router {
Router::RoundRobin { worker_urls, .. }
| Router::Random { worker_urls }
| Router::CacheAware { worker_urls, .. } => {
if worker_urls.is_empty() {
if worker_urls.read().unwrap().is_empty() {
None
} else {
Some(worker_urls[0].clone())
Some(worker_urls.read().unwrap()[0].clone())
}
}
}
......@@ -228,15 +230,15 @@ impl Router {
.fetch_update(
std::sync::atomic::Ordering::SeqCst,
std::sync::atomic::Ordering::SeqCst,
|x| Some((x + 1) % worker_urls.len()),
|x| Some((x + 1) % worker_urls.read().unwrap().len()),
)
.unwrap();
worker_urls[idx].clone()
worker_urls.read().unwrap()[idx].clone()
}
Router::Random { worker_urls } => {
worker_urls[rand::random::<usize>() % worker_urls.len()].clone()
}
Router::Random { worker_urls } => worker_urls.read().unwrap()
[rand::random::<usize>() % worker_urls.read().unwrap().len()]
.clone(),
Router::CacheAware {
worker_urls,
......@@ -277,7 +279,7 @@ impl Router {
.iter()
.min_by_key(|(_url, &count)| count)
.map(|(url, _)| url.clone())
.unwrap_or_else(|| worker_urls[0].clone())
.unwrap_or_else(|| worker_urls.read().unwrap()[0].clone())
} else {
// Use cache-aware routing when load is balanced
let (matched_text, matched_worker) = tree.prefix_match(&text);
......@@ -333,7 +335,10 @@ impl Router {
// For non-streaming requests, get response first
let response = match res.bytes().await {
Ok(body) => HttpResponse::build(status).body(body.to_vec()),
Err(_) => HttpResponse::InternalServerError().finish(),
Err(e) => {
let error_msg = format!("Failed to get response body: {}", e);
HttpResponse::InternalServerError().body(error_msg)
}
};
// Then decrement running queue counter if using CacheAware
......@@ -379,4 +384,16 @@ impl Router {
}))
}
}
pub fn add_worker(&self, worker_url: String) {
match self {
Router::RoundRobin { worker_urls, .. }
| Router::Random { worker_urls }
| Router::CacheAware { worker_urls, .. } => {
let mut urls = worker_urls.write().unwrap();
info!("Added worker: {}", worker_url);
urls.push(worker_url);
}
}
}
}
use crate::router::PolicyConfig;
use crate::router::Router;
use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Responder};
use actix_web::{
delete, get, post, put, web, App, HttpRequest, HttpResponse, HttpServer, Responder,
};
use bytes::Bytes;
use env_logger::Builder;
use log::{info, LevelFilter};
use std::collections::HashMap;
use std::io::Write;
#[derive(Debug)]
......@@ -128,6 +131,22 @@ async fn v1_completions(
.await
}
#[post("/add_worker")]
async fn add_worker(
query: web::Query<HashMap<String, String>>,
data: web::Data<AppState>,
) -> impl Responder {
let worker_url = match query.get("url") {
Some(url) => url.to_string(),
None => {
return HttpResponse::BadRequest()
.body("Worker URL required. Provide 'url' query parameter")
}
};
data.router.add_worker(worker_url);
HttpResponse::Ok().finish()
}
pub struct ServerConfig {
pub host: String,
pub port: u16,
......@@ -183,6 +202,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
.service(health)
.service(health_generate)
.service(get_server_info)
.service(add_worker)
})
.bind((config.host, config.port))?
.run()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment