Unverified Commit 38907fe6 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

refactor(pd-router): extract common patterns to reduce code duplication (#9081)

parent f9afa7dc
...@@ -72,6 +72,138 @@ impl PDRouter { ...@@ -72,6 +72,138 @@ impl PDRouter {
}) })
} }
// Generic helper for processing all workers with an endpoint
async fn process_workers(
&self,
workers: &RwLock<Vec<Box<dyn Worker>>>,
worker_type: &str,
endpoint: &str,
) -> (Vec<String>, Vec<String>) {
let mut results = Vec::new();
let mut errors = Vec::new();
// Get worker URLs first to avoid holding lock across await
let urls = match workers.read() {
Ok(workers) => workers
.iter()
.map(|w| w.url().to_string())
.collect::<Vec<_>>(),
Err(_) => {
errors.push(format!("Failed to access {} workers", worker_type));
Vec::new()
}
};
// Process each worker
for worker_url in urls {
let url = format!("{}/{}", worker_url, endpoint);
match self.client.post(&url).send().await {
Ok(res) if res.status().is_success() => {
results.push(format!("{} {}: OK", worker_type, worker_url));
}
Ok(res) => {
errors.push(format!(
"{} {} returned status: {}",
worker_type,
worker_url,
res.status()
));
}
Err(e) => {
errors.push(format!("{} {} error: {}", worker_type, worker_url, e));
}
}
}
(results, errors)
}
// Helper to get worker URLs from a worker collection
fn get_worker_urls(
workers: &RwLock<Vec<Box<dyn Worker>>>,
worker_type: &str,
) -> Result<Vec<String>, String> {
workers
.read()
.map(|workers| {
workers
.iter()
.map(|w| w.url().to_string())
.collect::<Vec<_>>()
})
.map_err(|_| format!("Failed to access {} workers", worker_type))
}
// Generic helper for proxying requests to the first worker
async fn proxy_to_first_worker(
&self,
workers: &RwLock<Vec<Box<dyn Worker>>>,
endpoint: &str,
worker_type: &str,
headers: Option<Vec<(String, String)>>,
) -> Response {
// Get first worker URL to avoid holding lock across await
let first_worker_url = match workers.read() {
Ok(workers) => workers.first().map(|w| w.url().to_string()),
Err(_) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to access {} workers", worker_type),
)
.into_response();
}
};
if let Some(worker_url) = first_worker_url {
let url = format!("{}/{}", worker_url, endpoint);
let mut request_builder = self.client.get(&url);
// Add headers if provided
if let Some(headers) = headers {
for (name, value) in headers {
request_builder = request_builder.header(name, value);
}
}
match request_builder.send().await {
Ok(res) if res.status().is_success() => match res.bytes().await {
Ok(body) => (StatusCode::OK, body).into_response(),
Err(e) => {
error!("Failed to read response body: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to read response body: {}", e),
)
.into_response()
}
},
Ok(res) => {
let status = StatusCode::from_u16(res.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
(
status,
format!("{} server returned status: {}", worker_type, res.status()),
)
.into_response()
}
Err(e) => {
error!("Failed to proxy request to {} server: {}", worker_type, e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to proxy request: {}", e),
)
.into_response()
}
}
} else {
(
StatusCode::SERVICE_UNAVAILABLE,
format!("No {} servers available", worker_type),
)
.into_response()
}
}
pub async fn add_prefill_server( pub async fn add_prefill_server(
&self, &self,
url: String, url: String,
...@@ -1384,191 +1516,32 @@ impl RouterTrait for PDRouter { ...@@ -1384,191 +1516,32 @@ impl RouterTrait for PDRouter {
async fn get_server_info(&self, _req: Request<Body>) -> Response { async fn get_server_info(&self, _req: Request<Body>) -> Response {
// Get info from the first decode server to match sglang's server info format // Get info from the first decode server to match sglang's server info format
let first_decode_url = if let Ok(workers) = self.decode_workers.read() { // Note: We use decode workers for server info to match expected format
workers.first().map(|w| w.url().to_string()) self.proxy_to_first_worker(&self.decode_workers, "get_server_info", "decode", None)
} else { .await
return (
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to access decode workers",
)
.into_response();
};
if let Some(worker_url) = first_decode_url {
match self
.client
.get(format!("{}/get_server_info", worker_url))
.send()
.await
{
Ok(res) if res.status().is_success() => {
match res.json::<Value>().await {
Ok(info) => {
// The decode server should already return the proper format
// with tokenizer_path and other fields that bench_one_batch_server.py expects
Json(info).into_response()
}
Err(e) => {
error!("Failed to parse server info: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to parse server info: {}", e),
)
.into_response()
}
}
}
Ok(res) => {
let status = StatusCode::from_u16(res.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
(
status,
format!("Decode server returned status: {}", res.status()),
)
.into_response()
}
Err(e) => {
error!("Failed to get server info: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to get server info: {}", e),
)
.into_response()
}
}
} else {
(
StatusCode::SERVICE_UNAVAILABLE,
"No decode servers available",
)
.into_response()
}
} }
async fn get_models(&self, req: Request<Body>) -> Response { async fn get_models(&self, req: Request<Body>) -> Response {
// Extract headers first to avoid Send issues // Extract headers first to avoid Send issues
let headers = crate::routers::router::copy_request_headers(&req); let headers = crate::routers::router::copy_request_headers(&req);
// Get first prefill worker URL to avoid holding lock across await // Proxy to first prefill worker
let first_worker_url = if let Ok(workers) = self.prefill_workers.read() { self.proxy_to_first_worker(&self.prefill_workers, "v1/models", "prefill", Some(headers))
workers.first().map(|w| w.url().to_string()) .await
} else {
return (
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to access prefill workers",
)
.into_response();
};
if let Some(worker_url) = first_worker_url {
let url = format!("{}/v1/models", worker_url);
let mut request_builder = self.client.get(&url);
// Add headers
for (name, value) in headers {
request_builder = request_builder.header(name, value);
}
match request_builder.send().await {
Ok(res) if res.status().is_success() => match res.bytes().await {
Ok(body) => (StatusCode::OK, body).into_response(),
Err(e) => {
error!("Failed to read response body: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to read response body: {}", e),
)
.into_response()
}
},
Ok(res) => {
let status = StatusCode::from_u16(res.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
(
status,
format!("Prefill server returned status: {}", res.status()),
)
.into_response()
}
Err(e) => {
error!("Failed to get models: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to get models: {}", e),
)
.into_response()
}
}
} else {
(
StatusCode::SERVICE_UNAVAILABLE,
"No prefill servers available",
)
.into_response()
}
} }
async fn get_model_info(&self, req: Request<Body>) -> Response { async fn get_model_info(&self, req: Request<Body>) -> Response {
// Extract headers first to avoid Send issues // Extract headers first to avoid Send issues
let headers = crate::routers::router::copy_request_headers(&req); let headers = crate::routers::router::copy_request_headers(&req);
// Get first prefill worker URL to avoid holding lock across await // Proxy to first prefill worker
let first_worker_url = if let Ok(workers) = self.prefill_workers.read() { self.proxy_to_first_worker(
workers.first().map(|w| w.url().to_string()) &self.prefill_workers,
} else { "get_model_info",
return ( "prefill",
StatusCode::INTERNAL_SERVER_ERROR, Some(headers),
"Failed to access prefill workers", )
) .await
.into_response();
};
if let Some(worker_url) = first_worker_url {
let url = format!("{}/get_model_info", worker_url);
let mut request_builder = self.client.get(&url);
// Add headers
for (name, value) in headers {
request_builder = request_builder.header(name, value);
}
match request_builder.send().await {
Ok(res) if res.status().is_success() => match res.bytes().await {
Ok(body) => (StatusCode::OK, body).into_response(),
Err(e) => {
error!("Failed to read response body: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to read response body: {}", e),
)
.into_response()
}
},
Ok(res) => {
let status = StatusCode::from_u16(res.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
(
status,
format!("Prefill server returned status: {}", res.status()),
)
.into_response()
}
Err(e) => {
error!("Failed to get model info: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to get model info: {}", e),
)
.into_response()
}
}
} else {
(
StatusCode::SERVICE_UNAVAILABLE,
"No prefill servers available",
)
.into_response()
}
} }
async fn route_generate( async fn route_generate(
...@@ -1692,70 +1665,19 @@ impl RouterTrait for PDRouter { ...@@ -1692,70 +1665,19 @@ impl RouterTrait for PDRouter {
} }
async fn flush_cache(&self) -> Response { async fn flush_cache(&self) -> Response {
let mut results = Vec::new(); // Process both prefill and decode workers
let mut errors = Vec::new(); let (prefill_results, prefill_errors) = self
.process_workers(&self.prefill_workers, "Prefill", "flush_cache")
// Get prefill worker URLs first to avoid holding lock across await .await;
let prefill_urls = if let Ok(workers) = self.prefill_workers.read() { let (decode_results, decode_errors) = self
workers .process_workers(&self.decode_workers, "Decode", "flush_cache")
.iter() .await;
.map(|w| w.url().to_string())
.collect::<Vec<_>>() // Combine results and errors
} else { let mut results = prefill_results;
errors.push("Failed to access prefill workers".to_string()); results.extend(decode_results);
Vec::new() let mut errors = prefill_errors;
}; errors.extend(decode_errors);
// Flush prefill workers
for worker_url in prefill_urls {
let url = format!("{}/flush_cache", worker_url);
match self.client.post(&url).send().await {
Ok(res) if res.status().is_success() => {
results.push(format!("Prefill {}: OK", worker_url));
}
Ok(res) => {
errors.push(format!(
"Prefill {} returned status: {}",
worker_url,
res.status()
));
}
Err(e) => {
errors.push(format!("Prefill {} error: {}", worker_url, e));
}
}
}
// Get decode worker URLs first to avoid holding lock across await
let decode_urls = if let Ok(workers) = self.decode_workers.read() {
workers
.iter()
.map(|w| w.url().to_string())
.collect::<Vec<_>>()
} else {
errors.push("Failed to access decode workers".to_string());
Vec::new()
};
// Flush decode workers
for worker_url in decode_urls {
let url = format!("{}/flush_cache", worker_url);
match self.client.post(&url).send().await {
Ok(res) if res.status().is_success() => {
results.push(format!("Decode {}: OK", worker_url));
}
Ok(res) => {
errors.push(format!(
"Decode {} returned status: {}",
worker_url,
res.status()
));
}
Err(e) => {
errors.push(format!("Decode {} error: {}", worker_url, e));
}
}
}
if errors.is_empty() { if errors.is_empty() {
( (
...@@ -1779,50 +1701,38 @@ impl RouterTrait for PDRouter { ...@@ -1779,50 +1701,38 @@ impl RouterTrait for PDRouter {
let mut loads = HashMap::new(); let mut loads = HashMap::new();
let mut errors = Vec::new(); let mut errors = Vec::new();
// Get prefill worker URLs first to avoid holding lock across await // Process prefill workers
let prefill_urls = if let Ok(workers) = self.prefill_workers.read() { match Self::get_worker_urls(&self.prefill_workers, "prefill") {
workers Ok(urls) => {
.iter() for worker_url in urls {
.map(|w| w.url().to_string()) match get_worker_load(&self.client, &worker_url).await {
.collect::<Vec<_>>() Some(load) => {
} else { loads.insert(format!("prefill_{}", worker_url), load);
errors.push("Failed to access prefill workers".to_string()); }
Vec::new() None => {
}; errors.push(format!("Failed to get load from prefill {}", worker_url));
}
// Get loads from prefill workers }
for worker_url in prefill_urls {
match get_worker_load(&self.client, &worker_url).await {
Some(load) => {
loads.insert(format!("prefill_{}", worker_url), load);
}
None => {
errors.push(format!("Failed to get load from prefill {}", worker_url));
} }
} }
Err(e) => errors.push(e),
} }
// Get decode worker URLs first to avoid holding lock across await // Process decode workers
let decode_urls = if let Ok(workers) = self.decode_workers.read() { match Self::get_worker_urls(&self.decode_workers, "decode") {
workers Ok(urls) => {
.iter() for worker_url in urls {
.map(|w| w.url().to_string()) match get_worker_load(&self.client, &worker_url).await {
.collect::<Vec<_>>() Some(load) => {
} else { loads.insert(format!("decode_{}", worker_url), load);
errors.push("Failed to access decode workers".to_string()); }
Vec::new() None => {
}; errors.push(format!("Failed to get load from decode {}", worker_url));
}
// Get loads from decode workers }
for worker_url in decode_urls {
match get_worker_load(&self.client, &worker_url).await {
Some(load) => {
loads.insert(format!("decode_{}", worker_url), load);
}
None => {
errors.push(format!("Failed to get load from decode {}", worker_url));
} }
} }
Err(e) => errors.push(e),
} }
let response_data = serde_json::json!({ let response_data = serde_json::json!({
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment