Unverified Commit 0e82fd3d authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][grpc] Fix model validation, tool call check, streaming logic and misc...

[router][grpc] Fix model validation, tool call check, streaming logic and misc in responses  (#12616)
parent b7d70411
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
pub mod handlers; pub mod handlers;
pub mod streaming; pub mod streaming;
pub mod utils;
pub use handlers::{cancel_response_impl, get_response_impl}; pub use handlers::{cancel_response_impl, get_response_impl};
pub use streaming::{OutputItemType, ResponseStreamEventEmitter}; pub use streaming::{build_sse_response, OutputItemType, ResponseStreamEventEmitter};
pub use utils::ensure_mcp_connection;
...@@ -2,9 +2,11 @@ ...@@ -2,9 +2,11 @@
use std::collections::HashMap; use std::collections::HashMap;
use axum::{body::Body, http::StatusCode, response::Response};
use bytes::Bytes; use bytes::Bytes;
use serde_json::json; use serde_json::json;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
use uuid::Uuid; use uuid::Uuid;
use crate::{mcp, protocols::chat::ChatCompletionStreamResponse}; use crate::{mcp, protocols::chat::ChatCompletionStreamResponse};
...@@ -13,6 +15,7 @@ pub enum OutputItemType { ...@@ -13,6 +15,7 @@ pub enum OutputItemType {
Message, Message,
McpListTools, McpListTools,
McpCall, McpCall,
FunctionCall,
Reasoning, Reasoning,
} }
...@@ -342,6 +345,40 @@ impl ResponseStreamEventEmitter { ...@@ -342,6 +345,40 @@ impl ResponseStreamEventEmitter {
}) })
} }
// ========================================================================
// Function Call Event Emission Methods
// ========================================================================
pub fn emit_function_call_arguments_delta(
&mut self,
output_index: usize,
item_id: &str,
delta: &str,
) -> serde_json::Value {
json!({
"type": "response.function_call_arguments.delta",
"sequence_number": self.next_sequence(),
"output_index": output_index,
"item_id": item_id,
"delta": delta
})
}
pub fn emit_function_call_arguments_done(
&mut self,
output_index: usize,
item_id: &str,
arguments: &str,
) -> serde_json::Value {
json!({
"type": "response.function_call_arguments.done",
"sequence_number": self.next_sequence(),
"output_index": output_index,
"item_id": item_id,
"arguments": arguments
})
}
// ======================================================================== // ========================================================================
// Output Item Wrapper Events // Output Item Wrapper Events
// ======================================================================== // ========================================================================
...@@ -387,6 +424,7 @@ impl ResponseStreamEventEmitter { ...@@ -387,6 +424,7 @@ impl ResponseStreamEventEmitter {
let id_prefix = match &item_type { let id_prefix = match &item_type {
OutputItemType::McpListTools => "mcpl", OutputItemType::McpListTools => "mcpl",
OutputItemType::McpCall => "mcp", OutputItemType::McpCall => "mcp",
OutputItemType::FunctionCall => "fc",
OutputItemType::Message => "msg", OutputItemType::Message => "msg",
OutputItemType::Reasoning => "rs", OutputItemType::Reasoning => "rs",
}; };
...@@ -582,4 +620,40 @@ impl ResponseStreamEventEmitter { ...@@ -582,4 +620,40 @@ impl ResponseStreamEventEmitter {
} }
} }
} }
/// Emit an error event
///
/// Creates and sends an error event with the given error message.
/// Uses OpenAI's error event format.
/// Use this for terminal errors that should abort the streaming response.
pub fn emit_error(
&mut self,
error_msg: &str,
error_code: Option<&str>,
tx: &mpsc::UnboundedSender<Result<Bytes, std::io::Error>>,
) {
let event = json!({
"type": "error",
"code": error_code.unwrap_or("internal_error"),
"message": error_msg,
"param": null,
"sequence_number": self.next_sequence()
});
let sse_data = format!("data: {}\n\n", serde_json::to_string(&event).unwrap());
let _ = tx.send(Ok(Bytes::from(sse_data)));
}
}
/// Build a Server-Sent Events (SSE) response
///
/// Creates a Response with proper SSE headers and streaming body.
pub fn build_sse_response(rx: mpsc::UnboundedReceiver<Result<Bytes, std::io::Error>>) -> Response {
let stream = UnboundedReceiverStream::new(rx);
Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "text/event-stream")
.header("Cache-Control", "no-cache")
.header("Connection", "keep-alive")
.body(Body::from_stream(stream))
.unwrap()
} }
//! Utility functions for /v1/responses endpoint
use std::sync::Arc;
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
};
use serde_json::json;
use crate::{
core::WorkerRegistry,
mcp::McpManager,
protocols::responses::{ResponseTool, ResponseToolType},
routers::{grpc::error, openai::mcp::ensure_request_mcp_client},
};
/// Ensure MCP connection succeeds if MCP tools are declared
///
/// Checks if request declares MCP tools, and if so, validates that
/// the MCP client can be created and connected.
pub async fn ensure_mcp_connection(
mcp_manager: &Arc<McpManager>,
tools: Option<&[ResponseTool]>,
) -> Result<bool, Response> {
let has_mcp_tools = tools
.map(|t| {
t.iter()
.any(|tool| matches!(tool.r#type, ResponseToolType::Mcp))
})
.unwrap_or(false);
if has_mcp_tools {
if let Some(tools) = tools {
if ensure_request_mcp_client(mcp_manager, tools)
.await
.is_none()
{
return Err(error::failed_dependency(
"Failed to connect to MCP server. Check server_url and authorization.",
));
}
}
}
Ok(has_mcp_tools)
}
/// Validate that workers are available for the requested model
pub fn validate_worker_availability(
worker_registry: &Arc<WorkerRegistry>,
model: &str,
) -> Option<Response> {
let available_models = worker_registry.get_models();
if !available_models.contains(&model.to_string()) {
return Some(
(
StatusCode::SERVICE_UNAVAILABLE,
axum::Json(json!({
"error": {
"message": format!(
"No workers available for model '{}'. Available models: {}",
model,
available_models.join(", ")
),
"type": "service_unavailable",
"param": "model",
"code": "no_available_workers"
}
})),
)
.into_response(),
);
}
None
}
...@@ -107,6 +107,30 @@ pub fn service_unavailable(message: impl Into<String>) -> Response { ...@@ -107,6 +107,30 @@ pub fn service_unavailable(message: impl Into<String>) -> Response {
.into_response() .into_response()
} }
/// Create a 424 Failed Dependency response
///
/// Use this when an external dependency (like MCP server) fails.
///
/// # Example
/// ```ignore
/// return Err(failed_dependency("Failed to connect to MCP server"));
/// ```
pub fn failed_dependency(message: impl Into<String>) -> Response {
let msg = message.into();
warn!("{}", msg);
(
StatusCode::FAILED_DEPENDENCY,
Json(json!({
"error": {
"message": msg,
"type": "external_connector_error",
"code": 424
}
})),
)
.into_response()
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
......
...@@ -124,8 +124,10 @@ pub enum ResponsesIterationResult { ...@@ -124,8 +124,10 @@ pub enum ResponsesIterationResult {
/// Tool calls found in commentary channel - continue MCP loop /// Tool calls found in commentary channel - continue MCP loop
ToolCallsFound { ToolCallsFound {
tool_calls: Vec<ToolCall>, tool_calls: Vec<ToolCall>,
analysis: Option<String>, // For streaming emission analysis: Option<String>, // For streaming emission or reasoning output
partial_text: String, // For streaming emission partial_text: String, // For streaming emission or message output
usage: Usage, // Token usage from this iteration
request_id: String, // Request ID from dispatch
}, },
/// No tool calls - return final ResponsesResponse /// No tool calls - return final ResponsesResponse
Completed { Completed {
...@@ -206,6 +208,9 @@ impl HarmonyResponseProcessor { ...@@ -206,6 +208,9 @@ impl HarmonyResponseProcessor {
); );
} }
// Build usage (needed for both ToolCallsFound and Completed)
let usage = response_formatting::build_usage(std::slice::from_ref(complete));
// Check for tool calls in commentary channel // Check for tool calls in commentary channel
if let Some(tool_calls) = parsed.commentary { if let Some(tool_calls) = parsed.commentary {
// Tool calls found - return for MCP loop execution // Tool calls found - return for MCP loop execution
...@@ -213,6 +218,8 @@ impl HarmonyResponseProcessor { ...@@ -213,6 +218,8 @@ impl HarmonyResponseProcessor {
tool_calls, tool_calls,
analysis: parsed.analysis, analysis: parsed.analysis,
partial_text: parsed.final_text, partial_text: parsed.final_text,
usage,
request_id: dispatch.request_id.clone(),
}); });
} }
...@@ -245,9 +252,6 @@ impl HarmonyResponseProcessor { ...@@ -245,9 +252,6 @@ impl HarmonyResponseProcessor {
output.push(message_item); output.push(message_item);
} }
// Build usage
let usage = response_formatting::build_usage(std::slice::from_ref(complete));
// Build ResponsesResponse with all required fields // Build ResponsesResponse with all required fields
let response = ResponsesResponse { let response = ResponsesResponse {
id: dispatch.request_id.clone(), id: dispatch.request_id.clone(),
......
...@@ -214,7 +214,7 @@ impl HarmonyPreparationStage { ...@@ -214,7 +214,7 @@ impl HarmonyPreparationStage {
let params_schema = &tool.function.parameters; let params_schema = &tool.function.parameters;
tags.push(json!({ tags.push(json!({
"begin": format!("<|channel|>commentary to=functions.{}<|constrain|>json<|message|>", tool_name), "begin": format!("<|start|>assistant<|channel|>commentary to=functions.{}<|constrain|>json<|message|>", tool_name),
"content": { "content": {
"type": "json_schema", "type": "json_schema",
"json_schema": params_schema "json_schema": params_schema
...@@ -228,7 +228,7 @@ impl HarmonyPreparationStage { ...@@ -228,7 +228,7 @@ impl HarmonyPreparationStage {
let structural_tag = json!({ let structural_tag = json!({
"format": { "format": {
"type": "triggered_tags", "type": "triggered_tags",
"triggers": ["<|channel|>commentary"], "triggers": ["<|start|>assistant"],
"tags": tags, "tags": tags,
"at_least_one": true, "at_least_one": true,
"stop_after_first": stop_after_first "stop_after_first": stop_after_first
......
...@@ -226,7 +226,7 @@ pub fn responses_to_chat(req: &ResponsesRequest) -> Result<ChatCompletionRequest ...@@ -226,7 +226,7 @@ pub fn responses_to_chat(req: &ResponsesRequest) -> Result<ChatCompletionRequest
parallel_tool_calls: req.parallel_tool_calls, parallel_tool_calls: req.parallel_tool_calls,
top_logprobs: req.top_logprobs, top_logprobs: req.top_logprobs,
top_p: req.top_p, top_p: req.top_p,
skip_special_tokens: true, // Always skip special tokens // TODO: except for gpt-oss skip_special_tokens: true,
// Note: tools and tool_choice will be handled separately for MCP transformation // Note: tools and tool_choice will be handled separately for MCP transformation
tools: None, // Will be set by caller if needed tools: None, // Will be set by caller if needed
tool_choice: None, // Will be set by caller if needed tool_choice: None, // Will be set by caller if needed
......
...@@ -36,14 +36,13 @@ use std::sync::Arc; ...@@ -36,14 +36,13 @@ use std::sync::Arc;
use axum::{ use axum::{
body::Body, body::Body,
http::{self, header, StatusCode}, http::{self, StatusCode},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
}; };
use bytes::Bytes; use bytes::Bytes;
use futures_util::StreamExt; use futures_util::StreamExt;
use serde_json::json; use serde_json::json;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, warn}; use tracing::{debug, warn};
use uuid::Uuid; use uuid::Uuid;
use validator::Validate; use validator::Validate;
...@@ -67,8 +66,13 @@ use crate::{ ...@@ -67,8 +66,13 @@ use crate::{
}, },
}, },
routers::{ routers::{
grpc::{common::responses::streaming::ResponseStreamEventEmitter, error}, grpc::{
openai::{conversations::persist_conversation_items, mcp::ensure_request_mcp_client}, common::responses::{
build_sse_response, ensure_mcp_connection, streaming::ResponseStreamEventEmitter,
},
error,
},
openai::conversations::persist_conversation_items,
}, },
}; };
...@@ -81,33 +85,6 @@ pub async fn route_responses( ...@@ -81,33 +85,6 @@ pub async fn route_responses(
headers: Option<http::HeaderMap>, headers: Option<http::HeaderMap>,
model_id: Option<String>, model_id: Option<String>,
) -> Response { ) -> Response {
// 0. Fast worker validation (fail-fast before expensive operations)
let requested_model: Option<&str> = model_id.as_deref().or(Some(request.model.as_str()));
if let Some(model) = requested_model {
// Check if any workers support this model
let available_models = ctx.worker_registry.get_models();
if !available_models.contains(&model.to_string()) {
return (
StatusCode::SERVICE_UNAVAILABLE,
axum::Json(json!({
"error": {
"message": format!(
"No workers available for model '{}'. Available models: {}",
model,
available_models.join(", ")
),
"type": "service_unavailable",
"param": "model",
"code": "no_available_workers"
}
})),
)
.into_response();
}
}
// 1. Validate request (includes conversation ID format) // 1. Validate request (includes conversation ID format)
if let Err(validation_errors) = request.validate() { if let Err(validation_errors) = request.validate() {
// Extract the first error message for conversation field // Extract the first error message for conversation field
...@@ -171,7 +148,10 @@ pub async fn route_responses( ...@@ -171,7 +148,10 @@ pub async fn route_responses(
if is_streaming { if is_streaming {
route_responses_streaming(ctx, request, headers, model_id).await route_responses_streaming(ctx, request, headers, model_id).await
} else { } else {
route_responses_sync(ctx, request, headers, model_id, None).await // Generate response ID for synchronous execution
// TODO: we may remove this when we have builder pattern for responses
let response_id = Some(format!("resp_{}", Uuid::new_v4()));
route_responses_sync(ctx, request, headers, model_id, response_id).await
} }
} }
...@@ -211,40 +191,24 @@ async fn route_responses_internal( ...@@ -211,40 +191,24 @@ async fn route_responses_internal(
// 1. Load conversation history and build modified request // 1. Load conversation history and build modified request
let modified_request = load_conversation_history(ctx, &request).await?; let modified_request = load_conversation_history(ctx, &request).await?;
// 2. Check if request has MCP tools - if so, use tool loop // 2. Check MCP connection and get whether MCP tools are present
let responses_response = if let Some(tools) = &request.tools { let has_mcp_tools = ensure_mcp_connection(&ctx.mcp_manager, request.tools.as_deref()).await?;
// Ensure dynamic MCP client is registered for request-scoped tools
if ensure_request_mcp_client(&ctx.mcp_manager, tools)
.await
.is_some()
{
debug!("MCP tools detected, using tool loop");
// Execute with MCP tool loop let responses_response = if has_mcp_tools {
execute_tool_loop( debug!("MCP tools detected, using tool loop");
ctx,
modified_request, // Execute with MCP tool loop
&request, execute_tool_loop(
headers, ctx,
model_id, modified_request,
response_id.clone(), &request,
) headers,
.await? model_id,
} else { response_id.clone(),
debug!("Failed to create MCP client from request tools"); )
// Fall through to non-MCP execution .await?
execute_without_mcp(
ctx,
&modified_request,
&request,
headers,
model_id,
response_id.clone(),
)
.await?
}
} else { } else {
// No tools, execute normally // No MCP tools - execute without MCP (may have function tools or no tools)
execute_without_mcp( execute_without_mcp(
ctx, ctx,
&modified_request, &modified_request,
...@@ -289,18 +253,18 @@ async fn route_responses_streaming( ...@@ -289,18 +253,18 @@ async fn route_responses_streaming(
Err(response) => return response, // Already a Response with proper status code Err(response) => return response, // Already a Response with proper status code
}; };
// 2. Check if request has MCP tools - if so, use streaming tool loop // 2. Check MCP connection and get whether MCP tools are present
if let Some(tools) = &request.tools { let has_mcp_tools =
// Ensure dynamic MCP client is registered for request-scoped tools match ensure_mcp_connection(&ctx.mcp_manager, request.tools.as_deref()).await {
if ensure_request_mcp_client(&ctx.mcp_manager, tools) Ok(has_mcp) => has_mcp,
.await Err(response) => return response,
.is_some() };
{
debug!("MCP tools detected in streaming mode, using streaming tool loop");
return execute_tool_loop_streaming(ctx, modified_request, &request, headers, model_id) if has_mcp_tools {
.await; debug!("MCP tools detected in streaming mode, using streaming tool loop");
}
return execute_tool_loop_streaming(ctx, modified_request, &request, headers, model_id)
.await;
} }
// 3. Convert ResponsesRequest → ChatCompletionRequest // 3. Convert ResponsesRequest → ChatCompletionRequest
...@@ -352,8 +316,8 @@ async fn convert_chat_stream_to_responses_stream( ...@@ -352,8 +316,8 @@ async fn convert_chat_stream_to_responses_stream(
) )
.await; .await;
// Extract body and headers from chat response // Extract body from chat response
let (parts, body) = chat_response.into_parts(); let (_parts, body) = chat_response.into_parts();
// Create channel for transformed SSE events // Create channel for transformed SSE events
let (tx, rx) = mpsc::unbounded_channel::<Result<Bytes, std::io::Error>>(); let (tx, rx) = mpsc::unbounded_channel::<Result<Bytes, std::io::Error>>();
...@@ -392,29 +356,7 @@ async fn convert_chat_stream_to_responses_stream( ...@@ -392,29 +356,7 @@ async fn convert_chat_stream_to_responses_stream(
}); });
// Build SSE response with transformed stream // Build SSE response with transformed stream
let stream = UnboundedReceiverStream::new(rx); build_sse_response(rx)
let body = Body::from_stream(stream);
let mut response = Response::builder().status(parts.status).body(body).unwrap();
// Copy headers from original chat response
*response.headers_mut() = parts.headers;
// Ensure SSE headers are set
response.headers_mut().insert(
header::CONTENT_TYPE,
header::HeaderValue::from_static("text/event-stream"),
);
response.headers_mut().insert(
header::CACHE_CONTROL,
header::HeaderValue::from_static("no-cache"),
);
response.headers_mut().insert(
header::CONNECTION,
header::HeaderValue::from_static("keep-alive"),
);
response
} }
/// Process chat SSE stream and transform to responses format /// Process chat SSE stream and transform to responses format
......
...@@ -10,7 +10,10 @@ use axum::{ ...@@ -10,7 +10,10 @@ use axum::{
use tracing::debug; use tracing::debug;
use super::{ use super::{
common::responses::handlers::{cancel_response_impl, get_response_impl}, common::responses::{
handlers::{cancel_response_impl, get_response_impl},
utils::validate_worker_availability,
},
context::SharedComponents, context::SharedComponents,
harmony::{ harmony::{
serve_harmony_responses, serve_harmony_responses_stream, HarmonyDetector, serve_harmony_responses, serve_harmony_responses_stream, HarmonyDetector,
...@@ -191,14 +194,18 @@ impl GrpcRouter { ...@@ -191,14 +194,18 @@ impl GrpcRouter {
body: &ResponsesRequest, body: &ResponsesRequest,
model_id: Option<&str>, model_id: Option<&str>,
) -> Response { ) -> Response {
// 0. Fast worker validation (fail-fast before expensive operations)
let requested_model: Option<&str> = model_id.or(Some(body.model.as_str()));
if let Some(error_response) = requested_model
.and_then(|model| validate_worker_availability(&self.worker_registry, model))
{
return error_response;
}
// Choose implementation based on Harmony model detection // Choose implementation based on Harmony model detection
let is_harmony = HarmonyDetector::is_harmony_model(&body.model); let is_harmony = HarmonyDetector::is_harmony_model(&body.model);
debug!(
"Processing responses request for model: {:?}, using_harmony={}",
model_id, is_harmony
);
if is_harmony { if is_harmony {
debug!( debug!(
"Processing Harmony responses request for model: {:?}, streaming: {:?}", "Processing Harmony responses request for model: {:?}, streaming: {:?}",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment