Unverified Commit 8b30bec2 authored by Bruce-x-1997's avatar Bruce-x-1997 Committed by GitHub
Browse files

[router] fix error response in pd_router (#9505)


Co-authored-by: default avatarbruce.xu <bruce.xu@gmicloud.ai>
parent 4aeba40d
......@@ -28,7 +28,7 @@ use axum::{
use futures_util::StreamExt;
use reqwest::Client;
use serde::Serialize;
use serde_json::Value;
use serde_json::{json, Value};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
......@@ -808,6 +808,57 @@ impl PDRouter {
.await
}
async fn handle_decode_error_response(
&self,
res: reqwest::Response,
context: &PDRequestContext,
prefill: &dyn Worker,
decode: &dyn Worker,
) -> Response {
let status = res.status();
if context.is_stream {
// Handle streaming error response
let response_headers = header_utils::preserve_response_headers(res.headers());
let error_payload = match res.bytes().await {
Ok(error_body) => {
if let Ok(error_json) = serde_json::from_slice::<Value>(&error_body) {
json!({ "message": error_json, "status": status.as_u16() })
} else {
json!({ "message": String::from_utf8_lossy(&error_body).to_string(), "status": status.as_u16() })
}
}
Err(e) => {
json!({ "message": format!("Decode server error: {}", e), "status": status.as_u16() })
}
};
let sse_data = format!(
"data: {{'error': {}}}",
serde_json::to_string(&error_payload).unwrap_or_default()
);
let error_stream = tokio_stream::once(Ok(axum::body::Bytes::from(sse_data)));
let decode_url = decode.url().to_string();
self.create_streaming_response(
error_stream,
status,
None,
context.return_logprob,
Some(decode_url),
Some(response_headers),
prefill,
decode,
)
} else {
// Handle non-streaming error response
match res.bytes().await {
Ok(error_body) => (status, error_body).into_response(),
Err(e) => (status, format!("Decode server error: {}", e)).into_response(),
}
}
}
// Internal method that performs the actual dual dispatch (without retry logic)
async fn execute_dual_dispatch_internal(
&self,
......@@ -881,16 +932,9 @@ impl PDRouter {
status
);
// Return the error response from decode server
match res.bytes().await {
Ok(error_body) => {
return (status, error_body).into_response();
}
Err(e) => {
return (status, format!("Decode server error: {}", e))
.into_response();
}
}
return self
.handle_decode_error_response(res, &context, prefill, decode)
.await;
}
// Process prefill response for logprobs
......@@ -1034,13 +1078,8 @@ impl PDRouter {
status
);
// Return the error response from decode server
match res.bytes().await {
Ok(error_body) => (status, error_body).into_response(),
Err(e) => {
(status, format!("Decode server error: {}", e)).into_response()
}
}
self.handle_decode_error_response(res, &context, prefill, decode)
.await
} else if context.is_stream {
// Streaming response without logprobs - direct passthrough
let decode_url = decode.url().to_string();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment