Unverified Commit d4de9a62 authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[router] Refactor: decouple select and send stage (#2440)

parent 7310aede
......@@ -106,28 +106,6 @@ pub enum PolicyConfig {
},
}
fn get_text_from_request(body: &Bytes, route: &str) -> String {
// convert body to json
let json = serde_json::from_slice::<serde_json::Value>(body).unwrap();
if route == "generate" {
// get the "text" field
let text = json.get("text").and_then(|t| t.as_str()).unwrap_or("");
return text.to_string();
} else if route == "v1/chat/completions" {
// get the messages field as raw text
if let Some(messages) = json.get("messages") {
// Convert messages back to a string, preserving all JSON formatting
return serde_json::to_string(messages).unwrap_or_default();
}
} else if route == "v1/completions" {
let prompt = json.get("prompt").and_then(|t| t.as_str()).unwrap_or("");
return prompt.to_string();
}
return "".to_string();
}
impl Router {
pub fn new(worker_urls: Vec<String>, policy_config: PolicyConfig) -> Result<Self, String> {
// Wait until all workers are healthy
......@@ -204,20 +182,6 @@ impl Router {
})
}
pub fn get_first(&self) -> Option<String> {
match self {
Router::RoundRobin { worker_urls, .. }
| Router::Random { worker_urls }
| Router::CacheAware { worker_urls, .. } => {
if worker_urls.read().unwrap().is_empty() {
None
} else {
Some(worker_urls.read().unwrap()[0].clone())
}
}
}
}
fn wait_for_healthy_workers(
worker_urls: &[String],
timeout_secs: u64,
......@@ -271,14 +235,76 @@ impl Router {
}
}
pub async fn dispatch(
fn select_first_worker(&self) -> Result<String, String> {
match self {
Router::RoundRobin { worker_urls, .. }
| Router::Random { worker_urls }
| Router::CacheAware { worker_urls, .. } => {
if worker_urls.read().unwrap().is_empty() {
Err("No workers are available".to_string())
} else {
Ok(worker_urls.read().unwrap()[0].clone())
}
}
}
}
async fn send_request(
&self,
client: &reqwest::Client,
req: HttpRequest,
body: Bytes,
worker_url: String,
route: &str,
) -> HttpResponse {
let text = get_text_from_request(&body, route);
match client.get(format!("{}{}", worker_url, route)).send().await {
Ok(res) => {
let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
match res.bytes().await {
Ok(body) => HttpResponse::build(status).body(body.to_vec()),
Err(e) => HttpResponse::InternalServerError()
.body(format!("Failed to read response body: {}", e)),
}
}
Err(e) => HttpResponse::InternalServerError().body(format!(
"Failed to send request to worker {}: {}",
worker_url, e
)),
}
}
pub async fn route_to_first(&self, client: &reqwest::Client, route: &str) -> HttpResponse {
match self.select_first_worker() {
Ok(worker_url) => self.send_request(client, worker_url, route).await,
Err(e) => HttpResponse::InternalServerError().body(e),
}
}
fn get_text_from_request(&self, body: &Bytes, route: &str) -> String {
// convert body to json
let json = serde_json::from_slice::<serde_json::Value>(body).unwrap();
if route == "generate" {
// get the "text" field
let text = json.get("text").and_then(|t| t.as_str()).unwrap_or("");
return text.to_string();
} else if route == "v1/chat/completions" {
// get the messages field as raw text
if let Some(messages) = json.get("messages") {
// Convert messages back to a string, preserving all JSON formatting
return serde_json::to_string(messages).unwrap_or_default();
}
} else if route == "v1/completions" {
let prompt = json.get("prompt").and_then(|t| t.as_str()).unwrap_or("");
return prompt.to_string();
}
return "".to_string();
}
// TODO: return Result<String, String> instead of panicking
fn select_generate_worker(&self, body: &Bytes, route: &str) -> String {
let text = self.get_text_from_request(&body, route);
let worker_url = match self {
Router::RoundRobin {
......@@ -366,12 +392,23 @@ impl Router {
}
};
worker_url
}
async fn send_generate_request(
&self,
client: &reqwest::Client,
req: HttpRequest,
body: Bytes,
route: &str,
worker_url: &str,
) -> HttpResponse {
let is_stream = serde_json::from_slice::<serde_json::Value>(&body)
.map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false))
.unwrap_or(false);
let res = match client
.post(format!("{}/{}", worker_url.clone(), route))
.post(format!("{}{}", worker_url, route))
.header(
"Content-Type",
req.headers()
......@@ -403,7 +440,7 @@ impl Router {
// Then decrement running queue counter if using CacheAware
if let Router::CacheAware { running_queue, .. } = self {
if let Ok(mut queue) = running_queue.lock() {
if let Some(count) = queue.get_mut(&worker_url) {
if let Some(count) = queue.get_mut(worker_url) {
*count = count.saturating_sub(1);
}
}
......@@ -412,7 +449,7 @@ impl Router {
response
} else if let Router::CacheAware { running_queue, .. } = self {
let running_queue = Arc::clone(running_queue);
let worker_url = worker_url.clone();
let worker_url = worker_url.to_string();
HttpResponse::build(status)
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
......@@ -431,7 +468,7 @@ impl Router {
let mut locked_queue = running_queue.lock().unwrap();
let count = locked_queue.get_mut(&worker_url).unwrap();
*count = count.saturating_sub(1);
debug!("streaming is done!!")
debug!("Streaming is done!!")
}
}),
)
......@@ -444,6 +481,18 @@ impl Router {
}
}
pub async fn route_generate_request(
&self,
client: &reqwest::Client,
req: HttpRequest,
body: Bytes,
route: &str,
) -> HttpResponse {
let worker_url = self.select_generate_worker(&body, route);
self.send_generate_request(client, req, body, route, &worker_url)
.await
}
pub async fn add_worker(&self, worker_url: String) -> Result<String, String> {
let interval_secs = 10; // check every 10 seconds
let timeout_secs = 300; // 5 minutes
......
......@@ -29,84 +29,41 @@ impl AppState {
}
}
async fn forward_request(
client: &reqwest::Client,
worker_url: String,
route: String,
) -> HttpResponse {
match client.get(format!("{}{}", worker_url, route)).send().await {
Ok(res) => {
let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
// print the status
println!(
"Forwarding Request Worker URL: {}, Route: {}, Status: {}",
worker_url, route, status
);
match res.bytes().await {
Ok(body) => HttpResponse::build(status).body(body.to_vec()),
Err(_) => HttpResponse::InternalServerError().finish(),
}
}
Err(_) => HttpResponse::InternalServerError().finish(),
}
}
#[get("/health")]
async fn health(data: web::Data<AppState>) -> impl Responder {
let worker_url = match data.router.get_first() {
Some(url) => url,
None => return HttpResponse::InternalServerError().finish(),
};
forward_request(&data.client, worker_url, "/health".to_string()).await
data.router.route_to_first(&data.client, "/health").await
}
#[get("/health_generate")]
async fn health_generate(data: web::Data<AppState>) -> impl Responder {
let worker_url = match data.router.get_first() {
Some(url) => url,
None => return HttpResponse::InternalServerError().finish(),
};
forward_request(&data.client, worker_url, "/health_generate".to_string()).await
data.router
.route_to_first(&data.client, "/health_generate")
.await
}
#[get("/get_server_info")]
async fn get_server_info(data: web::Data<AppState>) -> impl Responder {
let worker_url = match data.router.get_first() {
Some(url) => url,
None => return HttpResponse::InternalServerError().finish(),
};
forward_request(&data.client, worker_url, "/get_server_info".to_string()).await
data.router
.route_to_first(&data.client, "/get_server_info")
.await
}
#[get("/v1/models")]
async fn v1_models(data: web::Data<AppState>) -> impl Responder {
let worker_url = match data.router.get_first() {
Some(url) => url,
None => return HttpResponse::InternalServerError().finish(),
};
forward_request(&data.client, worker_url, "/v1/models".to_string()).await
data.router.route_to_first(&data.client, "/v1/models").await
}
#[get("/get_model_info")]
async fn get_model_info(data: web::Data<AppState>) -> impl Responder {
let worker_url = match data.router.get_first() {
Some(url) => url,
None => return HttpResponse::InternalServerError().finish(),
};
forward_request(&data.client, worker_url, "/get_model_info".to_string()).await
data.router
.route_to_first(&data.client, "/get_model_info")
.await
}
#[post("/generate")]
async fn generate(req: HttpRequest, body: Bytes, data: web::Data<AppState>) -> impl Responder {
data.router
.dispatch(&data.client, req, body, "generate")
.route_generate_request(&data.client, req, body, "/generate")
.await
}
......@@ -117,7 +74,7 @@ async fn v1_chat_completions(
data: web::Data<AppState>,
) -> impl Responder {
data.router
.dispatch(&data.client, req, body, "v1/chat/completions")
.route_generate_request(&data.client, req, body, "/v1/chat/completions")
.await
}
......@@ -128,7 +85,7 @@ async fn v1_completions(
data: web::Data<AppState>,
) -> impl Responder {
data.router
.dispatch(&data.client, req, body, "v1/completions")
.route_generate_request(&data.client, req, body, "/v1/completions")
.await
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment