Unverified Commit d05a968b authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][grpc] Add `ResponsesContext` and fix error propagation in responses api (#12164)

parent 94aad0de
...@@ -1098,10 +1098,15 @@ impl RequestPipeline { ...@@ -1098,10 +1098,15 @@ impl RequestPipeline {
} }
} }
/// Execute chat pipeline for responses endpoint (Result-based for easier composition) /// Execute chat pipeline for responses endpoint
/// ///
/// This is used by the responses module and returns Result instead of Response. /// TODO: The support for background tasks is not scalable. Consider replacing this with
/// It also supports background mode cancellation via background_tasks. /// a better design in the future.
/// Used by ALL non-streaming /v1/responses requests (both sync and background modes).
/// Uses the same 7 pipeline stages as execute_chat(), with three differences:
/// 1. Returns Result<ChatCompletionResponse, Response> for tool_loop composition
/// 2. Disallows streaming (responses endpoint uses different SSE format)
/// 3. Injects hooks for background task cancellation (only active when response_id provided)
pub async fn execute_chat_for_responses( pub async fn execute_chat_for_responses(
&self, &self,
request: Arc<ChatCompletionRequest>, request: Arc<ChatCompletionRequest>,
...@@ -1110,7 +1115,7 @@ impl RequestPipeline { ...@@ -1110,7 +1115,7 @@ impl RequestPipeline {
components: Arc<SharedComponents>, components: Arc<SharedComponents>,
response_id: Option<String>, response_id: Option<String>,
background_tasks: Option<Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>>, background_tasks: Option<Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>>,
) -> Result<ChatCompletionResponse, String> { ) -> Result<ChatCompletionResponse, Response> {
let mut ctx = RequestContext::for_chat(request, headers, model_id, components); let mut ctx = RequestContext::for_chat(request, headers, model_id, components);
// Execute each stage in sequence // Execute each stage in sequence
...@@ -1118,7 +1123,9 @@ impl RequestPipeline { ...@@ -1118,7 +1123,9 @@ impl RequestPipeline {
match stage.execute(&mut ctx).await { match stage.execute(&mut ctx).await {
Ok(Some(_response)) => { Ok(Some(_response)) => {
// Streaming not supported for responses sync mode // Streaming not supported for responses sync mode
return Err("Streaming is not supported in this context".to_string()); return Err(utils::bad_request_error(
"Streaming is not supported in this context".to_string(),
));
} }
Ok(None) => { Ok(None) => {
let stage_name = stage.name(); let stage_name = stage.name();
...@@ -1158,14 +1165,14 @@ impl RequestPipeline { ...@@ -1158,14 +1165,14 @@ impl RequestPipeline {
continue; continue;
} }
Err(response) => { Err(response) => {
// Error occurred // Error occurred - return the response as-is to preserve HTTP status codes
error!( error!(
"Stage {} ({}) failed with status {}", "Stage {} ({}) failed with status {}",
idx + 1, idx + 1,
stage.name(), stage.name(),
response.status() response.status()
); );
return Err(format!("Pipeline stage {} failed", stage.name())); return Err(response);
} }
} }
} }
...@@ -1173,10 +1180,10 @@ impl RequestPipeline { ...@@ -1173,10 +1180,10 @@ impl RequestPipeline {
// Extract final response // Extract final response
match ctx.state.response.final_response { match ctx.state.response.final_response {
Some(FinalResponse::Chat(response)) => Ok(response), Some(FinalResponse::Chat(response)) => Ok(response),
Some(FinalResponse::Generate(_)) => { Some(FinalResponse::Generate(_)) => Err(utils::internal_error_static(
Err("Internal error: wrong response type".to_string()) "Internal error: wrong response type",
} )),
None => Err("No response produced".to_string()), None => Err(utils::internal_error_static("No response produced")),
} }
} }
} }
//! Context for /v1/responses endpoint handlers
//!
//! Bundles all dependencies needed by responses handlers to avoid passing
//! 10+ parameters to every function (fixes clippy::too_many_arguments).
use std::{collections::HashMap, sync::Arc};
use tokio::sync::RwLock;
use super::types::BackgroundTaskInfo;
use crate::{
core::WorkerRegistry,
data_connector::{
SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage,
},
mcp::McpManager,
routers::grpc::{context::SharedComponents, pipeline::RequestPipeline},
};
/// Context for /v1/responses endpoint
///
/// All fields are Arc/shared references, so cloning this context is cheap.
#[derive(Clone)]
pub struct ResponsesContext {
/// Chat pipeline for executing requests
pub pipeline: Arc<RequestPipeline>,
/// Shared components (tokenizer, parsers, worker_registry)
pub components: Arc<SharedComponents>,
/// Worker registry for validation
pub worker_registry: Arc<WorkerRegistry>,
/// Response storage backend
pub response_storage: SharedResponseStorage,
/// Conversation storage backend
pub conversation_storage: SharedConversationStorage,
/// Conversation item storage backend
pub conversation_item_storage: SharedConversationItemStorage,
/// MCP manager for tool support
pub mcp_manager: Arc<McpManager>,
/// Background task handles for cancellation support
pub background_tasks: Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>,
}
impl ResponsesContext {
/// Create a new responses context
pub fn new(
pipeline: Arc<RequestPipeline>,
components: Arc<SharedComponents>,
worker_registry: Arc<WorkerRegistry>,
response_storage: SharedResponseStorage,
conversation_storage: SharedConversationStorage,
conversation_item_storage: SharedConversationItemStorage,
mcp_manager: Arc<McpManager>,
) -> Self {
Self {
pipeline,
components,
worker_registry,
response_storage,
conversation_storage,
conversation_item_storage,
mcp_manager,
background_tasks: Arc::new(RwLock::new(HashMap::new())),
}
}
}
//! Handler functions for /v1/responses endpoints //! Handler functions for /v1/responses endpoints
//! //!
//! This module contains all the actual implementation logic for: //! # Public API
//! - POST /v1/responses (route_responses) //!
//! - GET /v1/responses/{response_id} (get_response_impl) //! - `route_responses()` - POST /v1/responses (main entry point)
//! - POST /v1/responses/{response_id}/cancel (cancel_response_impl) //! - `get_response_impl()` - GET /v1/responses/{response_id}
//! - `cancel_response_impl()` - POST /v1/responses/{response_id}/cancel
//!
//! # Architecture
//!
//! This module orchestrates all request handling for the /v1/responses endpoint.
//! It supports three execution modes:
//!
//! 1. **Synchronous** - Returns complete response immediately
//! 2. **Background** - Returns queued response, executes in background task
//! 3. **Streaming** - Returns SSE stream with real-time events
//!
//! # Request Flow
//!
//! ```text
//! route_responses()
//! ├─► route_responses_sync() → route_responses_internal()
//! ├─► route_responses_background() → spawn(route_responses_internal())
//! └─► route_responses_streaming() → convert_chat_stream_to_responses_stream()
//!
//! route_responses_internal()
//! ├─► load_conversation_history()
//! ├─► execute_tool_loop() (if MCP tools)
//! │ └─► pipeline.execute_chat_for_responses() [loop]
//! └─► execute_without_mcp() (if no MCP tools)
//! └─► pipeline.execute_chat_for_responses()
//! ```
use std::{ use std::{
collections::HashMap,
sync::Arc, sync::Arc,
time::{SystemTime, UNIX_EPOCH}, time::{SystemTime, UNIX_EPOCH},
}; };
...@@ -27,7 +52,7 @@ use uuid::Uuid; ...@@ -27,7 +52,7 @@ use uuid::Uuid;
use super::{ use super::{
conversions, conversions,
streaming::ResponseStreamEventEmitter, streaming::ResponseStreamEventEmitter,
tool_loop::{create_mcp_manager_from_request, execute_tool_loop, execute_tool_loop_streaming}, tool_loop::{execute_tool_loop, execute_tool_loop_streaming},
types::BackgroundTaskInfo, types::BackgroundTaskInfo,
}; };
use crate::{ use crate::{
...@@ -42,10 +67,7 @@ use crate::{ ...@@ -42,10 +67,7 @@ use crate::{
ResponseStatus, ResponsesRequest, ResponsesResponse, ResponsesUsage, ResponseStatus, ResponsesRequest, ResponsesResponse, ResponsesUsage,
}, },
}, },
routers::{ routers::openai::{conversations::persist_conversation_items, mcp::ensure_request_mcp_client},
grpc::{context::SharedComponents, pipeline::RequestPipeline},
openai::conversations::persist_conversation_items,
},
}; };
// ============================================================================ // ============================================================================
...@@ -55,19 +77,39 @@ use crate::{ ...@@ -55,19 +77,39 @@ use crate::{
/// Main handler for POST /v1/responses /// Main handler for POST /v1/responses
/// ///
/// Validates request, determines execution mode (sync/async/streaming), and delegates /// Validates request, determines execution mode (sync/async/streaming), and delegates
#[allow(clippy::too_many_arguments)]
pub async fn route_responses( pub async fn route_responses(
pipeline: &RequestPipeline, ctx: &super::context::ResponsesContext,
request: Arc<ResponsesRequest>, request: Arc<ResponsesRequest>,
headers: Option<http::HeaderMap>, headers: Option<http::HeaderMap>,
model_id: Option<String>, model_id: Option<String>,
components: Arc<SharedComponents>,
response_storage: SharedResponseStorage,
conversation_storage: SharedConversationStorage,
conversation_item_storage: SharedConversationItemStorage,
mcp_manager: Arc<crate::mcp::McpManager>,
background_tasks: Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>,
) -> Response { ) -> Response {
// 0. Fast worker validation (fail-fast before expensive operations)
let requested_model: Option<&str> = model_id.as_deref().or(Some(request.model.as_str()));
if let Some(model) = requested_model {
// Check if any workers support this model
let available_models = ctx.worker_registry.get_models();
if !available_models.contains(&model.to_string()) {
return (
StatusCode::SERVICE_UNAVAILABLE,
axum::Json(json!({
"error": {
"message": format!(
"No workers available for model '{}'. Available models: {}",
model,
available_models.join(", ")
),
"type": "service_unavailable",
"param": "model",
"code": "no_available_workers"
}
})),
)
.into_response();
}
}
// 1. Validate mutually exclusive parameters // 1. Validate mutually exclusive parameters
if request.previous_response_id.is_some() && request.conversation.is_some() { if request.previous_response_id.is_some() && request.conversation.is_some() {
return ( return (
...@@ -105,47 +147,11 @@ pub async fn route_responses( ...@@ -105,47 +147,11 @@ pub async fn route_responses(
// 3. Route based on execution mode // 3. Route based on execution mode
if is_streaming { if is_streaming {
route_responses_streaming( route_responses_streaming(ctx, request, headers, model_id).await
pipeline,
request,
headers,
model_id,
components,
response_storage,
conversation_storage,
conversation_item_storage,
mcp_manager,
)
.await
} else if is_background { } else if is_background {
route_responses_background( route_responses_background(ctx, request, headers, model_id).await
pipeline,
request,
headers,
model_id,
components,
response_storage,
conversation_storage,
conversation_item_storage,
mcp_manager,
background_tasks,
)
.await
} else { } else {
route_responses_sync( route_responses_sync(ctx, request, headers, model_id, None).await
pipeline,
request,
headers,
model_id,
components,
response_storage,
conversation_storage,
conversation_item_storage,
mcp_manager,
None, // No response_id for sync
None, // No background_tasks for sync
)
.await
} }
} }
...@@ -161,120 +167,71 @@ pub async fn route_responses( ...@@ -161,120 +167,71 @@ pub async fn route_responses(
/// 3. Executes chat pipeline /// 3. Executes chat pipeline
/// 4. Converts back to ResponsesResponse /// 4. Converts back to ResponsesResponse
/// 5. Persists to storage /// 5. Persists to storage
#[allow(clippy::too_many_arguments)]
async fn route_responses_sync( async fn route_responses_sync(
pipeline: &RequestPipeline, ctx: &super::context::ResponsesContext,
request: Arc<ResponsesRequest>, request: Arc<ResponsesRequest>,
headers: Option<http::HeaderMap>, headers: Option<http::HeaderMap>,
model_id: Option<String>, model_id: Option<String>,
components: Arc<SharedComponents>,
response_storage: SharedResponseStorage,
conversation_storage: SharedConversationStorage,
conversation_item_storage: SharedConversationItemStorage,
mcp_manager: Arc<crate::mcp::McpManager>,
response_id: Option<String>, response_id: Option<String>,
background_tasks: Option<Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>>,
) -> Response { ) -> Response {
match route_responses_internal( match route_responses_internal(ctx, request, headers, model_id, response_id).await {
pipeline,
request,
headers,
model_id,
components,
response_storage,
conversation_storage,
conversation_item_storage,
mcp_manager,
response_id,
background_tasks,
)
.await
{
Ok(responses_response) => axum::Json(responses_response).into_response(), Ok(responses_response) => axum::Json(responses_response).into_response(),
Err(e) => ( Err(response) => response, // Already a Response with proper status code
StatusCode::INTERNAL_SERVER_ERROR,
axum::Json(json!({
"error": {
"message": e,
"type": "internal_error"
}
})),
)
.into_response(),
} }
} }
/// Internal implementation that returns Result for background task compatibility /// Internal implementation that returns Result for background task compatibility
#[allow(clippy::too_many_arguments)]
async fn route_responses_internal( async fn route_responses_internal(
pipeline: &RequestPipeline, ctx: &super::context::ResponsesContext,
request: Arc<ResponsesRequest>, request: Arc<ResponsesRequest>,
headers: Option<http::HeaderMap>, headers: Option<http::HeaderMap>,
model_id: Option<String>, model_id: Option<String>,
components: Arc<SharedComponents>,
response_storage: SharedResponseStorage,
conversation_storage: SharedConversationStorage,
conversation_item_storage: SharedConversationItemStorage,
mcp_manager: Arc<crate::mcp::McpManager>,
response_id: Option<String>, response_id: Option<String>,
background_tasks: Option<Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>>, ) -> Result<ResponsesResponse, Response> {
) -> Result<ResponsesResponse, String> {
// 1. Load conversation history and build modified request // 1. Load conversation history and build modified request
let modified_request = load_conversation_history( let modified_request = load_conversation_history(ctx, &request).await?;
&request,
&response_storage,
&conversation_storage,
&conversation_item_storage,
)
.await?;
// 2. Check if request has MCP tools - if so, use tool loop // 2. Check if request has MCP tools - if so, use tool loop
let responses_response = if let Some(tools) = &request.tools { let responses_response = if let Some(tools) = &request.tools {
// Try to create dynamic MCP client from request tools using the manager // Ensure dynamic MCP client is registered for request-scoped tools
if let Some(request_mcp_manager) = if ensure_request_mcp_client(&ctx.mcp_manager, tools)
create_mcp_manager_from_request(&mcp_manager, tools).await .await
.is_some()
{ {
debug!("MCP tools detected, using tool loop"); debug!("MCP tools detected, using tool loop");
// Execute with MCP tool loop // Execute with MCP tool loop
execute_tool_loop( execute_tool_loop(
pipeline, ctx,
modified_request, modified_request,
&request, &request,
headers, headers,
model_id, model_id,
components,
request_mcp_manager,
response_id.clone(), response_id.clone(),
background_tasks,
) )
.await? .await?
} else { } else {
debug!("Failed to create MCP client from request tools"); debug!("Failed to create MCP client from request tools");
// Fall through to non-MCP execution // Fall through to non-MCP execution
execute_without_mcp( execute_without_mcp(
pipeline, ctx,
&modified_request, &modified_request,
&request, &request,
headers, headers,
model_id, model_id,
components,
response_id.clone(), response_id.clone(),
background_tasks,
) )
.await? .await?
} }
} else { } else {
// No tools, execute normally // No tools, execute normally
execute_without_mcp( execute_without_mcp(
pipeline, ctx,
&modified_request, &modified_request,
&request, &request,
headers, headers,
model_id, model_id,
components,
response_id.clone(), response_id.clone(),
background_tasks,
) )
.await? .await?
}; };
...@@ -283,9 +240,9 @@ async fn route_responses_internal( ...@@ -283,9 +240,9 @@ async fn route_responses_internal(
if request.store.unwrap_or(true) { if request.store.unwrap_or(true) {
if let Ok(response_json) = serde_json::to_value(&responses_response) { if let Ok(response_json) = serde_json::to_value(&responses_response) {
if let Err(e) = persist_conversation_items( if let Err(e) = persist_conversation_items(
conversation_storage, ctx.conversation_storage.clone(),
conversation_item_storage, ctx.conversation_item_storage.clone(),
response_storage, ctx.response_storage.clone(),
&response_json, &response_json,
&request, &request,
) )
...@@ -306,16 +263,10 @@ async fn route_responses_internal( ...@@ -306,16 +263,10 @@ async fn route_responses_internal(
/// Execute responses request in background mode /// Execute responses request in background mode
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
async fn route_responses_background( async fn route_responses_background(
pipeline: &RequestPipeline, ctx: &super::context::ResponsesContext,
request: Arc<ResponsesRequest>, request: Arc<ResponsesRequest>,
headers: Option<http::HeaderMap>, headers: Option<http::HeaderMap>,
model_id: Option<String>, model_id: Option<String>,
components: Arc<SharedComponents>,
response_storage: SharedResponseStorage,
conversation_storage: SharedConversationStorage,
conversation_item_storage: SharedConversationItemStorage,
mcp_manager: Arc<crate::mcp::McpManager>,
background_tasks: Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>,
) -> Response { ) -> Response {
// Generate response_id for background tracking // Generate response_id for background tracking
let response_id = format!("resp_{}", Uuid::new_v4()); let response_id = format!("resp_{}", Uuid::new_v4());
...@@ -356,9 +307,9 @@ async fn route_responses_background( ...@@ -356,9 +307,9 @@ async fn route_responses_background(
// Persist queued response to storage // Persist queued response to storage
if let Ok(response_json) = serde_json::to_value(&queued_response) { if let Ok(response_json) = serde_json::to_value(&queued_response) {
if let Err(e) = persist_conversation_items( if let Err(e) = persist_conversation_items(
conversation_storage.clone(), ctx.conversation_storage.clone(),
conversation_item_storage.clone(), ctx.conversation_item_storage.clone(),
response_storage.clone(), ctx.response_storage.clone(),
&response_json, &response_json,
&request, &request,
) )
...@@ -369,17 +320,11 @@ async fn route_responses_background( ...@@ -369,17 +320,11 @@ async fn route_responses_background(
} }
// Spawn background task // Spawn background task
let pipeline = pipeline.clone(); let ctx_clone = ctx.clone();
let request_clone = request.clone(); let request_clone = request.clone();
let headers_clone = headers.clone(); let headers_clone = headers.clone();
let model_id_clone = model_id.clone(); let model_id_clone = model_id.clone();
let components_clone = components.clone();
let response_storage_clone = response_storage.clone();
let conversation_storage_clone = conversation_storage.clone();
let conversation_item_storage_clone = conversation_item_storage.clone();
let mcp_manager_clone = mcp_manager.clone();
let response_id_clone = response_id.clone(); let response_id_clone = response_id.clone();
let background_tasks_clone = background_tasks.clone();
let handle = tokio::task::spawn(async move { let handle = tokio::task::spawn(async move {
// Execute synchronously (set background=false to prevent recursion) // Execute synchronously (set background=false to prevent recursion)
...@@ -387,17 +332,11 @@ async fn route_responses_background( ...@@ -387,17 +332,11 @@ async fn route_responses_background(
background_request.background = Some(false); background_request.background = Some(false);
match route_responses_internal( match route_responses_internal(
&pipeline, &ctx_clone,
Arc::new(background_request), Arc::new(background_request),
headers_clone, headers_clone,
model_id_clone, model_id_clone,
components_clone,
response_storage_clone,
conversation_storage_clone,
conversation_item_storage_clone,
mcp_manager_clone,
Some(response_id_clone.clone()), Some(response_id_clone.clone()),
Some(background_tasks_clone.clone()),
) )
.await .await
{ {
...@@ -407,20 +346,25 @@ async fn route_responses_background( ...@@ -407,20 +346,25 @@ async fn route_responses_background(
response_id_clone response_id_clone
); );
} }
Err(e) => { Err(response) => {
warn!("Background response {} failed: {}", response_id_clone, e); warn!(
"Background response {} failed with status {}",
response_id_clone,
response.status()
);
} }
} }
// Clean up task handle when done // Clean up task handle when done
background_tasks_clone ctx_clone
.background_tasks
.write() .write()
.await .await
.remove(&response_id_clone); .remove(&response_id_clone);
}); });
// Store task info for cancellation support // Store task info for cancellation support
background_tasks.write().await.insert( ctx.background_tasks.write().await.insert(
response_id.clone(), response_id.clone(),
BackgroundTaskInfo { BackgroundTaskInfo {
handle, handle,
...@@ -440,61 +384,28 @@ async fn route_responses_background( ...@@ -440,61 +384,28 @@ async fn route_responses_background(
/// Execute streaming responses request /// Execute streaming responses request
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
async fn route_responses_streaming( async fn route_responses_streaming(
pipeline: &RequestPipeline, ctx: &super::context::ResponsesContext,
request: Arc<ResponsesRequest>, request: Arc<ResponsesRequest>,
headers: Option<http::HeaderMap>, headers: Option<http::HeaderMap>,
model_id: Option<String>, model_id: Option<String>,
components: Arc<SharedComponents>,
response_storage: SharedResponseStorage,
conversation_storage: SharedConversationStorage,
conversation_item_storage: SharedConversationItemStorage,
mcp_manager: Arc<crate::mcp::McpManager>,
) -> Response { ) -> Response {
// 1. Load conversation history // 1. Load conversation history
let modified_request = match load_conversation_history( let modified_request = match load_conversation_history(ctx, &request).await {
&request,
&response_storage,
&conversation_storage,
&conversation_item_storage,
)
.await
{
Ok(req) => req, Ok(req) => req,
Err(e) => { Err(response) => return response, // Already a Response with proper status code
return (
StatusCode::BAD_REQUEST,
axum::Json(json!({
"error": {
"message": e,
"type": "invalid_request_error"
}
})),
)
.into_response();
}
}; };
// 2. Check if request has MCP tools - if so, use streaming tool loop // 2. Check if request has MCP tools - if so, use streaming tool loop
if let Some(tools) = &request.tools { if let Some(tools) = &request.tools {
// Try to create dynamic MCP client from request tools using the manager // Ensure dynamic MCP client is registered for request-scoped tools
if let Some(request_mcp_manager) = if ensure_request_mcp_client(&ctx.mcp_manager, tools)
create_mcp_manager_from_request(&mcp_manager, tools).await .await
.is_some()
{ {
debug!("MCP tools detected in streaming mode, using streaming tool loop"); debug!("MCP tools detected in streaming mode, using streaming tool loop");
return execute_tool_loop_streaming( return execute_tool_loop_streaming(ctx, modified_request, &request, headers, model_id)
pipeline, .await;
modified_request,
&request,
headers,
model_id,
components,
request_mcp_manager,
response_storage,
conversation_storage,
conversation_item_storage,
)
.await;
} }
} }
...@@ -516,18 +427,7 @@ async fn route_responses_streaming( ...@@ -516,18 +427,7 @@ async fn route_responses_streaming(
}; };
// 4. Execute chat pipeline and convert streaming format (no MCP tools) // 4. Execute chat pipeline and convert streaming format (no MCP tools)
convert_chat_stream_to_responses_stream( convert_chat_stream_to_responses_stream(ctx, chat_request, headers, model_id, &request).await
pipeline,
chat_request,
headers,
model_id,
components,
&request,
response_storage,
conversation_storage,
conversation_item_storage,
)
.await
} }
/// Convert chat streaming response to responses streaming format /// Convert chat streaming response to responses streaming format
...@@ -540,21 +440,23 @@ async fn route_responses_streaming( ...@@ -540,21 +440,23 @@ async fn route_responses_streaming(
/// 5. Emits transformed SSE events in responses format /// 5. Emits transformed SSE events in responses format
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
async fn convert_chat_stream_to_responses_stream( async fn convert_chat_stream_to_responses_stream(
pipeline: &RequestPipeline, ctx: &super::context::ResponsesContext,
chat_request: Arc<crate::protocols::chat::ChatCompletionRequest>, chat_request: Arc<crate::protocols::chat::ChatCompletionRequest>,
headers: Option<http::HeaderMap>, headers: Option<http::HeaderMap>,
model_id: Option<String>, model_id: Option<String>,
components: Arc<SharedComponents>,
original_request: &ResponsesRequest, original_request: &ResponsesRequest,
response_storage: SharedResponseStorage,
_conversation_storage: SharedConversationStorage,
_conversation_item_storage: SharedConversationItemStorage,
) -> Response { ) -> Response {
debug!("Converting chat SSE stream to responses SSE format"); debug!("Converting chat SSE stream to responses SSE format");
// Get chat streaming response // Get chat streaming response
let chat_response = pipeline let chat_response = ctx
.execute_chat(chat_request.clone(), headers, model_id, components) .pipeline
.execute_chat(
chat_request.clone(),
headers,
model_id,
ctx.components.clone(),
)
.await; .await;
// Extract body and headers from chat response // Extract body and headers from chat response
...@@ -566,18 +468,18 @@ async fn convert_chat_stream_to_responses_stream( ...@@ -566,18 +468,18 @@ async fn convert_chat_stream_to_responses_stream(
// Spawn background task to transform stream // Spawn background task to transform stream
let original_request_clone = original_request.clone(); let original_request_clone = original_request.clone();
let chat_request_clone = chat_request.clone(); let chat_request_clone = chat_request.clone();
let response_storage_clone = response_storage.clone(); let response_storage = ctx.response_storage.clone();
let conversation_storage_clone = _conversation_storage.clone(); let conversation_storage = ctx.conversation_storage.clone();
let conversation_item_storage_clone = _conversation_item_storage.clone(); let conversation_item_storage = ctx.conversation_item_storage.clone();
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = process_and_transform_sse_stream( if let Err(e) = process_and_transform_sse_stream(
body, body,
original_request_clone, original_request_clone,
chat_request_clone, chat_request_clone,
response_storage_clone, response_storage,
conversation_storage_clone, conversation_storage,
conversation_item_storage_clone, conversation_item_storage,
tx.clone(), tx.clone(),
) )
.await .await
...@@ -710,9 +612,9 @@ async fn process_and_transform_sse_stream( ...@@ -710,9 +612,9 @@ async fn process_and_transform_sse_stream(
if let Ok(response_json) = serde_json::to_value(&final_response) { if let Ok(response_json) = serde_json::to_value(&final_response) {
if let Err(e) = persist_conversation_items( if let Err(e) = persist_conversation_items(
conversation_storage, conversation_storage.clone(),
conversation_item_storage, conversation_item_storage.clone(),
response_storage, response_storage.clone(),
&response_json, &response_json,
&original_request, &original_request,
) )
...@@ -925,53 +827,55 @@ impl StreamingResponseAccumulator { ...@@ -925,53 +827,55 @@ impl StreamingResponseAccumulator {
// ============================================================================ // ============================================================================
/// Execute request without MCP tool loop (simple pipeline execution) /// Execute request without MCP tool loop (simple pipeline execution)
#[allow(clippy::too_many_arguments)]
async fn execute_without_mcp( async fn execute_without_mcp(
pipeline: &RequestPipeline, ctx: &super::context::ResponsesContext,
modified_request: &ResponsesRequest, modified_request: &ResponsesRequest,
original_request: &ResponsesRequest, original_request: &ResponsesRequest,
headers: Option<http::HeaderMap>, headers: Option<http::HeaderMap>,
model_id: Option<String>, model_id: Option<String>,
components: Arc<SharedComponents>,
response_id: Option<String>, response_id: Option<String>,
background_tasks: Option<Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>>, ) -> Result<ResponsesResponse, Response> {
) -> Result<ResponsesResponse, String> { use crate::routers::grpc::utils;
// Convert ResponsesRequest → ChatCompletionRequest // Convert ResponsesRequest → ChatCompletionRequest
let chat_request = conversions::responses_to_chat(modified_request) let chat_request = conversions::responses_to_chat(modified_request)
.map_err(|e| format!("Failed to convert request: {}", e))?; .map_err(|e| utils::bad_request_error(format!("Failed to convert request: {}", e)))?;
// Execute chat pipeline // Execute chat pipeline (errors already have proper HTTP status codes)
let chat_response = pipeline let chat_response = ctx
.pipeline
.execute_chat_for_responses( .execute_chat_for_responses(
Arc::new(chat_request), Arc::new(chat_request),
headers, headers,
model_id, model_id,
components, ctx.components.clone(),
response_id.clone(), response_id.clone(),
background_tasks, Some(ctx.background_tasks.clone()),
) )
.await .await?; // Preserve the Response error as-is
.map_err(|e| format!("Pipeline execution failed: {}", e))?;
// Convert ChatCompletionResponse → ResponsesResponse // Convert ChatCompletionResponse → ResponsesResponse
conversions::chat_to_responses(&chat_response, original_request, response_id) conversions::chat_to_responses(&chat_response, original_request, response_id).map_err(|e| {
.map_err(|e| format!("Failed to convert to responses format: {}", e)) utils::internal_error_message(format!("Failed to convert to responses format: {}", e))
})
} }
/// Load conversation history and response chains, returning modified request /// Load conversation history and response chains, returning modified request
async fn load_conversation_history( async fn load_conversation_history(
ctx: &super::context::ResponsesContext,
request: &ResponsesRequest, request: &ResponsesRequest,
response_storage: &SharedResponseStorage, ) -> Result<ResponsesRequest, Response> {
conversation_storage: &SharedConversationStorage,
conversation_item_storage: &SharedConversationItemStorage,
) -> Result<ResponsesRequest, String> {
let mut modified_request = request.clone(); let mut modified_request = request.clone();
let mut conversation_items: Option<Vec<ResponseInputOutputItem>> = None; let mut conversation_items: Option<Vec<ResponseInputOutputItem>> = None;
// Handle previous_response_id by loading response chain // Handle previous_response_id by loading response chain
if let Some(ref prev_id_str) = modified_request.previous_response_id { if let Some(ref prev_id_str) = modified_request.previous_response_id {
let prev_id = ResponseId::from(prev_id_str.as_str()); let prev_id = ResponseId::from(prev_id_str.as_str());
match response_storage.get_response_chain(&prev_id, None).await { match ctx
.response_storage
.get_response_chain(&prev_id, None)
.await
{
Ok(chain) => { Ok(chain) => {
let mut items = Vec::new(); let mut items = Vec::new();
for stored in chain.responses.iter() { for stored in chain.responses.iter() {
...@@ -1026,7 +930,7 @@ async fn load_conversation_history( ...@@ -1026,7 +930,7 @@ async fn load_conversation_history(
let conv_id = ConversationId::from(conv_id_str.as_str()); let conv_id = ConversationId::from(conv_id_str.as_str());
// Auto-create conversation if it doesn't exist (OpenAI behavior) // Auto-create conversation if it doesn't exist (OpenAI behavior)
if let Ok(None) = conversation_storage.get_conversation(&conv_id).await { if let Ok(None) = ctx.conversation_storage.get_conversation(&conv_id).await {
debug!( debug!(
"Creating new conversation with user-provided ID: {}", "Creating new conversation with user-provided ID: {}",
conv_id_str conv_id_str
...@@ -1043,10 +947,15 @@ async fn load_conversation_history( ...@@ -1043,10 +947,15 @@ async fn load_conversation_history(
id: Some(conv_id.clone()), // Use user-provided conversation ID id: Some(conv_id.clone()), // Use user-provided conversation ID
metadata, metadata,
}; };
conversation_storage ctx.conversation_storage
.create_conversation(new_conv) .create_conversation(new_conv)
.await .await
.map_err(|e| format!("Failed to create conversation: {}", e))?; .map_err(|e| {
crate::routers::grpc::utils::internal_error_message(format!(
"Failed to create conversation: {}",
e
))
})?;
} }
// Load conversation history // Load conversation history
...@@ -1057,7 +966,11 @@ async fn load_conversation_history( ...@@ -1057,7 +966,11 @@ async fn load_conversation_history(
after: None, after: None,
}; };
match conversation_item_storage.list_items(&conv_id, params).await { match ctx
.conversation_item_storage
.list_items(&conv_id, params)
.await
{
Ok(stored_items) => { Ok(stored_items) => {
let mut items: Vec<ResponseInputOutputItem> = Vec::new(); let mut items: Vec<ResponseInputOutputItem> = Vec::new();
for item in stored_items.into_iter() { for item in stored_items.into_iter() {
...@@ -1142,13 +1055,13 @@ async fn load_conversation_history( ...@@ -1142,13 +1055,13 @@ async fn load_conversation_history(
/// Implementation for GET /v1/responses/{response_id} /// Implementation for GET /v1/responses/{response_id}
pub async fn get_response_impl( pub async fn get_response_impl(
response_storage: &SharedResponseStorage, ctx: &super::context::ResponsesContext,
response_id: &str, response_id: &str,
) -> Response { ) -> Response {
let resp_id = ResponseId::from(response_id); let resp_id = ResponseId::from(response_id);
// Retrieve response from storage // Retrieve response from storage
match response_storage.get_response(&resp_id).await { match ctx.response_storage.get_response(&resp_id).await {
Ok(Some(stored_response)) => axum::Json(stored_response.raw_response).into_response(), Ok(Some(stored_response)) => axum::Json(stored_response.raw_response).into_response(),
Ok(None) => ( Ok(None) => (
StatusCode::NOT_FOUND, StatusCode::NOT_FOUND,
...@@ -1180,14 +1093,13 @@ pub async fn get_response_impl( ...@@ -1180,14 +1093,13 @@ pub async fn get_response_impl(
/// Implementation for POST /v1/responses/{response_id}/cancel /// Implementation for POST /v1/responses/{response_id}/cancel
pub async fn cancel_response_impl( pub async fn cancel_response_impl(
response_storage: &SharedResponseStorage, ctx: &super::context::ResponsesContext,
background_tasks: &Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>,
response_id: &str, response_id: &str,
) -> Response { ) -> Response {
let resp_id = ResponseId::from(response_id); let resp_id = ResponseId::from(response_id);
// Retrieve response from storage to check if it exists and get current status // Retrieve response from storage to check if it exists and get current status
match response_storage.get_response(&resp_id).await { match ctx.response_storage.get_response(&resp_id).await {
Ok(Some(stored_response)) => { Ok(Some(stored_response)) => {
// Check current status - only queued or in_progress responses can be cancelled // Check current status - only queued or in_progress responses can be cancelled
let current_status = stored_response let current_status = stored_response
...@@ -1199,7 +1111,7 @@ pub async fn cancel_response_impl( ...@@ -1199,7 +1111,7 @@ pub async fn cancel_response_impl(
match current_status { match current_status {
"queued" | "in_progress" => { "queued" | "in_progress" => {
// Attempt to abort the background task // Attempt to abort the background task
let mut tasks = background_tasks.write().await; let mut tasks = ctx.background_tasks.write().await;
if let Some(task_info) = tasks.remove(response_id) { if let Some(task_info) = tasks.remove(response_id) {
// Abort the Rust task immediately // Abort the Rust task immediately
task_info.handle.abort(); task_info.handle.abort();
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
//! - Response persistence //! - Response persistence
// Module declarations // Module declarations
pub mod context;
mod conversions; mod conversions;
mod handlers; mod handlers;
pub mod streaming; pub mod streaming;
...@@ -16,5 +17,6 @@ pub mod tool_loop; ...@@ -16,5 +17,6 @@ pub mod tool_loop;
pub mod types; pub mod types;
// Public exports // Public exports
pub use context::ResponsesContext;
pub use handlers::{cancel_response_impl, get_response_impl, route_responses}; pub use handlers::{cancel_response_impl, get_response_impl, route_responses};
pub use types::BackgroundTaskInfo; pub use types::BackgroundTaskInfo;
...@@ -13,7 +13,7 @@ use axum::{ ...@@ -13,7 +13,7 @@ use axum::{
}; };
use bytes::Bytes; use bytes::Bytes;
use serde_json::json; use serde_json::json;
use tokio::sync::{mpsc, RwLock}; use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, warn}; use tracing::{debug, warn};
use uuid::Uuid; use uuid::Uuid;
...@@ -21,24 +21,14 @@ use uuid::Uuid; ...@@ -21,24 +21,14 @@ use uuid::Uuid;
use super::{ use super::{
conversions, conversions,
streaming::{OutputItemType, ResponseStreamEventEmitter}, streaming::{OutputItemType, ResponseStreamEventEmitter},
types::BackgroundTaskInfo,
}; };
/// This is a re-export of the shared implementation from openai::mcp use crate::protocols::{
pub(super) use crate::routers::openai::mcp::ensure_request_mcp_client as create_mcp_manager_from_request; chat::ChatCompletionResponse,
use crate::{ common::{Tool, ToolChoice, ToolChoiceValue},
data_connector::{ responses::{
SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage, McpToolInfo, ResponseContentPart, ResponseInput, ResponseInputOutputItem,
ResponseOutputItem, ResponseStatus, ResponseToolType, ResponsesRequest, ResponsesResponse,
}, },
protocols::{
chat::ChatCompletionResponse,
common::{Tool, ToolChoice, ToolChoiceValue},
responses::{
McpToolInfo, ResponseContentPart, ResponseInput, ResponseInputOutputItem,
ResponseOutputItem, ResponseStatus, ResponseToolType, ResponsesRequest,
ResponsesResponse,
},
},
routers::grpc::{context::SharedComponents, pipeline::RequestPipeline},
}; };
/// Extract function call from a chat completion response /// Extract function call from a chat completion response
...@@ -221,18 +211,14 @@ fn build_mcp_call_item( ...@@ -221,18 +211,14 @@ fn build_mcp_call_item(
/// 2. Checks if response has tool calls /// 2. Checks if response has tool calls
/// 3. If yes, executes MCP tools and builds resume request /// 3. If yes, executes MCP tools and builds resume request
/// 4. Repeats until no more tool calls or limit reached /// 4. Repeats until no more tool calls or limit reached
#[allow(clippy::too_many_arguments)]
pub(super) async fn execute_tool_loop( pub(super) async fn execute_tool_loop(
pipeline: &RequestPipeline, ctx: &super::context::ResponsesContext,
mut current_request: ResponsesRequest, mut current_request: ResponsesRequest,
original_request: &ResponsesRequest, original_request: &ResponsesRequest,
headers: Option<http::HeaderMap>, headers: Option<http::HeaderMap>,
model_id: Option<String>, model_id: Option<String>,
components: Arc<SharedComponents>,
mcp_manager: Arc<crate::mcp::McpManager>,
response_id: Option<String>, response_id: Option<String>,
background_tasks: Option<Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>>, ) -> Result<ResponsesResponse, Response> {
) -> Result<ResponsesResponse, String> {
// Get server label from original request tools // Get server label from original request tools
let server_label = original_request let server_label = original_request
.tools .tools
...@@ -257,31 +243,35 @@ pub(super) async fn execute_tool_loop( ...@@ -257,31 +243,35 @@ pub(super) async fn execute_tool_loop(
); );
// Get MCP tools and convert to chat format (do this once before loop) // Get MCP tools and convert to chat format (do this once before loop)
let mcp_tools = mcp_manager.list_tools(); let mcp_tools = ctx.mcp_manager.list_tools();
let chat_tools = convert_mcp_tools_to_chat_tools(&mcp_tools); let chat_tools = convert_mcp_tools_to_chat_tools(&mcp_tools);
debug!("Converted {} MCP tools to chat format", chat_tools.len()); debug!("Converted {} MCP tools to chat format", chat_tools.len());
loop { loop {
// Convert to chat request // Convert to chat request
let mut chat_request = conversions::responses_to_chat(&current_request) let mut chat_request = conversions::responses_to_chat(&current_request).map_err(|e| {
.map_err(|e| format!("Failed to convert request: {}", e))?; crate::routers::grpc::utils::bad_request_error(format!(
"Failed to convert request: {}",
e
))
})?;
// Add MCP tools to chat request so LLM knows about them // Add MCP tools to chat request so LLM knows about them
chat_request.tools = Some(chat_tools.clone()); chat_request.tools = Some(chat_tools.clone());
chat_request.tool_choice = Some(ToolChoice::Value(ToolChoiceValue::Auto)); chat_request.tool_choice = Some(ToolChoice::Value(ToolChoiceValue::Auto));
// Execute chat pipeline // Execute chat pipeline (errors already have proper HTTP status codes)
let chat_response = pipeline let chat_response = ctx
.pipeline
.execute_chat_for_responses( .execute_chat_for_responses(
Arc::new(chat_request), Arc::new(chat_request),
headers.clone(), headers.clone(),
model_id.clone(), model_id.clone(),
components.clone(), ctx.components.clone(),
response_id.clone(), response_id.clone(),
background_tasks.clone(), Some(ctx.background_tasks.clone()),
) )
.await .await?;
.map_err(|e| format!("Pipeline execution failed: {}", e))?;
// Check for function calls // Check for function calls
if let Some((call_id, tool_name, args_json_str)) = if let Some((call_id, tool_name, args_json_str)) =
...@@ -312,7 +302,12 @@ pub(super) async fn execute_tool_loop( ...@@ -312,7 +302,12 @@ pub(super) async fn execute_tool_loop(
original_request, original_request,
response_id.clone(), response_id.clone(),
) )
.map_err(|e| format!("Failed to convert to responses format: {}", e))?; .map_err(|e| {
crate::routers::grpc::utils::internal_error_message(format!(
"Failed to convert to responses format: {}",
e
))
})?;
// Mark as completed but with incomplete details // Mark as completed but with incomplete details
responses_response.status = ResponseStatus::Completed; responses_response.status = ResponseStatus::Completed;
...@@ -329,7 +324,8 @@ pub(super) async fn execute_tool_loop( ...@@ -329,7 +324,8 @@ pub(super) async fn execute_tool_loop(
"Calling MCP tool '{}' with args: {}", "Calling MCP tool '{}' with args: {}",
tool_name, args_json_str tool_name, args_json_str
); );
let (output_str, success, error) = match mcp_manager let (output_str, success, error) = match ctx
.mcp_manager
.call_tool(tool_name.as_str(), args_json_str.as_str()) .call_tool(tool_name.as_str(), args_json_str.as_str())
.await .await
{ {
...@@ -428,12 +424,17 @@ pub(super) async fn execute_tool_loop( ...@@ -428,12 +424,17 @@ pub(super) async fn execute_tool_loop(
original_request, original_request,
response_id.clone(), response_id.clone(),
) )
.map_err(|e| format!("Failed to convert to responses format: {}", e))?; .map_err(|e| {
crate::routers::grpc::utils::internal_error_message(format!(
"Failed to convert to responses format: {}",
e
))
})?;
// Inject MCP metadata into output // Inject MCP metadata into output
if state.total_calls > 0 { if state.total_calls > 0 {
// Prepend mcp_list_tools item // Prepend mcp_list_tools item
let mcp_list_tools = build_mcp_list_tools_item(&mcp_manager, &server_label); let mcp_list_tools = build_mcp_list_tools_item(&ctx.mcp_manager, &server_label);
responses_response.output.insert(0, mcp_list_tools); responses_response.output.insert(0, mcp_list_tools);
// Append all mcp_call items at the end // Append all mcp_call items at the end
...@@ -455,52 +456,28 @@ pub(super) async fn execute_tool_loop( ...@@ -455,52 +456,28 @@ pub(super) async fn execute_tool_loop(
/// This streams each iteration's response to the client while accumulating /// This streams each iteration's response to the client while accumulating
/// to check for tool calls. If tool calls are found, executes them and /// to check for tool calls. If tool calls are found, executes them and
/// continues with the next streaming iteration. /// continues with the next streaming iteration.
#[allow(clippy::too_many_arguments)]
pub(super) async fn execute_tool_loop_streaming( pub(super) async fn execute_tool_loop_streaming(
pipeline: &RequestPipeline, ctx: &super::context::ResponsesContext,
current_request: ResponsesRequest, current_request: ResponsesRequest,
original_request: &ResponsesRequest, original_request: &ResponsesRequest,
headers: Option<http::HeaderMap>, headers: Option<http::HeaderMap>,
model_id: Option<String>, model_id: Option<String>,
components: Arc<SharedComponents>,
mcp_manager: Arc<crate::mcp::McpManager>,
response_storage: SharedResponseStorage,
conversation_storage: SharedConversationStorage,
conversation_item_storage: SharedConversationItemStorage,
) -> Response { ) -> Response {
// Get server label
let server_label = original_request
.tools
.as_ref()
.and_then(|tools| {
tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
.and_then(|t| t.server_label.clone())
})
.unwrap_or_else(|| "request-mcp".to_string());
// Create SSE channel for client // Create SSE channel for client
let (tx, rx) = mpsc::unbounded_channel::<Result<Bytes, std::io::Error>>(); let (tx, rx) = mpsc::unbounded_channel::<Result<Bytes, std::io::Error>>();
// Clone data for background task // Clone data for background task
let pipeline_clone = pipeline.clone(); let ctx_clone = ctx.clone();
let original_request_clone = original_request.clone(); let original_request_clone = original_request.clone();
// Spawn background task for tool loop // Spawn background task for tool loop
tokio::spawn(async move { tokio::spawn(async move {
let result = execute_tool_loop_streaming_internal( let result = execute_tool_loop_streaming_internal(
&pipeline_clone, &ctx_clone,
current_request, current_request,
&original_request_clone, &original_request_clone,
headers, headers,
model_id, model_id,
components,
mcp_manager,
server_label,
response_storage,
conversation_storage,
conversation_item_storage,
tx.clone(), tx.clone(),
) )
.await; .await;
...@@ -546,21 +523,26 @@ pub(super) async fn execute_tool_loop_streaming( ...@@ -546,21 +523,26 @@ pub(super) async fn execute_tool_loop_streaming(
} }
/// Internal streaming tool loop implementation /// Internal streaming tool loop implementation
#[allow(clippy::too_many_arguments)]
async fn execute_tool_loop_streaming_internal( async fn execute_tool_loop_streaming_internal(
pipeline: &RequestPipeline, ctx: &super::context::ResponsesContext,
mut current_request: ResponsesRequest, mut current_request: ResponsesRequest,
original_request: &ResponsesRequest, original_request: &ResponsesRequest,
headers: Option<http::HeaderMap>, headers: Option<http::HeaderMap>,
model_id: Option<String>, model_id: Option<String>,
components: Arc<SharedComponents>,
mcp_manager: Arc<crate::mcp::McpManager>,
server_label: String,
_response_storage: SharedResponseStorage,
_conversation_storage: SharedConversationStorage,
_conversation_item_storage: SharedConversationItemStorage,
tx: mpsc::UnboundedSender<Result<Bytes, std::io::Error>>, tx: mpsc::UnboundedSender<Result<Bytes, std::io::Error>>,
) -> Result<(), String> { ) -> Result<(), String> {
// Extract server label from original request tools
let server_label = original_request
.tools
.as_ref()
.and_then(|tools| {
tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
.and_then(|t| t.server_label.clone())
})
.unwrap_or_else(|| "request-mcp".to_string());
const MAX_ITERATIONS: usize = 10; const MAX_ITERATIONS: usize = 10;
let mut state = ToolLoopState::new(original_request.input.clone(), server_label.clone()); let mut state = ToolLoopState::new(original_request.input.clone(), server_label.clone());
let max_tool_calls = original_request.max_tool_calls.map(|n| n as usize); let max_tool_calls = original_request.max_tool_calls.map(|n| n as usize);
...@@ -581,7 +563,7 @@ async fn execute_tool_loop_streaming_internal( ...@@ -581,7 +563,7 @@ async fn execute_tool_loop_streaming_internal(
emitter.send_event(&event, &tx)?; emitter.send_event(&event, &tx)?;
// Get MCP tools and convert to chat format (do this once before loop) // Get MCP tools and convert to chat format (do this once before loop)
let mcp_tools = mcp_manager.list_tools(); let mcp_tools = ctx.mcp_manager.list_tools();
let chat_tools = convert_mcp_tools_to_chat_tools(&mcp_tools); let chat_tools = convert_mcp_tools_to_chat_tools(&mcp_tools);
debug!( debug!(
"Streaming: Converted {} MCP tools to chat format", "Streaming: Converted {} MCP tools to chat format",
...@@ -670,12 +652,13 @@ async fn execute_tool_loop_streaming_internal( ...@@ -670,12 +652,13 @@ async fn execute_tool_loop_streaming_internal(
chat_request.tool_choice = Some(ToolChoice::Value(ToolChoiceValue::Auto)); chat_request.tool_choice = Some(ToolChoice::Value(ToolChoiceValue::Auto));
// Execute chat streaming // Execute chat streaming
let response = pipeline let response = ctx
.pipeline
.execute_chat( .execute_chat(
Arc::new(chat_request), Arc::new(chat_request),
headers.clone(), headers.clone(),
model_id.clone(), model_id.clone(),
components.clone(), ctx.components.clone(),
) )
.await; .await;
...@@ -758,7 +741,8 @@ async fn execute_tool_loop_streaming_internal( ...@@ -758,7 +741,8 @@ async fn execute_tool_loop_streaming_internal(
"Calling MCP tool '{}' with args: {}", "Calling MCP tool '{}' with args: {}",
tool_name, args_json_str tool_name, args_json_str
); );
let (output_str, success, error) = match mcp_manager let (output_str, success, error) = match ctx
.mcp_manager
.call_tool(tool_name.as_str(), args_json_str.as_str()) .call_tool(tool_name.as_str(), args_json_str.as_str())
.await .await
{ {
......
// gRPC Router Implementation // gRPC Router Implementation
use std::{collections::HashMap, sync::Arc}; use std::sync::Arc;
use async_trait::async_trait; use async_trait::async_trait;
use axum::{ use axum::{
...@@ -9,22 +9,13 @@ use axum::{ ...@@ -9,22 +9,13 @@ use axum::{
http::{HeaderMap, StatusCode}, http::{HeaderMap, StatusCode},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
}; };
use tokio::sync::RwLock;
use tracing::debug; use tracing::debug;
use super::{ use super::{context::SharedComponents, pipeline::RequestPipeline, responses};
context::SharedComponents,
pipeline::RequestPipeline,
responses::{self, BackgroundTaskInfo},
};
use crate::{ use crate::{
app_context::AppContext, app_context::AppContext,
config::types::RetryConfig, config::types::RetryConfig,
core::WorkerRegistry, core::WorkerRegistry,
data_connector::{
SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage,
},
mcp::McpManager,
policies::PolicyRegistry, policies::PolicyRegistry,
protocols::{ protocols::{
chat::ChatCompletionRequest, chat::ChatCompletionRequest,
...@@ -57,13 +48,8 @@ pub struct GrpcRouter { ...@@ -57,13 +48,8 @@ pub struct GrpcRouter {
configured_tool_parser: Option<String>, configured_tool_parser: Option<String>,
pipeline: RequestPipeline, pipeline: RequestPipeline,
shared_components: Arc<SharedComponents>, shared_components: Arc<SharedComponents>,
// Storage backends for /v1/responses support // Responses context (bundles all /v1/responses dependencies: storage, MCP, background_tasks)
response_storage: SharedResponseStorage, responses_context: responses::ResponsesContext,
conversation_storage: SharedConversationStorage,
conversation_item_storage: SharedConversationItemStorage,
mcp_manager: Arc<McpManager>,
// Background task handles for cancellation support (includes gRPC client for Python abort)
background_tasks: Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>,
} }
impl GrpcRouter { impl GrpcRouter {
...@@ -89,18 +75,6 @@ impl GrpcRouter { ...@@ -89,18 +75,6 @@ impl GrpcRouter {
let worker_registry = ctx.worker_registry.clone(); let worker_registry = ctx.worker_registry.clone();
let policy_registry = ctx.policy_registry.clone(); let policy_registry = ctx.policy_registry.clone();
// Extract storage backends from context
let response_storage = ctx.response_storage.clone();
let conversation_storage = ctx.conversation_storage.clone();
let conversation_item_storage = ctx.conversation_item_storage.clone();
// Get MCP manager from app context
let mcp_manager = ctx
.mcp_manager
.get()
.ok_or_else(|| "gRPC router requires MCP manager".to_string())?
.clone();
// Create shared components for pipeline // Create shared components for pipeline
let shared_components = Arc::new(SharedComponents { let shared_components = Arc::new(SharedComponents {
tokenizer: tokenizer.clone(), tokenizer: tokenizer.clone(),
...@@ -119,6 +93,20 @@ impl GrpcRouter { ...@@ -119,6 +93,20 @@ impl GrpcRouter {
ctx.configured_reasoning_parser.clone(), ctx.configured_reasoning_parser.clone(),
); );
// Create responses context with all dependencies
let responses_context = responses::ResponsesContext::new(
Arc::new(pipeline.clone()),
shared_components.clone(),
worker_registry.clone(),
ctx.response_storage.clone(),
ctx.conversation_storage.clone(),
ctx.conversation_item_storage.clone(),
ctx.mcp_manager
.get()
.ok_or_else(|| "gRPC router requires MCP manager".to_string())?
.clone(),
);
Ok(GrpcRouter { Ok(GrpcRouter {
worker_registry, worker_registry,
policy_registry, policy_registry,
...@@ -132,11 +120,7 @@ impl GrpcRouter { ...@@ -132,11 +120,7 @@ impl GrpcRouter {
configured_tool_parser: ctx.configured_tool_parser.clone(), configured_tool_parser: ctx.configured_tool_parser.clone(),
pipeline, pipeline,
shared_components, shared_components,
response_storage, responses_context,
conversation_storage,
conversation_item_storage,
mcp_manager,
background_tasks: Arc::new(RwLock::new(HashMap::new())),
}) })
} }
...@@ -254,26 +238,11 @@ impl RouterTrait for GrpcRouter { ...@@ -254,26 +238,11 @@ impl RouterTrait for GrpcRouter {
body: &ResponsesRequest, body: &ResponsesRequest,
model_id: Option<&str>, model_id: Option<&str>,
) -> Response { ) -> Response {
// Use responses module for ALL requests (streaming and non-streaming)
// Responses module handles:
// - Request validation (previous_response_id XOR conversation)
// - Loading response chain / conversation history from storage
// - Conversion: ResponsesRequest → ChatCompletionRequest
// - Execution through chat pipeline stages
// - Conversion: ChatCompletionResponse → ResponsesResponse
// - Response persistence
// - MCP tool loop wrapper (future)
responses::route_responses( responses::route_responses(
&self.pipeline, &self.responses_context,
Arc::new(body.clone()), Arc::new(body.clone()),
headers.cloned(), headers.cloned(),
model_id.map(|s| s.to_string()), model_id.map(|s| s.to_string()),
self.shared_components.clone(),
self.response_storage.clone(),
self.conversation_storage.clone(),
self.conversation_item_storage.clone(),
self.mcp_manager.clone(),
self.background_tasks.clone(),
) )
.await .await
} }
...@@ -284,12 +253,11 @@ impl RouterTrait for GrpcRouter { ...@@ -284,12 +253,11 @@ impl RouterTrait for GrpcRouter {
response_id: &str, response_id: &str,
_params: &ResponsesGetParams, _params: &ResponsesGetParams,
) -> Response { ) -> Response {
responses::get_response_impl(&self.response_storage, response_id).await responses::get_response_impl(&self.responses_context, response_id).await
} }
async fn cancel_response(&self, _headers: Option<&HeaderMap>, response_id: &str) -> Response { async fn cancel_response(&self, _headers: Option<&HeaderMap>, response_id: &str) -> Response {
responses::cancel_response_impl(&self.response_storage, &self.background_tasks, response_id) responses::cancel_response_impl(&self.responses_context, response_id).await
.await
} }
async fn route_classify( async fn route_classify(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment