Unverified Commit 4d2f17bd authored by Keyang Ru's avatar Keyang Ru Committed by GitHub
Browse files

[router] Function call support for openai router Responses API (#12386)

parent 7cd716f7
...@@ -97,6 +97,7 @@ jobs: ...@@ -97,6 +97,7 @@ jobs:
run: | run: |
source "$HOME/.cargo/env" source "$HOME/.cargo/env"
cd sgl-router/ cd sgl-router/
rustup component add clippy
cargo clippy --all-targets --all-features -- -D warnings cargo clippy --all-targets --all-features -- -D warnings
- name: Run fmt - name: Run fmt
......
...@@ -11,7 +11,7 @@ else ...@@ -11,7 +11,7 @@ else
fi fi
# Install rustup (Rust installer and version manager) # Install rustup (Rust installer and version manager)
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain 1.90
# Follow the installation prompts, then reload your shell # Follow the installation prompts, then reload your shell
......
...@@ -268,7 +268,7 @@ pub fn make_item_id(item_type: &str) -> ConversationItemId { ...@@ -268,7 +268,7 @@ pub fn make_item_id(item_type: &str) -> ConversationItemId {
"reasoning" => "rs".to_string(), "reasoning" => "rs".to_string(),
"mcp_call" => "mcp".to_string(), "mcp_call" => "mcp".to_string(),
"mcp_list_tools" => "mcpl".to_string(), "mcp_list_tools" => "mcpl".to_string(),
"function_tool_call" => "ftc".to_string(), "function_call" => "fc".to_string(),
other => { other => {
// Fallback: first 3 letters of type or "itm" // Fallback: first 3 letters of type or "itm"
let mut p = other.chars().take(3).collect::<String>(); let mut p = other.chars().take(3).collect::<String>();
......
...@@ -9,7 +9,7 @@ use validator::Validate; ...@@ -9,7 +9,7 @@ use validator::Validate;
// Import shared types from common module // Import shared types from common module
use super::common::{ use super::common::{
default_model, default_true, ChatLogProbs, GenerationRequest, PromptTokenUsageInfo, default_model, default_true, ChatLogProbs, Function, GenerationRequest, PromptTokenUsageInfo,
StringOrArray, ToolChoice, UsageInfo, StringOrArray, ToolChoice, UsageInfo,
}; };
...@@ -22,8 +22,10 @@ pub struct ResponseTool { ...@@ -22,8 +22,10 @@ pub struct ResponseTool {
#[serde(rename = "type")] #[serde(rename = "type")]
pub r#type: ResponseToolType, pub r#type: ResponseToolType,
// Function tool fields (used when type == "function") // Function tool fields (used when type == "function")
// In Responses API, function fields are flattened at the top level
#[serde(flatten)]
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub function: Option<crate::protocols::common::Function>, pub function: Option<Function>,
// MCP-specific fields (used when type == "mcp") // MCP-specific fields (used when type == "mcp")
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub server_url: Option<String>, pub server_url: Option<String>,
...@@ -123,15 +125,16 @@ pub enum ResponseInputOutputItem { ...@@ -123,15 +125,16 @@ pub enum ResponseInputOutputItem {
#[serde(rename = "reasoning")] #[serde(rename = "reasoning")]
Reasoning { Reasoning {
id: String, id: String,
#[serde(skip_serializing_if = "Vec::is_empty")]
summary: Vec<String>, summary: Vec<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
content: Vec<ResponseReasoningContent>, content: Vec<ResponseReasoningContent>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
status: Option<String>, status: Option<String>,
}, },
#[serde(rename = "function_tool_call")] #[serde(rename = "function_call")]
FunctionToolCall { FunctionToolCall {
id: String, id: String,
call_id: String,
name: String, name: String,
arguments: String, arguments: String,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
...@@ -141,6 +144,7 @@ pub enum ResponseInputOutputItem { ...@@ -141,6 +144,7 @@ pub enum ResponseInputOutputItem {
}, },
#[serde(rename = "function_call_output")] #[serde(rename = "function_call_output")]
FunctionCallOutput { FunctionCallOutput {
id: Option<String>,
call_id: String, call_id: String,
output: String, output: String,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
...@@ -207,15 +211,15 @@ pub enum ResponseOutputItem { ...@@ -207,15 +211,15 @@ pub enum ResponseOutputItem {
#[serde(rename = "reasoning")] #[serde(rename = "reasoning")]
Reasoning { Reasoning {
id: String, id: String,
#[serde(skip_serializing_if = "Vec::is_empty")]
summary: Vec<String>, summary: Vec<String>,
content: Vec<ResponseReasoningContent>, content: Vec<ResponseReasoningContent>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
status: Option<String>, status: Option<String>,
}, },
#[serde(rename = "function_tool_call")] #[serde(rename = "function_call")]
FunctionToolCall { FunctionToolCall {
id: String, id: String,
call_id: String,
name: String, name: String,
arguments: String, arguments: String,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
...@@ -925,6 +929,7 @@ impl ResponseOutputItem { ...@@ -925,6 +929,7 @@ impl ResponseOutputItem {
/// Create a new function tool call output item /// Create a new function tool call output item
pub fn new_function_tool_call( pub fn new_function_tool_call(
id: String, id: String,
call_id: String,
name: String, name: String,
arguments: String, arguments: String,
output: Option<String>, output: Option<String>,
...@@ -932,6 +937,7 @@ impl ResponseOutputItem { ...@@ -932,6 +937,7 @@ impl ResponseOutputItem {
) -> Self { ) -> Self {
Self::FunctionToolCall { Self::FunctionToolCall {
id, id,
call_id,
name, name,
arguments, arguments,
output, output,
......
...@@ -912,6 +912,7 @@ fn build_next_request_with_tools( ...@@ -912,6 +912,7 @@ fn build_next_request_with_tools(
for tool_call in tool_calls { for tool_call in tool_calls {
items.push(ResponseInputOutputItem::FunctionToolCall { items.push(ResponseInputOutputItem::FunctionToolCall {
id: tool_call.id.clone(), id: tool_call.id.clone(),
call_id: tool_call.id.clone(),
name: tool_call.function.name.clone(), name: tool_call.function.name.clone(),
arguments: tool_call arguments: tool_call
.function .function
......
...@@ -304,6 +304,7 @@ pub fn chat_to_responses( ...@@ -304,6 +304,7 @@ pub fn chat_to_responses(
for tool_call in tool_calls { for tool_call in tool_calls {
output.push(ResponseOutputItem::FunctionToolCall { output.push(ResponseOutputItem::FunctionToolCall {
id: tool_call.id.clone(), id: tool_call.id.clone(),
call_id: tool_call.id.clone(),
name: tool_call.function.name.clone(), name: tool_call.function.name.clone(),
arguments: tool_call.function.arguments.clone().unwrap_or_default(), arguments: tool_call.function.arguments.clone().unwrap_or_default(),
output: None, // Tool hasn't been executed yet output: None, // Tool hasn't been executed yet
......
...@@ -721,6 +721,7 @@ impl StreamingResponseAccumulator { ...@@ -721,6 +721,7 @@ impl StreamingResponseAccumulator {
while self.tool_calls.len() <= index { while self.tool_calls.len() <= index {
self.tool_calls.push(ResponseOutputItem::FunctionToolCall { self.tool_calls.push(ResponseOutputItem::FunctionToolCall {
id: String::new(), id: String::new(),
call_id: String::new(),
name: String::new(), name: String::new(),
arguments: String::new(), arguments: String::new(),
output: None, output: None,
......
...@@ -124,6 +124,7 @@ impl ToolLoopState { ...@@ -124,6 +124,7 @@ impl ToolLoopState {
self.conversation_history self.conversation_history
.push(ResponseInputOutputItem::FunctionToolCall { .push(ResponseInputOutputItem::FunctionToolCall {
id: call_id.clone(), id: call_id.clone(),
call_id: call_id.clone(),
name: tool_name.clone(), name: tool_name.clone(),
arguments: args_json_str.clone(), arguments: args_json_str.clone(),
output: Some(output_str.clone()), output: Some(output_str.clone()),
......
...@@ -384,9 +384,9 @@ const SUPPORTED_ITEM_TYPES: &[&str] = &[ ...@@ -384,9 +384,9 @@ const SUPPORTED_ITEM_TYPES: &[&str] = &[
"mcp_list_tools", "mcp_list_tools",
"mcp_call", "mcp_call",
"item_reference", "item_reference",
// Accepted but not yet implemented (stored, warning returned) "function_call",
"function_tool_call",
"function_call_output", "function_call_output",
// Accepted but not yet implemented (stored, warning returned)
"file_search_call", "file_search_call",
"computer_call", "computer_call",
"computer_call_output", "computer_call_output",
...@@ -936,6 +936,26 @@ fn item_to_json(item: &crate::data_connector::ConversationItem) -> Value { ...@@ -936,6 +936,26 @@ fn item_to_json(item: &crate::data_connector::ConversationItem) -> Value {
} }
} }
} }
"function_call" => {
// Extract function_call fields: call_id, name, arguments, output
if let Some(content_obj) = item.content.as_object() {
for field in ["call_id", "name", "arguments", "output"] {
if let Some(value) = content_obj.get(field) {
obj.insert(field.to_string(), value.clone());
}
}
}
}
"function_call_output" => {
// Extract function_call_output fields: call_id, output
if let Some(content_obj) = item.content.as_object() {
for field in ["call_id", "output"] {
if let Some(value) = content_obj.get(field) {
obj.insert(field.to_string(), value.clone());
}
}
}
}
_ => { _ => {
// For all other types (message, reasoning, etc.), keep content as-is // For all other types (message, reasoning, etc.), keep content as-is
obj.insert("content".to_string(), item.content.clone()); obj.insert("content".to_string(), item.content.clone());
...@@ -1144,7 +1164,7 @@ fn extract_input_items(input: &ResponseInput) -> Result<Vec<Value>, String> { ...@@ -1144,7 +1164,7 @@ fn extract_input_items(input: &ResponseInput) -> Result<Vec<Value>, String> {
})) }))
} }
_ => { _ => {
// For other item types (Message, Reasoning, FunctionToolCall), serialize and ensure ID // For other item types (Message, Reasoning, FunctionToolCall, FunctionCallOutput), serialize and ensure ID
let mut value = serde_json::to_value(item) let mut value = serde_json::to_value(item)
.map_err(|e| format!("Failed to serialize item: {}", e))?; .map_err(|e| format!("Failed to serialize item: {}", e))?;
...@@ -1157,7 +1177,15 @@ fn extract_input_items(input: &ResponseInput) -> Result<Vec<Value>, String> { ...@@ -1157,7 +1177,15 @@ fn extract_input_items(input: &ResponseInput) -> Result<Vec<Value>, String> {
.map(|s| s.is_empty()) .map(|s| s.is_empty())
.unwrap_or(true) .unwrap_or(true)
{ {
obj.insert("id".to_string(), json!(generate_id("item"))); // Generate ID with appropriate prefix based on type
let item_type =
obj.get("type").and_then(|v| v.as_str()).unwrap_or("item");
let prefix = match item_type {
"function_call" | "function_call_output" => "fc",
"message" => "msg",
_ => "item",
};
obj.insert("id".to_string(), json!(generate_id(prefix)));
} }
} }
...@@ -1201,17 +1229,31 @@ async fn link_items_to_conversation( ...@@ -1201,17 +1229,31 @@ async fn link_items_to_conversation(
.get("role") .get("role")
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.map(String::from); .map(String::from);
let content = input_item_value
.get("content") // For function_call and function_call_output, store the entire item as content
.cloned() // For message types, extract just the content field
.unwrap_or(json!([])); let content = if item_type == "function_call" || item_type == "function_call_output" {
input_item_value.clone()
} else {
input_item_value
.get("content")
.cloned()
.unwrap_or(json!([]))
};
let status = input_item_value let status = input_item_value
.get("status") .get("status")
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.map(String::from); .map(String::from);
// Extract the original item ID from input if present
let item_id = input_item_value
.get("id")
.and_then(|v| v.as_str())
.map(ConversationItemId::from);
let new_item = NewConversationItem { let new_item = NewConversationItem {
id: None, // Let storage generate ID id: item_id, // Preserve ID if present
response_id: response_id_opt.clone(), response_id: response_id_opt.clone(),
item_type: item_type.to_string(), item_type: item_type.to_string(),
role, role,
...@@ -1252,7 +1294,7 @@ async fn link_items_to_conversation( ...@@ -1252,7 +1294,7 @@ async fn link_items_to_conversation(
.cloned() .cloned()
.unwrap_or(json!([])) .unwrap_or(json!([]))
} else { } else {
// For other types (reasoning, function_tool_call, mcp_call, etc.) // For other types (reasoning, function_call, function_call_output, mcp_call, etc.)
// store the entire item structure // store the entire item structure
output_item_value.clone() output_item_value.clone()
}; };
......
...@@ -810,20 +810,76 @@ impl crate::routers::RouterTrait for OpenAIRouter { ...@@ -810,20 +810,76 @@ impl crate::routers::RouterTrait for OpenAIRouter {
Ok(stored_items) => { Ok(stored_items) => {
let mut items: Vec<ResponseInputOutputItem> = Vec::new(); let mut items: Vec<ResponseInputOutputItem> = Vec::new();
for item in stored_items.into_iter() { for item in stored_items.into_iter() {
// Only use message items for conversation context // Include messages, function calls, and function call outputs
// Skip non-message items (reasoning, function calls, etc.) // Skip reasoning items as they're internal processing details
if item.item_type == "message" { match item.item_type.as_str() {
if let Ok(content_parts) = "message" => {
serde_json::from_value::<Vec<ResponseContentPart>>( match serde_json::from_value::<Vec<ResponseContentPart>>(
item.content.clone(), item.content.clone(),
) ) {
{ Ok(content_parts) => {
items.push(ResponseInputOutputItem::Message { items.push(ResponseInputOutputItem::Message {
id: item.id.0.clone(), id: item.id.0.clone(),
role: item.role.clone().unwrap_or_else(|| "user".to_string()), role: item
content: content_parts, .role
status: item.status.clone(), .clone()
}); .unwrap_or_else(|| "user".to_string()),
content: content_parts,
status: item.status.clone(),
});
}
Err(e) => {
tracing::error!(
"Failed to deserialize message content: {}",
e
);
}
}
}
"function_call" => {
// The entire function_call item is stored in content field
match serde_json::from_value::<ResponseInputOutputItem>(
item.content.clone(),
) {
Ok(func_call) => items.push(func_call),
Err(e) => {
tracing::error!(
"Failed to deserialize function_call: {}",
e
);
}
}
}
"function_call_output" => {
// The entire function_call_output item is stored in content field
tracing::debug!(
"Loading function_call_output from DB - content: {}",
serde_json::to_string_pretty(&item.content)
.unwrap_or_else(|_| "failed to serialize".to_string())
);
match serde_json::from_value::<ResponseInputOutputItem>(
item.content.clone(),
) {
Ok(func_output) => {
tracing::debug!(
"Successfully deserialized function_call_output"
);
items.push(func_output);
}
Err(e) => {
tracing::error!(
"Failed to deserialize function_call_output: {}",
e
);
}
}
}
"reasoning" => {
// Skip reasoning items - they're internal processing details
}
_ => {
// Skip unknown item types
warn!("Unknown item type in conversation: {}", item.item_type);
} }
} }
} }
...@@ -889,6 +945,10 @@ impl crate::routers::RouterTrait for OpenAIRouter { ...@@ -889,6 +945,10 @@ impl crate::routers::RouterTrait for OpenAIRouter {
// Always set store=false for upstream (we store internally) // Always set store=false for upstream (we store internally)
request_body.store = Some(false); request_body.store = Some(false);
// Filter out reasoning items from input - they're internal processing details
if let ResponseInput::Items(ref mut items) = request_body.input {
items.retain(|item| !matches!(item, ResponseInputOutputItem::Reasoning { .. }));
}
// Convert to JSON and strip SGLang-specific fields // Convert to JSON and strip SGLang-specific fields
let mut payload = match to_value(&request_body) { let mut payload = match to_value(&request_body) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment