Unverified Commit 74243dff authored by Keyang Ru's avatar Keyang Ru Committed by GitHub
Browse files

Revert "[router] web_search_preview tool basic implementation" (#12716)

parent 4ea4c48b
......@@ -3,11 +3,10 @@
//! This module contains all MCP-related functionality for the OpenAI router:
//! - Tool loop state management for multi-turn tool calling
//! - MCP tool execution and result handling
//! - Output item builders for MCP-specific response formats (including web_search_call)
//! - Output item builders for MCP-specific response formats
//! - SSE event generation for streaming MCP operations
//! - Payload transformation for MCP tool interception
//! - Metadata injection for MCP operations
//! - Web search preview tool handling (simplified MCP interface)
use std::{io, sync::Arc};
......@@ -17,7 +16,7 @@ use serde_json::{json, to_value, Value};
use tokio::sync::mpsc;
use tracing::{debug, info, warn};
use super::utils::{event_types, web_search_constants, ToolContext};
use super::utils::event_types;
use crate::{
mcp,
protocols::responses::{
......@@ -37,16 +36,11 @@ pub(crate) struct McpLoopConfig {
/// Maximum iterations as safety limit (internal only, default: 10)
/// Prevents infinite loops when max_tool_calls is not set
pub max_iterations: usize,
/// Tool context for handling web_search_preview vs regular tools
pub tool_context: ToolContext,
}
impl Default for McpLoopConfig {
fn default() -> Self {
Self {
max_iterations: 10,
tool_context: ToolContext::Regular,
}
Self { max_iterations: 10 }
}
}
......@@ -164,13 +158,6 @@ pub async fn ensure_request_mcp_client(
.server_label
.clone()
.unwrap_or_else(|| "request-mcp".to_string());
// Validate that web_search_preview is not used as it's a reserved name
if name == web_search_constants::WEB_SEARCH_PREVIEW_SERVER_NAME {
warn!("Rejecting request MCP with reserved server name: {}", name);
return None;
}
let token = tool.authorization.clone();
let transport = if server_url.contains("/sse") {
mcp::McpTransport::Sse {
......@@ -215,7 +202,6 @@ pub(super) async fn execute_streaming_tool_calls(
state: &mut ToolLoopState,
server_label: &str,
sequence_number: &mut u64,
tool_context: ToolContext,
) -> bool {
// Execute all pending tool calls (sequential, as PR3 is skipped)
for call in pending_calls {
......@@ -272,7 +258,6 @@ pub(super) async fn execute_streaming_tool_calls(
success,
error_msg.as_deref(),
sequence_number,
tool_context,
) {
// Client disconnected, no point continuing tool execution
return false;
......@@ -292,7 +277,6 @@ pub(super) async fn execute_streaming_tool_calls(
pub(super) fn prepare_mcp_payload_for_streaming(
payload: &mut Value,
active_mcp: &Arc<mcp::McpManager>,
tool_context: ToolContext,
) {
if let Some(obj) = payload.as_object_mut() {
// Remove any non-function tools from outgoing payload
......@@ -307,27 +291,10 @@ pub(super) fn prepare_mcp_payload_for_streaming(
}
}
// Build function tools for discovered MCP tools
// Build function tools for all discovered MCP tools
let mut tools_json = Vec::new();
// Get tools with server names from inventory
// Returns Vec<(tool_name, server_name, Tool)>
let tools = active_mcp.inventory().list_tools();
// Filter tools based on context
let filtered_tools: Vec<_> = if tool_context.is_web_search() {
// Only include tools from web_search_preview server
tools
.into_iter()
.filter(|(_, server_name, _)| {
server_name == web_search_constants::WEB_SEARCH_PREVIEW_SERVER_NAME
})
.collect()
} else {
tools
};
for (_, _, t) in filtered_tools {
let tools = active_mcp.list_tools();
for t in tools {
let parameters = Value::Object((*t.input_schema).clone());
let tool = serde_json::json!({
"type": event_types::ITEM_TYPE_FUNCTION,
......@@ -499,7 +466,6 @@ pub(super) fn send_mcp_list_tools_events(
/// Send mcp_call completion events after tool execution
/// Returns false if client disconnected
#[allow(clippy::too_many_arguments)]
pub(super) fn send_mcp_call_completion_events_with_error(
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
call: &FunctionCallInProgress,
......@@ -508,7 +474,6 @@ pub(super) fn send_mcp_call_completion_events_with_error(
success: bool,
error_msg: Option<&str>,
sequence_number: &mut u64,
tool_context: ToolContext,
) -> bool {
let effective_output_index = call.effective_output_index();
......@@ -520,24 +485,17 @@ pub(super) fn send_mcp_call_completion_events_with_error(
server_label,
success,
error_msg,
tool_context,
);
// Get the item_id
// Get the mcp_call item_id
let item_id = mcp_call_item
.get("id")
.and_then(|v| v.as_str())
.unwrap_or("");
// Event 1: response.{web_search_call|mcp_call}.completed
let completed_event_type = if tool_context.is_web_search() {
event_types::WEB_SEARCH_CALL_COMPLETED
} else {
event_types::MCP_CALL_COMPLETED
};
// Event 1: response.mcp_call.completed
let completed_payload = json!({
"type": completed_event_type,
"type": event_types::MCP_CALL_COMPLETED,
"sequence_number": *sequence_number,
"output_index": effective_output_index,
"item_id": item_id
......@@ -546,7 +504,8 @@ pub(super) fn send_mcp_call_completion_events_with_error(
let completed_event = format!(
"event: {}\ndata: {}\n\n",
completed_event_type, completed_payload
event_types::MCP_CALL_COMPLETED,
completed_payload
);
if tx.send(Ok(Bytes::from(completed_event))).is_err() {
return false;
......@@ -579,40 +538,28 @@ pub(super) fn inject_mcp_metadata_streaming(
state: &ToolLoopState,
mcp: &Arc<mcp::McpManager>,
server_label: &str,
tool_context: ToolContext,
) {
if let Some(output_array) = response.get_mut("output").and_then(|v| v.as_array_mut()) {
output_array.retain(|item| {
item.get("type").and_then(|t| t.as_str()) != Some(event_types::ITEM_TYPE_MCP_LIST_TOOLS)
});
let mut insert_pos = 0;
// Only add mcp_list_tools for non-web-search cases
if !tool_context.is_web_search() {
let list_tools_item = build_mcp_list_tools_item(mcp, server_label);
output_array.insert(0, list_tools_item);
insert_pos = 1;
}
let mcp_call_items =
build_executed_mcp_call_items(&state.conversation_history, server_label, tool_context);
build_executed_mcp_call_items(&state.conversation_history, server_label);
let mut insert_pos = 1;
for item in mcp_call_items {
output_array.insert(insert_pos, item);
insert_pos += 1;
}
} else if let Some(obj) = response.as_object_mut() {
let mut output_items = Vec::new();
// Only add mcp_list_tools for non-web-search cases
if !tool_context.is_web_search() {
output_items.push(build_mcp_list_tools_item(mcp, server_label));
}
output_items.extend(build_executed_mcp_call_items(
&state.conversation_history,
server_label,
tool_context,
));
obj.insert("output".to_string(), Value::Array(output_items));
}
......@@ -711,7 +658,6 @@ pub(super) async fn execute_tool_loop(
"max_tool_calls",
active_mcp,
original_body,
config.tool_context,
);
}
......@@ -770,28 +716,22 @@ pub(super) async fn execute_tool_loop(
})
.unwrap_or("mcp");
// Build mcp_list_tools item
let list_tools_item = build_mcp_list_tools_item(active_mcp, server_label);
// Insert at beginning of output array
if let Some(output_array) = response_json
.get_mut("output")
.and_then(|v| v.as_array_mut())
{
let mut insert_pos = 0;
// Only add mcp_list_tools for non-web-search cases
if !config.tool_context.is_web_search() {
let list_tools_item = build_mcp_list_tools_item(active_mcp, server_label);
output_array.insert(0, list_tools_item);
insert_pos = 1;
}
// Build mcp_call items (will be web_search_call for web search tools)
let mcp_call_items = build_executed_mcp_call_items(
&state.conversation_history,
server_label,
config.tool_context,
);
// Build mcp_call items using helper function
let mcp_call_items =
build_executed_mcp_call_items(&state.conversation_history, server_label);
// Insert call items after mcp_list_tools (if present)
// Insert mcp_call items after mcp_list_tools using mutable position
let mut insert_pos = 1;
for item in mcp_call_items {
output_array.insert(insert_pos, item);
insert_pos += 1;
......@@ -811,17 +751,13 @@ pub(super) fn build_incomplete_response(
reason: &str,
active_mcp: &Arc<mcp::McpManager>,
original_body: &ResponsesRequest,
tool_context: ToolContext,
) -> Result<Value, String> {
let obj = response
.as_object_mut()
.ok_or_else(|| "response not an object".to_string())?;
// Set status to completed (not failed - partial success)
obj.insert(
"status".to_string(),
Value::String(web_search_constants::STATUS_COMPLETED.to_string()),
);
obj.insert("status".to_string(), Value::String("completed".to_string()));
// Set incomplete_details
obj.insert(
......@@ -863,7 +799,6 @@ pub(super) fn build_incomplete_response(
server_label,
false, // Not successful
Some("Not executed - response stopped due to limit"),
tool_context,
);
mcp_call_items.push(mcp_call_item);
}
......@@ -871,28 +806,20 @@ pub(super) fn build_incomplete_response(
// Add mcp_list_tools and executed mcp_call items at the beginning
if state.total_calls > 0 || !mcp_call_items.is_empty() {
let mut insert_pos = 0;
// Only add mcp_list_tools for non-web-search cases
if !tool_context.is_web_search() {
let list_tools_item = build_mcp_list_tools_item(active_mcp, server_label);
output_array.insert(0, list_tools_item);
insert_pos = 1;
}
// Add mcp_call items for executed calls (will be web_search_call for web search)
let executed_items = build_executed_mcp_call_items(
&state.conversation_history,
server_label,
tool_context,
);
// Add mcp_call items for executed calls using helper
let executed_items =
build_executed_mcp_call_items(&state.conversation_history, server_label);
let mut insert_pos = 1;
for item in executed_items {
output_array.insert(insert_pos, item);
insert_pos += 1;
}
// Add incomplete mcp_call items (will be web_search_call for web search)
// Add incomplete mcp_call items
for item in mcp_call_items {
output_array.insert(insert_pos, item);
insert_pos += 1;
......@@ -920,67 +847,6 @@ pub(super) fn build_incomplete_response(
Ok(response)
}
// ============================================================================
// Web Search Preview Helpers
// ============================================================================
/// Detect if request has web_search_preview tool
pub(super) fn has_web_search_preview_tool(tools: &[ResponseTool]) -> bool {
tools
.iter()
.any(|t| matches!(t.r#type, ResponseToolType::WebSearchPreview))
}
/// Check if MCP server "web_search_preview" is available
pub(super) async fn is_web_search_mcp_available(mcp_manager: &Arc<mcp::McpManager>) -> bool {
mcp_manager
.get_client(web_search_constants::WEB_SEARCH_PREVIEW_SERVER_NAME)
.await
.is_some()
}
/// Build a web_search_call output item (MVP - status only)
///
/// The MCP search results are passed to the LLM internally via function_call_output,
/// but we don't expose them in the web_search_call item to the client.
fn build_web_search_call_item(query: Option<String>) -> Value {
let mut action = serde_json::Map::new();
action.insert(
"type".to_string(),
Value::String(web_search_constants::ACTION_TYPE_SEARCH.to_string()),
);
if let Some(q) = query {
action.insert("query".to_string(), Value::String(q));
}
json!({
"id": generate_id("ws"),
"type": event_types::ITEM_TYPE_WEB_SEARCH_CALL,
"status": web_search_constants::STATUS_COMPLETED,
"action": action
})
}
/// Build a failed web_search_call output item
fn build_web_search_call_item_failed(error: &str, query: Option<String>) -> Value {
let mut action = serde_json::Map::new();
action.insert(
"type".to_string(),
Value::String(web_search_constants::ACTION_TYPE_SEARCH.to_string()),
);
if let Some(q) = query {
action.insert("query".to_string(), Value::String(q));
}
json!({
"id": generate_id("ws"),
"type": event_types::ITEM_TYPE_WEB_SEARCH_CALL,
"status": web_search_constants::STATUS_FAILED,
"action": action,
"error": error
})
}
// ============================================================================
// Output Item Builders
// ============================================================================
......@@ -1018,32 +884,11 @@ pub(super) fn build_mcp_call_item(
server_label: &str,
success: bool,
error: Option<&str>,
tool_context: ToolContext,
) -> Value {
// Check if this is a web_search_preview context - if so, build web_search_call format
if tool_context.is_web_search() {
// Extract query from arguments for web_search_call
let query = serde_json::from_str::<Value>(arguments).ok().and_then(|v| {
v.get("query")
.and_then(|q| q.as_str().map(|s| s.to_string()))
});
// Build web_search_call item (MVP - status only, no results)
if success {
build_web_search_call_item(query)
} else {
build_web_search_call_item_failed(error.unwrap_or("Tool execution failed"), query)
}
} else {
// Regular mcp_call item
json!({
"id": generate_id("mcp"),
"type": event_types::ITEM_TYPE_MCP_CALL,
"status": if success {
web_search_constants::STATUS_COMPLETED
} else {
web_search_constants::STATUS_FAILED
},
"status": if success { "completed" } else { "failed" },
"approval_request_id": Value::Null,
"arguments": arguments,
"error": error,
......@@ -1051,14 +896,12 @@ pub(super) fn build_mcp_call_item(
"output": output,
"server_label": server_label
})
}
}
/// Helper function to build mcp_call items from executed tool calls in conversation history
pub(super) fn build_executed_mcp_call_items(
conversation_history: &[Value],
server_label: &str,
tool_context: ToolContext,
) -> Vec<Value> {
let mut mcp_call_items = Vec::new();
......@@ -1097,7 +940,6 @@ pub(super) fn build_executed_mcp_call_items(
} else {
None
},
tool_context,
);
mcp_call_items.push(mcp_call_item);
}
......
......@@ -5,7 +5,7 @@ use std::collections::HashMap;
use serde_json::{json, Value};
use tracing::warn;
use super::utils::{event_types, web_search_constants};
use super::utils::event_types;
use crate::{
data_connector::{ResponseId, StoredResponse},
protocols::responses::{ResponseToolType, ResponsesRequest},
......@@ -276,29 +276,15 @@ pub(super) fn rewrite_streaming_block(
/// Mask function tools as MCP tools in response for client
pub(super) fn mask_tools_as_mcp(resp: &mut Value, original_body: &ResponsesRequest) {
// Check for MCP tool
let mcp_tool = original_body.tools.as_ref().and_then(|tools| {
tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some())
});
// Check for web_search_preview tool
let has_web_search = original_body
.tools
.as_ref()
.map(|tools| crate::routers::openai::mcp::has_web_search_preview_tool(tools))
.unwrap_or(false);
// If neither MCP nor web_search_preview, return early
if mcp_tool.is_none() && !has_web_search {
let Some(t) = mcp_tool else {
return;
}
let mut response_tools = Vec::new();
};
// Add MCP tool if present
if let Some(t) = mcp_tool {
let mut m = serde_json::Map::new();
m.insert("type".to_string(), Value::String("mcp".to_string()));
if let Some(label) = &t.server_label {
......@@ -322,21 +308,9 @@ pub(super) fn mask_tools_as_mcp(resp: &mut Value, original_body: &ResponsesReque
Value::Array(allowed.iter().map(|s| Value::String(s.clone())).collect()),
);
}
response_tools.push(Value::Object(m));
}
// Add web_search_preview tool if present
if has_web_search {
let mut ws = serde_json::Map::new();
ws.insert(
"type".to_string(),
Value::String(web_search_constants::WEB_SEARCH_PREVIEW_SERVER_NAME.to_string()),
);
response_tools.push(Value::Object(ws));
}
if let Some(obj) = resp.as_object_mut() {
obj.insert("tools".to_string(), Value::Array(response_tools));
obj.insert("tools".to_string(), Value::Array(vec![Value::Object(m)]));
obj.entry("tool_choice")
.or_insert(Value::String("auto".to_string()));
}
......
......@@ -30,12 +30,12 @@ use super::conversations::{
};
use super::{
mcp::{
ensure_request_mcp_client, execute_tool_loop, has_web_search_preview_tool,
is_web_search_mcp_available, prepare_mcp_payload_for_streaming, McpLoopConfig,
ensure_request_mcp_client, execute_tool_loop, prepare_mcp_payload_for_streaming,
McpLoopConfig,
},
responses::{mask_tools_as_mcp, patch_streaming_response_json},
streaming::handle_streaming_response,
utils::{apply_provider_headers, extract_auth_header, probe_endpoint_for_model, ToolContext},
utils::{apply_provider_headers, extract_auth_header, probe_endpoint_for_model},
};
use crate::{
core::{CircuitBreaker, CircuitBreakerConfig as CoreCircuitBreakerConfig},
......@@ -248,7 +248,6 @@ impl OpenAIRouter {
mut payload: Value,
original_body: &ResponsesRequest,
original_previous_response_id: Option<String>,
tool_context: ToolContext,
) -> Response {
// Check if MCP is active for this request
// Ensure dynamic client is created if needed
......@@ -267,13 +266,10 @@ impl OpenAIRouter {
// If MCP is active, execute tool loop
if let Some(mcp) = active_mcp {
let config = McpLoopConfig {
tool_context,
..Default::default()
};
let config = McpLoopConfig::default();
// Transform MCP tools to function tools
prepare_mcp_payload_for_streaming(&mut payload, mcp, tool_context);
prepare_mcp_payload_for_streaming(&mut payload, mcp);
match execute_tool_loop(
&self.client,
......@@ -699,35 +695,6 @@ impl crate::routers::RouterTrait for OpenAIRouter {
let url = format!("{}/v1/responses", base_url);
// Detect web_search_preview tool and verify MCP server availability
let tool_context = if let Some(ref tools) = body.tools {
if has_web_search_preview_tool(tools) {
ToolContext::WebSearchPreview
} else {
ToolContext::Regular
}
} else {
ToolContext::Regular
};
if tool_context.is_web_search() {
// Check if web_search_preview MCP server is available
if !is_web_search_mcp_available(&self.mcp_manager).await {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": {
"message": "Web search preview is currently unavailable. Please contact your server administrator.",
"type": "invalid_request_error",
"param": "tools",
"code": "web_search_unavailable"
}
})),
)
.into_response();
}
}
// Validate mutually exclusive params: previous_response_id and conversation
// TODO: this validation logic should move the right place, also we need a proper error message module
if body.previous_response_id.is_some() && body.conversation.is_some() {
......@@ -1055,7 +1022,6 @@ impl crate::routers::RouterTrait for OpenAIRouter {
payload,
body,
original_previous_response_id,
tool_context,
)
.await
} else {
......@@ -1065,7 +1031,6 @@ impl crate::routers::RouterTrait for OpenAIRouter {
payload,
body,
original_previous_response_id,
tool_context,
)
.await
}
......
......@@ -30,10 +30,7 @@ use super::{
send_mcp_list_tools_events, McpLoopConfig, ToolLoopState,
},
responses::{mask_tools_as_mcp, patch_streaming_response_json, rewrite_streaming_block},
utils::{
event_types, web_search_constants, FunctionCallInProgress, OutputIndexMapper, StreamAction,
ToolContext,
},
utils::{event_types, FunctionCallInProgress, OutputIndexMapper, StreamAction},
};
use crate::{
data_connector::{ConversationItemStorage, ConversationStorage, ResponseStorage},
......@@ -556,7 +553,6 @@ pub(super) fn apply_event_transformations_inplace(
server_label: &str,
original_request: &ResponsesRequest,
previous_response_id: Option<&str>,
tool_context: ToolContext,
) -> bool {
let mut changed = false;
......@@ -602,17 +598,6 @@ pub(super) fn apply_event_transformations_inplace(
// Mask tools from function to MCP format (optimized without cloning)
if response_obj.get("tools").is_some() {
// For web_search_preview, always use simplified tool format
if tool_context.is_web_search() {
let web_search_tool =
json!([{"type": web_search_constants::WEB_SEARCH_PREVIEW_SERVER_NAME}]);
response_obj.insert("tools".to_string(), web_search_tool);
response_obj
.entry("tool_choice".to_string())
.or_insert(Value::String("auto".to_string()));
changed = true;
} else {
// Regular MCP tools - only if requested
let requested_mcp = original_request
.tools
.as_ref()
......@@ -635,7 +620,6 @@ pub(super) fn apply_event_transformations_inplace(
}
}
}
}
// 2. Apply transform_streaming_event logic (function_call → mcp_call)
match event_type.as_str() {
......@@ -645,30 +629,13 @@ pub(super) fn apply_event_transformations_inplace(
if item_type == event_types::ITEM_TYPE_FUNCTION_CALL
|| item_type == event_types::ITEM_TYPE_FUNCTION_TOOL_CALL
{
// Use web_search_call for web_search_preview, mcp_call for regular MCP
if tool_context.is_web_search() {
item["type"] = json!(event_types::ITEM_TYPE_WEB_SEARCH_CALL);
// Don't include server_label for web_search_call
// Remove internal implementation fields
if let Some(obj) = item.as_object_mut() {
obj.remove("arguments");
obj.remove("call_id");
obj.remove("name");
}
} else {
item["type"] = json!(event_types::ITEM_TYPE_MCP_CALL);
item["server_label"] = json!(server_label);
}
// Transform ID from fc_* to ws_* or mcp_*
// Transform ID from fc_* to mcp_*
if let Some(id) = item.get("id").and_then(|v| v.as_str()) {
if let Some(stripped) = id.strip_prefix("fc_") {
let prefix = if tool_context.is_web_search() {
"ws"
} else {
"mcp"
};
let new_id = format!("{}_{}", prefix, stripped);
let new_id = format!("mcp_{}", stripped);
item["id"] = json!(new_id);
}
}
......@@ -726,7 +693,6 @@ pub(super) fn forward_streaming_event(
original_request: &ResponsesRequest,
previous_response_id: Option<&str>,
sequence_number: &mut u64,
tool_context: ToolContext,
) -> bool {
// Skip individual function_call_arguments.delta events - we'll send them as one
if event_name == Some(event_types::FUNCTION_CALL_ARGUMENTS_DELTA) {
......@@ -791,8 +757,6 @@ pub(super) fn forward_streaming_event(
};
// Emit a synthetic MCP arguments delta event before the done event
// Skip for web_search_preview - we don't expose tool call arguments
if !tool_context.is_web_search() {
let mut delta_event = json!({
"type": event_types::MCP_CALL_ARGUMENTS_DELTA,
"sequence_number": *sequence_number,
......@@ -827,7 +791,6 @@ pub(super) fn forward_streaming_event(
}
}
}
}
// Remap output_index (if present) so downstream sees sequential indices
if mapped_output_index.is_none() {
......@@ -850,7 +813,6 @@ pub(super) fn forward_streaming_event(
server_label,
original_request,
previous_response_id,
tool_context,
);
if let Some(response_obj) = parsed_data
......@@ -882,24 +844,16 @@ pub(super) fn forward_streaming_event(
let mut final_block = String::new();
if let Some(evt) = event_name {
// Update event name for function_call_arguments events
// Skip for web_search_preview - we don't expose tool call arguments
if evt == event_types::FUNCTION_CALL_ARGUMENTS_DELTA && !tool_context.is_web_search() {
if evt == event_types::FUNCTION_CALL_ARGUMENTS_DELTA {
final_block.push_str(&format!(
"event: {}\n",
event_types::MCP_CALL_ARGUMENTS_DELTA
));
} else if evt == event_types::FUNCTION_CALL_ARGUMENTS_DONE && !tool_context.is_web_search()
{
} else if evt == event_types::FUNCTION_CALL_ARGUMENTS_DONE {
final_block.push_str(&format!(
"event: {}\n",
event_types::MCP_CALL_ARGUMENTS_DONE
));
} else if (evt == event_types::FUNCTION_CALL_ARGUMENTS_DELTA
|| evt == event_types::FUNCTION_CALL_ARGUMENTS_DONE)
&& tool_context.is_web_search()
{
// Skip these events entirely for web_search_preview
return true;
} else {
final_block.push_str(&format!("event: {}\n", evt));
}
......@@ -911,62 +865,30 @@ pub(super) fn forward_streaming_event(
return false;
}
// After sending output_item.added for mcp_call/web_search_call, inject in_progress event
// After sending output_item.added for mcp_call, inject mcp_call.in_progress event
if event_name == Some(event_types::OUTPUT_ITEM_ADDED) {
if let Some(item) = parsed_data.get("item") {
let item_type = item.get("type").and_then(|v| v.as_str());
// Check if it's an mcp_call or web_search_call
let is_mcp_or_web_search = item_type == Some(event_types::ITEM_TYPE_MCP_CALL)
|| item_type == Some(event_types::ITEM_TYPE_WEB_SEARCH_CALL);
if is_mcp_or_web_search {
if item.get("type").and_then(|v| v.as_str()) == Some(event_types::ITEM_TYPE_MCP_CALL) {
// Already transformed to mcp_call
if let (Some(item_id), Some(output_index)) = (
item.get("id").and_then(|v| v.as_str()),
parsed_data.get("output_index").and_then(|v| v.as_u64()),
) {
// Choose event type based on tool_context
let in_progress_event_type = if tool_context.is_web_search() {
event_types::WEB_SEARCH_CALL_IN_PROGRESS
} else {
event_types::MCP_CALL_IN_PROGRESS
};
let in_progress_event = json!({
"type": in_progress_event_type,
"type": event_types::MCP_CALL_IN_PROGRESS,
"sequence_number": *sequence_number,
"output_index": output_index,
"item_id": item_id
});
*sequence_number += 1;
let in_progress_block = format!(
"event: {}\ndata: {}\n\n",
in_progress_event_type, in_progress_event
event_types::MCP_CALL_IN_PROGRESS,
in_progress_event
);
if tx.send(Ok(Bytes::from(in_progress_block))).is_err() {
return false;
}
// For web_search_call, also send a "searching" event
if tool_context.is_web_search() {
let searching_event = json!({
"type": event_types::WEB_SEARCH_CALL_SEARCHING,
"sequence_number": *sequence_number,
"output_index": output_index,
"item_id": item_id
});
*sequence_number += 1;
let searching_block = format!(
"event: {}\ndata: {}\n\n",
event_types::WEB_SEARCH_CALL_SEARCHING,
searching_event
);
if tx.send(Ok(Bytes::from(searching_block))).is_err() {
return false;
}
}
}
}
}
......@@ -987,7 +909,6 @@ pub(super) fn send_final_response_event(
original_request: &ResponsesRequest,
previous_response_id: Option<&str>,
server_label: &str,
tool_context: ToolContext,
) -> bool {
let mut final_response = match handler.snapshot_final_response() {
Some(resp) => resp,
......@@ -1004,7 +925,7 @@ pub(super) fn send_final_response_event(
}
if let Some(mcp) = active_mcp {
inject_mcp_metadata_streaming(&mut final_response, state, mcp, server_label, tool_context);
inject_mcp_metadata_streaming(&mut final_response, state, mcp, server_label);
}
mask_tools_as_mcp(&mut final_response, original_request);
......@@ -1216,10 +1137,9 @@ pub(super) async fn handle_streaming_with_tool_interception(
original_body: &ResponsesRequest,
original_previous_response_id: Option<String>,
active_mcp: &Arc<crate::mcp::McpManager>,
tool_context: ToolContext,
) -> Response {
// Transform MCP tools to function tools in payload
prepare_mcp_payload_for_streaming(&mut payload, active_mcp, tool_context);
prepare_mcp_payload_for_streaming(&mut payload, active_mcp);
let (tx, rx) = mpsc::unbounded_channel::<Result<Bytes, io::Error>>();
let should_store = original_body.store.unwrap_or(false);
......@@ -1236,10 +1156,7 @@ pub(super) async fn handle_streaming_with_tool_interception(
// Spawn the streaming loop task
tokio::spawn(async move {
let mut state = ToolLoopState::new(original_request.input.clone());
let loop_config = McpLoopConfig {
tool_context,
..Default::default()
};
let loop_config = McpLoopConfig::default();
let max_tool_calls = original_request.max_tool_calls.map(|n| n as usize);
let tools_json = payload_clone.get("tools").cloned().unwrap_or(json!([]));
let base_payload = payload_clone.clone();
......@@ -1358,7 +1275,6 @@ pub(super) async fn handle_streaming_with_tool_interception(
&original_request,
previous_response_id.as_deref(),
&mut sequence_number,
loop_config.tool_context,
) {
// Client disconnected
return;
......@@ -1374,10 +1290,7 @@ pub(super) async fn handle_streaming_with_tool_interception(
== Some(event_types::RESPONSE_IN_PROGRESS)
{
seen_in_progress = true;
// Skip mcp_list_tools for web_search_preview
if !mcp_list_tools_sent
&& !loop_config.tool_context.is_web_search()
{
if !mcp_list_tools_sent {
let list_tools_index =
handler.allocate_synthetic_output_index();
if !send_mcp_list_tools_events(
......@@ -1410,7 +1323,6 @@ pub(super) async fn handle_streaming_with_tool_interception(
&original_request,
previous_response_id.as_deref(),
&mut sequence_number,
loop_config.tool_context,
) {
// Client disconnected
return;
......@@ -1449,7 +1361,6 @@ pub(super) async fn handle_streaming_with_tool_interception(
&original_request,
previous_response_id.as_deref(),
server_label,
tool_context,
) {
return;
}
......@@ -1471,7 +1382,6 @@ pub(super) async fn handle_streaming_with_tool_interception(
&state,
&active_mcp_clone,
server_label,
tool_context,
);
mask_tools_as_mcp(&mut response_json, &original_request);
......@@ -1533,7 +1443,6 @@ pub(super) async fn handle_streaming_with_tool_interception(
&mut state,
server_label,
&mut sequence_number,
tool_context,
)
.await
{
......@@ -1589,7 +1498,6 @@ pub(super) async fn handle_streaming_response(
payload: Value,
original_body: &ResponsesRequest,
original_previous_response_id: Option<String>,
tool_context: ToolContext,
) -> Response {
// Check if MCP is active for this request
// Ensure dynamic client is created if needed
......@@ -1637,7 +1545,6 @@ pub(super) async fn handle_streaming_response(
original_body,
original_previous_response_id,
active_mcp,
tool_context,
)
.await
}
......@@ -32,59 +32,12 @@ pub(crate) mod event_types {
pub const MCP_LIST_TOOLS_IN_PROGRESS: &str = "response.mcp_list_tools.in_progress";
pub const MCP_LIST_TOOLS_COMPLETED: &str = "response.mcp_list_tools.completed";
// Web Search Call events (for web_search_preview)
pub const WEB_SEARCH_CALL_IN_PROGRESS: &str = "response.web_search_call.in_progress";
pub const WEB_SEARCH_CALL_SEARCHING: &str = "response.web_search_call.searching";
pub const WEB_SEARCH_CALL_COMPLETED: &str = "response.web_search_call.completed";
// Item types
pub const ITEM_TYPE_FUNCTION_CALL: &str = "function_call";
pub const ITEM_TYPE_FUNCTION_TOOL_CALL: &str = "function_tool_call";
pub const ITEM_TYPE_MCP_CALL: &str = "mcp_call";
pub const ITEM_TYPE_FUNCTION: &str = "function";
pub const ITEM_TYPE_MCP_LIST_TOOLS: &str = "mcp_list_tools";
pub const ITEM_TYPE_WEB_SEARCH_CALL: &str = "web_search_call";
}
// ============================================================================
// Web Search Constants
// ============================================================================
/// Constants for web search preview feature
pub(crate) mod web_search_constants {
/// MCP server name for web search preview
pub const WEB_SEARCH_PREVIEW_SERVER_NAME: &str = "web_search_preview";
/// Status constants
pub const STATUS_COMPLETED: &str = "completed";
pub const STATUS_FAILED: &str = "failed";
/// Action type for web search
pub const ACTION_TYPE_SEARCH: &str = "search";
}
// ============================================================================
// Tool Context Enum
// ============================================================================
/// Represents the context for tool handling strategy
///
/// This enum replaces boolean flags for better type safety and clarity.
/// It makes the code more maintainable and easier to extend with new
/// tool handling strategies in the future.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum ToolContext {
/// Regular MCP tool handling with full mcp_call and mcp_list_tools items
Regular,
/// Web search preview handling with simplified web_search_call items
WebSearchPreview,
}
impl ToolContext {
/// Check if this is web search preview context
pub fn is_web_search(&self) -> bool {
matches!(self, ToolContext::WebSearchPreview)
}
}
// ============================================================================
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment