Unverified Commit bbb81c24 authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

Add more api routes (completion, health, etc) to the router (#2146)

parent 52f58fc4
...@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" ...@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "sglang-router" name = "sglang-router"
version = "0.0.5" version = "0.0.6"
description = "SGLang router is a standalone module implemented in Rust to achieve data parallelism across SGLang instances." description = "SGLang router is a standalone module implemented in Rust to achieve data parallelism across SGLang instances."
authors = [{name = "Byron Hsu", email = "byronhsu1230@gmail.com"}] authors = [{name = "Byron Hsu", email = "byronhsu1230@gmail.com"}]
requires-python = ">=3.8" requires-python = ">=3.8"
......
...@@ -97,14 +97,27 @@ pub enum PolicyConfig { ...@@ -97,14 +97,27 @@ pub enum PolicyConfig {
}, },
} }
fn get_text_from_request(body: &Bytes) -> String { fn get_text_from_request(body: &Bytes, route: &str) -> String {
// 1. convert body to json // convert body to json
let json = serde_json::from_slice::<serde_json::Value>(body).unwrap(); let json = serde_json::from_slice::<serde_json::Value>(body).unwrap();
// 2. get the text field
let text = json.get("text").and_then(|t| t.as_str()).unwrap_or("");
return text.to_string();
}
if route == "generate" {
// get the "text" field
let text = json.get("text").and_then(|t| t.as_str()).unwrap_or("");
return text.to_string();
} else if route == "v1/chat/completions" {
// get the messages field as raw text
if let Some(messages) = json.get("messages") {
// Convert messages back to a string, preserving all JSON formatting
return serde_json::to_string(messages).unwrap_or_default();
}
} else if route == "v1/completions" {
let prompt = json.get("prompt").and_then(|t| t.as_str()).unwrap_or("");
return prompt.to_string();
}
return "".to_string();
}
impl Router { impl Router {
pub fn new(worker_urls: Vec<String>, policy_config: PolicyConfig) -> Self { pub fn new(worker_urls: Vec<String>, policy_config: PolicyConfig) -> Self {
match policy_config { match policy_config {
...@@ -187,8 +200,11 @@ impl Router { ...@@ -187,8 +200,11 @@ impl Router {
client: &reqwest::Client, client: &reqwest::Client,
req: HttpRequest, req: HttpRequest,
body: Bytes, body: Bytes,
route: &str,
) -> HttpResponse { ) -> HttpResponse {
let text = get_text_from_request(&body); let text = get_text_from_request(&body, route);
// For Debug
// println!("text: {:?}, route: {:?}", text, route);
let worker_url = match self { let worker_url = match self {
Router::RoundRobin { Router::RoundRobin {
...@@ -236,13 +252,14 @@ impl Router { ...@@ -236,13 +252,14 @@ impl Router {
if matched_rate > *cache_threshold { if matched_rate > *cache_threshold {
matched_worker.to_string() matched_worker.to_string()
} else { } else {
let m_map: HashMap<String, usize> = tree // For Debug
.tenant_char_count // let m_map: HashMap<String, usize> = tree
.iter() // .tenant_char_count
.map(|entry| (entry.key().clone(), *entry.value())) // .iter()
.collect(); // .map(|entry| (entry.key().clone(), *entry.value()))
// .collect();
println!("map: {:?}, mmap: {:?}", tree.get_tenant_char_count(), m_map); // println!("map: {:?}, mmap: {:?}", tree.get_tenant_char_count(), m_map);
tree.get_smallest_tenant() tree.get_smallest_tenant()
} }
...@@ -276,7 +293,7 @@ impl Router { ...@@ -276,7 +293,7 @@ impl Router {
.unwrap_or(false); .unwrap_or(false);
let res = match client let res = match client
.post(format!("{}/generate", worker_url.clone())) .post(format!("{}/{}", worker_url.clone(), route))
.header( .header(
"Content-Type", "Content-Type",
req.headers() req.headers()
......
...@@ -33,7 +33,10 @@ async fn forward_request( ...@@ -33,7 +33,10 @@ async fn forward_request(
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
// print the status // print the status
println!("Worker URL: {}, Status: {}", worker_url, status); println!(
"Forwarding Request Worker URL: {}, Route: {}, Status: {}",
worker_url, route, status
);
match res.bytes().await { match res.bytes().await {
Ok(body) => HttpResponse::build(status).body(body.to_vec()), Ok(body) => HttpResponse::build(status).body(body.to_vec()),
Err(_) => HttpResponse::InternalServerError().finish(), Err(_) => HttpResponse::InternalServerError().finish(),
...@@ -43,8 +46,38 @@ async fn forward_request( ...@@ -43,8 +46,38 @@ async fn forward_request(
} }
} }
#[get("/health")]
async fn health(data: web::Data<AppState>) -> impl Responder {
let worker_url = match data.router.get_first() {
Some(url) => url,
None => return HttpResponse::InternalServerError().finish(),
};
forward_request(&data.client, worker_url, "/health".to_string()).await
}
#[get("/health_generate")]
async fn health_generate(data: web::Data<AppState>) -> impl Responder {
let worker_url = match data.router.get_first() {
Some(url) => url,
None => return HttpResponse::InternalServerError().finish(),
};
forward_request(&data.client, worker_url, "/health_generate".to_string()).await
}
#[get("/get_server_args")]
async fn get_server_args(data: web::Data<AppState>) -> impl Responder {
let worker_url = match data.router.get_first() {
Some(url) => url,
None => return HttpResponse::InternalServerError().finish(),
};
forward_request(&data.client, worker_url, "/get_server_args".to_string()).await
}
#[get("/v1/models")] #[get("/v1/models")]
async fn v1_model(data: web::Data<AppState>) -> impl Responder { async fn v1_models(data: web::Data<AppState>) -> impl Responder {
let worker_url = match data.router.get_first() { let worker_url = match data.router.get_first() {
Some(url) => url, Some(url) => url,
None => return HttpResponse::InternalServerError().finish(), None => return HttpResponse::InternalServerError().finish(),
...@@ -65,7 +98,31 @@ async fn get_model_info(data: web::Data<AppState>) -> impl Responder { ...@@ -65,7 +98,31 @@ async fn get_model_info(data: web::Data<AppState>) -> impl Responder {
#[post("/generate")] #[post("/generate")]
async fn generate(req: HttpRequest, body: Bytes, data: web::Data<AppState>) -> impl Responder { async fn generate(req: HttpRequest, body: Bytes, data: web::Data<AppState>) -> impl Responder {
data.router.dispatch(&data.client, req, body).await data.router
.dispatch(&data.client, req, body, "generate")
.await
}
#[post("/v1/chat/completions")]
async fn v1_chat_completions(
req: HttpRequest,
body: Bytes,
data: web::Data<AppState>,
) -> impl Responder {
data.router
.dispatch(&data.client, req, body, "v1/chat/completions")
.await
}
#[post("/v1/completions")]
async fn v1_completions(
req: HttpRequest,
body: Bytes,
data: web::Data<AppState>,
) -> impl Responder {
data.router
.dispatch(&data.client, req, body, "v1/completions")
.await
} }
pub async fn startup( pub async fn startup(
...@@ -90,8 +147,13 @@ pub async fn startup( ...@@ -90,8 +147,13 @@ pub async fn startup(
App::new() App::new()
.app_data(app_state.clone()) .app_data(app_state.clone())
.service(generate) .service(generate)
.service(v1_model) .service(v1_chat_completions)
.service(v1_completions)
.service(v1_models)
.service(get_model_info) .service(get_model_info)
.service(health)
.service(health_generate)
.service(get_server_args)
}) })
.bind((host, port))? .bind((host, port))?
.run() .run()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment