Unverified Commit 63e84352 authored by Keyang Ru's avatar Keyang Ru Committed by GitHub
Browse files

[router] openai router: support grok model (#11511)

parent a20e7df8
...@@ -1073,8 +1073,8 @@ fn generate_request_id() -> String { ...@@ -1073,8 +1073,8 @@ fn generate_request_id() -> String {
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ResponsesRequest { pub struct ResponsesRequest {
/// Run the request in the background /// Run the request in the background
#[serde(default)] #[serde(skip_serializing_if = "Option::is_none")]
pub background: bool, pub background: Option<bool>,
/// Fields to include in the response /// Fields to include in the response
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
...@@ -1108,8 +1108,8 @@ pub struct ResponsesRequest { ...@@ -1108,8 +1108,8 @@ pub struct ResponsesRequest {
pub conversation: Option<String>, pub conversation: Option<String>,
/// Whether to enable parallel tool calls /// Whether to enable parallel tool calls
#[serde(default = "default_true")] #[serde(skip_serializing_if = "Option::is_none")]
pub parallel_tool_calls: bool, pub parallel_tool_calls: Option<bool>,
/// ID of previous response to continue from /// ID of previous response to continue from
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
...@@ -1120,40 +1120,40 @@ pub struct ResponsesRequest { ...@@ -1120,40 +1120,40 @@ pub struct ResponsesRequest {
pub reasoning: Option<ResponseReasoningParam>, pub reasoning: Option<ResponseReasoningParam>,
/// Service tier /// Service tier
#[serde(default)] #[serde(skip_serializing_if = "Option::is_none")]
pub service_tier: ServiceTier, pub service_tier: Option<ServiceTier>,
/// Whether to store the response /// Whether to store the response
#[serde(default = "default_true")] #[serde(skip_serializing_if = "Option::is_none")]
pub store: bool, pub store: Option<bool>,
/// Whether to stream the response /// Whether to stream the response
#[serde(default)] #[serde(skip_serializing_if = "Option::is_none")]
pub stream: bool, pub stream: Option<bool>,
/// Temperature for sampling /// Temperature for sampling
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>, pub temperature: Option<f32>,
/// Tool choice behavior /// Tool choice behavior
#[serde(default)] #[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: ToolChoice, pub tool_choice: Option<ToolChoice>,
/// Available tools /// Available tools
#[serde(default)] #[serde(skip_serializing_if = "Option::is_none")]
pub tools: Vec<ResponseTool>, pub tools: Option<Vec<ResponseTool>>,
/// Number of top logprobs to return /// Number of top logprobs to return
#[serde(default)] #[serde(skip_serializing_if = "Option::is_none")]
pub top_logprobs: u32, pub top_logprobs: Option<u32>,
/// Top-p sampling parameter /// Top-p sampling parameter
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>, pub top_p: Option<f32>,
/// Truncation behavior /// Truncation behavior
#[serde(default)] #[serde(skip_serializing_if = "Option::is_none")]
pub truncation: Truncation, pub truncation: Option<Truncation>,
/// User identifier /// User identifier
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
...@@ -1168,12 +1168,12 @@ pub struct ResponsesRequest { ...@@ -1168,12 +1168,12 @@ pub struct ResponsesRequest {
pub priority: i32, pub priority: i32,
/// Frequency penalty /// Frequency penalty
#[serde(default)] #[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: f32, pub frequency_penalty: Option<f32>,
/// Presence penalty /// Presence penalty
#[serde(default)] #[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: f32, pub presence_penalty: Option<f32>,
/// Stop sequences /// Stop sequences
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
...@@ -1210,7 +1210,7 @@ fn default_repetition_penalty() -> f32 { ...@@ -1210,7 +1210,7 @@ fn default_repetition_penalty() -> f32 {
impl Default for ResponsesRequest { impl Default for ResponsesRequest {
fn default() -> Self { fn default() -> Self {
Self { Self {
background: false, background: None,
include: None, include: None,
input: ResponseInput::Text(String::new()), input: ResponseInput::Text(String::new()),
instructions: None, instructions: None,
...@@ -1219,23 +1219,23 @@ impl Default for ResponsesRequest { ...@@ -1219,23 +1219,23 @@ impl Default for ResponsesRequest {
metadata: None, metadata: None,
model: None, model: None,
conversation: None, conversation: None,
parallel_tool_calls: true, parallel_tool_calls: None,
previous_response_id: None, previous_response_id: None,
reasoning: None, reasoning: None,
service_tier: ServiceTier::default(), service_tier: None,
store: true, store: None,
stream: false, stream: None,
temperature: None, temperature: None,
tool_choice: ToolChoice::default(), tool_choice: None,
tools: Vec::new(), tools: None,
top_logprobs: 0, top_logprobs: None,
top_p: None, top_p: None,
truncation: Truncation::default(), truncation: None,
user: None, user: None,
request_id: generate_request_id(), request_id: generate_request_id(),
priority: 0, priority: 0,
frequency_penalty: 0.0, frequency_penalty: None,
presence_penalty: 0.0, presence_penalty: None,
stop: None, stop: None,
top_k: default_top_k(), top_k: default_top_k(),
min_p: 0.0, min_p: 0.0,
...@@ -1299,14 +1299,18 @@ impl ResponsesRequest { ...@@ -1299,14 +1299,18 @@ impl ResponsesRequest {
"top_p".to_string(), "top_p".to_string(),
Value::Number(Number::from_f64(top_p as f64).unwrap()), Value::Number(Number::from_f64(top_p as f64).unwrap()),
); );
params.insert( if let Some(fp) = self.frequency_penalty {
"frequency_penalty".to_string(), params.insert(
Value::Number(Number::from_f64(self.frequency_penalty as f64).unwrap()), "frequency_penalty".to_string(),
); Value::Number(Number::from_f64(fp as f64).unwrap()),
params.insert( );
"presence_penalty".to_string(), }
Value::Number(Number::from_f64(self.presence_penalty as f64).unwrap()), if let Some(pp) = self.presence_penalty {
); params.insert(
"presence_penalty".to_string(),
Value::Number(Number::from_f64(pp as f64).unwrap()),
);
}
params.insert("top_k".to_string(), Value::Number(Number::from(self.top_k))); params.insert("top_k".to_string(), Value::Number(Number::from(self.top_k)));
params.insert( params.insert(
"min_p".to_string(), "min_p".to_string(),
...@@ -1337,7 +1341,7 @@ impl ResponsesRequest { ...@@ -1337,7 +1341,7 @@ impl ResponsesRequest {
impl GenerationRequest for ResponsesRequest { impl GenerationRequest for ResponsesRequest {
fn is_stream(&self) -> bool { fn is_stream(&self) -> bool {
self.stream self.stream.unwrap_or(false)
} }
fn get_model(&self) -> Option<&str> { fn get_model(&self) -> Option<&str> {
...@@ -1523,13 +1527,13 @@ impl ResponsesResponse { ...@@ -1523,13 +1527,13 @@ impl ResponsesResponse {
max_output_tokens: request.max_output_tokens, max_output_tokens: request.max_output_tokens,
model: model_name, model: model_name,
output, output,
parallel_tool_calls: request.parallel_tool_calls, parallel_tool_calls: request.parallel_tool_calls.unwrap_or(true),
previous_response_id: request.previous_response_id.clone(), previous_response_id: request.previous_response_id.clone(),
reasoning: request.reasoning.as_ref().map(|r| ReasoningInfo { reasoning: request.reasoning.as_ref().map(|r| ReasoningInfo {
effort: r.effort.as_ref().map(|e| format!("{:?}", e)), effort: r.effort.as_ref().map(|e| format!("{:?}", e)),
summary: None, summary: None,
}), }),
store: request.store, store: request.store.unwrap_or(false),
temperature: request.temperature, temperature: request.temperature,
text: Some(ResponseTextFormat { text: Some(ResponseTextFormat {
format: TextFormatType { format: TextFormatType {
...@@ -1537,17 +1541,19 @@ impl ResponsesResponse { ...@@ -1537,17 +1541,19 @@ impl ResponsesResponse {
}, },
}), }),
tool_choice: match &request.tool_choice { tool_choice: match &request.tool_choice {
ToolChoice::Value(ToolChoiceValue::Auto) => "auto".to_string(), Some(ToolChoice::Value(ToolChoiceValue::Auto)) => "auto".to_string(),
ToolChoice::Value(ToolChoiceValue::Required) => "required".to_string(), Some(ToolChoice::Value(ToolChoiceValue::Required)) => "required".to_string(),
ToolChoice::Value(ToolChoiceValue::None) => "none".to_string(), Some(ToolChoice::Value(ToolChoiceValue::None)) => "none".to_string(),
ToolChoice::Function { .. } => "function".to_string(), Some(ToolChoice::Function { .. }) => "function".to_string(),
ToolChoice::AllowedTools { mode, .. } => mode.clone(), Some(ToolChoice::AllowedTools { mode, .. }) => mode.clone(),
None => "auto".to_string(),
}, },
tools: request.tools.clone(), tools: request.tools.clone().unwrap_or_default(),
top_p: request.top_p, top_p: request.top_p,
truncation: match &request.truncation { truncation: match &request.truncation {
Truncation::Auto => Some("auto".to_string()), Some(Truncation::Auto) => Some("auto".to_string()),
Truncation::Disabled => Some("disabled".to_string()), Some(Truncation::Disabled) => Some("disabled".to_string()),
None => None,
}, },
usage: usage.map(ResponsesUsage::Classic), usage: usage.map(ResponsesUsage::Classic),
user: request.user.clone(), user: request.user.clone(),
......
...@@ -689,9 +689,13 @@ pub(super) async fn execute_tool_loop( ...@@ -689,9 +689,13 @@ pub(super) async fn execute_tool_loop(
if state.total_calls > 0 { if state.total_calls > 0 {
let server_label = original_body let server_label = original_body
.tools .tools
.iter() .as_ref()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp)) .and_then(|tools| {
.and_then(|t| t.server_label.as_deref()) tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
.and_then(|t| t.server_label.as_deref())
})
.unwrap_or("mcp"); .unwrap_or("mcp");
// Build mcp_list_tools item // Build mcp_list_tools item
...@@ -747,9 +751,13 @@ pub(super) fn build_incomplete_response( ...@@ -747,9 +751,13 @@ pub(super) fn build_incomplete_response(
if let Some(output_array) = obj.get_mut("output").and_then(|v| v.as_array_mut()) { if let Some(output_array) = obj.get_mut("output").and_then(|v| v.as_array_mut()) {
let server_label = original_body let server_label = original_body
.tools .tools
.iter() .as_ref()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp)) .and_then(|tools| {
.and_then(|t| t.server_label.as_deref()) tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
.and_then(|t| t.server_label.as_deref())
})
.unwrap_or("mcp"); .unwrap_or("mcp");
// Find any function_call items and convert them to mcp_call (incomplete) // Find any function_call items and convert them to mcp_call (incomplete)
......
...@@ -129,7 +129,10 @@ pub(super) fn patch_streaming_response_json( ...@@ -129,7 +129,10 @@ pub(super) fn patch_streaming_response_json(
} }
} }
obj.insert("store".to_string(), Value::Bool(original_body.store)); obj.insert(
"store".to_string(),
Value::Bool(original_body.store.unwrap_or(false)),
);
if obj if obj
.get("model") .get("model")
...@@ -205,7 +208,7 @@ pub(super) fn rewrite_streaming_block( ...@@ -205,7 +208,7 @@ pub(super) fn rewrite_streaming_block(
let mut changed = false; let mut changed = false;
if let Some(response_obj) = parsed.get_mut("response").and_then(|v| v.as_object_mut()) { if let Some(response_obj) = parsed.get_mut("response").and_then(|v| v.as_object_mut()) {
let desired_store = Value::Bool(original_body.store); let desired_store = Value::Bool(original_body.store.unwrap_or(false));
if response_obj.get("store") != Some(&desired_store) { if response_obj.get("store") != Some(&desired_store) {
response_obj.insert("store".to_string(), desired_store); response_obj.insert("store".to_string(), desired_store);
changed = true; changed = true;
...@@ -267,10 +270,11 @@ pub(super) fn rewrite_streaming_block( ...@@ -267,10 +270,11 @@ pub(super) fn rewrite_streaming_block(
/// Mask function tools as MCP tools in response for client /// Mask function tools as MCP tools in response for client
pub(super) fn mask_tools_as_mcp(resp: &mut Value, original_body: &ResponsesRequest) { pub(super) fn mask_tools_as_mcp(resp: &mut Value, original_body: &ResponsesRequest) {
let mcp_tool = original_body let mcp_tool = original_body.tools.as_ref().and_then(|tools| {
.tools tools
.iter() .iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some()); .find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some())
});
let Some(t) = mcp_tool else { let Some(t) = mcp_tool else {
return; return;
}; };
......
...@@ -148,7 +148,11 @@ impl OpenAIRouter { ...@@ -148,7 +148,11 @@ impl OpenAIRouter {
original_previous_response_id: Option<String>, original_previous_response_id: Option<String>,
) -> Response { ) -> Response {
// Check if MCP is active for this request // Check if MCP is active for this request
let req_mcp_manager = mcp_manager_from_request_tools(&original_body.tools).await; let req_mcp_manager = if let Some(ref tools) = original_body.tools {
mcp_manager_from_request_tools(tools.as_slice()).await
} else {
None
};
let active_mcp = req_mcp_manager.as_ref().or(self.mcp_manager.as_ref()); let active_mcp = req_mcp_manager.as_ref().or(self.mcp_manager.as_ref());
let mut response_json: Value; let mut response_json: Value;
...@@ -183,6 +187,7 @@ impl OpenAIRouter { ...@@ -183,6 +187,7 @@ impl OpenAIRouter {
} }
} else { } else {
// No MCP - simple request // No MCP - simple request
let mut request_builder = self.client.post(&url).json(&payload); let mut request_builder = self.client.post(&url).json(&payload);
if let Some(h) = headers { if let Some(h) = headers {
request_builder = apply_request_headers(h, request_builder, true); request_builder = apply_request_headers(h, request_builder, true);
...@@ -385,6 +390,7 @@ impl crate::routers::RouterTrait for OpenAIRouter { ...@@ -385,6 +390,7 @@ impl crate::routers::RouterTrait for OpenAIRouter {
} }
}; };
if let Some(obj) = payload.as_object_mut() { if let Some(obj) = payload.as_object_mut() {
// Always remove SGLang-specific fields (unsupported by OpenAI)
for key in [ for key in [
"top_k", "top_k",
"min_p", "min_p",
...@@ -535,7 +541,7 @@ impl crate::routers::RouterTrait for OpenAIRouter { ...@@ -535,7 +541,7 @@ impl crate::routers::RouterTrait for OpenAIRouter {
.into_response(); .into_response();
} }
// Clone the body and override model if needed // Clone the body for validation and logic, but we'll build payload differently
let mut request_body = body.clone(); let mut request_body = body.clone();
if let Some(model) = model_id { if let Some(model) = model_id {
request_body.model = Some(model.to_string()); request_body.model = Some(model.to_string());
...@@ -690,7 +696,7 @@ impl crate::routers::RouterTrait for OpenAIRouter { ...@@ -690,7 +696,7 @@ impl crate::routers::RouterTrait for OpenAIRouter {
} }
// Always set store=false for upstream (we store internally) // Always set store=false for upstream (we store internally)
request_body.store = false; request_body.store = Some(false);
// Convert to JSON and strip SGLang-specific fields // Convert to JSON and strip SGLang-specific fields
let mut payload = match to_value(&request_body) { let mut payload = match to_value(&request_body) {
...@@ -704,14 +710,13 @@ impl crate::routers::RouterTrait for OpenAIRouter { ...@@ -704,14 +710,13 @@ impl crate::routers::RouterTrait for OpenAIRouter {
} }
}; };
// Remove SGLang-specific fields // Remove SGLang-specific fields only
if let Some(obj) = payload.as_object_mut() { if let Some(obj) = payload.as_object_mut() {
// Remove SGLang-specific fields (not part of OpenAI API)
for key in [ for key in [
"request_id", "request_id",
"priority", "priority",
"top_k", "top_k",
"frequency_penalty",
"presence_penalty",
"min_p", "min_p",
"min_tokens", "min_tokens",
"regex", "regex",
...@@ -732,10 +737,38 @@ impl crate::routers::RouterTrait for OpenAIRouter { ...@@ -732,10 +737,38 @@ impl crate::routers::RouterTrait for OpenAIRouter {
] { ] {
obj.remove(key); obj.remove(key);
} }
// XAI doesn't support the OPENAI item type input: https://platform.openai.com/docs/api-reference/responses/create#responses-create-input-input-item-list-item
// To Achieve XAI compatibility, strip extra fields from input messages (id, status)
// XAI doesn't support output_text as type for content with role of assistant
// so normalize content types: output_text -> input_text
if let Some(input_arr) = obj.get_mut("input").and_then(Value::as_array_mut) {
for item_obj in input_arr.iter_mut().filter_map(Value::as_object_mut) {
// Remove fields not universally supported
item_obj.remove("id");
item_obj.remove("status");
// Normalize content types to input_text (xAI compatibility)
if let Some(content_arr) =
item_obj.get_mut("content").and_then(Value::as_array_mut)
{
for content_obj in content_arr.iter_mut().filter_map(Value::as_object_mut) {
// Change output_text to input_text
if content_obj.get("type").and_then(Value::as_str)
== Some("output_text")
{
content_obj.insert(
"type".to_string(),
Value::String("input_text".to_string()),
);
}
}
}
}
}
} }
// Delegate to streaming or non-streaming handler // Delegate to streaming or non-streaming handler
if body.stream { if body.stream.unwrap_or(false) {
handle_streaming_response( handle_streaming_response(
&self.client, &self.client,
&self.circuit_breaker, &self.circuit_breaker,
......
...@@ -572,7 +572,7 @@ pub(super) fn apply_event_transformations_inplace( ...@@ -572,7 +572,7 @@ pub(super) fn apply_event_transformations_inplace(
.get_mut("response") .get_mut("response")
.and_then(|v| v.as_object_mut()) .and_then(|v| v.as_object_mut())
{ {
let desired_store = Value::Bool(original_request.store); let desired_store = Value::Bool(original_request.store.unwrap_or(false));
if response_obj.get("store") != Some(&desired_store) { if response_obj.get("store") != Some(&desired_store) {
response_obj.insert("store".to_string(), desired_store); response_obj.insert("store".to_string(), desired_store);
changed = true; changed = true;
...@@ -597,8 +597,13 @@ pub(super) fn apply_event_transformations_inplace( ...@@ -597,8 +597,13 @@ pub(super) fn apply_event_transformations_inplace(
if response_obj.get("tools").is_some() { if response_obj.get("tools").is_some() {
let requested_mcp = original_request let requested_mcp = original_request
.tools .tools
.iter() .as_ref()
.any(|t| matches!(t.r#type, ResponseToolType::Mcp)); .map(|tools| {
tools
.iter()
.any(|t| matches!(t.r#type, ResponseToolType::Mcp))
})
.unwrap_or(false);
if requested_mcp { if requested_mcp {
if let Some(mcp_tools) = build_mcp_tools_value(original_request) { if let Some(mcp_tools) = build_mcp_tools_value(original_request) {
...@@ -658,8 +663,8 @@ pub(super) fn apply_event_transformations_inplace( ...@@ -658,8 +663,8 @@ pub(super) fn apply_event_transformations_inplace(
/// Helper to build MCP tools value /// Helper to build MCP tools value
fn build_mcp_tools_value(original_body: &ResponsesRequest) -> Option<Value> { fn build_mcp_tools_value(original_body: &ResponsesRequest) -> Option<Value> {
let mcp_tool = original_body let tools = original_body.tools.as_ref()?;
.tools let mcp_tool = tools
.iter() .iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some())?; .find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some())?;
...@@ -1000,7 +1005,7 @@ pub(super) async fn handle_simple_streaming_passthrough( ...@@ -1000,7 +1005,7 @@ pub(super) async fn handle_simple_streaming_passthrough(
let (tx, rx) = mpsc::unbounded_channel::<Result<Bytes, io::Error>>(); let (tx, rx) = mpsc::unbounded_channel::<Result<Bytes, io::Error>>();
let should_store = original_body.store; let should_store = original_body.store.unwrap_or(false);
let original_request = original_body.clone(); let original_request = original_body.clone();
let persist_needed = original_request.conversation.is_some(); let persist_needed = original_request.conversation.is_some();
let previous_response_id = original_previous_response_id.clone(); let previous_response_id = original_previous_response_id.clone();
...@@ -1134,7 +1139,7 @@ pub(super) async fn handle_streaming_with_tool_interception( ...@@ -1134,7 +1139,7 @@ pub(super) async fn handle_streaming_with_tool_interception(
prepare_mcp_payload_for_streaming(&mut payload, active_mcp); prepare_mcp_payload_for_streaming(&mut payload, active_mcp);
let (tx, rx) = mpsc::unbounded_channel::<Result<Bytes, io::Error>>(); let (tx, rx) = mpsc::unbounded_channel::<Result<Bytes, io::Error>>();
let should_store = original_body.store; let should_store = original_body.store.unwrap_or(false);
let original_request = original_body.clone(); let original_request = original_body.clone();
let persist_needed = original_request.conversation.is_some(); let persist_needed = original_request.conversation.is_some();
let previous_response_id = original_previous_response_id.clone(); let previous_response_id = original_previous_response_id.clone();
...@@ -1161,9 +1166,13 @@ pub(super) async fn handle_streaming_with_tool_interception( ...@@ -1161,9 +1166,13 @@ pub(super) async fn handle_streaming_with_tool_interception(
let server_label = original_request let server_label = original_request
.tools .tools
.iter() .as_ref()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp)) .and_then(|tools| {
.and_then(|t| t.server_label.as_deref()) tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
.and_then(|t| t.server_label.as_deref())
})
.unwrap_or("mcp"); .unwrap_or("mcp");
loop { loop {
...@@ -1488,7 +1497,11 @@ pub(super) async fn handle_streaming_response( ...@@ -1488,7 +1497,11 @@ pub(super) async fn handle_streaming_response(
original_previous_response_id: Option<String>, original_previous_response_id: Option<String>,
) -> Response { ) -> Response {
// Check if MCP is active for this request // Check if MCP is active for this request
let req_mcp_manager = mcp_manager_from_request_tools(&original_body.tools).await; let req_mcp_manager = if let Some(ref tools) = original_body.tools {
mcp_manager_from_request_tools(tools.as_slice()).await
} else {
None
};
let active_mcp = req_mcp_manager.as_ref().or(mcp_manager); let active_mcp = req_mcp_manager.as_ref().or(mcp_manager);
// If no MCP is active, use simple pass-through streaming // If no MCP is active, use simple pass-through streaming
......
...@@ -89,7 +89,7 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() { ...@@ -89,7 +89,7 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
// Build a simple ResponsesRequest that will trigger the tool call // Build a simple ResponsesRequest that will trigger the tool call
let req = ResponsesRequest { let req = ResponsesRequest {
background: false, background: Some(false),
include: None, include: None,
input: ResponseInput::Text("search something".to_string()), input: ResponseInput::Text("search something".to_string()),
instructions: Some("Be brief".to_string()), instructions: Some("Be brief".to_string()),
...@@ -97,15 +97,15 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() { ...@@ -97,15 +97,15 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
max_tool_calls: None, max_tool_calls: None,
metadata: None, metadata: None,
model: Some("mock-model".to_string()), model: Some("mock-model".to_string()),
parallel_tool_calls: true, parallel_tool_calls: Some(true),
previous_response_id: None, previous_response_id: None,
reasoning: None, reasoning: None,
service_tier: ServiceTier::Auto, service_tier: Some(ServiceTier::Auto),
store: true, store: Some(true),
stream: false, stream: Some(false),
temperature: Some(0.2), temperature: Some(0.2),
tool_choice: ToolChoice::default(), tool_choice: Some(ToolChoice::default()),
tools: vec![ResponseTool { tools: Some(vec![ResponseTool {
r#type: ResponseToolType::Mcp, r#type: ResponseToolType::Mcp,
server_url: Some(mcp.url()), server_url: Some(mcp.url()),
authorization: None, authorization: None,
...@@ -113,15 +113,15 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() { ...@@ -113,15 +113,15 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
server_description: None, server_description: None,
require_approval: None, require_approval: None,
allowed_tools: None, allowed_tools: None,
}], }]),
top_logprobs: 0, top_logprobs: Some(0),
top_p: None, top_p: None,
truncation: Truncation::Disabled, truncation: Some(Truncation::Disabled),
user: None, user: None,
request_id: "resp_test_mcp_e2e".to_string(), request_id: "resp_test_mcp_e2e".to_string(),
priority: 0, priority: 0,
frequency_penalty: 0.0, frequency_penalty: Some(0.0),
presence_penalty: 0.0, presence_penalty: Some(0.0),
stop: None, stop: None,
top_k: -1, top_k: -1,
min_p: 0.0, min_p: 0.0,
...@@ -338,7 +338,7 @@ async fn test_conversations_crud_basic() { ...@@ -338,7 +338,7 @@ async fn test_conversations_crud_basic() {
#[test] #[test]
fn test_responses_request_creation() { fn test_responses_request_creation() {
let request = ResponsesRequest { let request = ResponsesRequest {
background: false, background: Some(false),
include: None, include: None,
input: ResponseInput::Text("Hello, world!".to_string()), input: ResponseInput::Text("Hello, world!".to_string()),
instructions: Some("Be helpful".to_string()), instructions: Some("Be helpful".to_string()),
...@@ -346,29 +346,29 @@ fn test_responses_request_creation() { ...@@ -346,29 +346,29 @@ fn test_responses_request_creation() {
max_tool_calls: None, max_tool_calls: None,
metadata: None, metadata: None,
model: Some("test-model".to_string()), model: Some("test-model".to_string()),
parallel_tool_calls: true, parallel_tool_calls: Some(true),
previous_response_id: None, previous_response_id: None,
reasoning: Some(ResponseReasoningParam { reasoning: Some(ResponseReasoningParam {
effort: Some(ReasoningEffort::Medium), effort: Some(ReasoningEffort::Medium),
summary: None, summary: None,
}), }),
service_tier: ServiceTier::Auto, service_tier: Some(ServiceTier::Auto),
store: true, store: Some(true),
stream: false, stream: Some(false),
temperature: Some(0.7), temperature: Some(0.7),
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto), tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)),
tools: vec![ResponseTool { tools: Some(vec![ResponseTool {
r#type: ResponseToolType::WebSearchPreview, r#type: ResponseToolType::WebSearchPreview,
..Default::default() ..Default::default()
}], }]),
top_logprobs: 5, top_logprobs: Some(5),
top_p: Some(0.9), top_p: Some(0.9),
truncation: Truncation::Disabled, truncation: Some(Truncation::Disabled),
user: Some("test-user".to_string()), user: Some("test-user".to_string()),
request_id: "resp_test123".to_string(), request_id: "resp_test123".to_string(),
priority: 0, priority: 0,
frequency_penalty: 0.0, frequency_penalty: Some(0.0),
presence_penalty: 0.0, presence_penalty: Some(0.0),
stop: None, stop: None,
top_k: -1, top_k: -1,
min_p: 0.0, min_p: 0.0,
...@@ -385,7 +385,7 @@ fn test_responses_request_creation() { ...@@ -385,7 +385,7 @@ fn test_responses_request_creation() {
#[test] #[test]
fn test_sampling_params_conversion() { fn test_sampling_params_conversion() {
let request = ResponsesRequest { let request = ResponsesRequest {
background: false, background: Some(false),
include: None, include: None,
input: ResponseInput::Text("Test".to_string()), input: ResponseInput::Text("Test".to_string()),
instructions: None, instructions: None,
...@@ -393,23 +393,23 @@ fn test_sampling_params_conversion() { ...@@ -393,23 +393,23 @@ fn test_sampling_params_conversion() {
max_tool_calls: None, max_tool_calls: None,
metadata: None, metadata: None,
model: Some("test-model".to_string()), model: Some("test-model".to_string()),
parallel_tool_calls: true, // Use default true parallel_tool_calls: Some(true), // Use default true
previous_response_id: None, previous_response_id: None,
reasoning: None, reasoning: None,
service_tier: ServiceTier::Auto, service_tier: Some(ServiceTier::Auto),
store: true, // Use default true store: Some(true), // Use default true
stream: false, stream: Some(false),
temperature: Some(0.8), temperature: Some(0.8),
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto), tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)),
tools: vec![], tools: Some(vec![]),
top_logprobs: 0, // Use default 0 top_logprobs: Some(0), // Use default 0
top_p: Some(0.95), top_p: Some(0.95),
truncation: Truncation::Auto, truncation: Some(Truncation::Auto),
user: None, user: None,
request_id: "resp_test456".to_string(), request_id: "resp_test456".to_string(),
priority: 0, priority: 0,
frequency_penalty: 0.1, frequency_penalty: Some(0.1),
presence_penalty: 0.2, presence_penalty: Some(0.2),
stop: None, stop: None,
top_k: 10, top_k: 10,
min_p: 0.05, min_p: 0.05,
...@@ -493,7 +493,7 @@ fn test_reasoning_param_default() { ...@@ -493,7 +493,7 @@ fn test_reasoning_param_default() {
#[test] #[test]
fn test_json_serialization() { fn test_json_serialization() {
let request = ResponsesRequest { let request = ResponsesRequest {
background: true, background: Some(true),
include: None, include: None,
input: ResponseInput::Text("Test input".to_string()), input: ResponseInput::Text("Test input".to_string()),
instructions: Some("Test instructions".to_string()), instructions: Some("Test instructions".to_string()),
...@@ -501,29 +501,29 @@ fn test_json_serialization() { ...@@ -501,29 +501,29 @@ fn test_json_serialization() {
max_tool_calls: Some(5), max_tool_calls: Some(5),
metadata: None, metadata: None,
model: Some("gpt-4".to_string()), model: Some("gpt-4".to_string()),
parallel_tool_calls: false, parallel_tool_calls: Some(false),
previous_response_id: None, previous_response_id: None,
reasoning: Some(ResponseReasoningParam { reasoning: Some(ResponseReasoningParam {
effort: Some(ReasoningEffort::High), effort: Some(ReasoningEffort::High),
summary: None, summary: None,
}), }),
service_tier: ServiceTier::Priority, service_tier: Some(ServiceTier::Priority),
store: false, store: Some(false),
stream: true, stream: Some(true),
temperature: Some(0.9), temperature: Some(0.9),
tool_choice: ToolChoice::Value(ToolChoiceValue::Required), tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Required)),
tools: vec![ResponseTool { tools: Some(vec![ResponseTool {
r#type: ResponseToolType::CodeInterpreter, r#type: ResponseToolType::CodeInterpreter,
..Default::default() ..Default::default()
}], }]),
top_logprobs: 10, top_logprobs: Some(10),
top_p: Some(0.8), top_p: Some(0.8),
truncation: Truncation::Auto, truncation: Some(Truncation::Auto),
user: Some("test_user".to_string()), user: Some("test_user".to_string()),
request_id: "resp_comprehensive_test".to_string(), request_id: "resp_comprehensive_test".to_string(),
priority: 1, priority: 1,
frequency_penalty: 0.3, frequency_penalty: Some(0.3),
presence_penalty: 0.4, presence_penalty: Some(0.4),
stop: None, stop: None,
top_k: 50, top_k: 50,
min_p: 0.1, min_p: 0.1,
...@@ -537,9 +537,9 @@ fn test_json_serialization() { ...@@ -537,9 +537,9 @@ fn test_json_serialization() {
assert_eq!(parsed.request_id, "resp_comprehensive_test"); assert_eq!(parsed.request_id, "resp_comprehensive_test");
assert_eq!(parsed.model, Some("gpt-4".to_string())); assert_eq!(parsed.model, Some("gpt-4".to_string()));
assert!(parsed.background); assert_eq!(parsed.background, Some(true));
assert!(parsed.stream); assert_eq!(parsed.stream, Some(true));
assert_eq!(parsed.tools.len(), 1); assert_eq!(parsed.tools.as_ref().map(|t| t.len()), Some(1));
} }
#[tokio::test] #[tokio::test]
...@@ -620,7 +620,7 @@ async fn test_multi_turn_loop_with_mcp() { ...@@ -620,7 +620,7 @@ async fn test_multi_turn_loop_with_mcp() {
// Build request with MCP tools // Build request with MCP tools
let req = ResponsesRequest { let req = ResponsesRequest {
background: false, background: Some(false),
include: None, include: None,
input: ResponseInput::Text("search for SGLang".to_string()), input: ResponseInput::Text("search for SGLang".to_string()),
instructions: Some("Be helpful".to_string()), instructions: Some("Be helpful".to_string()),
...@@ -628,30 +628,30 @@ async fn test_multi_turn_loop_with_mcp() { ...@@ -628,30 +628,30 @@ async fn test_multi_turn_loop_with_mcp() {
max_tool_calls: None, // No limit - test unlimited max_tool_calls: None, // No limit - test unlimited
metadata: None, metadata: None,
model: Some("mock-model".to_string()), model: Some("mock-model".to_string()),
parallel_tool_calls: true, parallel_tool_calls: Some(true),
previous_response_id: None, previous_response_id: None,
reasoning: None, reasoning: None,
service_tier: ServiceTier::Auto, service_tier: Some(ServiceTier::Auto),
store: true, store: Some(true),
stream: false, stream: Some(false),
temperature: Some(0.7), temperature: Some(0.7),
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto), tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)),
tools: vec![ResponseTool { tools: Some(vec![ResponseTool {
r#type: ResponseToolType::Mcp, r#type: ResponseToolType::Mcp,
server_url: Some(mcp.url()), server_url: Some(mcp.url()),
server_label: Some("mock".to_string()), server_label: Some("mock".to_string()),
server_description: Some("Mock MCP server for testing".to_string()), server_description: Some("Mock MCP server for testing".to_string()),
require_approval: Some("never".to_string()), require_approval: Some("never".to_string()),
..Default::default() ..Default::default()
}], }]),
top_logprobs: 0, top_logprobs: Some(0),
top_p: Some(1.0), top_p: Some(1.0),
truncation: Truncation::Disabled, truncation: Some(Truncation::Disabled),
user: None, user: None,
request_id: "resp_multi_turn_test".to_string(), request_id: "resp_multi_turn_test".to_string(),
priority: 0, priority: 0,
frequency_penalty: 0.0, frequency_penalty: Some(0.0),
presence_penalty: 0.0, presence_penalty: Some(0.0),
stop: None, stop: None,
top_k: 50, top_k: 50,
min_p: 0.0, min_p: 0.0,
...@@ -796,7 +796,7 @@ async fn test_max_tool_calls_limit() { ...@@ -796,7 +796,7 @@ async fn test_max_tool_calls_limit() {
.expect("router"); .expect("router");
let req = ResponsesRequest { let req = ResponsesRequest {
background: false, background: Some(false),
include: None, include: None,
input: ResponseInput::Text("test max calls".to_string()), input: ResponseInput::Text("test max calls".to_string()),
instructions: None, instructions: None,
...@@ -804,28 +804,28 @@ async fn test_max_tool_calls_limit() { ...@@ -804,28 +804,28 @@ async fn test_max_tool_calls_limit() {
max_tool_calls: Some(1), // Limit to 1 call max_tool_calls: Some(1), // Limit to 1 call
metadata: None, metadata: None,
model: Some("mock-model".to_string()), model: Some("mock-model".to_string()),
parallel_tool_calls: true, parallel_tool_calls: Some(true),
previous_response_id: None, previous_response_id: None,
reasoning: None, reasoning: None,
service_tier: ServiceTier::Auto, service_tier: Some(ServiceTier::Auto),
store: false, store: Some(false),
stream: false, stream: Some(false),
temperature: Some(0.7), temperature: Some(0.7),
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto), tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)),
tools: vec![ResponseTool { tools: Some(vec![ResponseTool {
r#type: ResponseToolType::Mcp, r#type: ResponseToolType::Mcp,
server_url: Some(mcp.url()), server_url: Some(mcp.url()),
server_label: Some("mock".to_string()), server_label: Some("mock".to_string()),
..Default::default() ..Default::default()
}], }]),
top_logprobs: 0, top_logprobs: Some(0),
top_p: Some(1.0), top_p: Some(1.0),
truncation: Truncation::Disabled, truncation: Some(Truncation::Disabled),
user: None, user: None,
request_id: "resp_max_calls_test".to_string(), request_id: "resp_max_calls_test".to_string(),
priority: 0, priority: 0,
frequency_penalty: 0.0, frequency_penalty: Some(0.0),
presence_penalty: 0.0, presence_penalty: Some(0.0),
stop: None, stop: None,
top_k: 50, top_k: 50,
min_p: 0.0, min_p: 0.0,
...@@ -990,7 +990,7 @@ async fn test_streaming_with_mcp_tool_calls() { ...@@ -990,7 +990,7 @@ async fn test_streaming_with_mcp_tool_calls() {
// Build streaming request with MCP tools // Build streaming request with MCP tools
let req = ResponsesRequest { let req = ResponsesRequest {
background: false, background: Some(false),
include: None, include: None,
input: ResponseInput::Text("search for something interesting".to_string()), input: ResponseInput::Text("search for something interesting".to_string()),
instructions: Some("Use tools when needed".to_string()), instructions: Some("Use tools when needed".to_string()),
...@@ -998,30 +998,30 @@ async fn test_streaming_with_mcp_tool_calls() { ...@@ -998,30 +998,30 @@ async fn test_streaming_with_mcp_tool_calls() {
max_tool_calls: Some(3), max_tool_calls: Some(3),
metadata: None, metadata: None,
model: Some("mock-model".to_string()), model: Some("mock-model".to_string()),
parallel_tool_calls: true, parallel_tool_calls: Some(true),
previous_response_id: None, previous_response_id: None,
reasoning: None, reasoning: None,
service_tier: ServiceTier::Auto, service_tier: Some(ServiceTier::Auto),
store: true, store: Some(true),
stream: true, // KEY: Enable streaming stream: Some(true), // KEY: Enable streaming
temperature: Some(0.7), temperature: Some(0.7),
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto), tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)),
tools: vec![ResponseTool { tools: Some(vec![ResponseTool {
r#type: ResponseToolType::Mcp, r#type: ResponseToolType::Mcp,
server_url: Some(mcp.url()), server_url: Some(mcp.url()),
server_label: Some("mock".to_string()), server_label: Some("mock".to_string()),
server_description: Some("Mock MCP for streaming test".to_string()), server_description: Some("Mock MCP for streaming test".to_string()),
require_approval: Some("never".to_string()), require_approval: Some("never".to_string()),
..Default::default() ..Default::default()
}], }]),
top_logprobs: 0, top_logprobs: Some(0),
top_p: Some(1.0), top_p: Some(1.0),
truncation: Truncation::Disabled, truncation: Some(Truncation::Disabled),
user: None, user: None,
request_id: "resp_streaming_mcp_test".to_string(), request_id: "resp_streaming_mcp_test".to_string(),
priority: 0, priority: 0,
frequency_penalty: 0.0, frequency_penalty: Some(0.0),
presence_penalty: 0.0, presence_penalty: Some(0.0),
stop: None, stop: None,
top_k: 50, top_k: 50,
min_p: 0.0, min_p: 0.0,
...@@ -1271,7 +1271,7 @@ async fn test_streaming_multi_turn_with_mcp() { ...@@ -1271,7 +1271,7 @@ async fn test_streaming_multi_turn_with_mcp() {
let (mut mcp, mut worker, router, _dir) = setup_streaming_mcp_test().await; let (mut mcp, mut worker, router, _dir) = setup_streaming_mcp_test().await;
let req = ResponsesRequest { let req = ResponsesRequest {
background: false, background: Some(false),
include: None, include: None,
input: ResponseInput::Text("complex query requiring multiple tool calls".to_string()), input: ResponseInput::Text("complex query requiring multiple tool calls".to_string()),
instructions: Some("Be thorough".to_string()), instructions: Some("Be thorough".to_string()),
...@@ -1279,28 +1279,28 @@ async fn test_streaming_multi_turn_with_mcp() { ...@@ -1279,28 +1279,28 @@ async fn test_streaming_multi_turn_with_mcp() {
max_tool_calls: Some(5), // Allow multiple rounds max_tool_calls: Some(5), // Allow multiple rounds
metadata: None, metadata: None,
model: Some("mock-model".to_string()), model: Some("mock-model".to_string()),
parallel_tool_calls: true, parallel_tool_calls: Some(true),
previous_response_id: None, previous_response_id: None,
reasoning: None, reasoning: None,
service_tier: ServiceTier::Auto, service_tier: Some(ServiceTier::Auto),
store: true, store: Some(true),
stream: true, stream: Some(true),
temperature: Some(0.8), temperature: Some(0.8),
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto), tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)),
tools: vec![ResponseTool { tools: Some(vec![ResponseTool {
r#type: ResponseToolType::Mcp, r#type: ResponseToolType::Mcp,
server_url: Some(mcp.url()), server_url: Some(mcp.url()),
server_label: Some("mock".to_string()), server_label: Some("mock".to_string()),
..Default::default() ..Default::default()
}], }]),
top_logprobs: 0, top_logprobs: Some(0),
top_p: Some(1.0), top_p: Some(1.0),
truncation: Truncation::Disabled, truncation: Some(Truncation::Disabled),
user: None, user: None,
request_id: "resp_streaming_multiturn_test".to_string(), request_id: "resp_streaming_multiturn_test".to_string(),
priority: 0, priority: 0,
frequency_penalty: 0.0, frequency_penalty: Some(0.0),
presence_penalty: 0.0, presence_penalty: Some(0.0),
stop: None, stop: None,
top_k: 50, top_k: 50,
min_p: 0.0, min_p: 0.0,
......
...@@ -234,7 +234,7 @@ async fn test_openai_router_responses_with_mock() { ...@@ -234,7 +234,7 @@ async fn test_openai_router_responses_with_mock() {
let request1 = ResponsesRequest { let request1 = ResponsesRequest {
model: Some("gpt-4o-mini".to_string()), model: Some("gpt-4o-mini".to_string()),
input: ResponseInput::Text("Say hi".to_string()), input: ResponseInput::Text("Say hi".to_string()),
store: true, store: Some(true),
..Default::default() ..Default::default()
}; };
...@@ -250,7 +250,7 @@ async fn test_openai_router_responses_with_mock() { ...@@ -250,7 +250,7 @@ async fn test_openai_router_responses_with_mock() {
let request2 = ResponsesRequest { let request2 = ResponsesRequest {
model: Some("gpt-4o-mini".to_string()), model: Some("gpt-4o-mini".to_string()),
input: ResponseInput::Text("Thanks".to_string()), input: ResponseInput::Text("Thanks".to_string()),
store: true, store: Some(true),
previous_response_id: Some(resp1_id.clone()), previous_response_id: Some(resp1_id.clone()),
..Default::default() ..Default::default()
}; };
...@@ -501,8 +501,8 @@ async fn test_openai_router_responses_streaming_with_mock() { ...@@ -501,8 +501,8 @@ async fn test_openai_router_responses_streaming_with_mock() {
instructions: Some("Be kind".to_string()), instructions: Some("Be kind".to_string()),
metadata: Some(metadata), metadata: Some(metadata),
previous_response_id: Some("resp_prev_chain".to_string()), previous_response_id: Some("resp_prev_chain".to_string()),
store: true, store: Some(true),
stream: true, stream: Some(true),
..Default::default() ..Default::default()
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment