Unverified Commit 420c99ac authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][grpc] Fix error message format in grpc chat handler (#11307)

parent e3c7f091
...@@ -194,8 +194,7 @@ impl GrpcPDRouter { ...@@ -194,8 +194,7 @@ impl GrpcPDRouter {
let (original_text, token_ids) = match self.resolve_generate_input(body) { let (original_text, token_ids) = match self.resolve_generate_input(body) {
Ok(res) => res, Ok(res) => res,
Err(msg) => { Err(msg) => {
error!("Invalid generate request: {}", msg); return utils::bad_request_error(msg);
return (StatusCode::BAD_REQUEST, msg).into_response();
} }
}; };
...@@ -208,8 +207,7 @@ impl GrpcPDRouter { ...@@ -208,8 +207,7 @@ impl GrpcPDRouter {
{ {
Ok(pair) => pair, Ok(pair) => pair,
Err(e) => { Err(e) => {
warn!("Failed to select PD worker pair: {}", e); return utils::service_unavailable_error(e);
return (StatusCode::SERVICE_UNAVAILABLE, e).into_response();
} }
}; };
...@@ -244,15 +242,13 @@ impl GrpcPDRouter { ...@@ -244,15 +242,13 @@ impl GrpcPDRouter {
) { ) {
Ok(req) => req, Ok(req) => req,
Err(e) => { Err(e) => {
error!("Failed to build generate request: {}", e); return utils::bad_request_error(e);
return (StatusCode::BAD_REQUEST, e).into_response();
} }
}; };
// Step 5: Inject bootstrap metadata // Step 5: Inject bootstrap metadata
if let Err(e) = Self::inject_bootstrap_metadata(&mut request, &*prefill_worker) { if let Err(e) = Self::inject_bootstrap_metadata(&mut request, &*prefill_worker) {
error!("Failed to inject bootstrap metadata: {}", e); return utils::internal_error_message(e);
return (StatusCode::INTERNAL_SERVER_ERROR, e).into_response();
} }
// Step 6: Get weight version for response metadata // Step 6: Get weight version for response metadata
...@@ -334,8 +330,7 @@ impl GrpcPDRouter { ...@@ -334,8 +330,7 @@ impl GrpcPDRouter {
let processed_messages = match utils::process_chat_messages(&body_ref, &*self.tokenizer) { let processed_messages = match utils::process_chat_messages(&body_ref, &*self.tokenizer) {
Ok(msgs) => msgs, Ok(msgs) => msgs,
Err(e) => { Err(e) => {
error!("Failed to process chat messages: {}", e); return utils::bad_request_error(e.to_string());
return (StatusCode::BAD_REQUEST, e.to_string()).into_response();
} }
}; };
...@@ -343,12 +338,7 @@ impl GrpcPDRouter { ...@@ -343,12 +338,7 @@ impl GrpcPDRouter {
let encoding = match self.tokenizer.encode(&processed_messages.text) { let encoding = match self.tokenizer.encode(&processed_messages.text) {
Ok(encoding) => encoding, Ok(encoding) => encoding,
Err(e) => { Err(e) => {
error!("Tokenization failed: {}", e); return utils::internal_error_message(format!("Tokenization failed: {}", e));
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Tokenization failed: {}", e),
)
.into_response();
} }
}; };
...@@ -368,8 +358,7 @@ impl GrpcPDRouter { ...@@ -368,8 +358,7 @@ impl GrpcPDRouter {
{ {
Ok(pair) => pair, Ok(pair) => pair,
Err(e) => { Err(e) => {
warn!("Failed to select PD worker pair: {}", e); return utils::service_unavailable_error(e);
return (StatusCode::SERVICE_UNAVAILABLE, e).into_response();
} }
}; };
...@@ -402,19 +391,13 @@ impl GrpcPDRouter { ...@@ -402,19 +391,13 @@ impl GrpcPDRouter {
) { ) {
Ok(request) => request, Ok(request) => request,
Err(e) => { Err(e) => {
error!("Failed to build gRPC request: {}", e); return utils::bad_request_error(format!("Invalid request parameters: {}", e));
return (
StatusCode::BAD_REQUEST,
format!("Invalid request parameters: {}", e),
)
.into_response();
} }
}; };
// Step 8: Inject bootstrap metadata into the request // Step 8: Inject bootstrap metadata into the request
if let Err(e) = Self::inject_bootstrap_metadata(&mut request, &*prefill_worker) { if let Err(e) = Self::inject_bootstrap_metadata(&mut request, &*prefill_worker) {
error!("Failed to inject bootstrap metadata: {}", e); return utils::internal_error_message(e);
return (StatusCode::INTERNAL_SERVER_ERROR, e).into_response();
} }
// Step 9: Handle streaming vs non-streaming // Step 9: Handle streaming vs non-streaming
...@@ -486,12 +469,10 @@ impl GrpcPDRouter { ...@@ -486,12 +469,10 @@ impl GrpcPDRouter {
let prefill_stream = match prefill_result { let prefill_stream = match prefill_result {
Ok(s) => s, Ok(s) => s,
Err(e) => { Err(e) => {
error!("Failed to start prefill generation: {}", e); return utils::internal_error_message(format!(
return ( "Prefill worker failed to start: {}",
StatusCode::INTERNAL_SERVER_ERROR, e
format!("Prefill worker failed to start: {}", e), ));
)
.into_response();
} }
}; };
...@@ -499,12 +480,10 @@ impl GrpcPDRouter { ...@@ -499,12 +480,10 @@ impl GrpcPDRouter {
let decode_stream = match decode_result { let decode_stream = match decode_result {
Ok(s) => s, Ok(s) => s,
Err(e) => { Err(e) => {
error!("Failed to start decode generation: {}", e); return utils::internal_error_message(format!(
return ( "Decode worker failed to start: {}",
StatusCode::INTERNAL_SERVER_ERROR, e
format!("Decode worker failed to start: {}", e), ));
)
.into_response();
} }
}; };
...@@ -592,12 +571,10 @@ impl GrpcPDRouter { ...@@ -592,12 +571,10 @@ impl GrpcPDRouter {
let prefill_stream = match prefill_result { let prefill_stream = match prefill_result {
Ok(s) => s, Ok(s) => s,
Err(e) => { Err(e) => {
error!("Failed to start prefill generation: {}", e); return utils::internal_error_message(format!(
return ( "Prefill worker failed to start: {}",
StatusCode::INTERNAL_SERVER_ERROR, e
format!("Prefill worker failed to start: {}", e), ));
)
.into_response();
} }
}; };
...@@ -605,12 +582,10 @@ impl GrpcPDRouter { ...@@ -605,12 +582,10 @@ impl GrpcPDRouter {
let decode_stream = match decode_result { let decode_stream = match decode_result {
Ok(s) => s, Ok(s) => s,
Err(e) => { Err(e) => {
error!("Failed to start decode generation: {}", e); return utils::internal_error_message(format!(
return ( "Decode worker failed to start: {}",
StatusCode::INTERNAL_SERVER_ERROR, e
format!("Decode worker failed to start: {}", e), ));
)
.into_response();
} }
}; };
......
...@@ -113,8 +113,7 @@ impl GrpcRouter { ...@@ -113,8 +113,7 @@ impl GrpcRouter {
let processed_messages = match utils::process_chat_messages(&body_ref, &*self.tokenizer) { let processed_messages = match utils::process_chat_messages(&body_ref, &*self.tokenizer) {
Ok(msgs) => msgs, Ok(msgs) => msgs,
Err(e) => { Err(e) => {
error!("Failed to process chat messages: {}", e); return utils::bad_request_error(e.to_string());
return (StatusCode::BAD_REQUEST, e.to_string()).into_response();
} }
}; };
...@@ -122,12 +121,7 @@ impl GrpcRouter { ...@@ -122,12 +121,7 @@ impl GrpcRouter {
let encoding = match self.tokenizer.encode(&processed_messages.text) { let encoding = match self.tokenizer.encode(&processed_messages.text) {
Ok(encoding) => encoding, Ok(encoding) => encoding,
Err(e) => { Err(e) => {
error!("Tokenization failed: {}", e); return utils::internal_error_message(format!("Tokenization failed: {}", e));
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Tokenization failed: {}", e),
)
.into_response();
} }
}; };
...@@ -145,8 +139,10 @@ impl GrpcRouter { ...@@ -145,8 +139,10 @@ impl GrpcRouter {
{ {
Some(w) => w, Some(w) => w,
None => { None => {
warn!("No available workers for model: {:?}", model_id); return utils::service_unavailable_error(format!(
return (StatusCode::SERVICE_UNAVAILABLE, "No available workers").into_response(); "No available workers for model: {:?}",
model_id
));
} }
}; };
...@@ -170,12 +166,7 @@ impl GrpcRouter { ...@@ -170,12 +166,7 @@ impl GrpcRouter {
) { ) {
Ok(request) => request, Ok(request) => request,
Err(e) => { Err(e) => {
error!("Failed to build gRPC request: {}", e); return utils::bad_request_error(format!("Invalid request parameters: {}", e));
return (
StatusCode::BAD_REQUEST,
format!("Invalid request parameters: {}", e),
)
.into_response();
} }
}; };
...@@ -200,8 +191,7 @@ impl GrpcRouter { ...@@ -200,8 +191,7 @@ impl GrpcRouter {
let (original_text, token_ids) = match self.resolve_generate_input(body) { let (original_text, token_ids) = match self.resolve_generate_input(body) {
Ok(res) => res, Ok(res) => res,
Err(msg) => { Err(msg) => {
error!("Invalid generate request: {}", msg); return utils::bad_request_error(msg);
return (StatusCode::BAD_REQUEST, msg).into_response();
} }
}; };
...@@ -211,8 +201,10 @@ impl GrpcRouter { ...@@ -211,8 +201,10 @@ impl GrpcRouter {
let worker = match self.select_worker_for_request(model_id, original_text.as_deref()) { let worker = match self.select_worker_for_request(model_id, original_text.as_deref()) {
Some(w) => w, Some(w) => w,
None => { None => {
warn!("No available workers for model: {:?}", model_id); return utils::service_unavailable_error(format!(
return (StatusCode::SERVICE_UNAVAILABLE, "No available workers").into_response(); "No available workers for model: {:?}",
model_id
));
} }
}; };
...@@ -238,8 +230,7 @@ impl GrpcRouter { ...@@ -238,8 +230,7 @@ impl GrpcRouter {
) { ) {
Ok(req) => req, Ok(req) => req,
Err(e) => { Err(e) => {
error!("Failed to build generate request: {}", e); return utils::bad_request_error(e);
return (StatusCode::BAD_REQUEST, e).into_response();
} }
}; };
...@@ -405,16 +396,6 @@ impl GrpcRouter { ...@@ -405,16 +396,6 @@ impl GrpcRouter {
Ok((text.to_string(), encoding.token_ids().to_vec())) Ok((text.to_string(), encoding.token_ids().to_vec()))
} }
fn internal_error_static(msg: &'static str) -> Response {
error!("{}", msg);
(StatusCode::INTERNAL_SERVER_ERROR, msg).into_response()
}
fn internal_error_message(message: String) -> Response {
error!("{}", message);
(StatusCode::INTERNAL_SERVER_ERROR, message).into_response()
}
/// Count the number of tool calls in the request message history /// Count the number of tool calls in the request message history
/// This is used for KimiK2 format which needs globally unique indices /// This is used for KimiK2 format which needs globally unique indices
fn get_history_tool_calls_count(request: &ChatCompletionRequest) -> usize { fn get_history_tool_calls_count(request: &ChatCompletionRequest) -> usize {
...@@ -740,12 +721,7 @@ impl GrpcRouter { ...@@ -740,12 +721,7 @@ impl GrpcRouter {
let mut grpc_stream = match client.generate(request).await { let mut grpc_stream = match client.generate(request).await {
Ok(stream) => stream, Ok(stream) => stream,
Err(e) => { Err(e) => {
error!("Failed to start generation: {}", e); return utils::internal_error_message(format!("Generation failed: {}", e));
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Generation failed: {}", e),
)
.into_response();
} }
}; };
...@@ -1183,7 +1159,7 @@ impl GrpcRouter { ...@@ -1183,7 +1159,7 @@ impl GrpcRouter {
let stream = match client.generate(request).await { let stream = match client.generate(request).await {
Ok(s) => s, Ok(s) => s,
Err(e) => { Err(e) => {
return Self::internal_error_message(format!("Failed to start generation: {}", e)) return utils::internal_error_message(format!("Failed to start generation: {}", e))
} }
}; };
...@@ -1193,7 +1169,7 @@ impl GrpcRouter { ...@@ -1193,7 +1169,7 @@ impl GrpcRouter {
}; };
if all_responses.is_empty() { if all_responses.is_empty() {
return Self::internal_error_static("No responses from server"); return utils::internal_error_static("No responses from server");
} }
// Process each response into a ChatChoice // Process each response into a ChatChoice
...@@ -1212,7 +1188,7 @@ impl GrpcRouter { ...@@ -1212,7 +1188,7 @@ impl GrpcRouter {
{ {
Ok(choice) => choices.push(choice), Ok(choice) => choices.push(choice),
Err(e) => { Err(e) => {
return Self::internal_error_message(format!( return utils::internal_error_message(format!(
"Failed to process choice {}: {}", "Failed to process choice {}: {}",
index, e index, e
)); ));
...@@ -1265,7 +1241,7 @@ impl GrpcRouter { ...@@ -1265,7 +1241,7 @@ impl GrpcRouter {
let stream = match client.generate(request).await { let stream = match client.generate(request).await {
Ok(stream) => stream, Ok(stream) => stream,
Err(e) => { Err(e) => {
return Self::internal_error_message(format!("Failed to start generation: {}", e)) return utils::internal_error_message(format!("Failed to start generation: {}", e))
} }
}; };
...@@ -1276,7 +1252,7 @@ impl GrpcRouter { ...@@ -1276,7 +1252,7 @@ impl GrpcRouter {
}; };
if responses.is_empty() { if responses.is_empty() {
return Self::internal_error_static("No completion received from scheduler"); return utils::internal_error_static("No completion received from scheduler");
} }
// Create stop decoder from sampling params // Create stop decoder from sampling params
...@@ -1298,7 +1274,10 @@ impl GrpcRouter { ...@@ -1298,7 +1274,10 @@ impl GrpcRouter {
let outputs = match stop_decoder.process_tokens(&complete.output_ids) { let outputs = match stop_decoder.process_tokens(&complete.output_ids) {
Ok(outputs) => outputs, Ok(outputs) => outputs,
Err(e) => { Err(e) => {
return Self::internal_error_message(format!("Failed to process tokens: {}", e)) return utils::internal_error_message(format!(
"Failed to process tokens: {}",
e
))
} }
}; };
...@@ -1377,7 +1356,7 @@ impl GrpcRouter { ...@@ -1377,7 +1356,7 @@ impl GrpcRouter {
let stream = match client.generate(request).await { let stream = match client.generate(request).await {
Ok(stream) => stream, Ok(stream) => stream,
Err(e) => { Err(e) => {
return Self::internal_error_message(format!("Failed to start generation: {}", e)) return utils::internal_error_message(format!("Failed to start generation: {}", e))
} }
}; };
......
...@@ -14,13 +14,14 @@ pub use crate::tokenizer::StopSequenceDecoder; ...@@ -14,13 +14,14 @@ pub use crate::tokenizer::StopSequenceDecoder;
use axum::{ use axum::{
http::StatusCode, http::StatusCode,
response::{IntoResponse, Response}, response::{IntoResponse, Response},
Json,
}; };
use futures::StreamExt; use futures::StreamExt;
use serde_json::{json, Map, Value}; use serde_json::{json, Map, Value};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use tonic::codec::Streaming; use tonic::codec::Streaming;
use tracing::{debug, error}; use tracing::{debug, error, warn};
use uuid::Uuid; use uuid::Uuid;
/// Get gRPC client from worker, returning appropriate error response on failure /// Get gRPC client from worker, returning appropriate error response on failure
...@@ -30,22 +31,8 @@ pub async fn get_grpc_client_from_worker( ...@@ -30,22 +31,8 @@ pub async fn get_grpc_client_from_worker(
let client_arc = worker let client_arc = worker
.get_grpc_client() .get_grpc_client()
.await .await
.map_err(|e| { .map_err(|e| internal_error_message(format!("Failed to get gRPC client: {}", e)))?
error!("Failed to get gRPC client from worker: {}", e); .ok_or_else(|| internal_error_static("Selected worker is not configured for gRPC"))?;
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to get gRPC client: {}", e),
)
.into_response()
})?
.ok_or_else(|| {
error!("Selected worker is not a gRPC worker");
(
StatusCode::INTERNAL_SERVER_ERROR,
"Selected worker is not configured for gRPC",
)
.into_response()
})?;
let client = client_arc.lock().await.clone(); let client = client_arc.lock().await.clone();
Ok(client) Ok(client)
...@@ -422,12 +409,62 @@ pub fn process_chat_messages( ...@@ -422,12 +409,62 @@ pub fn process_chat_messages(
/// Error response helpers (shared between regular and PD routers) /// Error response helpers (shared between regular and PD routers)
pub fn internal_error_static(msg: &'static str) -> Response { pub fn internal_error_static(msg: &'static str) -> Response {
error!("{}", msg); error!("{}", msg);
(StatusCode::INTERNAL_SERVER_ERROR, msg).into_response() (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({
"error": {
"message": msg,
"type": "internal_error",
"code": 500
}
})),
)
.into_response()
} }
pub fn internal_error_message(message: String) -> Response { pub fn internal_error_message(message: String) -> Response {
error!("{}", message); error!("{}", message);
(StatusCode::INTERNAL_SERVER_ERROR, message).into_response() (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({
"error": {
"message": message,
"type": "internal_error",
"code": 500
}
})),
)
.into_response()
}
pub fn bad_request_error(message: String) -> Response {
error!("{}", message);
(
StatusCode::BAD_REQUEST,
Json(json!({
"error": {
"message": message,
"type": "invalid_request_error",
"code": 400
}
})),
)
.into_response()
}
pub fn service_unavailable_error(message: String) -> Response {
warn!("{}", message);
(
StatusCode::SERVICE_UNAVAILABLE,
Json(json!({
"error": {
"message": message,
"type": "service_unavailable",
"code": 503
}
})),
)
.into_response()
} }
/// Create a StopSequenceDecoder from stop parameters /// Create a StopSequenceDecoder from stop parameters
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment