"notebooks/vscode:/vscode.git/clone" did not exist on "b20455a2b5599fdb824fa776ee64b0f1606fc545"
Unverified Commit 7ac6b900 authored by Keyang Ru's avatar Keyang Ru Committed by GitHub
Browse files

[router] Support history management using conversation (#11339)

parent a1080b72
use std::collections::{BTreeMap, HashMap};
use std::sync::RwLock;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use super::conversation_items::{
make_item_id, ConversationItem, ConversationItemId, ConversationItemStorage, ListParams,
Result, SortOrder,
};
use super::conversations::ConversationId;
#[derive(Default)]
pub struct MemoryConversationItemStorage {
items: RwLock<HashMap<ConversationItemId, ConversationItem>>, // item_id -> item
#[allow(clippy::type_complexity)]
links: RwLock<HashMap<ConversationId, BTreeMap<(i64, String), ConversationItemId>>>,
// Per-conversation reverse index for fast after cursor lookup: item_id_str -> (ts, item_id_str)
#[allow(clippy::type_complexity)]
rev_index: RwLock<HashMap<ConversationId, HashMap<String, (i64, String)>>>,
}
impl MemoryConversationItemStorage {
pub fn new() -> Self {
Self::default()
}
}
#[async_trait]
impl ConversationItemStorage for MemoryConversationItemStorage {
async fn create_item(
&self,
new_item: super::conversation_items::NewConversationItem,
) -> Result<ConversationItem> {
let id = new_item
.id
.clone()
.unwrap_or_else(|| make_item_id(&new_item.item_type));
let created_at = Utc::now();
let item = ConversationItem {
id: id.clone(),
response_id: new_item.response_id,
item_type: new_item.item_type,
role: new_item.role,
content: new_item.content,
status: new_item.status,
created_at,
};
let mut items = self.items.write().unwrap();
items.insert(id.clone(), item.clone());
Ok(item)
}
async fn link_item(
&self,
conversation_id: &ConversationId,
item_id: &ConversationItemId,
added_at: DateTime<Utc>,
) -> Result<()> {
{
let mut links = self.links.write().unwrap();
let entry = links.entry(conversation_id.clone()).or_default();
entry.insert((added_at.timestamp(), item_id.0.clone()), item_id.clone());
}
{
let mut rev = self.rev_index.write().unwrap();
let entry = rev.entry(conversation_id.clone()).or_default();
entry.insert(item_id.0.clone(), (added_at.timestamp(), item_id.0.clone()));
}
Ok(())
}
async fn list_items(
&self,
conversation_id: &ConversationId,
params: ListParams,
) -> Result<Vec<ConversationItem>> {
let links_guard = self.links.read().unwrap();
let map = match links_guard.get(conversation_id) {
Some(m) => m,
None => return Ok(Vec::new()),
};
let mut results: Vec<ConversationItem> = Vec::new();
let after_key: Option<(i64, String)> = if let Some(after_id) = &params.after {
// O(1) lookup via reverse index for this conversation
if let Some(conv_idx) = self.rev_index.read().unwrap().get(conversation_id) {
conv_idx.get(after_id).cloned()
} else {
None
}
} else {
None
};
let take = params.limit;
let items_guard = self.items.read().unwrap();
use std::ops::Bound::{Excluded, Unbounded};
// Helper to push item if it exists and stop when reaching the limit
let mut push_item = |key: &ConversationItemId| -> bool {
if let Some(it) = items_guard.get(key) {
results.push(it.clone());
if results.len() == take {
return true;
}
}
false
};
match (params.order, after_key) {
(SortOrder::Desc, Some(k)) => {
for ((_ts, _id), item_key) in map.range(..k).rev() {
if push_item(item_key) {
break;
}
}
}
(SortOrder::Desc, None) => {
for ((_ts, _id), item_key) in map.iter().rev() {
if push_item(item_key) {
break;
}
}
}
(SortOrder::Asc, Some(k)) => {
for ((_ts, _id), item_key) in map.range((Excluded(k), Unbounded)) {
if push_item(item_key) {
break;
}
}
}
(SortOrder::Asc, None) => {
for ((_ts, _id), item_key) in map.iter() {
if push_item(item_key) {
break;
}
}
}
}
Ok(results)
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::{TimeZone, Utc};
fn make_item(
item_type: &str,
role: Option<&str>,
content: serde_json::Value,
) -> super::super::conversation_items::NewConversationItem {
super::super::conversation_items::NewConversationItem {
id: None,
response_id: None,
item_type: item_type.to_string(),
role: role.map(|r| r.to_string()),
content,
status: Some("completed".to_string()),
}
}
#[tokio::test]
async fn test_list_ordering_and_cursors() {
let store = MemoryConversationItemStorage::new();
let conv: ConversationId = "conv_test".into();
// Create 3 items and link them at controlled timestamps
let i1 = store
.create_item(make_item("message", Some("user"), serde_json::json!([])))
.await
.unwrap();
let i2 = store
.create_item(make_item(
"message",
Some("assistant"),
serde_json::json!([]),
))
.await
.unwrap();
let i3 = store
.create_item(make_item("reasoning", None, serde_json::json!([])))
.await
.unwrap();
let t1 = Utc.timestamp_opt(1_700_000_001, 0).single().unwrap();
let t2 = Utc.timestamp_opt(1_700_000_002, 0).single().unwrap();
let t3 = Utc.timestamp_opt(1_700_000_003, 0).single().unwrap();
store.link_item(&conv, &i1.id, t1).await.unwrap();
store.link_item(&conv, &i2.id, t2).await.unwrap();
store.link_item(&conv, &i3.id, t3).await.unwrap();
// Desc order, no cursor
let desc = store
.list_items(
&conv,
ListParams {
limit: 2,
order: SortOrder::Desc,
after: None,
},
)
.await
.unwrap();
assert!(desc.len() >= 2);
assert_eq!(desc[0].id, i3.id);
assert_eq!(desc[1].id, i2.id);
// Desc with cursor = i2 -> expect i1 next
let desc_after = store
.list_items(
&conv,
ListParams {
limit: 2,
order: SortOrder::Desc,
after: Some(i2.id.0.clone()),
},
)
.await
.unwrap();
assert!(!desc_after.is_empty());
assert_eq!(desc_after[0].id, i1.id);
// Asc order, no cursor
let asc = store
.list_items(
&conv,
ListParams {
limit: 2,
order: SortOrder::Asc,
after: None,
},
)
.await
.unwrap();
assert!(asc.len() >= 2);
assert_eq!(asc[0].id, i1.id);
assert_eq!(asc[1].id, i2.id);
// Asc with cursor = i2 -> expect i3 next
let asc_after = store
.list_items(
&conv,
ListParams {
limit: 2,
order: SortOrder::Asc,
after: Some(i2.id.0.clone()),
},
)
.await
.unwrap();
assert!(!asc_after.is_empty());
assert_eq!(asc_after[0].id, i3.id);
}
}
use crate::config::OracleConfig;
use crate::data_connector::conversation_items::{
make_item_id, ConversationItem, ConversationItemId, ConversationItemStorage,
ConversationItemStorageError, ListParams, Result as ItemResult, SortOrder,
};
use crate::data_connector::conversations::ConversationId;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use deadpool::managed::{Manager, Metrics, Pool, PoolError, RecycleError, RecycleResult};
use oracle::sql_type::ToSql;
use oracle::Connection;
use serde_json::Value;
use std::path::Path;
use std::sync::Arc;
use std::time::Duration;
#[derive(Clone)]
pub struct OracleConversationItemStorage {
pool: Pool<ConversationItemOracleConnectionManager>,
}
impl OracleConversationItemStorage {
pub fn new(config: OracleConfig) -> ItemResult<Self> {
configure_oracle_client(&config)?;
initialize_schema(&config)?;
let config = Arc::new(config);
let manager = ConversationItemOracleConnectionManager::new(config.clone());
let mut builder = Pool::builder(manager)
.max_size(config.pool_max)
.runtime(deadpool::Runtime::Tokio1);
if config.pool_timeout_secs > 0 {
builder = builder.wait_timeout(Some(Duration::from_secs(config.pool_timeout_secs)));
}
let pool = builder.build().map_err(|err| {
ConversationItemStorageError::StorageError(format!(
"failed to build Oracle pool for conversation items: {err}"
))
})?;
Ok(Self { pool })
}
async fn with_connection<F, T>(&self, func: F) -> ItemResult<T>
where
F: FnOnce(&Connection) -> ItemResult<T> + Send + 'static,
T: Send + 'static,
{
let connection = self.pool.get().await.map_err(map_pool_error)?;
tokio::task::spawn_blocking(move || {
let result = func(&connection);
drop(connection);
result
})
.await
.map_err(|err| {
ConversationItemStorageError::StorageError(format!(
"failed to execute Oracle conversation item task: {err}"
))
})?
}
// reserved for future use when parsing JSON columns directly into Value
// fn parse_json(raw: Option<String>) -> ItemResult<Value> { ... }
}
#[async_trait]
impl ConversationItemStorage for OracleConversationItemStorage {
async fn create_item(
&self,
item: crate::data_connector::conversation_items::NewConversationItem,
) -> ItemResult<ConversationItem> {
let id = item
.id
.clone()
.unwrap_or_else(|| make_item_id(&item.item_type));
let created_at = Utc::now();
let content_json = serde_json::to_string(&item.content)?;
// Build the return value up-front; move inexpensive clones as needed for SQL
let conversation_item = ConversationItem {
id: id.clone(),
response_id: item.response_id.clone(),
item_type: item.item_type.clone(),
role: item.role.clone(),
content: item.content,
status: item.status.clone(),
created_at,
};
// Prepare values for SQL insertion
let id_str = conversation_item.id.0.clone();
let response_id = conversation_item.response_id.clone();
let item_type = conversation_item.item_type.clone();
let role = conversation_item.role.clone();
let status = conversation_item.status.clone();
self.with_connection(move |conn| {
conn.execute(
"INSERT INTO conversation_items (id, response_id, item_type, role, content, status, created_at) \
VALUES (:1, :2, :3, :4, :5, :6, :7)",
&[&id_str, &response_id, &item_type, &role, &content_json, &status, &created_at],
)
.map_err(map_oracle_error)?;
Ok(())
})
.await?;
Ok(conversation_item)
}
async fn link_item(
&self,
conversation_id: &ConversationId,
item_id: &ConversationItemId,
added_at: DateTime<Utc>,
) -> ItemResult<()> {
let cid = conversation_id.0.clone();
let iid = item_id.0.clone();
self.with_connection(move |conn| {
conn.execute(
"INSERT INTO conversation_item_links (conversation_id, item_id, added_at) VALUES (:1, :2, :3)",
&[&cid, &iid, &added_at],
)
.map_err(map_oracle_error)?;
Ok(())
})
.await
}
async fn list_items(
&self,
conversation_id: &ConversationId,
params: ListParams,
) -> ItemResult<Vec<ConversationItem>> {
let cid = conversation_id.0.clone();
let limit: i64 = params.limit as i64;
let order_desc = matches!(params.order, SortOrder::Desc);
let after_id = params.after.clone();
// Resolve the added_at of the after cursor if provided
let after_key: Option<(DateTime<Utc>, String)> = if let Some(ref aid) = after_id {
self.with_connection({
let cid = cid.clone();
let aid = aid.clone();
move |conn| {
let mut stmt = conn
.statement(
"SELECT added_at FROM conversation_item_links WHERE conversation_id = :1 AND item_id = :2",
)
.build()
.map_err(map_oracle_error)?;
let mut rows = stmt.query(&[&cid, &aid]).map_err(map_oracle_error)?;
if let Some(row_res) = rows.next() {
let row = row_res.map_err(map_oracle_error)?;
let ts: DateTime<Utc> = row.get(0).map_err(map_oracle_error)?;
Ok(Some((ts, aid)))
} else {
Ok(None)
}
}
})
.await?
} else {
None
};
// Build the main list query
let rows: Vec<(String, Option<String>, String, Option<String>, Option<String>, Option<String>, DateTime<Utc>)> =
self.with_connection({
let cid = cid.clone();
move |conn| {
let mut sql = String::from(
"SELECT i.id, i.response_id, i.item_type, i.role, i.content, i.status, i.created_at \
FROM conversation_item_links l \
JOIN conversation_items i ON i.id = l.item_id \
WHERE l.conversation_id = :cid",
);
// Cursor predicate
if let Some((_ts, _iid)) = &after_key {
if order_desc {
sql.push_str(" AND (l.added_at < :ats OR (l.added_at = :ats AND l.item_id < :iid))");
} else {
sql.push_str(" AND (l.added_at > :ats OR (l.added_at = :ats AND l.item_id > :iid))");
}
}
// Order and limit
if order_desc {
sql.push_str(" ORDER BY l.added_at DESC, l.item_id DESC");
} else {
sql.push_str(" ORDER BY l.added_at ASC, l.item_id ASC");
}
sql.push_str(" FETCH NEXT :limit ROWS ONLY");
// Build params and perform a named SELECT query
let mut params_vec: Vec<(&str, &dyn ToSql)> = vec![("cid", &cid)];
if let Some((ts, iid)) = &after_key {
params_vec.push(("ats", ts));
params_vec.push(("iid", iid));
}
params_vec.push(("limit", &limit));
let rows_iter = conn.query_named(&sql, &params_vec).map_err(map_oracle_error)?;
let mut out = Vec::new();
for row_res in rows_iter {
let row = row_res.map_err(map_oracle_error)?;
let id: String = row.get(0).map_err(map_oracle_error)?;
let resp_id: Option<String> = row.get(1).map_err(map_oracle_error)?;
let item_type: String = row.get(2).map_err(map_oracle_error)?;
let role: Option<String> = row.get(3).map_err(map_oracle_error)?;
let content_raw: Option<String> = row.get(4).map_err(map_oracle_error)?;
let status: Option<String> = row.get(5).map_err(map_oracle_error)?;
let created_at: DateTime<Utc> = row.get(6).map_err(map_oracle_error)?;
out.push((id, resp_id, item_type, role, content_raw, status, created_at));
}
Ok(out)
}
})
.await?;
// Map rows to ConversationItem; propagate JSON parse errors instead of swallowing
rows.into_iter()
.map(
|(id, resp_id, item_type, role, content_raw, status, created_at)| {
let content = match content_raw {
Some(s) => {
serde_json::from_str(&s).map_err(ConversationItemStorageError::from)?
}
None => Value::Null,
};
Ok(ConversationItem {
id: ConversationItemId(id),
response_id: resp_id,
item_type,
role,
content,
status,
created_at,
})
},
)
.collect()
}
}
#[derive(Clone)]
struct ConversationItemOracleConnectionManager {
params: Arc<OracleConnectParams>,
}
#[derive(Clone)]
struct OracleConnectParams {
username: String,
password: String,
connect_descriptor: String,
}
impl ConversationItemOracleConnectionManager {
fn new(config: Arc<OracleConfig>) -> Self {
let params = OracleConnectParams {
username: config.username.clone(),
password: config.password.clone(),
connect_descriptor: config.connect_descriptor.clone(),
};
Self {
params: Arc::new(params),
}
}
}
impl std::fmt::Debug for ConversationItemOracleConnectionManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConversationItemOracleConnectionManager")
.field("username", &self.params.username)
.field("connect_descriptor", &self.params.connect_descriptor)
.finish()
}
}
#[async_trait]
impl Manager for ConversationItemOracleConnectionManager {
type Type = Connection;
type Error = oracle::Error;
fn create(
&self,
) -> impl std::future::Future<Output = std::result::Result<Connection, oracle::Error>> + Send
{
let params = self.params.clone();
async move {
let mut conn = Connection::connect(
&params.username,
&params.password,
&params.connect_descriptor,
)?;
conn.set_autocommit(true);
Ok(conn)
}
}
#[allow(clippy::manual_async_fn)]
fn recycle(
&self,
conn: &mut Connection,
_: &Metrics,
) -> impl std::future::Future<Output = RecycleResult<Self::Error>> + Send {
async move { conn.ping().map_err(RecycleError::Backend) }
}
}
fn configure_oracle_client(config: &OracleConfig) -> ItemResult<()> {
if let Some(wallet_path) = &config.wallet_path {
let wallet_path = Path::new(wallet_path);
if !wallet_path.is_dir() {
return Err(ConversationItemStorageError::StorageError(format!(
"Oracle wallet/config path '{}' is not a directory",
wallet_path.display()
)));
}
if !wallet_path.join("tnsnames.ora").exists() && !wallet_path.join("sqlnet.ora").exists() {
return Err(ConversationItemStorageError::StorageError(format!(
"Oracle wallet/config path '{}' is missing tnsnames.ora or sqlnet.ora",
wallet_path.display()
)));
}
std::env::set_var("TNS_ADMIN", wallet_path);
}
Ok(())
}
fn initialize_schema(config: &OracleConfig) -> ItemResult<()> {
let conn = Connection::connect(
&config.username,
&config.password,
&config.connect_descriptor,
)
.map_err(map_oracle_error)?;
let exists_items: i64 = conn
.query_row_as(
"SELECT COUNT(*) FROM user_tables WHERE table_name = 'CONVERSATION_ITEMS'",
&[],
)
.map_err(map_oracle_error)?;
if exists_items == 0 {
conn.execute(
"CREATE TABLE conversation_items (
id VARCHAR2(64) PRIMARY KEY,
response_id VARCHAR2(64),
item_type VARCHAR2(32) NOT NULL,
role VARCHAR2(32),
content CLOB,
status VARCHAR2(32),
created_at TIMESTAMP WITH TIME ZONE
)",
&[],
)
.map_err(map_oracle_error)?;
}
let exists_links: i64 = conn
.query_row_as(
"SELECT COUNT(*) FROM user_tables WHERE table_name = 'CONVERSATION_ITEM_LINKS'",
&[],
)
.map_err(map_oracle_error)?;
if exists_links == 0 {
conn.execute(
"CREATE TABLE conversation_item_links (
conversation_id VARCHAR2(64) NOT NULL,
item_id VARCHAR2(64) NOT NULL,
added_at TIMESTAMP WITH TIME ZONE,
CONSTRAINT pk_conv_item_link PRIMARY KEY (conversation_id, item_id)
)",
&[],
)
.map_err(map_oracle_error)?;
conn.execute(
"CREATE INDEX conv_item_links_conv_idx ON conversation_item_links (conversation_id, added_at)",
&[],
)
.map_err(map_oracle_error)?;
}
Ok(())
}
fn map_pool_error(err: PoolError<oracle::Error>) -> ConversationItemStorageError {
match err {
PoolError::Backend(e) => map_oracle_error(e),
other => ConversationItemStorageError::StorageError(format!(
"failed to obtain Oracle conversation item connection: {other}"
)),
}
}
fn map_oracle_error(err: oracle::Error) -> ConversationItemStorageError {
if let Some(db_err) = err.db_error() {
ConversationItemStorageError::StorageError(format!(
"Oracle error (code {}): {}",
db_err.code(),
db_err.message()
))
} else {
ConversationItemStorageError::StorageError(err.to_string())
}
}
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use rand::RngCore;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::fmt::{Display, Formatter};
use std::sync::Arc;
use super::conversations::ConversationId;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
pub struct ConversationItemId(pub String);
impl Display for ConversationItemId {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}
impl From<String> for ConversationItemId {
fn from(value: String) -> Self {
Self(value)
}
}
impl From<&str> for ConversationItemId {
fn from(value: &str) -> Self {
Self(value.to_string())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConversationItem {
pub id: ConversationItemId,
pub response_id: Option<String>,
pub item_type: String,
pub role: Option<String>,
pub content: Value,
pub status: Option<String>,
pub created_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NewConversationItem {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub id: Option<ConversationItemId>,
pub response_id: Option<String>,
pub item_type: String,
pub role: Option<String>,
pub content: Value,
pub status: Option<String>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum SortOrder {
Asc,
Desc,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ListParams {
pub limit: usize,
pub order: SortOrder,
pub after: Option<String>, // item_id cursor
}
pub type Result<T> = std::result::Result<T, ConversationItemStorageError>;
#[derive(Debug, thiserror::Error)]
pub enum ConversationItemStorageError {
#[error("Not found: {0}")]
NotFound(String),
#[error("Storage error: {0}")]
StorageError(String),
#[error("Serialization error: {0}")]
SerializationError(#[from] serde_json::Error),
}
#[async_trait]
pub trait ConversationItemStorage: Send + Sync + 'static {
async fn create_item(&self, item: NewConversationItem) -> Result<ConversationItem>;
async fn link_item(
&self,
conversation_id: &ConversationId,
item_id: &ConversationItemId,
added_at: DateTime<Utc>,
) -> Result<()>;
async fn list_items(
&self,
conversation_id: &ConversationId,
params: ListParams,
) -> Result<Vec<ConversationItem>>;
}
pub type SharedConversationItemStorage = Arc<dyn ConversationItemStorage>;
/// Helper to build id prefix based on item_type
pub fn make_item_id(item_type: &str) -> ConversationItemId {
// Generate a 24-byte random hex string (48 hex chars), consistent with conversation id style
let mut rng = rand::rng();
let mut bytes = [0u8; 24];
rng.fill_bytes(&mut bytes);
let hex_string: String = bytes.iter().map(|b| format!("{:02x}", b)).collect();
let prefix: String = match item_type {
"message" => "msg".to_string(),
"reasoning" => "rs".to_string(),
"mcp_call" => "mcp".to_string(),
"mcp_list_tools" => "mcpl".to_string(),
"function_tool_call" => "ftc".to_string(),
other => {
// Fallback: first 3 letters of type or "itm"
let mut p = other.chars().take(3).collect::<String>();
if p.is_empty() {
p = "itm".to_string();
}
p
}
};
ConversationItemId(format!("{}_{}", prefix, hex_string))
}
// Data connector module for response storage and conversation storage // Data connector module for response storage and conversation storage
pub mod conversation_item_memory_store;
pub mod conversation_item_oracle_store;
pub mod conversation_items;
pub mod conversation_memory_store; pub mod conversation_memory_store;
pub mod conversation_noop_store; pub mod conversation_noop_store;
pub mod conversation_oracle_store; pub mod conversation_oracle_store;
...@@ -8,6 +11,14 @@ pub mod response_noop_store; ...@@ -8,6 +11,14 @@ pub mod response_noop_store;
pub mod response_oracle_store; pub mod response_oracle_store;
pub mod responses; pub mod responses;
pub use conversation_item_memory_store::MemoryConversationItemStorage;
pub use conversation_item_oracle_store::OracleConversationItemStorage;
pub use conversation_items::{
ConversationItem, ConversationItemId, ConversationItemStorage, ConversationItemStorageError,
ListParams as ConversationItemsListParams, NewConversationItem,
Result as ConversationItemsResult, SharedConversationItemStorage,
SortOrder as ConversationItemsSortOrder,
};
pub use conversation_memory_store::MemoryConversationStorage; pub use conversation_memory_store::MemoryConversationStorage;
pub use conversation_noop_store::NoOpConversationStorage; pub use conversation_noop_store::NoOpConversationStorage;
pub use conversation_oracle_store::OracleConversationStorage; pub use conversation_oracle_store::OracleConversationStorage;
......
...@@ -13,7 +13,7 @@ use std::sync::Arc; ...@@ -13,7 +13,7 @@ use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
const SELECT_BASE: &str = "SELECT id, previous_response_id, input, instructions, output, \ const SELECT_BASE: &str = "SELECT id, previous_response_id, input, instructions, output, \
tool_calls, metadata, created_at, user_id, model, raw_response FROM responses"; tool_calls, metadata, created_at, user_id, model, conversation_id, raw_response FROM responses";
#[derive(Clone)] #[derive(Clone)]
pub struct OracleResponseStorage { pub struct OracleResponseStorage {
...@@ -95,8 +95,11 @@ impl OracleResponseStorage { ...@@ -95,8 +95,11 @@ impl OracleResponseStorage {
let model: Option<String> = row let model: Option<String> = row
.get(9) .get(9)
.map_err(|err| map_oracle_error(err).into_storage_error("fetch model"))?; .map_err(|err| map_oracle_error(err).into_storage_error("fetch model"))?;
let raw_response_json: Option<String> = row let conversation_id: Option<String> = row
.get(10) .get(10)
.map_err(|err| map_oracle_error(err).into_storage_error("fetch conversation_id"))?;
let raw_response_json: Option<String> = row
.get(11)
.map_err(|err| map_oracle_error(err).into_storage_error("fetch raw_response"))?; .map_err(|err| map_oracle_error(err).into_storage_error("fetch raw_response"))?;
let previous_response_id = previous.map(ResponseId); let previous_response_id = previous.map(ResponseId);
...@@ -115,6 +118,7 @@ impl OracleResponseStorage { ...@@ -115,6 +118,7 @@ impl OracleResponseStorage {
created_at, created_at,
user: user_id, user: user_id,
model, model,
conversation_id,
raw_response, raw_response,
}) })
} }
...@@ -134,6 +138,7 @@ impl ResponseStorage for OracleResponseStorage { ...@@ -134,6 +138,7 @@ impl ResponseStorage for OracleResponseStorage {
created_at, created_at,
user, user,
model, model,
conversation_id,
raw_response, raw_response,
} = response; } = response;
...@@ -147,8 +152,8 @@ impl ResponseStorage for OracleResponseStorage { ...@@ -147,8 +152,8 @@ impl ResponseStorage for OracleResponseStorage {
self.with_connection(move |conn| { self.with_connection(move |conn| {
conn.execute( conn.execute(
"INSERT INTO responses (id, previous_response_id, input, instructions, output, \ "INSERT INTO responses (id, previous_response_id, input, instructions, output, \
tool_calls, metadata, created_at, user_id, model, raw_response) \ tool_calls, metadata, created_at, user_id, model, conversation_id, raw_response) \
VALUES (:1, :2, :3, :4, :5, :6, :7, :8, :9, :10, :11)", VALUES (:1, :2, :3, :4, :5, :6, :7, :8, :9, :10, :11, :12)",
&[ &[
&response_id_str, &response_id_str,
&previous_id, &previous_id,
...@@ -160,6 +165,7 @@ impl ResponseStorage for OracleResponseStorage { ...@@ -160,6 +165,7 @@ impl ResponseStorage for OracleResponseStorage {
&created_at, &created_at,
&user, &user,
&model, &model,
&conversation_id,
&json_raw_response, &json_raw_response,
], ],
) )
...@@ -394,6 +400,7 @@ fn initialize_schema(config: &OracleConfig) -> StorageResult<()> { ...@@ -394,6 +400,7 @@ fn initialize_schema(config: &OracleConfig) -> StorageResult<()> {
conn.execute( conn.execute(
"CREATE TABLE responses ( "CREATE TABLE responses (
id VARCHAR2(64) PRIMARY KEY, id VARCHAR2(64) PRIMARY KEY,
conversation_id VARCHAR2(64),
previous_response_id VARCHAR2(64), previous_response_id VARCHAR2(64),
input CLOB, input CLOB,
instructions CLOB, instructions CLOB,
......
...@@ -65,6 +65,10 @@ pub struct StoredResponse { ...@@ -65,6 +65,10 @@ pub struct StoredResponse {
/// Model used for generation /// Model used for generation
pub model: Option<String>, pub model: Option<String>,
/// Conversation id if associated with a conversation
#[serde(default)]
pub conversation_id: Option<String>,
/// Raw OpenAI response payload /// Raw OpenAI response payload
#[serde(default)] #[serde(default)]
pub raw_response: Value, pub raw_response: Value,
...@@ -83,6 +87,7 @@ impl StoredResponse { ...@@ -83,6 +87,7 @@ impl StoredResponse {
created_at: chrono::Utc::now(), created_at: chrono::Utc::now(),
user: None, user: None,
model: None, model: None,
conversation_id: None,
raw_response: Value::Null, raw_response: Value::Null,
} }
} }
......
...@@ -1103,6 +1103,10 @@ pub struct ResponsesRequest { ...@@ -1103,6 +1103,10 @@ pub struct ResponsesRequest {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>, pub model: Option<String>,
/// Optional conversation id to persist input/output as items
#[serde(skip_serializing_if = "Option::is_none")]
pub conversation: Option<String>,
/// Whether to enable parallel tool calls /// Whether to enable parallel tool calls
#[serde(default = "default_true")] #[serde(default = "default_true")]
pub parallel_tool_calls: bool, pub parallel_tool_calls: bool,
...@@ -1214,6 +1218,7 @@ impl Default for ResponsesRequest { ...@@ -1214,6 +1218,7 @@ impl Default for ResponsesRequest {
max_tool_calls: None, max_tool_calls: None,
metadata: None, metadata: None,
model: None, model: None,
conversation: None,
parallel_tool_calls: true, parallel_tool_calls: true,
previous_response_id: None, previous_response_id: None,
reasoning: None, reasoning: None,
......
...@@ -129,6 +129,7 @@ impl RouterFactory { ...@@ -129,6 +129,7 @@ impl RouterFactory {
Some(ctx.router_config.circuit_breaker.clone()), Some(ctx.router_config.circuit_breaker.clone()),
ctx.response_storage.clone(), ctx.response_storage.clone(),
ctx.conversation_storage.clone(), ctx.conversation_storage.clone(),
ctx.conversation_item_storage.clone(),
) )
.await?; .await?;
......
...@@ -173,6 +173,22 @@ pub trait RouterTrait: Send + Sync + Debug { ...@@ -173,6 +173,22 @@ pub trait RouterTrait: Send + Sync + Debug {
.into_response() .into_response()
} }
/// List items for a conversation
async fn list_conversation_items(
&self,
_headers: Option<&HeaderMap>,
_conversation_id: &str,
_limit: Option<usize>,
_order: Option<String>,
_after: Option<String>,
) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"Conversation items list endpoint not implemented",
)
.into_response()
}
/// Get router type name /// Get router type name
fn router_type(&self) -> &'static str; fn router_type(&self) -> &'static str;
......
...@@ -589,6 +589,31 @@ impl RouterTrait for RouterManager { ...@@ -589,6 +589,31 @@ impl RouterTrait for RouterManager {
.into_response() .into_response()
} }
} }
async fn list_conversation_items(
&self,
headers: Option<&HeaderMap>,
conversation_id: &str,
limit: Option<usize>,
order: Option<String>,
after: Option<String>,
) -> Response {
let router = self.select_router_for_request(headers, None);
if let Some(router) = router {
router
.list_conversation_items(headers, conversation_id, limit, order, after)
.await
} else {
(
StatusCode::NOT_FOUND,
format!(
"No router available to list conversation items for '{}'",
conversation_id
),
)
.into_response()
}
}
} }
impl std::fmt::Debug for RouterManager { impl std::fmt::Debug for RouterManager {
......
...@@ -2,9 +2,10 @@ use crate::{ ...@@ -2,9 +2,10 @@ use crate::{
config::{ConnectionMode, HistoryBackend, RouterConfig, RoutingMode}, config::{ConnectionMode, HistoryBackend, RouterConfig, RoutingMode},
core::{LoadMonitor, WorkerManager, WorkerRegistry, WorkerType}, core::{LoadMonitor, WorkerManager, WorkerRegistry, WorkerType},
data_connector::{ data_connector::{
MemoryConversationStorage, MemoryResponseStorage, NoOpConversationStorage, MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage,
NoOpResponseStorage, OracleConversationStorage, OracleResponseStorage, NoOpConversationStorage, NoOpResponseStorage, OracleConversationItemStorage,
SharedConversationStorage, SharedResponseStorage, OracleConversationStorage, OracleResponseStorage, SharedConversationStorage,
SharedResponseStorage,
}, },
logging::{self, LoggingConfig}, logging::{self, LoggingConfig},
metrics::{self, PrometheusConfig}, metrics::{self, PrometheusConfig},
...@@ -56,6 +57,7 @@ pub struct AppContext { ...@@ -56,6 +57,7 @@ pub struct AppContext {
pub router_manager: Option<Arc<RouterManager>>, pub router_manager: Option<Arc<RouterManager>>,
pub response_storage: SharedResponseStorage, pub response_storage: SharedResponseStorage,
pub conversation_storage: SharedConversationStorage, pub conversation_storage: SharedConversationStorage,
pub conversation_item_storage: crate::data_connector::SharedConversationItemStorage,
pub load_monitor: Option<Arc<LoadMonitor>>, pub load_monitor: Option<Arc<LoadMonitor>>,
pub configured_reasoning_parser: Option<String>, pub configured_reasoning_parser: Option<String>,
pub configured_tool_parser: Option<String>, pub configured_tool_parser: Option<String>,
...@@ -121,8 +123,8 @@ impl AppContext { ...@@ -121,8 +123,8 @@ impl AppContext {
format!("failed to initialize Oracle response storage: {err}") format!("failed to initialize Oracle response storage: {err}")
})?; })?;
let conversation_storage = let conversation_storage = OracleConversationStorage::new(oracle_cfg.clone())
OracleConversationStorage::new(oracle_cfg).map_err(|err| { .map_err(|err| {
format!("failed to initialize Oracle conversation storage: {err}") format!("failed to initialize Oracle conversation storage: {err}")
})?; })?;
...@@ -130,6 +132,20 @@ impl AppContext { ...@@ -130,6 +132,20 @@ impl AppContext {
} }
}; };
// Conversation items storage (memory-backed for now)
let conversation_item_storage: crate::data_connector::SharedConversationItemStorage =
match router_config.history_backend {
HistoryBackend::Oracle => {
let oracle_cfg = router_config.oracle.clone().ok_or_else(|| {
"oracle configuration is required when history_backend=oracle".to_string()
})?;
Arc::new(OracleConversationItemStorage::new(oracle_cfg).map_err(|e| {
format!("failed to initialize Oracle conversation item storage: {e}")
})?)
}
_ => Arc::new(MemoryConversationItemStorage::new()),
};
let load_monitor = Some(Arc::new(LoadMonitor::new( let load_monitor = Some(Arc::new(LoadMonitor::new(
worker_registry.clone(), worker_registry.clone(),
policy_registry.clone(), policy_registry.clone(),
...@@ -152,6 +168,7 @@ impl AppContext { ...@@ -152,6 +168,7 @@ impl AppContext {
router_manager, router_manager,
response_storage, response_storage,
conversation_storage, conversation_storage,
conversation_item_storage,
load_monitor, load_monitor,
configured_reasoning_parser, configured_reasoning_parser,
configured_tool_parser, configured_tool_parser,
...@@ -400,6 +417,29 @@ async fn v1_conversations_delete( ...@@ -400,6 +417,29 @@ async fn v1_conversations_delete(
.await .await
} }
#[derive(Deserialize, Default)]
struct ListItemsQuery {
limit: Option<usize>,
order: Option<String>,
after: Option<String>,
}
async fn v1_conversations_list_items(
State(state): State<Arc<AppState>>,
Path(conversation_id): Path<String>,
Query(ListItemsQuery {
limit,
order,
after,
}): Query<ListItemsQuery>,
headers: http::HeaderMap,
) -> Response {
state
.router
.list_conversation_items(Some(&headers), &conversation_id, limit, order, after)
.await
}
#[derive(Deserialize)] #[derive(Deserialize)]
struct AddWorkerQuery { struct AddWorkerQuery {
url: String, url: String,
...@@ -674,6 +714,10 @@ pub fn build_app( ...@@ -674,6 +714,10 @@ pub fn build_app(
.post(v1_conversations_update) .post(v1_conversations_update)
.delete(v1_conversations_delete), .delete(v1_conversations_delete),
) )
.route(
"/v1/conversations/{conversation_id}/items",
get(v1_conversations_list_items),
)
.route_layer(axum::middleware::from_fn_with_state( .route_layer(axum::middleware::from_fn_with_state(
app_state.clone(), app_state.clone(),
middleware::concurrency_limit_middleware, middleware::concurrency_limit_middleware,
......
...@@ -543,6 +543,9 @@ mod tests { ...@@ -543,6 +543,9 @@ mod tests {
router_manager: None, router_manager: None,
response_storage: Arc::new(crate::data_connector::MemoryResponseStorage::new()), response_storage: Arc::new(crate::data_connector::MemoryResponseStorage::new()),
conversation_storage: Arc::new(crate::data_connector::MemoryConversationStorage::new()), conversation_storage: Arc::new(crate::data_connector::MemoryConversationStorage::new()),
conversation_item_storage: Arc::new(
crate::data_connector::MemoryConversationItemStorage::new(),
),
load_monitor: None, load_monitor: None,
configured_reasoning_parser: None, configured_reasoning_parser: None,
configured_tool_parser: None, configured_tool_parser: None,
......
...@@ -125,6 +125,7 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() { ...@@ -125,6 +125,7 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
top_k: -1, top_k: -1,
min_p: 0.0, min_p: 0.0,
repetition_penalty: 1.0, repetition_penalty: 1.0,
conversation: None,
}; };
let resp = router let resp = router
...@@ -371,6 +372,7 @@ fn test_responses_request_creation() { ...@@ -371,6 +372,7 @@ fn test_responses_request_creation() {
top_k: -1, top_k: -1,
min_p: 0.0, min_p: 0.0,
repetition_penalty: 1.0, repetition_penalty: 1.0,
conversation: None,
}; };
assert!(!request.is_stream()); assert!(!request.is_stream());
...@@ -411,6 +413,7 @@ fn test_sampling_params_conversion() { ...@@ -411,6 +413,7 @@ fn test_sampling_params_conversion() {
top_k: 10, top_k: 10,
min_p: 0.05, min_p: 0.05,
repetition_penalty: 1.1, repetition_penalty: 1.1,
conversation: None,
}; };
let params = request.to_sampling_params(1000, None); let params = request.to_sampling_params(1000, None);
...@@ -524,6 +527,7 @@ fn test_json_serialization() { ...@@ -524,6 +527,7 @@ fn test_json_serialization() {
top_k: 50, top_k: 50,
min_p: 0.1, min_p: 0.1,
repetition_penalty: 1.2, repetition_penalty: 1.2,
conversation: None,
}; };
let json = serde_json::to_string(&request).expect("Serialization should work"); let json = serde_json::to_string(&request).expect("Serialization should work");
...@@ -651,6 +655,7 @@ async fn test_multi_turn_loop_with_mcp() { ...@@ -651,6 +655,7 @@ async fn test_multi_turn_loop_with_mcp() {
top_k: 50, top_k: 50,
min_p: 0.0, min_p: 0.0,
repetition_penalty: 1.0, repetition_penalty: 1.0,
conversation: None,
}; };
// Execute the request (this should trigger the multi-turn loop) // Execute the request (this should trigger the multi-turn loop)
...@@ -828,6 +833,7 @@ async fn test_max_tool_calls_limit() { ...@@ -828,6 +833,7 @@ async fn test_max_tool_calls_limit() {
top_k: 50, top_k: 50,
min_p: 0.0, min_p: 0.0,
repetition_penalty: 1.0, repetition_penalty: 1.0,
conversation: None,
}; };
let response = router.route_responses(None, &req, None).await; let response = router.route_responses(None, &req, None).await;
...@@ -1023,6 +1029,7 @@ async fn test_streaming_with_mcp_tool_calls() { ...@@ -1023,6 +1029,7 @@ async fn test_streaming_with_mcp_tool_calls() {
top_k: 50, top_k: 50,
min_p: 0.0, min_p: 0.0,
repetition_penalty: 1.0, repetition_penalty: 1.0,
conversation: None,
}; };
let response = router.route_responses(None, &req, None).await; let response = router.route_responses(None, &req, None).await;
...@@ -1301,6 +1308,7 @@ async fn test_streaming_multi_turn_with_mcp() { ...@@ -1301,6 +1308,7 @@ async fn test_streaming_multi_turn_with_mcp() {
top_k: 50, top_k: 50,
min_p: 0.0, min_p: 0.0,
repetition_penalty: 1.0, repetition_penalty: 1.0,
conversation: None,
}; };
let response = router.route_responses(None, &req, None).await; let response = router.route_responses(None, &req, None).await;
......
...@@ -9,6 +9,7 @@ use axum::{ ...@@ -9,6 +9,7 @@ use axum::{
Json, Router, Json, Router,
}; };
use serde_json::json; use serde_json::json;
use sglang_router_rs::data_connector::MemoryConversationItemStorage;
use sglang_router_rs::{ use sglang_router_rs::{
config::{ config::{
ConfigError, ConfigValidator, HistoryBackend, OracleConfig, RouterConfig, RoutingMode, ConfigError, ConfigValidator, HistoryBackend, OracleConfig, RouterConfig, RoutingMode,
...@@ -95,6 +96,7 @@ async fn test_openai_router_creation() { ...@@ -95,6 +96,7 @@ async fn test_openai_router_creation() {
None, None,
Arc::new(MemoryResponseStorage::new()), Arc::new(MemoryResponseStorage::new()),
Arc::new(MemoryConversationStorage::new()), Arc::new(MemoryConversationStorage::new()),
Arc::new(sglang_router_rs::data_connector::MemoryConversationItemStorage::new()),
) )
.await; .await;
...@@ -113,6 +115,7 @@ async fn test_openai_router_server_info() { ...@@ -113,6 +115,7 @@ async fn test_openai_router_server_info() {
None, None,
Arc::new(MemoryResponseStorage::new()), Arc::new(MemoryResponseStorage::new()),
Arc::new(MemoryConversationStorage::new()), Arc::new(MemoryConversationStorage::new()),
Arc::new(MemoryConversationItemStorage::new()),
) )
.await .await
.unwrap(); .unwrap();
...@@ -143,6 +146,7 @@ async fn test_openai_router_models() { ...@@ -143,6 +146,7 @@ async fn test_openai_router_models() {
None, None,
Arc::new(MemoryResponseStorage::new()), Arc::new(MemoryResponseStorage::new()),
Arc::new(MemoryConversationStorage::new()), Arc::new(MemoryConversationStorage::new()),
Arc::new(sglang_router_rs::data_connector::MemoryConversationItemStorage::new()),
) )
.await .await
.unwrap(); .unwrap();
...@@ -222,6 +226,7 @@ async fn test_openai_router_responses_with_mock() { ...@@ -222,6 +226,7 @@ async fn test_openai_router_responses_with_mock() {
None, None,
storage.clone(), storage.clone(),
Arc::new(MemoryConversationStorage::new()), Arc::new(MemoryConversationStorage::new()),
Arc::new(sglang_router_rs::data_connector::MemoryConversationItemStorage::new()),
) )
.await .await
.unwrap(); .unwrap();
...@@ -482,6 +487,7 @@ async fn test_openai_router_responses_streaming_with_mock() { ...@@ -482,6 +487,7 @@ async fn test_openai_router_responses_streaming_with_mock() {
None, None,
storage.clone(), storage.clone(),
Arc::new(MemoryConversationStorage::new()), Arc::new(MemoryConversationStorage::new()),
Arc::new(sglang_router_rs::data_connector::MemoryConversationItemStorage::new()),
) )
.await .await
.unwrap(); .unwrap();
...@@ -586,6 +592,7 @@ async fn test_unsupported_endpoints() { ...@@ -586,6 +592,7 @@ async fn test_unsupported_endpoints() {
None, None,
Arc::new(MemoryResponseStorage::new()), Arc::new(MemoryResponseStorage::new()),
Arc::new(MemoryConversationStorage::new()), Arc::new(MemoryConversationStorage::new()),
Arc::new(sglang_router_rs::data_connector::MemoryConversationItemStorage::new()),
) )
.await .await
.unwrap(); .unwrap();
...@@ -627,6 +634,7 @@ async fn test_openai_router_chat_completion_with_mock() { ...@@ -627,6 +634,7 @@ async fn test_openai_router_chat_completion_with_mock() {
None, None,
Arc::new(MemoryResponseStorage::new()), Arc::new(MemoryResponseStorage::new()),
Arc::new(MemoryConversationStorage::new()), Arc::new(MemoryConversationStorage::new()),
Arc::new(sglang_router_rs::data_connector::MemoryConversationItemStorage::new()),
) )
.await .await
.unwrap(); .unwrap();
...@@ -669,6 +677,7 @@ async fn test_openai_e2e_with_server() { ...@@ -669,6 +677,7 @@ async fn test_openai_e2e_with_server() {
None, None,
Arc::new(MemoryResponseStorage::new()), Arc::new(MemoryResponseStorage::new()),
Arc::new(MemoryConversationStorage::new()), Arc::new(MemoryConversationStorage::new()),
Arc::new(sglang_router_rs::data_connector::MemoryConversationItemStorage::new()),
) )
.await .await
.unwrap(); .unwrap();
...@@ -739,6 +748,7 @@ async fn test_openai_router_chat_streaming_with_mock() { ...@@ -739,6 +748,7 @@ async fn test_openai_router_chat_streaming_with_mock() {
None, None,
Arc::new(MemoryResponseStorage::new()), Arc::new(MemoryResponseStorage::new()),
Arc::new(MemoryConversationStorage::new()), Arc::new(MemoryConversationStorage::new()),
Arc::new(sglang_router_rs::data_connector::MemoryConversationItemStorage::new()),
) )
.await .await
.unwrap(); .unwrap();
...@@ -792,6 +802,7 @@ async fn test_openai_router_circuit_breaker() { ...@@ -792,6 +802,7 @@ async fn test_openai_router_circuit_breaker() {
Some(cb_config), Some(cb_config),
Arc::new(MemoryResponseStorage::new()), Arc::new(MemoryResponseStorage::new()),
Arc::new(MemoryConversationStorage::new()), Arc::new(MemoryConversationStorage::new()),
Arc::new(sglang_router_rs::data_connector::MemoryConversationItemStorage::new()),
) )
.await .await
.unwrap(); .unwrap();
...@@ -820,6 +831,7 @@ async fn test_openai_router_models_auth_forwarding() { ...@@ -820,6 +831,7 @@ async fn test_openai_router_models_auth_forwarding() {
None, None,
Arc::new(MemoryResponseStorage::new()), Arc::new(MemoryResponseStorage::new()),
Arc::new(MemoryConversationStorage::new()), Arc::new(MemoryConversationStorage::new()),
Arc::new(sglang_router_rs::data_connector::MemoryConversationItemStorage::new()),
) )
.await .await
.unwrap(); .unwrap();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment