"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "12ab8f8f5f9a1e206d57cb7ce664bffff5faeed3"
Unverified Commit 4ed67c27 authored by Keyang Ru's avatar Keyang Ru Committed by GitHub
Browse files

[router] support Openai router conversation API CRUD (#11297)

parent cd4b39a9
use async_trait::async_trait;
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
use super::conversations::{
Conversation, ConversationId, ConversationMetadata, ConversationStorage, NewConversation,
Result,
};
/// In-memory conversation storage used for development and tests
#[derive(Default, Clone)]
pub struct MemoryConversationStorage {
inner: Arc<RwLock<HashMap<ConversationId, Conversation>>>,
}
impl MemoryConversationStorage {
pub fn new() -> Self {
Self {
inner: Arc::new(RwLock::new(HashMap::new())),
}
}
}
#[async_trait]
impl ConversationStorage for MemoryConversationStorage {
async fn create_conversation(&self, input: NewConversation) -> Result<Conversation> {
let conversation = Conversation::new(input);
self.inner
.write()
.insert(conversation.id.clone(), conversation.clone());
Ok(conversation)
}
async fn get_conversation(&self, id: &ConversationId) -> Result<Option<Conversation>> {
Ok(self.inner.read().get(id).cloned())
}
async fn update_conversation(
&self,
id: &ConversationId,
metadata: Option<ConversationMetadata>,
) -> Result<Option<Conversation>> {
let mut store = self.inner.write();
if let Some(entry) = store.get_mut(id) {
entry.metadata = metadata;
return Ok(Some(entry.clone()));
}
Ok(None)
}
async fn delete_conversation(&self, id: &ConversationId) -> Result<bool> {
let removed = self.inner.write().remove(id).is_some();
Ok(removed)
}
}
use async_trait::async_trait;
use super::conversations::{
Conversation, ConversationId, ConversationMetadata, ConversationStorage, Result,
};
/// No-op implementation that synthesizes conversation responses without persistence
#[derive(Default, Debug, Clone)]
pub struct NoOpConversationStorage;
impl NoOpConversationStorage {
pub fn new() -> Self {
Self
}
}
#[async_trait]
impl ConversationStorage for NoOpConversationStorage {
async fn create_conversation(
&self,
input: super::conversations::NewConversation,
) -> Result<Conversation> {
Ok(Conversation::new(input))
}
async fn get_conversation(&self, _id: &ConversationId) -> Result<Option<Conversation>> {
Ok(None)
}
async fn update_conversation(
&self,
_id: &ConversationId,
_metadata: Option<ConversationMetadata>,
) -> Result<Option<Conversation>> {
Ok(None)
}
async fn delete_conversation(&self, _id: &ConversationId) -> Result<bool> {
Ok(false)
}
}
use crate::config::OracleConfig;
use crate::data_connector::conversations::{
Conversation, ConversationId, ConversationMetadata, ConversationStorage,
ConversationStorageError, NewConversation, Result,
};
use async_trait::async_trait;
use chrono::Utc;
use deadpool::managed::{Manager, Metrics, Pool, PoolError, RecycleError, RecycleResult};
use oracle::{sql_type::OracleType, Connection};
use serde_json::Value;
use std::path::Path;
use std::sync::Arc;
use std::time::Duration;
#[derive(Clone)]
pub struct OracleConversationStorage {
pool: Pool<ConversationOracleConnectionManager>,
}
impl OracleConversationStorage {
pub fn new(config: OracleConfig) -> Result<Self> {
configure_oracle_client(&config)?;
initialize_schema(&config)?;
let config = Arc::new(config);
let manager = ConversationOracleConnectionManager::new(config.clone());
let mut builder = Pool::builder(manager)
.max_size(config.pool_max)
.runtime(deadpool::Runtime::Tokio1);
if config.pool_timeout_secs > 0 {
builder = builder.wait_timeout(Some(Duration::from_secs(config.pool_timeout_secs)));
}
let pool = builder.build().map_err(|err| {
ConversationStorageError::StorageError(format!(
"failed to build Oracle pool for conversations: {err}"
))
})?;
Ok(Self { pool })
}
async fn with_connection<F, T>(&self, func: F) -> Result<T>
where
F: FnOnce(&Connection) -> Result<T> + Send + 'static,
T: Send + 'static,
{
let connection = self.pool.get().await.map_err(map_pool_error)?;
tokio::task::spawn_blocking(move || {
let result = func(&connection);
drop(connection);
result
})
.await
.map_err(|err| {
ConversationStorageError::StorageError(format!(
"failed to execute Oracle conversation task: {err}"
))
})?
}
fn parse_metadata(raw: Option<String>) -> Result<Option<ConversationMetadata>> {
match raw {
Some(json) if !json.is_empty() => {
let value: Value = serde_json::from_str(&json)?;
match value {
Value::Object(map) => Ok(Some(map)),
Value::Null => Ok(None),
other => Err(ConversationStorageError::StorageError(format!(
"conversation metadata expected object, got {other}"
))),
}
}
_ => Ok(None),
}
}
}
#[async_trait]
impl ConversationStorage for OracleConversationStorage {
async fn create_conversation(&self, input: NewConversation) -> Result<Conversation> {
let conversation = Conversation::new(input);
let id_str = conversation.id.0.clone();
let created_at = conversation.created_at;
let metadata_json = conversation
.metadata
.as_ref()
.map(serde_json::to_string)
.transpose()?;
self.with_connection(move |conn| {
conn.execute(
"INSERT INTO conversations (id, created_at, metadata) VALUES (:1, :2, :3)",
&[&id_str, &created_at, &metadata_json],
)
.map(|_| ())
.map_err(map_oracle_error)
})
.await?;
Ok(conversation)
}
async fn get_conversation(&self, id: &ConversationId) -> Result<Option<Conversation>> {
let lookup = id.0.clone();
self.with_connection(move |conn| {
let mut stmt = conn
.statement("SELECT id, created_at, metadata FROM conversations WHERE id = :1")
.build()
.map_err(map_oracle_error)?;
let mut rows = stmt.query(&[&lookup]).map_err(map_oracle_error)?;
if let Some(row_res) = rows.next() {
let row = row_res.map_err(map_oracle_error)?;
let id: String = row.get(0).map_err(map_oracle_error)?;
let created_at: chrono::DateTime<Utc> = row.get(1).map_err(map_oracle_error)?;
let metadata_raw: Option<String> = row.get(2).map_err(map_oracle_error)?;
let metadata = Self::parse_metadata(metadata_raw)?;
Ok(Some(Conversation::with_parts(
ConversationId(id),
created_at,
metadata,
)))
} else {
Ok(None)
}
})
.await
}
async fn update_conversation(
&self,
id: &ConversationId,
metadata: Option<ConversationMetadata>,
) -> Result<Option<Conversation>> {
let id_str = id.0.clone();
let metadata_json = metadata.as_ref().map(serde_json::to_string).transpose()?;
let conversation_id = id.clone();
self.with_connection(move |conn| {
let mut stmt = conn
.statement(
"UPDATE conversations \
SET metadata = :1 \
WHERE id = :2 \
RETURNING created_at INTO :3",
)
.build()
.map_err(map_oracle_error)?;
stmt.bind(3, &OracleType::TimestampTZ(6))
.map_err(map_oracle_error)?;
stmt.execute(&[&metadata_json, &id_str])
.map_err(map_oracle_error)?;
if stmt.row_count().map_err(map_oracle_error)? == 0 {
return Ok(None);
}
let mut created_at: Vec<chrono::DateTime<Utc>> =
stmt.returned_values(3).map_err(map_oracle_error)?;
let created_at = created_at.pop().ok_or_else(|| {
ConversationStorageError::StorageError(
"Oracle update did not return created_at".to_string(),
)
})?;
Ok(Some(Conversation::with_parts(
conversation_id,
created_at,
metadata,
)))
})
.await
}
async fn delete_conversation(&self, id: &ConversationId) -> Result<bool> {
let id_str = id.0.clone();
let res = self
.with_connection(move |conn| {
conn.execute("DELETE FROM conversations WHERE id = :1", &[&id_str])
.map_err(map_oracle_error)
})
.await?;
Ok(res.row_count().map_err(map_oracle_error)? > 0)
}
}
#[derive(Clone)]
struct ConversationOracleConnectionManager {
params: Arc<OracleConnectParams>,
}
impl ConversationOracleConnectionManager {
fn new(config: Arc<OracleConfig>) -> Self {
let params = OracleConnectParams {
username: config.username.clone(),
password: config.password.clone(),
connect_descriptor: config.connect_descriptor.clone(),
};
Self {
params: Arc::new(params),
}
}
}
impl std::fmt::Debug for ConversationOracleConnectionManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConversationOracleConnectionManager")
.field("username", &self.params.username)
.field("connect_descriptor", &self.params.connect_descriptor)
.finish()
}
}
#[derive(Clone)]
struct OracleConnectParams {
username: String,
password: String,
connect_descriptor: String,
}
impl std::fmt::Debug for OracleConnectParams {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OracleConnectParams")
.field("username", &self.username)
.field("connect_descriptor", &self.connect_descriptor)
.finish()
}
}
#[async_trait]
impl Manager for ConversationOracleConnectionManager {
type Type = Connection;
type Error = oracle::Error;
fn create(
&self,
) -> impl std::future::Future<Output = std::result::Result<Connection, oracle::Error>> + Send
{
let params = self.params.clone();
async move {
let mut conn = Connection::connect(
&params.username,
&params.password,
&params.connect_descriptor,
)?;
conn.set_autocommit(true);
Ok(conn)
}
}
#[allow(clippy::manual_async_fn)]
fn recycle(
&self,
conn: &mut Connection,
_: &Metrics,
) -> impl std::future::Future<Output = RecycleResult<Self::Error>> + Send {
async move { conn.ping().map_err(RecycleError::Backend) }
}
}
fn configure_oracle_client(config: &OracleConfig) -> Result<()> {
if let Some(wallet_path) = &config.wallet_path {
let wallet_path = Path::new(wallet_path);
if !wallet_path.is_dir() {
return Err(ConversationStorageError::StorageError(format!(
"Oracle wallet/config path '{}' is not a directory",
wallet_path.display()
)));
}
if !wallet_path.join("tnsnames.ora").exists() && !wallet_path.join("sqlnet.ora").exists() {
return Err(ConversationStorageError::StorageError(format!(
"Oracle wallet/config path '{}' is missing tnsnames.ora or sqlnet.ora",
wallet_path.display()
)));
}
std::env::set_var("TNS_ADMIN", wallet_path);
}
Ok(())
}
fn initialize_schema(config: &OracleConfig) -> Result<()> {
let conn = Connection::connect(
&config.username,
&config.password,
&config.connect_descriptor,
)
.map_err(map_oracle_error)?;
let exists: i64 = conn
.query_row_as(
"SELECT COUNT(*) FROM user_tables WHERE table_name = 'CONVERSATIONS'",
&[],
)
.map_err(map_oracle_error)?;
if exists == 0 {
conn.execute(
"CREATE TABLE conversations (
id VARCHAR2(64) PRIMARY KEY,
created_at TIMESTAMP WITH TIME ZONE,
metadata CLOB
)",
&[],
)
.map_err(map_oracle_error)?;
}
Ok(())
}
fn map_pool_error(err: PoolError<oracle::Error>) -> ConversationStorageError {
match err {
PoolError::Backend(e) => map_oracle_error(e),
other => ConversationStorageError::StorageError(format!(
"failed to obtain Oracle conversation connection: {other}"
)),
}
}
fn map_oracle_error(err: oracle::Error) -> ConversationStorageError {
if let Some(db_err) = err.db_error() {
ConversationStorageError::StorageError(format!(
"Oracle error (code {}): {}",
db_err.code(),
db_err.message()
))
} else {
ConversationStorageError::StorageError(err.to_string())
}
}
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use rand::RngCore;
use serde::{Deserialize, Serialize};
use serde_json::{Map as JsonMap, Value};
use std::fmt::{Display, Formatter};
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
pub struct ConversationId(pub String);
impl ConversationId {
pub fn new() -> Self {
let mut rng = rand::rng();
let mut bytes = [0u8; 24];
rng.fill_bytes(&mut bytes);
let hex_string: String = bytes.iter().map(|b| format!("{:02x}", b)).collect();
Self(format!("conv_{}", hex_string))
}
}
impl Default for ConversationId {
fn default() -> Self {
Self::new()
}
}
impl From<String> for ConversationId {
fn from(value: String) -> Self {
Self(value)
}
}
impl From<&str> for ConversationId {
fn from(value: &str) -> Self {
Self(value.to_string())
}
}
impl Display for ConversationId {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}
/// Metadata payload persisted with a conversation
pub type ConversationMetadata = JsonMap<String, Value>;
/// Input payload for creating a conversation
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct NewConversation {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub metadata: Option<ConversationMetadata>,
}
/// Stored conversation data structure
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Conversation {
pub id: ConversationId,
pub created_at: DateTime<Utc>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub metadata: Option<ConversationMetadata>,
}
impl Conversation {
pub fn new(new_conversation: NewConversation) -> Self {
Self {
id: ConversationId::new(),
created_at: Utc::now(),
metadata: new_conversation.metadata,
}
}
pub fn with_parts(
id: ConversationId,
created_at: DateTime<Utc>,
metadata: Option<ConversationMetadata>,
) -> Self {
Self {
id,
created_at,
metadata,
}
}
}
/// Result alias for conversation storage operations
pub type Result<T> = std::result::Result<T, ConversationStorageError>;
/// Error type for conversation storage operations
#[derive(Debug, thiserror::Error)]
pub enum ConversationStorageError {
#[error("Conversation not found: {0}")]
ConversationNotFound(String),
#[error("Storage error: {0}")]
StorageError(String),
#[error("Serialization error: {0}")]
SerializationError(#[from] serde_json::Error),
}
/// Trait describing the CRUD interface for conversation storage backends
#[async_trait]
pub trait ConversationStorage: Send + Sync + 'static {
async fn create_conversation(&self, input: NewConversation) -> Result<Conversation>;
async fn get_conversation(&self, id: &ConversationId) -> Result<Option<Conversation>>;
async fn update_conversation(
&self,
id: &ConversationId,
metadata: Option<ConversationMetadata>,
) -> Result<Option<Conversation>>;
async fn delete_conversation(&self, id: &ConversationId) -> Result<bool>;
}
/// Shared pointer alias for conversation storage
pub type SharedConversationStorage = Arc<dyn ConversationStorage>;
// Data connector module for response storage // Data connector module for response storage and conversation storage
pub mod conversation_memory_store;
pub mod conversation_noop_store;
pub mod conversation_oracle_store;
pub mod conversations;
pub mod response_memory_store; pub mod response_memory_store;
pub mod response_noop_store; pub mod response_noop_store;
pub mod response_oracle_store; pub mod response_oracle_store;
pub mod responses; pub mod responses;
pub use conversation_memory_store::MemoryConversationStorage;
pub use conversation_noop_store::NoOpConversationStorage;
pub use conversation_oracle_store::OracleConversationStorage;
pub use conversations::{
Conversation, ConversationId, ConversationMetadata, ConversationStorage,
ConversationStorageError, NewConversation, Result as ConversationResult,
SharedConversationStorage,
};
pub use response_memory_store::MemoryResponseStorage; pub use response_memory_store::MemoryResponseStorage;
pub use response_noop_store::NoOpResponseStorage; pub use response_noop_store::NoOpResponseStorage;
pub use response_oracle_store::OracleResponseStorage; pub use response_oracle_store::OracleResponseStorage;
......
...@@ -207,10 +207,10 @@ mod tests { ...@@ -207,10 +207,10 @@ mod tests {
async fn test_store_with_custom_id() { async fn test_store_with_custom_id() {
let store = MemoryResponseStorage::new(); let store = MemoryResponseStorage::new();
let mut response = StoredResponse::new("Input".to_string(), "Output".to_string(), None); let mut response = StoredResponse::new("Input".to_string(), "Output".to_string(), None);
response.id = ResponseId::from_string("resp_custom".to_string()); response.id = ResponseId::from("resp_custom");
store.store_response(response.clone()).await.unwrap(); store.store_response(response.clone()).await.unwrap();
let retrieved = store let retrieved = store
.get_response(&ResponseId::from_string("resp_custom".to_string())) .get_response(&ResponseId::from("resp_custom"))
.await .await
.unwrap(); .unwrap();
assert!(retrieved.is_some()); assert!(retrieved.is_some());
......
...@@ -12,10 +12,6 @@ impl ResponseId { ...@@ -12,10 +12,6 @@ impl ResponseId {
pub fn new() -> Self { pub fn new() -> Self {
Self(ulid::Ulid::new().to_string()) Self(ulid::Ulid::new().to_string())
} }
pub fn from_string(s: String) -> Self {
Self(s)
}
} }
impl Default for ResponseId { impl Default for ResponseId {
...@@ -24,6 +20,18 @@ impl Default for ResponseId { ...@@ -24,6 +20,18 @@ impl Default for ResponseId {
} }
} }
impl From<String> for ResponseId {
fn from(value: String) -> Self {
Self(value)
}
}
impl From<&str> for ResponseId {
fn from(value: &str) -> Self {
Self(value.to_string())
}
}
/// Stored response data /// Stored response data
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StoredResponse { pub struct StoredResponse {
......
...@@ -128,6 +128,7 @@ impl RouterFactory { ...@@ -128,6 +128,7 @@ impl RouterFactory {
base_url, base_url,
Some(ctx.router_config.circuit_breaker.clone()), Some(ctx.router_config.circuit_breaker.clone()),
ctx.response_storage.clone(), ctx.response_storage.clone(),
ctx.conversation_storage.clone(),
) )
.await?; .await?;
......
...@@ -2,7 +2,10 @@ ...@@ -2,7 +2,10 @@
use crate::config::CircuitBreakerConfig; use crate::config::CircuitBreakerConfig;
use crate::core::{CircuitBreaker, CircuitBreakerConfig as CoreCircuitBreakerConfig}; use crate::core::{CircuitBreaker, CircuitBreakerConfig as CoreCircuitBreakerConfig};
use crate::data_connector::{ResponseId, SharedResponseStorage, StoredResponse}; use crate::data_connector::{
Conversation, ConversationId, ConversationMetadata, ResponseId, SharedConversationStorage,
SharedResponseStorage, StoredResponse,
};
use crate::protocols::spec::{ use crate::protocols::spec::{
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
ResponseContentPart, ResponseInput, ResponseInputOutputItem, ResponseOutputItem, ResponseContentPart, ResponseInput, ResponseInputOutputItem, ResponseOutputItem,
...@@ -16,6 +19,7 @@ use axum::{ ...@@ -16,6 +19,7 @@ use axum::{
extract::Request, extract::Request,
http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode}, http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
Json,
}; };
use bytes::Bytes; use bytes::Bytes;
use futures_util::StreamExt; use futures_util::StreamExt;
...@@ -75,6 +79,8 @@ pub struct OpenAIRouter { ...@@ -75,6 +79,8 @@ pub struct OpenAIRouter {
healthy: AtomicBool, healthy: AtomicBool,
/// Response storage for managing conversation history /// Response storage for managing conversation history
response_storage: SharedResponseStorage, response_storage: SharedResponseStorage,
/// Conversation storage backend
conversation_storage: SharedConversationStorage,
/// Optional MCP manager (enabled via config presence) /// Optional MCP manager (enabled via config presence)
mcp_manager: Option<Arc<crate::mcp::McpClientManager>>, mcp_manager: Option<Arc<crate::mcp::McpClientManager>>,
} }
...@@ -705,6 +711,7 @@ impl OpenAIRouter { ...@@ -705,6 +711,7 @@ impl OpenAIRouter {
base_url: String, base_url: String,
circuit_breaker_config: Option<CircuitBreakerConfig>, circuit_breaker_config: Option<CircuitBreakerConfig>,
response_storage: SharedResponseStorage, response_storage: SharedResponseStorage,
conversation_storage: SharedConversationStorage,
) -> Result<Self, String> { ) -> Result<Self, String> {
let client = reqwest::Client::builder() let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(300)) .timeout(std::time::Duration::from_secs(300))
...@@ -751,6 +758,7 @@ impl OpenAIRouter { ...@@ -751,6 +758,7 @@ impl OpenAIRouter {
circuit_breaker, circuit_breaker,
healthy: AtomicBool::new(true), healthy: AtomicBool::new(true),
response_storage, response_storage,
conversation_storage,
mcp_manager, mcp_manager,
}) })
} }
...@@ -2337,16 +2345,16 @@ impl OpenAIRouter { ...@@ -2337,16 +2345,16 @@ impl OpenAIRouter {
stored_response.previous_response_id = response_json stored_response.previous_response_id = response_json
.get("previous_response_id") .get("previous_response_id")
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.map(|s| ResponseId::from_string(s.to_string())) .map(ResponseId::from)
.or_else(|| { .or_else(|| {
original_body original_body
.previous_response_id .previous_response_id
.as_ref() .as_ref()
.map(|id| ResponseId::from_string(id.clone())) .map(|id| ResponseId::from(id.as_str()))
}); });
if let Some(id_str) = response_json.get("id").and_then(|v| v.as_str()) { if let Some(id_str) = response_json.get("id").and_then(|v| v.as_str()) {
stored_response.id = ResponseId::from_string(id_str.to_string()); stored_response.id = ResponseId::from(id_str);
} }
stored_response.raw_response = response_json.clone(); stored_response.raw_response = response_json.clone();
...@@ -3393,7 +3401,7 @@ impl super::super::RouterTrait for OpenAIRouter { ...@@ -3393,7 +3401,7 @@ impl super::super::RouterTrait for OpenAIRouter {
// Handle previous_response_id by loading prior context // Handle previous_response_id by loading prior context
let mut conversation_items: Option<Vec<ResponseInputOutputItem>> = None; let mut conversation_items: Option<Vec<ResponseInputOutputItem>> = None;
if let Some(prev_id_str) = request_body.previous_response_id.clone() { if let Some(prev_id_str) = request_body.previous_response_id.clone() {
let prev_id = ResponseId::from_string(prev_id_str.clone()); let prev_id = ResponseId::from(prev_id_str.as_str());
match self match self
.response_storage .response_storage
.get_response_chain(&prev_id, None) .get_response_chain(&prev_id, None)
...@@ -3516,7 +3524,7 @@ impl super::super::RouterTrait for OpenAIRouter { ...@@ -3516,7 +3524,7 @@ impl super::super::RouterTrait for OpenAIRouter {
response_id: &str, response_id: &str,
params: &ResponsesGetParams, params: &ResponsesGetParams,
) -> Response { ) -> Response {
let stored_id = ResponseId::from_string(response_id.to_string()); let stored_id = ResponseId::from(response_id);
if let Ok(Some(stored_response)) = self.response_storage.get_response(&stored_id).await { if let Ok(Some(stored_response)) = self.response_storage.get_response(&stored_id).await {
let stream_requested = params.stream.unwrap_or(false); let stream_requested = params.stream.unwrap_or(false);
let raw_value = stored_response.raw_response.clone(); let raw_value = stored_response.raw_response.clone();
...@@ -3646,10 +3654,6 @@ impl super::super::RouterTrait for OpenAIRouter { ...@@ -3646,10 +3654,6 @@ impl super::super::RouterTrait for OpenAIRouter {
} }
} }
fn router_type(&self) -> &'static str {
"openai"
}
async fn route_embeddings( async fn route_embeddings(
&self, &self,
_headers: Option<&HeaderMap>, _headers: Option<&HeaderMap>,
...@@ -3675,4 +3679,309 @@ impl super::super::RouterTrait for OpenAIRouter { ...@@ -3675,4 +3679,309 @@ impl super::super::RouterTrait for OpenAIRouter {
) )
.into_response() .into_response()
} }
async fn create_conversation(&self, _headers: Option<&HeaderMap>, body: &Value) -> Response {
// TODO: move this spec validation to the right place
let metadata = match body.get("metadata") {
Some(Value::Object(map)) => {
if map.len() > MAX_METADATA_PROPERTIES {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": {
"message": format!(
"Invalid 'metadata': too many properties. Max {}, got {}",
MAX_METADATA_PROPERTIES, map.len()
),
"type": "invalid_request_error",
"param": "metadata",
"code": "metadata_max_properties_exceeded"
}
})),
)
.into_response();
}
Some(map.clone())
}
Some(Value::Null) | None => None,
Some(other) => {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": {
"message": format!(
"Invalid 'metadata': expected object or null but got {}",
other
),
"type": "invalid_request_error",
"param": "metadata",
"code": "metadata_invalid_type"
}
})),
)
.into_response();
}
};
match self
.conversation_storage
.create_conversation(crate::data_connector::NewConversation { metadata })
.await
{
Ok(conversation) => {
(StatusCode::OK, Json(conversation_to_json(&conversation))).into_response()
}
Err(err) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({
"error": {
"message": err.to_string(),
"type": "internal_error",
"param": Value::Null,
"code": Value::Null
}
})),
)
.into_response(),
}
}
async fn get_conversation(
&self,
_headers: Option<&HeaderMap>,
conversation_id: &str,
) -> Response {
let id: ConversationId = conversation_id.to_string().into();
match self.conversation_storage.get_conversation(&id).await {
Ok(Some(conv)) => (StatusCode::OK, Json(conversation_to_json(&conv))).into_response(),
Ok(None) => (
StatusCode::NOT_FOUND,
Json(json!({
"error": {
"message": format!("Conversation with id '{}' not found.", conversation_id),
"type": "invalid_request_error",
"param": Value::Null,
"code": Value::Null
}
})),
)
.into_response(),
Err(err) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({
"error": {
"message": err.to_string(),
"type": "internal_error",
"param": Value::Null,
"code": Value::Null
}
})),
)
.into_response(),
}
}
async fn update_conversation(
&self,
_headers: Option<&HeaderMap>,
conversation_id: &str,
body: &Value,
) -> Response {
let id: ConversationId = conversation_id.to_string().into();
let existing = match self.conversation_storage.get_conversation(&id).await {
Ok(Some(c)) => c,
Ok(None) => {
return (
StatusCode::NOT_FOUND,
Json(json!({
"error": {
"message": format!("Conversation with id '{}' not found.", conversation_id),
"type": "invalid_request_error",
"param": Value::Null,
"code": Value::Null
}
})),
)
.into_response();
}
Err(err) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({
"error": {
"message": err.to_string(),
"type": "internal_error",
"param": Value::Null,
"code": Value::Null
}
})),
)
.into_response();
}
};
// Parse metadata patch
enum Patch {
NoChange,
ClearAll,
Merge(ConversationMetadata),
}
let patch = match body.get("metadata") {
None => Patch::NoChange,
Some(Value::Null) => Patch::ClearAll,
Some(Value::Object(map)) => Patch::Merge(map.clone()),
Some(other) => {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": {
"message": format!(
"Invalid 'metadata': expected object or null but got {}",
other
),
"type": "invalid_request_error",
"param": "metadata",
"code": "metadata_invalid_type"
}
})),
)
.into_response();
}
};
let merged_metadata = match patch {
Patch::NoChange => {
return (StatusCode::OK, Json(conversation_to_json(&existing))).into_response();
}
Patch::ClearAll => None,
Patch::Merge(upd) => {
let mut merged = existing.metadata.clone().unwrap_or_default();
let previous = merged.len();
for (k, v) in upd.into_iter() {
if v.is_null() {
merged.remove(&k);
} else {
merged.insert(k, v);
}
}
let updated = merged.len();
if updated > MAX_METADATA_PROPERTIES {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": {
"message": format!(
"Invalid 'metadata': too many properties after update. Max {} ({} -> {}).",
MAX_METADATA_PROPERTIES, previous, updated
),
"type": "invalid_request_error",
"param": "metadata",
"code": "metadata_max_properties_exceeded",
"extra": {
"previous_property_count": previous,
"updated_property_count": updated
}
}
})),
)
.into_response();
}
if merged.is_empty() {
None
} else {
Some(merged)
}
}
};
match self
.conversation_storage
.update_conversation(&id, merged_metadata)
.await
{
Ok(Some(conv)) => (StatusCode::OK, Json(conversation_to_json(&conv))).into_response(),
Ok(None) => (
StatusCode::NOT_FOUND,
Json(json!({
"error": {
"message": format!("Conversation with id '{}' not found.", conversation_id),
"type": "invalid_request_error",
"param": Value::Null,
"code": Value::Null
}
})),
)
.into_response(),
Err(err) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({
"error": {
"message": err.to_string(),
"type": "internal_error",
"param": Value::Null,
"code": Value::Null
}
})),
)
.into_response(),
}
}
async fn delete_conversation(
&self,
_headers: Option<&HeaderMap>,
conversation_id: &str,
) -> Response {
let id: ConversationId = conversation_id.to_string().into();
match self.conversation_storage.delete_conversation(&id).await {
Ok(true) => (
StatusCode::OK,
Json(json!({
"id": conversation_id,
"object": "conversation.deleted",
"deleted": true
})),
)
.into_response(),
Ok(false) => (
StatusCode::NOT_FOUND,
Json(json!({
"error": {
"message": format!("Conversation with id '{}' not found.", conversation_id),
"type": "invalid_request_error",
"param": Value::Null,
"code": Value::Null
}
})),
)
.into_response(),
Err(err) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({
"error": {
"message": err.to_string(),
"type": "internal_error",
"param": Value::Null,
"code": Value::Null
}
})),
)
.into_response(),
}
}
fn router_type(&self) -> &'static str {
"openai"
}
}
// Maximum number of properties allowed in conversation metadata (align with server)
const MAX_METADATA_PROPERTIES: usize = 16;
fn conversation_to_json(conversation: &Conversation) -> Value {
json!({
"id": conversation.id.0,
"object": "conversation",
"created_at": conversation.created_at.timestamp(),
"metadata": to_value(&conversation.metadata).unwrap_or(Value::Null),
})
} }
...@@ -13,6 +13,7 @@ use crate::protocols::spec::{ ...@@ -13,6 +13,7 @@ use crate::protocols::spec::{
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
ResponsesGetParams, ResponsesRequest, ResponsesGetParams, ResponsesRequest,
}; };
use serde_json::Value;
pub mod factory; pub mod factory;
pub mod grpc; pub mod grpc;
...@@ -126,6 +127,52 @@ pub trait RouterTrait: Send + Sync + Debug { ...@@ -126,6 +127,52 @@ pub trait RouterTrait: Send + Sync + Debug {
model_id: Option<&str>, model_id: Option<&str>,
) -> Response; ) -> Response;
// Conversations API
async fn create_conversation(&self, _headers: Option<&HeaderMap>, _body: &Value) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"Conversations create endpoint not implemented",
)
.into_response()
}
async fn get_conversation(
&self,
_headers: Option<&HeaderMap>,
_conversation_id: &str,
) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"Conversations get endpoint not implemented",
)
.into_response()
}
async fn update_conversation(
&self,
_headers: Option<&HeaderMap>,
_conversation_id: &str,
_body: &Value,
) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"Conversations update endpoint not implemented",
)
.into_response()
}
async fn delete_conversation(
&self,
_headers: Option<&HeaderMap>,
_conversation_id: &str,
) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"Conversations delete endpoint not implemented",
)
.into_response()
}
/// Get router type name /// Get router type name
fn router_type(&self) -> &'static str; fn router_type(&self) -> &'static str;
......
...@@ -20,6 +20,7 @@ use axum::{ ...@@ -20,6 +20,7 @@ use axum::{
response::{IntoResponse, Response}, response::{IntoResponse, Response},
}; };
use dashmap::DashMap; use dashmap::DashMap;
use serde_json::Value;
use std::sync::Arc; use std::sync::Arc;
use tracing::{debug, info, warn}; use tracing::{debug, info, warn};
...@@ -511,6 +512,83 @@ impl RouterTrait for RouterManager { ...@@ -511,6 +512,83 @@ impl RouterTrait for RouterManager {
fn router_type(&self) -> &'static str { fn router_type(&self) -> &'static str {
"manager" "manager"
} }
// Conversations API delegates
async fn create_conversation(&self, headers: Option<&HeaderMap>, body: &Value) -> Response {
let router = self.select_router_for_request(headers, None);
if let Some(router) = router {
router.create_conversation(headers, body).await
} else {
(
StatusCode::NOT_FOUND,
"No router available to create conversation",
)
.into_response()
}
}
async fn get_conversation(
&self,
headers: Option<&HeaderMap>,
conversation_id: &str,
) -> Response {
let router = self.select_router_for_request(headers, None);
if let Some(router) = router {
router.get_conversation(headers, conversation_id).await
} else {
(
StatusCode::NOT_FOUND,
format!(
"No router available to get conversation '{}'",
conversation_id
),
)
.into_response()
}
}
async fn update_conversation(
&self,
headers: Option<&HeaderMap>,
conversation_id: &str,
body: &Value,
) -> Response {
let router = self.select_router_for_request(headers, None);
if let Some(router) = router {
router
.update_conversation(headers, conversation_id, body)
.await
} else {
(
StatusCode::NOT_FOUND,
format!(
"No router available to update conversation '{}'",
conversation_id
),
)
.into_response()
}
}
async fn delete_conversation(
&self,
headers: Option<&HeaderMap>,
conversation_id: &str,
) -> Response {
let router = self.select_router_for_request(headers, None);
if let Some(router) = router {
router.delete_conversation(headers, conversation_id).await
} else {
(
StatusCode::NOT_FOUND,
format!(
"No router available to delete conversation '{}'",
conversation_id
),
)
.into_response()
}
}
} }
impl std::fmt::Debug for RouterManager { impl std::fmt::Debug for RouterManager {
......
...@@ -2,7 +2,9 @@ use crate::{ ...@@ -2,7 +2,9 @@ use crate::{
config::{ConnectionMode, HistoryBackend, RouterConfig, RoutingMode}, config::{ConnectionMode, HistoryBackend, RouterConfig, RoutingMode},
core::{LoadMonitor, WorkerManager, WorkerRegistry, WorkerType}, core::{LoadMonitor, WorkerManager, WorkerRegistry, WorkerType},
data_connector::{ data_connector::{
MemoryResponseStorage, NoOpResponseStorage, OracleResponseStorage, SharedResponseStorage, MemoryConversationStorage, MemoryResponseStorage, NoOpConversationStorage,
NoOpResponseStorage, OracleConversationStorage, OracleResponseStorage,
SharedConversationStorage, SharedResponseStorage,
}, },
logging::{self, LoggingConfig}, logging::{self, LoggingConfig},
metrics::{self, PrometheusConfig}, metrics::{self, PrometheusConfig},
...@@ -39,6 +41,8 @@ use std::{ ...@@ -39,6 +41,8 @@ use std::{
use tokio::{net::TcpListener, signal, spawn}; use tokio::{net::TcpListener, signal, spawn};
use tracing::{error, info, warn, Level}; use tracing::{error, info, warn, Level};
//
#[derive(Clone)] #[derive(Clone)]
pub struct AppContext { pub struct AppContext {
pub client: Client, pub client: Client,
...@@ -51,6 +55,7 @@ pub struct AppContext { ...@@ -51,6 +55,7 @@ pub struct AppContext {
pub policy_registry: Arc<PolicyRegistry>, pub policy_registry: Arc<PolicyRegistry>,
pub router_manager: Option<Arc<RouterManager>>, pub router_manager: Option<Arc<RouterManager>>,
pub response_storage: SharedResponseStorage, pub response_storage: SharedResponseStorage,
pub conversation_storage: SharedConversationStorage,
pub load_monitor: Option<Arc<LoadMonitor>>, pub load_monitor: Option<Arc<LoadMonitor>>,
pub configured_reasoning_parser: Option<String>, pub configured_reasoning_parser: Option<String>,
pub configured_tool_parser: Option<String>, pub configured_tool_parser: Option<String>,
...@@ -94,19 +99,34 @@ impl AppContext { ...@@ -94,19 +99,34 @@ impl AppContext {
let router_manager = None; let router_manager = None;
let response_storage: SharedResponseStorage = match router_config.history_backend { let (response_storage, conversation_storage): (
HistoryBackend::Memory => Arc::new(MemoryResponseStorage::new()), SharedResponseStorage,
HistoryBackend::None => Arc::new(NoOpResponseStorage::new()), SharedConversationStorage,
) = match router_config.history_backend {
HistoryBackend::Memory => (
Arc::new(MemoryResponseStorage::new()),
Arc::new(MemoryConversationStorage::new()),
),
HistoryBackend::None => (
Arc::new(NoOpResponseStorage::new()),
Arc::new(NoOpConversationStorage::new()),
),
HistoryBackend::Oracle => { HistoryBackend::Oracle => {
let oracle_cfg = router_config.oracle.clone().ok_or_else(|| { let oracle_cfg = router_config.oracle.clone().ok_or_else(|| {
"oracle configuration is required when history_backend=oracle".to_string() "oracle configuration is required when history_backend=oracle".to_string()
})?; })?;
let storage = OracleResponseStorage::new(oracle_cfg).map_err(|err| { let response_storage =
format!("failed to initialize Oracle response storage: {err}") OracleResponseStorage::new(oracle_cfg.clone()).map_err(|err| {
})?; format!("failed to initialize Oracle response storage: {err}")
})?;
let conversation_storage =
OracleConversationStorage::new(oracle_cfg).map_err(|err| {
format!("failed to initialize Oracle conversation storage: {err}")
})?;
Arc::new(storage) (Arc::new(response_storage), Arc::new(conversation_storage))
} }
}; };
...@@ -131,6 +151,7 @@ impl AppContext { ...@@ -131,6 +151,7 @@ impl AppContext {
policy_registry, policy_registry,
router_manager, router_manager,
response_storage, response_storage,
conversation_storage,
load_monitor, load_monitor,
configured_reasoning_parser, configured_reasoning_parser,
configured_tool_parser, configured_tool_parser,
...@@ -334,6 +355,51 @@ async fn v1_responses_list_input_items( ...@@ -334,6 +355,51 @@ async fn v1_responses_list_input_items(
.await .await
} }
async fn v1_conversations_create(
State(state): State<Arc<AppState>>,
headers: http::HeaderMap,
Json(body): Json<Value>,
) -> Response {
state
.router
.create_conversation(Some(&headers), &body)
.await
}
async fn v1_conversations_get(
State(state): State<Arc<AppState>>,
Path(conversation_id): Path<String>,
headers: http::HeaderMap,
) -> Response {
state
.router
.get_conversation(Some(&headers), &conversation_id)
.await
}
async fn v1_conversations_update(
State(state): State<Arc<AppState>>,
Path(conversation_id): Path<String>,
headers: http::HeaderMap,
Json(body): Json<Value>,
) -> Response {
state
.router
.update_conversation(Some(&headers), &conversation_id, &body)
.await
}
async fn v1_conversations_delete(
State(state): State<Arc<AppState>>,
Path(conversation_id): Path<String>,
headers: http::HeaderMap,
) -> Response {
state
.router
.delete_conversation(Some(&headers), &conversation_id)
.await
}
#[derive(Deserialize)] #[derive(Deserialize)]
struct AddWorkerQuery { struct AddWorkerQuery {
url: String, url: String,
...@@ -601,6 +667,13 @@ pub fn build_app( ...@@ -601,6 +667,13 @@ pub fn build_app(
"/v1/responses/{response_id}/input", "/v1/responses/{response_id}/input",
get(v1_responses_list_input_items), get(v1_responses_list_input_items),
) )
.route("/v1/conversations", post(v1_conversations_create))
.route(
"/v1/conversations/{conversation_id}",
get(v1_conversations_get)
.post(v1_conversations_update)
.delete(v1_conversations_delete),
)
.route_layer(axum::middleware::from_fn_with_state( .route_layer(axum::middleware::from_fn_with_state(
app_state.clone(), app_state.clone(),
middleware::concurrency_limit_middleware, middleware::concurrency_limit_middleware,
......
...@@ -542,6 +542,7 @@ mod tests { ...@@ -542,6 +542,7 @@ mod tests {
tool_parser_factory: None, tool_parser_factory: None,
router_manager: None, router_manager: None,
response_storage: Arc::new(crate::data_connector::MemoryResponseStorage::new()), response_storage: Arc::new(crate::data_connector::MemoryResponseStorage::new()),
conversation_storage: Arc::new(crate::data_connector::MemoryConversationStorage::new()),
load_monitor: None, load_monitor: None,
configured_reasoning_parser: None, configured_reasoning_parser: None,
configured_tool_parser: None, configured_tool_parser: None,
......
...@@ -239,6 +239,100 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() { ...@@ -239,6 +239,100 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
mcp.stop().await; mcp.stop().await;
} }
#[tokio::test]
async fn test_conversations_crud_basic() {
// Router in OpenAI mode (no actual upstream calls in these tests)
let router_cfg = RouterConfig {
mode: RoutingMode::OpenAI {
worker_urls: vec!["http://localhost".to_string()],
},
connection_mode: ConnectionMode::Http,
policy: PolicyConfig::Random,
host: "127.0.0.1".to_string(),
port: 0,
max_payload_size: 8 * 1024 * 1024,
request_timeout_secs: 60,
worker_startup_timeout_secs: 1,
worker_startup_check_interval_secs: 1,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,
log_level: Some("warn".to_string()),
request_id_headers: None,
max_concurrent_requests: 8,
queue_size: 0,
queue_timeout_secs: 5,
rate_limit_tokens_per_second: None,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: HealthCheckConfig::default(),
enable_igw: false,
model_path: None,
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
};
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 8, None).expect("ctx");
let router = RouterFactory::create_router(&Arc::new(ctx))
.await
.expect("router");
// Create
let create_body = serde_json::json!({ "metadata": { "project": "alpha" } });
let create_resp = router.create_conversation(None, &create_body).await;
assert_eq!(create_resp.status(), axum::http::StatusCode::OK);
let create_bytes = axum::body::to_bytes(create_resp.into_body(), usize::MAX)
.await
.unwrap();
let create_json: serde_json::Value = serde_json::from_slice(&create_bytes).unwrap();
let conv_id = create_json["id"].as_str().expect("id missing");
assert!(conv_id.starts_with("conv_"));
assert_eq!(create_json["object"], "conversation");
// Get
let get_resp = router.get_conversation(None, conv_id).await;
assert_eq!(get_resp.status(), axum::http::StatusCode::OK);
let get_bytes = axum::body::to_bytes(get_resp.into_body(), usize::MAX)
.await
.unwrap();
let get_json: serde_json::Value = serde_json::from_slice(&get_bytes).unwrap();
assert_eq!(get_json["metadata"]["project"], serde_json::json!("alpha"));
// Update (merge)
let update_body = serde_json::json!({ "metadata": { "owner": "alice" } });
let upd_resp = router
.update_conversation(None, conv_id, &update_body)
.await;
assert_eq!(upd_resp.status(), axum::http::StatusCode::OK);
let upd_bytes = axum::body::to_bytes(upd_resp.into_body(), usize::MAX)
.await
.unwrap();
let upd_json: serde_json::Value = serde_json::from_slice(&upd_bytes).unwrap();
assert_eq!(upd_json["metadata"]["project"], serde_json::json!("alpha"));
assert_eq!(upd_json["metadata"]["owner"], serde_json::json!("alice"));
// Delete
let del_resp = router.delete_conversation(None, conv_id).await;
assert_eq!(del_resp.status(), axum::http::StatusCode::OK);
let del_bytes = axum::body::to_bytes(del_resp.into_body(), usize::MAX)
.await
.unwrap();
let del_json: serde_json::Value = serde_json::from_slice(&del_bytes).unwrap();
assert_eq!(del_json["deleted"], serde_json::json!(true));
// Get again -> 404
let not_found = router.get_conversation(None, conv_id).await;
assert_eq!(not_found.status(), axum::http::StatusCode::NOT_FOUND);
}
#[test] #[test]
fn test_responses_request_creation() { fn test_responses_request_creation() {
let request = ResponsesRequest { let request = ResponsesRequest {
......
...@@ -13,7 +13,10 @@ use sglang_router_rs::{ ...@@ -13,7 +13,10 @@ use sglang_router_rs::{
config::{ config::{
ConfigError, ConfigValidator, HistoryBackend, OracleConfig, RouterConfig, RoutingMode, ConfigError, ConfigValidator, HistoryBackend, OracleConfig, RouterConfig, RoutingMode,
}, },
data_connector::{MemoryResponseStorage, ResponseId, ResponseStorage, StoredResponse}, data_connector::{
MemoryConversationStorage, MemoryResponseStorage, ResponseId, ResponseStorage,
StoredResponse,
},
protocols::spec::{ protocols::spec::{
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, ResponseInput, ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, ResponseInput,
ResponsesGetParams, ResponsesRequest, UserMessageContent, ResponsesGetParams, ResponsesRequest, UserMessageContent,
...@@ -91,6 +94,7 @@ async fn test_openai_router_creation() { ...@@ -91,6 +94,7 @@ async fn test_openai_router_creation() {
"https://api.openai.com".to_string(), "https://api.openai.com".to_string(),
None, None,
Arc::new(MemoryResponseStorage::new()), Arc::new(MemoryResponseStorage::new()),
Arc::new(MemoryConversationStorage::new()),
) )
.await; .await;
...@@ -108,6 +112,7 @@ async fn test_openai_router_server_info() { ...@@ -108,6 +112,7 @@ async fn test_openai_router_server_info() {
"https://api.openai.com".to_string(), "https://api.openai.com".to_string(),
None, None,
Arc::new(MemoryResponseStorage::new()), Arc::new(MemoryResponseStorage::new()),
Arc::new(MemoryConversationStorage::new()),
) )
.await .await
.unwrap(); .unwrap();
...@@ -137,6 +142,7 @@ async fn test_openai_router_models() { ...@@ -137,6 +142,7 @@ async fn test_openai_router_models() {
mock_server.base_url(), mock_server.base_url(),
None, None,
Arc::new(MemoryResponseStorage::new()), Arc::new(MemoryResponseStorage::new()),
Arc::new(MemoryConversationStorage::new()),
) )
.await .await
.unwrap(); .unwrap();
...@@ -211,9 +217,14 @@ async fn test_openai_router_responses_with_mock() { ...@@ -211,9 +217,14 @@ async fn test_openai_router_responses_with_mock() {
let base_url = format!("http://{}", addr); let base_url = format!("http://{}", addr);
let storage = Arc::new(MemoryResponseStorage::new()); let storage = Arc::new(MemoryResponseStorage::new());
let router = OpenAIRouter::new(base_url, None, storage.clone()) let router = OpenAIRouter::new(
.await base_url,
.unwrap(); None,
storage.clone(),
Arc::new(MemoryConversationStorage::new()),
)
.await
.unwrap();
let request1 = ResponsesRequest { let request1 = ResponsesRequest {
model: Some("gpt-4o-mini".to_string()), model: Some("gpt-4o-mini".to_string()),
...@@ -252,7 +263,7 @@ async fn test_openai_router_responses_with_mock() { ...@@ -252,7 +263,7 @@ async fn test_openai_router_responses_with_mock() {
); );
let stored1 = storage let stored1 = storage
.get_response(&ResponseId::from_string(resp1_id.clone())) .get_response(&ResponseId::from(resp1_id.clone()))
.await .await
.unwrap() .unwrap()
.expect("first response missing"); .expect("first response missing");
...@@ -261,7 +272,7 @@ async fn test_openai_router_responses_with_mock() { ...@@ -261,7 +272,7 @@ async fn test_openai_router_responses_with_mock() {
assert!(stored1.previous_response_id.is_none()); assert!(stored1.previous_response_id.is_none());
let stored2 = storage let stored2 = storage
.get_response(&ResponseId::from_string(resp2_id.to_string())) .get_response(&ResponseId::from(resp2_id))
.await .await
.unwrap() .unwrap()
.expect("second response missing"); .expect("second response missing");
...@@ -463,12 +474,17 @@ async fn test_openai_router_responses_streaming_with_mock() { ...@@ -463,12 +474,17 @@ async fn test_openai_router_responses_streaming_with_mock() {
"Earlier answer".to_string(), "Earlier answer".to_string(),
None, None,
); );
previous.id = ResponseId::from_string("resp_prev_chain".to_string()); previous.id = ResponseId::from("resp_prev_chain");
storage.store_response(previous).await.unwrap(); storage.store_response(previous).await.unwrap();
let router = OpenAIRouter::new(base_url, None, storage.clone()) let router = OpenAIRouter::new(
.await base_url,
.unwrap(); None,
storage.clone(),
Arc::new(MemoryConversationStorage::new()),
)
.await
.unwrap();
let mut metadata = HashMap::new(); let mut metadata = HashMap::new();
metadata.insert("topic".to_string(), json!("unicorns")); metadata.insert("topic".to_string(), json!("unicorns"));
...@@ -504,7 +520,7 @@ async fn test_openai_router_responses_streaming_with_mock() { ...@@ -504,7 +520,7 @@ async fn test_openai_router_responses_streaming_with_mock() {
assert!(body_text.contains("Once upon a streamed unicorn adventure.")); assert!(body_text.contains("Once upon a streamed unicorn adventure."));
// Wait for the storage task to persist the streaming response. // Wait for the storage task to persist the streaming response.
let target_id = ResponseId::from_string("resp_stream_123".to_string()); let target_id = ResponseId::from("resp_stream_123");
let stored = loop { let stored = loop {
if let Some(resp) = storage.get_response(&target_id).await.unwrap() { if let Some(resp) = storage.get_response(&target_id).await.unwrap() {
break resp; break resp;
...@@ -569,6 +585,7 @@ async fn test_unsupported_endpoints() { ...@@ -569,6 +585,7 @@ async fn test_unsupported_endpoints() {
"https://api.openai.com".to_string(), "https://api.openai.com".to_string(),
None, None,
Arc::new(MemoryResponseStorage::new()), Arc::new(MemoryResponseStorage::new()),
Arc::new(MemoryConversationStorage::new()),
) )
.await .await
.unwrap(); .unwrap();
...@@ -605,9 +622,14 @@ async fn test_openai_router_chat_completion_with_mock() { ...@@ -605,9 +622,14 @@ async fn test_openai_router_chat_completion_with_mock() {
let base_url = mock_server.base_url(); let base_url = mock_server.base_url();
// Create router pointing to mock server // Create router pointing to mock server
let router = OpenAIRouter::new(base_url, None, Arc::new(MemoryResponseStorage::new())) let router = OpenAIRouter::new(
.await base_url,
.unwrap(); None,
Arc::new(MemoryResponseStorage::new()),
Arc::new(MemoryConversationStorage::new()),
)
.await
.unwrap();
// Create a minimal chat completion request // Create a minimal chat completion request
let mut chat_request = create_minimal_chat_request(); let mut chat_request = create_minimal_chat_request();
...@@ -642,9 +664,14 @@ async fn test_openai_e2e_with_server() { ...@@ -642,9 +664,14 @@ async fn test_openai_e2e_with_server() {
let base_url = mock_server.base_url(); let base_url = mock_server.base_url();
// Create router // Create router
let router = OpenAIRouter::new(base_url, None, Arc::new(MemoryResponseStorage::new())) let router = OpenAIRouter::new(
.await base_url,
.unwrap(); None,
Arc::new(MemoryResponseStorage::new()),
Arc::new(MemoryConversationStorage::new()),
)
.await
.unwrap();
// Create Axum app with chat completions endpoint // Create Axum app with chat completions endpoint
let app = Router::new().route( let app = Router::new().route(
...@@ -707,9 +734,14 @@ async fn test_openai_e2e_with_server() { ...@@ -707,9 +734,14 @@ async fn test_openai_e2e_with_server() {
async fn test_openai_router_chat_streaming_with_mock() { async fn test_openai_router_chat_streaming_with_mock() {
let mock_server = MockOpenAIServer::new().await; let mock_server = MockOpenAIServer::new().await;
let base_url = mock_server.base_url(); let base_url = mock_server.base_url();
let router = OpenAIRouter::new(base_url, None, Arc::new(MemoryResponseStorage::new())) let router = OpenAIRouter::new(
.await base_url,
.unwrap(); None,
Arc::new(MemoryResponseStorage::new()),
Arc::new(MemoryConversationStorage::new()),
)
.await
.unwrap();
// Build a streaming chat request // Build a streaming chat request
let val = json!({ let val = json!({
...@@ -759,6 +791,7 @@ async fn test_openai_router_circuit_breaker() { ...@@ -759,6 +791,7 @@ async fn test_openai_router_circuit_breaker() {
"http://invalid-url-that-will-fail".to_string(), "http://invalid-url-that-will-fail".to_string(),
Some(cb_config), Some(cb_config),
Arc::new(MemoryResponseStorage::new()), Arc::new(MemoryResponseStorage::new()),
Arc::new(MemoryConversationStorage::new()),
) )
.await .await
.unwrap(); .unwrap();
...@@ -786,6 +819,7 @@ async fn test_openai_router_models_auth_forwarding() { ...@@ -786,6 +819,7 @@ async fn test_openai_router_models_auth_forwarding() {
mock_server.base_url(), mock_server.base_url(),
None, None,
Arc::new(MemoryResponseStorage::new()), Arc::new(MemoryResponseStorage::new()),
Arc::new(MemoryConversationStorage::new()),
) )
.await .await
.unwrap(); .unwrap();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment