Unverified Commit d7f0d88f authored by Keyang Ru's avatar Keyang Ru Committed by GitHub
Browse files

[router] implement response api get input item function and refactor input/output store (#11924)

parent fc86b18b
...@@ -150,9 +150,9 @@ class ResponseAPIBaseTest(CustomTestCase): ...@@ -150,9 +150,9 @@ class ResponseAPIBaseTest(CustomTestCase):
"""Cancel response by ID via POST /v1/responses/{response_id}/cancel.""" """Cancel response by ID via POST /v1/responses/{response_id}/cancel."""
return self.make_request(f"/v1/responses/{response_id}/cancel", "POST", {}) return self.make_request(f"/v1/responses/{response_id}/cancel", "POST", {})
def get_response_input(self, response_id: str) -> requests.Response: def get_response_input_items(self, response_id: str) -> requests.Response:
"""Get response input items via GET /v1/responses/{response_id}/input.""" """Get response input items via GET /v1/responses/{response_id}/input_items."""
return self.make_request(f"/v1/responses/{response_id}/input", "GET") return self.make_request(f"/v1/responses/{response_id}/input_items", "GET")
def create_conversation(self, metadata: Optional[dict] = None) -> requests.Response: def create_conversation(self, metadata: Optional[dict] = None) -> requests.Response:
"""Create conversation via POST /v1/conversations.""" """Create conversation via POST /v1/conversations."""
...@@ -359,13 +359,11 @@ class ResponseCRUDBaseTest(ResponseAPIBaseTest): ...@@ -359,13 +359,11 @@ class ResponseCRUDBaseTest(ResponseAPIBaseTest):
self.assertEqual(get_data["id"], response_id) self.assertEqual(get_data["id"], response_id)
self.assertEqual(get_data["status"], "completed") self.assertEqual(get_data["status"], "completed")
input_resp = self.get_response_input(get_data["id"]) input_resp = self.get_response_input_items(get_data["id"])
# change not merge yet self.assertEqual(input_resp.status_code, 200)
self.assertEqual(input_resp.status_code, 501) input_data = input_resp.json()
# self.assertEqual(input_resp.status_code, 200) self.assertIn("data", input_data)
# input_data = input_resp.json() self.assertGreater(len(input_data["data"]), 0)
# self.assertIn("data", input_data)
# self.assertGreater(len(input_data["data"]), 0)
@unittest.skip("TODO: Add delete response feature") @unittest.skip("TODO: Add delete response feature")
def test_delete_response(self): def test_delete_response(self):
......
...@@ -206,15 +206,17 @@ mod tests { ...@@ -206,15 +206,17 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_store_with_custom_id() { async fn test_store_with_custom_id() {
let store = MemoryResponseStorage::new(); let store = MemoryResponseStorage::new();
let mut response = StoredResponse::new("Input".to_string(), "Output".to_string(), None); let mut response = StoredResponse::new(None);
response.id = ResponseId::from("resp_custom"); response.id = ResponseId::from("resp_custom");
response.input = serde_json::json!("Input");
response.output = serde_json::json!("Output");
store.store_response(response.clone()).await.unwrap(); store.store_response(response.clone()).await.unwrap();
let retrieved = store let retrieved = store
.get_response(&ResponseId::from("resp_custom")) .get_response(&ResponseId::from("resp_custom"))
.await .await
.unwrap(); .unwrap();
assert!(retrieved.is_some()); assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().output, "Output"); assert_eq!(retrieved.unwrap().output, serde_json::json!("Output"));
} }
#[tokio::test] #[tokio::test]
...@@ -222,13 +224,15 @@ mod tests { ...@@ -222,13 +224,15 @@ mod tests {
let store = MemoryResponseStorage::new(); let store = MemoryResponseStorage::new();
// Store a response // Store a response
let response = StoredResponse::new("Hello".to_string(), "Hi there!".to_string(), None); let mut response = StoredResponse::new(None);
response.input = serde_json::json!("Hello");
response.output = serde_json::json!("Hi there!");
let response_id = store.store_response(response).await.unwrap(); let response_id = store.store_response(response).await.unwrap();
// Retrieve it // Retrieve it
let retrieved = store.get_response(&response_id).await.unwrap(); let retrieved = store.get_response(&response_id).await.unwrap();
assert!(retrieved.is_some()); assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().input, "Hello"); assert_eq!(retrieved.unwrap().input, serde_json::json!("Hello"));
// Delete it // Delete it
store.delete_response(&response_id).await.unwrap(); store.delete_response(&response_id).await.unwrap();
...@@ -241,35 +245,35 @@ mod tests { ...@@ -241,35 +245,35 @@ mod tests {
let store = MemoryResponseStorage::new(); let store = MemoryResponseStorage::new();
// Create a chain of responses // Create a chain of responses
let response1 = let mut response1 = StoredResponse::new(None);
StoredResponse::new("First".to_string(), "First response".to_string(), None); response1.input = serde_json::json!("First");
response1.output = serde_json::json!("First response");
let id1 = store.store_response(response1).await.unwrap(); let id1 = store.store_response(response1).await.unwrap();
let response2 = StoredResponse::new( let mut response2 = StoredResponse::new(Some(id1.clone()));
"Second".to_string(), response2.input = serde_json::json!("Second");
"Second response".to_string(), response2.output = serde_json::json!("Second response");
Some(id1.clone()),
);
let id2 = store.store_response(response2).await.unwrap(); let id2 = store.store_response(response2).await.unwrap();
let response3 = StoredResponse::new( let mut response3 = StoredResponse::new(Some(id2.clone()));
"Third".to_string(), response3.input = serde_json::json!("Third");
"Third response".to_string(), response3.output = serde_json::json!("Third response");
Some(id2.clone()),
);
let id3 = store.store_response(response3).await.unwrap(); let id3 = store.store_response(response3).await.unwrap();
// Get the chain // Get the chain
let chain = store.get_response_chain(&id3, None).await.unwrap(); let chain = store.get_response_chain(&id3, None).await.unwrap();
assert_eq!(chain.responses.len(), 3); assert_eq!(chain.responses.len(), 3);
assert_eq!(chain.responses[0].input, "First"); assert_eq!(chain.responses[0].input, serde_json::json!("First"));
assert_eq!(chain.responses[1].input, "Second"); assert_eq!(chain.responses[1].input, serde_json::json!("Second"));
assert_eq!(chain.responses[2].input, "Third"); assert_eq!(chain.responses[2].input, serde_json::json!("Third"));
let limited_chain = store.get_response_chain(&id3, Some(2)).await.unwrap(); let limited_chain = store.get_response_chain(&id3, Some(2)).await.unwrap();
assert_eq!(limited_chain.responses.len(), 2); assert_eq!(limited_chain.responses.len(), 2);
assert_eq!(limited_chain.responses[0].input, "Second"); assert_eq!(
assert_eq!(limited_chain.responses[1].input, "Third"); limited_chain.responses[0].input,
serde_json::json!("Second")
);
assert_eq!(limited_chain.responses[1].input, serde_json::json!("Third"));
} }
#[tokio::test] #[tokio::test]
...@@ -277,27 +281,21 @@ mod tests { ...@@ -277,27 +281,21 @@ mod tests {
let store = MemoryResponseStorage::new(); let store = MemoryResponseStorage::new();
// Store responses for different users // Store responses for different users
let mut response1 = StoredResponse::new( let mut response1 = StoredResponse::new(None);
"User1 message".to_string(), response1.input = serde_json::json!("User1 message");
"Response to user1".to_string(), response1.output = serde_json::json!("Response to user1");
None,
);
response1.user = Some("user1".to_string()); response1.user = Some("user1".to_string());
store.store_response(response1).await.unwrap(); store.store_response(response1).await.unwrap();
let mut response2 = StoredResponse::new( let mut response2 = StoredResponse::new(None);
"Another user1 message".to_string(), response2.input = serde_json::json!("Another user1 message");
"Another response to user1".to_string(), response2.output = serde_json::json!("Another response to user1");
None,
);
response2.user = Some("user1".to_string()); response2.user = Some("user1".to_string());
store.store_response(response2).await.unwrap(); store.store_response(response2).await.unwrap();
let mut response3 = StoredResponse::new( let mut response3 = StoredResponse::new(None);
"User2 message".to_string(), response3.input = serde_json::json!("User2 message");
"Response to user2".to_string(), response3.output = serde_json::json!("Response to user2");
None,
);
response3.user = Some("user2".to_string()); response3.user = Some("user2".to_string());
store.store_response(response3).await.unwrap(); store.store_response(response3).await.unwrap();
...@@ -325,11 +323,15 @@ mod tests { ...@@ -325,11 +323,15 @@ mod tests {
async fn test_memory_store_stats() { async fn test_memory_store_stats() {
let store = MemoryResponseStorage::new(); let store = MemoryResponseStorage::new();
let mut response1 = StoredResponse::new("Test1".to_string(), "Reply1".to_string(), None); let mut response1 = StoredResponse::new(None);
response1.input = serde_json::json!("Test1");
response1.output = serde_json::json!("Reply1");
response1.user = Some("user1".to_string()); response1.user = Some("user1".to_string());
store.store_response(response1).await.unwrap(); store.store_response(response1).await.unwrap();
let mut response2 = StoredResponse::new("Test2".to_string(), "Reply2".to_string(), None); let mut response2 = StoredResponse::new(None);
response2.input = serde_json::json!("Test2");
response2.output = serde_json::json!("Reply2");
response2.user = Some("user2".to_string()); response2.user = Some("user2".to_string());
store.store_response(response2).await.unwrap(); store.store_response(response2).await.unwrap();
......
...@@ -72,13 +72,13 @@ impl OracleResponseStorage { ...@@ -72,13 +72,13 @@ impl OracleResponseStorage {
let previous: Option<String> = row.get(1).map_err(|err| { let previous: Option<String> = row.get(1).map_err(|err| {
map_oracle_error(err).into_storage_error("fetch previous_response_id") map_oracle_error(err).into_storage_error("fetch previous_response_id")
})?; })?;
let input: String = row let input_json: Option<String> = row
.get(2) .get(2)
.map_err(|err| map_oracle_error(err).into_storage_error("fetch input"))?; .map_err(|err| map_oracle_error(err).into_storage_error("fetch input"))?;
let instructions: Option<String> = row let instructions: Option<String> = row
.get(3) .get(3)
.map_err(|err| map_oracle_error(err).into_storage_error("fetch instructions"))?; .map_err(|err| map_oracle_error(err).into_storage_error("fetch instructions"))?;
let output: String = row let output_json: Option<String> = row
.get(4) .get(4)
.map_err(|err| map_oracle_error(err).into_storage_error("fetch output"))?; .map_err(|err| map_oracle_error(err).into_storage_error("fetch output"))?;
let tool_calls_json: Option<String> = row let tool_calls_json: Option<String> = row
...@@ -107,6 +107,8 @@ impl OracleResponseStorage { ...@@ -107,6 +107,8 @@ impl OracleResponseStorage {
let tool_calls = parse_tool_calls(tool_calls_json)?; let tool_calls = parse_tool_calls(tool_calls_json)?;
let metadata = parse_metadata(metadata_json)?; let metadata = parse_metadata(metadata_json)?;
let raw_response = parse_raw_response(raw_response_json)?; let raw_response = parse_raw_response(raw_response_json)?;
let input = parse_json_value(input_json)?;
let output = parse_json_value(output_json)?;
Ok(StoredResponse { Ok(StoredResponse {
id: ResponseId(id), id: ResponseId(id),
...@@ -146,6 +148,8 @@ impl ResponseStorage for OracleResponseStorage { ...@@ -146,6 +148,8 @@ impl ResponseStorage for OracleResponseStorage {
let response_id = id.clone(); let response_id = id.clone();
let response_id_str = response_id.0.clone(); let response_id_str = response_id.0.clone();
let previous_id = previous_response_id.map(|r| r.0); let previous_id = previous_response_id.map(|r| r.0);
let json_input = serde_json::to_string(&input)?;
let json_output = serde_json::to_string(&output)?;
let json_tool_calls = serde_json::to_string(&tool_calls)?; let json_tool_calls = serde_json::to_string(&tool_calls)?;
let json_metadata = serde_json::to_string(&metadata)?; let json_metadata = serde_json::to_string(&metadata)?;
let json_raw_response = serde_json::to_string(&raw_response)?; let json_raw_response = serde_json::to_string(&raw_response)?;
...@@ -158,9 +162,9 @@ impl ResponseStorage for OracleResponseStorage { ...@@ -158,9 +162,9 @@ impl ResponseStorage for OracleResponseStorage {
&[ &[
&response_id_str, &response_id_str,
&previous_id, &previous_id,
&input, &json_input,
&instructions, &instructions,
&output, &json_output,
&json_tool_calls, &json_tool_calls,
&json_metadata, &json_metadata,
&created_at, &created_at,
...@@ -478,6 +482,15 @@ fn parse_raw_response(raw: Option<String>) -> StorageResult<Value> { ...@@ -478,6 +482,15 @@ fn parse_raw_response(raw: Option<String>) -> StorageResult<Value> {
} }
} }
fn parse_json_value(raw: Option<String>) -> StorageResult<Value> {
match raw {
Some(s) if !s.is_empty() => {
serde_json::from_str(&s).map_err(ResponseStorageError::SerializationError)
}
_ => Ok(Value::Array(vec![])),
}
}
fn map_pool_error(err: PoolError<oracle::Error>) -> ResponseStorageError { fn map_pool_error(err: PoolError<oracle::Error>) -> ResponseStorageError {
match err { match err {
PoolError::Backend(e) => map_oracle_error(e), PoolError::Backend(e) => map_oracle_error(e),
......
...@@ -41,14 +41,14 @@ pub struct StoredResponse { ...@@ -41,14 +41,14 @@ pub struct StoredResponse {
/// ID of the previous response in the chain (if any) /// ID of the previous response in the chain (if any)
pub previous_response_id: Option<ResponseId>, pub previous_response_id: Option<ResponseId>,
/// The user input for this response /// Input items as JSON array
pub input: String, pub input: Value,
/// System instructions used /// System instructions used
pub instructions: Option<String>, pub instructions: Option<String>,
/// The model's output /// Output items as JSON array
pub output: String, pub output: Value,
/// Tool calls made by the model (if any) /// Tool calls made by the model (if any)
pub tool_calls: Vec<Value>, pub tool_calls: Vec<Value>,
...@@ -75,13 +75,13 @@ pub struct StoredResponse { ...@@ -75,13 +75,13 @@ pub struct StoredResponse {
} }
impl StoredResponse { impl StoredResponse {
pub fn new(input: String, output: String, previous_response_id: Option<ResponseId>) -> Self { pub fn new(previous_response_id: Option<ResponseId>) -> Self {
Self { Self {
id: ResponseId::new(), id: ResponseId::new(),
previous_response_id, previous_response_id,
input, input: Value::Array(vec![]),
instructions: None, instructions: None,
output, output: Value::Array(vec![]),
tool_calls: Vec::new(), tool_calls: Vec::new(),
metadata: HashMap::new(), metadata: HashMap::new(),
created_at: chrono::Utc::now(), created_at: chrono::Utc::now(),
...@@ -128,7 +128,7 @@ impl ResponseChain { ...@@ -128,7 +128,7 @@ impl ResponseChain {
} }
/// Build context from the chain for the next request /// Build context from the chain for the next request
pub fn build_context(&self, max_responses: Option<usize>) -> Vec<(String, String)> { pub fn build_context(&self, max_responses: Option<usize>) -> Vec<(Value, Value)> {
let responses = if let Some(max) = max_responses { let responses = if let Some(max) = max_responses {
let start = self.responses.len().saturating_sub(max); let start = self.responses.len().saturating_sub(max);
&self.responses[start..] &self.responses[start..]
...@@ -197,6 +197,6 @@ pub type SharedResponseStorage = Arc<dyn ResponseStorage>; ...@@ -197,6 +197,6 @@ pub type SharedResponseStorage = Arc<dyn ResponseStorage>;
impl Default for StoredResponse { impl Default for StoredResponse {
fn default() -> Self { fn default() -> Self {
Self::new(String::new(), String::new(), None) Self::new(None)
} }
} }
...@@ -94,6 +94,14 @@ pub enum ReasoningSummary { ...@@ -94,6 +94,14 @@ pub enum ReasoningSummary {
// Input/Output Items // Input/Output Items
// ============================================================================ // ============================================================================
/// Content can be either a simple string or array of content parts (for SimpleInputMessage)
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum StringOrContentParts {
String(String),
Array(Vec<ResponseContentPart>),
}
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type")] #[serde(tag = "type")]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
...@@ -125,6 +133,14 @@ pub enum ResponseInputOutputItem { ...@@ -125,6 +133,14 @@ pub enum ResponseInputOutputItem {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
status: Option<String>, status: Option<String>,
}, },
#[serde(untagged)]
SimpleInputMessage {
content: StringOrContentParts,
role: String,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "type")]
r#type: Option<String>,
},
} }
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
...@@ -551,8 +567,8 @@ pub struct ResponsesRequest { ...@@ -551,8 +567,8 @@ pub struct ResponsesRequest {
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)] #[serde(untagged)]
pub enum ResponseInput { pub enum ResponseInput {
Text(String),
Items(Vec<ResponseInputOutputItem>), Items(Vec<ResponseInputOutputItem>),
Text(String),
} }
impl Default for ResponsesRequest { impl Default for ResponsesRequest {
...@@ -622,6 +638,28 @@ impl GenerationRequest for ResponsesRequest { ...@@ -622,6 +638,28 @@ impl GenerationRequest for ResponsesRequest {
Some(texts.join(" ")) Some(texts.join(" "))
} }
} }
ResponseInputOutputItem::SimpleInputMessage { content, .. } => {
match content {
StringOrContentParts::String(s) => Some(s.clone()),
StringOrContentParts::Array(parts) => {
// SimpleInputMessage only supports InputText
let texts: Vec<String> = parts
.iter()
.filter_map(|part| match part {
ResponseContentPart::InputText { text } => {
Some(text.clone())
}
_ => None,
})
.collect();
if texts.is_empty() {
None
} else {
Some(texts.join(" "))
}
}
}
}
ResponseInputOutputItem::Reasoning { content, .. } => { ResponseInputOutputItem::Reasoning { content, .. } => {
let texts: Vec<String> = content let texts: Vec<String> = content
.iter() .iter()
...@@ -645,6 +683,50 @@ impl GenerationRequest for ResponsesRequest { ...@@ -645,6 +683,50 @@ impl GenerationRequest for ResponsesRequest {
} }
} }
/// Normalize a SimpleInputMessage to a proper Message item
///
/// This helper converts SimpleInputMessage (which can have flexible content)
/// into a fully-structured Message item with a generated ID, role, and content array.
///
/// SimpleInputMessage items are converted to Message items with IDs generated using
/// the centralized ID generation pattern with "msg_" prefix for consistency.
///
/// # Arguments
/// * `item` - The input item to normalize
///
/// # Returns
/// A normalized ResponseInputOutputItem (either Message if converted, or original if not SimpleInputMessage)
pub fn normalize_input_item(item: &ResponseInputOutputItem) -> ResponseInputOutputItem {
match item {
ResponseInputOutputItem::SimpleInputMessage { content, role, .. } => {
let content_vec = match content {
StringOrContentParts::String(s) => {
vec![ResponseContentPart::InputText { text: s.clone() }]
}
StringOrContentParts::Array(parts) => parts.clone(),
};
ResponseInputOutputItem::Message {
id: generate_id("msg"),
role: role.clone(),
content: content_vec,
status: Some("completed".to_string()),
}
}
_ => item.clone(),
}
}
pub fn generate_id(prefix: &str) -> String {
use rand::RngCore;
let mut rng = rand::rng();
// Generate exactly 50 hex characters (25 bytes) for the part after the underscore
let mut bytes = [0u8; 25];
rng.fill_bytes(&mut bytes);
let hex_string: String = bytes.iter().map(|b| format!("{:02x}", b)).collect();
format!("{}_{}", prefix, hex_string)
}
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ResponsesResponse { pub struct ResponsesResponse {
/// Response ID /// Response ID
......
...@@ -48,6 +48,56 @@ pub fn responses_to_chat(req: &ResponsesRequest) -> Result<ChatCompletionRequest ...@@ -48,6 +48,56 @@ pub fn responses_to_chat(req: &ResponsesRequest) -> Result<ChatCompletionRequest
// Structured items → convert each to appropriate chat message // Structured items → convert each to appropriate chat message
for item in items { for item in items {
match item { match item {
ResponseInputOutputItem::SimpleInputMessage { content, role, .. } => {
// Convert SimpleInputMessage to chat message
use crate::protocols::responses::StringOrContentParts;
let text = match content {
StringOrContentParts::String(s) => s.clone(),
StringOrContentParts::Array(parts) => {
// Extract text from content parts (only InputText supported)
parts
.iter()
.filter_map(|part| match part {
ResponseContentPart::InputText { text } => {
Some(text.as_str())
}
_ => None,
})
.collect::<Vec<_>>()
.join(" ")
}
};
match role.as_str() {
"user" => {
messages.push(ChatMessage::User {
content: UserMessageContent::Text(text),
name: None,
});
}
"assistant" => {
messages.push(ChatMessage::Assistant {
content: Some(text),
name: None,
tool_calls: None,
reasoning_content: None,
});
}
"system" => {
messages.push(ChatMessage::System {
content: text,
name: None,
});
}
_ => {
// Unknown role, treat as user message
messages.push(ChatMessage::User {
content: UserMessageContent::Text(text),
name: None,
});
}
}
}
ResponseInputOutputItem::Message { role, content, .. } => { ResponseInputOutputItem::Message { role, content, .. } => {
// Extract text from content parts // Extract text from content parts
let text = extract_text_from_content(content); let text = extract_text_from_content(content);
......
...@@ -324,11 +324,7 @@ async fn route_responses_background( ...@@ -324,11 +324,7 @@ async fn route_responses_background(
incomplete_details: None, incomplete_details: None,
instructions: request.instructions.clone(), instructions: request.instructions.clone(),
max_output_tokens: request.max_output_tokens, max_output_tokens: request.max_output_tokens,
model: if request.model.is_empty() { model: request.model.clone(),
"default".to_string()
} else {
request.model.clone()
},
output: Vec::new(), output: Vec::new(),
parallel_tool_calls: request.parallel_tool_calls.unwrap_or(true), parallel_tool_calls: request.parallel_tool_calls.unwrap_or(true),
previous_response_id: request.previous_response_id.clone(), previous_response_id: request.previous_response_id.clone(),
...@@ -623,11 +619,7 @@ async fn process_and_transform_sse_stream( ...@@ -623,11 +619,7 @@ async fn process_and_transform_sse_stream(
// Create event emitter for OpenAI-compatible streaming // Create event emitter for OpenAI-compatible streaming
let response_id = format!("resp_{}", Uuid::new_v4()); let response_id = format!("resp_{}", Uuid::new_v4());
let model = if original_request.model.is_empty() { let model = original_request.model.clone();
"default".to_string()
} else {
original_request.model.clone()
};
let created_at = chrono::Utc::now().timestamp() as u64; let created_at = chrono::Utc::now().timestamp() as u64;
let mut event_emitter = ResponseStreamEventEmitter::new(response_id, model, created_at); let mut event_emitter = ResponseStreamEventEmitter::new(response_id, model, created_at);
...@@ -965,26 +957,37 @@ async fn load_conversation_history( ...@@ -965,26 +957,37 @@ async fn load_conversation_history(
Ok(chain) => { Ok(chain) => {
let mut items = Vec::new(); let mut items = Vec::new();
for stored in chain.responses.iter() { for stored in chain.responses.iter() {
// Convert input to conversation item // Convert input items from stored input (which is now a JSON array)
items.push(ResponseInputOutputItem::Message { if let Some(input_arr) = stored.input.as_array() {
id: format!("msg_u_{}", stored.id.0.trim_start_matches("resp_")), for item in input_arr {
role: "user".to_string(), match serde_json::from_value::<ResponseInputOutputItem>(item.clone()) {
content: vec![ResponseContentPart::InputText { Ok(input_item) => {
text: stored.input.clone(), items.push(input_item);
}], }
status: Some("completed".to_string()), Err(e) => {
}); warn!(
"Failed to deserialize stored input item: {}. Item: {}",
e, item
);
}
}
}
}
// Convert output to conversation items // Convert output items from stored output (which is now a JSON array)
if let Some(output_arr) = if let Some(output_arr) = stored.output.as_array() {
stored.raw_response.get("output").and_then(|v| v.as_array())
{
for item in output_arr { for item in output_arr {
if let Ok(output_item) = match serde_json::from_value::<ResponseInputOutputItem>(item.clone()) {
serde_json::from_value::<ResponseInputOutputItem>(item.clone()) Ok(output_item) => {
{
items.push(output_item); items.push(output_item);
} }
Err(e) => {
warn!(
"Failed to deserialize stored output item: {}. Item: {}",
e, item
);
}
}
} }
} }
} }
...@@ -1065,7 +1068,12 @@ async fn load_conversation_history( ...@@ -1065,7 +1068,12 @@ async fn load_conversation_history(
}); });
} }
ResponseInput::Items(current_items) => { ResponseInput::Items(current_items) => {
items.extend_from_slice(current_items); // Process all item types, converting SimpleInputMessage to Message
for item in current_items.iter() {
let normalized =
crate::protocols::responses::normalize_input_item(item);
items.push(normalized);
}
} }
} }
...@@ -1096,7 +1104,11 @@ async fn load_conversation_history( ...@@ -1096,7 +1104,11 @@ async fn load_conversation_history(
}); });
} }
ResponseInput::Items(current_items) => { ResponseInput::Items(current_items) => {
items.extend_from_slice(current_items); // Process all item types, converting SimpleInputMessage to Message
for item in current_items.iter() {
let normalized = crate::protocols::responses::normalize_input_item(item);
items.push(normalized);
}
} }
} }
......
...@@ -414,7 +414,10 @@ pub(super) async fn execute_tool_loop( ...@@ -414,7 +414,10 @@ pub(super) async fn execute_tool_loop(
content: vec![ResponseContentPart::InputText { text: text.clone() }], content: vec![ResponseContentPart::InputText { text: text.clone() }],
status: Some("completed".to_string()), status: Some("completed".to_string()),
}], }],
ResponseInput::Items(items) => items.clone(), ResponseInput::Items(items) => items
.iter()
.map(crate::protocols::responses::normalize_input_item)
.collect(),
}; };
// Append all conversation history (function calls and outputs) // Append all conversation history (function calls and outputs)
...@@ -608,11 +611,7 @@ async fn execute_tool_loop_streaming_internal( ...@@ -608,11 +611,7 @@ async fn execute_tool_loop_streaming_internal(
// Create response event emitter // Create response event emitter
let response_id = format!("resp_{}", Uuid::new_v4()); let response_id = format!("resp_{}", Uuid::new_v4());
let model = if current_request.model.is_empty() { let model = current_request.model.clone();
"default".to_string()
} else {
current_request.model.clone()
};
let created_at = SystemTime::now() let created_at = SystemTime::now()
.duration_since(UNIX_EPOCH) .duration_since(UNIX_EPOCH)
.unwrap() .unwrap()
...@@ -871,13 +870,14 @@ async fn execute_tool_loop_streaming_internal( ...@@ -871,13 +870,14 @@ async fn execute_tool_loop_streaming_internal(
content: vec![ResponseContentPart::InputText { text: text.clone() }], content: vec![ResponseContentPart::InputText { text: text.clone() }],
status: Some("completed".to_string()), status: Some("completed".to_string()),
}], }],
ResponseInput::Items(items) => items.clone(), ResponseInput::Items(items) => items
.iter()
.map(crate::protocols::responses::normalize_input_item)
.collect(),
}; };
// Append all conversation history
input_items.extend_from_slice(&state.conversation_history); input_items.extend_from_slice(&state.conversation_history);
// Build new request for next iteration
current_request = ResponsesRequest { current_request = ResponsesRequest {
input: ResponseInput::Items(input_items), input: ResponseInput::Items(input_items),
model: current_request.model.clone(), model: current_request.model.clone(),
...@@ -886,8 +886,8 @@ async fn execute_tool_loop_streaming_internal( ...@@ -886,8 +886,8 @@ async fn execute_tool_loop_streaming_internal(
max_output_tokens: current_request.max_output_tokens, max_output_tokens: current_request.max_output_tokens,
temperature: current_request.temperature, temperature: current_request.temperature,
top_p: current_request.top_p, top_p: current_request.top_p,
stream: Some(true), // Keep streaming enabled stream: Some(true),
store: Some(false), // Don't store intermediate responses store: Some(false),
background: Some(false), background: Some(false),
max_tool_calls: current_request.max_tool_calls, max_tool_calls: current_request.max_tool_calls,
tool_choice: current_request.tool_choice.clone(), tool_choice: current_request.tool_choice.clone(),
......
...@@ -11,7 +11,7 @@ use chrono::Utc; ...@@ -11,7 +11,7 @@ use chrono::Utc;
use serde_json::{json, Value}; use serde_json::{json, Value};
use tracing::{debug, info, warn}; use tracing::{debug, info, warn};
use super::responses::build_stored_response; use super::{responses::build_stored_response, utils::generate_id};
use crate::{ use crate::{
data_connector::{ data_connector::{
conversation_items::{ListParams, SortOrder}, conversation_items::{ListParams, SortOrder},
...@@ -19,7 +19,7 @@ use crate::{ ...@@ -19,7 +19,7 @@ use crate::{
ConversationStorage, NewConversation, NewConversationItem, ResponseId, ResponseStorage, ConversationStorage, NewConversation, NewConversationItem, ResponseId, ResponseStorage,
SharedConversationItemStorage, SharedConversationStorage, SharedConversationItemStorage, SharedConversationStorage,
}, },
protocols::responses::{ResponseInput, ResponseInputOutputItem, ResponsesRequest}, protocols::responses::{ResponseInput, ResponsesRequest},
}; };
/// Maximum number of properties allowed in conversation metadata /// Maximum number of properties allowed in conversation metadata
...@@ -1015,6 +1015,12 @@ async fn create_and_link_item( ...@@ -1015,6 +1015,12 @@ async fn create_and_link_item(
} }
/// Persist conversation items with all storages /// Persist conversation items with all storages
///
/// This function:
/// 1. Extracts and normalizes input items from the request
/// 2. Extracts output items from the response
/// 3. Stores ALL items in response storage (always)
/// 4. If conversation provided, also links items to conversation
async fn persist_items_with_storages( async fn persist_items_with_storages(
conversation_storage: Arc<dyn ConversationStorage>, conversation_storage: Arc<dyn ConversationStorage>,
item_storage: Arc<dyn ConversationItemStorage>, item_storage: Arc<dyn ConversationItemStorage>,
...@@ -1022,7 +1028,32 @@ async fn persist_items_with_storages( ...@@ -1022,7 +1028,32 @@ async fn persist_items_with_storages(
response_json: &Value, response_json: &Value,
original_body: &ResponsesRequest, original_body: &ResponsesRequest,
) -> Result<(), String> { ) -> Result<(), String> {
// Check if conversation is provided and validate it // Step 1: Extract response ID
let response_id_str = response_json
.get("id")
.and_then(|v| v.as_str())
.ok_or_else(|| "Response missing id field".to_string())?;
let response_id = ResponseId::from(response_id_str);
// Step 2: Parse and normalize input items from request
let input_items = extract_input_items(&original_body.input)?;
// Step 3: Parse output items from response
let output_items = extract_output_items(response_json)?;
// Step 4: Build StoredResponse with input and output as JSON arrays
let mut stored_response = build_stored_response(response_json, original_body);
stored_response.id = response_id.clone();
stored_response.input = Value::Array(input_items.clone());
stored_response.output = Value::Array(output_items.clone());
// Step 5: Store response (ALWAYS, regardless of conversation)
response_storage
.store_response(stored_response)
.await
.map_err(|e| format!("Failed to store response: {}", e))?;
// Step 6: Check if conversation is provided and validate it
let conv_id_opt = match &original_body.conversation { let conv_id_opt = match &original_body.conversation {
Some(id) => { Some(id) => {
let conv_id = ConversationId::from(id.as_str()); let conv_id = ConversationId::from(id.as_str());
...@@ -1034,130 +1065,209 @@ async fn persist_items_with_storages( ...@@ -1034,130 +1065,209 @@ async fn persist_items_with_storages(
.is_none() .is_none()
{ {
warn!(conversation_id = %conv_id.0, "Conversation not found, skipping item linking"); warn!(conversation_id = %conv_id.0, "Conversation not found, skipping item linking");
None // Conversation doesn't exist, store items without linking None // Conversation doesn't exist, items already stored in response
} else { } else {
Some(conv_id) Some(conv_id)
} }
} }
None => None, // No conversation provided, store items without linking None => None, // No conversation provided, items already stored in response
}; };
let response_id_str = response_json // Step 7: If conversation exists, link items to it
.get("id") if let Some(conv_id) = conv_id_opt {
.and_then(|v| v.as_str()) link_items_to_conversation(
.ok_or_else(|| "Response missing id field".to_string())?; &item_storage,
let response_id = ResponseId::from(response_id_str); &conv_id,
&input_items,
&output_items,
response_id_str,
)
.await?;
let response_id_opt = Some(response_id_str.to_string()); info!(
conversation_id = %conv_id.0,
response_id = %response_id.0,
input_count = input_items.len(),
output_count = output_items.len(),
"Persisted response and linked items to conversation"
);
} else {
info!(
response_id = %response_id.0,
input_count = input_items.len(),
output_count = output_items.len(),
"Persisted response without conversation linking"
);
}
Ok(())
}
/// Extract and normalize input items from ResponseInput
fn extract_input_items(input: &ResponseInput) -> Result<Vec<Value>, String> {
use crate::protocols::responses::{ResponseInputOutputItem, StringOrContentParts};
// Persist input items (only if conversation is provided) let items = match input {
if conv_id_opt.is_some() {
match &original_body.input {
ResponseInput::Text(text) => { ResponseInput::Text(text) => {
let new_item = NewConversationItem { // Convert simple text to message item
id: None, // Let storage generate ID vec![json!({
response_id: response_id_opt.clone(), "id": generate_id("msg"),
item_type: "message".to_string(), "type": "message",
role: Some("user".to_string()), "role": "user",
content: json!([{ "type": "input_text", "text": text }]), "content": [{"type": "input_text", "text": text}],
status: Some("completed".to_string()), "status": "completed"
}; })]
create_and_link_item(&item_storage, conv_id_opt.as_ref(), new_item).await?; }
} ResponseInput::Items(items) => {
ResponseInput::Items(items_array) => { // Process all item types and ensure IDs
for input_item in items_array { items
match input_item { .iter()
ResponseInputOutputItem::Message { .map(|item| {
role, match item {
content, ResponseInputOutputItem::SimpleInputMessage { content, role, .. } => {
status, // Convert SimpleInputMessage to standard message format with ID
.. let content_json = match content {
} => { StringOrContentParts::String(s) => {
let content_v = serde_json::to_value(content) json!([{"type": "input_text", "text": s}])
.map_err(|e| format!("Failed to serialize content: {}", e))?; }
let new_item = NewConversationItem { StringOrContentParts::Array(parts) => serde_json::to_value(parts)
id: None, .map_err(|e| {
response_id: response_id_opt.clone(), format!("Failed to serialize content: {}", e)
item_type: "message".to_string(), })?,
role: Some(role.clone()),
content: content_v,
status: status.clone(),
}; };
create_and_link_item(&item_storage, conv_id_opt.as_ref(), new_item)
.await?; Ok(json!({
"id": generate_id("msg"),
"type": "message",
"role": role,
"content": content_json,
"status": "completed"
}))
} }
_ => { _ => {
// For other types (FunctionToolCall, etc.), serialize the whole item // For other item types (Message, Reasoning, FunctionToolCall), serialize and ensure ID
let item_val = serde_json::to_value(input_item) let mut value = serde_json::to_value(item)
.map_err(|e| format!("Failed to serialize item: {}", e))?; .map_err(|e| format!("Failed to serialize item: {}", e))?;
let new_item = NewConversationItem {
id: None, // Ensure ID exists - generate if missing
response_id: response_id_opt.clone(), if let Some(obj) = value.as_object_mut() {
item_type: "unknown".to_string(), if !obj.contains_key("id")
role: None, || obj
content: item_val, .get("id")
status: Some("completed".to_string()), .and_then(|v| v.as_str())
}; .map(|s| s.is_empty())
create_and_link_item(&item_storage, conv_id_opt.as_ref(), new_item) .unwrap_or(true)
.await?; {
} obj.insert("id".to_string(), json!(generate_id("item")));
} }
} }
Ok(value)
} }
} }
})
.collect::<Result<Vec<_>, String>>()?
} }
};
Ok(items)
}
// Persist output items - ALWAYS persist output items, even if no conversation /// Extract ALL output items from response JSON
if let Some(output_arr) = response_json.get("output").and_then(|v| v.as_array()) { fn extract_output_items(response_json: &Value) -> Result<Vec<Value>, String> {
for output_item in output_arr { response_json
if let Some(obj) = output_item.as_object() { .get("output")
let item_type = obj .and_then(|v| v.as_array())
.cloned()
.ok_or_else(|| "No output array in response".to_string())
}
/// Link ALL input and output items to a conversation
async fn link_items_to_conversation(
item_storage: &Arc<dyn ConversationItemStorage>,
conv_id: &ConversationId,
input_items: &[Value],
output_items: &[Value],
response_id: &str,
) -> Result<(), String> {
let response_id_opt = Some(response_id.to_string());
// Link ALL input items (no filtering by type)
for input_item_value in input_items {
let item_type = input_item_value
.get("type") .get("type")
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.unwrap_or("message"); .unwrap_or("message");
let role = input_item_value
.get("role")
.and_then(|v| v.as_str())
.map(String::from);
let content = input_item_value
.get("content")
.cloned()
.unwrap_or(json!([]));
let status = input_item_value
.get("status")
.and_then(|v| v.as_str())
.map(String::from);
let role = obj.get("role").and_then(|v| v.as_str()).map(String::from); let new_item = NewConversationItem {
let status = obj.get("status").and_then(|v| v.as_str()).map(String::from); id: None, // Let storage generate ID
response_id: response_id_opt.clone(),
item_type: item_type.to_string(),
role,
content,
status,
};
create_and_link_item(item_storage, Some(conv_id), new_item).await?;
}
// Link ALL output items (no filtering by type)
// Store reasoning, function_tool_call, mcp_call, and any other types
for output_item_value in output_items {
let item_type = output_item_value
.get("type")
.and_then(|v| v.as_str())
.unwrap_or("message");
let role = output_item_value
.get("role")
.and_then(|v| v.as_str())
.map(String::from);
let status = output_item_value
.get("status")
.and_then(|v| v.as_str())
.map(String::from);
// Extract the original item ID from the response // Extract the original item ID from the response
let item_id = obj let item_id = output_item_value
.get("id") .get("id")
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.map(ConversationItemId::from); .map(ConversationItemId::from);
// For non-message types, store the entire item as content
// For message types, extract just the content field
let content = if item_type == "message" { let content = if item_type == "message" {
obj.get("content").cloned().unwrap_or(json!([])) output_item_value
.get("content")
.cloned()
.unwrap_or(json!([]))
} else { } else {
output_item.clone() // For other types (reasoning, function_tool_call, mcp_call, etc.)
// store the entire item structure
output_item_value.clone()
}; };
let new_item = NewConversationItem { let new_item = NewConversationItem {
id: item_id, // Use the original ID from response id: item_id, // Preserve ID if present
response_id: response_id_opt.clone(), response_id: response_id_opt.clone(),
item_type: item_type.to_string(), item_type: item_type.to_string(),
role, role,
content, content,
status, status,
}; };
create_and_link_item(&item_storage, conv_id_opt.as_ref(), new_item).await?;
}
}
}
// Store the full response using the shared helper
let mut stored_response = build_stored_response(response_json, original_body);
stored_response.id = response_id;
let final_response_id = stored_response.id.clone();
response_storage create_and_link_item(item_storage, Some(conv_id), new_item).await?;
.store_response(stored_response)
.await
.map_err(|e| format!("Failed to store response: {}", e))?;
if let Some(conv_id) = &conv_id_opt {
info!(conversation_id = %conv_id.0, response_id = %final_response_id.0, "Persisted conversation items and response");
} else {
info!(response_id = %final_response_id.0, "Persisted items and response (no conversation)");
} }
Ok(()) Ok(())
......
...@@ -16,7 +16,7 @@ use serde_json::{json, to_value, Value}; ...@@ -16,7 +16,7 @@ use serde_json::{json, to_value, Value};
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tracing::{info, warn}; use tracing::{info, warn};
use super::utils::event_types; use super::utils::{event_types, generate_id};
use crate::{ use crate::{
mcp::McpClientManager, mcp::McpClientManager,
protocols::responses::{ResponseInput, ResponseTool, ResponseToolType, ResponsesRequest}, protocols::responses::{ResponseInput, ResponseTool, ResponseToolType, ResponsesRequest},
...@@ -338,7 +338,7 @@ pub(super) fn build_resume_payload( ...@@ -338,7 +338,7 @@ pub(super) fn build_resume_payload(
input_array.push(user_item); input_array.push(user_item);
} }
ResponseInput::Items(items) => { ResponseInput::Items(items) => {
// Items are already structured ResponseInputOutputItem, convert to JSON // Items are ResponseInputOutputItem (including SimpleInputMessage), convert to JSON
if let Ok(items_value) = to_value(items) { if let Ok(items_value) = to_value(items) {
if let Some(items_arr) = items_value.as_array() { if let Some(items_arr) = items_value.as_array() {
input_array.extend_from_slice(items_arr); input_array.extend_from_slice(items_arr);
...@@ -836,17 +836,6 @@ pub(super) fn build_incomplete_response( ...@@ -836,17 +836,6 @@ pub(super) fn build_incomplete_response(
// Output Item Builders // Output Item Builders
// ============================================================================ // ============================================================================
/// Generate a unique ID for MCP output items (similar to OpenAI format)
pub(super) fn generate_mcp_id(prefix: &str) -> String {
use rand::RngCore;
let mut rng = rand::rng();
// Generate exactly 50 hex characters (25 bytes) for the part after the underscore
let mut bytes = [0u8; 25];
rng.fill_bytes(&mut bytes);
let hex_string: String = bytes.iter().map(|b| format!("{:02x}", b)).collect();
format!("{}_{}", prefix, hex_string)
}
/// Build an mcp_list_tools output item /// Build an mcp_list_tools output item
pub(super) fn build_mcp_list_tools_item(mcp: &Arc<McpClientManager>, server_label: &str) -> Value { pub(super) fn build_mcp_list_tools_item(mcp: &Arc<McpClientManager>, server_label: &str) -> Value {
let tools = mcp.list_tools(); let tools = mcp.list_tools();
...@@ -869,7 +858,7 @@ pub(super) fn build_mcp_list_tools_item(mcp: &Arc<McpClientManager>, server_labe ...@@ -869,7 +858,7 @@ pub(super) fn build_mcp_list_tools_item(mcp: &Arc<McpClientManager>, server_labe
.collect(); .collect();
json!({ json!({
"id": generate_mcp_id("mcpl"), "id": generate_id("mcpl"),
"type": event_types::ITEM_TYPE_MCP_LIST_TOOLS, "type": event_types::ITEM_TYPE_MCP_LIST_TOOLS,
"server_label": server_label, "server_label": server_label,
"tools": tools_json "tools": tools_json
...@@ -886,7 +875,7 @@ pub(super) fn build_mcp_call_item( ...@@ -886,7 +875,7 @@ pub(super) fn build_mcp_call_item(
error: Option<&str>, error: Option<&str>,
) -> Value { ) -> Value {
json!({ json!({
"id": generate_mcp_id("mcp"), "id": generate_id("mcp"),
"type": event_types::ITEM_TYPE_MCP_CALL, "type": event_types::ITEM_TYPE_MCP_CALL,
"status": if success { "completed" } else { "failed" }, "status": if success { "completed" } else { "failed" },
"approval_request_id": Value::Null, "approval_request_id": Value::Null,
......
...@@ -8,7 +8,7 @@ use tracing::warn; ...@@ -8,7 +8,7 @@ use tracing::warn;
use super::utils::event_types; use super::utils::event_types;
use crate::{ use crate::{
data_connector::{ResponseId, StoredResponse}, data_connector::{ResponseId, StoredResponse},
protocols::responses::{ResponseInput, ResponseToolType, ResponsesRequest}, protocols::responses::{ResponseToolType, ResponsesRequest},
}; };
// ============================================================================ // ============================================================================
...@@ -20,14 +20,11 @@ pub(super) fn build_stored_response( ...@@ -20,14 +20,11 @@ pub(super) fn build_stored_response(
response_json: &Value, response_json: &Value,
original_body: &ResponsesRequest, original_body: &ResponsesRequest,
) -> StoredResponse { ) -> StoredResponse {
let input_text = match &original_body.input { let mut stored_response = StoredResponse::new(None);
ResponseInput::Text(text) => text.clone(),
ResponseInput::Items(_) => "complex input".to_string(),
};
let output_text = extract_primary_output_text(response_json).unwrap_or_default();
let mut stored_response = StoredResponse::new(input_text, output_text, None); // Initialize empty arrays - will be populated by persist_items_with_storages
stored_response.input = Value::Array(vec![]);
stored_response.output = Value::Array(vec![]);
stored_response.instructions = response_json stored_response.instructions = response_json
.get("instructions") .get("instructions")
...@@ -313,31 +310,3 @@ pub(super) fn mask_tools_as_mcp(resp: &mut Value, original_body: &ResponsesReque ...@@ -313,31 +310,3 @@ pub(super) fn mask_tools_as_mcp(resp: &mut Value, original_body: &ResponsesReque
.or_insert(Value::String("auto".to_string())); .or_insert(Value::String("auto".to_string()));
} }
} }
// ============================================================================
// Output Text Extraction
// ============================================================================
/// Extract primary output text from response JSON
pub(super) fn extract_primary_output_text(response_json: &Value) -> Option<String> {
if let Some(items) = response_json.get("output").and_then(|v| v.as_array()) {
for item in items {
if let Some(content) = item.get("content").and_then(|v| v.as_array()) {
for part in content {
if part
.get("type")
.and_then(|v| v.as_str())
.map(|t| t == "output_text")
.unwrap_or(false)
{
if let Some(text) = part.get("text").and_then(|v| v.as_str()) {
return Some(text.to_string());
}
}
}
}
}
}
None
}
...@@ -744,26 +744,38 @@ impl crate::routers::RouterTrait for OpenAIRouter { ...@@ -744,26 +744,38 @@ impl crate::routers::RouterTrait for OpenAIRouter {
Ok(chain) => { Ok(chain) => {
let mut items = Vec::new(); let mut items = Vec::new();
for stored in chain.responses.iter() { for stored in chain.responses.iter() {
// Convert input to conversation item // Convert input items from stored input (which is now a JSON array)
items.push(ResponseInputOutputItem::Message { if let Some(input_arr) = stored.input.as_array() {
id: format!("msg_u_{}", stored.id.0.trim_start_matches("resp_")), for item in input_arr {
role: "user".to_string(), match serde_json::from_value::<ResponseInputOutputItem>(
content: vec![ResponseContentPart::InputText { item.clone(),
text: stored.input.clone(), ) {
}], Ok(input_item) => {
status: Some("completed".to_string()), items.push(input_item);
}); }
Err(e) => {
warn!(
"Failed to deserialize stored input item: {}. Item: {}",
e, item
);
}
}
}
}
// Convert output to conversation items directly from stored response // Convert output items from stored output (which is now a JSON array)
if let Some(output_arr) = if let Some(output_arr) = stored.output.as_array() {
stored.raw_response.get("output").and_then(|v| v.as_array())
{
for item in output_arr { for item in output_arr {
if let Ok(output_item) = match serde_json::from_value::<ResponseInputOutputItem>(
serde_json::from_value::<ResponseInputOutputItem>(item.clone()) item.clone(),
{ ) {
Ok(output_item) => {
items.push(output_item); items.push(output_item);
} }
Err(e) => {
warn!("Failed to deserialize stored output item: {}. Item: {}", e, item);
}
}
} }
} }
} }
...@@ -838,7 +850,12 @@ impl crate::routers::RouterTrait for OpenAIRouter { ...@@ -838,7 +850,12 @@ impl crate::routers::RouterTrait for OpenAIRouter {
}); });
} }
ResponseInput::Items(current_items) => { ResponseInput::Items(current_items) => {
items.extend_from_slice(current_items); // Process all item types, converting SimpleInputMessage to Message
for item in current_items.iter() {
let normalized =
crate::protocols::responses::normalize_input_item(item);
items.push(normalized);
}
} }
} }
...@@ -868,7 +885,11 @@ impl crate::routers::RouterTrait for OpenAIRouter { ...@@ -868,7 +885,11 @@ impl crate::routers::RouterTrait for OpenAIRouter {
}); });
} }
ResponseInput::Items(current_items) => { ResponseInput::Items(current_items) => {
items.extend_from_slice(current_items); // Process all item types, converting SimpleInputMessage to Message
for item in current_items.iter() {
let normalized = crate::protocols::responses::normalize_input_item(item);
items.push(normalized);
}
} }
} }
...@@ -1023,6 +1044,78 @@ impl crate::routers::RouterTrait for OpenAIRouter { ...@@ -1023,6 +1044,78 @@ impl crate::routers::RouterTrait for OpenAIRouter {
.into_response() .into_response()
} }
async fn list_response_input_items(
&self,
_headers: Option<&HeaderMap>,
response_id: &str,
) -> Response {
let resp_id = ResponseId::from(response_id);
match self.response_storage.get_response(&resp_id).await {
Ok(Some(stored)) => {
// Extract items from input field (which is a JSON array)
let items = match &stored.input {
Value::Array(arr) => arr.clone(),
_ => vec![],
};
// Generate IDs for items if they don't have them
let items_with_ids: Vec<Value> = items
.into_iter()
.map(|mut item| {
if item.get("id").is_none() {
// Generate ID if not present using centralized utility
if let Some(obj) = item.as_object_mut() {
obj.insert(
"id".to_string(),
json!(super::utils::generate_id("msg")),
);
}
}
item
})
.collect();
let response_body = json!({
"object": "list",
"data": items_with_ids,
"first_id": items_with_ids.first().and_then(|v| v.get("id").and_then(|i| i.as_str())),
"last_id": items_with_ids.last().and_then(|v| v.get("id").and_then(|i| i.as_str())),
"has_more": false
});
(StatusCode::OK, Json(response_body)).into_response()
}
Ok(None) => (
StatusCode::NOT_FOUND,
Json(json!({
"error": {
"message": format!("No response found with id '{}'", response_id),
"type": "invalid_request_error",
"param": Value::Null,
"code": "not_found"
}
})),
)
.into_response(),
Err(e) => {
warn!("Failed to retrieve input items for {}: {}", response_id, e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({
"error": {
"message": format!("Failed to retrieve input items: {}", e),
"type": "internal_error",
"param": Value::Null,
"code": "storage_error"
}
})),
)
.into_response()
}
}
}
async fn route_embeddings( async fn route_embeddings(
&self, &self,
_headers: Option<&HeaderMap>, _headers: Option<&HeaderMap>,
......
...@@ -220,8 +220,17 @@ pub async fn probe_endpoint_for_model( ...@@ -220,8 +220,17 @@ pub async fn probe_endpoint_for_model(
} }
} }
pub fn generate_id(prefix: &str) -> String {
use rand::RngCore;
let mut rng = rand::rng();
// Generate exactly 50 hex characters (25 bytes) for the part after the underscore
let mut bytes = [0u8; 25];
rng.fill_bytes(&mut bytes);
let hex_string: String = bytes.iter().map(|b| format!("{:02x}", b)).collect();
format!("{}_{}", prefix, hex_string)
}
// ============================================================================ // ============================================================================
// Re-export FunctionCallInProgress from mcp module // Re-export FunctionCallInProgress from mcp module
// ============================================================================ // ============================================================================
pub(crate) use super::mcp::FunctionCallInProgress; pub(crate) use super::mcp::FunctionCallInProgress;
...@@ -434,15 +434,22 @@ impl RouterTrait for RouterManager { ...@@ -434,15 +434,22 @@ impl RouterTrait for RouterManager {
async fn list_response_input_items( async fn list_response_input_items(
&self, &self,
_headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
_response_id: &str, response_id: &str,
) -> Response { ) -> Response {
// Delegate to the default router (typically http-regular)
// Response storage is shared across all routers via AppContext
let router = self.select_router_for_request(headers, None);
if let Some(router) = router {
router.list_response_input_items(headers, response_id).await
} else {
( (
StatusCode::NOT_IMPLEMENTED, StatusCode::NOT_FOUND,
"responses api not yet implemented in inference gateway mode", "No router available to list response input items",
) )
.into_response() .into_response()
} }
}
async fn get_response( async fn get_response(
&self, &self,
......
...@@ -690,7 +690,7 @@ pub fn build_app( ...@@ -690,7 +690,7 @@ pub fn build_app(
) )
.route("/v1/responses/{response_id}", delete(v1_responses_delete)) .route("/v1/responses/{response_id}", delete(v1_responses_delete))
.route( .route(
"/v1/responses/{response_id}/input", "/v1/responses/{response_id}/input_items",
get(v1_responses_list_input_items), get(v1_responses_list_input_items),
) )
.route("/v1/conversations", post(v1_conversations_create)) .route("/v1/conversations", post(v1_conversations_create))
......
...@@ -11,8 +11,8 @@ use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType ...@@ -11,8 +11,8 @@ use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType
use reqwest::Client; use reqwest::Client;
use serde_json::json; use serde_json::json;
use sglang_router_rs::{ use sglang_router_rs::{
config::{RouterConfig, RoutingMode}, config::{CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode},
core::Job, core::{ConnectionMode, Job},
routers::{RouterFactory, RouterTrait}, routers::{RouterFactory, RouterTrait},
server::AppContext, server::AppContext,
}; };
...@@ -66,14 +66,23 @@ impl TestContext { ...@@ -66,14 +66,23 @@ impl TestContext {
} }
// Update config with worker URLs if not already set // Update config with worker URLs if not already set
if let RoutingMode::Regular { match &mut config.mode {
RoutingMode::Regular {
worker_urls: ref mut urls, worker_urls: ref mut urls,
} = config.mode } => {
{
if urls.is_empty() { if urls.is_empty() {
*urls = worker_urls.clone(); *urls = worker_urls.clone();
} }
} }
RoutingMode::OpenAI {
worker_urls: ref mut urls,
} => {
if urls.is_empty() {
*urls = worker_urls.clone();
}
}
_ => {} // PrefillDecode mode has its own setup
}
let client = Client::builder() let client = Client::builder()
.timeout(std::time::Duration::from_secs(config.request_timeout_secs)) .timeout(std::time::Duration::from_secs(config.request_timeout_secs))
...@@ -212,7 +221,6 @@ mod health_tests { ...@@ -212,7 +221,6 @@ mod health_tests {
let resp = app.oneshot(req).await.unwrap(); let resp = app.oneshot(req).await.unwrap();
// With no workers, readiness should return SERVICE_UNAVAILABLE // With no workers, readiness should return SERVICE_UNAVAILABLE
assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE); assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
ctx.shutdown().await; ctx.shutdown().await;
} }
...@@ -967,7 +975,7 @@ mod responses_endpoint_tests { ...@@ -967,7 +975,7 @@ mod responses_endpoint_tests {
} }
#[tokio::test] #[tokio::test]
async fn test_v1_responses_delete_and_list_not_implemented() { async fn test_v1_responses_delete_not_implemented() {
let ctx = TestContext::new(vec![MockWorkerConfig { let ctx = TestContext::new(vec![MockWorkerConfig {
port: 18954, port: 18954,
worker_type: WorkerType::Regular, worker_type: WorkerType::Regular,
...@@ -979,7 +987,7 @@ mod responses_endpoint_tests { ...@@ -979,7 +987,7 @@ mod responses_endpoint_tests {
let app = ctx.create_app().await; let app = ctx.create_app().await;
// Use an arbitrary id for delete/list // Test DELETE is not implemented
let resp_id = "resp-test-123"; let resp_id = "resp-test-123";
let req = Request::builder() let req = Request::builder()
...@@ -990,13 +998,100 @@ mod responses_endpoint_tests { ...@@ -990,13 +998,100 @@ mod responses_endpoint_tests {
let resp = app.clone().oneshot(req).await.unwrap(); let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::NOT_IMPLEMENTED); assert_eq!(resp.status(), StatusCode::NOT_IMPLEMENTED);
ctx.shutdown().await;
}
#[tokio::test]
async fn test_v1_responses_input_items() {
// This test uses OpenAI mode because the input_items endpoint
// is only implemented in OpenAIRouter and reads from storage (no workers needed)
let config = RouterConfig {
chat_template: None,
mode: RoutingMode::OpenAI {
worker_urls: vec!["http://dummy.local".to_string()], // Dummy URL (won't be called)
},
policy: PolicyConfig::Random,
host: "127.0.0.1".to_string(),
port: 3002,
max_payload_size: 256 * 1024 * 1024,
request_timeout_secs: 600,
worker_startup_timeout_secs: 1,
worker_startup_check_interval_secs: 1,
discovery: None,
dp_aware: false,
api_key: None,
metrics: None,
log_dir: None,
log_level: None,
request_id_headers: None,
max_concurrent_requests: 64,
queue_size: 0,
queue_timeout_secs: 60,
rate_limit_tokens_per_second: None,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
enable_igw: false,
connection_mode: ConnectionMode::Http,
model_path: None,
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
};
let ctx = TestContext::new_with_config(
config,
vec![], // No workers needed
)
.await;
let app = ctx.create_app().await;
// Directly store a response in the storage to test the retrieval endpoint
use sglang_router_rs::data_connector::{ResponseId, StoredResponse};
let mut stored_response = StoredResponse::new(None);
stored_response.id = ResponseId::from("resp_test_input_items");
stored_response.input = json!([
{"id": "item_1", "content": "hello", "role": "user"},
{"id": "item_2", "content": "hi there", "role": "assistant"}
]);
stored_response.output = json!([
{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "test response"}]}
]);
ctx.app_context
.response_storage
.store_response(stored_response)
.await
.expect("Failed to store response");
// Fetch input_items for the created response
let req = Request::builder() let req = Request::builder()
.method("GET") .method("GET")
.uri(format!("/v1/responses/{}/input", resp_id)) .uri("/v1/responses/resp_test_input_items/input_items")
.body(Body::empty()) .body(Body::empty())
.unwrap(); .unwrap();
let resp = app.clone().oneshot(req).await.unwrap(); let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::NOT_IMPLEMENTED); assert_eq!(resp.status(), StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
let items_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
// Verify response structure
assert_eq!(items_json["object"], "list");
assert!(items_json["data"].is_array());
// Should have 2 input items
let items = items_json["data"].as_array().unwrap();
assert_eq!(items.len(), 2);
ctx.shutdown().await; ctx.shutdown().await;
} }
......
...@@ -279,8 +279,20 @@ async fn test_openai_router_responses_with_mock() { ...@@ -279,8 +279,20 @@ async fn test_openai_router_responses_with_mock() {
.await .await
.unwrap() .unwrap()
.expect("first response missing"); .expect("first response missing");
assert_eq!(stored1.input, "Say hi"); // Input is now stored as a JSON array of items
assert_eq!(stored1.output, "mock_output_1"); assert!(stored1.input.is_array());
let input_items = stored1.input.as_array().unwrap();
assert_eq!(input_items.len(), 1);
assert_eq!(input_items[0]["type"], "message");
assert_eq!(input_items[0]["role"], "user");
assert_eq!(input_items[0]["content"][0]["text"], "Say hi");
// Output is now stored as a JSON array of items
assert!(stored1.output.is_array());
let output_items = stored1.output.as_array().unwrap();
assert_eq!(output_items.len(), 1);
assert_eq!(output_items[0]["content"][0]["text"], "mock_output_1");
assert!(stored1.previous_response_id.is_none()); assert!(stored1.previous_response_id.is_none());
let stored2 = storage let stored2 = storage
...@@ -289,7 +301,12 @@ async fn test_openai_router_responses_with_mock() { ...@@ -289,7 +301,12 @@ async fn test_openai_router_responses_with_mock() {
.unwrap() .unwrap()
.expect("second response missing"); .expect("second response missing");
assert_eq!(stored2.previous_response_id.unwrap().0, resp1_id); assert_eq!(stored2.previous_response_id.unwrap().0, resp1_id);
assert_eq!(stored2.output, "mock_output_2");
// Output is now stored as a JSON array
assert!(stored2.output.is_array());
let output_items2 = stored2.output.as_array().unwrap();
assert_eq!(output_items2.len(), 1);
assert_eq!(output_items2[0]["content"][0]["text"], "mock_output_2");
let get1 = router let get1 = router
.get_response(None, &stored1.id.0, &ResponsesGetParams::default()) .get_response(None, &stored1.id.0, &ResponsesGetParams::default())
...@@ -481,12 +498,10 @@ async fn test_openai_router_responses_streaming_with_mock() { ...@@ -481,12 +498,10 @@ async fn test_openai_router_responses_streaming_with_mock() {
let storage = Arc::new(MemoryResponseStorage::new()); let storage = Arc::new(MemoryResponseStorage::new());
// Seed a previous response so previous_response_id logic has data to pull from. // Seed a previous response so previous_response_id logic has data to pull from.
let mut previous = StoredResponse::new( let mut previous = StoredResponse::new(None);
"Earlier bedtime question".to_string(),
"Earlier answer".to_string(),
None,
);
previous.id = ResponseId::from("resp_prev_chain"); previous.id = ResponseId::from("resp_prev_chain");
previous.input = serde_json::json!("Earlier bedtime question");
previous.output = serde_json::json!("Earlier answer");
storage.store_response(previous).await.unwrap(); storage.store_response(previous).await.unwrap();
let router = OpenAIRouter::new( let router = OpenAIRouter::new(
...@@ -541,8 +556,25 @@ async fn test_openai_router_responses_streaming_with_mock() { ...@@ -541,8 +556,25 @@ async fn test_openai_router_responses_streaming_with_mock() {
sleep(Duration::from_millis(10)).await; sleep(Duration::from_millis(10)).await;
}; };
assert_eq!(stored.input, "Tell me a bedtime story."); // Input is now stored as a JSON array of items
assert_eq!(stored.output, "Once upon a streamed unicorn adventure."); assert!(stored.input.is_array());
let input_items = stored.input.as_array().unwrap();
assert_eq!(input_items.len(), 1);
assert_eq!(input_items[0]["type"], "message");
assert_eq!(input_items[0]["role"], "user");
assert_eq!(
input_items[0]["content"][0]["text"],
"Tell me a bedtime story."
);
// Output is now stored as a JSON array of items
assert!(stored.output.is_array());
let output_items = stored.output.as_array().unwrap();
assert_eq!(output_items.len(), 1);
assert_eq!(
output_items[0]["content"][0]["text"],
"Once upon a streamed unicorn adventure."
);
assert_eq!( assert_eq!(
stored stored
.previous_response_id .previous_response_id
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment