"llm/vscode:/vscode.git/clone" did not exist on "f4bf1d514f537af9166f72fa00feda04556fc3d5"
Unverified Commit 67b65794 authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[router] support `/add_worker` api (#2369)

parent 37ee906f
import socket
import subprocess import subprocess
import time import time
import unittest import unittest
...@@ -49,7 +50,7 @@ def popen_launch_router( ...@@ -49,7 +50,7 @@ def popen_launch_router(
# Use current environment # Use current environment
env = None env = None
process = subprocess.Popen(command, stdout=None, stderr=None, env=env) process = subprocess.Popen(command, stdout=None, stderr=None)
start_time = time.time() start_time = time.time()
with requests.Session() as session: with requests.Session() as session:
...@@ -57,6 +58,52 @@ def popen_launch_router( ...@@ -57,6 +58,52 @@ def popen_launch_router(
try: try:
response = session.get(f"{base_url}/health") response = session.get(f"{base_url}/health")
if response.status_code == 200: if response.status_code == 200:
print(f"Router {base_url} is healthy")
return process
except requests.RequestException:
pass
time.sleep(10)
raise TimeoutError("Router failed to start within the timeout period.")
def find_available_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
def popen_launch_server(
model: str,
base_url: str,
timeout: float,
):
_, host, port = base_url.split(":")
host = host[2:]
command = [
"python3",
"-m",
"sglang.launch_server",
"--model-path",
model,
"--host",
host,
"--port",
port,
"--base-gpu-id",
"1",
]
process = subprocess.Popen(command, stdout=None, stderr=None)
start_time = time.time()
with requests.Session() as session:
while time.time() - start_time < timeout:
try:
response = session.get(f"{base_url}/health")
if response.status_code == 200:
print(f"Server {base_url} is healthy")
return process return process
except requests.RequestException: except requests.RequestException:
pass pass
...@@ -76,10 +123,13 @@ class TestEvalAccuracyMini(unittest.TestCase): ...@@ -76,10 +123,13 @@ class TestEvalAccuracyMini(unittest.TestCase):
dp_size=1, dp_size=1,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
) )
cls.other_process = []
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_process_tree(cls.process.pid) kill_process_tree(cls.process.pid)
for process in cls.other_process:
kill_process_tree(process.pid)
def test_mmlu(self): def test_mmlu(self):
args = SimpleNamespace( args = SimpleNamespace(
...@@ -98,6 +148,35 @@ class TestEvalAccuracyMini(unittest.TestCase): ...@@ -98,6 +148,35 @@ class TestEvalAccuracyMini(unittest.TestCase):
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
self.assertGreaterEqual(score, THRESHOLD, msg) self.assertGreaterEqual(score, THRESHOLD, msg)
def test_add_worker(self):
# 1. start a worker, and wait until it is healthy
port = find_available_port()
worker_url = f"http://127.0.0.1:{port}"
worker_process = popen_launch_server(
self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
self.other_process.append(worker_process)
# 2. use /add_worker api to add it the the router
with requests.Session() as session:
response = session.post(f"{self.base_url}/add_worker?url={worker_url}")
print(f"status code: {response.status_code}, response: {response.text}")
self.assertEqual(response.status_code, 200)
# 3. run mmlu
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
temperature=0.1,
)
metrics = run_eval(args)
score = metrics["score"]
THRESHOLD = 0.65
passed = score >= THRESHOLD
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
self.assertGreaterEqual(score, THRESHOLD, msg)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -7,18 +7,18 @@ use log::{debug, info}; ...@@ -7,18 +7,18 @@ use log::{debug, info};
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::Debug; use std::fmt::Debug;
use std::sync::atomic::AtomicUsize; use std::sync::atomic::AtomicUsize;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex, RwLock};
use std::thread; use std::thread;
use std::time::Duration; use std::time::Duration;
#[derive(Debug)] #[derive(Debug)]
pub enum Router { pub enum Router {
RoundRobin { RoundRobin {
worker_urls: Vec<String>, worker_urls: Arc<RwLock<Vec<String>>>,
current_index: AtomicUsize, current_index: AtomicUsize,
}, },
Random { Random {
worker_urls: Vec<String>, worker_urls: Arc<RwLock<Vec<String>>>,
}, },
CacheAware { CacheAware {
/* /*
...@@ -81,7 +81,7 @@ pub enum Router { ...@@ -81,7 +81,7 @@ pub enum Router {
Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted
during the next eviction cycle. during the next eviction cycle.
*/ */
worker_urls: Vec<String>, worker_urls: Arc<RwLock<Vec<String>>>,
tree: Arc<Mutex<Tree>>, tree: Arc<Mutex<Tree>>,
running_queue: Arc<Mutex<HashMap<String, usize>>>, running_queue: Arc<Mutex<HashMap<String, usize>>>,
processed_queue: Arc<Mutex<HashMap<String, usize>>>, processed_queue: Arc<Mutex<HashMap<String, usize>>>,
...@@ -129,9 +129,11 @@ fn get_text_from_request(body: &Bytes, route: &str) -> String { ...@@ -129,9 +129,11 @@ fn get_text_from_request(body: &Bytes, route: &str) -> String {
impl Router { impl Router {
pub fn new(worker_urls: Vec<String>, policy_config: PolicyConfig) -> Self { pub fn new(worker_urls: Vec<String>, policy_config: PolicyConfig) -> Self {
match policy_config { match policy_config {
PolicyConfig::RandomConfig => Router::Random { worker_urls }, PolicyConfig::RandomConfig => Router::Random {
worker_urls: Arc::new(RwLock::new(worker_urls)),
},
PolicyConfig::RoundRobinConfig => Router::RoundRobin { PolicyConfig::RoundRobinConfig => Router::RoundRobin {
worker_urls, worker_urls: Arc::new(RwLock::new(worker_urls)),
current_index: std::sync::atomic::AtomicUsize::new(0), current_index: std::sync::atomic::AtomicUsize::new(0),
}, },
PolicyConfig::CacheAwareConfig { PolicyConfig::CacheAwareConfig {
...@@ -183,7 +185,7 @@ impl Router { ...@@ -183,7 +185,7 @@ impl Router {
} }
Router::CacheAware { Router::CacheAware {
worker_urls, worker_urls: Arc::new(RwLock::new(worker_urls)),
tree, tree,
running_queue, running_queue,
processed_queue, processed_queue,
...@@ -201,10 +203,10 @@ impl Router { ...@@ -201,10 +203,10 @@ impl Router {
Router::RoundRobin { worker_urls, .. } Router::RoundRobin { worker_urls, .. }
| Router::Random { worker_urls } | Router::Random { worker_urls }
| Router::CacheAware { worker_urls, .. } => { | Router::CacheAware { worker_urls, .. } => {
if worker_urls.is_empty() { if worker_urls.read().unwrap().is_empty() {
None None
} else { } else {
Some(worker_urls[0].clone()) Some(worker_urls.read().unwrap()[0].clone())
} }
} }
} }
...@@ -228,15 +230,15 @@ impl Router { ...@@ -228,15 +230,15 @@ impl Router {
.fetch_update( .fetch_update(
std::sync::atomic::Ordering::SeqCst, std::sync::atomic::Ordering::SeqCst,
std::sync::atomic::Ordering::SeqCst, std::sync::atomic::Ordering::SeqCst,
|x| Some((x + 1) % worker_urls.len()), |x| Some((x + 1) % worker_urls.read().unwrap().len()),
) )
.unwrap(); .unwrap();
worker_urls[idx].clone() worker_urls.read().unwrap()[idx].clone()
} }
Router::Random { worker_urls } => { Router::Random { worker_urls } => worker_urls.read().unwrap()
worker_urls[rand::random::<usize>() % worker_urls.len()].clone() [rand::random::<usize>() % worker_urls.read().unwrap().len()]
} .clone(),
Router::CacheAware { Router::CacheAware {
worker_urls, worker_urls,
...@@ -277,7 +279,7 @@ impl Router { ...@@ -277,7 +279,7 @@ impl Router {
.iter() .iter()
.min_by_key(|(_url, &count)| count) .min_by_key(|(_url, &count)| count)
.map(|(url, _)| url.clone()) .map(|(url, _)| url.clone())
.unwrap_or_else(|| worker_urls[0].clone()) .unwrap_or_else(|| worker_urls.read().unwrap()[0].clone())
} else { } else {
// Use cache-aware routing when load is balanced // Use cache-aware routing when load is balanced
let (matched_text, matched_worker) = tree.prefix_match(&text); let (matched_text, matched_worker) = tree.prefix_match(&text);
...@@ -333,7 +335,10 @@ impl Router { ...@@ -333,7 +335,10 @@ impl Router {
// For non-streaming requests, get response first // For non-streaming requests, get response first
let response = match res.bytes().await { let response = match res.bytes().await {
Ok(body) => HttpResponse::build(status).body(body.to_vec()), Ok(body) => HttpResponse::build(status).body(body.to_vec()),
Err(_) => HttpResponse::InternalServerError().finish(), Err(e) => {
let error_msg = format!("Failed to get response body: {}", e);
HttpResponse::InternalServerError().body(error_msg)
}
}; };
// Then decrement running queue counter if using CacheAware // Then decrement running queue counter if using CacheAware
...@@ -379,4 +384,16 @@ impl Router { ...@@ -379,4 +384,16 @@ impl Router {
})) }))
} }
} }
pub fn add_worker(&self, worker_url: String) {
match self {
Router::RoundRobin { worker_urls, .. }
| Router::Random { worker_urls }
| Router::CacheAware { worker_urls, .. } => {
let mut urls = worker_urls.write().unwrap();
info!("Added worker: {}", worker_url);
urls.push(worker_url);
}
}
}
} }
use crate::router::PolicyConfig; use crate::router::PolicyConfig;
use crate::router::Router; use crate::router::Router;
use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Responder}; use actix_web::{
delete, get, post, put, web, App, HttpRequest, HttpResponse, HttpServer, Responder,
};
use bytes::Bytes; use bytes::Bytes;
use env_logger::Builder; use env_logger::Builder;
use log::{info, LevelFilter}; use log::{info, LevelFilter};
use std::collections::HashMap;
use std::io::Write; use std::io::Write;
#[derive(Debug)] #[derive(Debug)]
...@@ -128,6 +131,22 @@ async fn v1_completions( ...@@ -128,6 +131,22 @@ async fn v1_completions(
.await .await
} }
#[post("/add_worker")]
async fn add_worker(
query: web::Query<HashMap<String, String>>,
data: web::Data<AppState>,
) -> impl Responder {
let worker_url = match query.get("url") {
Some(url) => url.to_string(),
None => {
return HttpResponse::BadRequest()
.body("Worker URL required. Provide 'url' query parameter")
}
};
data.router.add_worker(worker_url);
HttpResponse::Ok().finish()
}
pub struct ServerConfig { pub struct ServerConfig {
pub host: String, pub host: String,
pub port: u16, pub port: u16,
...@@ -183,6 +202,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { ...@@ -183,6 +202,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
.service(health) .service(health)
.service(health_generate) .service(health_generate)
.service(get_server_info) .service(get_server_info)
.service(add_worker)
}) })
.bind((config.host, config.port))? .bind((config.host, config.port))?
.run() .run()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment