Unverified Commit d05a968b authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][grpc] Add `ResponsesContext` and fix error propagation in responses api (#12164)

parent 94aad0de
......@@ -1098,10 +1098,15 @@ impl RequestPipeline {
}
}
/// Execute chat pipeline for responses endpoint (Result-based for easier composition)
/// Execute chat pipeline for responses endpoint
///
/// This is used by the responses module and returns Result instead of Response.
/// It also supports background mode cancellation via background_tasks.
/// TODO: The support for background tasks is not scalable. Consider replacing this with
/// a better design in the future.
/// Used by ALL non-streaming /v1/responses requests (both sync and background modes).
/// Uses the same 7 pipeline stages as execute_chat(), with three differences:
/// 1. Returns Result<ChatCompletionResponse, Response> for tool_loop composition
/// 2. Disallows streaming (responses endpoint uses different SSE format)
/// 3. Injects hooks for background task cancellation (only active when response_id provided)
pub async fn execute_chat_for_responses(
&self,
request: Arc<ChatCompletionRequest>,
......@@ -1110,7 +1115,7 @@ impl RequestPipeline {
components: Arc<SharedComponents>,
response_id: Option<String>,
background_tasks: Option<Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>>,
) -> Result<ChatCompletionResponse, String> {
) -> Result<ChatCompletionResponse, Response> {
let mut ctx = RequestContext::for_chat(request, headers, model_id, components);
// Execute each stage in sequence
......@@ -1118,7 +1123,9 @@ impl RequestPipeline {
match stage.execute(&mut ctx).await {
Ok(Some(_response)) => {
// Streaming not supported for responses sync mode
return Err("Streaming is not supported in this context".to_string());
return Err(utils::bad_request_error(
"Streaming is not supported in this context".to_string(),
));
}
Ok(None) => {
let stage_name = stage.name();
......@@ -1158,14 +1165,14 @@ impl RequestPipeline {
continue;
}
Err(response) => {
// Error occurred
// Error occurred - return the response as-is to preserve HTTP status codes
error!(
"Stage {} ({}) failed with status {}",
idx + 1,
stage.name(),
response.status()
);
return Err(format!("Pipeline stage {} failed", stage.name()));
return Err(response);
}
}
}
......@@ -1173,10 +1180,10 @@ impl RequestPipeline {
// Extract final response
match ctx.state.response.final_response {
Some(FinalResponse::Chat(response)) => Ok(response),
Some(FinalResponse::Generate(_)) => {
Err("Internal error: wrong response type".to_string())
}
None => Err("No response produced".to_string()),
Some(FinalResponse::Generate(_)) => Err(utils::internal_error_static(
"Internal error: wrong response type",
)),
None => Err(utils::internal_error_static("No response produced")),
}
}
}
//! Context for /v1/responses endpoint handlers
//!
//! Bundles all dependencies needed by responses handlers to avoid passing
//! 10+ parameters to every function (fixes clippy::too_many_arguments).
use std::{collections::HashMap, sync::Arc};
use tokio::sync::RwLock;
use super::types::BackgroundTaskInfo;
use crate::{
core::WorkerRegistry,
data_connector::{
SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage,
},
mcp::McpManager,
routers::grpc::{context::SharedComponents, pipeline::RequestPipeline},
};
/// Context for /v1/responses endpoint
///
/// All fields are Arc/shared references, so cloning this context is cheap.
#[derive(Clone)]
pub struct ResponsesContext {
/// Chat pipeline for executing requests
pub pipeline: Arc<RequestPipeline>,
/// Shared components (tokenizer, parsers, worker_registry)
pub components: Arc<SharedComponents>,
/// Worker registry for validation
pub worker_registry: Arc<WorkerRegistry>,
/// Response storage backend
pub response_storage: SharedResponseStorage,
/// Conversation storage backend
pub conversation_storage: SharedConversationStorage,
/// Conversation item storage backend
pub conversation_item_storage: SharedConversationItemStorage,
/// MCP manager for tool support
pub mcp_manager: Arc<McpManager>,
/// Background task handles for cancellation support
pub background_tasks: Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>,
}
impl ResponsesContext {
/// Create a new responses context
pub fn new(
pipeline: Arc<RequestPipeline>,
components: Arc<SharedComponents>,
worker_registry: Arc<WorkerRegistry>,
response_storage: SharedResponseStorage,
conversation_storage: SharedConversationStorage,
conversation_item_storage: SharedConversationItemStorage,
mcp_manager: Arc<McpManager>,
) -> Self {
Self {
pipeline,
components,
worker_registry,
response_storage,
conversation_storage,
conversation_item_storage,
mcp_manager,
background_tasks: Arc::new(RwLock::new(HashMap::new())),
}
}
}
//! Handler functions for /v1/responses endpoints
//!
//! This module contains all the actual implementation logic for:
//! - POST /v1/responses (route_responses)
//! - GET /v1/responses/{response_id} (get_response_impl)
//! - POST /v1/responses/{response_id}/cancel (cancel_response_impl)
//! # Public API
//!
//! - `route_responses()` - POST /v1/responses (main entry point)
//! - `get_response_impl()` - GET /v1/responses/{response_id}
//! - `cancel_response_impl()` - POST /v1/responses/{response_id}/cancel
//!
//! # Architecture
//!
//! This module orchestrates all request handling for the /v1/responses endpoint.
//! It supports three execution modes:
//!
//! 1. **Synchronous** - Returns complete response immediately
//! 2. **Background** - Returns queued response, executes in background task
//! 3. **Streaming** - Returns SSE stream with real-time events
//!
//! # Request Flow
//!
//! ```text
//! route_responses()
//! ├─► route_responses_sync() → route_responses_internal()
//! ├─► route_responses_background() → spawn(route_responses_internal())
//! └─► route_responses_streaming() → convert_chat_stream_to_responses_stream()
//!
//! route_responses_internal()
//! ├─► load_conversation_history()
//! ├─► execute_tool_loop() (if MCP tools)
//! │ └─► pipeline.execute_chat_for_responses() [loop]
//! └─► execute_without_mcp() (if no MCP tools)
//! └─► pipeline.execute_chat_for_responses()
//! ```
use std::{
collections::HashMap,
sync::Arc,
time::{SystemTime, UNIX_EPOCH},
};
......@@ -27,7 +52,7 @@ use uuid::Uuid;
use super::{
conversions,
streaming::ResponseStreamEventEmitter,
tool_loop::{create_mcp_manager_from_request, execute_tool_loop, execute_tool_loop_streaming},
tool_loop::{execute_tool_loop, execute_tool_loop_streaming},
types::BackgroundTaskInfo,
};
use crate::{
......@@ -42,10 +67,7 @@ use crate::{
ResponseStatus, ResponsesRequest, ResponsesResponse, ResponsesUsage,
},
},
routers::{
grpc::{context::SharedComponents, pipeline::RequestPipeline},
openai::conversations::persist_conversation_items,
},
routers::openai::{conversations::persist_conversation_items, mcp::ensure_request_mcp_client},
};
// ============================================================================
......@@ -55,19 +77,39 @@ use crate::{
/// Main handler for POST /v1/responses
///
/// Validates request, determines execution mode (sync/async/streaming), and delegates
#[allow(clippy::too_many_arguments)]
pub async fn route_responses(
pipeline: &RequestPipeline,
ctx: &super::context::ResponsesContext,
request: Arc<ResponsesRequest>,
headers: Option<http::HeaderMap>,
model_id: Option<String>,
components: Arc<SharedComponents>,
response_storage: SharedResponseStorage,
conversation_storage: SharedConversationStorage,
conversation_item_storage: SharedConversationItemStorage,
mcp_manager: Arc<crate::mcp::McpManager>,
background_tasks: Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>,
) -> Response {
// 0. Fast worker validation (fail-fast before expensive operations)
let requested_model: Option<&str> = model_id.as_deref().or(Some(request.model.as_str()));
if let Some(model) = requested_model {
// Check if any workers support this model
let available_models = ctx.worker_registry.get_models();
if !available_models.contains(&model.to_string()) {
return (
StatusCode::SERVICE_UNAVAILABLE,
axum::Json(json!({
"error": {
"message": format!(
"No workers available for model '{}'. Available models: {}",
model,
available_models.join(", ")
),
"type": "service_unavailable",
"param": "model",
"code": "no_available_workers"
}
})),
)
.into_response();
}
}
// 1. Validate mutually exclusive parameters
if request.previous_response_id.is_some() && request.conversation.is_some() {
return (
......@@ -105,47 +147,11 @@ pub async fn route_responses(
// 3. Route based on execution mode
if is_streaming {
route_responses_streaming(
pipeline,
request,
headers,
model_id,
components,
response_storage,
conversation_storage,
conversation_item_storage,
mcp_manager,
)
.await
route_responses_streaming(ctx, request, headers, model_id).await
} else if is_background {
route_responses_background(
pipeline,
request,
headers,
model_id,
components,
response_storage,
conversation_storage,
conversation_item_storage,
mcp_manager,
background_tasks,
)
.await
route_responses_background(ctx, request, headers, model_id).await
} else {
route_responses_sync(
pipeline,
request,
headers,
model_id,
components,
response_storage,
conversation_storage,
conversation_item_storage,
mcp_manager,
None, // No response_id for sync
None, // No background_tasks for sync
)
.await
route_responses_sync(ctx, request, headers, model_id, None).await
}
}
......@@ -161,120 +167,71 @@ pub async fn route_responses(
/// 3. Executes chat pipeline
/// 4. Converts back to ResponsesResponse
/// 5. Persists to storage
#[allow(clippy::too_many_arguments)]
async fn route_responses_sync(
pipeline: &RequestPipeline,
ctx: &super::context::ResponsesContext,
request: Arc<ResponsesRequest>,
headers: Option<http::HeaderMap>,
model_id: Option<String>,
components: Arc<SharedComponents>,
response_storage: SharedResponseStorage,
conversation_storage: SharedConversationStorage,
conversation_item_storage: SharedConversationItemStorage,
mcp_manager: Arc<crate::mcp::McpManager>,
response_id: Option<String>,
background_tasks: Option<Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>>,
) -> Response {
match route_responses_internal(
pipeline,
request,
headers,
model_id,
components,
response_storage,
conversation_storage,
conversation_item_storage,
mcp_manager,
response_id,
background_tasks,
)
.await
{
match route_responses_internal(ctx, request, headers, model_id, response_id).await {
Ok(responses_response) => axum::Json(responses_response).into_response(),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
axum::Json(json!({
"error": {
"message": e,
"type": "internal_error"
}
})),
)
.into_response(),
Err(response) => response, // Already a Response with proper status code
}
}
/// Internal implementation that returns Result for background task compatibility
#[allow(clippy::too_many_arguments)]
async fn route_responses_internal(
pipeline: &RequestPipeline,
ctx: &super::context::ResponsesContext,
request: Arc<ResponsesRequest>,
headers: Option<http::HeaderMap>,
model_id: Option<String>,
components: Arc<SharedComponents>,
response_storage: SharedResponseStorage,
conversation_storage: SharedConversationStorage,
conversation_item_storage: SharedConversationItemStorage,
mcp_manager: Arc<crate::mcp::McpManager>,
response_id: Option<String>,
background_tasks: Option<Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>>,
) -> Result<ResponsesResponse, String> {
) -> Result<ResponsesResponse, Response> {
// 1. Load conversation history and build modified request
let modified_request = load_conversation_history(
&request,
&response_storage,
&conversation_storage,
&conversation_item_storage,
)
.await?;
let modified_request = load_conversation_history(ctx, &request).await?;
// 2. Check if request has MCP tools - if so, use tool loop
let responses_response = if let Some(tools) = &request.tools {
// Try to create dynamic MCP client from request tools using the manager
if let Some(request_mcp_manager) =
create_mcp_manager_from_request(&mcp_manager, tools).await
// Ensure dynamic MCP client is registered for request-scoped tools
if ensure_request_mcp_client(&ctx.mcp_manager, tools)
.await
.is_some()
{
debug!("MCP tools detected, using tool loop");
// Execute with MCP tool loop
execute_tool_loop(
pipeline,
ctx,
modified_request,
&request,
headers,
model_id,
components,
request_mcp_manager,
response_id.clone(),
background_tasks,
)
.await?
} else {
debug!("Failed to create MCP client from request tools");
// Fall through to non-MCP execution
execute_without_mcp(
pipeline,
ctx,
&modified_request,
&request,
headers,
model_id,
components,
response_id.clone(),
background_tasks,
)
.await?
}
} else {
// No tools, execute normally
execute_without_mcp(
pipeline,
ctx,
&modified_request,
&request,
headers,
model_id,
components,
response_id.clone(),
background_tasks,
)
.await?
};
......@@ -283,9 +240,9 @@ async fn route_responses_internal(
if request.store.unwrap_or(true) {
if let Ok(response_json) = serde_json::to_value(&responses_response) {
if let Err(e) = persist_conversation_items(
conversation_storage,
conversation_item_storage,
response_storage,
ctx.conversation_storage.clone(),
ctx.conversation_item_storage.clone(),
ctx.response_storage.clone(),
&response_json,
&request,
)
......@@ -306,16 +263,10 @@ async fn route_responses_internal(
/// Execute responses request in background mode
#[allow(clippy::too_many_arguments)]
async fn route_responses_background(
pipeline: &RequestPipeline,
ctx: &super::context::ResponsesContext,
request: Arc<ResponsesRequest>,
headers: Option<http::HeaderMap>,
model_id: Option<String>,
components: Arc<SharedComponents>,
response_storage: SharedResponseStorage,
conversation_storage: SharedConversationStorage,
conversation_item_storage: SharedConversationItemStorage,
mcp_manager: Arc<crate::mcp::McpManager>,
background_tasks: Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>,
) -> Response {
// Generate response_id for background tracking
let response_id = format!("resp_{}", Uuid::new_v4());
......@@ -356,9 +307,9 @@ async fn route_responses_background(
// Persist queued response to storage
if let Ok(response_json) = serde_json::to_value(&queued_response) {
if let Err(e) = persist_conversation_items(
conversation_storage.clone(),
conversation_item_storage.clone(),
response_storage.clone(),
ctx.conversation_storage.clone(),
ctx.conversation_item_storage.clone(),
ctx.response_storage.clone(),
&response_json,
&request,
)
......@@ -369,17 +320,11 @@ async fn route_responses_background(
}
// Spawn background task
let pipeline = pipeline.clone();
let ctx_clone = ctx.clone();
let request_clone = request.clone();
let headers_clone = headers.clone();
let model_id_clone = model_id.clone();
let components_clone = components.clone();
let response_storage_clone = response_storage.clone();
let conversation_storage_clone = conversation_storage.clone();
let conversation_item_storage_clone = conversation_item_storage.clone();
let mcp_manager_clone = mcp_manager.clone();
let response_id_clone = response_id.clone();
let background_tasks_clone = background_tasks.clone();
let handle = tokio::task::spawn(async move {
// Execute synchronously (set background=false to prevent recursion)
......@@ -387,17 +332,11 @@ async fn route_responses_background(
background_request.background = Some(false);
match route_responses_internal(
&pipeline,
&ctx_clone,
Arc::new(background_request),
headers_clone,
model_id_clone,
components_clone,
response_storage_clone,
conversation_storage_clone,
conversation_item_storage_clone,
mcp_manager_clone,
Some(response_id_clone.clone()),
Some(background_tasks_clone.clone()),
)
.await
{
......@@ -407,20 +346,25 @@ async fn route_responses_background(
response_id_clone
);
}
Err(e) => {
warn!("Background response {} failed: {}", response_id_clone, e);
Err(response) => {
warn!(
"Background response {} failed with status {}",
response_id_clone,
response.status()
);
}
}
// Clean up task handle when done
background_tasks_clone
ctx_clone
.background_tasks
.write()
.await
.remove(&response_id_clone);
});
// Store task info for cancellation support
background_tasks.write().await.insert(
ctx.background_tasks.write().await.insert(
response_id.clone(),
BackgroundTaskInfo {
handle,
......@@ -440,61 +384,28 @@ async fn route_responses_background(
/// Execute streaming responses request
#[allow(clippy::too_many_arguments)]
async fn route_responses_streaming(
pipeline: &RequestPipeline,
ctx: &super::context::ResponsesContext,
request: Arc<ResponsesRequest>,
headers: Option<http::HeaderMap>,
model_id: Option<String>,
components: Arc<SharedComponents>,
response_storage: SharedResponseStorage,
conversation_storage: SharedConversationStorage,
conversation_item_storage: SharedConversationItemStorage,
mcp_manager: Arc<crate::mcp::McpManager>,
) -> Response {
// 1. Load conversation history
let modified_request = match load_conversation_history(
&request,
&response_storage,
&conversation_storage,
&conversation_item_storage,
)
.await
{
let modified_request = match load_conversation_history(ctx, &request).await {
Ok(req) => req,
Err(e) => {
return (
StatusCode::BAD_REQUEST,
axum::Json(json!({
"error": {
"message": e,
"type": "invalid_request_error"
}
})),
)
.into_response();
}
Err(response) => return response, // Already a Response with proper status code
};
// 2. Check if request has MCP tools - if so, use streaming tool loop
if let Some(tools) = &request.tools {
// Try to create dynamic MCP client from request tools using the manager
if let Some(request_mcp_manager) =
create_mcp_manager_from_request(&mcp_manager, tools).await
// Ensure dynamic MCP client is registered for request-scoped tools
if ensure_request_mcp_client(&ctx.mcp_manager, tools)
.await
.is_some()
{
debug!("MCP tools detected in streaming mode, using streaming tool loop");
return execute_tool_loop_streaming(
pipeline,
modified_request,
&request,
headers,
model_id,
components,
request_mcp_manager,
response_storage,
conversation_storage,
conversation_item_storage,
)
.await;
return execute_tool_loop_streaming(ctx, modified_request, &request, headers, model_id)
.await;
}
}
......@@ -516,18 +427,7 @@ async fn route_responses_streaming(
};
// 4. Execute chat pipeline and convert streaming format (no MCP tools)
convert_chat_stream_to_responses_stream(
pipeline,
chat_request,
headers,
model_id,
components,
&request,
response_storage,
conversation_storage,
conversation_item_storage,
)
.await
convert_chat_stream_to_responses_stream(ctx, chat_request, headers, model_id, &request).await
}
/// Convert chat streaming response to responses streaming format
......@@ -540,21 +440,23 @@ async fn route_responses_streaming(
/// 5. Emits transformed SSE events in responses format
#[allow(clippy::too_many_arguments)]
async fn convert_chat_stream_to_responses_stream(
pipeline: &RequestPipeline,
ctx: &super::context::ResponsesContext,
chat_request: Arc<crate::protocols::chat::ChatCompletionRequest>,
headers: Option<http::HeaderMap>,
model_id: Option<String>,
components: Arc<SharedComponents>,
original_request: &ResponsesRequest,
response_storage: SharedResponseStorage,
_conversation_storage: SharedConversationStorage,
_conversation_item_storage: SharedConversationItemStorage,
) -> Response {
debug!("Converting chat SSE stream to responses SSE format");
// Get chat streaming response
let chat_response = pipeline
.execute_chat(chat_request.clone(), headers, model_id, components)
let chat_response = ctx
.pipeline
.execute_chat(
chat_request.clone(),
headers,
model_id,
ctx.components.clone(),
)
.await;
// Extract body and headers from chat response
......@@ -566,18 +468,18 @@ async fn convert_chat_stream_to_responses_stream(
// Spawn background task to transform stream
let original_request_clone = original_request.clone();
let chat_request_clone = chat_request.clone();
let response_storage_clone = response_storage.clone();
let conversation_storage_clone = _conversation_storage.clone();
let conversation_item_storage_clone = _conversation_item_storage.clone();
let response_storage = ctx.response_storage.clone();
let conversation_storage = ctx.conversation_storage.clone();
let conversation_item_storage = ctx.conversation_item_storage.clone();
tokio::spawn(async move {
if let Err(e) = process_and_transform_sse_stream(
body,
original_request_clone,
chat_request_clone,
response_storage_clone,
conversation_storage_clone,
conversation_item_storage_clone,
response_storage,
conversation_storage,
conversation_item_storage,
tx.clone(),
)
.await
......@@ -710,9 +612,9 @@ async fn process_and_transform_sse_stream(
if let Ok(response_json) = serde_json::to_value(&final_response) {
if let Err(e) = persist_conversation_items(
conversation_storage,
conversation_item_storage,
response_storage,
conversation_storage.clone(),
conversation_item_storage.clone(),
response_storage.clone(),
&response_json,
&original_request,
)
......@@ -925,53 +827,55 @@ impl StreamingResponseAccumulator {
// ============================================================================
/// Execute request without MCP tool loop (simple pipeline execution)
#[allow(clippy::too_many_arguments)]
async fn execute_without_mcp(
pipeline: &RequestPipeline,
ctx: &super::context::ResponsesContext,
modified_request: &ResponsesRequest,
original_request: &ResponsesRequest,
headers: Option<http::HeaderMap>,
model_id: Option<String>,
components: Arc<SharedComponents>,
response_id: Option<String>,
background_tasks: Option<Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>>,
) -> Result<ResponsesResponse, String> {
) -> Result<ResponsesResponse, Response> {
use crate::routers::grpc::utils;
// Convert ResponsesRequest → ChatCompletionRequest
let chat_request = conversions::responses_to_chat(modified_request)
.map_err(|e| format!("Failed to convert request: {}", e))?;
.map_err(|e| utils::bad_request_error(format!("Failed to convert request: {}", e)))?;
// Execute chat pipeline
let chat_response = pipeline
// Execute chat pipeline (errors already have proper HTTP status codes)
let chat_response = ctx
.pipeline
.execute_chat_for_responses(
Arc::new(chat_request),
headers,
model_id,
components,
ctx.components.clone(),
response_id.clone(),
background_tasks,
Some(ctx.background_tasks.clone()),
)
.await
.map_err(|e| format!("Pipeline execution failed: {}", e))?;
.await?; // Preserve the Response error as-is
// Convert ChatCompletionResponse → ResponsesResponse
conversions::chat_to_responses(&chat_response, original_request, response_id)
.map_err(|e| format!("Failed to convert to responses format: {}", e))
conversions::chat_to_responses(&chat_response, original_request, response_id).map_err(|e| {
utils::internal_error_message(format!("Failed to convert to responses format: {}", e))
})
}
/// Load conversation history and response chains, returning modified request
async fn load_conversation_history(
ctx: &super::context::ResponsesContext,
request: &ResponsesRequest,
response_storage: &SharedResponseStorage,
conversation_storage: &SharedConversationStorage,
conversation_item_storage: &SharedConversationItemStorage,
) -> Result<ResponsesRequest, String> {
) -> Result<ResponsesRequest, Response> {
let mut modified_request = request.clone();
let mut conversation_items: Option<Vec<ResponseInputOutputItem>> = None;
// Handle previous_response_id by loading response chain
if let Some(ref prev_id_str) = modified_request.previous_response_id {
let prev_id = ResponseId::from(prev_id_str.as_str());
match response_storage.get_response_chain(&prev_id, None).await {
match ctx
.response_storage
.get_response_chain(&prev_id, None)
.await
{
Ok(chain) => {
let mut items = Vec::new();
for stored in chain.responses.iter() {
......@@ -1026,7 +930,7 @@ async fn load_conversation_history(
let conv_id = ConversationId::from(conv_id_str.as_str());
// Auto-create conversation if it doesn't exist (OpenAI behavior)
if let Ok(None) = conversation_storage.get_conversation(&conv_id).await {
if let Ok(None) = ctx.conversation_storage.get_conversation(&conv_id).await {
debug!(
"Creating new conversation with user-provided ID: {}",
conv_id_str
......@@ -1043,10 +947,15 @@ async fn load_conversation_history(
id: Some(conv_id.clone()), // Use user-provided conversation ID
metadata,
};
conversation_storage
ctx.conversation_storage
.create_conversation(new_conv)
.await
.map_err(|e| format!("Failed to create conversation: {}", e))?;
.map_err(|e| {
crate::routers::grpc::utils::internal_error_message(format!(
"Failed to create conversation: {}",
e
))
})?;
}
// Load conversation history
......@@ -1057,7 +966,11 @@ async fn load_conversation_history(
after: None,
};
match conversation_item_storage.list_items(&conv_id, params).await {
match ctx
.conversation_item_storage
.list_items(&conv_id, params)
.await
{
Ok(stored_items) => {
let mut items: Vec<ResponseInputOutputItem> = Vec::new();
for item in stored_items.into_iter() {
......@@ -1142,13 +1055,13 @@ async fn load_conversation_history(
/// Implementation for GET /v1/responses/{response_id}
pub async fn get_response_impl(
response_storage: &SharedResponseStorage,
ctx: &super::context::ResponsesContext,
response_id: &str,
) -> Response {
let resp_id = ResponseId::from(response_id);
// Retrieve response from storage
match response_storage.get_response(&resp_id).await {
match ctx.response_storage.get_response(&resp_id).await {
Ok(Some(stored_response)) => axum::Json(stored_response.raw_response).into_response(),
Ok(None) => (
StatusCode::NOT_FOUND,
......@@ -1180,14 +1093,13 @@ pub async fn get_response_impl(
/// Implementation for POST /v1/responses/{response_id}/cancel
pub async fn cancel_response_impl(
response_storage: &SharedResponseStorage,
background_tasks: &Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>,
ctx: &super::context::ResponsesContext,
response_id: &str,
) -> Response {
let resp_id = ResponseId::from(response_id);
// Retrieve response from storage to check if it exists and get current status
match response_storage.get_response(&resp_id).await {
match ctx.response_storage.get_response(&resp_id).await {
Ok(Some(stored_response)) => {
// Check current status - only queued or in_progress responses can be cancelled
let current_status = stored_response
......@@ -1199,7 +1111,7 @@ pub async fn cancel_response_impl(
match current_status {
"queued" | "in_progress" => {
// Attempt to abort the background task
let mut tasks = background_tasks.write().await;
let mut tasks = ctx.background_tasks.write().await;
if let Some(task_info) = tasks.remove(response_id) {
// Abort the Rust task immediately
task_info.handle.abort();
......
......@@ -9,6 +9,7 @@
//! - Response persistence
// Module declarations
pub mod context;
mod conversions;
mod handlers;
pub mod streaming;
......@@ -16,5 +17,6 @@ pub mod tool_loop;
pub mod types;
// Public exports
pub use context::ResponsesContext;
pub use handlers::{cancel_response_impl, get_response_impl, route_responses};
pub use types::BackgroundTaskInfo;
......@@ -13,7 +13,7 @@ use axum::{
};
use bytes::Bytes;
use serde_json::json;
use tokio::sync::{mpsc, RwLock};
use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, warn};
use uuid::Uuid;
......@@ -21,24 +21,14 @@ use uuid::Uuid;
use super::{
conversions,
streaming::{OutputItemType, ResponseStreamEventEmitter},
types::BackgroundTaskInfo,
};
/// This is a re-export of the shared implementation from openai::mcp
pub(super) use crate::routers::openai::mcp::ensure_request_mcp_client as create_mcp_manager_from_request;
use crate::{
data_connector::{
SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage,
use crate::protocols::{
chat::ChatCompletionResponse,
common::{Tool, ToolChoice, ToolChoiceValue},
responses::{
McpToolInfo, ResponseContentPart, ResponseInput, ResponseInputOutputItem,
ResponseOutputItem, ResponseStatus, ResponseToolType, ResponsesRequest, ResponsesResponse,
},
protocols::{
chat::ChatCompletionResponse,
common::{Tool, ToolChoice, ToolChoiceValue},
responses::{
McpToolInfo, ResponseContentPart, ResponseInput, ResponseInputOutputItem,
ResponseOutputItem, ResponseStatus, ResponseToolType, ResponsesRequest,
ResponsesResponse,
},
},
routers::grpc::{context::SharedComponents, pipeline::RequestPipeline},
};
/// Extract function call from a chat completion response
......@@ -221,18 +211,14 @@ fn build_mcp_call_item(
/// 2. Checks if response has tool calls
/// 3. If yes, executes MCP tools and builds resume request
/// 4. Repeats until no more tool calls or limit reached
#[allow(clippy::too_many_arguments)]
pub(super) async fn execute_tool_loop(
pipeline: &RequestPipeline,
ctx: &super::context::ResponsesContext,
mut current_request: ResponsesRequest,
original_request: &ResponsesRequest,
headers: Option<http::HeaderMap>,
model_id: Option<String>,
components: Arc<SharedComponents>,
mcp_manager: Arc<crate::mcp::McpManager>,
response_id: Option<String>,
background_tasks: Option<Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>>,
) -> Result<ResponsesResponse, String> {
) -> Result<ResponsesResponse, Response> {
// Get server label from original request tools
let server_label = original_request
.tools
......@@ -257,31 +243,35 @@ pub(super) async fn execute_tool_loop(
);
// Get MCP tools and convert to chat format (do this once before loop)
let mcp_tools = mcp_manager.list_tools();
let mcp_tools = ctx.mcp_manager.list_tools();
let chat_tools = convert_mcp_tools_to_chat_tools(&mcp_tools);
debug!("Converted {} MCP tools to chat format", chat_tools.len());
loop {
// Convert to chat request
let mut chat_request = conversions::responses_to_chat(&current_request)
.map_err(|e| format!("Failed to convert request: {}", e))?;
let mut chat_request = conversions::responses_to_chat(&current_request).map_err(|e| {
crate::routers::grpc::utils::bad_request_error(format!(
"Failed to convert request: {}",
e
))
})?;
// Add MCP tools to chat request so LLM knows about them
chat_request.tools = Some(chat_tools.clone());
chat_request.tool_choice = Some(ToolChoice::Value(ToolChoiceValue::Auto));
// Execute chat pipeline
let chat_response = pipeline
// Execute chat pipeline (errors already have proper HTTP status codes)
let chat_response = ctx
.pipeline
.execute_chat_for_responses(
Arc::new(chat_request),
headers.clone(),
model_id.clone(),
components.clone(),
ctx.components.clone(),
response_id.clone(),
background_tasks.clone(),
Some(ctx.background_tasks.clone()),
)
.await
.map_err(|e| format!("Pipeline execution failed: {}", e))?;
.await?;
// Check for function calls
if let Some((call_id, tool_name, args_json_str)) =
......@@ -312,7 +302,12 @@ pub(super) async fn execute_tool_loop(
original_request,
response_id.clone(),
)
.map_err(|e| format!("Failed to convert to responses format: {}", e))?;
.map_err(|e| {
crate::routers::grpc::utils::internal_error_message(format!(
"Failed to convert to responses format: {}",
e
))
})?;
// Mark as completed but with incomplete details
responses_response.status = ResponseStatus::Completed;
......@@ -329,7 +324,8 @@ pub(super) async fn execute_tool_loop(
"Calling MCP tool '{}' with args: {}",
tool_name, args_json_str
);
let (output_str, success, error) = match mcp_manager
let (output_str, success, error) = match ctx
.mcp_manager
.call_tool(tool_name.as_str(), args_json_str.as_str())
.await
{
......@@ -428,12 +424,17 @@ pub(super) async fn execute_tool_loop(
original_request,
response_id.clone(),
)
.map_err(|e| format!("Failed to convert to responses format: {}", e))?;
.map_err(|e| {
crate::routers::grpc::utils::internal_error_message(format!(
"Failed to convert to responses format: {}",
e
))
})?;
// Inject MCP metadata into output
if state.total_calls > 0 {
// Prepend mcp_list_tools item
let mcp_list_tools = build_mcp_list_tools_item(&mcp_manager, &server_label);
let mcp_list_tools = build_mcp_list_tools_item(&ctx.mcp_manager, &server_label);
responses_response.output.insert(0, mcp_list_tools);
// Append all mcp_call items at the end
......@@ -455,52 +456,28 @@ pub(super) async fn execute_tool_loop(
/// This streams each iteration's response to the client while accumulating
/// to check for tool calls. If tool calls are found, executes them and
/// continues with the next streaming iteration.
#[allow(clippy::too_many_arguments)]
pub(super) async fn execute_tool_loop_streaming(
pipeline: &RequestPipeline,
ctx: &super::context::ResponsesContext,
current_request: ResponsesRequest,
original_request: &ResponsesRequest,
headers: Option<http::HeaderMap>,
model_id: Option<String>,
components: Arc<SharedComponents>,
mcp_manager: Arc<crate::mcp::McpManager>,
response_storage: SharedResponseStorage,
conversation_storage: SharedConversationStorage,
conversation_item_storage: SharedConversationItemStorage,
) -> Response {
// Get server label
let server_label = original_request
.tools
.as_ref()
.and_then(|tools| {
tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
.and_then(|t| t.server_label.clone())
})
.unwrap_or_else(|| "request-mcp".to_string());
// Create SSE channel for client
let (tx, rx) = mpsc::unbounded_channel::<Result<Bytes, std::io::Error>>();
// Clone data for background task
let pipeline_clone = pipeline.clone();
let ctx_clone = ctx.clone();
let original_request_clone = original_request.clone();
// Spawn background task for tool loop
tokio::spawn(async move {
let result = execute_tool_loop_streaming_internal(
&pipeline_clone,
&ctx_clone,
current_request,
&original_request_clone,
headers,
model_id,
components,
mcp_manager,
server_label,
response_storage,
conversation_storage,
conversation_item_storage,
tx.clone(),
)
.await;
......@@ -546,21 +523,26 @@ pub(super) async fn execute_tool_loop_streaming(
}
/// Internal streaming tool loop implementation
#[allow(clippy::too_many_arguments)]
async fn execute_tool_loop_streaming_internal(
pipeline: &RequestPipeline,
ctx: &super::context::ResponsesContext,
mut current_request: ResponsesRequest,
original_request: &ResponsesRequest,
headers: Option<http::HeaderMap>,
model_id: Option<String>,
components: Arc<SharedComponents>,
mcp_manager: Arc<crate::mcp::McpManager>,
server_label: String,
_response_storage: SharedResponseStorage,
_conversation_storage: SharedConversationStorage,
_conversation_item_storage: SharedConversationItemStorage,
tx: mpsc::UnboundedSender<Result<Bytes, std::io::Error>>,
) -> Result<(), String> {
// Extract server label from original request tools
let server_label = original_request
.tools
.as_ref()
.and_then(|tools| {
tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
.and_then(|t| t.server_label.clone())
})
.unwrap_or_else(|| "request-mcp".to_string());
const MAX_ITERATIONS: usize = 10;
let mut state = ToolLoopState::new(original_request.input.clone(), server_label.clone());
let max_tool_calls = original_request.max_tool_calls.map(|n| n as usize);
......@@ -581,7 +563,7 @@ async fn execute_tool_loop_streaming_internal(
emitter.send_event(&event, &tx)?;
// Get MCP tools and convert to chat format (do this once before loop)
let mcp_tools = mcp_manager.list_tools();
let mcp_tools = ctx.mcp_manager.list_tools();
let chat_tools = convert_mcp_tools_to_chat_tools(&mcp_tools);
debug!(
"Streaming: Converted {} MCP tools to chat format",
......@@ -670,12 +652,13 @@ async fn execute_tool_loop_streaming_internal(
chat_request.tool_choice = Some(ToolChoice::Value(ToolChoiceValue::Auto));
// Execute chat streaming
let response = pipeline
let response = ctx
.pipeline
.execute_chat(
Arc::new(chat_request),
headers.clone(),
model_id.clone(),
components.clone(),
ctx.components.clone(),
)
.await;
......@@ -758,7 +741,8 @@ async fn execute_tool_loop_streaming_internal(
"Calling MCP tool '{}' with args: {}",
tool_name, args_json_str
);
let (output_str, success, error) = match mcp_manager
let (output_str, success, error) = match ctx
.mcp_manager
.call_tool(tool_name.as_str(), args_json_str.as_str())
.await
{
......
// gRPC Router Implementation
use std::{collections::HashMap, sync::Arc};
use std::sync::Arc;
use async_trait::async_trait;
use axum::{
......@@ -9,22 +9,13 @@ use axum::{
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
};
use tokio::sync::RwLock;
use tracing::debug;
use super::{
context::SharedComponents,
pipeline::RequestPipeline,
responses::{self, BackgroundTaskInfo},
};
use super::{context::SharedComponents, pipeline::RequestPipeline, responses};
use crate::{
app_context::AppContext,
config::types::RetryConfig,
core::WorkerRegistry,
data_connector::{
SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage,
},
mcp::McpManager,
policies::PolicyRegistry,
protocols::{
chat::ChatCompletionRequest,
......@@ -57,13 +48,8 @@ pub struct GrpcRouter {
configured_tool_parser: Option<String>,
pipeline: RequestPipeline,
shared_components: Arc<SharedComponents>,
// Storage backends for /v1/responses support
response_storage: SharedResponseStorage,
conversation_storage: SharedConversationStorage,
conversation_item_storage: SharedConversationItemStorage,
mcp_manager: Arc<McpManager>,
// Background task handles for cancellation support (includes gRPC client for Python abort)
background_tasks: Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>,
// Responses context (bundles all /v1/responses dependencies: storage, MCP, background_tasks)
responses_context: responses::ResponsesContext,
}
impl GrpcRouter {
......@@ -89,18 +75,6 @@ impl GrpcRouter {
let worker_registry = ctx.worker_registry.clone();
let policy_registry = ctx.policy_registry.clone();
// Extract storage backends from context
let response_storage = ctx.response_storage.clone();
let conversation_storage = ctx.conversation_storage.clone();
let conversation_item_storage = ctx.conversation_item_storage.clone();
// Get MCP manager from app context
let mcp_manager = ctx
.mcp_manager
.get()
.ok_or_else(|| "gRPC router requires MCP manager".to_string())?
.clone();
// Create shared components for pipeline
let shared_components = Arc::new(SharedComponents {
tokenizer: tokenizer.clone(),
......@@ -119,6 +93,20 @@ impl GrpcRouter {
ctx.configured_reasoning_parser.clone(),
);
// Create responses context with all dependencies
let responses_context = responses::ResponsesContext::new(
Arc::new(pipeline.clone()),
shared_components.clone(),
worker_registry.clone(),
ctx.response_storage.clone(),
ctx.conversation_storage.clone(),
ctx.conversation_item_storage.clone(),
ctx.mcp_manager
.get()
.ok_or_else(|| "gRPC router requires MCP manager".to_string())?
.clone(),
);
Ok(GrpcRouter {
worker_registry,
policy_registry,
......@@ -132,11 +120,7 @@ impl GrpcRouter {
configured_tool_parser: ctx.configured_tool_parser.clone(),
pipeline,
shared_components,
response_storage,
conversation_storage,
conversation_item_storage,
mcp_manager,
background_tasks: Arc::new(RwLock::new(HashMap::new())),
responses_context,
})
}
......@@ -254,26 +238,11 @@ impl RouterTrait for GrpcRouter {
body: &ResponsesRequest,
model_id: Option<&str>,
) -> Response {
// Use responses module for ALL requests (streaming and non-streaming)
// Responses module handles:
// - Request validation (previous_response_id XOR conversation)
// - Loading response chain / conversation history from storage
// - Conversion: ResponsesRequest → ChatCompletionRequest
// - Execution through chat pipeline stages
// - Conversion: ChatCompletionResponse → ResponsesResponse
// - Response persistence
// - MCP tool loop wrapper (future)
responses::route_responses(
&self.pipeline,
&self.responses_context,
Arc::new(body.clone()),
headers.cloned(),
model_id.map(|s| s.to_string()),
self.shared_components.clone(),
self.response_storage.clone(),
self.conversation_storage.clone(),
self.conversation_item_storage.clone(),
self.mcp_manager.clone(),
self.background_tasks.clone(),
)
.await
}
......@@ -284,12 +253,11 @@ impl RouterTrait for GrpcRouter {
response_id: &str,
_params: &ResponsesGetParams,
) -> Response {
responses::get_response_impl(&self.response_storage, response_id).await
responses::get_response_impl(&self.responses_context, response_id).await
}
async fn cancel_response(&self, _headers: Option<&HeaderMap>, response_id: &str) -> Response {
responses::cancel_response_impl(&self.response_storage, &self.background_tasks, response_id)
.await
responses::cancel_response_impl(&self.responses_context, response_id).await
}
async fn route_classify(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment