Unverified Commit 4d2f17bd authored by Keyang Ru's avatar Keyang Ru Committed by GitHub
Browse files

[router] Function call support for openai router Responses API (#12386)

parent 7cd716f7
......@@ -97,6 +97,7 @@ jobs:
run: |
source "$HOME/.cargo/env"
cd sgl-router/
rustup component add clippy
cargo clippy --all-targets --all-features -- -D warnings
- name: Run fmt
......
......@@ -11,7 +11,7 @@ else
fi
# Install rustup (Rust installer and version manager)
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain 1.90
# Follow the installation prompts, then reload your shell
......
......@@ -268,7 +268,7 @@ pub fn make_item_id(item_type: &str) -> ConversationItemId {
"reasoning" => "rs".to_string(),
"mcp_call" => "mcp".to_string(),
"mcp_list_tools" => "mcpl".to_string(),
"function_tool_call" => "ftc".to_string(),
"function_call" => "fc".to_string(),
other => {
// Fallback: first 3 letters of type or "itm"
let mut p = other.chars().take(3).collect::<String>();
......
......@@ -9,7 +9,7 @@ use validator::Validate;
// Import shared types from common module
use super::common::{
default_model, default_true, ChatLogProbs, GenerationRequest, PromptTokenUsageInfo,
default_model, default_true, ChatLogProbs, Function, GenerationRequest, PromptTokenUsageInfo,
StringOrArray, ToolChoice, UsageInfo,
};
......@@ -22,8 +22,10 @@ pub struct ResponseTool {
#[serde(rename = "type")]
pub r#type: ResponseToolType,
// Function tool fields (used when type == "function")
// In Responses API, function fields are flattened at the top level
#[serde(flatten)]
#[serde(skip_serializing_if = "Option::is_none")]
pub function: Option<crate::protocols::common::Function>,
pub function: Option<Function>,
// MCP-specific fields (used when type == "mcp")
#[serde(skip_serializing_if = "Option::is_none")]
pub server_url: Option<String>,
......@@ -123,15 +125,16 @@ pub enum ResponseInputOutputItem {
#[serde(rename = "reasoning")]
Reasoning {
id: String,
#[serde(skip_serializing_if = "Vec::is_empty")]
summary: Vec<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
content: Vec<ResponseReasoningContent>,
#[serde(skip_serializing_if = "Option::is_none")]
status: Option<String>,
},
#[serde(rename = "function_tool_call")]
#[serde(rename = "function_call")]
FunctionToolCall {
id: String,
call_id: String,
name: String,
arguments: String,
#[serde(skip_serializing_if = "Option::is_none")]
......@@ -141,6 +144,7 @@ pub enum ResponseInputOutputItem {
},
#[serde(rename = "function_call_output")]
FunctionCallOutput {
id: Option<String>,
call_id: String,
output: String,
#[serde(skip_serializing_if = "Option::is_none")]
......@@ -207,15 +211,15 @@ pub enum ResponseOutputItem {
#[serde(rename = "reasoning")]
Reasoning {
id: String,
#[serde(skip_serializing_if = "Vec::is_empty")]
summary: Vec<String>,
content: Vec<ResponseReasoningContent>,
#[serde(skip_serializing_if = "Option::is_none")]
status: Option<String>,
},
#[serde(rename = "function_tool_call")]
#[serde(rename = "function_call")]
FunctionToolCall {
id: String,
call_id: String,
name: String,
arguments: String,
#[serde(skip_serializing_if = "Option::is_none")]
......@@ -925,6 +929,7 @@ impl ResponseOutputItem {
/// Create a new function tool call output item
pub fn new_function_tool_call(
id: String,
call_id: String,
name: String,
arguments: String,
output: Option<String>,
......@@ -932,6 +937,7 @@ impl ResponseOutputItem {
) -> Self {
Self::FunctionToolCall {
id,
call_id,
name,
arguments,
output,
......
......@@ -912,6 +912,7 @@ fn build_next_request_with_tools(
for tool_call in tool_calls {
items.push(ResponseInputOutputItem::FunctionToolCall {
id: tool_call.id.clone(),
call_id: tool_call.id.clone(),
name: tool_call.function.name.clone(),
arguments: tool_call
.function
......
......@@ -304,6 +304,7 @@ pub fn chat_to_responses(
for tool_call in tool_calls {
output.push(ResponseOutputItem::FunctionToolCall {
id: tool_call.id.clone(),
call_id: tool_call.id.clone(),
name: tool_call.function.name.clone(),
arguments: tool_call.function.arguments.clone().unwrap_or_default(),
output: None, // Tool hasn't been executed yet
......
......@@ -721,6 +721,7 @@ impl StreamingResponseAccumulator {
while self.tool_calls.len() <= index {
self.tool_calls.push(ResponseOutputItem::FunctionToolCall {
id: String::new(),
call_id: String::new(),
name: String::new(),
arguments: String::new(),
output: None,
......
......@@ -124,6 +124,7 @@ impl ToolLoopState {
self.conversation_history
.push(ResponseInputOutputItem::FunctionToolCall {
id: call_id.clone(),
call_id: call_id.clone(),
name: tool_name.clone(),
arguments: args_json_str.clone(),
output: Some(output_str.clone()),
......
......@@ -384,9 +384,9 @@ const SUPPORTED_ITEM_TYPES: &[&str] = &[
"mcp_list_tools",
"mcp_call",
"item_reference",
// Accepted but not yet implemented (stored, warning returned)
"function_tool_call",
"function_call",
"function_call_output",
// Accepted but not yet implemented (stored, warning returned)
"file_search_call",
"computer_call",
"computer_call_output",
......@@ -936,6 +936,26 @@ fn item_to_json(item: &crate::data_connector::ConversationItem) -> Value {
}
}
}
"function_call" => {
// Extract function_call fields: call_id, name, arguments, output
if let Some(content_obj) = item.content.as_object() {
for field in ["call_id", "name", "arguments", "output"] {
if let Some(value) = content_obj.get(field) {
obj.insert(field.to_string(), value.clone());
}
}
}
}
"function_call_output" => {
// Extract function_call_output fields: call_id, output
if let Some(content_obj) = item.content.as_object() {
for field in ["call_id", "output"] {
if let Some(value) = content_obj.get(field) {
obj.insert(field.to_string(), value.clone());
}
}
}
}
_ => {
// For all other types (message, reasoning, etc.), keep content as-is
obj.insert("content".to_string(), item.content.clone());
......@@ -1144,7 +1164,7 @@ fn extract_input_items(input: &ResponseInput) -> Result<Vec<Value>, String> {
}))
}
_ => {
// For other item types (Message, Reasoning, FunctionToolCall), serialize and ensure ID
// For other item types (Message, Reasoning, FunctionToolCall, FunctionCallOutput), serialize and ensure ID
let mut value = serde_json::to_value(item)
.map_err(|e| format!("Failed to serialize item: {}", e))?;
......@@ -1157,7 +1177,15 @@ fn extract_input_items(input: &ResponseInput) -> Result<Vec<Value>, String> {
.map(|s| s.is_empty())
.unwrap_or(true)
{
obj.insert("id".to_string(), json!(generate_id("item")));
// Generate ID with appropriate prefix based on type
let item_type =
obj.get("type").and_then(|v| v.as_str()).unwrap_or("item");
let prefix = match item_type {
"function_call" | "function_call_output" => "fc",
"message" => "msg",
_ => "item",
};
obj.insert("id".to_string(), json!(generate_id(prefix)));
}
}
......@@ -1201,17 +1229,31 @@ async fn link_items_to_conversation(
.get("role")
.and_then(|v| v.as_str())
.map(String::from);
let content = input_item_value
// For function_call and function_call_output, store the entire item as content
// For message types, extract just the content field
let content = if item_type == "function_call" || item_type == "function_call_output" {
input_item_value.clone()
} else {
input_item_value
.get("content")
.cloned()
.unwrap_or(json!([]));
.unwrap_or(json!([]))
};
let status = input_item_value
.get("status")
.and_then(|v| v.as_str())
.map(String::from);
// Extract the original item ID from input if present
let item_id = input_item_value
.get("id")
.and_then(|v| v.as_str())
.map(ConversationItemId::from);
let new_item = NewConversationItem {
id: None, // Let storage generate ID
id: item_id, // Preserve ID if present
response_id: response_id_opt.clone(),
item_type: item_type.to_string(),
role,
......@@ -1252,7 +1294,7 @@ async fn link_items_to_conversation(
.cloned()
.unwrap_or(json!([]))
} else {
// For other types (reasoning, function_tool_call, mcp_call, etc.)
// For other types (reasoning, function_call, function_call_output, mcp_call, etc.)
// store the entire item structure
output_item_value.clone()
};
......
......@@ -810,21 +810,77 @@ impl crate::routers::RouterTrait for OpenAIRouter {
Ok(stored_items) => {
let mut items: Vec<ResponseInputOutputItem> = Vec::new();
for item in stored_items.into_iter() {
// Only use message items for conversation context
// Skip non-message items (reasoning, function calls, etc.)
if item.item_type == "message" {
if let Ok(content_parts) =
serde_json::from_value::<Vec<ResponseContentPart>>(
// Include messages, function calls, and function call outputs
// Skip reasoning items as they're internal processing details
match item.item_type.as_str() {
"message" => {
match serde_json::from_value::<Vec<ResponseContentPart>>(
item.content.clone(),
)
{
) {
Ok(content_parts) => {
items.push(ResponseInputOutputItem::Message {
id: item.id.0.clone(),
role: item.role.clone().unwrap_or_else(|| "user".to_string()),
role: item
.role
.clone()
.unwrap_or_else(|| "user".to_string()),
content: content_parts,
status: item.status.clone(),
});
}
Err(e) => {
tracing::error!(
"Failed to deserialize message content: {}",
e
);
}
}
}
"function_call" => {
// The entire function_call item is stored in content field
match serde_json::from_value::<ResponseInputOutputItem>(
item.content.clone(),
) {
Ok(func_call) => items.push(func_call),
Err(e) => {
tracing::error!(
"Failed to deserialize function_call: {}",
e
);
}
}
}
"function_call_output" => {
// The entire function_call_output item is stored in content field
tracing::debug!(
"Loading function_call_output from DB - content: {}",
serde_json::to_string_pretty(&item.content)
.unwrap_or_else(|_| "failed to serialize".to_string())
);
match serde_json::from_value::<ResponseInputOutputItem>(
item.content.clone(),
) {
Ok(func_output) => {
tracing::debug!(
"Successfully deserialized function_call_output"
);
items.push(func_output);
}
Err(e) => {
tracing::error!(
"Failed to deserialize function_call_output: {}",
e
);
}
}
}
"reasoning" => {
// Skip reasoning items - they're internal processing details
}
_ => {
// Skip unknown item types
warn!("Unknown item type in conversation: {}", item.item_type);
}
}
}
......@@ -889,6 +945,10 @@ impl crate::routers::RouterTrait for OpenAIRouter {
// Always set store=false for upstream (we store internally)
request_body.store = Some(false);
// Filter out reasoning items from input - they're internal processing details
if let ResponseInput::Items(ref mut items) = request_body.input {
items.retain(|item| !matches!(item, ResponseInputOutputItem::Reasoning { .. }));
}
// Convert to JSON and strip SGLang-specific fields
let mut payload = match to_value(&request_body) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment