Unverified Commit 74243dff authored by Keyang Ru's avatar Keyang Ru Committed by GitHub
Browse files

Revert "[router] web_search_preview tool basic implementation" (#12716)

parent 4ea4c48b
...@@ -3,11 +3,10 @@ ...@@ -3,11 +3,10 @@
//! This module contains all MCP-related functionality for the OpenAI router: //! This module contains all MCP-related functionality for the OpenAI router:
//! - Tool loop state management for multi-turn tool calling //! - Tool loop state management for multi-turn tool calling
//! - MCP tool execution and result handling //! - MCP tool execution and result handling
//! - Output item builders for MCP-specific response formats (including web_search_call) //! - Output item builders for MCP-specific response formats
//! - SSE event generation for streaming MCP operations //! - SSE event generation for streaming MCP operations
//! - Payload transformation for MCP tool interception //! - Payload transformation for MCP tool interception
//! - Metadata injection for MCP operations //! - Metadata injection for MCP operations
//! - Web search preview tool handling (simplified MCP interface)
use std::{io, sync::Arc}; use std::{io, sync::Arc};
...@@ -17,7 +16,7 @@ use serde_json::{json, to_value, Value}; ...@@ -17,7 +16,7 @@ use serde_json::{json, to_value, Value};
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tracing::{debug, info, warn}; use tracing::{debug, info, warn};
use super::utils::{event_types, web_search_constants, ToolContext}; use super::utils::event_types;
use crate::{ use crate::{
mcp, mcp,
protocols::responses::{ protocols::responses::{
...@@ -37,16 +36,11 @@ pub(crate) struct McpLoopConfig { ...@@ -37,16 +36,11 @@ pub(crate) struct McpLoopConfig {
/// Maximum iterations as safety limit (internal only, default: 10) /// Maximum iterations as safety limit (internal only, default: 10)
/// Prevents infinite loops when max_tool_calls is not set /// Prevents infinite loops when max_tool_calls is not set
pub max_iterations: usize, pub max_iterations: usize,
/// Tool context for handling web_search_preview vs regular tools
pub tool_context: ToolContext,
} }
impl Default for McpLoopConfig { impl Default for McpLoopConfig {
fn default() -> Self { fn default() -> Self {
Self { Self { max_iterations: 10 }
max_iterations: 10,
tool_context: ToolContext::Regular,
}
} }
} }
...@@ -164,13 +158,6 @@ pub async fn ensure_request_mcp_client( ...@@ -164,13 +158,6 @@ pub async fn ensure_request_mcp_client(
.server_label .server_label
.clone() .clone()
.unwrap_or_else(|| "request-mcp".to_string()); .unwrap_or_else(|| "request-mcp".to_string());
// Validate that web_search_preview is not used as it's a reserved name
if name == web_search_constants::WEB_SEARCH_PREVIEW_SERVER_NAME {
warn!("Rejecting request MCP with reserved server name: {}", name);
return None;
}
let token = tool.authorization.clone(); let token = tool.authorization.clone();
let transport = if server_url.contains("/sse") { let transport = if server_url.contains("/sse") {
mcp::McpTransport::Sse { mcp::McpTransport::Sse {
...@@ -215,7 +202,6 @@ pub(super) async fn execute_streaming_tool_calls( ...@@ -215,7 +202,6 @@ pub(super) async fn execute_streaming_tool_calls(
state: &mut ToolLoopState, state: &mut ToolLoopState,
server_label: &str, server_label: &str,
sequence_number: &mut u64, sequence_number: &mut u64,
tool_context: ToolContext,
) -> bool { ) -> bool {
// Execute all pending tool calls (sequential, as PR3 is skipped) // Execute all pending tool calls (sequential, as PR3 is skipped)
for call in pending_calls { for call in pending_calls {
...@@ -272,7 +258,6 @@ pub(super) async fn execute_streaming_tool_calls( ...@@ -272,7 +258,6 @@ pub(super) async fn execute_streaming_tool_calls(
success, success,
error_msg.as_deref(), error_msg.as_deref(),
sequence_number, sequence_number,
tool_context,
) { ) {
// Client disconnected, no point continuing tool execution // Client disconnected, no point continuing tool execution
return false; return false;
...@@ -292,7 +277,6 @@ pub(super) async fn execute_streaming_tool_calls( ...@@ -292,7 +277,6 @@ pub(super) async fn execute_streaming_tool_calls(
pub(super) fn prepare_mcp_payload_for_streaming( pub(super) fn prepare_mcp_payload_for_streaming(
payload: &mut Value, payload: &mut Value,
active_mcp: &Arc<mcp::McpManager>, active_mcp: &Arc<mcp::McpManager>,
tool_context: ToolContext,
) { ) {
if let Some(obj) = payload.as_object_mut() { if let Some(obj) = payload.as_object_mut() {
// Remove any non-function tools from outgoing payload // Remove any non-function tools from outgoing payload
...@@ -307,27 +291,10 @@ pub(super) fn prepare_mcp_payload_for_streaming( ...@@ -307,27 +291,10 @@ pub(super) fn prepare_mcp_payload_for_streaming(
} }
} }
// Build function tools for discovered MCP tools // Build function tools for all discovered MCP tools
let mut tools_json = Vec::new(); let mut tools_json = Vec::new();
let tools = active_mcp.list_tools();
// Get tools with server names from inventory for t in tools {
// Returns Vec<(tool_name, server_name, Tool)>
let tools = active_mcp.inventory().list_tools();
// Filter tools based on context
let filtered_tools: Vec<_> = if tool_context.is_web_search() {
// Only include tools from web_search_preview server
tools
.into_iter()
.filter(|(_, server_name, _)| {
server_name == web_search_constants::WEB_SEARCH_PREVIEW_SERVER_NAME
})
.collect()
} else {
tools
};
for (_, _, t) in filtered_tools {
let parameters = Value::Object((*t.input_schema).clone()); let parameters = Value::Object((*t.input_schema).clone());
let tool = serde_json::json!({ let tool = serde_json::json!({
"type": event_types::ITEM_TYPE_FUNCTION, "type": event_types::ITEM_TYPE_FUNCTION,
...@@ -499,7 +466,6 @@ pub(super) fn send_mcp_list_tools_events( ...@@ -499,7 +466,6 @@ pub(super) fn send_mcp_list_tools_events(
/// Send mcp_call completion events after tool execution /// Send mcp_call completion events after tool execution
/// Returns false if client disconnected /// Returns false if client disconnected
#[allow(clippy::too_many_arguments)]
pub(super) fn send_mcp_call_completion_events_with_error( pub(super) fn send_mcp_call_completion_events_with_error(
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>, tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
call: &FunctionCallInProgress, call: &FunctionCallInProgress,
...@@ -508,7 +474,6 @@ pub(super) fn send_mcp_call_completion_events_with_error( ...@@ -508,7 +474,6 @@ pub(super) fn send_mcp_call_completion_events_with_error(
success: bool, success: bool,
error_msg: Option<&str>, error_msg: Option<&str>,
sequence_number: &mut u64, sequence_number: &mut u64,
tool_context: ToolContext,
) -> bool { ) -> bool {
let effective_output_index = call.effective_output_index(); let effective_output_index = call.effective_output_index();
...@@ -520,24 +485,17 @@ pub(super) fn send_mcp_call_completion_events_with_error( ...@@ -520,24 +485,17 @@ pub(super) fn send_mcp_call_completion_events_with_error(
server_label, server_label,
success, success,
error_msg, error_msg,
tool_context,
); );
// Get the item_id // Get the mcp_call item_id
let item_id = mcp_call_item let item_id = mcp_call_item
.get("id") .get("id")
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.unwrap_or(""); .unwrap_or("");
// Event 1: response.{web_search_call|mcp_call}.completed // Event 1: response.mcp_call.completed
let completed_event_type = if tool_context.is_web_search() {
event_types::WEB_SEARCH_CALL_COMPLETED
} else {
event_types::MCP_CALL_COMPLETED
};
let completed_payload = json!({ let completed_payload = json!({
"type": completed_event_type, "type": event_types::MCP_CALL_COMPLETED,
"sequence_number": *sequence_number, "sequence_number": *sequence_number,
"output_index": effective_output_index, "output_index": effective_output_index,
"item_id": item_id "item_id": item_id
...@@ -546,7 +504,8 @@ pub(super) fn send_mcp_call_completion_events_with_error( ...@@ -546,7 +504,8 @@ pub(super) fn send_mcp_call_completion_events_with_error(
let completed_event = format!( let completed_event = format!(
"event: {}\ndata: {}\n\n", "event: {}\ndata: {}\n\n",
completed_event_type, completed_payload event_types::MCP_CALL_COMPLETED,
completed_payload
); );
if tx.send(Ok(Bytes::from(completed_event))).is_err() { if tx.send(Ok(Bytes::from(completed_event))).is_err() {
return false; return false;
...@@ -579,40 +538,28 @@ pub(super) fn inject_mcp_metadata_streaming( ...@@ -579,40 +538,28 @@ pub(super) fn inject_mcp_metadata_streaming(
state: &ToolLoopState, state: &ToolLoopState,
mcp: &Arc<mcp::McpManager>, mcp: &Arc<mcp::McpManager>,
server_label: &str, server_label: &str,
tool_context: ToolContext,
) { ) {
if let Some(output_array) = response.get_mut("output").and_then(|v| v.as_array_mut()) { if let Some(output_array) = response.get_mut("output").and_then(|v| v.as_array_mut()) {
output_array.retain(|item| { output_array.retain(|item| {
item.get("type").and_then(|t| t.as_str()) != Some(event_types::ITEM_TYPE_MCP_LIST_TOOLS) item.get("type").and_then(|t| t.as_str()) != Some(event_types::ITEM_TYPE_MCP_LIST_TOOLS)
}); });
let mut insert_pos = 0; let list_tools_item = build_mcp_list_tools_item(mcp, server_label);
output_array.insert(0, list_tools_item);
// Only add mcp_list_tools for non-web-search cases
if !tool_context.is_web_search() {
let list_tools_item = build_mcp_list_tools_item(mcp, server_label);
output_array.insert(0, list_tools_item);
insert_pos = 1;
}
let mcp_call_items = let mcp_call_items =
build_executed_mcp_call_items(&state.conversation_history, server_label, tool_context); build_executed_mcp_call_items(&state.conversation_history, server_label);
let mut insert_pos = 1;
for item in mcp_call_items { for item in mcp_call_items {
output_array.insert(insert_pos, item); output_array.insert(insert_pos, item);
insert_pos += 1; insert_pos += 1;
} }
} else if let Some(obj) = response.as_object_mut() { } else if let Some(obj) = response.as_object_mut() {
let mut output_items = Vec::new(); let mut output_items = Vec::new();
output_items.push(build_mcp_list_tools_item(mcp, server_label));
// Only add mcp_list_tools for non-web-search cases
if !tool_context.is_web_search() {
output_items.push(build_mcp_list_tools_item(mcp, server_label));
}
output_items.extend(build_executed_mcp_call_items( output_items.extend(build_executed_mcp_call_items(
&state.conversation_history, &state.conversation_history,
server_label, server_label,
tool_context,
)); ));
obj.insert("output".to_string(), Value::Array(output_items)); obj.insert("output".to_string(), Value::Array(output_items));
} }
...@@ -711,7 +658,6 @@ pub(super) async fn execute_tool_loop( ...@@ -711,7 +658,6 @@ pub(super) async fn execute_tool_loop(
"max_tool_calls", "max_tool_calls",
active_mcp, active_mcp,
original_body, original_body,
config.tool_context,
); );
} }
...@@ -770,28 +716,22 @@ pub(super) async fn execute_tool_loop( ...@@ -770,28 +716,22 @@ pub(super) async fn execute_tool_loop(
}) })
.unwrap_or("mcp"); .unwrap_or("mcp");
// Build mcp_list_tools item
let list_tools_item = build_mcp_list_tools_item(active_mcp, server_label);
// Insert at beginning of output array // Insert at beginning of output array
if let Some(output_array) = response_json if let Some(output_array) = response_json
.get_mut("output") .get_mut("output")
.and_then(|v| v.as_array_mut()) .and_then(|v| v.as_array_mut())
{ {
let mut insert_pos = 0; output_array.insert(0, list_tools_item);
// Only add mcp_list_tools for non-web-search cases // Build mcp_call items using helper function
if !config.tool_context.is_web_search() { let mcp_call_items =
let list_tools_item = build_mcp_list_tools_item(active_mcp, server_label); build_executed_mcp_call_items(&state.conversation_history, server_label);
output_array.insert(0, list_tools_item);
insert_pos = 1;
}
// Build mcp_call items (will be web_search_call for web search tools) // Insert mcp_call items after mcp_list_tools using mutable position
let mcp_call_items = build_executed_mcp_call_items( let mut insert_pos = 1;
&state.conversation_history,
server_label,
config.tool_context,
);
// Insert call items after mcp_list_tools (if present)
for item in mcp_call_items { for item in mcp_call_items {
output_array.insert(insert_pos, item); output_array.insert(insert_pos, item);
insert_pos += 1; insert_pos += 1;
...@@ -811,17 +751,13 @@ pub(super) fn build_incomplete_response( ...@@ -811,17 +751,13 @@ pub(super) fn build_incomplete_response(
reason: &str, reason: &str,
active_mcp: &Arc<mcp::McpManager>, active_mcp: &Arc<mcp::McpManager>,
original_body: &ResponsesRequest, original_body: &ResponsesRequest,
tool_context: ToolContext,
) -> Result<Value, String> { ) -> Result<Value, String> {
let obj = response let obj = response
.as_object_mut() .as_object_mut()
.ok_or_else(|| "response not an object".to_string())?; .ok_or_else(|| "response not an object".to_string())?;
// Set status to completed (not failed - partial success) // Set status to completed (not failed - partial success)
obj.insert( obj.insert("status".to_string(), Value::String("completed".to_string()));
"status".to_string(),
Value::String(web_search_constants::STATUS_COMPLETED.to_string()),
);
// Set incomplete_details // Set incomplete_details
obj.insert( obj.insert(
...@@ -863,7 +799,6 @@ pub(super) fn build_incomplete_response( ...@@ -863,7 +799,6 @@ pub(super) fn build_incomplete_response(
server_label, server_label,
false, // Not successful false, // Not successful
Some("Not executed - response stopped due to limit"), Some("Not executed - response stopped due to limit"),
tool_context,
); );
mcp_call_items.push(mcp_call_item); mcp_call_items.push(mcp_call_item);
} }
...@@ -871,28 +806,20 @@ pub(super) fn build_incomplete_response( ...@@ -871,28 +806,20 @@ pub(super) fn build_incomplete_response(
// Add mcp_list_tools and executed mcp_call items at the beginning // Add mcp_list_tools and executed mcp_call items at the beginning
if state.total_calls > 0 || !mcp_call_items.is_empty() { if state.total_calls > 0 || !mcp_call_items.is_empty() {
let mut insert_pos = 0; let list_tools_item = build_mcp_list_tools_item(active_mcp, server_label);
output_array.insert(0, list_tools_item);
// Only add mcp_list_tools for non-web-search cases
if !tool_context.is_web_search() {
let list_tools_item = build_mcp_list_tools_item(active_mcp, server_label);
output_array.insert(0, list_tools_item);
insert_pos = 1;
}
// Add mcp_call items for executed calls (will be web_search_call for web search) // Add mcp_call items for executed calls using helper
let executed_items = build_executed_mcp_call_items( let executed_items =
&state.conversation_history, build_executed_mcp_call_items(&state.conversation_history, server_label);
server_label,
tool_context,
);
let mut insert_pos = 1;
for item in executed_items { for item in executed_items {
output_array.insert(insert_pos, item); output_array.insert(insert_pos, item);
insert_pos += 1; insert_pos += 1;
} }
// Add incomplete mcp_call items (will be web_search_call for web search) // Add incomplete mcp_call items
for item in mcp_call_items { for item in mcp_call_items {
output_array.insert(insert_pos, item); output_array.insert(insert_pos, item);
insert_pos += 1; insert_pos += 1;
...@@ -920,67 +847,6 @@ pub(super) fn build_incomplete_response( ...@@ -920,67 +847,6 @@ pub(super) fn build_incomplete_response(
Ok(response) Ok(response)
} }
// ============================================================================
// Web Search Preview Helpers
// ============================================================================
/// Detect if request has web_search_preview tool
pub(super) fn has_web_search_preview_tool(tools: &[ResponseTool]) -> bool {
tools
.iter()
.any(|t| matches!(t.r#type, ResponseToolType::WebSearchPreview))
}
/// Check if MCP server "web_search_preview" is available
pub(super) async fn is_web_search_mcp_available(mcp_manager: &Arc<mcp::McpManager>) -> bool {
mcp_manager
.get_client(web_search_constants::WEB_SEARCH_PREVIEW_SERVER_NAME)
.await
.is_some()
}
/// Build a web_search_call output item (MVP - status only)
///
/// The MCP search results are passed to the LLM internally via function_call_output,
/// but we don't expose them in the web_search_call item to the client.
fn build_web_search_call_item(query: Option<String>) -> Value {
let mut action = serde_json::Map::new();
action.insert(
"type".to_string(),
Value::String(web_search_constants::ACTION_TYPE_SEARCH.to_string()),
);
if let Some(q) = query {
action.insert("query".to_string(), Value::String(q));
}
json!({
"id": generate_id("ws"),
"type": event_types::ITEM_TYPE_WEB_SEARCH_CALL,
"status": web_search_constants::STATUS_COMPLETED,
"action": action
})
}
/// Build a failed web_search_call output item
fn build_web_search_call_item_failed(error: &str, query: Option<String>) -> Value {
let mut action = serde_json::Map::new();
action.insert(
"type".to_string(),
Value::String(web_search_constants::ACTION_TYPE_SEARCH.to_string()),
);
if let Some(q) = query {
action.insert("query".to_string(), Value::String(q));
}
json!({
"id": generate_id("ws"),
"type": event_types::ITEM_TYPE_WEB_SEARCH_CALL,
"status": web_search_constants::STATUS_FAILED,
"action": action,
"error": error
})
}
// ============================================================================ // ============================================================================
// Output Item Builders // Output Item Builders
// ============================================================================ // ============================================================================
...@@ -1018,47 +884,24 @@ pub(super) fn build_mcp_call_item( ...@@ -1018,47 +884,24 @@ pub(super) fn build_mcp_call_item(
server_label: &str, server_label: &str,
success: bool, success: bool,
error: Option<&str>, error: Option<&str>,
tool_context: ToolContext,
) -> Value { ) -> Value {
// Check if this is a web_search_preview context - if so, build web_search_call format json!({
if tool_context.is_web_search() { "id": generate_id("mcp"),
// Extract query from arguments for web_search_call "type": event_types::ITEM_TYPE_MCP_CALL,
let query = serde_json::from_str::<Value>(arguments).ok().and_then(|v| { "status": if success { "completed" } else { "failed" },
v.get("query") "approval_request_id": Value::Null,
.and_then(|q| q.as_str().map(|s| s.to_string())) "arguments": arguments,
}); "error": error,
"name": tool_name,
// Build web_search_call item (MVP - status only, no results) "output": output,
if success { "server_label": server_label
build_web_search_call_item(query) })
} else {
build_web_search_call_item_failed(error.unwrap_or("Tool execution failed"), query)
}
} else {
// Regular mcp_call item
json!({
"id": generate_id("mcp"),
"type": event_types::ITEM_TYPE_MCP_CALL,
"status": if success {
web_search_constants::STATUS_COMPLETED
} else {
web_search_constants::STATUS_FAILED
},
"approval_request_id": Value::Null,
"arguments": arguments,
"error": error,
"name": tool_name,
"output": output,
"server_label": server_label
})
}
} }
/// Helper function to build mcp_call items from executed tool calls in conversation history /// Helper function to build mcp_call items from executed tool calls in conversation history
pub(super) fn build_executed_mcp_call_items( pub(super) fn build_executed_mcp_call_items(
conversation_history: &[Value], conversation_history: &[Value],
server_label: &str, server_label: &str,
tool_context: ToolContext,
) -> Vec<Value> { ) -> Vec<Value> {
let mut mcp_call_items = Vec::new(); let mut mcp_call_items = Vec::new();
...@@ -1097,7 +940,6 @@ pub(super) fn build_executed_mcp_call_items( ...@@ -1097,7 +940,6 @@ pub(super) fn build_executed_mcp_call_items(
} else { } else {
None None
}, },
tool_context,
); );
mcp_call_items.push(mcp_call_item); mcp_call_items.push(mcp_call_item);
} }
......
...@@ -5,7 +5,7 @@ use std::collections::HashMap; ...@@ -5,7 +5,7 @@ use std::collections::HashMap;
use serde_json::{json, Value}; use serde_json::{json, Value};
use tracing::warn; use tracing::warn;
use super::utils::{event_types, web_search_constants}; use super::utils::event_types;
use crate::{ use crate::{
data_connector::{ResponseId, StoredResponse}, data_connector::{ResponseId, StoredResponse},
protocols::responses::{ResponseToolType, ResponsesRequest}, protocols::responses::{ResponseToolType, ResponsesRequest},
...@@ -276,67 +276,41 @@ pub(super) fn rewrite_streaming_block( ...@@ -276,67 +276,41 @@ pub(super) fn rewrite_streaming_block(
/// Mask function tools as MCP tools in response for client /// Mask function tools as MCP tools in response for client
pub(super) fn mask_tools_as_mcp(resp: &mut Value, original_body: &ResponsesRequest) { pub(super) fn mask_tools_as_mcp(resp: &mut Value, original_body: &ResponsesRequest) {
// Check for MCP tool
let mcp_tool = original_body.tools.as_ref().and_then(|tools| { let mcp_tool = original_body.tools.as_ref().and_then(|tools| {
tools tools
.iter() .iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some()) .find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some())
}); });
let Some(t) = mcp_tool else {
// Check for web_search_preview tool
let has_web_search = original_body
.tools
.as_ref()
.map(|tools| crate::routers::openai::mcp::has_web_search_preview_tool(tools))
.unwrap_or(false);
// If neither MCP nor web_search_preview, return early
if mcp_tool.is_none() && !has_web_search {
return; return;
} };
let mut response_tools = Vec::new();
// Add MCP tool if present let mut m = serde_json::Map::new();
if let Some(t) = mcp_tool { m.insert("type".to_string(), Value::String("mcp".to_string()));
let mut m = serde_json::Map::new(); if let Some(label) = &t.server_label {
m.insert("type".to_string(), Value::String("mcp".to_string())); m.insert("server_label".to_string(), Value::String(label.clone()));
if let Some(label) = &t.server_label {
m.insert("server_label".to_string(), Value::String(label.clone()));
}
if let Some(url) = &t.server_url {
m.insert("server_url".to_string(), Value::String(url.clone()));
}
if let Some(desc) = &t.server_description {
m.insert(
"server_description".to_string(),
Value::String(desc.clone()),
);
}
if let Some(req) = &t.require_approval {
m.insert("require_approval".to_string(), Value::String(req.clone()));
}
if let Some(allowed) = &t.allowed_tools {
m.insert(
"allowed_tools".to_string(),
Value::Array(allowed.iter().map(|s| Value::String(s.clone())).collect()),
);
}
response_tools.push(Value::Object(m));
} }
if let Some(url) = &t.server_url {
// Add web_search_preview tool if present m.insert("server_url".to_string(), Value::String(url.clone()));
if has_web_search { }
let mut ws = serde_json::Map::new(); if let Some(desc) = &t.server_description {
ws.insert( m.insert(
"type".to_string(), "server_description".to_string(),
Value::String(web_search_constants::WEB_SEARCH_PREVIEW_SERVER_NAME.to_string()), Value::String(desc.clone()),
);
}
if let Some(req) = &t.require_approval {
m.insert("require_approval".to_string(), Value::String(req.clone()));
}
if let Some(allowed) = &t.allowed_tools {
m.insert(
"allowed_tools".to_string(),
Value::Array(allowed.iter().map(|s| Value::String(s.clone())).collect()),
); );
response_tools.push(Value::Object(ws));
} }
if let Some(obj) = resp.as_object_mut() { if let Some(obj) = resp.as_object_mut() {
obj.insert("tools".to_string(), Value::Array(response_tools)); obj.insert("tools".to_string(), Value::Array(vec![Value::Object(m)]));
obj.entry("tool_choice") obj.entry("tool_choice")
.or_insert(Value::String("auto".to_string())); .or_insert(Value::String("auto".to_string()));
} }
......
...@@ -30,12 +30,12 @@ use super::conversations::{ ...@@ -30,12 +30,12 @@ use super::conversations::{
}; };
use super::{ use super::{
mcp::{ mcp::{
ensure_request_mcp_client, execute_tool_loop, has_web_search_preview_tool, ensure_request_mcp_client, execute_tool_loop, prepare_mcp_payload_for_streaming,
is_web_search_mcp_available, prepare_mcp_payload_for_streaming, McpLoopConfig, McpLoopConfig,
}, },
responses::{mask_tools_as_mcp, patch_streaming_response_json}, responses::{mask_tools_as_mcp, patch_streaming_response_json},
streaming::handle_streaming_response, streaming::handle_streaming_response,
utils::{apply_provider_headers, extract_auth_header, probe_endpoint_for_model, ToolContext}, utils::{apply_provider_headers, extract_auth_header, probe_endpoint_for_model},
}; };
use crate::{ use crate::{
core::{CircuitBreaker, CircuitBreakerConfig as CoreCircuitBreakerConfig}, core::{CircuitBreaker, CircuitBreakerConfig as CoreCircuitBreakerConfig},
...@@ -248,7 +248,6 @@ impl OpenAIRouter { ...@@ -248,7 +248,6 @@ impl OpenAIRouter {
mut payload: Value, mut payload: Value,
original_body: &ResponsesRequest, original_body: &ResponsesRequest,
original_previous_response_id: Option<String>, original_previous_response_id: Option<String>,
tool_context: ToolContext,
) -> Response { ) -> Response {
// Check if MCP is active for this request // Check if MCP is active for this request
// Ensure dynamic client is created if needed // Ensure dynamic client is created if needed
...@@ -267,13 +266,10 @@ impl OpenAIRouter { ...@@ -267,13 +266,10 @@ impl OpenAIRouter {
// If MCP is active, execute tool loop // If MCP is active, execute tool loop
if let Some(mcp) = active_mcp { if let Some(mcp) = active_mcp {
let config = McpLoopConfig { let config = McpLoopConfig::default();
tool_context,
..Default::default()
};
// Transform MCP tools to function tools // Transform MCP tools to function tools
prepare_mcp_payload_for_streaming(&mut payload, mcp, tool_context); prepare_mcp_payload_for_streaming(&mut payload, mcp);
match execute_tool_loop( match execute_tool_loop(
&self.client, &self.client,
...@@ -699,35 +695,6 @@ impl crate::routers::RouterTrait for OpenAIRouter { ...@@ -699,35 +695,6 @@ impl crate::routers::RouterTrait for OpenAIRouter {
let url = format!("{}/v1/responses", base_url); let url = format!("{}/v1/responses", base_url);
// Detect web_search_preview tool and verify MCP server availability
let tool_context = if let Some(ref tools) = body.tools {
if has_web_search_preview_tool(tools) {
ToolContext::WebSearchPreview
} else {
ToolContext::Regular
}
} else {
ToolContext::Regular
};
if tool_context.is_web_search() {
// Check if web_search_preview MCP server is available
if !is_web_search_mcp_available(&self.mcp_manager).await {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": {
"message": "Web search preview is currently unavailable. Please contact your server administrator.",
"type": "invalid_request_error",
"param": "tools",
"code": "web_search_unavailable"
}
})),
)
.into_response();
}
}
// Validate mutually exclusive params: previous_response_id and conversation // Validate mutually exclusive params: previous_response_id and conversation
// TODO: this validation logic should move the right place, also we need a proper error message module // TODO: this validation logic should move the right place, also we need a proper error message module
if body.previous_response_id.is_some() && body.conversation.is_some() { if body.previous_response_id.is_some() && body.conversation.is_some() {
...@@ -1055,7 +1022,6 @@ impl crate::routers::RouterTrait for OpenAIRouter { ...@@ -1055,7 +1022,6 @@ impl crate::routers::RouterTrait for OpenAIRouter {
payload, payload,
body, body,
original_previous_response_id, original_previous_response_id,
tool_context,
) )
.await .await
} else { } else {
...@@ -1065,7 +1031,6 @@ impl crate::routers::RouterTrait for OpenAIRouter { ...@@ -1065,7 +1031,6 @@ impl crate::routers::RouterTrait for OpenAIRouter {
payload, payload,
body, body,
original_previous_response_id, original_previous_response_id,
tool_context,
) )
.await .await
} }
......
...@@ -30,10 +30,7 @@ use super::{ ...@@ -30,10 +30,7 @@ use super::{
send_mcp_list_tools_events, McpLoopConfig, ToolLoopState, send_mcp_list_tools_events, McpLoopConfig, ToolLoopState,
}, },
responses::{mask_tools_as_mcp, patch_streaming_response_json, rewrite_streaming_block}, responses::{mask_tools_as_mcp, patch_streaming_response_json, rewrite_streaming_block},
utils::{ utils::{event_types, FunctionCallInProgress, OutputIndexMapper, StreamAction},
event_types, web_search_constants, FunctionCallInProgress, OutputIndexMapper, StreamAction,
ToolContext,
},
}; };
use crate::{ use crate::{
data_connector::{ConversationItemStorage, ConversationStorage, ResponseStorage}, data_connector::{ConversationItemStorage, ConversationStorage, ResponseStorage},
...@@ -556,7 +553,6 @@ pub(super) fn apply_event_transformations_inplace( ...@@ -556,7 +553,6 @@ pub(super) fn apply_event_transformations_inplace(
server_label: &str, server_label: &str,
original_request: &ResponsesRequest, original_request: &ResponsesRequest,
previous_response_id: Option<&str>, previous_response_id: Option<&str>,
tool_context: ToolContext,
) -> bool { ) -> bool {
let mut changed = false; let mut changed = false;
...@@ -602,35 +598,23 @@ pub(super) fn apply_event_transformations_inplace( ...@@ -602,35 +598,23 @@ pub(super) fn apply_event_transformations_inplace(
// Mask tools from function to MCP format (optimized without cloning) // Mask tools from function to MCP format (optimized without cloning)
if response_obj.get("tools").is_some() { if response_obj.get("tools").is_some() {
// For web_search_preview, always use simplified tool format let requested_mcp = original_request
if tool_context.is_web_search() { .tools
let web_search_tool = .as_ref()
json!([{"type": web_search_constants::WEB_SEARCH_PREVIEW_SERVER_NAME}]); .map(|tools| {
response_obj.insert("tools".to_string(), web_search_tool); tools
response_obj .iter()
.entry("tool_choice".to_string()) .any(|t| matches!(t.r#type, ResponseToolType::Mcp))
.or_insert(Value::String("auto".to_string())); })
changed = true; .unwrap_or(false);
} else {
// Regular MCP tools - only if requested if requested_mcp {
let requested_mcp = original_request if let Some(mcp_tools) = build_mcp_tools_value(original_request) {
.tools response_obj.insert("tools".to_string(), mcp_tools);
.as_ref() response_obj
.map(|tools| { .entry("tool_choice".to_string())
tools .or_insert(Value::String("auto".to_string()));
.iter() changed = true;
.any(|t| matches!(t.r#type, ResponseToolType::Mcp))
})
.unwrap_or(false);
if requested_mcp {
if let Some(mcp_tools) = build_mcp_tools_value(original_request) {
response_obj.insert("tools".to_string(), mcp_tools);
response_obj
.entry("tool_choice".to_string())
.or_insert(Value::String("auto".to_string()));
changed = true;
}
} }
} }
} }
...@@ -645,30 +629,13 @@ pub(super) fn apply_event_transformations_inplace( ...@@ -645,30 +629,13 @@ pub(super) fn apply_event_transformations_inplace(
if item_type == event_types::ITEM_TYPE_FUNCTION_CALL if item_type == event_types::ITEM_TYPE_FUNCTION_CALL
|| item_type == event_types::ITEM_TYPE_FUNCTION_TOOL_CALL || item_type == event_types::ITEM_TYPE_FUNCTION_TOOL_CALL
{ {
// Use web_search_call for web_search_preview, mcp_call for regular MCP item["type"] = json!(event_types::ITEM_TYPE_MCP_CALL);
if tool_context.is_web_search() { item["server_label"] = json!(server_label);
item["type"] = json!(event_types::ITEM_TYPE_WEB_SEARCH_CALL);
// Don't include server_label for web_search_call
// Remove internal implementation fields
if let Some(obj) = item.as_object_mut() {
obj.remove("arguments");
obj.remove("call_id");
obj.remove("name");
}
} else {
item["type"] = json!(event_types::ITEM_TYPE_MCP_CALL);
item["server_label"] = json!(server_label);
}
// Transform ID from fc_* to ws_* or mcp_* // Transform ID from fc_* to mcp_*
if let Some(id) = item.get("id").and_then(|v| v.as_str()) { if let Some(id) = item.get("id").and_then(|v| v.as_str()) {
if let Some(stripped) = id.strip_prefix("fc_") { if let Some(stripped) = id.strip_prefix("fc_") {
let prefix = if tool_context.is_web_search() { let new_id = format!("mcp_{}", stripped);
"ws"
} else {
"mcp"
};
let new_id = format!("{}_{}", prefix, stripped);
item["id"] = json!(new_id); item["id"] = json!(new_id);
} }
} }
...@@ -726,7 +693,6 @@ pub(super) fn forward_streaming_event( ...@@ -726,7 +693,6 @@ pub(super) fn forward_streaming_event(
original_request: &ResponsesRequest, original_request: &ResponsesRequest,
previous_response_id: Option<&str>, previous_response_id: Option<&str>,
sequence_number: &mut u64, sequence_number: &mut u64,
tool_context: ToolContext,
) -> bool { ) -> bool {
// Skip individual function_call_arguments.delta events - we'll send them as one // Skip individual function_call_arguments.delta events - we'll send them as one
if event_name == Some(event_types::FUNCTION_CALL_ARGUMENTS_DELTA) { if event_name == Some(event_types::FUNCTION_CALL_ARGUMENTS_DELTA) {
...@@ -791,40 +757,37 @@ pub(super) fn forward_streaming_event( ...@@ -791,40 +757,37 @@ pub(super) fn forward_streaming_event(
}; };
// Emit a synthetic MCP arguments delta event before the done event // Emit a synthetic MCP arguments delta event before the done event
// Skip for web_search_preview - we don't expose tool call arguments let mut delta_event = json!({
if !tool_context.is_web_search() { "type": event_types::MCP_CALL_ARGUMENTS_DELTA,
let mut delta_event = json!({ "sequence_number": *sequence_number,
"type": event_types::MCP_CALL_ARGUMENTS_DELTA, "output_index": assigned_index,
"sequence_number": *sequence_number, "item_id": mcp_item_id,
"output_index": assigned_index, "delta": arguments_value,
"item_id": mcp_item_id, });
"delta": arguments_value,
}); if let Some(obfuscation) = call.last_obfuscation.as_ref() {
if let Some(obj) = delta_event.as_object_mut() {
if let Some(obfuscation) = call.last_obfuscation.as_ref() { obj.insert(
if let Some(obj) = delta_event.as_object_mut() { "obfuscation".to_string(),
obj.insert( Value::String(obfuscation.clone()),
"obfuscation".to_string(), );
Value::String(obfuscation.clone()),
);
}
} else if let Some(obfuscation) = parsed_data.get("obfuscation").cloned() {
if let Some(obj) = delta_event.as_object_mut() {
obj.insert("obfuscation".to_string(), obfuscation);
}
} }
} else if let Some(obfuscation) = parsed_data.get("obfuscation").cloned() {
let delta_block = format!( if let Some(obj) = delta_event.as_object_mut() {
"event: {}\ndata: {}\n\n", obj.insert("obfuscation".to_string(), obfuscation);
event_types::MCP_CALL_ARGUMENTS_DELTA,
delta_event
);
if tx.send(Ok(Bytes::from(delta_block))).is_err() {
return false;
} }
}
*sequence_number += 1; let delta_block = format!(
"event: {}\ndata: {}\n\n",
event_types::MCP_CALL_ARGUMENTS_DELTA,
delta_event
);
if tx.send(Ok(Bytes::from(delta_block))).is_err() {
return false;
} }
*sequence_number += 1;
} }
} }
} }
...@@ -850,7 +813,6 @@ pub(super) fn forward_streaming_event( ...@@ -850,7 +813,6 @@ pub(super) fn forward_streaming_event(
server_label, server_label,
original_request, original_request,
previous_response_id, previous_response_id,
tool_context,
); );
if let Some(response_obj) = parsed_data if let Some(response_obj) = parsed_data
...@@ -882,24 +844,16 @@ pub(super) fn forward_streaming_event( ...@@ -882,24 +844,16 @@ pub(super) fn forward_streaming_event(
let mut final_block = String::new(); let mut final_block = String::new();
if let Some(evt) = event_name { if let Some(evt) = event_name {
// Update event name for function_call_arguments events // Update event name for function_call_arguments events
// Skip for web_search_preview - we don't expose tool call arguments if evt == event_types::FUNCTION_CALL_ARGUMENTS_DELTA {
if evt == event_types::FUNCTION_CALL_ARGUMENTS_DELTA && !tool_context.is_web_search() {
final_block.push_str(&format!( final_block.push_str(&format!(
"event: {}\n", "event: {}\n",
event_types::MCP_CALL_ARGUMENTS_DELTA event_types::MCP_CALL_ARGUMENTS_DELTA
)); ));
} else if evt == event_types::FUNCTION_CALL_ARGUMENTS_DONE && !tool_context.is_web_search() } else if evt == event_types::FUNCTION_CALL_ARGUMENTS_DONE {
{
final_block.push_str(&format!( final_block.push_str(&format!(
"event: {}\n", "event: {}\n",
event_types::MCP_CALL_ARGUMENTS_DONE event_types::MCP_CALL_ARGUMENTS_DONE
)); ));
} else if (evt == event_types::FUNCTION_CALL_ARGUMENTS_DELTA
|| evt == event_types::FUNCTION_CALL_ARGUMENTS_DONE)
&& tool_context.is_web_search()
{
// Skip these events entirely for web_search_preview
return true;
} else { } else {
final_block.push_str(&format!("event: {}\n", evt)); final_block.push_str(&format!("event: {}\n", evt));
} }
...@@ -911,62 +865,30 @@ pub(super) fn forward_streaming_event( ...@@ -911,62 +865,30 @@ pub(super) fn forward_streaming_event(
return false; return false;
} }
// After sending output_item.added for mcp_call/web_search_call, inject in_progress event // After sending output_item.added for mcp_call, inject mcp_call.in_progress event
if event_name == Some(event_types::OUTPUT_ITEM_ADDED) { if event_name == Some(event_types::OUTPUT_ITEM_ADDED) {
if let Some(item) = parsed_data.get("item") { if let Some(item) = parsed_data.get("item") {
let item_type = item.get("type").and_then(|v| v.as_str()); if item.get("type").and_then(|v| v.as_str()) == Some(event_types::ITEM_TYPE_MCP_CALL) {
// Already transformed to mcp_call
// Check if it's an mcp_call or web_search_call
let is_mcp_or_web_search = item_type == Some(event_types::ITEM_TYPE_MCP_CALL)
|| item_type == Some(event_types::ITEM_TYPE_WEB_SEARCH_CALL);
if is_mcp_or_web_search {
if let (Some(item_id), Some(output_index)) = ( if let (Some(item_id), Some(output_index)) = (
item.get("id").and_then(|v| v.as_str()), item.get("id").and_then(|v| v.as_str()),
parsed_data.get("output_index").and_then(|v| v.as_u64()), parsed_data.get("output_index").and_then(|v| v.as_u64()),
) { ) {
// Choose event type based on tool_context
let in_progress_event_type = if tool_context.is_web_search() {
event_types::WEB_SEARCH_CALL_IN_PROGRESS
} else {
event_types::MCP_CALL_IN_PROGRESS
};
let in_progress_event = json!({ let in_progress_event = json!({
"type": in_progress_event_type, "type": event_types::MCP_CALL_IN_PROGRESS,
"sequence_number": *sequence_number, "sequence_number": *sequence_number,
"output_index": output_index, "output_index": output_index,
"item_id": item_id "item_id": item_id
}); });
*sequence_number += 1; *sequence_number += 1;
let in_progress_block = format!( let in_progress_block = format!(
"event: {}\ndata: {}\n\n", "event: {}\ndata: {}\n\n",
in_progress_event_type, in_progress_event event_types::MCP_CALL_IN_PROGRESS,
in_progress_event
); );
if tx.send(Ok(Bytes::from(in_progress_block))).is_err() { if tx.send(Ok(Bytes::from(in_progress_block))).is_err() {
return false; return false;
} }
// For web_search_call, also send a "searching" event
if tool_context.is_web_search() {
let searching_event = json!({
"type": event_types::WEB_SEARCH_CALL_SEARCHING,
"sequence_number": *sequence_number,
"output_index": output_index,
"item_id": item_id
});
*sequence_number += 1;
let searching_block = format!(
"event: {}\ndata: {}\n\n",
event_types::WEB_SEARCH_CALL_SEARCHING,
searching_event
);
if tx.send(Ok(Bytes::from(searching_block))).is_err() {
return false;
}
}
} }
} }
} }
...@@ -987,7 +909,6 @@ pub(super) fn send_final_response_event( ...@@ -987,7 +909,6 @@ pub(super) fn send_final_response_event(
original_request: &ResponsesRequest, original_request: &ResponsesRequest,
previous_response_id: Option<&str>, previous_response_id: Option<&str>,
server_label: &str, server_label: &str,
tool_context: ToolContext,
) -> bool { ) -> bool {
let mut final_response = match handler.snapshot_final_response() { let mut final_response = match handler.snapshot_final_response() {
Some(resp) => resp, Some(resp) => resp,
...@@ -1004,7 +925,7 @@ pub(super) fn send_final_response_event( ...@@ -1004,7 +925,7 @@ pub(super) fn send_final_response_event(
} }
if let Some(mcp) = active_mcp { if let Some(mcp) = active_mcp {
inject_mcp_metadata_streaming(&mut final_response, state, mcp, server_label, tool_context); inject_mcp_metadata_streaming(&mut final_response, state, mcp, server_label);
} }
mask_tools_as_mcp(&mut final_response, original_request); mask_tools_as_mcp(&mut final_response, original_request);
...@@ -1216,10 +1137,9 @@ pub(super) async fn handle_streaming_with_tool_interception( ...@@ -1216,10 +1137,9 @@ pub(super) async fn handle_streaming_with_tool_interception(
original_body: &ResponsesRequest, original_body: &ResponsesRequest,
original_previous_response_id: Option<String>, original_previous_response_id: Option<String>,
active_mcp: &Arc<crate::mcp::McpManager>, active_mcp: &Arc<crate::mcp::McpManager>,
tool_context: ToolContext,
) -> Response { ) -> Response {
// Transform MCP tools to function tools in payload // Transform MCP tools to function tools in payload
prepare_mcp_payload_for_streaming(&mut payload, active_mcp, tool_context); prepare_mcp_payload_for_streaming(&mut payload, active_mcp);
let (tx, rx) = mpsc::unbounded_channel::<Result<Bytes, io::Error>>(); let (tx, rx) = mpsc::unbounded_channel::<Result<Bytes, io::Error>>();
let should_store = original_body.store.unwrap_or(false); let should_store = original_body.store.unwrap_or(false);
...@@ -1236,10 +1156,7 @@ pub(super) async fn handle_streaming_with_tool_interception( ...@@ -1236,10 +1156,7 @@ pub(super) async fn handle_streaming_with_tool_interception(
// Spawn the streaming loop task // Spawn the streaming loop task
tokio::spawn(async move { tokio::spawn(async move {
let mut state = ToolLoopState::new(original_request.input.clone()); let mut state = ToolLoopState::new(original_request.input.clone());
let loop_config = McpLoopConfig { let loop_config = McpLoopConfig::default();
tool_context,
..Default::default()
};
let max_tool_calls = original_request.max_tool_calls.map(|n| n as usize); let max_tool_calls = original_request.max_tool_calls.map(|n| n as usize);
let tools_json = payload_clone.get("tools").cloned().unwrap_or(json!([])); let tools_json = payload_clone.get("tools").cloned().unwrap_or(json!([]));
let base_payload = payload_clone.clone(); let base_payload = payload_clone.clone();
...@@ -1358,7 +1275,6 @@ pub(super) async fn handle_streaming_with_tool_interception( ...@@ -1358,7 +1275,6 @@ pub(super) async fn handle_streaming_with_tool_interception(
&original_request, &original_request,
previous_response_id.as_deref(), previous_response_id.as_deref(),
&mut sequence_number, &mut sequence_number,
loop_config.tool_context,
) { ) {
// Client disconnected // Client disconnected
return; return;
...@@ -1374,10 +1290,7 @@ pub(super) async fn handle_streaming_with_tool_interception( ...@@ -1374,10 +1290,7 @@ pub(super) async fn handle_streaming_with_tool_interception(
== Some(event_types::RESPONSE_IN_PROGRESS) == Some(event_types::RESPONSE_IN_PROGRESS)
{ {
seen_in_progress = true; seen_in_progress = true;
// Skip mcp_list_tools for web_search_preview if !mcp_list_tools_sent {
if !mcp_list_tools_sent
&& !loop_config.tool_context.is_web_search()
{
let list_tools_index = let list_tools_index =
handler.allocate_synthetic_output_index(); handler.allocate_synthetic_output_index();
if !send_mcp_list_tools_events( if !send_mcp_list_tools_events(
...@@ -1410,7 +1323,6 @@ pub(super) async fn handle_streaming_with_tool_interception( ...@@ -1410,7 +1323,6 @@ pub(super) async fn handle_streaming_with_tool_interception(
&original_request, &original_request,
previous_response_id.as_deref(), previous_response_id.as_deref(),
&mut sequence_number, &mut sequence_number,
loop_config.tool_context,
) { ) {
// Client disconnected // Client disconnected
return; return;
...@@ -1449,7 +1361,6 @@ pub(super) async fn handle_streaming_with_tool_interception( ...@@ -1449,7 +1361,6 @@ pub(super) async fn handle_streaming_with_tool_interception(
&original_request, &original_request,
previous_response_id.as_deref(), previous_response_id.as_deref(),
server_label, server_label,
tool_context,
) { ) {
return; return;
} }
...@@ -1471,7 +1382,6 @@ pub(super) async fn handle_streaming_with_tool_interception( ...@@ -1471,7 +1382,6 @@ pub(super) async fn handle_streaming_with_tool_interception(
&state, &state,
&active_mcp_clone, &active_mcp_clone,
server_label, server_label,
tool_context,
); );
mask_tools_as_mcp(&mut response_json, &original_request); mask_tools_as_mcp(&mut response_json, &original_request);
...@@ -1533,7 +1443,6 @@ pub(super) async fn handle_streaming_with_tool_interception( ...@@ -1533,7 +1443,6 @@ pub(super) async fn handle_streaming_with_tool_interception(
&mut state, &mut state,
server_label, server_label,
&mut sequence_number, &mut sequence_number,
tool_context,
) )
.await .await
{ {
...@@ -1589,7 +1498,6 @@ pub(super) async fn handle_streaming_response( ...@@ -1589,7 +1498,6 @@ pub(super) async fn handle_streaming_response(
payload: Value, payload: Value,
original_body: &ResponsesRequest, original_body: &ResponsesRequest,
original_previous_response_id: Option<String>, original_previous_response_id: Option<String>,
tool_context: ToolContext,
) -> Response { ) -> Response {
// Check if MCP is active for this request // Check if MCP is active for this request
// Ensure dynamic client is created if needed // Ensure dynamic client is created if needed
...@@ -1637,7 +1545,6 @@ pub(super) async fn handle_streaming_response( ...@@ -1637,7 +1545,6 @@ pub(super) async fn handle_streaming_response(
original_body, original_body,
original_previous_response_id, original_previous_response_id,
active_mcp, active_mcp,
tool_context,
) )
.await .await
} }
...@@ -32,59 +32,12 @@ pub(crate) mod event_types { ...@@ -32,59 +32,12 @@ pub(crate) mod event_types {
pub const MCP_LIST_TOOLS_IN_PROGRESS: &str = "response.mcp_list_tools.in_progress"; pub const MCP_LIST_TOOLS_IN_PROGRESS: &str = "response.mcp_list_tools.in_progress";
pub const MCP_LIST_TOOLS_COMPLETED: &str = "response.mcp_list_tools.completed"; pub const MCP_LIST_TOOLS_COMPLETED: &str = "response.mcp_list_tools.completed";
// Web Search Call events (for web_search_preview)
pub const WEB_SEARCH_CALL_IN_PROGRESS: &str = "response.web_search_call.in_progress";
pub const WEB_SEARCH_CALL_SEARCHING: &str = "response.web_search_call.searching";
pub const WEB_SEARCH_CALL_COMPLETED: &str = "response.web_search_call.completed";
// Item types // Item types
pub const ITEM_TYPE_FUNCTION_CALL: &str = "function_call"; pub const ITEM_TYPE_FUNCTION_CALL: &str = "function_call";
pub const ITEM_TYPE_FUNCTION_TOOL_CALL: &str = "function_tool_call"; pub const ITEM_TYPE_FUNCTION_TOOL_CALL: &str = "function_tool_call";
pub const ITEM_TYPE_MCP_CALL: &str = "mcp_call"; pub const ITEM_TYPE_MCP_CALL: &str = "mcp_call";
pub const ITEM_TYPE_FUNCTION: &str = "function"; pub const ITEM_TYPE_FUNCTION: &str = "function";
pub const ITEM_TYPE_MCP_LIST_TOOLS: &str = "mcp_list_tools"; pub const ITEM_TYPE_MCP_LIST_TOOLS: &str = "mcp_list_tools";
pub const ITEM_TYPE_WEB_SEARCH_CALL: &str = "web_search_call";
}
// ============================================================================
// Web Search Constants
// ============================================================================
/// Constants for web search preview feature
pub(crate) mod web_search_constants {
/// MCP server name for web search preview
pub const WEB_SEARCH_PREVIEW_SERVER_NAME: &str = "web_search_preview";
/// Status constants
pub const STATUS_COMPLETED: &str = "completed";
pub const STATUS_FAILED: &str = "failed";
/// Action type for web search
pub const ACTION_TYPE_SEARCH: &str = "search";
}
// ============================================================================
// Tool Context Enum
// ============================================================================
/// Represents the context for tool handling strategy
///
/// This enum replaces boolean flags for better type safety and clarity.
/// It makes the code more maintainable and easier to extend with new
/// tool handling strategies in the future.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum ToolContext {
/// Regular MCP tool handling with full mcp_call and mcp_list_tools items
Regular,
/// Web search preview handling with simplified web_search_call items
WebSearchPreview,
}
impl ToolContext {
/// Check if this is web search preview context
pub fn is_web_search(&self) -> bool {
matches!(self, ToolContext::WebSearchPreview)
}
} }
// ============================================================================ // ============================================================================
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment