"src/vscode:/vscode.git/clone" did not exist on "91f7ff8e5284dadaa6d463a03c600fe5ed329980"
Unverified Commit 96ac24c0 authored by Keyang Ru's avatar Keyang Ru Committed by GitHub
Browse files

[router] web_search_preview tool basic implementation (#12290)

parent c0d02cf4
......@@ -3,10 +3,11 @@
//! This module contains all MCP-related functionality for the OpenAI router:
//! - Tool loop state management for multi-turn tool calling
//! - MCP tool execution and result handling
//! - Output item builders for MCP-specific response formats
//! - Output item builders for MCP-specific response formats (including web_search_call)
//! - SSE event generation for streaming MCP operations
//! - Payload transformation for MCP tool interception
//! - Metadata injection for MCP operations
//! - Web search preview tool handling (simplified MCP interface)
use std::{io, sync::Arc};
......@@ -16,7 +17,7 @@ use serde_json::{json, to_value, Value};
use tokio::sync::mpsc;
use tracing::{debug, info, warn};
use super::utils::event_types;
use super::utils::{event_types, web_search_constants, ToolContext};
use crate::{
mcp,
protocols::responses::{
......@@ -36,11 +37,16 @@ pub(crate) struct McpLoopConfig {
/// Maximum iterations as safety limit (internal only, default: 10)
/// Prevents infinite loops when max_tool_calls is not set
pub max_iterations: usize,
/// Tool context for handling web_search_preview vs regular tools
pub tool_context: ToolContext,
}
impl Default for McpLoopConfig {
fn default() -> Self {
Self { max_iterations: 10 }
Self {
max_iterations: 10,
tool_context: ToolContext::Regular,
}
}
}
......@@ -158,6 +164,13 @@ pub async fn ensure_request_mcp_client(
.server_label
.clone()
.unwrap_or_else(|| "request-mcp".to_string());
// Validate that web_search_preview is not used as it's a reserved name
if name == web_search_constants::WEB_SEARCH_PREVIEW_SERVER_NAME {
warn!("Rejecting request MCP with reserved server name: {}", name);
return None;
}
let token = tool.authorization.clone();
let transport = if server_url.contains("/sse") {
mcp::McpTransport::Sse {
......@@ -202,6 +215,7 @@ pub(super) async fn execute_streaming_tool_calls(
state: &mut ToolLoopState,
server_label: &str,
sequence_number: &mut u64,
tool_context: ToolContext,
) -> bool {
// Execute all pending tool calls (sequential, as PR3 is skipped)
for call in pending_calls {
......@@ -258,6 +272,7 @@ pub(super) async fn execute_streaming_tool_calls(
success,
error_msg.as_deref(),
sequence_number,
tool_context,
) {
// Client disconnected, no point continuing tool execution
return false;
......@@ -277,6 +292,7 @@ pub(super) async fn execute_streaming_tool_calls(
pub(super) fn prepare_mcp_payload_for_streaming(
payload: &mut Value,
active_mcp: &Arc<mcp::McpManager>,
tool_context: ToolContext,
) {
if let Some(obj) = payload.as_object_mut() {
// Remove any non-function tools from outgoing payload
......@@ -291,10 +307,27 @@ pub(super) fn prepare_mcp_payload_for_streaming(
}
}
// Build function tools for all discovered MCP tools
// Build function tools for discovered MCP tools
let mut tools_json = Vec::new();
let tools = active_mcp.list_tools();
for t in tools {
// Get tools with server names from inventory
// Returns Vec<(tool_name, server_name, Tool)>
let tools = active_mcp.inventory().list_tools();
// Filter tools based on context
let filtered_tools: Vec<_> = if tool_context.is_web_search() {
// Only include tools from web_search_preview server
tools
.into_iter()
.filter(|(_, server_name, _)| {
server_name == web_search_constants::WEB_SEARCH_PREVIEW_SERVER_NAME
})
.collect()
} else {
tools
};
for (_, _, t) in filtered_tools {
let parameters = Value::Object((*t.input_schema).clone());
let tool = serde_json::json!({
"type": event_types::ITEM_TYPE_FUNCTION,
......@@ -466,6 +499,7 @@ pub(super) fn send_mcp_list_tools_events(
/// Send mcp_call completion events after tool execution
/// Returns false if client disconnected
#[allow(clippy::too_many_arguments)]
pub(super) fn send_mcp_call_completion_events_with_error(
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
call: &FunctionCallInProgress,
......@@ -474,6 +508,7 @@ pub(super) fn send_mcp_call_completion_events_with_error(
success: bool,
error_msg: Option<&str>,
sequence_number: &mut u64,
tool_context: ToolContext,
) -> bool {
let effective_output_index = call.effective_output_index();
......@@ -485,17 +520,24 @@ pub(super) fn send_mcp_call_completion_events_with_error(
server_label,
success,
error_msg,
tool_context,
);
// Get the mcp_call item_id
// Get the item_id
let item_id = mcp_call_item
.get("id")
.and_then(|v| v.as_str())
.unwrap_or("");
// Event 1: response.mcp_call.completed
// Event 1: response.{web_search_call|mcp_call}.completed
let completed_event_type = if tool_context.is_web_search() {
event_types::WEB_SEARCH_CALL_COMPLETED
} else {
event_types::MCP_CALL_COMPLETED
};
let completed_payload = json!({
"type": event_types::MCP_CALL_COMPLETED,
"type": completed_event_type,
"sequence_number": *sequence_number,
"output_index": effective_output_index,
"item_id": item_id
......@@ -504,8 +546,7 @@ pub(super) fn send_mcp_call_completion_events_with_error(
let completed_event = format!(
"event: {}\ndata: {}\n\n",
event_types::MCP_CALL_COMPLETED,
completed_payload
completed_event_type, completed_payload
);
if tx.send(Ok(Bytes::from(completed_event))).is_err() {
return false;
......@@ -538,28 +579,40 @@ pub(super) fn inject_mcp_metadata_streaming(
state: &ToolLoopState,
mcp: &Arc<mcp::McpManager>,
server_label: &str,
tool_context: ToolContext,
) {
if let Some(output_array) = response.get_mut("output").and_then(|v| v.as_array_mut()) {
output_array.retain(|item| {
item.get("type").and_then(|t| t.as_str()) != Some(event_types::ITEM_TYPE_MCP_LIST_TOOLS)
});
let list_tools_item = build_mcp_list_tools_item(mcp, server_label);
output_array.insert(0, list_tools_item);
let mut insert_pos = 0;
// Only add mcp_list_tools for non-web-search cases
if !tool_context.is_web_search() {
let list_tools_item = build_mcp_list_tools_item(mcp, server_label);
output_array.insert(0, list_tools_item);
insert_pos = 1;
}
let mcp_call_items =
build_executed_mcp_call_items(&state.conversation_history, server_label);
let mut insert_pos = 1;
build_executed_mcp_call_items(&state.conversation_history, server_label, tool_context);
for item in mcp_call_items {
output_array.insert(insert_pos, item);
insert_pos += 1;
}
} else if let Some(obj) = response.as_object_mut() {
let mut output_items = Vec::new();
output_items.push(build_mcp_list_tools_item(mcp, server_label));
// Only add mcp_list_tools for non-web-search cases
if !tool_context.is_web_search() {
output_items.push(build_mcp_list_tools_item(mcp, server_label));
}
output_items.extend(build_executed_mcp_call_items(
&state.conversation_history,
server_label,
tool_context,
));
obj.insert("output".to_string(), Value::Array(output_items));
}
......@@ -658,6 +711,7 @@ pub(super) async fn execute_tool_loop(
"max_tool_calls",
active_mcp,
original_body,
config.tool_context,
);
}
......@@ -716,22 +770,28 @@ pub(super) async fn execute_tool_loop(
})
.unwrap_or("mcp");
// Build mcp_list_tools item
let list_tools_item = build_mcp_list_tools_item(active_mcp, server_label);
// Insert at beginning of output array
if let Some(output_array) = response_json
.get_mut("output")
.and_then(|v| v.as_array_mut())
{
output_array.insert(0, list_tools_item);
let mut insert_pos = 0;
// Build mcp_call items using helper function
let mcp_call_items =
build_executed_mcp_call_items(&state.conversation_history, server_label);
// Only add mcp_list_tools for non-web-search cases
if !config.tool_context.is_web_search() {
let list_tools_item = build_mcp_list_tools_item(active_mcp, server_label);
output_array.insert(0, list_tools_item);
insert_pos = 1;
}
// Insert mcp_call items after mcp_list_tools using mutable position
let mut insert_pos = 1;
// Build mcp_call items (will be web_search_call for web search tools)
let mcp_call_items = build_executed_mcp_call_items(
&state.conversation_history,
server_label,
config.tool_context,
);
// Insert call items after mcp_list_tools (if present)
for item in mcp_call_items {
output_array.insert(insert_pos, item);
insert_pos += 1;
......@@ -751,13 +811,17 @@ pub(super) fn build_incomplete_response(
reason: &str,
active_mcp: &Arc<mcp::McpManager>,
original_body: &ResponsesRequest,
tool_context: ToolContext,
) -> Result<Value, String> {
let obj = response
.as_object_mut()
.ok_or_else(|| "response not an object".to_string())?;
// Set status to completed (not failed - partial success)
obj.insert("status".to_string(), Value::String("completed".to_string()));
obj.insert(
"status".to_string(),
Value::String(web_search_constants::STATUS_COMPLETED.to_string()),
);
// Set incomplete_details
obj.insert(
......@@ -799,6 +863,7 @@ pub(super) fn build_incomplete_response(
server_label,
false, // Not successful
Some("Not executed - response stopped due to limit"),
tool_context,
);
mcp_call_items.push(mcp_call_item);
}
......@@ -806,20 +871,28 @@ pub(super) fn build_incomplete_response(
// Add mcp_list_tools and executed mcp_call items at the beginning
if state.total_calls > 0 || !mcp_call_items.is_empty() {
let list_tools_item = build_mcp_list_tools_item(active_mcp, server_label);
output_array.insert(0, list_tools_item);
let mut insert_pos = 0;
// Only add mcp_list_tools for non-web-search cases
if !tool_context.is_web_search() {
let list_tools_item = build_mcp_list_tools_item(active_mcp, server_label);
output_array.insert(0, list_tools_item);
insert_pos = 1;
}
// Add mcp_call items for executed calls using helper
let executed_items =
build_executed_mcp_call_items(&state.conversation_history, server_label);
// Add mcp_call items for executed calls (will be web_search_call for web search)
let executed_items = build_executed_mcp_call_items(
&state.conversation_history,
server_label,
tool_context,
);
let mut insert_pos = 1;
for item in executed_items {
output_array.insert(insert_pos, item);
insert_pos += 1;
}
// Add incomplete mcp_call items
// Add incomplete mcp_call items (will be web_search_call for web search)
for item in mcp_call_items {
output_array.insert(insert_pos, item);
insert_pos += 1;
......@@ -847,6 +920,67 @@ pub(super) fn build_incomplete_response(
Ok(response)
}
// ============================================================================
// Web Search Preview Helpers
// ============================================================================
/// Detect if request has web_search_preview tool
pub(super) fn has_web_search_preview_tool(tools: &[ResponseTool]) -> bool {
tools
.iter()
.any(|t| matches!(t.r#type, ResponseToolType::WebSearchPreview))
}
/// Check if MCP server "web_search_preview" is available
pub(super) async fn is_web_search_mcp_available(mcp_manager: &Arc<mcp::McpManager>) -> bool {
mcp_manager
.get_client(web_search_constants::WEB_SEARCH_PREVIEW_SERVER_NAME)
.await
.is_some()
}
/// Build a web_search_call output item (MVP - status only)
///
/// The MCP search results are passed to the LLM internally via function_call_output,
/// but we don't expose them in the web_search_call item to the client.
fn build_web_search_call_item(query: Option<String>) -> Value {
let mut action = serde_json::Map::new();
action.insert(
"type".to_string(),
Value::String(web_search_constants::ACTION_TYPE_SEARCH.to_string()),
);
if let Some(q) = query {
action.insert("query".to_string(), Value::String(q));
}
json!({
"id": generate_id("ws"),
"type": event_types::ITEM_TYPE_WEB_SEARCH_CALL,
"status": web_search_constants::STATUS_COMPLETED,
"action": action
})
}
/// Build a failed web_search_call output item
fn build_web_search_call_item_failed(error: &str, query: Option<String>) -> Value {
let mut action = serde_json::Map::new();
action.insert(
"type".to_string(),
Value::String(web_search_constants::ACTION_TYPE_SEARCH.to_string()),
);
if let Some(q) = query {
action.insert("query".to_string(), Value::String(q));
}
json!({
"id": generate_id("ws"),
"type": event_types::ITEM_TYPE_WEB_SEARCH_CALL,
"status": web_search_constants::STATUS_FAILED,
"action": action,
"error": error
})
}
// ============================================================================
// Output Item Builders
// ============================================================================
......@@ -884,24 +1018,47 @@ pub(super) fn build_mcp_call_item(
server_label: &str,
success: bool,
error: Option<&str>,
tool_context: ToolContext,
) -> Value {
json!({
"id": generate_id("mcp"),
"type": event_types::ITEM_TYPE_MCP_CALL,
"status": if success { "completed" } else { "failed" },
"approval_request_id": Value::Null,
"arguments": arguments,
"error": error,
"name": tool_name,
"output": output,
"server_label": server_label
})
// Check if this is a web_search_preview context - if so, build web_search_call format
if tool_context.is_web_search() {
// Extract query from arguments for web_search_call
let query = serde_json::from_str::<Value>(arguments).ok().and_then(|v| {
v.get("query")
.and_then(|q| q.as_str().map(|s| s.to_string()))
});
// Build web_search_call item (MVP - status only, no results)
if success {
build_web_search_call_item(query)
} else {
build_web_search_call_item_failed(error.unwrap_or("Tool execution failed"), query)
}
} else {
// Regular mcp_call item
json!({
"id": generate_id("mcp"),
"type": event_types::ITEM_TYPE_MCP_CALL,
"status": if success {
web_search_constants::STATUS_COMPLETED
} else {
web_search_constants::STATUS_FAILED
},
"approval_request_id": Value::Null,
"arguments": arguments,
"error": error,
"name": tool_name,
"output": output,
"server_label": server_label
})
}
}
/// Helper function to build mcp_call items from executed tool calls in conversation history
pub(super) fn build_executed_mcp_call_items(
conversation_history: &[Value],
server_label: &str,
tool_context: ToolContext,
) -> Vec<Value> {
let mut mcp_call_items = Vec::new();
......@@ -940,6 +1097,7 @@ pub(super) fn build_executed_mcp_call_items(
} else {
None
},
tool_context,
);
mcp_call_items.push(mcp_call_item);
}
......
......@@ -5,7 +5,7 @@ use std::collections::HashMap;
use serde_json::{json, Value};
use tracing::warn;
use super::utils::event_types;
use super::utils::{event_types, web_search_constants};
use crate::{
data_connector::{ResponseId, StoredResponse},
protocols::responses::{ResponseToolType, ResponsesRequest},
......@@ -276,41 +276,67 @@ pub(super) fn rewrite_streaming_block(
/// Mask function tools as MCP tools in response for client
pub(super) fn mask_tools_as_mcp(resp: &mut Value, original_body: &ResponsesRequest) {
// Check for MCP tool
let mcp_tool = original_body.tools.as_ref().and_then(|tools| {
tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some())
});
let Some(t) = mcp_tool else {
return;
};
let mut m = serde_json::Map::new();
m.insert("type".to_string(), Value::String("mcp".to_string()));
if let Some(label) = &t.server_label {
m.insert("server_label".to_string(), Value::String(label.clone()));
}
if let Some(url) = &t.server_url {
m.insert("server_url".to_string(), Value::String(url.clone()));
}
if let Some(desc) = &t.server_description {
m.insert(
"server_description".to_string(),
Value::String(desc.clone()),
);
// Check for web_search_preview tool
let has_web_search = original_body
.tools
.as_ref()
.map(|tools| crate::routers::openai::mcp::has_web_search_preview_tool(tools))
.unwrap_or(false);
// If neither MCP nor web_search_preview, return early
if mcp_tool.is_none() && !has_web_search {
return;
}
if let Some(req) = &t.require_approval {
m.insert("require_approval".to_string(), Value::String(req.clone()));
let mut response_tools = Vec::new();
// Add MCP tool if present
if let Some(t) = mcp_tool {
let mut m = serde_json::Map::new();
m.insert("type".to_string(), Value::String("mcp".to_string()));
if let Some(label) = &t.server_label {
m.insert("server_label".to_string(), Value::String(label.clone()));
}
if let Some(url) = &t.server_url {
m.insert("server_url".to_string(), Value::String(url.clone()));
}
if let Some(desc) = &t.server_description {
m.insert(
"server_description".to_string(),
Value::String(desc.clone()),
);
}
if let Some(req) = &t.require_approval {
m.insert("require_approval".to_string(), Value::String(req.clone()));
}
if let Some(allowed) = &t.allowed_tools {
m.insert(
"allowed_tools".to_string(),
Value::Array(allowed.iter().map(|s| Value::String(s.clone())).collect()),
);
}
response_tools.push(Value::Object(m));
}
if let Some(allowed) = &t.allowed_tools {
m.insert(
"allowed_tools".to_string(),
Value::Array(allowed.iter().map(|s| Value::String(s.clone())).collect()),
// Add web_search_preview tool if present
if has_web_search {
let mut ws = serde_json::Map::new();
ws.insert(
"type".to_string(),
Value::String(web_search_constants::WEB_SEARCH_PREVIEW_SERVER_NAME.to_string()),
);
response_tools.push(Value::Object(ws));
}
if let Some(obj) = resp.as_object_mut() {
obj.insert("tools".to_string(), Value::Array(vec![Value::Object(m)]));
obj.insert("tools".to_string(), Value::Array(response_tools));
obj.entry("tool_choice")
.or_insert(Value::String("auto".to_string()));
}
......
......@@ -30,12 +30,12 @@ use super::conversations::{
};
use super::{
mcp::{
ensure_request_mcp_client, execute_tool_loop, prepare_mcp_payload_for_streaming,
McpLoopConfig,
ensure_request_mcp_client, execute_tool_loop, has_web_search_preview_tool,
is_web_search_mcp_available, prepare_mcp_payload_for_streaming, McpLoopConfig,
},
responses::{mask_tools_as_mcp, patch_streaming_response_json},
streaming::handle_streaming_response,
utils::{apply_provider_headers, extract_auth_header, probe_endpoint_for_model},
utils::{apply_provider_headers, extract_auth_header, probe_endpoint_for_model, ToolContext},
};
use crate::{
core::{CircuitBreaker, CircuitBreakerConfig as CoreCircuitBreakerConfig},
......@@ -248,6 +248,7 @@ impl OpenAIRouter {
mut payload: Value,
original_body: &ResponsesRequest,
original_previous_response_id: Option<String>,
tool_context: ToolContext,
) -> Response {
// Check if MCP is active for this request
// Ensure dynamic client is created if needed
......@@ -266,10 +267,13 @@ impl OpenAIRouter {
// If MCP is active, execute tool loop
if let Some(mcp) = active_mcp {
let config = McpLoopConfig::default();
let config = McpLoopConfig {
tool_context,
..Default::default()
};
// Transform MCP tools to function tools
prepare_mcp_payload_for_streaming(&mut payload, mcp);
prepare_mcp_payload_for_streaming(&mut payload, mcp, tool_context);
match execute_tool_loop(
&self.client,
......@@ -695,6 +699,35 @@ impl crate::routers::RouterTrait for OpenAIRouter {
let url = format!("{}/v1/responses", base_url);
// Detect web_search_preview tool and verify MCP server availability
let tool_context = if let Some(ref tools) = body.tools {
if has_web_search_preview_tool(tools) {
ToolContext::WebSearchPreview
} else {
ToolContext::Regular
}
} else {
ToolContext::Regular
};
if tool_context.is_web_search() {
// Check if web_search_preview MCP server is available
if !is_web_search_mcp_available(&self.mcp_manager).await {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": {
"message": "Web search preview is currently unavailable. Please contact your server administrator.",
"type": "invalid_request_error",
"param": "tools",
"code": "web_search_unavailable"
}
})),
)
.into_response();
}
}
// Validate mutually exclusive params: previous_response_id and conversation
// TODO: this validation logic should move the right place, also we need a proper error message module
if body.previous_response_id.is_some() && body.conversation.is_some() {
......@@ -1022,6 +1055,7 @@ impl crate::routers::RouterTrait for OpenAIRouter {
payload,
body,
original_previous_response_id,
tool_context,
)
.await
} else {
......@@ -1031,6 +1065,7 @@ impl crate::routers::RouterTrait for OpenAIRouter {
payload,
body,
original_previous_response_id,
tool_context,
)
.await
}
......
......@@ -30,7 +30,10 @@ use super::{
send_mcp_list_tools_events, McpLoopConfig, ToolLoopState,
},
responses::{mask_tools_as_mcp, patch_streaming_response_json, rewrite_streaming_block},
utils::{event_types, FunctionCallInProgress, OutputIndexMapper, StreamAction},
utils::{
event_types, web_search_constants, FunctionCallInProgress, OutputIndexMapper, StreamAction,
ToolContext,
},
};
use crate::{
data_connector::{ConversationItemStorage, ConversationStorage, ResponseStorage},
......@@ -553,6 +556,7 @@ pub(super) fn apply_event_transformations_inplace(
server_label: &str,
original_request: &ResponsesRequest,
previous_response_id: Option<&str>,
tool_context: ToolContext,
) -> bool {
let mut changed = false;
......@@ -598,23 +602,35 @@ pub(super) fn apply_event_transformations_inplace(
// Mask tools from function to MCP format (optimized without cloning)
if response_obj.get("tools").is_some() {
let requested_mcp = original_request
.tools
.as_ref()
.map(|tools| {
tools
.iter()
.any(|t| matches!(t.r#type, ResponseToolType::Mcp))
})
.unwrap_or(false);
if requested_mcp {
if let Some(mcp_tools) = build_mcp_tools_value(original_request) {
response_obj.insert("tools".to_string(), mcp_tools);
response_obj
.entry("tool_choice".to_string())
.or_insert(Value::String("auto".to_string()));
changed = true;
// For web_search_preview, always use simplified tool format
if tool_context.is_web_search() {
let web_search_tool =
json!([{"type": web_search_constants::WEB_SEARCH_PREVIEW_SERVER_NAME}]);
response_obj.insert("tools".to_string(), web_search_tool);
response_obj
.entry("tool_choice".to_string())
.or_insert(Value::String("auto".to_string()));
changed = true;
} else {
// Regular MCP tools - only if requested
let requested_mcp = original_request
.tools
.as_ref()
.map(|tools| {
tools
.iter()
.any(|t| matches!(t.r#type, ResponseToolType::Mcp))
})
.unwrap_or(false);
if requested_mcp {
if let Some(mcp_tools) = build_mcp_tools_value(original_request) {
response_obj.insert("tools".to_string(), mcp_tools);
response_obj
.entry("tool_choice".to_string())
.or_insert(Value::String("auto".to_string()));
changed = true;
}
}
}
}
......@@ -629,13 +645,30 @@ pub(super) fn apply_event_transformations_inplace(
if item_type == event_types::ITEM_TYPE_FUNCTION_CALL
|| item_type == event_types::ITEM_TYPE_FUNCTION_TOOL_CALL
{
item["type"] = json!(event_types::ITEM_TYPE_MCP_CALL);
item["server_label"] = json!(server_label);
// Use web_search_call for web_search_preview, mcp_call for regular MCP
if tool_context.is_web_search() {
item["type"] = json!(event_types::ITEM_TYPE_WEB_SEARCH_CALL);
// Don't include server_label for web_search_call
// Remove internal implementation fields
if let Some(obj) = item.as_object_mut() {
obj.remove("arguments");
obj.remove("call_id");
obj.remove("name");
}
} else {
item["type"] = json!(event_types::ITEM_TYPE_MCP_CALL);
item["server_label"] = json!(server_label);
}
// Transform ID from fc_* to mcp_*
// Transform ID from fc_* to ws_* or mcp_*
if let Some(id) = item.get("id").and_then(|v| v.as_str()) {
if let Some(stripped) = id.strip_prefix("fc_") {
let new_id = format!("mcp_{}", stripped);
let prefix = if tool_context.is_web_search() {
"ws"
} else {
"mcp"
};
let new_id = format!("{}_{}", prefix, stripped);
item["id"] = json!(new_id);
}
}
......@@ -693,6 +726,7 @@ pub(super) fn forward_streaming_event(
original_request: &ResponsesRequest,
previous_response_id: Option<&str>,
sequence_number: &mut u64,
tool_context: ToolContext,
) -> bool {
// Skip individual function_call_arguments.delta events - we'll send them as one
if event_name == Some(event_types::FUNCTION_CALL_ARGUMENTS_DELTA) {
......@@ -757,37 +791,40 @@ pub(super) fn forward_streaming_event(
};
// Emit a synthetic MCP arguments delta event before the done event
let mut delta_event = json!({
"type": event_types::MCP_CALL_ARGUMENTS_DELTA,
"sequence_number": *sequence_number,
"output_index": assigned_index,
"item_id": mcp_item_id,
"delta": arguments_value,
});
if let Some(obfuscation) = call.last_obfuscation.as_ref() {
if let Some(obj) = delta_event.as_object_mut() {
obj.insert(
"obfuscation".to_string(),
Value::String(obfuscation.clone()),
);
// Skip for web_search_preview - we don't expose tool call arguments
if !tool_context.is_web_search() {
let mut delta_event = json!({
"type": event_types::MCP_CALL_ARGUMENTS_DELTA,
"sequence_number": *sequence_number,
"output_index": assigned_index,
"item_id": mcp_item_id,
"delta": arguments_value,
});
if let Some(obfuscation) = call.last_obfuscation.as_ref() {
if let Some(obj) = delta_event.as_object_mut() {
obj.insert(
"obfuscation".to_string(),
Value::String(obfuscation.clone()),
);
}
} else if let Some(obfuscation) = parsed_data.get("obfuscation").cloned() {
if let Some(obj) = delta_event.as_object_mut() {
obj.insert("obfuscation".to_string(), obfuscation);
}
}
} else if let Some(obfuscation) = parsed_data.get("obfuscation").cloned() {
if let Some(obj) = delta_event.as_object_mut() {
obj.insert("obfuscation".to_string(), obfuscation);
let delta_block = format!(
"event: {}\ndata: {}\n\n",
event_types::MCP_CALL_ARGUMENTS_DELTA,
delta_event
);
if tx.send(Ok(Bytes::from(delta_block))).is_err() {
return false;
}
}
let delta_block = format!(
"event: {}\ndata: {}\n\n",
event_types::MCP_CALL_ARGUMENTS_DELTA,
delta_event
);
if tx.send(Ok(Bytes::from(delta_block))).is_err() {
return false;
*sequence_number += 1;
}
*sequence_number += 1;
}
}
}
......@@ -813,6 +850,7 @@ pub(super) fn forward_streaming_event(
server_label,
original_request,
previous_response_id,
tool_context,
);
if let Some(response_obj) = parsed_data
......@@ -844,16 +882,24 @@ pub(super) fn forward_streaming_event(
let mut final_block = String::new();
if let Some(evt) = event_name {
// Update event name for function_call_arguments events
if evt == event_types::FUNCTION_CALL_ARGUMENTS_DELTA {
// Skip for web_search_preview - we don't expose tool call arguments
if evt == event_types::FUNCTION_CALL_ARGUMENTS_DELTA && !tool_context.is_web_search() {
final_block.push_str(&format!(
"event: {}\n",
event_types::MCP_CALL_ARGUMENTS_DELTA
));
} else if evt == event_types::FUNCTION_CALL_ARGUMENTS_DONE {
} else if evt == event_types::FUNCTION_CALL_ARGUMENTS_DONE && !tool_context.is_web_search()
{
final_block.push_str(&format!(
"event: {}\n",
event_types::MCP_CALL_ARGUMENTS_DONE
));
} else if (evt == event_types::FUNCTION_CALL_ARGUMENTS_DELTA
|| evt == event_types::FUNCTION_CALL_ARGUMENTS_DONE)
&& tool_context.is_web_search()
{
// Skip these events entirely for web_search_preview
return true;
} else {
final_block.push_str(&format!("event: {}\n", evt));
}
......@@ -865,30 +911,62 @@ pub(super) fn forward_streaming_event(
return false;
}
// After sending output_item.added for mcp_call, inject mcp_call.in_progress event
// After sending output_item.added for mcp_call/web_search_call, inject in_progress event
if event_name == Some(event_types::OUTPUT_ITEM_ADDED) {
if let Some(item) = parsed_data.get("item") {
if item.get("type").and_then(|v| v.as_str()) == Some(event_types::ITEM_TYPE_MCP_CALL) {
// Already transformed to mcp_call
let item_type = item.get("type").and_then(|v| v.as_str());
// Check if it's an mcp_call or web_search_call
let is_mcp_or_web_search = item_type == Some(event_types::ITEM_TYPE_MCP_CALL)
|| item_type == Some(event_types::ITEM_TYPE_WEB_SEARCH_CALL);
if is_mcp_or_web_search {
if let (Some(item_id), Some(output_index)) = (
item.get("id").and_then(|v| v.as_str()),
parsed_data.get("output_index").and_then(|v| v.as_u64()),
) {
// Choose event type based on tool_context
let in_progress_event_type = if tool_context.is_web_search() {
event_types::WEB_SEARCH_CALL_IN_PROGRESS
} else {
event_types::MCP_CALL_IN_PROGRESS
};
let in_progress_event = json!({
"type": event_types::MCP_CALL_IN_PROGRESS,
"type": in_progress_event_type,
"sequence_number": *sequence_number,
"output_index": output_index,
"item_id": item_id
});
*sequence_number += 1;
let in_progress_block = format!(
"event: {}\ndata: {}\n\n",
event_types::MCP_CALL_IN_PROGRESS,
in_progress_event
in_progress_event_type, in_progress_event
);
if tx.send(Ok(Bytes::from(in_progress_block))).is_err() {
return false;
}
// For web_search_call, also send a "searching" event
if tool_context.is_web_search() {
let searching_event = json!({
"type": event_types::WEB_SEARCH_CALL_SEARCHING,
"sequence_number": *sequence_number,
"output_index": output_index,
"item_id": item_id
});
*sequence_number += 1;
let searching_block = format!(
"event: {}\ndata: {}\n\n",
event_types::WEB_SEARCH_CALL_SEARCHING,
searching_event
);
if tx.send(Ok(Bytes::from(searching_block))).is_err() {
return false;
}
}
}
}
}
......@@ -909,6 +987,7 @@ pub(super) fn send_final_response_event(
original_request: &ResponsesRequest,
previous_response_id: Option<&str>,
server_label: &str,
tool_context: ToolContext,
) -> bool {
let mut final_response = match handler.snapshot_final_response() {
Some(resp) => resp,
......@@ -925,7 +1004,7 @@ pub(super) fn send_final_response_event(
}
if let Some(mcp) = active_mcp {
inject_mcp_metadata_streaming(&mut final_response, state, mcp, server_label);
inject_mcp_metadata_streaming(&mut final_response, state, mcp, server_label, tool_context);
}
mask_tools_as_mcp(&mut final_response, original_request);
......@@ -1137,9 +1216,10 @@ pub(super) async fn handle_streaming_with_tool_interception(
original_body: &ResponsesRequest,
original_previous_response_id: Option<String>,
active_mcp: &Arc<crate::mcp::McpManager>,
tool_context: ToolContext,
) -> Response {
// Transform MCP tools to function tools in payload
prepare_mcp_payload_for_streaming(&mut payload, active_mcp);
prepare_mcp_payload_for_streaming(&mut payload, active_mcp, tool_context);
let (tx, rx) = mpsc::unbounded_channel::<Result<Bytes, io::Error>>();
let should_store = original_body.store.unwrap_or(false);
......@@ -1156,7 +1236,10 @@ pub(super) async fn handle_streaming_with_tool_interception(
// Spawn the streaming loop task
tokio::spawn(async move {
let mut state = ToolLoopState::new(original_request.input.clone());
let loop_config = McpLoopConfig::default();
let loop_config = McpLoopConfig {
tool_context,
..Default::default()
};
let max_tool_calls = original_request.max_tool_calls.map(|n| n as usize);
let tools_json = payload_clone.get("tools").cloned().unwrap_or(json!([]));
let base_payload = payload_clone.clone();
......@@ -1275,6 +1358,7 @@ pub(super) async fn handle_streaming_with_tool_interception(
&original_request,
previous_response_id.as_deref(),
&mut sequence_number,
loop_config.tool_context,
) {
// Client disconnected
return;
......@@ -1290,7 +1374,10 @@ pub(super) async fn handle_streaming_with_tool_interception(
== Some(event_types::RESPONSE_IN_PROGRESS)
{
seen_in_progress = true;
if !mcp_list_tools_sent {
// Skip mcp_list_tools for web_search_preview
if !mcp_list_tools_sent
&& !loop_config.tool_context.is_web_search()
{
let list_tools_index =
handler.allocate_synthetic_output_index();
if !send_mcp_list_tools_events(
......@@ -1323,6 +1410,7 @@ pub(super) async fn handle_streaming_with_tool_interception(
&original_request,
previous_response_id.as_deref(),
&mut sequence_number,
loop_config.tool_context,
) {
// Client disconnected
return;
......@@ -1361,6 +1449,7 @@ pub(super) async fn handle_streaming_with_tool_interception(
&original_request,
previous_response_id.as_deref(),
server_label,
tool_context,
) {
return;
}
......@@ -1382,6 +1471,7 @@ pub(super) async fn handle_streaming_with_tool_interception(
&state,
&active_mcp_clone,
server_label,
tool_context,
);
mask_tools_as_mcp(&mut response_json, &original_request);
......@@ -1443,6 +1533,7 @@ pub(super) async fn handle_streaming_with_tool_interception(
&mut state,
server_label,
&mut sequence_number,
tool_context,
)
.await
{
......@@ -1498,6 +1589,7 @@ pub(super) async fn handle_streaming_response(
payload: Value,
original_body: &ResponsesRequest,
original_previous_response_id: Option<String>,
tool_context: ToolContext,
) -> Response {
// Check if MCP is active for this request
// Ensure dynamic client is created if needed
......@@ -1545,6 +1637,7 @@ pub(super) async fn handle_streaming_response(
original_body,
original_previous_response_id,
active_mcp,
tool_context,
)
.await
}
......@@ -32,12 +32,59 @@ pub(crate) mod event_types {
pub const MCP_LIST_TOOLS_IN_PROGRESS: &str = "response.mcp_list_tools.in_progress";
pub const MCP_LIST_TOOLS_COMPLETED: &str = "response.mcp_list_tools.completed";
// Web Search Call events (for web_search_preview)
pub const WEB_SEARCH_CALL_IN_PROGRESS: &str = "response.web_search_call.in_progress";
pub const WEB_SEARCH_CALL_SEARCHING: &str = "response.web_search_call.searching";
pub const WEB_SEARCH_CALL_COMPLETED: &str = "response.web_search_call.completed";
// Item types
pub const ITEM_TYPE_FUNCTION_CALL: &str = "function_call";
pub const ITEM_TYPE_FUNCTION_TOOL_CALL: &str = "function_tool_call";
pub const ITEM_TYPE_MCP_CALL: &str = "mcp_call";
pub const ITEM_TYPE_FUNCTION: &str = "function";
pub const ITEM_TYPE_MCP_LIST_TOOLS: &str = "mcp_list_tools";
pub const ITEM_TYPE_WEB_SEARCH_CALL: &str = "web_search_call";
}
// ============================================================================
// Web Search Constants
// ============================================================================
/// Constants for web search preview feature
pub(crate) mod web_search_constants {
/// MCP server name for web search preview
pub const WEB_SEARCH_PREVIEW_SERVER_NAME: &str = "web_search_preview";
/// Status constants
pub const STATUS_COMPLETED: &str = "completed";
pub const STATUS_FAILED: &str = "failed";
/// Action type for web search
pub const ACTION_TYPE_SEARCH: &str = "search";
}
// ============================================================================
// Tool Context Enum
// ============================================================================
/// Represents the context for tool handling strategy
///
/// This enum replaces boolean flags for better type safety and clarity.
/// It makes the code more maintainable and easier to extend with new
/// tool handling strategies in the future.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum ToolContext {
/// Regular MCP tool handling with full mcp_call and mcp_list_tools items
Regular,
/// Web search preview handling with simplified web_search_call items
WebSearchPreview,
}
impl ToolContext {
/// Check if this is web search preview context
pub fn is_web_search(&self) -> bool {
matches!(self, ToolContext::WebSearchPreview)
}
}
// ============================================================================
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment