Unverified Commit 2cf3d0f8 authored by rongfu.leng's avatar rongfu.leng Committed by GitHub
Browse files

[router] use safety_identifier replace user on chat history storage (#12185)

parent 1ed1abfd
...@@ -340,8 +340,8 @@ pub struct StoredResponse { ...@@ -340,8 +340,8 @@ pub struct StoredResponse {
/// When this response was created /// When this response was created
pub created_at: DateTime<Utc>, pub created_at: DateTime<Utc>,
/// User identifier (optional) /// Safety identifier for content moderation
pub user: Option<String>, pub safety_identifier: Option<String>,
/// Model used for generation /// Model used for generation
pub model: Option<String>, pub model: Option<String>,
...@@ -366,7 +366,7 @@ impl StoredResponse { ...@@ -366,7 +366,7 @@ impl StoredResponse {
tool_calls: Vec::new(), tool_calls: Vec::new(),
metadata: HashMap::new(), metadata: HashMap::new(),
created_at: Utc::now(), created_at: Utc::now(),
user: None, safety_identifier: None,
model: None, model: None,
conversation_id: None, conversation_id: None,
raw_response: Value::Null, raw_response: Value::Null,
...@@ -465,15 +465,15 @@ pub trait ResponseStorage: Send + Sync { ...@@ -465,15 +465,15 @@ pub trait ResponseStorage: Send + Sync {
max_depth: Option<usize>, max_depth: Option<usize>,
) -> ResponseResult<ResponseChain>; ) -> ResponseResult<ResponseChain>;
/// List recent responses for a user /// List recent responses for a safety identifier
async fn list_user_responses( async fn list_identifier_responses(
&self, &self,
user: &str, identifier: &str,
limit: Option<usize>, limit: Option<usize>,
) -> ResponseResult<Vec<StoredResponse>>; ) -> ResponseResult<Vec<StoredResponse>>;
/// Delete all responses for a user /// Delete all responses for a safety identifier
async fn delete_user_responses(&self, user: &str) -> ResponseResult<usize>; async fn delete_identifier_responses(&self, identifier: &str) -> ResponseResult<usize>;
} }
impl Default for StoredResponse { impl Default for StoredResponse {
......
...@@ -270,8 +270,8 @@ impl ConversationItemStorage for MemoryConversationItemStorage { ...@@ -270,8 +270,8 @@ impl ConversationItemStorage for MemoryConversationItemStorage {
struct InnerStore { struct InnerStore {
/// All stored responses indexed by ID /// All stored responses indexed by ID
responses: HashMap<ResponseId, StoredResponse>, responses: HashMap<ResponseId, StoredResponse>,
/// Index of response IDs by user /// Index of response IDs by safety identifier
user_index: HashMap<String, Vec<ResponseId>>, identifier_index: HashMap<String, Vec<ResponseId>>,
} }
/// In-memory implementation of response storage /// In-memory implementation of response storage
...@@ -292,7 +292,7 @@ impl MemoryResponseStorage { ...@@ -292,7 +292,7 @@ impl MemoryResponseStorage {
let store = self.store.read(); let store = self.store.read();
MemoryStoreStats { MemoryStoreStats {
response_count: store.responses.len(), response_count: store.responses.len(),
user_count: store.user_index.len(), identifier_count: store.identifier_index.len(),
} }
} }
...@@ -300,7 +300,7 @@ impl MemoryResponseStorage { ...@@ -300,7 +300,7 @@ impl MemoryResponseStorage {
pub fn clear(&self) { pub fn clear(&self) {
let mut store = self.store.write(); let mut store = self.store.write();
store.responses.clear(); store.responses.clear();
store.user_index.clear(); store.identifier_index.clear();
} }
} }
...@@ -323,11 +323,11 @@ impl ResponseStorage for MemoryResponseStorage { ...@@ -323,11 +323,11 @@ impl ResponseStorage for MemoryResponseStorage {
// Single lock acquisition for atomic update // Single lock acquisition for atomic update
let mut store = self.store.write(); let mut store = self.store.write();
// Update user index if user is specified // Update safety identifier index if specified
if let Some(ref user) = response.user { if let Some(ref safety_identifier) = response.safety_identifier {
store store
.user_index .identifier_index
.entry(user.clone()) .entry(safety_identifier.clone())
.or_default() .or_default()
.push(response_id.clone()); .push(response_id.clone());
} }
...@@ -354,8 +354,8 @@ impl ResponseStorage for MemoryResponseStorage { ...@@ -354,8 +354,8 @@ impl ResponseStorage for MemoryResponseStorage {
// Remove the response and update user index if needed // Remove the response and update user index if needed
if let Some(response) = store.responses.remove(response_id) { if let Some(response) = store.responses.remove(response_id) {
if let Some(ref user) = response.user { if let Some(ref safety_identifier) = response.safety_identifier {
if let Some(user_responses) = store.user_index.get_mut(user) { if let Some(user_responses) = store.identifier_index.get_mut(safety_identifier) {
user_responses.retain(|id| id != response_id); user_responses.retain(|id| id != response_id);
} }
} }
...@@ -409,14 +409,14 @@ impl ResponseStorage for MemoryResponseStorage { ...@@ -409,14 +409,14 @@ impl ResponseStorage for MemoryResponseStorage {
Ok(chain) Ok(chain)
} }
async fn list_user_responses( async fn list_identifier_responses(
&self, &self,
user: &str, identifier: &str,
limit: Option<usize>, limit: Option<usize>,
) -> ResponseResult<Vec<StoredResponse>> { ) -> ResponseResult<Vec<StoredResponse>> {
let store = self.store.read(); let store = self.store.read();
if let Some(user_response_ids) = store.user_index.get(user) { if let Some(user_response_ids) = store.identifier_index.get(identifier) {
// Collect responses with their timestamps for sorting // Collect responses with their timestamps for sorting
let mut responses_with_time: Vec<_> = user_response_ids let mut responses_with_time: Vec<_> = user_response_ids
.iter() .iter()
...@@ -440,10 +440,10 @@ impl ResponseStorage for MemoryResponseStorage { ...@@ -440,10 +440,10 @@ impl ResponseStorage for MemoryResponseStorage {
} }
} }
async fn delete_user_responses(&self, user: &str) -> ResponseResult<usize> { async fn delete_identifier_responses(&self, identifier: &str) -> ResponseResult<usize> {
let mut store = self.store.write(); let mut store = self.store.write();
if let Some(user_response_ids) = store.user_index.remove(user) { if let Some(user_response_ids) = store.identifier_index.remove(identifier) {
let count = user_response_ids.len(); let count = user_response_ids.len();
for id in user_response_ids { for id in user_response_ids {
store.responses.remove(&id); store.responses.remove(&id);
...@@ -459,7 +459,7 @@ impl ResponseStorage for MemoryResponseStorage { ...@@ -459,7 +459,7 @@ impl ResponseStorage for MemoryResponseStorage {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct MemoryStoreStats { pub struct MemoryStoreStats {
pub response_count: usize, pub response_count: usize,
pub user_count: usize, pub identifier_count: usize,
} }
#[cfg(test)] #[cfg(test)]
...@@ -660,38 +660,50 @@ mod tests { ...@@ -660,38 +660,50 @@ mod tests {
let mut response1 = StoredResponse::new(None); let mut response1 = StoredResponse::new(None);
response1.input = json!("User1 message"); response1.input = json!("User1 message");
response1.output = json!("Response to user1"); response1.output = json!("Response to user1");
response1.user = Some("user1".to_string()); response1.safety_identifier = Some("user1".to_string());
store.store_response(response1).await.unwrap(); store.store_response(response1).await.unwrap();
let mut response2 = StoredResponse::new(None); let mut response2 = StoredResponse::new(None);
response2.input = json!("Another user1 message"); response2.input = json!("Another user1 message");
response2.output = json!("Another response to user1"); response2.output = json!("Another response to user1");
response2.user = Some("user1".to_string()); response2.safety_identifier = Some("user1".to_string());
store.store_response(response2).await.unwrap(); store.store_response(response2).await.unwrap();
let mut response3 = StoredResponse::new(None); let mut response3 = StoredResponse::new(None);
response3.input = json!("User2 message"); response3.input = json!("User2 message");
response3.output = json!("Response to user2"); response3.output = json!("Response to user2");
response3.user = Some("user2".to_string()); response3.safety_identifier = Some("user2".to_string());
store.store_response(response3).await.unwrap(); store.store_response(response3).await.unwrap();
// List user1's responses // List user1's responses
let user1_responses = store.list_user_responses("user1", None).await.unwrap(); let user1_responses = store
.list_identifier_responses("user1", None)
.await
.unwrap();
assert_eq!(user1_responses.len(), 2); assert_eq!(user1_responses.len(), 2);
// List user2's responses // List user2's responses
let user2_responses = store.list_user_responses("user2", None).await.unwrap(); let user2_responses = store
.list_identifier_responses("user2", None)
.await
.unwrap();
assert_eq!(user2_responses.len(), 1); assert_eq!(user2_responses.len(), 1);
// Delete user1's responses // Delete user1's responses
let deleted_count = store.delete_user_responses("user1").await.unwrap(); let deleted_count = store.delete_identifier_responses("user1").await.unwrap();
assert_eq!(deleted_count, 2); assert_eq!(deleted_count, 2);
let user1_responses_after = store.list_user_responses("user1", None).await.unwrap(); let user1_responses_after = store
.list_identifier_responses("user1", None)
.await
.unwrap();
assert_eq!(user1_responses_after.len(), 0); assert_eq!(user1_responses_after.len(), 0);
// User2's responses should still be there // User2's responses should still be there
let user2_responses_after = store.list_user_responses("user2", None).await.unwrap(); let user2_responses_after = store
.list_identifier_responses("user2", None)
.await
.unwrap();
assert_eq!(user2_responses_after.len(), 1); assert_eq!(user2_responses_after.len(), 1);
} }
...@@ -702,17 +714,17 @@ mod tests { ...@@ -702,17 +714,17 @@ mod tests {
let mut response1 = StoredResponse::new(None); let mut response1 = StoredResponse::new(None);
response1.input = json!("Test1"); response1.input = json!("Test1");
response1.output = json!("Reply1"); response1.output = json!("Reply1");
response1.user = Some("user1".to_string()); response1.safety_identifier = Some("user1".to_string());
store.store_response(response1).await.unwrap(); store.store_response(response1).await.unwrap();
let mut response2 = StoredResponse::new(None); let mut response2 = StoredResponse::new(None);
response2.input = json!("Test2"); response2.input = json!("Test2");
response2.output = json!("Reply2"); response2.output = json!("Reply2");
response2.user = Some("user2".to_string()); response2.safety_identifier = Some("user2".to_string());
store.store_response(response2).await.unwrap(); store.store_response(response2).await.unwrap();
let stats = store.stats(); let stats = store.stats();
assert_eq!(stats.response_count, 2); assert_eq!(stats.response_count, 2);
assert_eq!(stats.user_count, 2); assert_eq!(stats.identifier_count, 2);
} }
} }
...@@ -175,15 +175,15 @@ impl ResponseStorage for NoOpResponseStorage { ...@@ -175,15 +175,15 @@ impl ResponseStorage for NoOpResponseStorage {
Ok(ResponseChain::new()) Ok(ResponseChain::new())
} }
async fn list_user_responses( async fn list_identifier_responses(
&self, &self,
_user: &str, _identifier: &str,
_limit: Option<usize>, _limit: Option<usize>,
) -> ResponseResult<Vec<StoredResponse>> { ) -> ResponseResult<Vec<StoredResponse>> {
Ok(Vec::new()) Ok(Vec::new())
} }
async fn delete_user_responses(&self, _user: &str) -> ResponseResult<usize> { async fn delete_identifier_responses(&self, _identifier: &str) -> ResponseResult<usize> {
Ok(0) Ok(0)
} }
} }
...@@ -768,7 +768,7 @@ impl ConversationItemStorage for OracleConversationItemStorage { ...@@ -768,7 +768,7 @@ impl ConversationItemStorage for OracleConversationItemStorage {
// ============================================================================ // ============================================================================
const SELECT_BASE: &str = "SELECT id, previous_response_id, input, instructions, output, \ const SELECT_BASE: &str = "SELECT id, previous_response_id, input, instructions, output, \
tool_calls, metadata, created_at, user_id, model, conversation_id, raw_response FROM responses"; tool_calls, metadata, created_at, safety_identifier, model, conversation_id, raw_response FROM responses";
#[derive(Clone)] #[derive(Clone)]
pub struct OracleResponseStorage { pub struct OracleResponseStorage {
...@@ -798,13 +798,16 @@ impl OracleResponseStorage { ...@@ -798,13 +798,16 @@ impl OracleResponseStorage {
tool_calls CLOB, tool_calls CLOB,
metadata CLOB, metadata CLOB,
created_at TIMESTAMP WITH TIME ZONE, created_at TIMESTAMP WITH TIME ZONE,
user_id VARCHAR2(128), safety_identifier VARCHAR2(128),
model VARCHAR2(128), model VARCHAR2(128),
raw_response CLOB raw_response CLOB
)", )",
&[], &[],
) )
.map_err(map_oracle_error)?; .map_err(map_oracle_error)?;
} else {
Self::alter_safety_identifier_column(conn)?;
Self::remove_user_id_column_if_exists(conn)?;
} }
// Create indexes // Create indexes
...@@ -816,7 +819,7 @@ impl OracleResponseStorage { ...@@ -816,7 +819,7 @@ impl OracleResponseStorage {
create_index_if_missing( create_index_if_missing(
conn, conn,
"RESPONSES_USER_IDX", "RESPONSES_USER_IDX",
"CREATE INDEX responses_user_idx ON responses(user_id)", "CREATE INDEX responses_user_idx ON responses(safety_identifier)",
)?; )?;
Ok(()) Ok(())
...@@ -826,6 +829,61 @@ impl OracleResponseStorage { ...@@ -826,6 +829,61 @@ impl OracleResponseStorage {
Ok(Self { store }) Ok(Self { store })
} }
// Alter safety_identifier column if missing
fn alter_safety_identifier_column(conn: &Connection) -> Result<(), String> {
let present: i64 = conn
.query_row_as(
"SELECT COUNT(*) FROM user_tab_columns WHERE table_name = 'RESPONSES' AND column_name = 'SAFETY_IDENTIFIER'",
&[],
)
.map_err(map_oracle_error)?;
if present == 0 {
if let Err(err) = conn.execute(
"ALTER TABLE responses ADD (safety_identifier VARCHAR2(128))",
&[],
) {
let present_after: i64 = conn
.query_row_as(
"SELECT COUNT(*) FROM user_tab_columns WHERE table_name = 'RESPONSES' AND column_name = 'SAFETY_IDENTIFIER'",
&[],
)
.map_err(map_oracle_error)?;
if present_after == 0 {
return Err(map_oracle_error(err));
}
}
}
Ok(())
}
// Remove user_id column if exists
fn remove_user_id_column_if_exists(conn: &Connection) -> Result<(), String> {
let present: i64 = conn
.query_row_as(
"SELECT COUNT(*) FROM user_tab_columns WHERE table_name = 'RESPONSES' AND column_name = 'USER_ID'",
&[],
)
.map_err(map_oracle_error)?;
if present > 0 {
if let Err(err) = conn.execute("ALTER TABLE responses DROP COLUMN USER_ID", &[]) {
let present_after: i64 = conn
.query_row_as(
"SELECT COUNT(*) FROM user_tab_columns WHERE table_name = 'RESPONSES' AND column_name = 'USER_ID'",
&[],
)
.map_err(map_oracle_error)?;
if present_after > 0 {
return Err(map_oracle_error(err));
}
}
}
Ok(())
}
fn build_response_from_row(row: &Row) -> Result<StoredResponse, String> { fn build_response_from_row(row: &Row) -> Result<StoredResponse, String> {
let id: String = row.get(0).map_err(map_oracle_error)?; let id: String = row.get(0).map_err(map_oracle_error)?;
let previous: Option<String> = row.get(1).map_err(map_oracle_error)?; let previous: Option<String> = row.get(1).map_err(map_oracle_error)?;
...@@ -835,7 +893,7 @@ impl OracleResponseStorage { ...@@ -835,7 +893,7 @@ impl OracleResponseStorage {
let tool_calls_json: Option<String> = row.get(5).map_err(map_oracle_error)?; let tool_calls_json: Option<String> = row.get(5).map_err(map_oracle_error)?;
let metadata_json: Option<String> = row.get(6).map_err(map_oracle_error)?; let metadata_json: Option<String> = row.get(6).map_err(map_oracle_error)?;
let created_at: DateTime<Utc> = row.get(7).map_err(map_oracle_error)?; let created_at: DateTime<Utc> = row.get(7).map_err(map_oracle_error)?;
let user_id: Option<String> = row.get(8).map_err(map_oracle_error)?; let safety_identifier: Option<String> = row.get(8).map_err(map_oracle_error)?;
let model: Option<String> = row.get(9).map_err(map_oracle_error)?; let model: Option<String> = row.get(9).map_err(map_oracle_error)?;
let conversation_id: Option<String> = row.get(10).map_err(map_oracle_error)?; let conversation_id: Option<String> = row.get(10).map_err(map_oracle_error)?;
let raw_response_json: Option<String> = row.get(11).map_err(map_oracle_error)?; let raw_response_json: Option<String> = row.get(11).map_err(map_oracle_error)?;
...@@ -856,7 +914,7 @@ impl OracleResponseStorage { ...@@ -856,7 +914,7 @@ impl OracleResponseStorage {
tool_calls, tool_calls,
metadata, metadata,
created_at, created_at,
user: user_id, safety_identifier,
model, model,
conversation_id, conversation_id,
raw_response, raw_response,
...@@ -880,7 +938,7 @@ impl ResponseStorage for OracleResponseStorage { ...@@ -880,7 +938,7 @@ impl ResponseStorage for OracleResponseStorage {
let json_raw_response = serde_json::to_string(&response.raw_response)?; let json_raw_response = serde_json::to_string(&response.raw_response)?;
let instructions = response.instructions.clone(); let instructions = response.instructions.clone();
let created_at = response.created_at; let created_at = response.created_at;
let user = response.user.clone(); let safety_identifier = response.safety_identifier.clone();
let model = response.model.clone(); let model = response.model.clone();
let conversation_id = response.conversation_id.clone(); let conversation_id = response.conversation_id.clone();
...@@ -888,7 +946,7 @@ impl ResponseStorage for OracleResponseStorage { ...@@ -888,7 +946,7 @@ impl ResponseStorage for OracleResponseStorage {
.execute(move |conn| { .execute(move |conn| {
conn.execute( conn.execute(
"INSERT INTO responses (id, previous_response_id, input, instructions, output, \ "INSERT INTO responses (id, previous_response_id, input, instructions, output, \
tool_calls, metadata, created_at, user_id, model, conversation_id, raw_response) \ tool_calls, metadata, created_at, safety_identifier, model, conversation_id, raw_response) \
VALUES (:1, :2, :3, :4, :5, :6, :7, :8, :9, :10, :11, :12)", VALUES (:1, :2, :3, :4, :5, :6, :7, :8, :9, :10, :11, :12)",
&[ &[
&response_id_str, &response_id_str,
...@@ -899,7 +957,7 @@ impl ResponseStorage for OracleResponseStorage { ...@@ -899,7 +957,7 @@ impl ResponseStorage for OracleResponseStorage {
&json_tool_calls, &json_tool_calls,
&json_metadata, &json_metadata,
&created_at, &created_at,
&user, &safety_identifier,
&model, &model,
&conversation_id, &conversation_id,
&json_raw_response, &json_raw_response,
...@@ -981,26 +1039,26 @@ impl ResponseStorage for OracleResponseStorage { ...@@ -981,26 +1039,26 @@ impl ResponseStorage for OracleResponseStorage {
Ok(chain) Ok(chain)
} }
async fn list_user_responses( async fn list_identifier_responses(
&self, &self,
user: &str, identifier: &str,
limit: Option<usize>, limit: Option<usize>,
) -> Result<Vec<StoredResponse>, ResponseStorageError> { ) -> Result<Vec<StoredResponse>, ResponseStorageError> {
let user = user.to_string(); let identifier = identifier.to_string();
self.store self.store
.execute(move |conn| { .execute(move |conn| {
let sql = if let Some(limit) = limit { let sql = if let Some(limit) = limit {
format!( format!(
"SELECT * FROM ({} WHERE user_id = :1 ORDER BY created_at DESC) WHERE ROWNUM <= {}", "SELECT * FROM ({} WHERE safety_identifier = :1 ORDER BY created_at DESC) WHERE ROWNUM <= {}",
SELECT_BASE, limit SELECT_BASE, limit
) )
} else { } else {
format!("{} WHERE user_id = :1 ORDER BY created_at DESC", SELECT_BASE) format!("{} WHERE safety_identifier = :1 ORDER BY created_at DESC", SELECT_BASE)
}; };
let mut stmt = conn.statement(&sql).build().map_err(map_oracle_error)?; let mut stmt = conn.statement(&sql).build().map_err(map_oracle_error)?;
let mut rows = stmt.query(&[&user]).map_err(map_oracle_error)?; let mut rows = stmt.query(&[&identifier]).map_err(map_oracle_error)?;
let mut results = Vec::new(); let mut results = Vec::new();
for row in &mut rows { for row in &mut rows {
...@@ -1014,13 +1072,19 @@ impl ResponseStorage for OracleResponseStorage { ...@@ -1014,13 +1072,19 @@ impl ResponseStorage for OracleResponseStorage {
.map_err(ResponseStorageError::StorageError) .map_err(ResponseStorageError::StorageError)
} }
async fn delete_user_responses(&self, user: &str) -> Result<usize, ResponseStorageError> { async fn delete_identifier_responses(
let user = user.to_string(); &self,
identifier: &str,
) -> Result<usize, ResponseStorageError> {
let identifier = identifier.to_string();
let affected = self let affected = self
.store .store
.execute(move |conn| { .execute(move |conn| {
conn.execute("DELETE FROM responses WHERE user_id = :1", &[&user]) conn.execute(
.map_err(map_oracle_error) "DELETE FROM responses WHERE safety_identifier = :1",
&[&identifier],
)
.map_err(map_oracle_error)
}) })
.await .await
.map_err(ResponseStorageError::StorageError)?; .map_err(ResponseStorageError::StorageError)?;
......
...@@ -857,6 +857,10 @@ pub struct ResponsesResponse { ...@@ -857,6 +857,10 @@ pub struct ResponsesResponse {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>, pub user: Option<String>,
/// Safety identifier for content moderation
#[serde(skip_serializing_if = "Option::is_none")]
pub safety_identifier: Option<String>,
/// Additional metadata /// Additional metadata
#[serde(default)] #[serde(default)]
pub metadata: HashMap<String, Value>, pub metadata: HashMap<String, Value>,
......
...@@ -325,6 +325,7 @@ impl HarmonyResponseProcessor { ...@@ -325,6 +325,7 @@ impl HarmonyResponseProcessor {
output_tokens_details: None, output_tokens_details: None,
})), })),
user: None, user: None,
safety_identifier: responses_request.user.clone(),
metadata: responses_request.metadata.clone().unwrap_or_default(), metadata: responses_request.metadata.clone().unwrap_or_default(),
}; };
......
...@@ -358,7 +358,8 @@ pub fn chat_to_responses( ...@@ -358,7 +358,8 @@ pub fn chat_to_responses(
top_p: original_req.top_p, top_p: original_req.top_p,
truncation: None, truncation: None,
usage, usage,
user: None, // No user field in chat response user: None,
safety_identifier: original_req.user.clone(),
metadata: original_req.metadata.clone().unwrap_or_default(), metadata: original_req.metadata.clone().unwrap_or_default(),
}) })
} }
......
...@@ -325,7 +325,8 @@ async fn route_responses_background( ...@@ -325,7 +325,8 @@ async fn route_responses_background(
top_p: request.top_p, top_p: request.top_p,
truncation: None, truncation: None,
usage: None, usage: None,
user: request.user.clone(), user: None,
safety_identifier: request.user.clone(),
metadata: request.metadata.clone().unwrap_or_default(), metadata: request.metadata.clone().unwrap_or_default(),
}; };
...@@ -842,6 +843,7 @@ impl StreamingResponseAccumulator { ...@@ -842,6 +843,7 @@ impl StreamingResponseAccumulator {
truncation: None, truncation: None,
usage, usage,
user: None, user: None,
safety_identifier: self.original_request.user.clone(),
metadata: self.original_request.metadata.clone().unwrap_or_default(), metadata: self.original_request.metadata.clone().unwrap_or_default(),
} }
} }
......
...@@ -38,11 +38,9 @@ pub(super) fn build_stored_response( ...@@ -38,11 +38,9 @@ pub(super) fn build_stored_response(
.map(|s| s.to_string()) .map(|s| s.to_string())
.or_else(|| Some(original_body.model.clone())); .or_else(|| Some(original_body.model.clone()));
stored_response.user = response_json if let Some(safety_identifier) = original_body.user.clone() {
.get("user") stored_response.safety_identifier = Some(safety_identifier);
.and_then(|v| v.as_str()) }
.map(|s| s.to_string())
.or_else(|| original_body.user.clone());
// Set conversation id from request if provided // Set conversation id from request if provided
if let Some(conv_id) = original_body.conversation.clone() { if let Some(conv_id) = original_body.conversation.clone() {
...@@ -146,9 +144,16 @@ pub(super) fn patch_streaming_response_json( ...@@ -146,9 +144,16 @@ pub(super) fn patch_streaming_response_json(
); );
} }
if obj.get("user").map(|v| v.is_null()).unwrap_or(false) { if obj
if let Some(user) = &original_body.user { .get("safety_identifier")
obj.insert("user".to_string(), Value::String(user.clone())); .map(|v| v.is_null())
.unwrap_or(false)
{
if let Some(safety_identifier) = &original_body.user {
obj.insert(
"safety_identifier".to_string(),
Value::String(safety_identifier.clone()),
);
} }
} }
......
...@@ -555,7 +555,7 @@ async fn test_openai_router_responses_streaming_with_mock() { ...@@ -555,7 +555,7 @@ async fn test_openai_router_responses_streaming_with_mock() {
assert_eq!(stored.metadata.get("topic"), Some(&json!("unicorns"))); assert_eq!(stored.metadata.get("topic"), Some(&json!("unicorns")));
assert_eq!(stored.instructions.as_deref(), Some("Be kind")); assert_eq!(stored.instructions.as_deref(), Some("Be kind"));
assert_eq!(stored.model.as_deref(), Some("gpt-5-nano")); assert_eq!(stored.model.as_deref(), Some("gpt-5-nano"));
assert_eq!(stored.user, None); assert_eq!(stored.safety_identifier, None);
assert_eq!(stored.raw_response["store"], json!(true)); assert_eq!(stored.raw_response["store"], json!(true));
assert_eq!( assert_eq!(
stored.raw_response["previous_response_id"].as_str(), stored.raw_response["previous_response_id"].as_str(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment