Unverified Commit 98c3b04f authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] responses api POST and GET with local storage (#10581)


Co-authored-by: default avatarkey4ng <rukeyang@gmail.com>
parent ddab4fc7
...@@ -74,13 +74,16 @@ impl ResponseStorage for MemoryResponseStorage { ...@@ -74,13 +74,16 @@ impl ResponseStorage for MemoryResponseStorage {
// Store the response // Store the response
store.responses.insert(response_id.clone(), response); store.responses.insert(response_id.clone(), response);
tracing::info!("memory_store_size" = store.responses.len());
Ok(response_id) Ok(response_id)
} }
async fn get_response(&self, response_id: &ResponseId) -> Result<Option<StoredResponse>> { async fn get_response(&self, response_id: &ResponseId) -> Result<Option<StoredResponse>> {
let store = self.store.read(); let store = self.store.read();
Ok(store.responses.get(response_id).cloned()) let result = store.responses.get(response_id).cloned();
tracing::info!("memory_get_response" = %response_id.0, found = result.is_some());
Ok(result)
} }
async fn delete_response(&self, response_id: &ResponseId) -> Result<()> { async fn delete_response(&self, response_id: &ResponseId) -> Result<()> {
...@@ -200,6 +203,20 @@ pub struct MemoryStoreStats { ...@@ -200,6 +203,20 @@ pub struct MemoryStoreStats {
mod tests { mod tests {
use super::*; use super::*;
#[tokio::test]
async fn test_store_with_custom_id() {
let store = MemoryResponseStorage::new();
let mut response = StoredResponse::new("Input".to_string(), "Output".to_string(), None);
response.id = ResponseId::from_string("resp_custom".to_string());
store.store_response(response.clone()).await.unwrap();
let retrieved = store
.get_response(&ResponseId::from_string("resp_custom".to_string()))
.await
.unwrap();
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().output, "Output");
}
#[tokio::test] #[tokio::test]
async fn test_memory_store_basic() { async fn test_memory_store_basic() {
let store = MemoryResponseStorage::new(); let store = MemoryResponseStorage::new();
......
use async_trait::async_trait; use async_trait::async_trait;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
...@@ -55,6 +56,10 @@ pub struct StoredResponse { ...@@ -55,6 +56,10 @@ pub struct StoredResponse {
/// Model used for generation /// Model used for generation
pub model: Option<String>, pub model: Option<String>,
/// Raw OpenAI response payload
#[serde(default)]
pub raw_response: Value,
} }
impl StoredResponse { impl StoredResponse {
...@@ -70,6 +75,7 @@ impl StoredResponse { ...@@ -70,6 +75,7 @@ impl StoredResponse {
created_at: chrono::Utc::now(), created_at: chrono::Utc::now(),
user: None, user: None,
model: None, model: None,
raw_response: Value::Null,
} }
} }
} }
...@@ -175,3 +181,9 @@ pub trait ResponseStorage: Send + Sync { ...@@ -175,3 +181,9 @@ pub trait ResponseStorage: Send + Sync {
/// Type alias for shared storage /// Type alias for shared storage
pub type SharedResponseStorage = Arc<dyn ResponseStorage>; pub type SharedResponseStorage = Arc<dyn ResponseStorage>;
impl Default for StoredResponse {
fn default() -> Self {
Self::new(String::new(), String::new(), None)
}
}
This diff is collapsed.
...@@ -166,8 +166,12 @@ impl RouterFactory { ...@@ -166,8 +166,12 @@ impl RouterFactory {
.cloned() .cloned()
.ok_or_else(|| "OpenAI mode requires at least one worker URL".to_string())?; .ok_or_else(|| "OpenAI mode requires at least one worker URL".to_string())?;
let router = let router = OpenAIRouter::new(
OpenAIRouter::new(base_url, Some(ctx.router_config.circuit_breaker.clone())).await?; base_url,
Some(ctx.router_config.circuit_breaker.clone()),
ctx.response_storage.clone(),
)
.await?;
Ok(Box::new(router)) Ok(Box::new(router))
} }
......
...@@ -308,7 +308,12 @@ impl RouterTrait for GrpcPDRouter { ...@@ -308,7 +308,12 @@ impl RouterTrait for GrpcPDRouter {
(StatusCode::NOT_IMPLEMENTED).into_response() (StatusCode::NOT_IMPLEMENTED).into_response()
} }
async fn get_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response { async fn get_response(
&self,
_headers: Option<&HeaderMap>,
_response_id: &str,
_params: &crate::protocols::spec::ResponsesGetParams,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response() (StatusCode::NOT_IMPLEMENTED).into_response()
} }
......
...@@ -237,7 +237,12 @@ impl RouterTrait for GrpcRouter { ...@@ -237,7 +237,12 @@ impl RouterTrait for GrpcRouter {
(StatusCode::NOT_IMPLEMENTED).into_response() (StatusCode::NOT_IMPLEMENTED).into_response()
} }
async fn get_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response { async fn get_response(
&self,
_headers: Option<&HeaderMap>,
_response_id: &str,
_params: &crate::protocols::spec::ResponsesGetParams,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response() (StatusCode::NOT_IMPLEMENTED).into_response()
} }
......
...@@ -51,3 +51,45 @@ fn should_forward_header(name: &str) -> bool { ...@@ -51,3 +51,45 @@ fn should_forward_header(name: &str) -> bool {
"host" // Should not forward the backend's host header "host" // Should not forward the backend's host header
) )
} }
/// Apply headers to a reqwest request builder, filtering out headers that shouldn't be forwarded
/// or that will be set automatically by reqwest
pub fn apply_request_headers(
headers: &HeaderMap,
mut request_builder: reqwest::RequestBuilder,
skip_content_headers: bool,
) -> reqwest::RequestBuilder {
// Always forward Authorization header first if present
if let Some(auth) = headers
.get("authorization")
.or_else(|| headers.get("Authorization"))
{
request_builder = request_builder.header("Authorization", auth.clone());
}
// Forward other headers, filtering out problematic ones
for (key, value) in headers.iter() {
let key_str = key.as_str().to_lowercase();
// Skip headers that:
// - Are set automatically by reqwest (content-type, content-length for POST/PUT)
// - We already handled (authorization)
// - Are hop-by-hop headers (connection, transfer-encoding)
// - Should not be forwarded (host)
let should_skip = key_str == "authorization" || // Already handled above
key_str == "host" ||
key_str == "connection" ||
key_str == "transfer-encoding" ||
key_str == "keep-alive" ||
key_str == "te" ||
key_str == "trailers" ||
key_str == "upgrade" ||
(skip_content_headers && (key_str == "content-type" || key_str == "content-length"));
if !should_skip {
request_builder = request_builder.header(key.clone(), value.clone());
}
}
request_builder
}
...@@ -8,7 +8,7 @@ use crate::metrics::RouterMetrics; ...@@ -8,7 +8,7 @@ use crate::metrics::RouterMetrics;
use crate::policies::{LoadBalancingPolicy, PolicyRegistry}; use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
use crate::protocols::spec::{ use crate::protocols::spec::{
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, RerankRequest, ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, RerankRequest,
ResponsesRequest, StringOrArray, UserMessageContent, ResponsesGetParams, ResponsesRequest, StringOrArray, UserMessageContent,
}; };
use crate::routers::header_utils; use crate::routers::header_utils;
use crate::routers::RouterTrait; use crate::routers::RouterTrait;
...@@ -1424,7 +1424,12 @@ impl RouterTrait for PDRouter { ...@@ -1424,7 +1424,12 @@ impl RouterTrait for PDRouter {
.into_response() .into_response()
} }
async fn get_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response { async fn get_response(
&self,
_headers: Option<&HeaderMap>,
_response_id: &str,
_params: &ResponsesGetParams,
) -> Response {
( (
StatusCode::NOT_IMPLEMENTED, StatusCode::NOT_IMPLEMENTED,
"Responses retrieve endpoint not implemented for PD router", "Responses retrieve endpoint not implemented for PD router",
......
...@@ -6,7 +6,7 @@ use crate::metrics::RouterMetrics; ...@@ -6,7 +6,7 @@ use crate::metrics::RouterMetrics;
use crate::policies::{LoadBalancingPolicy, PolicyRegistry}; use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
use crate::protocols::spec::{ use crate::protocols::spec::{
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, GenerationRequest, ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, GenerationRequest,
RerankRequest, RerankResponse, RerankResult, ResponsesRequest, RerankRequest, RerankResponse, RerankResult, ResponsesGetParams, ResponsesRequest,
}; };
use crate::routers::header_utils; use crate::routers::header_utils;
use crate::routers::RouterTrait; use crate::routers::RouterTrait;
...@@ -903,7 +903,12 @@ impl RouterTrait for Router { ...@@ -903,7 +903,12 @@ impl RouterTrait for Router {
.await .await
} }
async fn get_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response { async fn get_response(
&self,
headers: Option<&HeaderMap>,
response_id: &str,
_params: &ResponsesGetParams,
) -> Response {
let endpoint = format!("v1/responses/{}", response_id); let endpoint = format!("v1/responses/{}", response_id);
self.route_get_request(headers, &endpoint).await self.route_get_request(headers, &endpoint).await
} }
......
...@@ -11,7 +11,7 @@ use std::fmt::Debug; ...@@ -11,7 +11,7 @@ use std::fmt::Debug;
use crate::protocols::spec::{ use crate::protocols::spec::{
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
ResponsesRequest, ResponsesGetParams, ResponsesRequest,
}; };
pub mod factory; pub mod factory;
...@@ -82,7 +82,12 @@ pub trait RouterTrait: Send + Sync + Debug { ...@@ -82,7 +82,12 @@ pub trait RouterTrait: Send + Sync + Debug {
) -> Response; ) -> Response;
/// Retrieve a stored/background response by id /// Retrieve a stored/background response by id
async fn get_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response; async fn get_response(
&self,
headers: Option<&HeaderMap>,
response_id: &str,
params: &ResponsesGetParams,
) -> Response;
/// Cancel a background response by id /// Cancel a background response by id
async fn cancel_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response; async fn cancel_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response;
......
...@@ -8,7 +8,7 @@ use crate::config::{ConnectionMode, RoutingMode}; ...@@ -8,7 +8,7 @@ use crate::config::{ConnectionMode, RoutingMode};
use crate::core::{WorkerRegistry, WorkerType}; use crate::core::{WorkerRegistry, WorkerType};
use crate::protocols::spec::{ use crate::protocols::spec::{
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
ResponsesRequest, ResponsesGetParams, ResponsesRequest,
}; };
use crate::routers::RouterTrait; use crate::routers::RouterTrait;
use crate::server::{AppContext, ServerConfig}; use crate::server::{AppContext, ServerConfig};
...@@ -402,10 +402,37 @@ impl RouterTrait for RouterManager { ...@@ -402,10 +402,37 @@ impl RouterTrait for RouterManager {
} }
async fn route_responses( async fn route_responses(
&self,
headers: Option<&HeaderMap>,
body: &ResponsesRequest,
model_id: Option<&str>,
) -> Response {
let selected_model = body.model.as_deref().or(model_id);
let router = self.select_router_for_request(headers, selected_model);
if let Some(router) = router {
router.route_responses(headers, body, selected_model).await
} else {
(
StatusCode::NOT_FOUND,
"No router available to handle responses request",
)
.into_response()
}
}
async fn delete_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"responses api not yet implemented in inference gateway mode",
)
.into_response()
}
async fn list_response_input_items(
&self, &self,
_headers: Option<&HeaderMap>, _headers: Option<&HeaderMap>,
_body: &ResponsesRequest, _response_id: &str,
_model_id: Option<&str>,
) -> Response { ) -> Response {
( (
StatusCode::NOT_IMPLEMENTED, StatusCode::NOT_IMPLEMENTED,
...@@ -414,10 +441,15 @@ impl RouterTrait for RouterManager { ...@@ -414,10 +441,15 @@ impl RouterTrait for RouterManager {
.into_response() .into_response()
} }
async fn get_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response { async fn get_response(
&self,
headers: Option<&HeaderMap>,
response_id: &str,
params: &ResponsesGetParams,
) -> Response {
let router = self.select_router_for_request(headers, None); let router = self.select_router_for_request(headers, None);
if let Some(router) = router { if let Some(router) = router {
router.get_response(headers, response_id).await router.get_response(headers, response_id, params).await
} else { } else {
( (
StatusCode::NOT_FOUND, StatusCode::NOT_FOUND,
...@@ -440,26 +472,6 @@ impl RouterTrait for RouterManager { ...@@ -440,26 +472,6 @@ impl RouterTrait for RouterManager {
} }
} }
async fn delete_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"responses api not yet implemented in inference gateway mode",
)
.into_response()
}
async fn list_response_input_items(
&self,
_headers: Option<&HeaderMap>,
_response_id: &str,
) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"responses api not yet implemented in inference gateway mode",
)
.into_response()
}
async fn route_embeddings( async fn route_embeddings(
&self, &self,
headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
......
...@@ -9,7 +9,7 @@ use crate::{ ...@@ -9,7 +9,7 @@ use crate::{
protocols::{ protocols::{
spec::{ spec::{
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest,
RerankRequest, ResponsesRequest, V1RerankReqInput, RerankRequest, ResponsesGetParams, ResponsesRequest, V1RerankReqInput,
}, },
worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse}, worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse},
}, },
...@@ -224,10 +224,11 @@ async fn v1_responses_get( ...@@ -224,10 +224,11 @@ async fn v1_responses_get(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
Path(response_id): Path<String>, Path(response_id): Path<String>,
headers: http::HeaderMap, headers: http::HeaderMap,
Query(params): Query<ResponsesGetParams>,
) -> Response { ) -> Response {
state state
.router .router
.get_response(Some(&headers), &response_id) .get_response(Some(&headers), &response_id, &params)
.await .await
} }
......
...@@ -5,17 +5,23 @@ use axum::{ ...@@ -5,17 +5,23 @@ use axum::{
extract::Request, extract::Request,
http::{Method, StatusCode}, http::{Method, StatusCode},
routing::post, routing::post,
Router, Json, Router,
}; };
use serde_json::json; use serde_json::json;
use sglang_router_rs::{ use sglang_router_rs::{
config::{RouterConfig, RoutingMode}, config::{RouterConfig, RoutingMode},
data_connector::{MemoryResponseStorage, ResponseId, ResponseStorage},
protocols::spec::{ protocols::spec::{
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, UserMessageContent, ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, ResponseInput,
ResponsesGetParams, ResponsesRequest, UserMessageContent,
}, },
routers::{openai_router::OpenAIRouter, RouterTrait}, routers::{openai_router::OpenAIRouter, RouterTrait},
}; };
use std::sync::Arc; use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
use tokio::net::TcpListener;
use tower::ServiceExt; use tower::ServiceExt;
mod common; mod common;
...@@ -78,7 +84,12 @@ fn create_minimal_completion_request() -> CompletionRequest { ...@@ -78,7 +84,12 @@ fn create_minimal_completion_request() -> CompletionRequest {
/// Test basic OpenAI router creation and configuration /// Test basic OpenAI router creation and configuration
#[tokio::test] #[tokio::test]
async fn test_openai_router_creation() { async fn test_openai_router_creation() {
let router = OpenAIRouter::new("https://api.openai.com".to_string(), None).await; let router = OpenAIRouter::new(
"https://api.openai.com".to_string(),
None,
Arc::new(MemoryResponseStorage::new()),
)
.await;
assert!(router.is_ok(), "Router creation should succeed"); assert!(router.is_ok(), "Router creation should succeed");
...@@ -90,7 +101,11 @@ async fn test_openai_router_creation() { ...@@ -90,7 +101,11 @@ async fn test_openai_router_creation() {
/// Test health endpoints /// Test health endpoints
#[tokio::test] #[tokio::test]
async fn test_openai_router_health() { async fn test_openai_router_health() {
let router = OpenAIRouter::new("https://api.openai.com".to_string(), None) let router = OpenAIRouter::new(
"https://api.openai.com".to_string(),
None,
Arc::new(MemoryResponseStorage::new()),
)
.await .await
.unwrap(); .unwrap();
...@@ -107,7 +122,11 @@ async fn test_openai_router_health() { ...@@ -107,7 +122,11 @@ async fn test_openai_router_health() {
/// Test server info endpoint /// Test server info endpoint
#[tokio::test] #[tokio::test]
async fn test_openai_router_server_info() { async fn test_openai_router_server_info() {
let router = OpenAIRouter::new("https://api.openai.com".to_string(), None) let router = OpenAIRouter::new(
"https://api.openai.com".to_string(),
None,
Arc::new(MemoryResponseStorage::new()),
)
.await .await
.unwrap(); .unwrap();
...@@ -132,7 +151,11 @@ async fn test_openai_router_server_info() { ...@@ -132,7 +151,11 @@ async fn test_openai_router_server_info() {
async fn test_openai_router_models() { async fn test_openai_router_models() {
// Use mock server for deterministic models response // Use mock server for deterministic models response
let mock_server = MockOpenAIServer::new().await; let mock_server = MockOpenAIServer::new().await;
let router = OpenAIRouter::new(mock_server.base_url(), None) let router = OpenAIRouter::new(
mock_server.base_url(),
None,
Arc::new(MemoryResponseStorage::new()),
)
.await .await
.unwrap(); .unwrap();
...@@ -154,6 +177,138 @@ async fn test_openai_router_models() { ...@@ -154,6 +177,138 @@ async fn test_openai_router_models() {
assert!(models["data"].is_array()); assert!(models["data"].is_array());
} }
#[tokio::test]
async fn test_openai_router_responses_with_mock() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = counter.clone();
let app = Router::new().route(
"/v1/responses",
post({
move |Json(request): Json<serde_json::Value>| {
let counter = counter_clone.clone();
async move {
let idx = counter.fetch_add(1, Ordering::SeqCst) + 1;
let model = request
.get("model")
.and_then(|v| v.as_str())
.unwrap_or("gpt-4o-mini")
.to_string();
let id = format!("resp_mock_{idx}");
let response = json!({
"id": id,
"object": "response",
"created_at": 1_700_000_000 + idx as i64,
"status": "completed",
"model": model,
"output": [{
"type": "message",
"id": format!("msg_{idx}"),
"role": "assistant",
"status": "completed",
"content": [{
"type": "output_text",
"text": format!("mock_output_{idx}"),
"annotations": []
}]
}],
"metadata": {}
});
Json(response)
}
}
}),
);
let server = tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
let base_url = format!("http://{}", addr);
let storage = Arc::new(MemoryResponseStorage::new());
let router = OpenAIRouter::new(base_url, None, storage.clone())
.await
.unwrap();
let request1 = ResponsesRequest {
model: Some("gpt-4o-mini".to_string()),
input: ResponseInput::Text("Say hi".to_string()),
store: true,
..Default::default()
};
let response1 = router.route_responses(None, &request1, None).await;
assert_eq!(response1.status(), StatusCode::OK);
let body1_bytes = axum::body::to_bytes(response1.into_body(), usize::MAX)
.await
.unwrap();
let body1: serde_json::Value = serde_json::from_slice(&body1_bytes).unwrap();
let resp1_id = body1["id"].as_str().expect("id missing").to_string();
assert_eq!(body1["previous_response_id"], serde_json::Value::Null);
let request2 = ResponsesRequest {
model: Some("gpt-4o-mini".to_string()),
input: ResponseInput::Text("Thanks".to_string()),
store: true,
previous_response_id: Some(resp1_id.clone()),
..Default::default()
};
let response2 = router.route_responses(None, &request2, None).await;
assert_eq!(response2.status(), StatusCode::OK);
let body2_bytes = axum::body::to_bytes(response2.into_body(), usize::MAX)
.await
.unwrap();
let body2: serde_json::Value = serde_json::from_slice(&body2_bytes).unwrap();
let resp2_id = body2["id"].as_str().expect("second id missing");
assert_eq!(
body2["previous_response_id"].as_str(),
Some(resp1_id.as_str())
);
let stored1 = storage
.get_response(&ResponseId::from_string(resp1_id.clone()))
.await
.unwrap()
.expect("first response missing");
assert_eq!(stored1.input, "Say hi");
assert_eq!(stored1.output, "mock_output_1");
assert!(stored1.previous_response_id.is_none());
let stored2 = storage
.get_response(&ResponseId::from_string(resp2_id.to_string()))
.await
.unwrap()
.expect("second response missing");
assert_eq!(stored2.previous_response_id.unwrap().0, resp1_id);
assert_eq!(stored2.output, "mock_output_2");
let get1 = router
.get_response(None, &stored1.id.0, &ResponsesGetParams::default())
.await;
assert_eq!(get1.status(), StatusCode::OK);
let get1_body_bytes = axum::body::to_bytes(get1.into_body(), usize::MAX)
.await
.unwrap();
let get1_json: serde_json::Value = serde_json::from_slice(&get1_body_bytes).unwrap();
assert_eq!(get1_json, body1);
let get2 = router
.get_response(None, &stored2.id.0, &ResponsesGetParams::default())
.await;
assert_eq!(get2.status(), StatusCode::OK);
let get2_body_bytes = axum::body::to_bytes(get2.into_body(), usize::MAX)
.await
.unwrap();
let get2_json: serde_json::Value = serde_json::from_slice(&get2_body_bytes).unwrap();
assert_eq!(get2_json, body2);
server.abort();
}
/// Test router factory with OpenAI routing mode /// Test router factory with OpenAI routing mode
#[tokio::test] #[tokio::test]
async fn test_router_factory_openai_mode() { async fn test_router_factory_openai_mode() {
...@@ -179,7 +334,11 @@ async fn test_router_factory_openai_mode() { ...@@ -179,7 +334,11 @@ async fn test_router_factory_openai_mode() {
/// Test that unsupported endpoints return proper error codes /// Test that unsupported endpoints return proper error codes
#[tokio::test] #[tokio::test]
async fn test_unsupported_endpoints() { async fn test_unsupported_endpoints() {
let router = OpenAIRouter::new("https://api.openai.com".to_string(), None) let router = OpenAIRouter::new(
"https://api.openai.com".to_string(),
None,
Arc::new(MemoryResponseStorage::new()),
)
.await .await
.unwrap(); .unwrap();
...@@ -219,7 +378,9 @@ async fn test_openai_router_chat_completion_with_mock() { ...@@ -219,7 +378,9 @@ async fn test_openai_router_chat_completion_with_mock() {
let base_url = mock_server.base_url(); let base_url = mock_server.base_url();
// Create router pointing to mock server // Create router pointing to mock server
let router = OpenAIRouter::new(base_url, None).await.unwrap(); let router = OpenAIRouter::new(base_url, None, Arc::new(MemoryResponseStorage::new()))
.await
.unwrap();
// Create a minimal chat completion request // Create a minimal chat completion request
let mut chat_request = create_minimal_chat_request(); let mut chat_request = create_minimal_chat_request();
...@@ -255,7 +416,9 @@ async fn test_openai_e2e_with_server() { ...@@ -255,7 +416,9 @@ async fn test_openai_e2e_with_server() {
let base_url = mock_server.base_url(); let base_url = mock_server.base_url();
// Create router // Create router
let router = OpenAIRouter::new(base_url, None).await.unwrap(); let router = OpenAIRouter::new(base_url, None, Arc::new(MemoryResponseStorage::new()))
.await
.unwrap();
// Create Axum app with chat completions endpoint // Create Axum app with chat completions endpoint
let app = Router::new().route( let app = Router::new().route(
...@@ -319,7 +482,9 @@ async fn test_openai_e2e_with_server() { ...@@ -319,7 +482,9 @@ async fn test_openai_e2e_with_server() {
async fn test_openai_router_chat_streaming_with_mock() { async fn test_openai_router_chat_streaming_with_mock() {
let mock_server = MockOpenAIServer::new().await; let mock_server = MockOpenAIServer::new().await;
let base_url = mock_server.base_url(); let base_url = mock_server.base_url();
let router = OpenAIRouter::new(base_url, None).await.unwrap(); let router = OpenAIRouter::new(base_url, None, Arc::new(MemoryResponseStorage::new()))
.await
.unwrap();
// Build a streaming chat request // Build a streaming chat request
let val = json!({ let val = json!({
...@@ -368,6 +533,7 @@ async fn test_openai_router_circuit_breaker() { ...@@ -368,6 +533,7 @@ async fn test_openai_router_circuit_breaker() {
let router = OpenAIRouter::new( let router = OpenAIRouter::new(
"http://invalid-url-that-will-fail".to_string(), "http://invalid-url-that-will-fail".to_string(),
Some(cb_config), Some(cb_config),
Arc::new(MemoryResponseStorage::new()),
) )
.await .await
.unwrap(); .unwrap();
...@@ -391,7 +557,11 @@ async fn test_openai_router_models_auth_forwarding() { ...@@ -391,7 +557,11 @@ async fn test_openai_router_models_auth_forwarding() {
// Start a mock server that requires Authorization // Start a mock server that requires Authorization
let expected_auth = "Bearer test-token".to_string(); let expected_auth = "Bearer test-token".to_string();
let mock_server = MockOpenAIServer::new_with_auth(Some(expected_auth.clone())).await; let mock_server = MockOpenAIServer::new_with_auth(Some(expected_auth.clone())).await;
let router = OpenAIRouter::new(mock_server.base_url(), None) let router = OpenAIRouter::new(
mock_server.base_url(),
None,
Arc::new(MemoryResponseStorage::new()),
)
.await .await
.unwrap(); .unwrap();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment