Unverified Commit 2cf3d0f8 authored by rongfu.leng's avatar rongfu.leng Committed by GitHub
Browse files

[router] use safety_identifier replace user on chat history storage (#12185)

parent 1ed1abfd
......@@ -340,8 +340,8 @@ pub struct StoredResponse {
/// When this response was created
pub created_at: DateTime<Utc>,
/// User identifier (optional)
pub user: Option<String>,
/// Safety identifier for content moderation
pub safety_identifier: Option<String>,
/// Model used for generation
pub model: Option<String>,
......@@ -366,7 +366,7 @@ impl StoredResponse {
tool_calls: Vec::new(),
metadata: HashMap::new(),
created_at: Utc::now(),
user: None,
safety_identifier: None,
model: None,
conversation_id: None,
raw_response: Value::Null,
......@@ -465,15 +465,15 @@ pub trait ResponseStorage: Send + Sync {
max_depth: Option<usize>,
) -> ResponseResult<ResponseChain>;
/// List recent responses for a user
async fn list_user_responses(
/// List recent responses for a safety identifier
async fn list_identifier_responses(
&self,
user: &str,
identifier: &str,
limit: Option<usize>,
) -> ResponseResult<Vec<StoredResponse>>;
/// Delete all responses for a user
async fn delete_user_responses(&self, user: &str) -> ResponseResult<usize>;
/// Delete all responses for a safety identifier
async fn delete_identifier_responses(&self, identifier: &str) -> ResponseResult<usize>;
}
impl Default for StoredResponse {
......
......@@ -270,8 +270,8 @@ impl ConversationItemStorage for MemoryConversationItemStorage {
struct InnerStore {
/// All stored responses indexed by ID
responses: HashMap<ResponseId, StoredResponse>,
/// Index of response IDs by user
user_index: HashMap<String, Vec<ResponseId>>,
/// Index of response IDs by safety identifier
identifier_index: HashMap<String, Vec<ResponseId>>,
}
/// In-memory implementation of response storage
......@@ -292,7 +292,7 @@ impl MemoryResponseStorage {
let store = self.store.read();
MemoryStoreStats {
response_count: store.responses.len(),
user_count: store.user_index.len(),
identifier_count: store.identifier_index.len(),
}
}
......@@ -300,7 +300,7 @@ impl MemoryResponseStorage {
pub fn clear(&self) {
let mut store = self.store.write();
store.responses.clear();
store.user_index.clear();
store.identifier_index.clear();
}
}
......@@ -323,11 +323,11 @@ impl ResponseStorage for MemoryResponseStorage {
// Single lock acquisition for atomic update
let mut store = self.store.write();
// Update user index if user is specified
if let Some(ref user) = response.user {
// Update safety identifier index if specified
if let Some(ref safety_identifier) = response.safety_identifier {
store
.user_index
.entry(user.clone())
.identifier_index
.entry(safety_identifier.clone())
.or_default()
.push(response_id.clone());
}
......@@ -354,8 +354,8 @@ impl ResponseStorage for MemoryResponseStorage {
// Remove the response and update user index if needed
if let Some(response) = store.responses.remove(response_id) {
if let Some(ref user) = response.user {
if let Some(user_responses) = store.user_index.get_mut(user) {
if let Some(ref safety_identifier) = response.safety_identifier {
if let Some(user_responses) = store.identifier_index.get_mut(safety_identifier) {
user_responses.retain(|id| id != response_id);
}
}
......@@ -409,14 +409,14 @@ impl ResponseStorage for MemoryResponseStorage {
Ok(chain)
}
async fn list_user_responses(
async fn list_identifier_responses(
&self,
user: &str,
identifier: &str,
limit: Option<usize>,
) -> ResponseResult<Vec<StoredResponse>> {
let store = self.store.read();
if let Some(user_response_ids) = store.user_index.get(user) {
if let Some(user_response_ids) = store.identifier_index.get(identifier) {
// Collect responses with their timestamps for sorting
let mut responses_with_time: Vec<_> = user_response_ids
.iter()
......@@ -440,10 +440,10 @@ impl ResponseStorage for MemoryResponseStorage {
}
}
async fn delete_user_responses(&self, user: &str) -> ResponseResult<usize> {
async fn delete_identifier_responses(&self, identifier: &str) -> ResponseResult<usize> {
let mut store = self.store.write();
if let Some(user_response_ids) = store.user_index.remove(user) {
if let Some(user_response_ids) = store.identifier_index.remove(identifier) {
let count = user_response_ids.len();
for id in user_response_ids {
store.responses.remove(&id);
......@@ -459,7 +459,7 @@ impl ResponseStorage for MemoryResponseStorage {
#[derive(Debug, Clone)]
pub struct MemoryStoreStats {
pub response_count: usize,
pub user_count: usize,
pub identifier_count: usize,
}
#[cfg(test)]
......@@ -660,38 +660,50 @@ mod tests {
let mut response1 = StoredResponse::new(None);
response1.input = json!("User1 message");
response1.output = json!("Response to user1");
response1.user = Some("user1".to_string());
response1.safety_identifier = Some("user1".to_string());
store.store_response(response1).await.unwrap();
let mut response2 = StoredResponse::new(None);
response2.input = json!("Another user1 message");
response2.output = json!("Another response to user1");
response2.user = Some("user1".to_string());
response2.safety_identifier = Some("user1".to_string());
store.store_response(response2).await.unwrap();
let mut response3 = StoredResponse::new(None);
response3.input = json!("User2 message");
response3.output = json!("Response to user2");
response3.user = Some("user2".to_string());
response3.safety_identifier = Some("user2".to_string());
store.store_response(response3).await.unwrap();
// List user1's responses
let user1_responses = store.list_user_responses("user1", None).await.unwrap();
let user1_responses = store
.list_identifier_responses("user1", None)
.await
.unwrap();
assert_eq!(user1_responses.len(), 2);
// List user2's responses
let user2_responses = store.list_user_responses("user2", None).await.unwrap();
let user2_responses = store
.list_identifier_responses("user2", None)
.await
.unwrap();
assert_eq!(user2_responses.len(), 1);
// Delete user1's responses
let deleted_count = store.delete_user_responses("user1").await.unwrap();
let deleted_count = store.delete_identifier_responses("user1").await.unwrap();
assert_eq!(deleted_count, 2);
let user1_responses_after = store.list_user_responses("user1", None).await.unwrap();
let user1_responses_after = store
.list_identifier_responses("user1", None)
.await
.unwrap();
assert_eq!(user1_responses_after.len(), 0);
// User2's responses should still be there
let user2_responses_after = store.list_user_responses("user2", None).await.unwrap();
let user2_responses_after = store
.list_identifier_responses("user2", None)
.await
.unwrap();
assert_eq!(user2_responses_after.len(), 1);
}
......@@ -702,17 +714,17 @@ mod tests {
let mut response1 = StoredResponse::new(None);
response1.input = json!("Test1");
response1.output = json!("Reply1");
response1.user = Some("user1".to_string());
response1.safety_identifier = Some("user1".to_string());
store.store_response(response1).await.unwrap();
let mut response2 = StoredResponse::new(None);
response2.input = json!("Test2");
response2.output = json!("Reply2");
response2.user = Some("user2".to_string());
response2.safety_identifier = Some("user2".to_string());
store.store_response(response2).await.unwrap();
let stats = store.stats();
assert_eq!(stats.response_count, 2);
assert_eq!(stats.user_count, 2);
assert_eq!(stats.identifier_count, 2);
}
}
......@@ -175,15 +175,15 @@ impl ResponseStorage for NoOpResponseStorage {
Ok(ResponseChain::new())
}
async fn list_user_responses(
async fn list_identifier_responses(
&self,
_user: &str,
_identifier: &str,
_limit: Option<usize>,
) -> ResponseResult<Vec<StoredResponse>> {
Ok(Vec::new())
}
async fn delete_user_responses(&self, _user: &str) -> ResponseResult<usize> {
async fn delete_identifier_responses(&self, _identifier: &str) -> ResponseResult<usize> {
Ok(0)
}
}
......@@ -768,7 +768,7 @@ impl ConversationItemStorage for OracleConversationItemStorage {
// ============================================================================
const SELECT_BASE: &str = "SELECT id, previous_response_id, input, instructions, output, \
tool_calls, metadata, created_at, user_id, model, conversation_id, raw_response FROM responses";
tool_calls, metadata, created_at, safety_identifier, model, conversation_id, raw_response FROM responses";
#[derive(Clone)]
pub struct OracleResponseStorage {
......@@ -798,13 +798,16 @@ impl OracleResponseStorage {
tool_calls CLOB,
metadata CLOB,
created_at TIMESTAMP WITH TIME ZONE,
user_id VARCHAR2(128),
safety_identifier VARCHAR2(128),
model VARCHAR2(128),
raw_response CLOB
)",
&[],
)
.map_err(map_oracle_error)?;
} else {
Self::alter_safety_identifier_column(conn)?;
Self::remove_user_id_column_if_exists(conn)?;
}
// Create indexes
......@@ -816,7 +819,7 @@ impl OracleResponseStorage {
create_index_if_missing(
conn,
"RESPONSES_USER_IDX",
"CREATE INDEX responses_user_idx ON responses(user_id)",
"CREATE INDEX responses_user_idx ON responses(safety_identifier)",
)?;
Ok(())
......@@ -826,6 +829,61 @@ impl OracleResponseStorage {
Ok(Self { store })
}
// Alter safety_identifier column if missing
fn alter_safety_identifier_column(conn: &Connection) -> Result<(), String> {
let present: i64 = conn
.query_row_as(
"SELECT COUNT(*) FROM user_tab_columns WHERE table_name = 'RESPONSES' AND column_name = 'SAFETY_IDENTIFIER'",
&[],
)
.map_err(map_oracle_error)?;
if present == 0 {
if let Err(err) = conn.execute(
"ALTER TABLE responses ADD (safety_identifier VARCHAR2(128))",
&[],
) {
let present_after: i64 = conn
.query_row_as(
"SELECT COUNT(*) FROM user_tab_columns WHERE table_name = 'RESPONSES' AND column_name = 'SAFETY_IDENTIFIER'",
&[],
)
.map_err(map_oracle_error)?;
if present_after == 0 {
return Err(map_oracle_error(err));
}
}
}
Ok(())
}
// Remove user_id column if exists
fn remove_user_id_column_if_exists(conn: &Connection) -> Result<(), String> {
let present: i64 = conn
.query_row_as(
"SELECT COUNT(*) FROM user_tab_columns WHERE table_name = 'RESPONSES' AND column_name = 'USER_ID'",
&[],
)
.map_err(map_oracle_error)?;
if present > 0 {
if let Err(err) = conn.execute("ALTER TABLE responses DROP COLUMN USER_ID", &[]) {
let present_after: i64 = conn
.query_row_as(
"SELECT COUNT(*) FROM user_tab_columns WHERE table_name = 'RESPONSES' AND column_name = 'USER_ID'",
&[],
)
.map_err(map_oracle_error)?;
if present_after > 0 {
return Err(map_oracle_error(err));
}
}
}
Ok(())
}
fn build_response_from_row(row: &Row) -> Result<StoredResponse, String> {
let id: String = row.get(0).map_err(map_oracle_error)?;
let previous: Option<String> = row.get(1).map_err(map_oracle_error)?;
......@@ -835,7 +893,7 @@ impl OracleResponseStorage {
let tool_calls_json: Option<String> = row.get(5).map_err(map_oracle_error)?;
let metadata_json: Option<String> = row.get(6).map_err(map_oracle_error)?;
let created_at: DateTime<Utc> = row.get(7).map_err(map_oracle_error)?;
let user_id: Option<String> = row.get(8).map_err(map_oracle_error)?;
let safety_identifier: Option<String> = row.get(8).map_err(map_oracle_error)?;
let model: Option<String> = row.get(9).map_err(map_oracle_error)?;
let conversation_id: Option<String> = row.get(10).map_err(map_oracle_error)?;
let raw_response_json: Option<String> = row.get(11).map_err(map_oracle_error)?;
......@@ -856,7 +914,7 @@ impl OracleResponseStorage {
tool_calls,
metadata,
created_at,
user: user_id,
safety_identifier,
model,
conversation_id,
raw_response,
......@@ -880,7 +938,7 @@ impl ResponseStorage for OracleResponseStorage {
let json_raw_response = serde_json::to_string(&response.raw_response)?;
let instructions = response.instructions.clone();
let created_at = response.created_at;
let user = response.user.clone();
let safety_identifier = response.safety_identifier.clone();
let model = response.model.clone();
let conversation_id = response.conversation_id.clone();
......@@ -888,7 +946,7 @@ impl ResponseStorage for OracleResponseStorage {
.execute(move |conn| {
conn.execute(
"INSERT INTO responses (id, previous_response_id, input, instructions, output, \
tool_calls, metadata, created_at, user_id, model, conversation_id, raw_response) \
tool_calls, metadata, created_at, safety_identifier, model, conversation_id, raw_response) \
VALUES (:1, :2, :3, :4, :5, :6, :7, :8, :9, :10, :11, :12)",
&[
&response_id_str,
......@@ -899,7 +957,7 @@ impl ResponseStorage for OracleResponseStorage {
&json_tool_calls,
&json_metadata,
&created_at,
&user,
&safety_identifier,
&model,
&conversation_id,
&json_raw_response,
......@@ -981,26 +1039,26 @@ impl ResponseStorage for OracleResponseStorage {
Ok(chain)
}
async fn list_user_responses(
async fn list_identifier_responses(
&self,
user: &str,
identifier: &str,
limit: Option<usize>,
) -> Result<Vec<StoredResponse>, ResponseStorageError> {
let user = user.to_string();
let identifier = identifier.to_string();
self.store
.execute(move |conn| {
let sql = if let Some(limit) = limit {
format!(
"SELECT * FROM ({} WHERE user_id = :1 ORDER BY created_at DESC) WHERE ROWNUM <= {}",
"SELECT * FROM ({} WHERE safety_identifier = :1 ORDER BY created_at DESC) WHERE ROWNUM <= {}",
SELECT_BASE, limit
)
} else {
format!("{} WHERE user_id = :1 ORDER BY created_at DESC", SELECT_BASE)
format!("{} WHERE safety_identifier = :1 ORDER BY created_at DESC", SELECT_BASE)
};
let mut stmt = conn.statement(&sql).build().map_err(map_oracle_error)?;
let mut rows = stmt.query(&[&user]).map_err(map_oracle_error)?;
let mut rows = stmt.query(&[&identifier]).map_err(map_oracle_error)?;
let mut results = Vec::new();
for row in &mut rows {
......@@ -1014,13 +1072,19 @@ impl ResponseStorage for OracleResponseStorage {
.map_err(ResponseStorageError::StorageError)
}
async fn delete_user_responses(&self, user: &str) -> Result<usize, ResponseStorageError> {
let user = user.to_string();
async fn delete_identifier_responses(
&self,
identifier: &str,
) -> Result<usize, ResponseStorageError> {
let identifier = identifier.to_string();
let affected = self
.store
.execute(move |conn| {
conn.execute("DELETE FROM responses WHERE user_id = :1", &[&user])
.map_err(map_oracle_error)
conn.execute(
"DELETE FROM responses WHERE safety_identifier = :1",
&[&identifier],
)
.map_err(map_oracle_error)
})
.await
.map_err(ResponseStorageError::StorageError)?;
......
......@@ -857,6 +857,10 @@ pub struct ResponsesResponse {
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
/// Safety identifier for content moderation
#[serde(skip_serializing_if = "Option::is_none")]
pub safety_identifier: Option<String>,
/// Additional metadata
#[serde(default)]
pub metadata: HashMap<String, Value>,
......
......@@ -325,6 +325,7 @@ impl HarmonyResponseProcessor {
output_tokens_details: None,
})),
user: None,
safety_identifier: responses_request.user.clone(),
metadata: responses_request.metadata.clone().unwrap_or_default(),
};
......
......@@ -358,7 +358,8 @@ pub fn chat_to_responses(
top_p: original_req.top_p,
truncation: None,
usage,
user: None, // No user field in chat response
user: None,
safety_identifier: original_req.user.clone(),
metadata: original_req.metadata.clone().unwrap_or_default(),
})
}
......
......@@ -325,7 +325,8 @@ async fn route_responses_background(
top_p: request.top_p,
truncation: None,
usage: None,
user: request.user.clone(),
user: None,
safety_identifier: request.user.clone(),
metadata: request.metadata.clone().unwrap_or_default(),
};
......@@ -842,6 +843,7 @@ impl StreamingResponseAccumulator {
truncation: None,
usage,
user: None,
safety_identifier: self.original_request.user.clone(),
metadata: self.original_request.metadata.clone().unwrap_or_default(),
}
}
......
......@@ -38,11 +38,9 @@ pub(super) fn build_stored_response(
.map(|s| s.to_string())
.or_else(|| Some(original_body.model.clone()));
stored_response.user = response_json
.get("user")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.or_else(|| original_body.user.clone());
if let Some(safety_identifier) = original_body.user.clone() {
stored_response.safety_identifier = Some(safety_identifier);
}
// Set conversation id from request if provided
if let Some(conv_id) = original_body.conversation.clone() {
......@@ -146,9 +144,16 @@ pub(super) fn patch_streaming_response_json(
);
}
if obj.get("user").map(|v| v.is_null()).unwrap_or(false) {
if let Some(user) = &original_body.user {
obj.insert("user".to_string(), Value::String(user.clone()));
if obj
.get("safety_identifier")
.map(|v| v.is_null())
.unwrap_or(false)
{
if let Some(safety_identifier) = &original_body.user {
obj.insert(
"safety_identifier".to_string(),
Value::String(safety_identifier.clone()),
);
}
}
......
......@@ -555,7 +555,7 @@ async fn test_openai_router_responses_streaming_with_mock() {
assert_eq!(stored.metadata.get("topic"), Some(&json!("unicorns")));
assert_eq!(stored.instructions.as_deref(), Some("Be kind"));
assert_eq!(stored.model.as_deref(), Some("gpt-5-nano"));
assert_eq!(stored.user, None);
assert_eq!(stored.safety_identifier, None);
assert_eq!(stored.raw_response["store"], json!(true));
assert_eq!(
stored.raw_response["previous_response_id"].as_str(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment