utils.rs 9.35 KB
Newer Older
1
2
3
4
//! Utility types and constants for OpenAI router

use std::collections::HashMap;

5
6
use axum::http::{HeaderMap, HeaderValue};

7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
// ============================================================================
// SSE Event Type Constants
// ============================================================================

/// SSE event type constants - single source of truth for event type strings
pub(crate) mod event_types {
    // Response lifecycle events
    pub const RESPONSE_CREATED: &str = "response.created";
    pub const RESPONSE_IN_PROGRESS: &str = "response.in_progress";
    pub const RESPONSE_COMPLETED: &str = "response.completed";

    // Output item events
    pub const OUTPUT_ITEM_ADDED: &str = "response.output_item.added";
    pub const OUTPUT_ITEM_DONE: &str = "response.output_item.done";
    pub const OUTPUT_ITEM_DELTA: &str = "response.output_item.delta";

    // Function call events
    pub const FUNCTION_CALL_ARGUMENTS_DELTA: &str = "response.function_call_arguments.delta";
    pub const FUNCTION_CALL_ARGUMENTS_DONE: &str = "response.function_call_arguments.done";

    // MCP call events
    pub const MCP_CALL_ARGUMENTS_DELTA: &str = "response.mcp_call_arguments.delta";
    pub const MCP_CALL_ARGUMENTS_DONE: &str = "response.mcp_call_arguments.done";
    pub const MCP_CALL_IN_PROGRESS: &str = "response.mcp_call.in_progress";
    pub const MCP_CALL_COMPLETED: &str = "response.mcp_call.completed";
    pub const MCP_LIST_TOOLS_IN_PROGRESS: &str = "response.mcp_list_tools.in_progress";
    pub const MCP_LIST_TOOLS_COMPLETED: &str = "response.mcp_list_tools.completed";

35
36
37
38
39
    // Web Search Call events (for web_search_preview)
    pub const WEB_SEARCH_CALL_IN_PROGRESS: &str = "response.web_search_call.in_progress";
    pub const WEB_SEARCH_CALL_SEARCHING: &str = "response.web_search_call.searching";
    pub const WEB_SEARCH_CALL_COMPLETED: &str = "response.web_search_call.completed";

40
41
42
43
44
45
    // Item types
    pub const ITEM_TYPE_FUNCTION_CALL: &str = "function_call";
    pub const ITEM_TYPE_FUNCTION_TOOL_CALL: &str = "function_tool_call";
    pub const ITEM_TYPE_MCP_CALL: &str = "mcp_call";
    pub const ITEM_TYPE_FUNCTION: &str = "function";
    pub const ITEM_TYPE_MCP_LIST_TOOLS: &str = "mcp_list_tools";
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    pub const ITEM_TYPE_WEB_SEARCH_CALL: &str = "web_search_call";
}

// ============================================================================
// Web Search Constants
// ============================================================================

/// Constants for web search preview feature
pub(crate) mod web_search_constants {
    /// MCP server name for web search preview
    pub const WEB_SEARCH_PREVIEW_SERVER_NAME: &str = "web_search_preview";

    /// Status constants
    pub const STATUS_COMPLETED: &str = "completed";
    pub const STATUS_FAILED: &str = "failed";

    /// Action type for web search
    pub const ACTION_TYPE_SEARCH: &str = "search";
}

// ============================================================================
// Tool Context Enum
// ============================================================================

/// Represents the context for tool handling strategy
///
/// This enum replaces boolean flags for better type safety and clarity.
/// It makes the code more maintainable and easier to extend with new
/// tool handling strategies in the future.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum ToolContext {
    /// Regular MCP tool handling with full mcp_call and mcp_list_tools items
    Regular,
    /// Web search preview handling with simplified web_search_call items
    WebSearchPreview,
}

impl ToolContext {
    /// Check if this is web search preview context
    pub fn is_web_search(&self) -> bool {
        matches!(self, ToolContext::WebSearchPreview)
    }
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
}

// ============================================================================
// Stream Action Enum
// ============================================================================

/// Action to take based on streaming event processing
#[derive(Debug)]
pub(crate) enum StreamAction {
    Forward,      // Pass event to client
    Buffer,       // Accumulate for tool execution
    ExecuteTools, // Function call complete, execute now
}

// ============================================================================
// Output Index Mapper
// ============================================================================

/// Maps upstream output indices to sequential downstream indices
#[derive(Debug, Default)]
pub(crate) struct OutputIndexMapper {
    next_index: usize,
    // Map upstream output_index -> remapped output_index
    assigned: HashMap<usize, usize>,
}

impl OutputIndexMapper {
    pub fn with_start(next_index: usize) -> Self {
        Self {
            next_index,
            assigned: HashMap::new(),
        }
    }

    pub fn ensure_mapping(&mut self, upstream_index: usize) -> usize {
        *self.assigned.entry(upstream_index).or_insert_with(|| {
            let assigned = self.next_index;
            self.next_index += 1;
            assigned
        })
    }

    pub fn lookup(&self, upstream_index: usize) -> Option<usize> {
        self.assigned.get(&upstream_index).copied()
    }

    pub fn allocate_synthetic(&mut self) -> usize {
        let assigned = self.next_index;
        self.next_index += 1;
        assigned
    }

    pub fn next_index(&self) -> usize {
        self.next_index
    }
}

145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
// ============================================================================
// Provider Detection and Header Handling
// ============================================================================

/// Extract authorization header from request headers
/// Checks both "authorization" and "Authorization" (case variations)
pub fn extract_auth_header(headers: Option<&HeaderMap>) -> Option<&str> {
    headers.and_then(|h| {
        h.get("authorization")
            .or_else(|| h.get("Authorization"))
            .and_then(|v| v.to_str().ok())
    })
}

/// API provider types
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ApiProvider {
    Anthropic,
    Xai,
    OpenAi,
    Gemini,
    Generic,
}

impl ApiProvider {
    /// Detect provider type from URL
    pub fn from_url(url: &str) -> Self {
        if url.contains("anthropic") {
            ApiProvider::Anthropic
        } else if url.contains("x.ai") {
            ApiProvider::Xai
        } else if url.contains("openai.com") {
            ApiProvider::OpenAi
        } else if url.contains("googleapis.com") {
            ApiProvider::Gemini
        } else {
            ApiProvider::Generic
        }
    }
}

/// Apply provider-specific headers to request
pub fn apply_provider_headers(
    mut req: reqwest::RequestBuilder,
    url: &str,
    auth_header: Option<&HeaderValue>,
) -> reqwest::RequestBuilder {
    let provider = ApiProvider::from_url(url);

    match provider {
        ApiProvider::Anthropic => {
            // Anthropic requires x-api-key instead of Authorization
            // Extract Bearer token and use as x-api-key
            if let Some(auth) = auth_header {
                if let Ok(auth_str) = auth.to_str() {
                    let api_key = auth_str.strip_prefix("Bearer ").unwrap_or(auth_str);
                    req = req
                        .header("x-api-key", api_key)
                        .header("anthropic-version", "2023-06-01");
                }
            }
        }
        ApiProvider::Gemini | ApiProvider::Xai | ApiProvider::OpenAi | ApiProvider::Generic => {
            // Standard OpenAI-compatible: use Authorization header as-is
            if let Some(auth) = auth_header {
                req = req.header("Authorization", auth);
            }
        }
    }

    req
}

/// Probe a single endpoint to check if it has the model
/// Returns Ok(url) if model found, Err(()) otherwise
pub async fn probe_endpoint_for_model(
    client: reqwest::Client,
    url: String,
    model: String,
    auth: Option<String>,
) -> Result<String, ()> {
    use tracing::debug;

    let probe_url = format!("{}/v1/models/{}", url, model);
    let req = client
        .get(&probe_url)
        .timeout(std::time::Duration::from_secs(5));

    // Apply provider-specific headers (handles Anthropic, xAI, OpenAI, etc.)
    let auth_header_value = auth.as_ref().and_then(|a| HeaderValue::from_str(a).ok());
    let req = apply_provider_headers(req, &url, auth_header_value.as_ref());

    match req.send().await {
        Ok(resp) => {
            let status = resp.status();
            if status.is_success() {
                debug!(
                    url = %url,
                    model = %model,
                    status = %status,
                    "Model found on endpoint"
                );
                Ok(url)
            } else {
                debug!(
                    url = %url,
                    model = %model,
                    status = %status,
                    "Model not found on endpoint (unsuccessful status)"
                );
                Err(())
            }
        }
        Err(e) => {
            debug!(
                url = %url,
                model = %model,
                error = %e,
                "Probe request to endpoint failed"
            );
            Err(())
        }
    }
}

270
271
272
273
// ============================================================================
// Re-export FunctionCallInProgress from mcp module
// ============================================================================
pub(crate) use super::mcp::FunctionCallInProgress;