Unverified Commit 0b41a293 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][grpc] Consolidate error messages build in error.rs (#12301)

parent d31d48b3
//! Centralized error response handling for all routers
//!
//! This module provides consistent error responses across OpenAI and gRPC routers,
//! ensuring all errors follow OpenAI's API error format.
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
use serde_json::json;
use tracing::{error, warn};
/// Create a 500 Internal Server Error response
///
/// Use this for unexpected server-side errors, database failures, etc.
///
/// # Example
/// ```ignore
/// return Err(internal_error("Database connection failed"));
/// ```
pub fn internal_error(message: impl Into<String>) -> Response {
let msg = message.into();
error!("{}", msg);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({
"error": {
"message": msg,
"type": "internal_error",
"code": 500
}
})),
)
.into_response()
}
/// Create a 400 Bad Request response
///
/// Use this for invalid request parameters, malformed JSON, validation errors, etc.
///
/// # Example
/// ```ignore
/// return Err(bad_request("Invalid conversation ID format"));
/// ```
pub fn bad_request(message: impl Into<String>) -> Response {
let msg = message.into();
error!("{}", msg);
(
StatusCode::BAD_REQUEST,
Json(json!({
"error": {
"message": msg,
"type": "invalid_request_error",
"code": 400
}
})),
)
.into_response()
}
/// Create a 404 Not Found response
///
/// Use this for resources that don't exist (conversations, responses, etc.)
///
/// # Example
/// ```ignore
/// return Err(not_found(format!("Conversation '{}' not found", id)));
/// ```
pub fn not_found(message: impl Into<String>) -> Response {
let msg = message.into();
warn!("{}", msg);
(
StatusCode::NOT_FOUND,
Json(json!({
"error": {
"message": msg,
"type": "invalid_request_error",
"code": 404
}
})),
)
.into_response()
}
/// Create a 503 Service Unavailable response
///
/// Use this for temporary service issues like no workers available, rate limiting, etc.
///
/// # Example
/// ```ignore
/// return Err(service_unavailable("No workers available for this model"));
/// ```
pub fn service_unavailable(message: impl Into<String>) -> Response {
let msg = message.into();
warn!("{}", msg);
(
StatusCode::SERVICE_UNAVAILABLE,
Json(json!({
"error": {
"message": msg,
"type": "service_unavailable",
"code": 503
}
})),
)
.into_response()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_internal_error_string() {
let response = internal_error("Test error");
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[test]
fn test_internal_error_format() {
let response = internal_error(format!("Error: {}", 42));
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[test]
fn test_bad_request() {
let response = bad_request("Invalid input");
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
}
#[test]
fn test_not_found() {
let response = not_found("Resource not found");
assert_eq!(response.status(), StatusCode::NOT_FOUND);
}
#[test]
fn test_service_unavailable() {
let response = service_unavailable("No workers");
assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
}
}
......@@ -18,7 +18,7 @@ use crate::{
},
routers::grpc::{
context::{DispatchMetadata, ExecutionResult},
utils,
error, utils,
},
};
......@@ -66,7 +66,7 @@ impl HarmonyResponseProcessor {
// Collect all completed responses (one per choice)
let all_responses = Self::collect_responses(execution_result).await?;
if all_responses.is_empty() {
return Err(utils::internal_error_static("No responses from server"));
return Err(error::internal_error("No responses from server"));
}
// Build choices by parsing output with HarmonyParserAdapter
......@@ -84,7 +84,7 @@ impl HarmonyResponseProcessor {
// Parse Harmony channels with HarmonyParserAdapter
let mut parser = HarmonyParserAdapter::new().map_err(|e| {
utils::internal_error_message(format!("Failed to create Harmony parser: {}", e))
error::internal_error(format!("Failed to create Harmony parser: {}", e))
})?;
// Parse Harmony channels with finish_reason and matched_stop
......@@ -94,9 +94,7 @@ impl HarmonyResponseProcessor {
complete.finish_reason.clone(),
matched_stop.clone(),
)
.map_err(|e| {
utils::internal_error_message(format!("Harmony parsing failed: {}", e))
})?;
.map_err(|e| error::internal_error(format!("Harmony parsing failed: {}", e)))?;
// Build response message (assistant)
let message = ChatCompletionMessage {
......@@ -195,17 +193,17 @@ impl HarmonyResponseProcessor {
// Collect all completed responses
let all_responses = Self::collect_responses(execution_result).await?;
if all_responses.is_empty() {
return Err(utils::internal_error_static("No responses from server"));
return Err(error::internal_error("No responses from server"));
}
// For Responses API, we only process the first response (n=1)
let complete = all_responses
.first()
.ok_or_else(|| utils::internal_error_static("No complete response"))?;
.ok_or_else(|| error::internal_error("No complete response"))?;
// Parse Harmony channels
let mut parser = HarmonyParserAdapter::new().map_err(|e| {
utils::internal_error_message(format!("Failed to create Harmony parser: {}", e))
error::internal_error(format!("Failed to create Harmony parser: {}", e))
})?;
// Convert matched_stop from proto to JSON
......@@ -224,7 +222,7 @@ impl HarmonyResponseProcessor {
complete.finish_reason.clone(),
matched_stop,
)
.map_err(|e| utils::internal_error_message(format!("Harmony parsing failed: {}", e)))?;
.map_err(|e| error::internal_error(format!("Harmony parsing failed: {}", e)))?;
// VALIDATION: Check if model incorrectly generated Tool role messages
// This happens when the model copies the format of tool result messages
......
......@@ -61,10 +61,10 @@ use crate::{
routers::{
grpc::{
context::SharedComponents,
error,
harmony::processor::ResponsesIterationResult,
pipeline::RequestPipeline,
responses::streaming::{OutputItemType, ResponseStreamEventEmitter},
utils,
},
openai::mcp::ensure_request_mcp_client,
},
......@@ -285,7 +285,7 @@ pub async fn serve_harmony_responses(
// Safety check: prevent infinite loops
if iteration_count > MAX_TOOL_ITERATIONS {
return Err(utils::internal_error_message(format!(
return Err(error::internal_error(format!(
"Maximum tool iterations ({}) exceeded",
MAX_TOOL_ITERATIONS
)));
......@@ -333,7 +333,7 @@ pub async fn serve_harmony_responses(
execute_mcp_tools(&ctx.mcp_manager, &tool_calls, tracking).await?
} else {
// Should never happen (we only get tool_calls when has_mcp_tools=true)
return Err(utils::internal_error_static(
return Err(error::internal_error(
"Tool calls found but MCP tracking not initialized",
));
};
......@@ -734,7 +734,7 @@ async fn execute_mcp_tools(
// Parse tool arguments from JSON string
let args_str = tool_call.function.arguments.as_deref().unwrap_or("{}");
let args: JsonValue = serde_json::from_str(args_str).map_err(|e| {
utils::internal_error_message(format!(
error::internal_error(format!(
"Invalid tool arguments JSON for tool '{}': {}",
tool_call.function.name, e
))
......@@ -1111,7 +1111,7 @@ async fn load_previous_messages(
.get_response_chain(&prev_id, None)
.await
.map_err(|e| {
utils::internal_error_message(format!(
error::internal_error(format!(
"Failed to load previous response chain for {}: {}",
prev_id_str, e
))
......
......@@ -13,6 +13,7 @@ use crate::{
},
routers::grpc::{
context::{PreparationOutput, RequestContext, RequestType},
error,
stages::PipelineStage,
utils,
},
......@@ -56,7 +57,7 @@ impl PipelineStage for HarmonyPreparationStage {
let request_arc = ctx.responses_request_arc();
self.prepare_responses(ctx, &request_arc).await?;
} else {
return Err(utils::bad_request_error(
return Err(error::bad_request(
"Only Chat and Responses requests supported in Harmony pipeline".to_string(),
));
}
......@@ -78,7 +79,7 @@ impl HarmonyPreparationStage {
) -> Result<Option<Response>, Response> {
// Validate - reject logprobs
if request.logprobs {
return Err(utils::bad_request_error(
return Err(error::bad_request(
"logprobs are not supported for Harmony models".to_string(),
));
}
......@@ -97,7 +98,7 @@ impl HarmonyPreparationStage {
let build_output = self
.builder
.build_from_chat(&body_ref)
.map_err(|e| utils::bad_request_error(format!("Harmony build failed: {}", e)))?;
.map_err(|e| error::bad_request(format!("Harmony build failed: {}", e)))?;
// Step 4: Store results
ctx.state.preparation = Some(PreparationOutput {
......@@ -132,7 +133,7 @@ impl HarmonyPreparationStage {
let build_output = self
.builder
.build_from_responses(request)
.map_err(|e| utils::bad_request_error(format!("Harmony build failed: {}", e)))?;
.map_err(|e| error::bad_request(format!("Harmony build failed: {}", e)))?;
// Store results in preparation output
ctx.state.preparation = Some(PreparationOutput {
......@@ -202,7 +203,7 @@ impl HarmonyPreparationStage {
// Validate specific function exists
if specific_function.is_some() && tools_to_use.is_empty() {
return Err(Box::new(utils::bad_request_error(format!(
return Err(Box::new(error::bad_request(format!(
"Tool '{}' not found in tools list",
specific_function.unwrap()
))));
......@@ -236,7 +237,7 @@ impl HarmonyPreparationStage {
});
serde_json::to_string(&structural_tag).map_err(|e| {
Box::new(utils::internal_error_message(format!(
Box::new(error::internal_error(format!(
"Failed to serialize structural tag: {}",
e
)))
......
......@@ -13,8 +13,8 @@ use crate::{
grpc_client::proto::{DisaggregatedParams, GenerateRequest},
routers::grpc::{
context::{ClientSelection, RequestContext, RequestType, WorkerSelection},
error,
stages::PipelineStage,
utils,
},
};
......@@ -69,14 +69,14 @@ impl PipelineStage for HarmonyRequestBuildingStage {
.state
.preparation
.as_ref()
.ok_or_else(|| utils::internal_error_static("Preparation not completed"))?;
.ok_or_else(|| error::internal_error("Preparation not completed"))?;
// Get clients
let clients = ctx
.state
.clients
.as_ref()
.ok_or_else(|| utils::internal_error_static("Client acquisition not completed"))?;
.ok_or_else(|| error::internal_error("Client acquisition not completed"))?;
let builder_client = match clients {
ClientSelection::Single { client } => client,
ClientSelection::Dual { prefill, .. } => prefill,
......@@ -87,7 +87,7 @@ impl PipelineStage for HarmonyRequestBuildingStage {
RequestType::Chat(_) => format!("chatcmpl-{}", Uuid::new_v4()),
RequestType::Responses(_) => format!("responses-{}", Uuid::new_v4()),
RequestType::Generate(_) => {
return Err(utils::bad_request_error(
return Err(error::bad_request(
"Generate requests are not supported with Harmony models".to_string(),
));
}
......@@ -111,9 +111,7 @@ impl PipelineStage for HarmonyRequestBuildingStage {
None,
prep.tool_constraints.clone(),
)
.map_err(|e| {
utils::bad_request_error(format!("Invalid request parameters: {}", e))
})?
.map_err(|e| error::bad_request(format!("Invalid request parameters: {}", e)))?
}
RequestType::Responses(request) => builder_client
.build_generate_request_from_responses(
......@@ -123,9 +121,7 @@ impl PipelineStage for HarmonyRequestBuildingStage {
prep.token_ids.clone(),
prep.harmony_stop_ids.clone(),
)
.map_err(|e| {
utils::bad_request_error(format!("Invalid request parameters: {}", e))
})?,
.map_err(|e| error::bad_request(format!("Invalid request parameters: {}", e)))?,
_ => unreachable!(),
};
......
......@@ -8,8 +8,8 @@ use axum::response::Response;
use super::super::{HarmonyResponseProcessor, HarmonyStreamingProcessor};
use crate::routers::grpc::{
context::{FinalResponse, RequestContext, RequestType},
error,
stages::PipelineStage,
utils,
};
/// Harmony Response Processing stage: Parse and format Harmony responses
......@@ -51,14 +51,14 @@ impl PipelineStage for HarmonyResponseProcessingStage {
.response
.execution_result
.take()
.ok_or_else(|| utils::internal_error_static("No execution result"))?;
.ok_or_else(|| error::internal_error("No execution result"))?;
let dispatch = ctx
.state
.dispatch
.as_ref()
.cloned()
.ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?;
.ok_or_else(|| error::internal_error("Dispatch metadata not set"))?;
// For streaming, delegate to streaming processor and return SSE response
if is_streaming {
......@@ -97,14 +97,14 @@ impl PipelineStage for HarmonyResponseProcessingStage {
.response
.execution_result
.take()
.ok_or_else(|| utils::internal_error_static("No execution result"))?;
.ok_or_else(|| error::internal_error("No execution result"))?;
let dispatch = ctx
.state
.dispatch
.as_ref()
.cloned()
.ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?;
.ok_or_else(|| error::internal_error("Dispatch metadata not set"))?;
let responses_request = ctx.responses_request_arc();
let iteration_result = self
......@@ -115,7 +115,7 @@ impl PipelineStage for HarmonyResponseProcessingStage {
ctx.state.response.responses_iteration_result = Some(iteration_result);
Ok(None)
}
RequestType::Generate(_) => Err(utils::internal_error_static(
RequestType::Generate(_) => Err(error::internal_error(
"Generate requests not supported in Harmony pipeline",
)),
}
......
......@@ -3,6 +3,7 @@
use crate::{grpc_client::proto, protocols::common::StringOrArray};
pub mod context;
pub mod error;
pub mod harmony;
pub mod pd_router;
pub mod pipeline;
......
......@@ -11,7 +11,7 @@ use tracing::{debug, error};
// Import all stage types from the stages module
use super::stages::*;
use super::{context::*, harmony, processing, responses::BackgroundTaskInfo, streaming, utils};
use super::{context::*, error, harmony, processing, responses::BackgroundTaskInfo, streaming};
use crate::{
core::WorkerRegistry,
policies::PolicyRegistry,
......@@ -228,9 +228,9 @@ impl RequestPipeline {
match ctx.state.response.final_response {
Some(FinalResponse::Chat(response)) => axum::Json(response).into_response(),
Some(FinalResponse::Generate(_)) => {
utils::internal_error_static("Internal error: wrong response type")
error::internal_error("Internal error: wrong response type")
}
None => utils::internal_error_static("No response produced"),
None => error::internal_error("No response produced"),
}
}
......@@ -272,9 +272,9 @@ impl RequestPipeline {
match ctx.state.response.final_response {
Some(FinalResponse::Generate(response)) => axum::Json(response).into_response(),
Some(FinalResponse::Chat(_)) => {
utils::internal_error_static("Internal error: wrong response type")
error::internal_error("Internal error: wrong response type")
}
None => utils::internal_error_static("No response produced"),
None => error::internal_error("No response produced"),
}
}
......@@ -303,7 +303,7 @@ impl RequestPipeline {
match stage.execute(&mut ctx).await {
Ok(Some(_response)) => {
// Streaming not supported for responses sync mode
return Err(utils::bad_request_error(
return Err(error::bad_request(
"Streaming is not supported in this context".to_string(),
));
}
......@@ -360,10 +360,10 @@ impl RequestPipeline {
// Extract final response
match ctx.state.response.final_response {
Some(FinalResponse::Chat(response)) => Ok(response),
Some(FinalResponse::Generate(_)) => Err(utils::internal_error_static(
"Internal error: wrong response type",
)),
None => Err(utils::internal_error_static("No response produced")),
Some(FinalResponse::Generate(_)) => {
Err(error::internal_error("Internal error: wrong response type"))
}
None => Err(error::internal_error("No response produced")),
}
}
......@@ -384,7 +384,7 @@ impl RequestPipeline {
_model_id: Option<String>,
_components: Arc<SharedComponents>,
) -> Response {
utils::internal_error_static("Responses API execution not yet implemented")
error::internal_error("Responses API execution not yet implemented")
}
/// Execute Harmony Responses API request through all pipeline stages
......@@ -451,7 +451,7 @@ impl RequestPipeline {
.responses_iteration_result
.take()
.ok_or_else(|| {
utils::internal_error_static("No ResponsesIterationResult produced by pipeline")
error::internal_error("No ResponsesIterationResult produced by pipeline")
})
}
......@@ -501,6 +501,6 @@ impl RequestPipeline {
.response
.execution_result
.take()
.ok_or_else(|| utils::internal_error_static("No ExecutionResult produced by pipeline"))
.ok_or_else(|| error::internal_error("No ExecutionResult produced by pipeline"))
}
}
......@@ -11,7 +11,7 @@ use tracing::error;
use super::{
context::{DispatchMetadata, ExecutionResult},
utils,
error, utils,
};
use crate::{
grpc_client::proto,
......@@ -104,7 +104,7 @@ impl ResponseProcessor {
};
if all_responses.is_empty() {
return Err(utils::internal_error_static("No responses from server"));
return Err(error::internal_error("No responses from server"));
}
Ok(all_responses)
......@@ -332,7 +332,7 @@ impl ResponseProcessor {
{
Ok(choice) => choices.push(choice),
Err(e) => {
return Err(utils::internal_error_message(format!(
return Err(error::internal_error(format!(
"Failed to process choice {}: {}",
index, e
)));
......@@ -447,7 +447,7 @@ impl ResponseProcessor {
let outputs = match stop_decoder.process_tokens(&complete.output_ids) {
Ok(outputs) => outputs,
Err(e) => {
return Err(utils::internal_error_message(format!(
return Err(error::internal_error(format!(
"Failed to process tokens: {}",
e
)))
......
......@@ -67,7 +67,10 @@ use crate::{
ResponseStatus, ResponsesRequest, ResponsesResponse, ResponsesUsage,
},
},
routers::openai::{conversations::persist_conversation_items, mcp::ensure_request_mcp_client},
routers::{
grpc::error,
openai::{conversations::persist_conversation_items, mcp::ensure_request_mcp_client},
},
};
// ============================================================================
......@@ -863,11 +866,9 @@ async fn execute_without_mcp(
model_id: Option<String>,
response_id: Option<String>,
) -> Result<ResponsesResponse, Response> {
use crate::routers::grpc::utils;
// Convert ResponsesRequest → ChatCompletionRequest
let chat_request = conversions::responses_to_chat(modified_request)
.map_err(|e| utils::bad_request_error(format!("Failed to convert request: {}", e)))?;
.map_err(|e| error::bad_request(format!("Failed to convert request: {}", e)))?;
// Execute chat pipeline (errors already have proper HTTP status codes)
let chat_response = ctx
......@@ -883,9 +884,8 @@ async fn execute_without_mcp(
.await?; // Preserve the Response error as-is
// Convert ChatCompletionResponse → ResponsesResponse
conversions::chat_to_responses(&chat_response, original_request, response_id).map_err(|e| {
utils::internal_error_message(format!("Failed to convert to responses format: {}", e))
})
conversions::chat_to_responses(&chat_response, original_request, response_id)
.map_err(|e| error::internal_error(format!("Failed to convert to responses format: {}", e)))
}
/// Load conversation history and response chains, returning modified request
......@@ -962,15 +962,10 @@ async fn load_conversation_history(
.conversation_storage
.get_conversation(&conv_id)
.await
.map_err(|e| {
crate::routers::grpc::utils::internal_error_message(format!(
"Failed to check conversation: {}",
e
))
})?;
.map_err(|e| error::internal_error(format!("Failed to check conversation: {}", e)))?;
if conversation.is_none() {
return Err(crate::routers::grpc::utils::bad_request_error(format!(
return Err(error::not_found(format!(
"Conversation '{}' not found. Please create the conversation first using the conversations API.",
conv_id_str
)));
......
......@@ -19,6 +19,7 @@ use tracing::{debug, warn};
use uuid::Uuid;
use super::{
super::error,
conversions,
streaming::{OutputItemType, ResponseStreamEventEmitter},
};
......@@ -247,12 +248,8 @@ pub(super) async fn execute_tool_loop(
loop {
// Convert to chat request
let mut chat_request = conversions::responses_to_chat(&current_request).map_err(|e| {
crate::routers::grpc::utils::bad_request_error(format!(
"Failed to convert request: {}",
e
))
})?;
let mut chat_request = conversions::responses_to_chat(&current_request)
.map_err(|e| error::bad_request(format!("Failed to convert request: {}", e)))?;
// Add MCP tools to chat request so LLM knows about them
chat_request.tools = Some(chat_tools.clone());
......@@ -301,10 +298,7 @@ pub(super) async fn execute_tool_loop(
response_id.clone(),
)
.map_err(|e| {
crate::routers::grpc::utils::internal_error_message(format!(
"Failed to convert to responses format: {}",
e
))
error::internal_error(format!("Failed to convert to responses format: {}", e))
})?;
// Mark as completed but with incomplete details
......@@ -423,10 +417,7 @@ pub(super) async fn execute_tool_loop(
response_id.clone(),
)
.map_err(|e| {
crate::routers::grpc::utils::internal_error_message(format!(
"Failed to convert to responses format: {}",
e
))
error::internal_error(format!("Failed to convert to responses format: {}", e))
})?;
// Inject MCP metadata into output
......
......@@ -6,7 +6,7 @@ use axum::response::Response;
use super::PipelineStage;
use crate::routers::grpc::{
context::{ClientSelection, RequestContext, WorkerSelection},
utils,
error, utils,
};
/// Client acquisition stage: Get gRPC clients from selected workers
......@@ -19,7 +19,7 @@ impl PipelineStage for ClientAcquisitionStage {
.state
.workers
.as_ref()
.ok_or_else(|| utils::internal_error_static("Worker selection not completed"))?;
.ok_or_else(|| error::internal_error("Worker selection not completed"))?;
let clients = match workers {
WorkerSelection::Single { worker } => {
......
......@@ -8,7 +8,7 @@ use axum::response::Response;
use super::PipelineStage;
use crate::routers::grpc::{
context::{DispatchMetadata, RequestContext, RequestType, WorkerSelection},
utils,
error,
};
/// Dispatch metadata stage: Prepare metadata for dispatch
......@@ -21,7 +21,7 @@ impl PipelineStage for DispatchMetadataStage {
.state
.proto_request
.as_ref()
.ok_or_else(|| utils::internal_error_static("Proto request not built"))?;
.ok_or_else(|| error::internal_error("Proto request not built"))?;
let request_id = proto_request.request_id.clone();
let model = match &ctx.input.request_type {
......
......@@ -10,7 +10,7 @@ use crate::{
protocols::{chat::ChatCompletionRequest, common::InputIds, generate::GenerateRequest},
routers::grpc::{
context::{PreparationOutput, RequestContext, RequestType},
utils,
error, utils,
},
tokenizer::traits::Tokenizer,
};
......@@ -56,7 +56,7 @@ impl PreparationStage {
match utils::process_chat_messages(&body_ref, &*ctx.components.tokenizer) {
Ok(msgs) => msgs,
Err(e) => {
return Err(utils::bad_request_error(e));
return Err(error::bad_request(e));
}
};
......@@ -64,10 +64,7 @@ impl PreparationStage {
let encoding = match ctx.components.tokenizer.encode(&processed_messages.text) {
Ok(encoding) => encoding,
Err(e) => {
return Err(utils::internal_error_message(format!(
"Tokenization failed: {}",
e
)));
return Err(error::internal_error(format!("Tokenization failed: {}", e)));
}
};
......@@ -75,9 +72,8 @@ impl PreparationStage {
// Step 4: Build tool constraints if needed
let tool_call_constraint = if let Some(tools) = body_ref.tools.as_ref() {
utils::generate_tool_constraints(tools, &request.tool_choice, &request.model).map_err(
|e| utils::bad_request_error(format!("Invalid tool configuration: {}", e)),
)?
utils::generate_tool_constraints(tools, &request.tool_choice, &request.model)
.map_err(|e| error::bad_request(format!("Invalid tool configuration: {}", e)))?
} else {
None
};
......@@ -124,7 +120,7 @@ impl PreparationStage {
let (original_text, token_ids) = match self.resolve_generate_input(ctx, request) {
Ok(res) => res,
Err(msg) => {
return Err(utils::bad_request_error(msg));
return Err(error::bad_request(msg));
}
};
......
......@@ -15,7 +15,7 @@ use crate::{
grpc_client::proto,
routers::grpc::{
context::{ClientSelection, RequestContext, RequestType, WorkerSelection},
utils,
error,
},
};
......@@ -37,13 +37,13 @@ impl PipelineStage for RequestBuildingStage {
.state
.preparation
.as_ref()
.ok_or_else(|| utils::internal_error_static("Preparation not completed"))?;
.ok_or_else(|| error::internal_error("Preparation not completed"))?;
let clients = ctx
.state
.clients
.as_ref()
.ok_or_else(|| utils::internal_error_static("Client acquisition not completed"))?;
.ok_or_else(|| error::internal_error("Client acquisition not completed"))?;
// Get client for building request (use prefill client if PD mode)
let builder_client = match clients {
......@@ -69,9 +69,7 @@ impl PipelineStage for RequestBuildingStage {
.clone(),
prep.tool_constraints.clone(),
)
.map_err(|e| {
utils::bad_request_error(format!("Invalid request parameters: {}", e))
})?
.map_err(|e| error::bad_request(format!("Invalid request parameters: {}", e)))?
}
RequestType::Generate(request) => {
let request_id = request
......@@ -86,7 +84,7 @@ impl PipelineStage for RequestBuildingStage {
prep.original_text.clone(),
prep.token_ids.clone(),
)
.map_err(utils::bad_request_error)?
.map_err(error::bad_request)?
}
RequestType::Responses(_request) => {
// Responses API builds request during the MCP loop
......
......@@ -5,13 +5,15 @@ use axum::response::Response;
use super::PipelineStage;
use crate::{
grpc_client::proto,
grpc_client::{proto, sglang_scheduler::AbortOnDropStream},
routers::grpc::{
context::{ClientSelection, ExecutionResult, RequestContext},
utils,
error,
},
};
type StreamResult = Result<AbortOnDropStream, Box<dyn std::error::Error + Send + Sync>>;
/// Request execution stage: Execute gRPC requests (single or dual dispatch)
pub struct RequestExecutionStage {
mode: ExecutionMode,
......@@ -37,13 +39,13 @@ impl PipelineStage for RequestExecutionStage {
.state
.proto_request
.take()
.ok_or_else(|| utils::internal_error_static("Proto request not built"))?;
.ok_or_else(|| error::internal_error("Proto request not built"))?;
let clients = ctx
.state
.clients
.as_mut()
.ok_or_else(|| utils::internal_error_static("Client acquisition not completed"))?;
.ok_or_else(|| error::internal_error("Client acquisition not completed"))?;
let result = match self.mode {
ExecutionMode::Single => self.execute_single(proto_request, clients).await?,
......@@ -70,11 +72,12 @@ impl RequestExecutionStage {
) -> Result<ExecutionResult, Response> {
let client = clients
.single_mut()
.ok_or_else(|| utils::internal_error_static("Expected single client but got dual"))?;
.ok_or_else(|| error::internal_error("Expected single client but got dual"))?;
let stream = client.generate(proto_request).await.map_err(|e| {
utils::internal_error_message(format!("Failed to start generation: {}", e))
})?;
let stream = client
.generate(proto_request)
.await
.map_err(|e| error::internal_error(format!("Failed to start generation: {}", e)))?;
Ok(ExecutionResult::Single { stream })
}
......@@ -86,12 +89,12 @@ impl RequestExecutionStage {
) -> Result<ExecutionResult, Response> {
let (prefill_client, decode_client) = clients
.dual_mut()
.ok_or_else(|| utils::internal_error_static("Expected dual clients but got single"))?;
.ok_or_else(|| error::internal_error("Expected dual clients but got single"))?;
let prefill_request = proto_request.clone();
let decode_request = proto_request;
let (prefill_result, decode_result) = tokio::join!(
let (prefill_result, decode_result): (StreamResult, StreamResult) = tokio::join!(
prefill_client.generate(prefill_request),
decode_client.generate(decode_request)
);
......@@ -100,7 +103,7 @@ impl RequestExecutionStage {
let prefill_stream = match prefill_result {
Ok(s) => s,
Err(e) => {
return Err(utils::internal_error_message(format!(
return Err(error::internal_error(format!(
"Prefill worker failed to start: {}",
e
)));
......@@ -111,7 +114,7 @@ impl RequestExecutionStage {
let decode_stream = match decode_result {
Ok(s) => s,
Err(e) => {
return Err(utils::internal_error_message(format!(
return Err(error::internal_error(format!(
"Decode worker failed to start: {}",
e
)));
......
......@@ -11,7 +11,7 @@ use axum::response::Response;
use super::PipelineStage;
use crate::routers::grpc::{
context::{FinalResponse, RequestContext, RequestType},
processing, streaming, utils,
error, processing, streaming,
};
/// Response processing stage: Handles both streaming and non-streaming responses
......@@ -42,7 +42,7 @@ impl PipelineStage for ResponseProcessingStage {
match &ctx.input.request_type {
RequestType::Chat(_) => self.process_chat_response(ctx).await,
RequestType::Generate(_) => self.process_generate_response(ctx).await,
RequestType::Responses(_) => Err(utils::bad_request_error(
RequestType::Responses(_) => Err(error::bad_request(
"Responses API processing must be handled by responses handler".to_string(),
)),
}
......@@ -66,14 +66,14 @@ impl ResponseProcessingStage {
.response
.execution_result
.take()
.ok_or_else(|| utils::internal_error_static("No execution result"))?;
.ok_or_else(|| error::internal_error("No execution result"))?;
// Get dispatch metadata (needed by both streaming and non-streaming)
let dispatch = ctx
.state
.dispatch
.as_ref()
.ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?
.ok_or_else(|| error::internal_error("Dispatch metadata not set"))?
.clone();
if is_streaming {
......@@ -100,7 +100,7 @@ impl ResponseProcessingStage {
.response
.stop_decoder
.as_mut()
.ok_or_else(|| utils::internal_error_static("Stop decoder not initialized"))?;
.ok_or_else(|| error::internal_error("Stop decoder not initialized"))?;
let response = self
.processor
......@@ -132,14 +132,14 @@ impl ResponseProcessingStage {
.response
.execution_result
.take()
.ok_or_else(|| utils::internal_error_static("No execution result"))?;
.ok_or_else(|| error::internal_error("No execution result"))?;
// Get dispatch metadata (needed by both streaming and non-streaming)
let dispatch = ctx
.state
.dispatch
.as_ref()
.ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?
.ok_or_else(|| error::internal_error("Dispatch metadata not set"))?
.clone();
if is_streaming {
......@@ -162,7 +162,7 @@ impl ResponseProcessingStage {
.response
.stop_decoder
.as_mut()
.ok_or_else(|| utils::internal_error_static("Stop decoder not initialized"))?;
.ok_or_else(|| error::internal_error("Stop decoder not initialized"))?;
let result_array = self
.processor
......
......@@ -12,7 +12,7 @@ use crate::{
policies::PolicyRegistry,
routers::grpc::{
context::{RequestContext, WorkerSelection},
utils,
error,
},
};
......@@ -51,7 +51,7 @@ impl PipelineStage for WorkerSelectionStage {
.state
.preparation
.as_ref()
.ok_or_else(|| utils::internal_error_static("Preparation stage not completed"))?;
.ok_or_else(|| error::internal_error("Preparation stage not completed"))?;
// For Harmony, use selection_text produced during Harmony encoding
// Otherwise, use original_text from regular preparation
......@@ -66,7 +66,7 @@ impl PipelineStage for WorkerSelectionStage {
match self.select_single_worker(ctx.input.model_id.as_deref(), text) {
Some(w) => WorkerSelection::Single { worker: w },
None => {
return Err(utils::service_unavailable_error(format!(
return Err(error::service_unavailable(format!(
"No available workers for model: {:?}",
ctx.input.model_id
)));
......@@ -77,7 +77,7 @@ impl PipelineStage for WorkerSelectionStage {
match self.select_pd_pair(ctx.input.model_id.as_deref(), text) {
Some((prefill, decode)) => WorkerSelection::Dual { prefill, decode },
None => {
return Err(utils::service_unavailable_error(format!(
return Err(error::service_unavailable(format!(
"No available PD worker pairs for model: {:?}",
ctx.input.model_id
)));
......
......@@ -2,17 +2,13 @@
use std::{collections::HashMap, sync::Arc};
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
use axum::response::Response;
use futures::StreamExt;
use serde_json::{json, Map, Value};
use tracing::{error, warn};
use uuid::Uuid;
use super::ProcessedMessages;
use super::{error, ProcessedMessages};
pub use crate::tokenizer::StopSequenceDecoder;
use crate::{
core::Worker,
......@@ -40,8 +36,8 @@ pub async fn get_grpc_client_from_worker(
let client_arc = worker
.get_grpc_client()
.await
.map_err(|e| internal_error_message(format!("Failed to get gRPC client: {}", e)))?
.ok_or_else(|| internal_error_static("Selected worker is not configured for gRPC"))?;
.map_err(|e| error::internal_error(format!("Failed to get gRPC client: {}", e)))?
.ok_or_else(|| error::internal_error("Selected worker is not configured for gRPC"))?;
Ok((*client_arc).clone())
}
......@@ -433,67 +429,6 @@ pub fn process_chat_messages(
})
}
/// Error response helpers (shared between regular and PD routers)
pub fn internal_error_static(msg: &'static str) -> Response {
error!("{}", msg);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({
"error": {
"message": msg,
"type": "internal_error",
"code": 500
}
})),
)
.into_response()
}
pub fn internal_error_message(message: String) -> Response {
error!("{}", message);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({
"error": {
"message": message,
"type": "internal_error",
"code": 500
}
})),
)
.into_response()
}
pub fn bad_request_error(message: String) -> Response {
error!("{}", message);
(
StatusCode::BAD_REQUEST,
Json(json!({
"error": {
"message": message,
"type": "invalid_request_error",
"code": 400
}
})),
)
.into_response()
}
pub fn service_unavailable_error(message: String) -> Response {
warn!("{}", message);
(
StatusCode::SERVICE_UNAVAILABLE,
Json(json!({
"error": {
"message": message,
"type": "service_unavailable",
"code": 503
}
})),
)
.into_response()
}
/// Create a StopSequenceDecoder from stop parameters
pub fn create_stop_decoder(
tokenizer: &Arc<dyn Tokenizer>,
......@@ -646,7 +581,7 @@ pub async fn collect_stream_responses(
Some(Error(err)) => {
error!("{} error: {}", worker_name, err.message);
// Don't mark as completed - let Drop send abort for error cases
return Err(internal_error_message(format!(
return Err(error::internal_error(format!(
"{} generation failed: {}",
worker_name, err.message
)));
......@@ -662,7 +597,7 @@ pub async fn collect_stream_responses(
Err(e) => {
error!("{} stream error: {:?}", worker_name, e);
// Don't mark as completed - let Drop send abort for error cases
return Err(internal_error_message(format!(
return Err(error::internal_error(format!(
"{} stream failed: {}",
worker_name, e
)));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment