Unverified Commit 77258ce0 authored by Keyang Ru's avatar Keyang Ru Committed by GitHub
Browse files

[router] Support multiple worker URLs for OpenAI router (#11723)

parent 1d097aac
...@@ -165,18 +165,14 @@ impl ConfigValidator { ...@@ -165,18 +165,14 @@ impl ConfigValidator {
} }
} }
RoutingMode::OpenAI { worker_urls } => { RoutingMode::OpenAI { worker_urls } => {
// Require exactly one worker URL for OpenAI router // Require at least one worker URL for OpenAI router
if worker_urls.len() != 1 { if worker_urls.is_empty() {
return Err(ConfigError::ValidationFailed { return Err(ConfigError::ValidationFailed {
reason: "OpenAI mode requires exactly one --worker-urls entry".to_string(), reason: "OpenAI mode requires at least one --worker-urls entry".to_string(),
});
}
// Validate URL format
if let Err(e) = url::Url::parse(&worker_urls[0]) {
return Err(ConfigError::ValidationFailed {
reason: format!("Invalid OpenAI worker URL '{}': {}", &worker_urls[0], e),
}); });
} }
// Validate URLs
Self::validate_urls(worker_urls)?;
} }
} }
Ok(()) Ok(())
......
...@@ -8,8 +8,8 @@ use serde_json::Value; ...@@ -8,8 +8,8 @@ use serde_json::Value;
// Import shared types from common module // Import shared types from common module
use super::common::{ use super::common::{
default_true, ChatLogProbs, GenerationRequest, PromptTokenUsageInfo, StringOrArray, ToolChoice, default_model, default_true, ChatLogProbs, GenerationRequest, PromptTokenUsageInfo,
UsageInfo, StringOrArray, ToolChoice, UsageInfo,
}; };
// ============================================================================ // ============================================================================
...@@ -452,9 +452,9 @@ pub struct ResponsesRequest { ...@@ -452,9 +452,9 @@ pub struct ResponsesRequest {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<HashMap<String, Value>>, pub metadata: Option<HashMap<String, Value>>,
/// Model to use (optional to match vLLM) /// Model to use
#[serde(skip_serializing_if = "Option::is_none")] #[serde(default = "default_model")]
pub model: Option<String>, pub model: String,
/// Optional conversation id to persist input/output as items /// Optional conversation id to persist input/output as items
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
...@@ -565,7 +565,7 @@ impl Default for ResponsesRequest { ...@@ -565,7 +565,7 @@ impl Default for ResponsesRequest {
max_output_tokens: None, max_output_tokens: None,
max_tool_calls: None, max_tool_calls: None,
metadata: None, metadata: None,
model: None, model: default_model(),
conversation: None, conversation: None,
parallel_tool_calls: None, parallel_tool_calls: None,
previous_response_id: None, previous_response_id: None,
...@@ -598,7 +598,7 @@ impl GenerationRequest for ResponsesRequest { ...@@ -598,7 +598,7 @@ impl GenerationRequest for ResponsesRequest {
} }
fn get_model(&self) -> Option<&str> { fn get_model(&self) -> Option<&str> {
self.model.as_deref() Some(self.model.as_str())
} }
fn extract_text_for_routing(&self) -> String { fn extract_text_for_routing(&self) -> String {
......
...@@ -55,7 +55,7 @@ impl RouterFactory { ...@@ -55,7 +55,7 @@ impl RouterFactory {
) )
.await .await
} }
RoutingMode::OpenAI { worker_urls, .. } => { RoutingMode::OpenAI { worker_urls } => {
Self::create_openai_router(worker_urls.clone(), ctx).await Self::create_openai_router(worker_urls.clone(), ctx).await
} }
}, },
...@@ -122,13 +122,12 @@ impl RouterFactory { ...@@ -122,13 +122,12 @@ impl RouterFactory {
worker_urls: Vec<String>, worker_urls: Vec<String>,
ctx: &Arc<AppContext>, ctx: &Arc<AppContext>,
) -> Result<Box<dyn RouterTrait>, String> { ) -> Result<Box<dyn RouterTrait>, String> {
let base_url = worker_urls if worker_urls.is_empty() {
.first() return Err("OpenAI mode requires at least one worker URL".to_string());
.cloned() }
.ok_or_else(|| "OpenAI mode requires at least one worker URL".to_string())?;
let router = OpenAIRouter::new( let router = OpenAIRouter::new(
base_url, worker_urls,
Some(ctx.router_config.circuit_breaker.clone()), Some(ctx.router_config.circuit_breaker.clone()),
ctx.response_storage.clone(), ctx.response_storage.clone(),
ctx.conversation_storage.clone(), ctx.conversation_storage.clone(),
......
...@@ -39,7 +39,7 @@ pub(super) fn build_stored_response( ...@@ -39,7 +39,7 @@ pub(super) fn build_stored_response(
.get("model") .get("model")
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.map(|s| s.to_string()) .map(|s| s.to_string())
.or_else(|| original_body.model.clone()); .or_else(|| Some(original_body.model.clone()));
stored_response.user = response_json stored_response.user = response_json
.get("user") .get("user")
...@@ -143,9 +143,10 @@ pub(super) fn patch_streaming_response_json( ...@@ -143,9 +143,10 @@ pub(super) fn patch_streaming_response_json(
.map(|s| s.is_empty()) .map(|s| s.is_empty())
.unwrap_or(true) .unwrap_or(true)
{ {
if let Some(model) = &original_body.model { obj.insert(
obj.insert("model".to_string(), Value::String(model.clone())); "model".to_string(),
} Value::String(original_body.model.clone()),
);
} }
if obj.get("user").map(|v| v.is_null()).unwrap_or(false) { if obj.get("user").map(|v| v.is_null()).unwrap_or(false) {
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
use std::{ use std::{
any::Any, any::Any,
sync::{atomic::AtomicBool, Arc}, sync::{atomic::AtomicBool, Arc},
time::{Duration, Instant},
}; };
use axum::{ use axum::{
...@@ -12,6 +13,7 @@ use axum::{ ...@@ -12,6 +13,7 @@ use axum::{
response::{IntoResponse, Response}, response::{IntoResponse, Response},
Json, Json,
}; };
use dashmap::DashMap;
use futures_util::StreamExt; use futures_util::StreamExt;
use serde_json::{json, to_value, Value}; use serde_json::{json, to_value, Value};
use tokio::sync::mpsc; use tokio::sync::mpsc;
...@@ -31,6 +33,7 @@ use super::{ ...@@ -31,6 +33,7 @@ use super::{
}, },
responses::{mask_tools_as_mcp, patch_streaming_response_json}, responses::{mask_tools_as_mcp, patch_streaming_response_json},
streaming::handle_streaming_response, streaming::handle_streaming_response,
utils::{apply_provider_headers, extract_auth_header, probe_endpoint_for_model},
}; };
use crate::{ use crate::{
config::CircuitBreakerConfig, config::CircuitBreakerConfig,
...@@ -59,12 +62,21 @@ use crate::{ ...@@ -59,12 +62,21 @@ use crate::{
// OpenAIRouter Struct // OpenAIRouter Struct
// ============================================================================ // ============================================================================
/// Cached endpoint information
#[derive(Clone, Debug)]
struct CachedEndpoint {
url: String,
cached_at: Instant,
}
/// Router for OpenAI backend /// Router for OpenAI backend
pub struct OpenAIRouter { pub struct OpenAIRouter {
/// HTTP client for upstream OpenAI-compatible API /// HTTP client for upstream OpenAI-compatible API
client: reqwest::Client, client: reqwest::Client,
/// Base URL for identification (no trailing slash) /// Multiple OpenAI-compatible API endpoints (OpenAI, xAI, etc.)
base_url: String, worker_urls: Vec<String>,
/// Model cache: model_id -> endpoint URL
model_cache: Arc<DashMap<String, CachedEndpoint>>,
/// Circuit breaker /// Circuit breaker
circuit_breaker: CircuitBreaker, circuit_breaker: CircuitBreaker,
/// Health status /// Health status
...@@ -82,7 +94,7 @@ pub struct OpenAIRouter { ...@@ -82,7 +94,7 @@ pub struct OpenAIRouter {
impl std::fmt::Debug for OpenAIRouter { impl std::fmt::Debug for OpenAIRouter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OpenAIRouter") f.debug_struct("OpenAIRouter")
.field("base_url", &self.base_url) .field("worker_urls", &self.worker_urls)
.field("healthy", &self.healthy) .field("healthy", &self.healthy)
.finish() .finish()
} }
...@@ -92,28 +104,35 @@ impl OpenAIRouter { ...@@ -92,28 +104,35 @@ impl OpenAIRouter {
/// Maximum number of conversation items to attach as input when a conversation is provided /// Maximum number of conversation items to attach as input when a conversation is provided
const MAX_CONVERSATION_HISTORY_ITEMS: usize = 100; const MAX_CONVERSATION_HISTORY_ITEMS: usize = 100;
/// Model discovery cache TTL (1 hour)
const MODEL_CACHE_TTL_SECS: u64 = 3600;
/// Create a new OpenAI router /// Create a new OpenAI router
pub async fn new( pub async fn new(
base_url: String, worker_urls: Vec<String>,
circuit_breaker_config: Option<CircuitBreakerConfig>, circuit_breaker_config: Option<CircuitBreakerConfig>,
response_storage: SharedResponseStorage, response_storage: SharedResponseStorage,
conversation_storage: SharedConversationStorage, conversation_storage: SharedConversationStorage,
conversation_item_storage: SharedConversationItemStorage, conversation_item_storage: SharedConversationItemStorage,
) -> Result<Self, String> { ) -> Result<Self, String> {
let client = reqwest::Client::builder() let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(300)) .timeout(Duration::from_secs(300))
.build() .build()
.map_err(|e| format!("Failed to create HTTP client: {}", e))?; .map_err(|e| format!("Failed to create HTTP client: {}", e))?;
let base_url = base_url.trim_end_matches('/').to_string(); // Normalize URLs (remove trailing slashes)
let worker_urls: Vec<String> = worker_urls
.into_iter()
.map(|url| url.trim_end_matches('/').to_string())
.collect();
// Convert circuit breaker config // Convert circuit breaker config
let core_cb_config = circuit_breaker_config let core_cb_config = circuit_breaker_config
.map(|cb| CoreCircuitBreakerConfig { .map(|cb| CoreCircuitBreakerConfig {
failure_threshold: cb.failure_threshold, failure_threshold: cb.failure_threshold,
success_threshold: cb.success_threshold, success_threshold: cb.success_threshold,
timeout_duration: std::time::Duration::from_secs(cb.timeout_duration_secs), timeout_duration: Duration::from_secs(cb.timeout_duration_secs),
window_duration: std::time::Duration::from_secs(cb.window_duration_secs), window_duration: Duration::from_secs(cb.window_duration_secs),
}) })
.unwrap_or_default(); .unwrap_or_default();
...@@ -141,7 +160,8 @@ impl OpenAIRouter { ...@@ -141,7 +160,8 @@ impl OpenAIRouter {
Ok(Self { Ok(Self {
client, client,
base_url, worker_urls,
model_cache: Arc::new(DashMap::new()),
circuit_breaker, circuit_breaker,
healthy: AtomicBool::new(true), healthy: AtomicBool::new(true),
response_storage, response_storage,
...@@ -151,6 +171,67 @@ impl OpenAIRouter { ...@@ -151,6 +171,67 @@ impl OpenAIRouter {
}) })
} }
/// Discover which endpoint has the model
async fn find_endpoint_for_model(
&self,
model_id: &str,
auth_header: Option<&str>,
) -> Result<String, Response> {
// Single endpoint - fast path
if self.worker_urls.len() == 1 {
return Ok(self.worker_urls[0].clone());
}
// Check cache
if let Some(entry) = self.model_cache.get(model_id) {
if entry.cached_at.elapsed() < Duration::from_secs(Self::MODEL_CACHE_TTL_SECS) {
return Ok(entry.url.clone());
}
}
// Probe all endpoints in parallel
let mut handles = vec![];
let model = model_id.to_string();
let auth = auth_header.map(|s| s.to_string());
for url in &self.worker_urls {
let handle = tokio::spawn(probe_endpoint_for_model(
self.client.clone(),
url.clone(),
model.clone(),
auth.clone(),
));
handles.push(handle);
}
// Return first successful endpoint
for handle in handles {
if let Ok(Ok(url)) = handle.await {
// Cache it
self.model_cache.insert(
model_id.to_string(),
CachedEndpoint {
url: url.clone(),
cached_at: Instant::now(),
},
);
return Ok(url);
}
}
// Model not found on any endpoint
Err((
StatusCode::NOT_FOUND,
Json(json!({
"error": {
"message": format!("Model '{}' not found on any endpoint", model_id),
"type": "model_not_found",
}
})),
)
.into_response())
}
/// Handle non-streaming response with optional MCP tool loop /// Handle non-streaming response with optional MCP tool loop
async fn handle_non_streaming_response( async fn handle_non_streaming_response(
&self, &self,
...@@ -282,85 +363,145 @@ impl crate::routers::RouterTrait for OpenAIRouter { ...@@ -282,85 +363,145 @@ impl crate::routers::RouterTrait for OpenAIRouter {
} }
async fn health_generate(&self, _req: Request<Body>) -> Response { async fn health_generate(&self, _req: Request<Body>) -> Response {
// Simple upstream probe: GET {base}/v1/models without auth // Check all endpoints in parallel - only healthy if ALL are healthy
let url = format!("{}/v1/models", self.base_url); if self.worker_urls.is_empty() {
match self return (StatusCode::SERVICE_UNAVAILABLE, "No endpoints configured").into_response();
.client }
.get(&url)
.timeout(std::time::Duration::from_secs(2)) let mut handles = vec![];
.send() for url in &self.worker_urls {
.await let url = url.clone();
{ let client = self.client.clone();
Ok(resp) => {
let code = resp.status(); let handle = tokio::spawn(async move {
// Treat success and auth-required as healthy (endpoint reachable) let probe_url = format!("{}/v1/models", url);
if code.is_success() || code.as_u16() == 401 || code.as_u16() == 403 { match client
(StatusCode::OK, "OK").into_response() .get(&probe_url)
} else { .timeout(Duration::from_secs(2))
( .send()
StatusCode::SERVICE_UNAVAILABLE, .await
format!("Upstream status: {}", code), {
) Ok(resp) => {
.into_response() let code = resp.status();
// Treat success and auth-required as healthy (endpoint reachable)
if code.is_success() || code.as_u16() == 401 || code.as_u16() == 403 {
Ok(())
} else {
Err(format!("Endpoint {} returned status {}", url, code))
}
}
Err(e) => Err(format!("Endpoint {} error: {}", url, e)),
} }
});
handles.push(handle);
}
// Collect all results
let mut errors = Vec::new();
for handle in handles {
match handle.await {
Ok(Ok(())) => (),
Ok(Err(e)) => errors.push(e),
Err(e) => errors.push(format!("Task join error: {}", e)),
} }
Err(e) => ( }
if errors.is_empty() {
(StatusCode::OK, "OK").into_response()
} else {
(
StatusCode::SERVICE_UNAVAILABLE, StatusCode::SERVICE_UNAVAILABLE,
format!("Upstream error: {}", e), format!("Some endpoints unhealthy: {}", errors.join(", ")),
) )
.into_response(), .into_response()
} }
} }
async fn get_server_info(&self, _req: Request<Body>) -> Response { async fn get_server_info(&self, _req: Request<Body>) -> Response {
let info = json!({ let info = json!({
"router_type": "openai", "router_type": "openai",
"workers": 1, "workers": self.worker_urls.len(),
"base_url": &self.base_url "worker_urls": &self.worker_urls
}); });
(StatusCode::OK, info.to_string()).into_response() (StatusCode::OK, info.to_string()).into_response()
} }
async fn get_models(&self, req: Request<Body>) -> Response { async fn get_models(&self, req: Request<Body>) -> Response {
// Proxy to upstream /v1/models; forward Authorization header if provided // Aggregate models from all endpoints
let headers = req.headers(); if self.worker_urls.is_empty() {
return (StatusCode::SERVICE_UNAVAILABLE, "No endpoints configured").into_response();
let mut upstream = self.client.get(format!("{}/v1/models", self.base_url));
if let Some(auth) = headers
.get("authorization")
.or_else(|| headers.get("Authorization"))
{
upstream = upstream.header("Authorization", auth);
} }
match upstream.send().await { let headers = req.headers();
Ok(res) => { let auth = headers
let status = StatusCode::from_u16(res.status().as_u16()) .get("authorization")
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); .or_else(|| headers.get("Authorization"));
let content_type = res.headers().get(CONTENT_TYPE).cloned();
match res.bytes().await { // Query all endpoints in parallel
Ok(body) => { let mut handles = vec![];
let mut response = Response::new(Body::from(body)); for url in &self.worker_urls {
*response.status_mut() = status; let url = url.clone();
if let Some(ct) = content_type { let client = self.client.clone();
response.headers_mut().insert(CONTENT_TYPE, ct); let auth = auth.cloned();
let handle = tokio::spawn(async move {
let models_url = format!("{}/v1/models", url);
let req = client.get(&models_url);
// Apply provider-specific headers (handles Anthropic, xAI, OpenAI, etc.)
let req = apply_provider_headers(req, &url, auth.as_ref());
match req.send().await {
Ok(res) => {
if res.status().is_success() {
match res.json::<Value>().await {
Ok(json) => Ok(json),
Err(e) => {
tracing::warn!(
"Failed to parse models response from '{}': {}",
url,
e
);
Err(())
}
}
} else {
tracing::warn!(
"Getting models from '{}' failed with status: {}",
url,
res.status()
);
Err(())
} }
response
} }
Err(e) => ( Err(e) => {
StatusCode::INTERNAL_SERVER_ERROR, tracing::warn!("Request to get models from '{}' failed: {}", url, e);
format!("Failed to read upstream response: {}", e), Err(())
) }
.into_response(), }
});
handles.push(handle);
}
// Collect all model lists
let mut all_models = Vec::new();
for handle in handles {
if let Ok(Ok(json)) = handle.await {
if let Some(data) = json.get("data").and_then(|v| v.as_array()) {
all_models.extend_from_slice(data);
} }
} }
Err(e) => (
StatusCode::BAD_GATEWAY,
format!("Failed to contact upstream: {}", e),
)
.into_response(),
} }
// Return aggregated models
let response_json = json!({
"object": "list",
"data": all_models
});
(StatusCode::OK, Json(response_json)).into_response()
} }
async fn get_model_info(&self, _req: Request<Body>) -> Response { async fn get_model_info(&self, _req: Request<Body>) -> Response {
...@@ -396,6 +537,18 @@ impl crate::routers::RouterTrait for OpenAIRouter { ...@@ -396,6 +537,18 @@ impl crate::routers::RouterTrait for OpenAIRouter {
return (StatusCode::SERVICE_UNAVAILABLE, "Circuit breaker open").into_response(); return (StatusCode::SERVICE_UNAVAILABLE, "Circuit breaker open").into_response();
} }
// Extract auth header
let auth = extract_auth_header(headers);
// Find endpoint for model
let base_url = match self
.find_endpoint_for_model(body.model.as_str(), auth)
.await
{
Ok(url) => url,
Err(response) => return response,
};
// Serialize request body, removing SGLang-only fields // Serialize request body, removing SGLang-only fields
let mut payload = match to_value(body) { let mut payload = match to_value(body) {
Ok(v) => v, Ok(v) => v,
...@@ -431,9 +584,14 @@ impl crate::routers::RouterTrait for OpenAIRouter { ...@@ -431,9 +584,14 @@ impl crate::routers::RouterTrait for OpenAIRouter {
] { ] {
obj.remove(key); obj.remove(key);
} }
// Remove logprobs if false (Gemini don't accept it)
if obj.get("logprobs").and_then(|v| v.as_bool()) == Some(false) {
obj.remove("logprobs");
}
} }
let url = format!("{}/v1/chat/completions", self.base_url); let url = format!("{}/v1/chat/completions", base_url);
let mut req = self.client.post(&url).json(&payload); let mut req = self.client.post(&url).json(&payload);
// Forward Authorization header if provided // Forward Authorization header if provided
...@@ -534,7 +692,17 @@ impl crate::routers::RouterTrait for OpenAIRouter { ...@@ -534,7 +692,17 @@ impl crate::routers::RouterTrait for OpenAIRouter {
body: &ResponsesRequest, body: &ResponsesRequest,
model_id: Option<&str>, model_id: Option<&str>,
) -> Response { ) -> Response {
let url = format!("{}/v1/responses", self.base_url); // Extract auth header
let auth = extract_auth_header(headers);
// Find endpoint for model (use model_id if provided, otherwise use body.model)
let model = model_id.unwrap_or(body.model.as_str());
let base_url = match self.find_endpoint_for_model(model, auth).await {
Ok(url) => url,
Err(response) => return response,
};
let url = format!("{}/v1/responses", base_url);
// Validate mutually exclusive params: previous_response_id and conversation // Validate mutually exclusive params: previous_response_id and conversation
// TODO: this validation logic should move the right place, also we need a proper error message module // TODO: this validation logic should move the right place, also we need a proper error message module
...@@ -556,7 +724,7 @@ impl crate::routers::RouterTrait for OpenAIRouter { ...@@ -556,7 +724,7 @@ impl crate::routers::RouterTrait for OpenAIRouter {
// Clone the body for validation and logic, but we'll build payload differently // Clone the body for validation and logic, but we'll build payload differently
let mut request_body = body.clone(); let mut request_body = body.clone();
if let Some(model) = model_id { if let Some(model) = model_id {
request_body.model = Some(model.to_string()); request_body.model = model.to_string();
} }
// Do not forward conversation field upstream; retain for local persistence only // Do not forward conversation field upstream; retain for local persistence only
request_body.conversation = None; request_body.conversation = None;
...@@ -847,34 +1015,12 @@ impl crate::routers::RouterTrait for OpenAIRouter { ...@@ -847,34 +1015,12 @@ impl crate::routers::RouterTrait for OpenAIRouter {
} }
} }
async fn cancel_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response { async fn cancel_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
// Forward cancellation to upstream (
let url = format!("{}/v1/responses/{}/cancel", self.base_url, response_id); StatusCode::NOT_IMPLEMENTED,
let mut req = self.client.post(&url); "Cancel response not implemented for OpenAI router",
)
if let Some(h) = headers { .into_response()
req = apply_request_headers(h, req, false);
}
match req.send().await {
Ok(resp) => {
let status = StatusCode::from_u16(resp.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
match resp.text().await {
Ok(body) => (status, body).into_response(),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to read response: {}", e),
)
.into_response(),
}
}
Err(e) => (
StatusCode::BAD_GATEWAY,
format!("Failed to contact upstream: {}", e),
)
.into_response(),
}
} }
async fn route_embeddings( async fn route_embeddings(
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
use std::collections::HashMap; use std::collections::HashMap;
use axum::http::{HeaderMap, HeaderValue};
// ============================================================================ // ============================================================================
// SSE Event Type Constants // SSE Event Type Constants
// ============================================================================ // ============================================================================
...@@ -93,6 +95,131 @@ impl OutputIndexMapper { ...@@ -93,6 +95,131 @@ impl OutputIndexMapper {
} }
} }
// ============================================================================
// Provider Detection and Header Handling
// ============================================================================
/// Extract authorization header from request headers
/// Checks both "authorization" and "Authorization" (case variations)
pub fn extract_auth_header(headers: Option<&HeaderMap>) -> Option<&str> {
headers.and_then(|h| {
h.get("authorization")
.or_else(|| h.get("Authorization"))
.and_then(|v| v.to_str().ok())
})
}
/// API provider types
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ApiProvider {
Anthropic,
Xai,
OpenAi,
Gemini,
Generic,
}
impl ApiProvider {
/// Detect provider type from URL
pub fn from_url(url: &str) -> Self {
if url.contains("anthropic") {
ApiProvider::Anthropic
} else if url.contains("x.ai") {
ApiProvider::Xai
} else if url.contains("openai.com") {
ApiProvider::OpenAi
} else if url.contains("googleapis.com") {
ApiProvider::Gemini
} else {
ApiProvider::Generic
}
}
}
/// Apply provider-specific headers to request
pub fn apply_provider_headers(
mut req: reqwest::RequestBuilder,
url: &str,
auth_header: Option<&HeaderValue>,
) -> reqwest::RequestBuilder {
let provider = ApiProvider::from_url(url);
match provider {
ApiProvider::Anthropic => {
// Anthropic requires x-api-key instead of Authorization
// Extract Bearer token and use as x-api-key
if let Some(auth) = auth_header {
if let Ok(auth_str) = auth.to_str() {
let api_key = auth_str.strip_prefix("Bearer ").unwrap_or(auth_str);
req = req
.header("x-api-key", api_key)
.header("anthropic-version", "2023-06-01");
}
}
}
ApiProvider::Gemini | ApiProvider::Xai | ApiProvider::OpenAi | ApiProvider::Generic => {
// Standard OpenAI-compatible: use Authorization header as-is
if let Some(auth) = auth_header {
req = req.header("Authorization", auth);
}
}
}
req
}
/// Probe a single endpoint to check if it has the model
/// Returns Ok(url) if model found, Err(()) otherwise
pub async fn probe_endpoint_for_model(
client: reqwest::Client,
url: String,
model: String,
auth: Option<String>,
) -> Result<String, ()> {
use tracing::debug;
let probe_url = format!("{}/v1/models/{}", url, model);
let req = client
.get(&probe_url)
.timeout(std::time::Duration::from_secs(5));
// Apply provider-specific headers (handles Anthropic, xAI, OpenAI, etc.)
let auth_header_value = auth.as_ref().and_then(|a| HeaderValue::from_str(a).ok());
let req = apply_provider_headers(req, &url, auth_header_value.as_ref());
match req.send().await {
Ok(resp) => {
let status = resp.status();
if status.is_success() {
debug!(
url = %url,
model = %model,
status = %status,
"Model found on endpoint"
);
Ok(url)
} else {
debug!(
url = %url,
model = %model,
status = %status,
"Model not found on endpoint (unsuccessful status)"
);
Err(())
}
}
Err(e) => {
debug!(
url = %url,
model = %model,
error = %e,
"Probe request to endpoint failed"
);
Err(())
}
}
}
// ============================================================================ // ============================================================================
// Re-export FunctionCallInProgress from mcp module // Re-export FunctionCallInProgress from mcp module
// ============================================================================ // ============================================================================
......
...@@ -410,7 +410,7 @@ impl RouterTrait for RouterManager { ...@@ -410,7 +410,7 @@ impl RouterTrait for RouterManager {
body: &ResponsesRequest, body: &ResponsesRequest,
model_id: Option<&str>, model_id: Option<&str>,
) -> Response { ) -> Response {
let selected_model = body.model.as_deref().or(model_id); let selected_model = model_id.or(Some(body.model.as_str()));
let router = self.select_router_for_request(headers, selected_model); let router = self.select_router_for_request(headers, selected_model);
if let Some(router) = router { if let Some(router) = router {
......
...@@ -100,7 +100,7 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() { ...@@ -100,7 +100,7 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
max_output_tokens: Some(64), max_output_tokens: Some(64),
max_tool_calls: None, max_tool_calls: None,
metadata: None, metadata: None,
model: Some("mock-model".to_string()), model: "mock-model".to_string(),
parallel_tool_calls: Some(true), parallel_tool_calls: Some(true),
previous_response_id: None, previous_response_id: None,
reasoning: None, reasoning: None,
...@@ -134,7 +134,7 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() { ...@@ -134,7 +134,7 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
}; };
let resp = router let resp = router
.route_responses(None, &req, req.model.as_deref()) .route_responses(None, &req, Some(req.model.as_str()))
.await; .await;
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
...@@ -349,7 +349,7 @@ fn test_responses_request_creation() { ...@@ -349,7 +349,7 @@ fn test_responses_request_creation() {
max_output_tokens: Some(100), max_output_tokens: Some(100),
max_tool_calls: None, max_tool_calls: None,
metadata: None, metadata: None,
model: Some("test-model".to_string()), model: "test-model".to_string(),
parallel_tool_calls: Some(true), parallel_tool_calls: Some(true),
previous_response_id: None, previous_response_id: None,
reasoning: Some(ResponseReasoningParam { reasoning: Some(ResponseReasoningParam {
...@@ -397,7 +397,7 @@ fn test_responses_request_sglang_extensions() { ...@@ -397,7 +397,7 @@ fn test_responses_request_sglang_extensions() {
max_output_tokens: Some(50), max_output_tokens: Some(50),
max_tool_calls: None, max_tool_calls: None,
metadata: None, metadata: None,
model: Some("test-model".to_string()), model: "test-model".to_string(),
parallel_tool_calls: Some(true), parallel_tool_calls: Some(true),
previous_response_id: None, previous_response_id: None,
reasoning: None, reasoning: None,
...@@ -506,7 +506,7 @@ fn test_json_serialization() { ...@@ -506,7 +506,7 @@ fn test_json_serialization() {
max_output_tokens: Some(200), max_output_tokens: Some(200),
max_tool_calls: Some(5), max_tool_calls: Some(5),
metadata: None, metadata: None,
model: Some("gpt-4".to_string()), model: "gpt-4".to_string(),
parallel_tool_calls: Some(false), parallel_tool_calls: Some(false),
previous_response_id: None, previous_response_id: None,
reasoning: Some(ResponseReasoningParam { reasoning: Some(ResponseReasoningParam {
...@@ -545,7 +545,7 @@ fn test_json_serialization() { ...@@ -545,7 +545,7 @@ fn test_json_serialization() {
parsed.request_id, parsed.request_id,
Some("resp_comprehensive_test".to_string()) Some("resp_comprehensive_test".to_string())
); );
assert_eq!(parsed.model, Some("gpt-4".to_string())); assert_eq!(parsed.model, "gpt-4");
assert_eq!(parsed.background, Some(true)); assert_eq!(parsed.background, Some(true));
assert_eq!(parsed.stream, Some(true)); assert_eq!(parsed.stream, Some(true));
assert_eq!(parsed.tools.as_ref().map(|t| t.len()), Some(1)); assert_eq!(parsed.tools.as_ref().map(|t| t.len()), Some(1));
...@@ -636,7 +636,7 @@ async fn test_multi_turn_loop_with_mcp() { ...@@ -636,7 +636,7 @@ async fn test_multi_turn_loop_with_mcp() {
max_output_tokens: Some(128), max_output_tokens: Some(128),
max_tool_calls: None, // No limit - test unlimited max_tool_calls: None, // No limit - test unlimited
metadata: None, metadata: None,
model: Some("mock-model".to_string()), model: "mock-model".to_string(),
parallel_tool_calls: Some(true), parallel_tool_calls: Some(true),
previous_response_id: None, previous_response_id: None,
reasoning: None, reasoning: None,
...@@ -812,7 +812,7 @@ async fn test_max_tool_calls_limit() { ...@@ -812,7 +812,7 @@ async fn test_max_tool_calls_limit() {
max_output_tokens: Some(128), max_output_tokens: Some(128),
max_tool_calls: Some(1), // Limit to 1 call max_tool_calls: Some(1), // Limit to 1 call
metadata: None, metadata: None,
model: Some("mock-model".to_string()), model: "mock-model".to_string(),
parallel_tool_calls: Some(true), parallel_tool_calls: Some(true),
previous_response_id: None, previous_response_id: None,
reasoning: None, reasoning: None,
...@@ -1006,7 +1006,7 @@ async fn test_streaming_with_mcp_tool_calls() { ...@@ -1006,7 +1006,7 @@ async fn test_streaming_with_mcp_tool_calls() {
max_output_tokens: Some(256), max_output_tokens: Some(256),
max_tool_calls: Some(3), max_tool_calls: Some(3),
metadata: None, metadata: None,
model: Some("mock-model".to_string()), model: "mock-model".to_string(),
parallel_tool_calls: Some(true), parallel_tool_calls: Some(true),
previous_response_id: None, previous_response_id: None,
reasoning: None, reasoning: None,
...@@ -1287,7 +1287,7 @@ async fn test_streaming_multi_turn_with_mcp() { ...@@ -1287,7 +1287,7 @@ async fn test_streaming_multi_turn_with_mcp() {
max_output_tokens: Some(512), max_output_tokens: Some(512),
max_tool_calls: Some(5), // Allow multiple rounds max_tool_calls: Some(5), // Allow multiple rounds
metadata: None, metadata: None,
model: Some("mock-model".to_string()), model: "mock-model".to_string(),
parallel_tool_calls: Some(true), parallel_tool_calls: Some(true),
previous_response_id: None, previous_response_id: None,
reasoning: None, reasoning: None,
......
...@@ -99,7 +99,7 @@ fn create_minimal_completion_request() -> CompletionRequest { ...@@ -99,7 +99,7 @@ fn create_minimal_completion_request() -> CompletionRequest {
#[tokio::test] #[tokio::test]
async fn test_openai_router_creation() { async fn test_openai_router_creation() {
let router = OpenAIRouter::new( let router = OpenAIRouter::new(
"https://api.openai.com".to_string(), vec!["https://api.openai.com".to_string()],
None, None,
Arc::new(MemoryResponseStorage::new()), Arc::new(MemoryResponseStorage::new()),
Arc::new(MemoryConversationStorage::new()), Arc::new(MemoryConversationStorage::new()),
...@@ -118,7 +118,7 @@ async fn test_openai_router_creation() { ...@@ -118,7 +118,7 @@ async fn test_openai_router_creation() {
#[tokio::test] #[tokio::test]
async fn test_openai_router_server_info() { async fn test_openai_router_server_info() {
let router = OpenAIRouter::new( let router = OpenAIRouter::new(
"https://api.openai.com".to_string(), vec!["https://api.openai.com".to_string()],
None, None,
Arc::new(MemoryResponseStorage::new()), Arc::new(MemoryResponseStorage::new()),
Arc::new(MemoryConversationStorage::new()), Arc::new(MemoryConversationStorage::new()),
...@@ -149,7 +149,7 @@ async fn test_openai_router_models() { ...@@ -149,7 +149,7 @@ async fn test_openai_router_models() {
// Use mock server for deterministic models response // Use mock server for deterministic models response
let mock_server = MockOpenAIServer::new().await; let mock_server = MockOpenAIServer::new().await;
let router = OpenAIRouter::new( let router = OpenAIRouter::new(
mock_server.base_url(), vec![mock_server.base_url()],
None, None,
Arc::new(MemoryResponseStorage::new()), Arc::new(MemoryResponseStorage::new()),
Arc::new(MemoryConversationStorage::new()), Arc::new(MemoryConversationStorage::new()),
...@@ -229,7 +229,7 @@ async fn test_openai_router_responses_with_mock() { ...@@ -229,7 +229,7 @@ async fn test_openai_router_responses_with_mock() {
let storage = Arc::new(MemoryResponseStorage::new()); let storage = Arc::new(MemoryResponseStorage::new());
let router = OpenAIRouter::new( let router = OpenAIRouter::new(
base_url, vec![base_url],
None, None,
storage.clone(), storage.clone(),
Arc::new(MemoryConversationStorage::new()), Arc::new(MemoryConversationStorage::new()),
...@@ -239,7 +239,7 @@ async fn test_openai_router_responses_with_mock() { ...@@ -239,7 +239,7 @@ async fn test_openai_router_responses_with_mock() {
.unwrap(); .unwrap();
let request1 = ResponsesRequest { let request1 = ResponsesRequest {
model: Some("gpt-4o-mini".to_string()), model: "gpt-4o-mini".to_string(),
input: ResponseInput::Text("Say hi".to_string()), input: ResponseInput::Text("Say hi".to_string()),
store: Some(true), store: Some(true),
..Default::default() ..Default::default()
...@@ -255,7 +255,7 @@ async fn test_openai_router_responses_with_mock() { ...@@ -255,7 +255,7 @@ async fn test_openai_router_responses_with_mock() {
assert_eq!(body1["previous_response_id"], serde_json::Value::Null); assert_eq!(body1["previous_response_id"], serde_json::Value::Null);
let request2 = ResponsesRequest { let request2 = ResponsesRequest {
model: Some("gpt-4o-mini".to_string()), model: "gpt-4o-mini".to_string(),
input: ResponseInput::Text("Thanks".to_string()), input: ResponseInput::Text("Thanks".to_string()),
store: Some(true), store: Some(true),
previous_response_id: Some(resp1_id.clone()), previous_response_id: Some(resp1_id.clone()),
...@@ -490,7 +490,7 @@ async fn test_openai_router_responses_streaming_with_mock() { ...@@ -490,7 +490,7 @@ async fn test_openai_router_responses_streaming_with_mock() {
storage.store_response(previous).await.unwrap(); storage.store_response(previous).await.unwrap();
let router = OpenAIRouter::new( let router = OpenAIRouter::new(
base_url, vec![base_url],
None, None,
storage.clone(), storage.clone(),
Arc::new(MemoryConversationStorage::new()), Arc::new(MemoryConversationStorage::new()),
...@@ -503,7 +503,7 @@ async fn test_openai_router_responses_streaming_with_mock() { ...@@ -503,7 +503,7 @@ async fn test_openai_router_responses_streaming_with_mock() {
metadata.insert("topic".to_string(), json!("unicorns")); metadata.insert("topic".to_string(), json!("unicorns"));
let request = ResponsesRequest { let request = ResponsesRequest {
model: Some("gpt-5-nano".to_string()), model: "gpt-5-nano".to_string(),
input: ResponseInput::Text("Tell me a bedtime story.".to_string()), input: ResponseInput::Text("Tell me a bedtime story.".to_string()),
instructions: Some("Be kind".to_string()), instructions: Some("Be kind".to_string()),
metadata: Some(metadata), metadata: Some(metadata),
...@@ -595,7 +595,7 @@ async fn test_router_factory_openai_mode() { ...@@ -595,7 +595,7 @@ async fn test_router_factory_openai_mode() {
#[tokio::test] #[tokio::test]
async fn test_unsupported_endpoints() { async fn test_unsupported_endpoints() {
let router = OpenAIRouter::new( let router = OpenAIRouter::new(
"https://api.openai.com".to_string(), vec!["https://api.openai.com".to_string()],
None, None,
Arc::new(MemoryResponseStorage::new()), Arc::new(MemoryResponseStorage::new()),
Arc::new(MemoryConversationStorage::new()), Arc::new(MemoryConversationStorage::new()),
...@@ -660,7 +660,7 @@ async fn test_openai_router_chat_completion_with_mock() { ...@@ -660,7 +660,7 @@ async fn test_openai_router_chat_completion_with_mock() {
// Create router pointing to mock server // Create router pointing to mock server
let router = OpenAIRouter::new( let router = OpenAIRouter::new(
base_url, vec![base_url],
None, None,
Arc::new(MemoryResponseStorage::new()), Arc::new(MemoryResponseStorage::new()),
Arc::new(MemoryConversationStorage::new()), Arc::new(MemoryConversationStorage::new()),
...@@ -702,7 +702,7 @@ async fn test_openai_e2e_with_server() { ...@@ -702,7 +702,7 @@ async fn test_openai_e2e_with_server() {
// Create router // Create router
let router = OpenAIRouter::new( let router = OpenAIRouter::new(
base_url, vec![base_url],
None, None,
Arc::new(MemoryResponseStorage::new()), Arc::new(MemoryResponseStorage::new()),
Arc::new(MemoryConversationStorage::new()), Arc::new(MemoryConversationStorage::new()),
...@@ -773,7 +773,7 @@ async fn test_openai_router_chat_streaming_with_mock() { ...@@ -773,7 +773,7 @@ async fn test_openai_router_chat_streaming_with_mock() {
let mock_server = MockOpenAIServer::new().await; let mock_server = MockOpenAIServer::new().await;
let base_url = mock_server.base_url(); let base_url = mock_server.base_url();
let router = OpenAIRouter::new( let router = OpenAIRouter::new(
base_url, vec![base_url],
None, None,
Arc::new(MemoryResponseStorage::new()), Arc::new(MemoryResponseStorage::new()),
Arc::new(MemoryConversationStorage::new()), Arc::new(MemoryConversationStorage::new()),
...@@ -827,7 +827,7 @@ async fn test_openai_router_circuit_breaker() { ...@@ -827,7 +827,7 @@ async fn test_openai_router_circuit_breaker() {
}; };
let router = OpenAIRouter::new( let router = OpenAIRouter::new(
"http://invalid-url-that-will-fail".to_string(), vec!["http://invalid-url-that-will-fail".to_string()],
Some(cb_config), Some(cb_config),
Arc::new(MemoryResponseStorage::new()), Arc::new(MemoryResponseStorage::new()),
Arc::new(MemoryConversationStorage::new()), Arc::new(MemoryConversationStorage::new()),
...@@ -856,7 +856,7 @@ async fn test_openai_router_models_auth_forwarding() { ...@@ -856,7 +856,7 @@ async fn test_openai_router_models_auth_forwarding() {
let expected_auth = "Bearer test-token".to_string(); let expected_auth = "Bearer test-token".to_string();
let mock_server = MockOpenAIServer::new_with_auth(Some(expected_auth.clone())).await; let mock_server = MockOpenAIServer::new_with_auth(Some(expected_auth.clone())).await;
let router = OpenAIRouter::new( let router = OpenAIRouter::new(
mock_server.base_url(), vec![mock_server.base_url()],
None, None,
Arc::new(MemoryResponseStorage::new()), Arc::new(MemoryResponseStorage::new()),
Arc::new(MemoryConversationStorage::new()), Arc::new(MemoryConversationStorage::new()),
...@@ -865,7 +865,8 @@ async fn test_openai_router_models_auth_forwarding() { ...@@ -865,7 +865,8 @@ async fn test_openai_router_models_auth_forwarding() {
.await .await
.unwrap(); .unwrap();
// 1) Without auth header -> expect 401 // 1) Without auth header -> expect 200 with empty model list
// (multi-endpoint aggregation silently skips failed endpoints)
let req = Request::builder() let req = Request::builder()
.method(Method::GET) .method(Method::GET)
.uri("/models") .uri("/models")
...@@ -873,7 +874,13 @@ async fn test_openai_router_models_auth_forwarding() { ...@@ -873,7 +874,13 @@ async fn test_openai_router_models_auth_forwarding() {
.unwrap(); .unwrap();
let response = router.get_models(req).await; let response = router.get_models(req).await;
assert_eq!(response.status(), StatusCode::UNAUTHORIZED); assert_eq!(response.status(), StatusCode::OK);
let (_, body) = response.into_parts();
let body_bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap();
let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
let models: serde_json::Value = serde_json::from_str(&body_str).unwrap();
assert_eq!(models["object"], "list");
assert_eq!(models["data"].as_array().unwrap().len(), 0); // Empty when auth fails
// 2) With auth header -> expect 200 // 2) With auth header -> expect 200
let req = Request::builder() let req = Request::builder()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment