Unverified Commit 8b30bec2 authored by Bruce-x-1997's avatar Bruce-x-1997 Committed by GitHub
Browse files

[router] fix error response in pd_router (#9505)


Co-authored-by: default avatarbruce.xu <bruce.xu@gmicloud.ai>
parent 4aeba40d
...@@ -28,7 +28,7 @@ use axum::{ ...@@ -28,7 +28,7 @@ use axum::{
use futures_util::StreamExt; use futures_util::StreamExt;
use reqwest::Client; use reqwest::Client;
use serde::Serialize; use serde::Serialize;
use serde_json::Value; use serde_json::{json, Value};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
...@@ -808,6 +808,57 @@ impl PDRouter { ...@@ -808,6 +808,57 @@ impl PDRouter {
.await .await
} }
async fn handle_decode_error_response(
&self,
res: reqwest::Response,
context: &PDRequestContext,
prefill: &dyn Worker,
decode: &dyn Worker,
) -> Response {
let status = res.status();
if context.is_stream {
// Handle streaming error response
let response_headers = header_utils::preserve_response_headers(res.headers());
let error_payload = match res.bytes().await {
Ok(error_body) => {
if let Ok(error_json) = serde_json::from_slice::<Value>(&error_body) {
json!({ "message": error_json, "status": status.as_u16() })
} else {
json!({ "message": String::from_utf8_lossy(&error_body).to_string(), "status": status.as_u16() })
}
}
Err(e) => {
json!({ "message": format!("Decode server error: {}", e), "status": status.as_u16() })
}
};
let sse_data = format!(
"data: {{'error': {}}}",
serde_json::to_string(&error_payload).unwrap_or_default()
);
let error_stream = tokio_stream::once(Ok(axum::body::Bytes::from(sse_data)));
let decode_url = decode.url().to_string();
self.create_streaming_response(
error_stream,
status,
None,
context.return_logprob,
Some(decode_url),
Some(response_headers),
prefill,
decode,
)
} else {
// Handle non-streaming error response
match res.bytes().await {
Ok(error_body) => (status, error_body).into_response(),
Err(e) => (status, format!("Decode server error: {}", e)).into_response(),
}
}
}
// Internal method that performs the actual dual dispatch (without retry logic) // Internal method that performs the actual dual dispatch (without retry logic)
async fn execute_dual_dispatch_internal( async fn execute_dual_dispatch_internal(
&self, &self,
...@@ -881,16 +932,9 @@ impl PDRouter { ...@@ -881,16 +932,9 @@ impl PDRouter {
status status
); );
// Return the error response from decode server return self
match res.bytes().await { .handle_decode_error_response(res, &context, prefill, decode)
Ok(error_body) => { .await;
return (status, error_body).into_response();
}
Err(e) => {
return (status, format!("Decode server error: {}", e))
.into_response();
}
}
} }
// Process prefill response for logprobs // Process prefill response for logprobs
...@@ -1034,13 +1078,8 @@ impl PDRouter { ...@@ -1034,13 +1078,8 @@ impl PDRouter {
status status
); );
// Return the error response from decode server self.handle_decode_error_response(res, &context, prefill, decode)
match res.bytes().await { .await
Ok(error_body) => (status, error_body).into_response(),
Err(e) => {
(status, format!("Decode server error: {}", e)).into_response()
}
}
} else if context.is_stream { } else if context.is_stream {
// Streaming response without logprobs - direct passthrough // Streaming response without logprobs - direct passthrough
let decode_url = decode.url().to_string(); let decode_url = decode.url().to_string();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment