Unverified Commit 0e82fd3d authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][grpc] Fix model validation, tool call check, streaming logic and misc...

[router][grpc] Fix model validation, tool call check, streaming logic and misc in responses  (#12616)
parent b7d70411
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
pub mod handlers; pub mod handlers;
pub mod streaming; pub mod streaming;
pub mod utils;
pub use handlers::{cancel_response_impl, get_response_impl}; pub use handlers::{cancel_response_impl, get_response_impl};
pub use streaming::{OutputItemType, ResponseStreamEventEmitter}; pub use streaming::{build_sse_response, OutputItemType, ResponseStreamEventEmitter};
pub use utils::ensure_mcp_connection;
...@@ -2,9 +2,11 @@ ...@@ -2,9 +2,11 @@
use std::collections::HashMap; use std::collections::HashMap;
use axum::{body::Body, http::StatusCode, response::Response};
use bytes::Bytes; use bytes::Bytes;
use serde_json::json; use serde_json::json;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
use uuid::Uuid; use uuid::Uuid;
use crate::{mcp, protocols::chat::ChatCompletionStreamResponse}; use crate::{mcp, protocols::chat::ChatCompletionStreamResponse};
...@@ -13,6 +15,7 @@ pub enum OutputItemType { ...@@ -13,6 +15,7 @@ pub enum OutputItemType {
Message, Message,
McpListTools, McpListTools,
McpCall, McpCall,
FunctionCall,
Reasoning, Reasoning,
} }
...@@ -342,6 +345,40 @@ impl ResponseStreamEventEmitter { ...@@ -342,6 +345,40 @@ impl ResponseStreamEventEmitter {
}) })
} }
// ========================================================================
// Function Call Event Emission Methods
// ========================================================================
pub fn emit_function_call_arguments_delta(
&mut self,
output_index: usize,
item_id: &str,
delta: &str,
) -> serde_json::Value {
json!({
"type": "response.function_call_arguments.delta",
"sequence_number": self.next_sequence(),
"output_index": output_index,
"item_id": item_id,
"delta": delta
})
}
pub fn emit_function_call_arguments_done(
&mut self,
output_index: usize,
item_id: &str,
arguments: &str,
) -> serde_json::Value {
json!({
"type": "response.function_call_arguments.done",
"sequence_number": self.next_sequence(),
"output_index": output_index,
"item_id": item_id,
"arguments": arguments
})
}
// ======================================================================== // ========================================================================
// Output Item Wrapper Events // Output Item Wrapper Events
// ======================================================================== // ========================================================================
...@@ -387,6 +424,7 @@ impl ResponseStreamEventEmitter { ...@@ -387,6 +424,7 @@ impl ResponseStreamEventEmitter {
let id_prefix = match &item_type { let id_prefix = match &item_type {
OutputItemType::McpListTools => "mcpl", OutputItemType::McpListTools => "mcpl",
OutputItemType::McpCall => "mcp", OutputItemType::McpCall => "mcp",
OutputItemType::FunctionCall => "fc",
OutputItemType::Message => "msg", OutputItemType::Message => "msg",
OutputItemType::Reasoning => "rs", OutputItemType::Reasoning => "rs",
}; };
...@@ -582,4 +620,40 @@ impl ResponseStreamEventEmitter { ...@@ -582,4 +620,40 @@ impl ResponseStreamEventEmitter {
} }
} }
} }
/// Emit an error event
///
/// Creates and sends an error event with the given error message.
/// Uses OpenAI's error event format.
/// Use this for terminal errors that should abort the streaming response.
pub fn emit_error(
&mut self,
error_msg: &str,
error_code: Option<&str>,
tx: &mpsc::UnboundedSender<Result<Bytes, std::io::Error>>,
) {
let event = json!({
"type": "error",
"code": error_code.unwrap_or("internal_error"),
"message": error_msg,
"param": null,
"sequence_number": self.next_sequence()
});
let sse_data = format!("data: {}\n\n", serde_json::to_string(&event).unwrap());
let _ = tx.send(Ok(Bytes::from(sse_data)));
}
}
/// Build a Server-Sent Events (SSE) response
///
/// Creates a Response with proper SSE headers and streaming body.
pub fn build_sse_response(rx: mpsc::UnboundedReceiver<Result<Bytes, std::io::Error>>) -> Response {
let stream = UnboundedReceiverStream::new(rx);
Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "text/event-stream")
.header("Cache-Control", "no-cache")
.header("Connection", "keep-alive")
.body(Body::from_stream(stream))
.unwrap()
} }
//! Utility functions for /v1/responses endpoint
use std::sync::Arc;
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
};
use serde_json::json;
use crate::{
core::WorkerRegistry,
mcp::McpManager,
protocols::responses::{ResponseTool, ResponseToolType},
routers::{grpc::error, openai::mcp::ensure_request_mcp_client},
};
/// Ensure MCP connection succeeds if MCP tools are declared
///
/// Checks if request declares MCP tools, and if so, validates that
/// the MCP client can be created and connected.
pub async fn ensure_mcp_connection(
mcp_manager: &Arc<McpManager>,
tools: Option<&[ResponseTool]>,
) -> Result<bool, Response> {
let has_mcp_tools = tools
.map(|t| {
t.iter()
.any(|tool| matches!(tool.r#type, ResponseToolType::Mcp))
})
.unwrap_or(false);
if has_mcp_tools {
if let Some(tools) = tools {
if ensure_request_mcp_client(mcp_manager, tools)
.await
.is_none()
{
return Err(error::failed_dependency(
"Failed to connect to MCP server. Check server_url and authorization.",
));
}
}
}
Ok(has_mcp_tools)
}
/// Validate that workers are available for the requested model
pub fn validate_worker_availability(
worker_registry: &Arc<WorkerRegistry>,
model: &str,
) -> Option<Response> {
let available_models = worker_registry.get_models();
if !available_models.contains(&model.to_string()) {
return Some(
(
StatusCode::SERVICE_UNAVAILABLE,
axum::Json(json!({
"error": {
"message": format!(
"No workers available for model '{}'. Available models: {}",
model,
available_models.join(", ")
),
"type": "service_unavailable",
"param": "model",
"code": "no_available_workers"
}
})),
)
.into_response(),
);
}
None
}
...@@ -107,6 +107,30 @@ pub fn service_unavailable(message: impl Into<String>) -> Response { ...@@ -107,6 +107,30 @@ pub fn service_unavailable(message: impl Into<String>) -> Response {
.into_response() .into_response()
} }
/// Create a 424 Failed Dependency response
///
/// Use this when an external dependency (like MCP server) fails.
///
/// # Example
/// ```ignore
/// return Err(failed_dependency("Failed to connect to MCP server"));
/// ```
pub fn failed_dependency(message: impl Into<String>) -> Response {
let msg = message.into();
warn!("{}", msg);
(
StatusCode::FAILED_DEPENDENCY,
Json(json!({
"error": {
"message": msg,
"type": "external_connector_error",
"code": 424
}
})),
)
.into_response()
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
......
...@@ -124,8 +124,10 @@ pub enum ResponsesIterationResult { ...@@ -124,8 +124,10 @@ pub enum ResponsesIterationResult {
/// Tool calls found in commentary channel - continue MCP loop /// Tool calls found in commentary channel - continue MCP loop
ToolCallsFound { ToolCallsFound {
tool_calls: Vec<ToolCall>, tool_calls: Vec<ToolCall>,
analysis: Option<String>, // For streaming emission analysis: Option<String>, // For streaming emission or reasoning output
partial_text: String, // For streaming emission partial_text: String, // For streaming emission or message output
usage: Usage, // Token usage from this iteration
request_id: String, // Request ID from dispatch
}, },
/// No tool calls - return final ResponsesResponse /// No tool calls - return final ResponsesResponse
Completed { Completed {
...@@ -206,6 +208,9 @@ impl HarmonyResponseProcessor { ...@@ -206,6 +208,9 @@ impl HarmonyResponseProcessor {
); );
} }
// Build usage (needed for both ToolCallsFound and Completed)
let usage = response_formatting::build_usage(std::slice::from_ref(complete));
// Check for tool calls in commentary channel // Check for tool calls in commentary channel
if let Some(tool_calls) = parsed.commentary { if let Some(tool_calls) = parsed.commentary {
// Tool calls found - return for MCP loop execution // Tool calls found - return for MCP loop execution
...@@ -213,6 +218,8 @@ impl HarmonyResponseProcessor { ...@@ -213,6 +218,8 @@ impl HarmonyResponseProcessor {
tool_calls, tool_calls,
analysis: parsed.analysis, analysis: parsed.analysis,
partial_text: parsed.final_text, partial_text: parsed.final_text,
usage,
request_id: dispatch.request_id.clone(),
}); });
} }
...@@ -245,9 +252,6 @@ impl HarmonyResponseProcessor { ...@@ -245,9 +252,6 @@ impl HarmonyResponseProcessor {
output.push(message_item); output.push(message_item);
} }
// Build usage
let usage = response_formatting::build_usage(std::slice::from_ref(complete));
// Build ResponsesResponse with all required fields // Build ResponsesResponse with all required fields
let response = ResponsesResponse { let response = ResponsesResponse {
id: dispatch.request_id.clone(), id: dispatch.request_id.clone(),
......
...@@ -37,16 +37,14 @@ ...@@ -37,16 +37,14 @@
//! for complete architecture, rationale, and implementation details. //! for complete architecture, rationale, and implementation details.
use std::{ use std::{
io,
sync::Arc, sync::Arc,
time::{SystemTime, UNIX_EPOCH}, time::{SystemTime, UNIX_EPOCH},
}; };
use axum::{body::Body, http::StatusCode, response::Response}; use axum::response::Response;
use bytes::Bytes; use bytes::Bytes;
use serde_json::{from_str, from_value, json, to_string, to_value, Value}; use serde_json::{from_str, from_value, json, to_string, to_value, Value};
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, warn}; use tracing::{debug, warn};
use uuid::Uuid; use uuid::Uuid;
...@@ -54,23 +52,24 @@ use crate::{ ...@@ -54,23 +52,24 @@ use crate::{
data_connector::{ResponseId, ResponseStorage}, data_connector::{ResponseId, ResponseStorage},
mcp::{self, McpManager}, mcp::{self, McpManager},
protocols::{ protocols::{
common::{Function, ToolCall}, common::{Function, ToolCall, Usage},
responses::{ responses::{
McpToolInfo, ResponseContentPart, ResponseInput, ResponseInputOutputItem, McpToolInfo, ResponseContentPart, ResponseInput, ResponseInputOutputItem,
ResponseOutputItem, ResponseReasoningContent, ResponseTool, ResponseToolType, ResponseOutputItem, ResponseReasoningContent, ResponseStatus, ResponseTool,
ResponsesRequest, ResponsesResponse, StringOrContentParts, ResponseToolType, ResponseUsage, ResponsesRequest, ResponsesResponse, ResponsesUsage,
StringOrContentParts,
}, },
}, },
routers::{ routers::grpc::{
grpc::{ common::responses::{
common::responses::streaming::{OutputItemType, ResponseStreamEventEmitter}, build_sse_response, ensure_mcp_connection,
streaming::{OutputItemType, ResponseStreamEventEmitter},
},
context::SharedComponents, context::SharedComponents,
error, error,
harmony::processor::ResponsesIterationResult, harmony::{processor::ResponsesIterationResult, streaming::HarmonyStreamingProcessor},
pipeline::RequestPipeline, pipeline::RequestPipeline,
}, },
openai::mcp::ensure_request_mcp_client,
},
}; };
/// Maximum number of tool execution iterations to prevent infinite loops /// Maximum number of tool execution iterations to prevent infinite loops
...@@ -239,34 +238,34 @@ pub async fn serve_harmony_responses( ...@@ -239,34 +238,34 @@ pub async fn serve_harmony_responses(
request: ResponsesRequest, request: ResponsesRequest,
) -> Result<ResponsesResponse, Response> { ) -> Result<ResponsesResponse, Response> {
// Load previous conversation history if previous_response_id is set // Load previous conversation history if previous_response_id is set
let mut current_request = load_previous_messages(ctx, request).await?; let current_request = load_previous_messages(ctx, request).await?;
let mut iteration_count = 0;
let has_mcp_tools = current_request
.tools
.as_ref()
.map(|tools| {
tools
.iter()
.any(|t| matches!(t.r#type, ResponseToolType::Mcp))
})
.unwrap_or(false);
// Initialize MCP call tracking (will be passed to processor for final response) // Check MCP connection and get whether MCP tools are present
let mut mcp_tracking = if has_mcp_tools { let has_mcp_tools =
Some(McpCallTracking::new("sglang-mcp".to_string())) ensure_mcp_connection(&ctx.mcp_manager, current_request.tools.as_deref()).await?;
} else {
None
};
if has_mcp_tools { if has_mcp_tools {
// Ensure dynamic MCP client is registered for request-scoped tools execute_with_mcp_loop(ctx, current_request).await
if let Some(tools) = &current_request.tools { } else {
ensure_request_mcp_client(&ctx.mcp_manager, tools).await; // No MCP tools - execute pipeline once (may have function tools or no tools)
execute_without_mcp_loop(ctx, current_request).await
} }
}
/// Execute Harmony Responses with MCP tool loop
///
/// Automatically executes MCP tools in a loop until no more tool calls or max iterations
async fn execute_with_mcp_loop(
ctx: &HarmonyResponsesContext,
mut current_request: ResponsesRequest,
) -> Result<ResponsesResponse, Response> {
let mut iteration_count = 0;
let mut mcp_tracking = McpCallTracking::new("sglang-mcp".to_string());
// Extract user's max_tool_calls limit (if set)
let max_tool_calls = current_request.max_tool_calls.map(|n| n as usize);
// Add static MCP tools from inventory to the request // Add static MCP tools from inventory to the request
// (similar to non-Harmony pipeline pattern)
let mcp_tools = ctx.mcp_manager.list_tools(); let mcp_tools = ctx.mcp_manager.list_tools();
if !mcp_tools.is_empty() { if !mcp_tools.is_empty() {
let mcp_response_tools = convert_mcp_tools_to_response_tools(&mcp_tools); let mcp_response_tools = convert_mcp_tools_to_response_tools(&mcp_tools);
...@@ -278,10 +277,9 @@ pub async fn serve_harmony_responses( ...@@ -278,10 +277,9 @@ pub async fn serve_harmony_responses(
debug!( debug!(
mcp_tool_count = mcp_tools.len(), mcp_tool_count = mcp_tools.len(),
total_tool_count = current_request.tools.as_ref().map(|t| t.len()).unwrap_or(0), total_tool_count = current_request.tools.as_ref().map(|t| t.len()).unwrap_or(0),
"Request has MCP tools - added static MCP tools to Harmony Responses request" "MCP client available - added static MCP tools to Harmony Responses request"
); );
} }
}
loop { loop {
iteration_count += 1; iteration_count += 1;
...@@ -317,30 +315,60 @@ pub async fn serve_harmony_responses( ...@@ -317,30 +315,60 @@ pub async fn serve_harmony_responses(
tool_calls, tool_calls,
analysis, analysis,
partial_text, partial_text,
usage,
request_id,
} => { } => {
debug!( debug!(
tool_call_count = tool_calls.len(), tool_call_count = tool_calls.len(),
has_analysis = analysis.is_some(), has_analysis = analysis.is_some(),
partial_text_len = partial_text.len(), partial_text_len = partial_text.len(),
"Tool calls found in commentary channel" "Tool calls found - checking limits before executing MCP tools"
); );
// TODO: Streaming support - emit intermediate chunks // Check combined limit (user's max_tool_calls vs safety limit)
// if let Some(tx) = &ctx.stream_tx { let effective_limit = match max_tool_calls {
// emit_intermediate_chunks(tx, &analysis, &partial_text, iteration_count).await?; Some(user_max) => user_max.min(MAX_TOOL_ITERATIONS),
// } None => MAX_TOOL_ITERATIONS,
// Execute MCP tools via MCP manager
// If tools don't exist, call_tool() will return error naturally
let tool_results = if let Some(ref mut tracking) = mcp_tracking {
execute_mcp_tools(&ctx.mcp_manager, &tool_calls, tracking).await?
} else {
// Should never happen (we only get tool_calls when has_mcp_tools=true)
return Err(error::internal_error(
"Tool calls found but MCP tracking not initialized",
));
}; };
// Check if we would exceed the limit with these new tool calls
let total_calls_after = mcp_tracking.total_calls() + tool_calls.len();
if total_calls_after > effective_limit {
warn!(
current_calls = mcp_tracking.total_calls(),
new_calls = tool_calls.len(),
total_after = total_calls_after,
effective_limit = effective_limit,
user_max = ?max_tool_calls,
"Reached tool call limit - returning incomplete response"
);
// Build response with incomplete status
let mut response = build_function_tool_response(
tool_calls,
analysis,
partial_text,
usage,
request_id,
Arc::new(current_request),
);
// Mark as completed with incomplete_details
response.status = ResponseStatus::Completed;
response.incomplete_details = Some(json!({ "reason": "max_tool_calls" }));
// Inject MCP metadata if any calls were executed
if mcp_tracking.total_calls() > 0 {
inject_mcp_metadata(&mut response, &mcp_tracking, &ctx.mcp_manager);
}
return Ok(response);
}
// Execute MCP tools
let tool_results =
execute_mcp_tools(&ctx.mcp_manager, &tool_calls, &mut mcp_tracking).await?;
// Build next request with appended history // Build next request with appended history
current_request = build_next_request_with_tools( current_request = build_next_request_with_tools(
current_request, current_request,
...@@ -361,30 +389,71 @@ pub async fn serve_harmony_responses( ...@@ -361,30 +389,71 @@ pub async fn serve_harmony_responses(
output_items = response.output.len(), output_items = response.output.len(),
input_tokens = usage.prompt_tokens, input_tokens = usage.prompt_tokens,
output_tokens = usage.completion_tokens, output_tokens = usage.completion_tokens,
has_mcp_tracking = mcp_tracking.is_some(), "MCP loop completed - no more tool calls"
"Harmony Responses serving completed - no more tool calls"
); );
// Inject MCP output items if MCP tools were available // Inject MCP metadata into final response
// (even if no tools were called, we still list available tools) inject_mcp_metadata(&mut response, &mcp_tracking, &ctx.mcp_manager);
if let Some(tracking) = mcp_tracking {
inject_mcp_metadata(&mut response, &tracking, &ctx.mcp_manager);
debug!( debug!(
mcp_calls = tracking.total_calls(), mcp_calls = mcp_tracking.total_calls(),
output_items_after = response.output.len(), output_items_after = response.output.len(),
"Injected MCP metadata into final response" "Injected MCP metadata into final response"
); );
}
// No tool calls - this is the final response // No tool calls - this is the final response
// TODO: Accumulate usage across all iterations if needed
return Ok(*response); return Ok(*response);
} }
} }
} }
} }
/// Execute Harmony Responses without MCP loop (single execution)
///
/// For function tools or no tools - executes pipeline once and returns
async fn execute_without_mcp_loop(
ctx: &HarmonyResponsesContext,
current_request: ResponsesRequest,
) -> Result<ResponsesResponse, Response> {
debug!("Executing Harmony Responses without MCP loop");
// Execute pipeline once
let iteration_result = ctx
.pipeline
.execute_harmony_responses(&current_request, ctx)
.await?;
match iteration_result {
ResponsesIterationResult::ToolCallsFound {
tool_calls,
analysis,
partial_text,
usage,
request_id,
} => {
// Function tool calls found - return to caller for execution
debug!(
tool_call_count = tool_calls.len(),
"Function tool calls found - returning to caller"
);
Ok(build_function_tool_response(
tool_calls,
analysis,
partial_text,
usage,
request_id,
Arc::new(current_request),
))
}
ResponsesIterationResult::Completed { response, usage: _ } => {
// No tool calls - return completed response
debug!("No tool calls - returning completed response");
Ok(*response)
}
}
}
/// Serve Harmony Responses API with streaming (SSE) /// Serve Harmony Responses API with streaming (SSE)
/// ///
/// This is the streaming equivalent of `serve_harmony_responses()`. /// This is the streaming equivalent of `serve_harmony_responses()`.
...@@ -412,14 +481,20 @@ pub async fn serve_harmony_responses_stream( ...@@ -412,14 +481,20 @@ pub async fn serve_harmony_responses_stream(
request: ResponsesRequest, request: ResponsesRequest,
) -> Response { ) -> Response {
// Load previous conversation history if previous_response_id is set // Load previous conversation history if previous_response_id is set
let mut current_request = match load_previous_messages(ctx, request).await { let current_request = match load_previous_messages(ctx, request).await {
Ok(req) => req, Ok(req) => req,
Err(err_response) => return err_response, Err(err_response) => return err_response,
}; };
// Check MCP connection BEFORE starting stream and get whether MCP tools are present
let has_mcp_tools =
match ensure_mcp_connection(&ctx.mcp_manager, current_request.tools.as_deref()).await {
Ok(has_mcp) => has_mcp,
Err(response) => return response,
};
// Create SSE channel // Create SSE channel
let (tx, rx) = mpsc::unbounded_channel(); let (tx, rx) = mpsc::unbounded_channel();
let stream = UnboundedReceiverStream::new(rx);
// Create response event emitter // Create response event emitter
let response_id = format!("resp_{}", Uuid::new_v4()); let response_id = format!("resp_{}", Uuid::new_v4());
...@@ -437,24 +512,6 @@ pub async fn serve_harmony_responses_stream( ...@@ -437,24 +512,6 @@ pub async fn serve_harmony_responses_stream(
tokio::spawn(async move { tokio::spawn(async move {
let ctx = &ctx_clone; let ctx = &ctx_clone;
// Clone response_id for closure to avoid borrow conflicts
let response_id_for_error = response_id.clone();
// Helper to emit error and return
let emit_error = |tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>, error_msg: &str| {
// Create error event manually since emit_failed doesn't exist
let event = json!({
"type": "response.failed",
"response_id": response_id_for_error,
"error": {
"message": error_msg,
"type": "internal_error"
}
});
let sse_data = format!("data: {}\n\n", to_string(&event).unwrap());
let _ = tx.send(Ok(Bytes::from(sse_data)));
};
// Emit initial response.created and response.in_progress events // Emit initial response.created and response.in_progress events
let event = emitter.emit_created(); let event = emitter.emit_created();
if emitter.send_event(&event, &tx).is_err() { if emitter.send_event(&event, &tx).is_err() {
...@@ -465,30 +522,35 @@ pub async fn serve_harmony_responses_stream( ...@@ -465,30 +522,35 @@ pub async fn serve_harmony_responses_stream(
return; return;
} }
// Check if request has MCP tools if has_mcp_tools {
let has_mcp_tools = current_request execute_mcp_tool_loop_streaming(ctx, current_request, &mut emitter, &tx).await;
.tools } else {
.as_ref() execute_without_mcp_streaming(ctx, &current_request, &mut emitter, &tx).await;
.map(|tools| { }
tools });
.iter()
.any(|t| matches!(t.r#type, ResponseToolType::Mcp))
})
.unwrap_or(false);
// Return SSE stream response
build_sse_response(rx)
}
// Execute MCP tool loop with streaming
///
/// Handles the full MCP workflow:
/// - Adds static MCP tools to request
/// - Emits mcp_list_tools events
/// - Loops through tool execution iterations
/// - Emits final response.completed event
async fn execute_mcp_tool_loop_streaming(
ctx: &HarmonyResponsesContext,
mut current_request: ResponsesRequest,
emitter: &mut ResponseStreamEventEmitter,
tx: &mpsc::UnboundedSender<Result<Bytes, std::io::Error>>,
) {
// Initialize MCP call tracking // Initialize MCP call tracking
let mut mcp_tracking = if has_mcp_tools { let mut mcp_tracking = McpCallTracking::new("sglang-mcp".to_string());
Some(McpCallTracking::new("sglang-mcp".to_string()))
} else {
None
};
// Setup MCP tools if needed // Extract user's max_tool_calls limit (if set)
if has_mcp_tools { let max_tool_calls = current_request.max_tool_calls.map(|n| n as usize);
// Ensure dynamic MCP client is registered
if let Some(tools) = &current_request.tools {
ensure_request_mcp_client(&ctx.mcp_manager, tools).await;
}
// Add static MCP tools from inventory // Add static MCP tools from inventory
let mcp_tools = ctx.mcp_manager.list_tools(); let mcp_tools = ctx.mcp_manager.list_tools();
...@@ -501,16 +563,12 @@ pub async fn serve_harmony_responses_stream( ...@@ -501,16 +563,12 @@ pub async fn serve_harmony_responses_stream(
debug!( debug!(
mcp_tool_count = mcp_tools.len(), mcp_tool_count = mcp_tools.len(),
total_tool_count = current_request.tools.as_ref().map(|t| t.len()).unwrap_or(0), total_tool_count = current_request.tools.as_ref().map(|t| t.len()).unwrap_or(0),
"Added static MCP tools to Harmony Responses streaming request" "MCP client available - added static MCP tools to Harmony Responses streaming request"
); );
} }
}
// Emit mcp_list_tools on first iteration (only if MCP tools available) // Emit mcp_list_tools on first iteration
if has_mcp_tools { let (output_index, item_id) = emitter.allocate_output_index(OutputItemType::McpListTools);
let mcp_tools = ctx.mcp_manager.list_tools();
let (output_index, item_id) =
emitter.allocate_output_index(OutputItemType::McpListTools);
// Build tools list for item structure // Build tools list for item structure
let tool_items: Vec<_> = mcp_tools let tool_items: Vec<_> = mcp_tools
...@@ -533,19 +591,19 @@ pub async fn serve_harmony_responses_stream( ...@@ -533,19 +591,19 @@ pub async fn serve_harmony_responses_stream(
"tools": [] "tools": []
}); });
let event = emitter.emit_output_item_added(output_index, &item); let event = emitter.emit_output_item_added(output_index, &item);
if emitter.send_event(&event, &tx).is_err() { if emitter.send_event(&event, tx).is_err() {
return; return;
} }
// Emit mcp_list_tools.in_progress // Emit mcp_list_tools.in_progress
let event = emitter.emit_mcp_list_tools_in_progress(output_index); let event = emitter.emit_mcp_list_tools_in_progress(output_index);
if emitter.send_event(&event, &tx).is_err() { if emitter.send_event(&event, tx).is_err() {
return; return;
} }
// Emit mcp_list_tools.completed // Emit mcp_list_tools.completed
let event = emitter.emit_mcp_list_tools_completed(output_index, &mcp_tools); let event = emitter.emit_mcp_list_tools_completed(output_index, &mcp_tools);
if emitter.send_event(&event, &tx).is_err() { if emitter.send_event(&event, tx).is_err() {
return; return;
} }
...@@ -558,7 +616,7 @@ pub async fn serve_harmony_responses_stream( ...@@ -558,7 +616,7 @@ pub async fn serve_harmony_responses_stream(
"tools": tool_items "tools": tool_items
}); });
let event = emitter.emit_output_item_done(output_index, &item_done); let event = emitter.emit_output_item_done(output_index, &item_done);
if emitter.send_event(&event, &tx).is_err() { if emitter.send_event(&event, tx).is_err() {
return; return;
} }
...@@ -568,18 +626,19 @@ pub async fn serve_harmony_responses_stream( ...@@ -568,18 +626,19 @@ pub async fn serve_harmony_responses_stream(
tool_count = mcp_tools.len(), tool_count = mcp_tools.len(),
"Emitted mcp_list_tools on first iteration" "Emitted mcp_list_tools on first iteration"
); );
}
// Tool loop (max 10 iterations) // MCP tool loop (max 10 iterations)
let mut iteration_count = 0; let mut iteration_count = 0;
loop { loop {
iteration_count += 1; iteration_count += 1;
// Safety check: prevent infinite loops // Safety check: prevent infinite loops
if iteration_count > MAX_TOOL_ITERATIONS { if iteration_count > MAX_TOOL_ITERATIONS {
let error_msg = emitter.emit_error(
format!("Maximum tool iterations ({}) exceeded", MAX_TOOL_ITERATIONS); &format!("Maximum tool iterations ({}) exceeded", MAX_TOOL_ITERATIONS),
emit_error(&tx, &error_msg); Some("max_iterations_exceeded"),
tx,
);
return; return;
} }
...@@ -588,7 +647,7 @@ pub async fn serve_harmony_responses_stream( ...@@ -588,7 +647,7 @@ pub async fn serve_harmony_responses_stream(
"Harmony Responses streaming iteration" "Harmony Responses streaming iteration"
); );
// Execute through pipeline and get raw stream // Execute pipeline and get stream
let execution_result = match ctx let execution_result = match ctx
.pipeline .pipeline
.execute_harmony_responses_streaming(&current_request, ctx) .execute_harmony_responses_streaming(&current_request, ctx)
...@@ -596,23 +655,27 @@ pub async fn serve_harmony_responses_stream( ...@@ -596,23 +655,27 @@ pub async fn serve_harmony_responses_stream(
{ {
Ok(result) => result, Ok(result) => result,
Err(err_response) => { Err(err_response) => {
let error_msg = format!("Pipeline execution failed: {:?}", err_response); emitter.emit_error(
emit_error(&tx, &error_msg); &format!("Pipeline execution failed: {:?}", err_response),
Some("pipeline_error"),
tx,
);
return; return;
} }
}; };
// Process stream with token-level streaming using HarmonyStreamingProcessor // Process stream with token-level streaming (MCP path - emits mcp_call.* events)
let iteration_result = match super::streaming::HarmonyStreamingProcessor::process_responses_iteration_stream( let iteration_result =
match HarmonyStreamingProcessor::process_responses_iteration_stream_mcp(
execution_result, execution_result,
&mut emitter, emitter,
&tx, tx,
) )
.await .await
{ {
Ok(result) => result, Ok(result) => result,
Err(err_msg) => { Err(err_msg) => {
emit_error(&tx, &err_msg); emitter.emit_error(&err_msg, Some("processing_error"), tx);
return; return;
} }
}; };
...@@ -623,29 +686,60 @@ pub async fn serve_harmony_responses_stream( ...@@ -623,29 +686,60 @@ pub async fn serve_harmony_responses_stream(
tool_calls, tool_calls,
analysis, analysis,
partial_text, partial_text,
usage,
request_id: _,
} => { } => {
debug!( debug!(
tool_call_count = tool_calls.len(), tool_call_count = tool_calls.len(),
has_analysis = analysis.is_some(), has_analysis = analysis.is_some(),
partial_text_len = partial_text.len(), partial_text_len = partial_text.len(),
"Tool calls found in commentary channel" "MCP tool calls found in commentary channel - checking limits"
); );
// Execute MCP tools // Check combined limit (user's max_tool_calls vs safety limit)
let tool_results = if let Some(ref mut tracking) = mcp_tracking { let effective_limit = match max_tool_calls {
match execute_mcp_tools(&ctx.mcp_manager, &tool_calls, tracking).await { Some(user_max) => user_max.min(MAX_TOOL_ITERATIONS),
None => MAX_TOOL_ITERATIONS,
};
// Check if we would exceed the limit with these new tool calls
let total_calls_after = mcp_tracking.total_calls() + tool_calls.len();
if total_calls_after > effective_limit {
warn!(
current_calls = mcp_tracking.total_calls(),
new_calls = tool_calls.len(),
total_after = total_calls_after,
effective_limit = effective_limit,
user_max = ?max_tool_calls,
"Reached tool call limit in streaming - emitting completion with incomplete_details"
);
// Emit response.completed with incomplete_details and usage
let incomplete_details = json!({ "reason": "max_tool_calls" });
let usage_json = json!({
"prompt_tokens": usage.prompt_tokens,
"completion_tokens": usage.completion_tokens,
"total_tokens": usage.total_tokens,
"incomplete_details": incomplete_details,
});
let event = emitter.emit_completed(Some(&usage_json));
emitter.send_event_best_effort(&event, tx);
return;
}
// Execute MCP tools and continue loop
let tool_results =
match execute_mcp_tools(&ctx.mcp_manager, &tool_calls, &mut mcp_tracking).await
{
Ok(results) => results, Ok(results) => results,
Err(err_response) => { Err(err_response) => {
let error_msg = emitter.emit_error(
format!("MCP tool execution failed: {:?}", err_response); &format!("MCP tool execution failed: {:?}", err_response),
emit_error(&tx, &error_msg); Some("mcp_tool_error"),
tx,
);
return; return;
} }
}
} else {
let error_msg = "Tool calls found but MCP tracking not initialized";
emit_error(&tx, error_msg);
return;
}; };
// Build next request with appended history // Build next request with appended history
...@@ -658,8 +752,11 @@ pub async fn serve_harmony_responses_stream( ...@@ -658,8 +752,11 @@ pub async fn serve_harmony_responses_stream(
) { ) {
Ok(req) => req, Ok(req) => req,
Err(e) => { Err(e) => {
let error_msg = format!("Failed to build next request: {:?}", e); emitter.emit_error(
emit_error(&tx, &error_msg); &format!("Failed to build next request: {:?}", e),
Some("request_building_error"),
tx,
);
return; return;
} }
}; };
...@@ -681,24 +778,157 @@ pub async fn serve_harmony_responses_stream( ...@@ -681,24 +778,157 @@ pub async fn serve_harmony_responses_stream(
"total_tokens": usage.total_tokens, "total_tokens": usage.total_tokens,
}); });
let event = emitter.emit_completed(Some(&usage_json)); let event = emitter.emit_completed(Some(&usage_json));
emitter.send_event_best_effort(&event, &tx); emitter.send_event_best_effort(&event, tx);
return;
}
}
}
}
// Close channel /// Execute without MCP tool loop (single execution with streaming)
drop(tx); ///
/// For function tools or no tools - executes pipeline once and emits completion.
/// The streaming processor handles all output items (reasoning, message, function tool calls).
async fn execute_without_mcp_streaming(
ctx: &HarmonyResponsesContext,
current_request: &ResponsesRequest,
emitter: &mut ResponseStreamEventEmitter,
tx: &mpsc::UnboundedSender<Result<Bytes, std::io::Error>>,
) {
debug!("No MCP tools - executing single iteration");
// Execute pipeline and get stream
let execution_result = match ctx
.pipeline
.execute_harmony_responses_streaming(current_request, ctx)
.await
{
Ok(result) => result,
Err(err_response) => {
emitter.emit_error(
&format!("Pipeline execution failed: {:?}", err_response),
Some("pipeline_error"),
tx,
);
return; return;
} }
};
// Process stream (emits all output items during streaming - function tool path emits function_call_arguments.* events)
if let Err(err_msg) = HarmonyStreamingProcessor::process_responses_iteration_stream_function(
execution_result,
emitter,
tx,
)
.await
{
emitter.emit_error(&err_msg, Some("processing_error"), tx);
return;
}
// Emit response.completed
let event = emitter.emit_completed(None);
emitter.send_event_best_effort(&event, tx);
}
/// Build ResponsesResponse with function tool calls for caller to execute
///
/// When tool calls are found but no MCP client is available (function tools only),
/// this builds a response with status=Completed and tool calls without output field.
/// The absence of output signals the caller should execute tools and resume.
///
/// TODO: Refactor to use builder pattern
fn build_function_tool_response(
tool_calls: Vec<ToolCall>,
analysis: Option<String>,
partial_text: String,
usage: Usage,
request_id: String,
responses_request: Arc<ResponsesRequest>,
) -> ResponsesResponse {
let mut output: Vec<ResponseOutputItem> = Vec::new();
// Add reasoning output item if analysis exists
if let Some(analysis_text) = analysis {
output.push(ResponseOutputItem::Reasoning {
id: format!("reasoning_{}", request_id),
summary: vec![],
content: vec![ResponseReasoningContent::ReasoningText {
text: analysis_text,
}],
status: Some("completed".to_string()),
});
} }
// Add message output item if partial text exists
if !partial_text.is_empty() {
output.push(ResponseOutputItem::Message {
id: format!("msg_{}", request_id),
role: "assistant".to_string(),
content: vec![ResponseContentPart::OutputText {
text: partial_text,
annotations: vec![],
logprobs: None,
}],
status: "completed".to_string(),
});
} }
// Add function tool calls as completed output items (no output field = needs execution)
for tool_call in tool_calls {
output.push(ResponseOutputItem::FunctionToolCall {
id: tool_call.id.clone(),
call_id: tool_call.id.clone(),
name: tool_call.function.name.clone(),
arguments: tool_call.function.arguments.clone().unwrap_or_default(),
output: None, // No output = tool needs execution by caller
status: "completed".to_string(),
}); });
}
// Return SSE stream response // Build ResponsesResponse with Completed status
Response::builder() // The presence of FunctionToolCall items without output signals tool execution needed
.status(StatusCode::OK) let created_at = SystemTime::now()
.header("Content-Type", "text/event-stream") .duration_since(UNIX_EPOCH)
.header("Cache-Control", "no-cache")
.header("Connection", "keep-alive")
.body(Body::from_stream(stream))
.unwrap() .unwrap()
.as_secs() as i64;
ResponsesResponse {
id: request_id,
object: "response".to_string(),
created_at,
status: ResponseStatus::Completed,
error: None,
incomplete_details: None,
instructions: responses_request.instructions.clone(),
max_output_tokens: responses_request.max_output_tokens,
model: responses_request.model.clone(),
output,
parallel_tool_calls: responses_request.parallel_tool_calls.unwrap_or(true),
previous_response_id: responses_request.previous_response_id.clone(),
reasoning: None,
store: responses_request.store.unwrap_or(true),
temperature: responses_request.temperature,
text: None,
tool_choice: responses_request
.tool_choice
.as_ref()
.map(|tc| to_string(tc).unwrap_or_else(|_| "auto".to_string()))
.unwrap_or_else(|| "auto".to_string()),
tools: responses_request.tools.clone().unwrap_or_default(),
top_p: responses_request.top_p,
truncation: None,
usage: Some(ResponsesUsage::Modern(ResponseUsage {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
total_tokens: usage.total_tokens,
input_tokens_details: None,
output_tokens_details: None,
})),
user: None,
safety_identifier: responses_request.user.clone(),
metadata: responses_request.metadata.clone().unwrap_or_default(),
}
} }
/// Execute MCP tools and collect results /// Execute MCP tools and collect results
...@@ -758,8 +988,7 @@ async fn execute_mcp_tools( ...@@ -758,8 +988,7 @@ async fn execute_mcp_tools(
// Extract content from MCP result // Extract content from MCP result
let output = if let Some(content) = mcp_result.content.first() { let output = if let Some(content) = mcp_result.content.first() {
// TODO: Handle different content types (text, image, resource) // Serialize the entire content item
// For now, serialize the entire content item
to_value(content) to_value(content)
.unwrap_or_else(|_| json!({"error": "Failed to serialize tool result"})) .unwrap_or_else(|_| json!({"error": "Failed to serialize tool result"}))
} else { } else {
......
...@@ -214,7 +214,7 @@ impl HarmonyPreparationStage { ...@@ -214,7 +214,7 @@ impl HarmonyPreparationStage {
let params_schema = &tool.function.parameters; let params_schema = &tool.function.parameters;
tags.push(json!({ tags.push(json!({
"begin": format!("<|channel|>commentary to=functions.{}<|constrain|>json<|message|>", tool_name), "begin": format!("<|start|>assistant<|channel|>commentary to=functions.{}<|constrain|>json<|message|>", tool_name),
"content": { "content": {
"type": "json_schema", "type": "json_schema",
"json_schema": params_schema "json_schema": params_schema
...@@ -228,7 +228,7 @@ impl HarmonyPreparationStage { ...@@ -228,7 +228,7 @@ impl HarmonyPreparationStage {
let structural_tag = json!({ let structural_tag = json!({
"format": { "format": {
"type": "triggered_tags", "type": "triggered_tags",
"triggers": ["<|channel|>commentary"], "triggers": ["<|start|>assistant"],
"tags": tags, "tags": tags,
"at_least_one": true, "at_least_one": true,
"stop_after_first": stop_after_first "stop_after_first": stop_after_first
......
...@@ -35,6 +35,71 @@ use crate::{ ...@@ -35,6 +35,71 @@ use crate::{
context, context,
}, },
}; };
/// Mode for tool call event emission
#[derive(Debug, Clone, Copy)]
enum ToolCallMode {
/// MCP tool calls (emit .in_progress and .completed events)
Mcp,
/// Function tool calls (no status events, only arguments streaming)
Function,
}
impl ToolCallMode {
/// Get the output item type for this mode
fn output_item_type(&self) -> OutputItemType {
match self {
Self::Mcp => OutputItemType::McpCall,
Self::Function => OutputItemType::FunctionCall,
}
}
/// Get the type string for JSON output
fn type_str(&self) -> &'static str {
match self {
Self::Mcp => "mcp_call",
Self::Function => "function_call",
}
}
/// Whether this mode emits status events (.in_progress, .completed)
fn emits_status_events(&self) -> bool {
matches!(self, Self::Mcp)
}
/// Emit arguments delta event
fn emit_arguments_delta(
&self,
emitter: &mut ResponseStreamEventEmitter,
output_index: usize,
item_id: &str,
delta: &str,
) -> serde_json::Value {
match self {
Self::Mcp => emitter.emit_mcp_call_arguments_delta(output_index, item_id, delta),
Self::Function => {
emitter.emit_function_call_arguments_delta(output_index, item_id, delta)
}
}
}
/// Emit arguments done event
fn emit_arguments_done(
&self,
emitter: &mut ResponseStreamEventEmitter,
output_index: usize,
item_id: &str,
arguments: &str,
) -> serde_json::Value {
match self {
Self::Mcp => emitter.emit_mcp_call_arguments_done(output_index, item_id, arguments),
Self::Function => {
emitter.emit_function_call_arguments_done(output_index, item_id, arguments)
}
}
}
}
/// Processor for streaming Harmony responses /// Processor for streaming Harmony responses
/// ///
/// Returns an SSE stream that parses Harmony tokens incrementally and /// Returns an SSE stream that parses Harmony tokens incrementally and
...@@ -531,14 +596,14 @@ impl HarmonyStreamingProcessor { ...@@ -531,14 +596,14 @@ impl HarmonyStreamingProcessor {
Ok(()) Ok(())
} }
/// Common decode stream processing logic for both single and dual stream modes /// Decode stream processing for tool loops
/// ///
/// This helper function contains the shared logic for processing the decode stream, /// Emits tool call events based on the mode (MCP or Function).
/// parsing Harmony tokens, emitting SSE events, and tracking state. async fn process_decode_stream(
async fn process_decode_stream_common(
mut decode_stream: AbortOnDropStream, mut decode_stream: AbortOnDropStream,
emitter: &mut ResponseStreamEventEmitter, emitter: &mut ResponseStreamEventEmitter,
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>, tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
mode: ToolCallMode,
) -> Result<ResponsesIterationResult, String> { ) -> Result<ResponsesIterationResult, String> {
// Initialize Harmony parser for this iteration // Initialize Harmony parser for this iteration
let mut parser = let mut parser =
...@@ -555,12 +620,14 @@ impl HarmonyStreamingProcessor { ...@@ -555,12 +620,14 @@ impl HarmonyStreamingProcessor {
let mut message_item_id: Option<String> = None; let mut message_item_id: Option<String> = None;
let mut has_emitted_content_part_added = false; let mut has_emitted_content_part_added = false;
// MCP tool call tracking (call_index -> (output_index, item_id)) // Tool call tracking (call_index -> (output_index, item_id))
let mut mcp_call_tracking: HashMap<usize, (usize, String)> = HashMap::new(); let mut tool_call_tracking: HashMap<usize, (usize, String)> = HashMap::new();
// Metadata from Complete message // Metadata from Complete message
let mut finish_reason = String::from("stop"); let mut finish_reason = String::from("stop");
let mut matched_stop: Option<serde_json::Value> = None; let mut matched_stop: Option<serde_json::Value> = None;
let mut prompt_tokens: u32 = 0;
let mut completion_tokens: u32 = 0;
// Process stream // Process stream
let mut chunk_count = 0; let mut chunk_count = 0;
...@@ -646,29 +713,53 @@ impl HarmonyStreamingProcessor { ...@@ -646,29 +713,53 @@ impl HarmonyStreamingProcessor {
} }
} }
// Commentary channel → MCP tool call streaming // Commentary channel → Tool call streaming
if let Some(tc_delta) = &delta.commentary_delta { if let Some(tc_delta) = &delta.commentary_delta {
let call_index = tc_delta.index; let call_index = tc_delta.index;
// Check if this is a new tool call (has id and name) // Check if this is a new tool call (has id and name)
if tc_delta.id.is_some() { if tc_delta.id.is_some() {
// NEW MCP CALL: Allocate output item and emit in_progress // NEW TOOL CALL: Allocate output item
let (output_index, item_id) = let (output_index, item_id) =
emitter.allocate_output_index(OutputItemType::McpCall); emitter.allocate_output_index(mode.output_item_type());
// Store tracking info // Store tracking info
mcp_call_tracking tool_call_tracking
.insert(call_index, (output_index, item_id.clone())); .insert(call_index, (output_index, item_id.clone()));
// Emit mcp_call.in_progress // Get tool name
let tool_name = tc_delta
.function
.as_ref()
.and_then(|f| f.name.as_ref())
.map(|n| n.as_str())
.unwrap_or("");
// Emit output_item.added wrapper event
let call_id = tc_delta.id.as_ref().unwrap();
let item = json!({
"id": item_id,
"type": mode.type_str(),
"name": tool_name,
"call_id": call_id,
"arguments": "",
"status": "in_progress"
});
let event = emitter.emit_output_item_added(output_index, &item);
emitter.send_event_best_effort(&event, tx);
// Emit status event if mode supports it (MCP only)
if mode.emits_status_events() {
let event = let event =
emitter.emit_mcp_call_in_progress(output_index, &item_id); emitter.emit_mcp_call_in_progress(output_index, &item_id);
emitter.send_event_best_effort(&event, tx); emitter.send_event_best_effort(&event, tx);
}
// If we have function name, emit initial mcp_call_arguments.delta // If we have function name, emit initial arguments delta
if let Some(func) = &tc_delta.function { if let Some(func) = &tc_delta.function {
if func.name.is_some() { if func.name.is_some() {
let event = emitter.emit_mcp_call_arguments_delta( let event = mode.emit_arguments_delta(
emitter,
output_index, output_index,
&item_id, &item_id,
"", "",
...@@ -677,9 +768,9 @@ impl HarmonyStreamingProcessor { ...@@ -677,9 +768,9 @@ impl HarmonyStreamingProcessor {
} }
} }
} else { } else {
// CONTINUING MCP CALL: Emit arguments delta // CONTINUING TOOL CALL: Emit arguments delta
if let Some((output_index, item_id)) = if let Some((output_index, item_id)) =
mcp_call_tracking.get(&call_index) tool_call_tracking.get(&call_index)
{ {
if let Some(args) = tc_delta if let Some(args) = tc_delta
.function .function
...@@ -687,7 +778,8 @@ impl HarmonyStreamingProcessor { ...@@ -687,7 +778,8 @@ impl HarmonyStreamingProcessor {
.and_then(|f| f.arguments.as_ref()) .and_then(|f| f.arguments.as_ref())
.filter(|a| !a.is_empty()) .filter(|a| !a.is_empty())
{ {
let event = emitter.emit_mcp_call_arguments_delta( let event = mode.emit_arguments_delta(
emitter,
*output_index, *output_index,
item_id, item_id,
args, args,
...@@ -704,12 +796,14 @@ impl HarmonyStreamingProcessor { ...@@ -704,12 +796,14 @@ impl HarmonyStreamingProcessor {
finish_reason = complete.finish_reason.clone(); finish_reason = complete.finish_reason.clone();
matched_stop = complete.matched_stop.as_ref().map(|m| match m { matched_stop = complete.matched_stop.as_ref().map(|m| match m {
MatchedTokenId(id) => { MatchedTokenId(id) => {
serde_json::json!(id) json!(id)
} }
MatchedStopStr(s) => { MatchedStopStr(s) => {
serde_json::json!(s) json!(s)
} }
}); });
prompt_tokens = complete.prompt_tokens as u32;
completion_tokens = complete.completion_tokens as u32;
// Finalize parser and get complete output // Finalize parser and get complete output
let final_output = parser let final_output = parser
...@@ -719,23 +813,42 @@ impl HarmonyStreamingProcessor { ...@@ -719,23 +813,42 @@ impl HarmonyStreamingProcessor {
// Store finalized tool calls // Store finalized tool calls
accumulated_tool_calls = final_output.commentary.clone(); accumulated_tool_calls = final_output.commentary.clone();
// Complete all MCP tool calls if we have commentary // Complete all tool calls if we have commentary
if let Some(ref tool_calls) = accumulated_tool_calls { if let Some(ref tool_calls) = accumulated_tool_calls {
for (call_idx, tool_call) in tool_calls.iter().enumerate() { for (call_idx, tool_call) in tool_calls.iter().enumerate() {
if let Some((output_index, item_id)) = mcp_call_tracking.get(&call_idx) if let Some((output_index, item_id)) = tool_call_tracking.get(&call_idx)
{ {
// Emit mcp_call_arguments.done with final arguments let tool_name = &tool_call.function.name;
// Emit arguments done with final arguments
let args_str = let args_str =
tool_call.function.arguments.as_deref().unwrap_or(""); tool_call.function.arguments.as_deref().unwrap_or("");
let event = emitter.emit_mcp_call_arguments_done(
let event = mode.emit_arguments_done(
emitter,
*output_index, *output_index,
item_id, item_id,
args_str, args_str,
); );
emitter.send_event_best_effort(&event, tx); emitter.send_event_best_effort(&event, tx);
// Emit mcp_call.completed // Emit status event if mode supports it (MCP only)
let event = emitter.emit_mcp_call_completed(*output_index, item_id); if mode.emits_status_events() {
let event =
emitter.emit_mcp_call_completed(*output_index, item_id);
emitter.send_event_best_effort(&event, tx);
}
// Emit output_item.done wrapper event
let item = json!({
"id": item_id,
"type": mode.type_str(),
"name": tool_name,
"call_id": &tool_call.id,
"arguments": args_str,
"status": "completed"
});
let event = emitter.emit_output_item_done(*output_index, &item);
emitter.send_event_best_effort(&event, tx); emitter.send_event_best_effort(&event, tx);
// Mark output item as completed // Mark output item as completed
...@@ -811,23 +924,25 @@ impl HarmonyStreamingProcessor { ...@@ -811,23 +924,25 @@ impl HarmonyStreamingProcessor {
final_text_extracted.len() final_text_extracted.len()
); );
// Complete any pending MCP tool calls with data from completed messages // Complete any pending tool calls with data from completed messages
if let Some(ref tool_calls) = accumulated_tool_calls { if let Some(ref tool_calls) = accumulated_tool_calls {
for (call_idx, tool_call) in tool_calls.iter().enumerate() { for (call_idx, tool_call) in tool_calls.iter().enumerate() {
if let Some((output_index, item_id)) = mcp_call_tracking.get(&call_idx) { if let Some((output_index, item_id)) = tool_call_tracking.get(&call_idx) {
// Emit mcp_call_arguments.done with final arguments // Emit arguments done with final arguments
let args_str = tool_call.function.arguments.as_deref().unwrap_or(""); let args_str = tool_call.function.arguments.as_deref().unwrap_or("");
let event = let event =
emitter.emit_mcp_call_arguments_done(*output_index, item_id, args_str); mode.emit_arguments_done(emitter, *output_index, item_id, args_str);
emitter.send_event_best_effort(&event, tx); emitter.send_event_best_effort(&event, tx);
// Emit mcp_call.completed // Emit status event if mode supports it (MCP only)
if mode.emits_status_events() {
let event = emitter.emit_mcp_call_completed(*output_index, item_id); let event = emitter.emit_mcp_call_completed(*output_index, item_id);
emitter.send_event_best_effort(&event, tx); emitter.send_event_best_effort(&event, tx);
} }
} }
} }
} }
}
// Mark stream as completed successfully to prevent abort on drop // Mark stream as completed successfully to prevent abort on drop
decode_stream.mark_completed(); decode_stream.mark_completed();
...@@ -848,6 +963,13 @@ impl HarmonyStreamingProcessor { ...@@ -848,6 +963,13 @@ impl HarmonyStreamingProcessor {
tool_calls, tool_calls,
analysis: analysis_content, analysis: analysis_content,
partial_text: accumulated_final_text, partial_text: accumulated_final_text,
usage: Usage {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
completion_tokens_details: None,
},
request_id: emitter.response_id.clone(),
}); });
} }
} }
...@@ -857,7 +979,7 @@ impl HarmonyStreamingProcessor { ...@@ -857,7 +979,7 @@ impl HarmonyStreamingProcessor {
// Return a placeholder Completed result (caller ignores these fields in streaming mode) // Return a placeholder Completed result (caller ignores these fields in streaming mode)
Ok(ResponsesIterationResult::Completed { Ok(ResponsesIterationResult::Completed {
response: Box::new(ResponsesResponse { response: Box::new(ResponsesResponse {
id: String::new(), id: emitter.response_id.clone(),
object: "response".to_string(), object: "response".to_string(),
created_at: 0, created_at: 0,
status: ResponseStatus::Completed, status: ResponseStatus::Completed,
...@@ -881,76 +1003,136 @@ impl HarmonyStreamingProcessor { ...@@ -881,76 +1003,136 @@ impl HarmonyStreamingProcessor {
safety_identifier: None, safety_identifier: None,
metadata: HashMap::new(), metadata: HashMap::new(),
usage: Some(ResponsesUsage::Modern(ResponseUsage { usage: Some(ResponsesUsage::Modern(ResponseUsage {
input_tokens: 0, input_tokens: prompt_tokens,
output_tokens: 0, output_tokens: completion_tokens,
total_tokens: 0, total_tokens: prompt_tokens + completion_tokens,
input_tokens_details: None, input_tokens_details: None,
output_tokens_details: None, output_tokens_details: None,
})), })),
}), }),
usage: Usage { usage: Usage {
prompt_tokens: 0, prompt_tokens,
completion_tokens: 0, completion_tokens,
total_tokens: 0, total_tokens: prompt_tokens + completion_tokens,
completion_tokens_details: None, completion_tokens_details: None,
}, },
}) })
} }
/// Process streaming chunks for Responses API iteration /// Process streaming chunks for Responses API iteration - MCP loop
///
/// Emits mcp_call.* events for all tool calls
pub async fn process_responses_iteration_stream_mcp(
execution_result: context::ExecutionResult,
emitter: &mut ResponseStreamEventEmitter,
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
) -> Result<ResponsesIterationResult, String> {
match execution_result {
context::ExecutionResult::Single { stream } => {
debug!("Processing Responses API single stream mode (MCP)");
Self::process_responses_single_stream_mcp(stream, emitter, tx).await
}
context::ExecutionResult::Dual { prefill, decode } => {
debug!("Processing Responses API dual stream mode (MCP)");
Self::process_responses_dual_stream_mcp(prefill, *decode, emitter, tx).await
}
}
}
/// Process streaming chunks for Responses API iteration - Function tools
/// ///
/// Returns ResponsesIterationResult indicating whether tool calls were found /// Emits function_call_arguments.* events for all tool calls
/// (requiring MCP loop continuation) or if the iteration is complete. pub async fn process_responses_iteration_stream_function(
pub async fn process_responses_iteration_stream(
execution_result: context::ExecutionResult, execution_result: context::ExecutionResult,
emitter: &mut ResponseStreamEventEmitter, emitter: &mut ResponseStreamEventEmitter,
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>, tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
) -> Result<ResponsesIterationResult, String> { ) -> Result<ResponsesIterationResult, String> {
match execution_result { match execution_result {
context::ExecutionResult::Single { stream } => { context::ExecutionResult::Single { stream } => {
debug!("Processing Responses API single stream mode"); debug!("Processing Responses API single stream mode (Function)");
Self::process_responses_single_stream(stream, emitter, tx).await Self::process_responses_single_stream_function(stream, emitter, tx).await
} }
context::ExecutionResult::Dual { prefill, decode } => { context::ExecutionResult::Dual { prefill, decode } => {
debug!("Processing Responses API dual stream mode"); debug!("Processing Responses API dual stream mode (Function)");
Self::process_responses_dual_stream(prefill, *decode, emitter, tx).await Self::process_responses_dual_stream_function(prefill, *decode, emitter, tx).await
} }
} }
} }
/// Process streaming chunks from a single stream (Responses API) /// Process streaming chunks from a single stream - MCP loop
async fn process_responses_single_stream( async fn process_responses_single_stream_mcp(
grpc_stream: AbortOnDropStream, grpc_stream: AbortOnDropStream,
emitter: &mut ResponseStreamEventEmitter, emitter: &mut ResponseStreamEventEmitter,
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>, tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
) -> Result<ResponsesIterationResult, String> { ) -> Result<ResponsesIterationResult, String> {
// Delegate to common helper Self::process_decode_stream(grpc_stream, emitter, tx, ToolCallMode::Mcp).await
Self::process_decode_stream_common(grpc_stream, emitter, tx).await
} }
/// Process streaming chunks from dual streams (Responses API) /// Process streaming chunks from a single stream - Function tools
async fn process_responses_single_stream_function(
grpc_stream: AbortOnDropStream,
emitter: &mut ResponseStreamEventEmitter,
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
) -> Result<ResponsesIterationResult, String> {
Self::process_decode_stream(grpc_stream, emitter, tx, ToolCallMode::Function).await
}
/// Process streaming chunks from dual streams (common implementation)
async fn process_responses_dual_stream( async fn process_responses_dual_stream(
mut prefill_stream: AbortOnDropStream, mut prefill_stream: AbortOnDropStream,
decode_stream: AbortOnDropStream, decode_stream: AbortOnDropStream,
emitter: &mut ResponseStreamEventEmitter, emitter: &mut ResponseStreamEventEmitter,
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>, tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
mode: ToolCallMode,
) -> Result<ResponsesIterationResult, String> { ) -> Result<ResponsesIterationResult, String> {
// Phase 1: Process prefill stream (collect metadata, no output) // Phase 1: Process prefill stream (collect metadata, no output)
while let Some(result) = prefill_stream.next().await { while let Some(result) = prefill_stream.next().await {
let _response = result.map_err(|e| format!("Prefill stream error: {}", e))?; let _response = result.map_err(|e| format!("Prefill stream error: {}", e))?;
// No-op for prefill in Responses API (just metadata collection)
} }
// Phase 2: Process decode stream using common helper // Phase 2: Process decode stream using common helper
let result = Self::process_decode_stream_common(decode_stream, emitter, tx).await; let result = Self::process_decode_stream(decode_stream, emitter, tx, mode).await;
// Mark prefill stream as completed AFTER decode completes successfully // Mark prefill stream as completed AFTER decode completes successfully
// This ensures that if client disconnects during decode, BOTH streams send abort // This ensures that if client disconnects during decode, BOTH streams send abort
prefill_stream.mark_completed(); prefill_stream.mark_completed();
result result
} }
/// Process streaming chunks from dual streams - MCP loop
async fn process_responses_dual_stream_mcp(
prefill_stream: AbortOnDropStream,
decode_stream: AbortOnDropStream,
emitter: &mut ResponseStreamEventEmitter,
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
) -> Result<ResponsesIterationResult, String> {
Self::process_responses_dual_stream(
prefill_stream,
decode_stream,
emitter,
tx,
ToolCallMode::Mcp,
)
.await
}
/// Process streaming chunks from dual streams - Function tools
async fn process_responses_dual_stream_function(
prefill_stream: AbortOnDropStream,
decode_stream: AbortOnDropStream,
emitter: &mut ResponseStreamEventEmitter,
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
) -> Result<ResponsesIterationResult, String> {
Self::process_responses_dual_stream(
prefill_stream,
decode_stream,
emitter,
tx,
ToolCallMode::Function,
)
.await
}
/// Build SSE response from receiver /// Build SSE response from receiver
fn build_sse_response(rx: mpsc::UnboundedReceiver<Result<Bytes, io::Error>>) -> Response { fn build_sse_response(rx: mpsc::UnboundedReceiver<Result<Bytes, io::Error>>) -> Response {
let stream = UnboundedReceiverStream::new(rx); let stream = UnboundedReceiverStream::new(rx);
......
...@@ -226,7 +226,7 @@ pub fn responses_to_chat(req: &ResponsesRequest) -> Result<ChatCompletionRequest ...@@ -226,7 +226,7 @@ pub fn responses_to_chat(req: &ResponsesRequest) -> Result<ChatCompletionRequest
parallel_tool_calls: req.parallel_tool_calls, parallel_tool_calls: req.parallel_tool_calls,
top_logprobs: req.top_logprobs, top_logprobs: req.top_logprobs,
top_p: req.top_p, top_p: req.top_p,
skip_special_tokens: true, // Always skip special tokens // TODO: except for gpt-oss skip_special_tokens: true,
// Note: tools and tool_choice will be handled separately for MCP transformation // Note: tools and tool_choice will be handled separately for MCP transformation
tools: None, // Will be set by caller if needed tools: None, // Will be set by caller if needed
tool_choice: None, // Will be set by caller if needed tool_choice: None, // Will be set by caller if needed
......
...@@ -36,14 +36,13 @@ use std::sync::Arc; ...@@ -36,14 +36,13 @@ use std::sync::Arc;
use axum::{ use axum::{
body::Body, body::Body,
http::{self, header, StatusCode}, http::{self, StatusCode},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
}; };
use bytes::Bytes; use bytes::Bytes;
use futures_util::StreamExt; use futures_util::StreamExt;
use serde_json::json; use serde_json::json;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, warn}; use tracing::{debug, warn};
use uuid::Uuid; use uuid::Uuid;
use validator::Validate; use validator::Validate;
...@@ -67,8 +66,13 @@ use crate::{ ...@@ -67,8 +66,13 @@ use crate::{
}, },
}, },
routers::{ routers::{
grpc::{common::responses::streaming::ResponseStreamEventEmitter, error}, grpc::{
openai::{conversations::persist_conversation_items, mcp::ensure_request_mcp_client}, common::responses::{
build_sse_response, ensure_mcp_connection, streaming::ResponseStreamEventEmitter,
},
error,
},
openai::conversations::persist_conversation_items,
}, },
}; };
...@@ -81,33 +85,6 @@ pub async fn route_responses( ...@@ -81,33 +85,6 @@ pub async fn route_responses(
headers: Option<http::HeaderMap>, headers: Option<http::HeaderMap>,
model_id: Option<String>, model_id: Option<String>,
) -> Response { ) -> Response {
// 0. Fast worker validation (fail-fast before expensive operations)
let requested_model: Option<&str> = model_id.as_deref().or(Some(request.model.as_str()));
if let Some(model) = requested_model {
// Check if any workers support this model
let available_models = ctx.worker_registry.get_models();
if !available_models.contains(&model.to_string()) {
return (
StatusCode::SERVICE_UNAVAILABLE,
axum::Json(json!({
"error": {
"message": format!(
"No workers available for model '{}'. Available models: {}",
model,
available_models.join(", ")
),
"type": "service_unavailable",
"param": "model",
"code": "no_available_workers"
}
})),
)
.into_response();
}
}
// 1. Validate request (includes conversation ID format) // 1. Validate request (includes conversation ID format)
if let Err(validation_errors) = request.validate() { if let Err(validation_errors) = request.validate() {
// Extract the first error message for conversation field // Extract the first error message for conversation field
...@@ -171,7 +148,10 @@ pub async fn route_responses( ...@@ -171,7 +148,10 @@ pub async fn route_responses(
if is_streaming { if is_streaming {
route_responses_streaming(ctx, request, headers, model_id).await route_responses_streaming(ctx, request, headers, model_id).await
} else { } else {
route_responses_sync(ctx, request, headers, model_id, None).await // Generate response ID for synchronous execution
// TODO: we may remove this when we have builder pattern for responses
let response_id = Some(format!("resp_{}", Uuid::new_v4()));
route_responses_sync(ctx, request, headers, model_id, response_id).await
} }
} }
...@@ -211,13 +191,10 @@ async fn route_responses_internal( ...@@ -211,13 +191,10 @@ async fn route_responses_internal(
// 1. Load conversation history and build modified request // 1. Load conversation history and build modified request
let modified_request = load_conversation_history(ctx, &request).await?; let modified_request = load_conversation_history(ctx, &request).await?;
// 2. Check if request has MCP tools - if so, use tool loop // 2. Check MCP connection and get whether MCP tools are present
let responses_response = if let Some(tools) = &request.tools { let has_mcp_tools = ensure_mcp_connection(&ctx.mcp_manager, request.tools.as_deref()).await?;
// Ensure dynamic MCP client is registered for request-scoped tools
if ensure_request_mcp_client(&ctx.mcp_manager, tools) let responses_response = if has_mcp_tools {
.await
.is_some()
{
debug!("MCP tools detected, using tool loop"); debug!("MCP tools detected, using tool loop");
// Execute with MCP tool loop // Execute with MCP tool loop
...@@ -231,20 +208,7 @@ async fn route_responses_internal( ...@@ -231,20 +208,7 @@ async fn route_responses_internal(
) )
.await? .await?
} else { } else {
debug!("Failed to create MCP client from request tools"); // No MCP tools - execute without MCP (may have function tools or no tools)
// Fall through to non-MCP execution
execute_without_mcp(
ctx,
&modified_request,
&request,
headers,
model_id,
response_id.clone(),
)
.await?
}
} else {
// No tools, execute normally
execute_without_mcp( execute_without_mcp(
ctx, ctx,
&modified_request, &modified_request,
...@@ -289,19 +253,19 @@ async fn route_responses_streaming( ...@@ -289,19 +253,19 @@ async fn route_responses_streaming(
Err(response) => return response, // Already a Response with proper status code Err(response) => return response, // Already a Response with proper status code
}; };
// 2. Check if request has MCP tools - if so, use streaming tool loop // 2. Check MCP connection and get whether MCP tools are present
if let Some(tools) = &request.tools { let has_mcp_tools =
// Ensure dynamic MCP client is registered for request-scoped tools match ensure_mcp_connection(&ctx.mcp_manager, request.tools.as_deref()).await {
if ensure_request_mcp_client(&ctx.mcp_manager, tools) Ok(has_mcp) => has_mcp,
.await Err(response) => return response,
.is_some() };
{
if has_mcp_tools {
debug!("MCP tools detected in streaming mode, using streaming tool loop"); debug!("MCP tools detected in streaming mode, using streaming tool loop");
return execute_tool_loop_streaming(ctx, modified_request, &request, headers, model_id) return execute_tool_loop_streaming(ctx, modified_request, &request, headers, model_id)
.await; .await;
} }
}
// 3. Convert ResponsesRequest → ChatCompletionRequest // 3. Convert ResponsesRequest → ChatCompletionRequest
let chat_request = match conversions::responses_to_chat(&modified_request) { let chat_request = match conversions::responses_to_chat(&modified_request) {
...@@ -352,8 +316,8 @@ async fn convert_chat_stream_to_responses_stream( ...@@ -352,8 +316,8 @@ async fn convert_chat_stream_to_responses_stream(
) )
.await; .await;
// Extract body and headers from chat response // Extract body from chat response
let (parts, body) = chat_response.into_parts(); let (_parts, body) = chat_response.into_parts();
// Create channel for transformed SSE events // Create channel for transformed SSE events
let (tx, rx) = mpsc::unbounded_channel::<Result<Bytes, std::io::Error>>(); let (tx, rx) = mpsc::unbounded_channel::<Result<Bytes, std::io::Error>>();
...@@ -392,29 +356,7 @@ async fn convert_chat_stream_to_responses_stream( ...@@ -392,29 +356,7 @@ async fn convert_chat_stream_to_responses_stream(
}); });
// Build SSE response with transformed stream // Build SSE response with transformed stream
let stream = UnboundedReceiverStream::new(rx); build_sse_response(rx)
let body = Body::from_stream(stream);
let mut response = Response::builder().status(parts.status).body(body).unwrap();
// Copy headers from original chat response
*response.headers_mut() = parts.headers;
// Ensure SSE headers are set
response.headers_mut().insert(
header::CONTENT_TYPE,
header::HeaderValue::from_static("text/event-stream"),
);
response.headers_mut().insert(
header::CACHE_CONTROL,
header::HeaderValue::from_static("no-cache"),
);
response.headers_mut().insert(
header::CONNECTION,
header::HeaderValue::from_static("keep-alive"),
);
response
} }
/// Process chat SSE stream and transform to responses format /// Process chat SSE stream and transform to responses format
......
...@@ -10,7 +10,10 @@ use axum::{ ...@@ -10,7 +10,10 @@ use axum::{
use tracing::debug; use tracing::debug;
use super::{ use super::{
common::responses::handlers::{cancel_response_impl, get_response_impl}, common::responses::{
handlers::{cancel_response_impl, get_response_impl},
utils::validate_worker_availability,
},
context::SharedComponents, context::SharedComponents,
harmony::{ harmony::{
serve_harmony_responses, serve_harmony_responses_stream, HarmonyDetector, serve_harmony_responses, serve_harmony_responses_stream, HarmonyDetector,
...@@ -191,14 +194,18 @@ impl GrpcRouter { ...@@ -191,14 +194,18 @@ impl GrpcRouter {
body: &ResponsesRequest, body: &ResponsesRequest,
model_id: Option<&str>, model_id: Option<&str>,
) -> Response { ) -> Response {
// 0. Fast worker validation (fail-fast before expensive operations)
let requested_model: Option<&str> = model_id.or(Some(body.model.as_str()));
if let Some(error_response) = requested_model
.and_then(|model| validate_worker_availability(&self.worker_registry, model))
{
return error_response;
}
// Choose implementation based on Harmony model detection // Choose implementation based on Harmony model detection
let is_harmony = HarmonyDetector::is_harmony_model(&body.model); let is_harmony = HarmonyDetector::is_harmony_model(&body.model);
debug!(
"Processing responses request for model: {:?}, using_harmony={}",
model_id, is_harmony
);
if is_harmony { if is_harmony {
debug!( debug!(
"Processing Harmony responses request for model: {:?}, streaming: {:?}", "Processing Harmony responses request for model: {:?}, streaming: {:?}",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment