Unverified Commit d05a968b authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][grpc] Add `ResponsesContext` and fix error propagation in responses api (#12164)

parent 94aad0de
......@@ -1098,10 +1098,15 @@ impl RequestPipeline {
}
}
/// Execute chat pipeline for responses endpoint (Result-based for easier composition)
/// Execute chat pipeline for responses endpoint
///
/// This is used by the responses module and returns Result instead of Response.
/// It also supports background mode cancellation via background_tasks.
/// TODO: The support for background tasks is not scalable. Consider replacing this with
/// a better design in the future.
/// Used by ALL non-streaming /v1/responses requests (both sync and background modes).
/// Uses the same 7 pipeline stages as execute_chat(), with three differences:
/// 1. Returns Result<ChatCompletionResponse, Response> for tool_loop composition
/// 2. Disallows streaming (responses endpoint uses different SSE format)
/// 3. Injects hooks for background task cancellation (only active when response_id provided)
pub async fn execute_chat_for_responses(
&self,
request: Arc<ChatCompletionRequest>,
......@@ -1110,7 +1115,7 @@ impl RequestPipeline {
components: Arc<SharedComponents>,
response_id: Option<String>,
background_tasks: Option<Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>>,
) -> Result<ChatCompletionResponse, String> {
) -> Result<ChatCompletionResponse, Response> {
let mut ctx = RequestContext::for_chat(request, headers, model_id, components);
// Execute each stage in sequence
......@@ -1118,7 +1123,9 @@ impl RequestPipeline {
match stage.execute(&mut ctx).await {
Ok(Some(_response)) => {
// Streaming not supported for responses sync mode
return Err("Streaming is not supported in this context".to_string());
return Err(utils::bad_request_error(
"Streaming is not supported in this context".to_string(),
));
}
Ok(None) => {
let stage_name = stage.name();
......@@ -1158,14 +1165,14 @@ impl RequestPipeline {
continue;
}
Err(response) => {
// Error occurred
// Error occurred - return the response as-is to preserve HTTP status codes
error!(
"Stage {} ({}) failed with status {}",
idx + 1,
stage.name(),
response.status()
);
return Err(format!("Pipeline stage {} failed", stage.name()));
return Err(response);
}
}
}
......@@ -1173,10 +1180,10 @@ impl RequestPipeline {
// Extract final response
match ctx.state.response.final_response {
Some(FinalResponse::Chat(response)) => Ok(response),
Some(FinalResponse::Generate(_)) => {
Err("Internal error: wrong response type".to_string())
}
None => Err("No response produced".to_string()),
Some(FinalResponse::Generate(_)) => Err(utils::internal_error_static(
"Internal error: wrong response type",
)),
None => Err(utils::internal_error_static("No response produced")),
}
}
}
//! Context for /v1/responses endpoint handlers
//!
//! Bundles all dependencies needed by responses handlers to avoid passing
//! 10+ parameters to every function (fixes clippy::too_many_arguments).
use std::{collections::HashMap, sync::Arc};
use tokio::sync::RwLock;
use super::types::BackgroundTaskInfo;
use crate::{
core::WorkerRegistry,
data_connector::{
SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage,
},
mcp::McpManager,
routers::grpc::{context::SharedComponents, pipeline::RequestPipeline},
};
/// Context for /v1/responses endpoint
///
/// All fields are Arc/shared references, so cloning this context is cheap.
#[derive(Clone)]
pub struct ResponsesContext {
/// Chat pipeline for executing requests
pub pipeline: Arc<RequestPipeline>,
/// Shared components (tokenizer, parsers, worker_registry)
pub components: Arc<SharedComponents>,
/// Worker registry for validation
pub worker_registry: Arc<WorkerRegistry>,
/// Response storage backend
pub response_storage: SharedResponseStorage,
/// Conversation storage backend
pub conversation_storage: SharedConversationStorage,
/// Conversation item storage backend
pub conversation_item_storage: SharedConversationItemStorage,
/// MCP manager for tool support
pub mcp_manager: Arc<McpManager>,
/// Background task handles for cancellation support
pub background_tasks: Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>,
}
impl ResponsesContext {
/// Create a new responses context
pub fn new(
pipeline: Arc<RequestPipeline>,
components: Arc<SharedComponents>,
worker_registry: Arc<WorkerRegistry>,
response_storage: SharedResponseStorage,
conversation_storage: SharedConversationStorage,
conversation_item_storage: SharedConversationItemStorage,
mcp_manager: Arc<McpManager>,
) -> Self {
Self {
pipeline,
components,
worker_registry,
response_storage,
conversation_storage,
conversation_item_storage,
mcp_manager,
background_tasks: Arc::new(RwLock::new(HashMap::new())),
}
}
}
......@@ -9,6 +9,7 @@
//! - Response persistence
// Module declarations
pub mod context;
mod conversions;
mod handlers;
pub mod streaming;
......@@ -16,5 +17,6 @@ pub mod tool_loop;
pub mod types;
// Public exports
pub use context::ResponsesContext;
pub use handlers::{cancel_response_impl, get_response_impl, route_responses};
pub use types::BackgroundTaskInfo;
......@@ -13,7 +13,7 @@ use axum::{
};
use bytes::Bytes;
use serde_json::json;
use tokio::sync::{mpsc, RwLock};
use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, warn};
use uuid::Uuid;
......@@ -21,24 +21,14 @@ use uuid::Uuid;
use super::{
conversions,
streaming::{OutputItemType, ResponseStreamEventEmitter},
types::BackgroundTaskInfo,
};
/// This is a re-export of the shared implementation from openai::mcp
pub(super) use crate::routers::openai::mcp::ensure_request_mcp_client as create_mcp_manager_from_request;
use crate::{
data_connector::{
SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage,
},
protocols::{
use crate::protocols::{
chat::ChatCompletionResponse,
common::{Tool, ToolChoice, ToolChoiceValue},
responses::{
McpToolInfo, ResponseContentPart, ResponseInput, ResponseInputOutputItem,
ResponseOutputItem, ResponseStatus, ResponseToolType, ResponsesRequest,
ResponsesResponse,
},
ResponseOutputItem, ResponseStatus, ResponseToolType, ResponsesRequest, ResponsesResponse,
},
routers::grpc::{context::SharedComponents, pipeline::RequestPipeline},
};
/// Extract function call from a chat completion response
......@@ -221,18 +211,14 @@ fn build_mcp_call_item(
/// 2. Checks if response has tool calls
/// 3. If yes, executes MCP tools and builds resume request
/// 4. Repeats until no more tool calls or limit reached
#[allow(clippy::too_many_arguments)]
pub(super) async fn execute_tool_loop(
pipeline: &RequestPipeline,
ctx: &super::context::ResponsesContext,
mut current_request: ResponsesRequest,
original_request: &ResponsesRequest,
headers: Option<http::HeaderMap>,
model_id: Option<String>,
components: Arc<SharedComponents>,
mcp_manager: Arc<crate::mcp::McpManager>,
response_id: Option<String>,
background_tasks: Option<Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>>,
) -> Result<ResponsesResponse, String> {
) -> Result<ResponsesResponse, Response> {
// Get server label from original request tools
let server_label = original_request
.tools
......@@ -257,31 +243,35 @@ pub(super) async fn execute_tool_loop(
);
// Get MCP tools and convert to chat format (do this once before loop)
let mcp_tools = mcp_manager.list_tools();
let mcp_tools = ctx.mcp_manager.list_tools();
let chat_tools = convert_mcp_tools_to_chat_tools(&mcp_tools);
debug!("Converted {} MCP tools to chat format", chat_tools.len());
loop {
// Convert to chat request
let mut chat_request = conversions::responses_to_chat(&current_request)
.map_err(|e| format!("Failed to convert request: {}", e))?;
let mut chat_request = conversions::responses_to_chat(&current_request).map_err(|e| {
crate::routers::grpc::utils::bad_request_error(format!(
"Failed to convert request: {}",
e
))
})?;
// Add MCP tools to chat request so LLM knows about them
chat_request.tools = Some(chat_tools.clone());
chat_request.tool_choice = Some(ToolChoice::Value(ToolChoiceValue::Auto));
// Execute chat pipeline
let chat_response = pipeline
// Execute chat pipeline (errors already have proper HTTP status codes)
let chat_response = ctx
.pipeline
.execute_chat_for_responses(
Arc::new(chat_request),
headers.clone(),
model_id.clone(),
components.clone(),
ctx.components.clone(),
response_id.clone(),
background_tasks.clone(),
Some(ctx.background_tasks.clone()),
)
.await
.map_err(|e| format!("Pipeline execution failed: {}", e))?;
.await?;
// Check for function calls
if let Some((call_id, tool_name, args_json_str)) =
......@@ -312,7 +302,12 @@ pub(super) async fn execute_tool_loop(
original_request,
response_id.clone(),
)
.map_err(|e| format!("Failed to convert to responses format: {}", e))?;
.map_err(|e| {
crate::routers::grpc::utils::internal_error_message(format!(
"Failed to convert to responses format: {}",
e
))
})?;
// Mark as completed but with incomplete details
responses_response.status = ResponseStatus::Completed;
......@@ -329,7 +324,8 @@ pub(super) async fn execute_tool_loop(
"Calling MCP tool '{}' with args: {}",
tool_name, args_json_str
);
let (output_str, success, error) = match mcp_manager
let (output_str, success, error) = match ctx
.mcp_manager
.call_tool(tool_name.as_str(), args_json_str.as_str())
.await
{
......@@ -428,12 +424,17 @@ pub(super) async fn execute_tool_loop(
original_request,
response_id.clone(),
)
.map_err(|e| format!("Failed to convert to responses format: {}", e))?;
.map_err(|e| {
crate::routers::grpc::utils::internal_error_message(format!(
"Failed to convert to responses format: {}",
e
))
})?;
// Inject MCP metadata into output
if state.total_calls > 0 {
// Prepend mcp_list_tools item
let mcp_list_tools = build_mcp_list_tools_item(&mcp_manager, &server_label);
let mcp_list_tools = build_mcp_list_tools_item(&ctx.mcp_manager, &server_label);
responses_response.output.insert(0, mcp_list_tools);
// Append all mcp_call items at the end
......@@ -455,52 +456,28 @@ pub(super) async fn execute_tool_loop(
/// This streams each iteration's response to the client while accumulating
/// to check for tool calls. If tool calls are found, executes them and
/// continues with the next streaming iteration.
#[allow(clippy::too_many_arguments)]
pub(super) async fn execute_tool_loop_streaming(
pipeline: &RequestPipeline,
ctx: &super::context::ResponsesContext,
current_request: ResponsesRequest,
original_request: &ResponsesRequest,
headers: Option<http::HeaderMap>,
model_id: Option<String>,
components: Arc<SharedComponents>,
mcp_manager: Arc<crate::mcp::McpManager>,
response_storage: SharedResponseStorage,
conversation_storage: SharedConversationStorage,
conversation_item_storage: SharedConversationItemStorage,
) -> Response {
// Get server label
let server_label = original_request
.tools
.as_ref()
.and_then(|tools| {
tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
.and_then(|t| t.server_label.clone())
})
.unwrap_or_else(|| "request-mcp".to_string());
// Create SSE channel for client
let (tx, rx) = mpsc::unbounded_channel::<Result<Bytes, std::io::Error>>();
// Clone data for background task
let pipeline_clone = pipeline.clone();
let ctx_clone = ctx.clone();
let original_request_clone = original_request.clone();
// Spawn background task for tool loop
tokio::spawn(async move {
let result = execute_tool_loop_streaming_internal(
&pipeline_clone,
&ctx_clone,
current_request,
&original_request_clone,
headers,
model_id,
components,
mcp_manager,
server_label,
response_storage,
conversation_storage,
conversation_item_storage,
tx.clone(),
)
.await;
......@@ -546,21 +523,26 @@ pub(super) async fn execute_tool_loop_streaming(
}
/// Internal streaming tool loop implementation
#[allow(clippy::too_many_arguments)]
async fn execute_tool_loop_streaming_internal(
pipeline: &RequestPipeline,
ctx: &super::context::ResponsesContext,
mut current_request: ResponsesRequest,
original_request: &ResponsesRequest,
headers: Option<http::HeaderMap>,
model_id: Option<String>,
components: Arc<SharedComponents>,
mcp_manager: Arc<crate::mcp::McpManager>,
server_label: String,
_response_storage: SharedResponseStorage,
_conversation_storage: SharedConversationStorage,
_conversation_item_storage: SharedConversationItemStorage,
tx: mpsc::UnboundedSender<Result<Bytes, std::io::Error>>,
) -> Result<(), String> {
// Extract server label from original request tools
let server_label = original_request
.tools
.as_ref()
.and_then(|tools| {
tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
.and_then(|t| t.server_label.clone())
})
.unwrap_or_else(|| "request-mcp".to_string());
const MAX_ITERATIONS: usize = 10;
let mut state = ToolLoopState::new(original_request.input.clone(), server_label.clone());
let max_tool_calls = original_request.max_tool_calls.map(|n| n as usize);
......@@ -581,7 +563,7 @@ async fn execute_tool_loop_streaming_internal(
emitter.send_event(&event, &tx)?;
// Get MCP tools and convert to chat format (do this once before loop)
let mcp_tools = mcp_manager.list_tools();
let mcp_tools = ctx.mcp_manager.list_tools();
let chat_tools = convert_mcp_tools_to_chat_tools(&mcp_tools);
debug!(
"Streaming: Converted {} MCP tools to chat format",
......@@ -670,12 +652,13 @@ async fn execute_tool_loop_streaming_internal(
chat_request.tool_choice = Some(ToolChoice::Value(ToolChoiceValue::Auto));
// Execute chat streaming
let response = pipeline
let response = ctx
.pipeline
.execute_chat(
Arc::new(chat_request),
headers.clone(),
model_id.clone(),
components.clone(),
ctx.components.clone(),
)
.await;
......@@ -758,7 +741,8 @@ async fn execute_tool_loop_streaming_internal(
"Calling MCP tool '{}' with args: {}",
tool_name, args_json_str
);
let (output_str, success, error) = match mcp_manager
let (output_str, success, error) = match ctx
.mcp_manager
.call_tool(tool_name.as_str(), args_json_str.as_str())
.await
{
......
// gRPC Router Implementation
use std::{collections::HashMap, sync::Arc};
use std::sync::Arc;
use async_trait::async_trait;
use axum::{
......@@ -9,22 +9,13 @@ use axum::{
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
};
use tokio::sync::RwLock;
use tracing::debug;
use super::{
context::SharedComponents,
pipeline::RequestPipeline,
responses::{self, BackgroundTaskInfo},
};
use super::{context::SharedComponents, pipeline::RequestPipeline, responses};
use crate::{
app_context::AppContext,
config::types::RetryConfig,
core::WorkerRegistry,
data_connector::{
SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage,
},
mcp::McpManager,
policies::PolicyRegistry,
protocols::{
chat::ChatCompletionRequest,
......@@ -57,13 +48,8 @@ pub struct GrpcRouter {
configured_tool_parser: Option<String>,
pipeline: RequestPipeline,
shared_components: Arc<SharedComponents>,
// Storage backends for /v1/responses support
response_storage: SharedResponseStorage,
conversation_storage: SharedConversationStorage,
conversation_item_storage: SharedConversationItemStorage,
mcp_manager: Arc<McpManager>,
// Background task handles for cancellation support (includes gRPC client for Python abort)
background_tasks: Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>,
// Responses context (bundles all /v1/responses dependencies: storage, MCP, background_tasks)
responses_context: responses::ResponsesContext,
}
impl GrpcRouter {
......@@ -89,18 +75,6 @@ impl GrpcRouter {
let worker_registry = ctx.worker_registry.clone();
let policy_registry = ctx.policy_registry.clone();
// Extract storage backends from context
let response_storage = ctx.response_storage.clone();
let conversation_storage = ctx.conversation_storage.clone();
let conversation_item_storage = ctx.conversation_item_storage.clone();
// Get MCP manager from app context
let mcp_manager = ctx
.mcp_manager
.get()
.ok_or_else(|| "gRPC router requires MCP manager".to_string())?
.clone();
// Create shared components for pipeline
let shared_components = Arc::new(SharedComponents {
tokenizer: tokenizer.clone(),
......@@ -119,6 +93,20 @@ impl GrpcRouter {
ctx.configured_reasoning_parser.clone(),
);
// Create responses context with all dependencies
let responses_context = responses::ResponsesContext::new(
Arc::new(pipeline.clone()),
shared_components.clone(),
worker_registry.clone(),
ctx.response_storage.clone(),
ctx.conversation_storage.clone(),
ctx.conversation_item_storage.clone(),
ctx.mcp_manager
.get()
.ok_or_else(|| "gRPC router requires MCP manager".to_string())?
.clone(),
);
Ok(GrpcRouter {
worker_registry,
policy_registry,
......@@ -132,11 +120,7 @@ impl GrpcRouter {
configured_tool_parser: ctx.configured_tool_parser.clone(),
pipeline,
shared_components,
response_storage,
conversation_storage,
conversation_item_storage,
mcp_manager,
background_tasks: Arc::new(RwLock::new(HashMap::new())),
responses_context,
})
}
......@@ -254,26 +238,11 @@ impl RouterTrait for GrpcRouter {
body: &ResponsesRequest,
model_id: Option<&str>,
) -> Response {
// Use responses module for ALL requests (streaming and non-streaming)
// Responses module handles:
// - Request validation (previous_response_id XOR conversation)
// - Loading response chain / conversation history from storage
// - Conversion: ResponsesRequest → ChatCompletionRequest
// - Execution through chat pipeline stages
// - Conversion: ChatCompletionResponse → ResponsesResponse
// - Response persistence
// - MCP tool loop wrapper (future)
responses::route_responses(
&self.pipeline,
&self.responses_context,
Arc::new(body.clone()),
headers.cloned(),
model_id.map(|s| s.to_string()),
self.shared_components.clone(),
self.response_storage.clone(),
self.conversation_storage.clone(),
self.conversation_item_storage.clone(),
self.mcp_manager.clone(),
self.background_tasks.clone(),
)
.await
}
......@@ -284,12 +253,11 @@ impl RouterTrait for GrpcRouter {
response_id: &str,
_params: &ResponsesGetParams,
) -> Response {
responses::get_response_impl(&self.response_storage, response_id).await
responses::get_response_impl(&self.responses_context, response_id).await
}
async fn cancel_response(&self, _headers: Option<&HeaderMap>, response_id: &str) -> Response {
responses::cancel_response_impl(&self.response_storage, &self.background_tasks, response_id)
.await
responses::cancel_response_impl(&self.responses_context, response_id).await
}
async fn route_classify(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment