Unverified Commit 63e84352 authored by Keyang Ru's avatar Keyang Ru Committed by GitHub
Browse files

[router] openai router: support grok model (#11511)

parent a20e7df8
......@@ -1073,8 +1073,8 @@ fn generate_request_id() -> String {
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ResponsesRequest {
/// Run the request in the background
#[serde(default)]
pub background: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub background: Option<bool>,
/// Fields to include in the response
#[serde(skip_serializing_if = "Option::is_none")]
......@@ -1108,8 +1108,8 @@ pub struct ResponsesRequest {
pub conversation: Option<String>,
/// Whether to enable parallel tool calls
#[serde(default = "default_true")]
pub parallel_tool_calls: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub parallel_tool_calls: Option<bool>,
/// ID of previous response to continue from
#[serde(skip_serializing_if = "Option::is_none")]
......@@ -1120,40 +1120,40 @@ pub struct ResponsesRequest {
pub reasoning: Option<ResponseReasoningParam>,
/// Service tier
#[serde(default)]
pub service_tier: ServiceTier,
#[serde(skip_serializing_if = "Option::is_none")]
pub service_tier: Option<ServiceTier>,
/// Whether to store the response
#[serde(default = "default_true")]
pub store: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub store: Option<bool>,
/// Whether to stream the response
#[serde(default)]
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
/// Temperature for sampling
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
/// Tool choice behavior
#[serde(default)]
pub tool_choice: ToolChoice,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
/// Available tools
#[serde(default)]
pub tools: Vec<ResponseTool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<ResponseTool>>,
/// Number of top logprobs to return
#[serde(default)]
pub top_logprobs: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_logprobs: Option<u32>,
/// Top-p sampling parameter
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
/// Truncation behavior
#[serde(default)]
pub truncation: Truncation,
#[serde(skip_serializing_if = "Option::is_none")]
pub truncation: Option<Truncation>,
/// User identifier
#[serde(skip_serializing_if = "Option::is_none")]
......@@ -1168,12 +1168,12 @@ pub struct ResponsesRequest {
pub priority: i32,
/// Frequency penalty
#[serde(default)]
pub frequency_penalty: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
/// Presence penalty
#[serde(default)]
pub presence_penalty: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
/// Stop sequences
#[serde(skip_serializing_if = "Option::is_none")]
......@@ -1210,7 +1210,7 @@ fn default_repetition_penalty() -> f32 {
impl Default for ResponsesRequest {
fn default() -> Self {
Self {
background: false,
background: None,
include: None,
input: ResponseInput::Text(String::new()),
instructions: None,
......@@ -1219,23 +1219,23 @@ impl Default for ResponsesRequest {
metadata: None,
model: None,
conversation: None,
parallel_tool_calls: true,
parallel_tool_calls: None,
previous_response_id: None,
reasoning: None,
service_tier: ServiceTier::default(),
store: true,
stream: false,
service_tier: None,
store: None,
stream: None,
temperature: None,
tool_choice: ToolChoice::default(),
tools: Vec::new(),
top_logprobs: 0,
tool_choice: None,
tools: None,
top_logprobs: None,
top_p: None,
truncation: Truncation::default(),
truncation: None,
user: None,
request_id: generate_request_id(),
priority: 0,
frequency_penalty: 0.0,
presence_penalty: 0.0,
frequency_penalty: None,
presence_penalty: None,
stop: None,
top_k: default_top_k(),
min_p: 0.0,
......@@ -1299,14 +1299,18 @@ impl ResponsesRequest {
"top_p".to_string(),
Value::Number(Number::from_f64(top_p as f64).unwrap()),
);
if let Some(fp) = self.frequency_penalty {
params.insert(
"frequency_penalty".to_string(),
Value::Number(Number::from_f64(self.frequency_penalty as f64).unwrap()),
Value::Number(Number::from_f64(fp as f64).unwrap()),
);
}
if let Some(pp) = self.presence_penalty {
params.insert(
"presence_penalty".to_string(),
Value::Number(Number::from_f64(self.presence_penalty as f64).unwrap()),
Value::Number(Number::from_f64(pp as f64).unwrap()),
);
}
params.insert("top_k".to_string(), Value::Number(Number::from(self.top_k)));
params.insert(
"min_p".to_string(),
......@@ -1337,7 +1341,7 @@ impl ResponsesRequest {
impl GenerationRequest for ResponsesRequest {
fn is_stream(&self) -> bool {
self.stream
self.stream.unwrap_or(false)
}
fn get_model(&self) -> Option<&str> {
......@@ -1523,13 +1527,13 @@ impl ResponsesResponse {
max_output_tokens: request.max_output_tokens,
model: model_name,
output,
parallel_tool_calls: request.parallel_tool_calls,
parallel_tool_calls: request.parallel_tool_calls.unwrap_or(true),
previous_response_id: request.previous_response_id.clone(),
reasoning: request.reasoning.as_ref().map(|r| ReasoningInfo {
effort: r.effort.as_ref().map(|e| format!("{:?}", e)),
summary: None,
}),
store: request.store,
store: request.store.unwrap_or(false),
temperature: request.temperature,
text: Some(ResponseTextFormat {
format: TextFormatType {
......@@ -1537,17 +1541,19 @@ impl ResponsesResponse {
},
}),
tool_choice: match &request.tool_choice {
ToolChoice::Value(ToolChoiceValue::Auto) => "auto".to_string(),
ToolChoice::Value(ToolChoiceValue::Required) => "required".to_string(),
ToolChoice::Value(ToolChoiceValue::None) => "none".to_string(),
ToolChoice::Function { .. } => "function".to_string(),
ToolChoice::AllowedTools { mode, .. } => mode.clone(),
Some(ToolChoice::Value(ToolChoiceValue::Auto)) => "auto".to_string(),
Some(ToolChoice::Value(ToolChoiceValue::Required)) => "required".to_string(),
Some(ToolChoice::Value(ToolChoiceValue::None)) => "none".to_string(),
Some(ToolChoice::Function { .. }) => "function".to_string(),
Some(ToolChoice::AllowedTools { mode, .. }) => mode.clone(),
None => "auto".to_string(),
},
tools: request.tools.clone(),
tools: request.tools.clone().unwrap_or_default(),
top_p: request.top_p,
truncation: match &request.truncation {
Truncation::Auto => Some("auto".to_string()),
Truncation::Disabled => Some("disabled".to_string()),
Some(Truncation::Auto) => Some("auto".to_string()),
Some(Truncation::Disabled) => Some("disabled".to_string()),
None => None,
},
usage: usage.map(ResponsesUsage::Classic),
user: request.user.clone(),
......
......@@ -689,9 +689,13 @@ pub(super) async fn execute_tool_loop(
if state.total_calls > 0 {
let server_label = original_body
.tools
.as_ref()
.and_then(|tools| {
tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
.and_then(|t| t.server_label.as_deref())
})
.unwrap_or("mcp");
// Build mcp_list_tools item
......@@ -747,9 +751,13 @@ pub(super) fn build_incomplete_response(
if let Some(output_array) = obj.get_mut("output").and_then(|v| v.as_array_mut()) {
let server_label = original_body
.tools
.as_ref()
.and_then(|tools| {
tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
.and_then(|t| t.server_label.as_deref())
})
.unwrap_or("mcp");
// Find any function_call items and convert them to mcp_call (incomplete)
......
......@@ -129,7 +129,10 @@ pub(super) fn patch_streaming_response_json(
}
}
obj.insert("store".to_string(), Value::Bool(original_body.store));
obj.insert(
"store".to_string(),
Value::Bool(original_body.store.unwrap_or(false)),
);
if obj
.get("model")
......@@ -205,7 +208,7 @@ pub(super) fn rewrite_streaming_block(
let mut changed = false;
if let Some(response_obj) = parsed.get_mut("response").and_then(|v| v.as_object_mut()) {
let desired_store = Value::Bool(original_body.store);
let desired_store = Value::Bool(original_body.store.unwrap_or(false));
if response_obj.get("store") != Some(&desired_store) {
response_obj.insert("store".to_string(), desired_store);
changed = true;
......@@ -267,10 +270,11 @@ pub(super) fn rewrite_streaming_block(
/// Mask function tools as MCP tools in response for client
pub(super) fn mask_tools_as_mcp(resp: &mut Value, original_body: &ResponsesRequest) {
let mcp_tool = original_body
.tools
let mcp_tool = original_body.tools.as_ref().and_then(|tools| {
tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some());
.find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some())
});
let Some(t) = mcp_tool else {
return;
};
......
......@@ -148,7 +148,11 @@ impl OpenAIRouter {
original_previous_response_id: Option<String>,
) -> Response {
// Check if MCP is active for this request
let req_mcp_manager = mcp_manager_from_request_tools(&original_body.tools).await;
let req_mcp_manager = if let Some(ref tools) = original_body.tools {
mcp_manager_from_request_tools(tools.as_slice()).await
} else {
None
};
let active_mcp = req_mcp_manager.as_ref().or(self.mcp_manager.as_ref());
let mut response_json: Value;
......@@ -183,6 +187,7 @@ impl OpenAIRouter {
}
} else {
// No MCP - simple request
let mut request_builder = self.client.post(&url).json(&payload);
if let Some(h) = headers {
request_builder = apply_request_headers(h, request_builder, true);
......@@ -385,6 +390,7 @@ impl crate::routers::RouterTrait for OpenAIRouter {
}
};
if let Some(obj) = payload.as_object_mut() {
// Always remove SGLang-specific fields (unsupported by OpenAI)
for key in [
"top_k",
"min_p",
......@@ -535,7 +541,7 @@ impl crate::routers::RouterTrait for OpenAIRouter {
.into_response();
}
// Clone the body and override model if needed
// Clone the body for validation and logic, but we'll build payload differently
let mut request_body = body.clone();
if let Some(model) = model_id {
request_body.model = Some(model.to_string());
......@@ -690,7 +696,7 @@ impl crate::routers::RouterTrait for OpenAIRouter {
}
// Always set store=false for upstream (we store internally)
request_body.store = false;
request_body.store = Some(false);
// Convert to JSON and strip SGLang-specific fields
let mut payload = match to_value(&request_body) {
......@@ -704,14 +710,13 @@ impl crate::routers::RouterTrait for OpenAIRouter {
}
};
// Remove SGLang-specific fields
// Remove SGLang-specific fields only
if let Some(obj) = payload.as_object_mut() {
// Remove SGLang-specific fields (not part of OpenAI API)
for key in [
"request_id",
"priority",
"top_k",
"frequency_penalty",
"presence_penalty",
"min_p",
"min_tokens",
"regex",
......@@ -732,10 +737,38 @@ impl crate::routers::RouterTrait for OpenAIRouter {
] {
obj.remove(key);
}
// XAI doesn't support the OPENAI item type input: https://platform.openai.com/docs/api-reference/responses/create#responses-create-input-input-item-list-item
// To Achieve XAI compatibility, strip extra fields from input messages (id, status)
// XAI doesn't support output_text as type for content with role of assistant
// so normalize content types: output_text -> input_text
if let Some(input_arr) = obj.get_mut("input").and_then(Value::as_array_mut) {
for item_obj in input_arr.iter_mut().filter_map(Value::as_object_mut) {
// Remove fields not universally supported
item_obj.remove("id");
item_obj.remove("status");
// Normalize content types to input_text (xAI compatibility)
if let Some(content_arr) =
item_obj.get_mut("content").and_then(Value::as_array_mut)
{
for content_obj in content_arr.iter_mut().filter_map(Value::as_object_mut) {
// Change output_text to input_text
if content_obj.get("type").and_then(Value::as_str)
== Some("output_text")
{
content_obj.insert(
"type".to_string(),
Value::String("input_text".to_string()),
);
}
}
}
}
}
}
// Delegate to streaming or non-streaming handler
if body.stream {
if body.stream.unwrap_or(false) {
handle_streaming_response(
&self.client,
&self.circuit_breaker,
......
......@@ -572,7 +572,7 @@ pub(super) fn apply_event_transformations_inplace(
.get_mut("response")
.and_then(|v| v.as_object_mut())
{
let desired_store = Value::Bool(original_request.store);
let desired_store = Value::Bool(original_request.store.unwrap_or(false));
if response_obj.get("store") != Some(&desired_store) {
response_obj.insert("store".to_string(), desired_store);
changed = true;
......@@ -597,8 +597,13 @@ pub(super) fn apply_event_transformations_inplace(
if response_obj.get("tools").is_some() {
let requested_mcp = original_request
.tools
.as_ref()
.map(|tools| {
tools
.iter()
.any(|t| matches!(t.r#type, ResponseToolType::Mcp));
.any(|t| matches!(t.r#type, ResponseToolType::Mcp))
})
.unwrap_or(false);
if requested_mcp {
if let Some(mcp_tools) = build_mcp_tools_value(original_request) {
......@@ -658,8 +663,8 @@ pub(super) fn apply_event_transformations_inplace(
/// Helper to build MCP tools value
fn build_mcp_tools_value(original_body: &ResponsesRequest) -> Option<Value> {
let mcp_tool = original_body
.tools
let tools = original_body.tools.as_ref()?;
let mcp_tool = tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some())?;
......@@ -1000,7 +1005,7 @@ pub(super) async fn handle_simple_streaming_passthrough(
let (tx, rx) = mpsc::unbounded_channel::<Result<Bytes, io::Error>>();
let should_store = original_body.store;
let should_store = original_body.store.unwrap_or(false);
let original_request = original_body.clone();
let persist_needed = original_request.conversation.is_some();
let previous_response_id = original_previous_response_id.clone();
......@@ -1134,7 +1139,7 @@ pub(super) async fn handle_streaming_with_tool_interception(
prepare_mcp_payload_for_streaming(&mut payload, active_mcp);
let (tx, rx) = mpsc::unbounded_channel::<Result<Bytes, io::Error>>();
let should_store = original_body.store;
let should_store = original_body.store.unwrap_or(false);
let original_request = original_body.clone();
let persist_needed = original_request.conversation.is_some();
let previous_response_id = original_previous_response_id.clone();
......@@ -1161,9 +1166,13 @@ pub(super) async fn handle_streaming_with_tool_interception(
let server_label = original_request
.tools
.as_ref()
.and_then(|tools| {
tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
.and_then(|t| t.server_label.as_deref())
})
.unwrap_or("mcp");
loop {
......@@ -1488,7 +1497,11 @@ pub(super) async fn handle_streaming_response(
original_previous_response_id: Option<String>,
) -> Response {
// Check if MCP is active for this request
let req_mcp_manager = mcp_manager_from_request_tools(&original_body.tools).await;
let req_mcp_manager = if let Some(ref tools) = original_body.tools {
mcp_manager_from_request_tools(tools.as_slice()).await
} else {
None
};
let active_mcp = req_mcp_manager.as_ref().or(mcp_manager);
// If no MCP is active, use simple pass-through streaming
......
......@@ -89,7 +89,7 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
// Build a simple ResponsesRequest that will trigger the tool call
let req = ResponsesRequest {
background: false,
background: Some(false),
include: None,
input: ResponseInput::Text("search something".to_string()),
instructions: Some("Be brief".to_string()),
......@@ -97,15 +97,15 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
max_tool_calls: None,
metadata: None,
model: Some("mock-model".to_string()),
parallel_tool_calls: true,
parallel_tool_calls: Some(true),
previous_response_id: None,
reasoning: None,
service_tier: ServiceTier::Auto,
store: true,
stream: false,
service_tier: Some(ServiceTier::Auto),
store: Some(true),
stream: Some(false),
temperature: Some(0.2),
tool_choice: ToolChoice::default(),
tools: vec![ResponseTool {
tool_choice: Some(ToolChoice::default()),
tools: Some(vec![ResponseTool {
r#type: ResponseToolType::Mcp,
server_url: Some(mcp.url()),
authorization: None,
......@@ -113,15 +113,15 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
server_description: None,
require_approval: None,
allowed_tools: None,
}],
top_logprobs: 0,
}]),
top_logprobs: Some(0),
top_p: None,
truncation: Truncation::Disabled,
truncation: Some(Truncation::Disabled),
user: None,
request_id: "resp_test_mcp_e2e".to_string(),
priority: 0,
frequency_penalty: 0.0,
presence_penalty: 0.0,
frequency_penalty: Some(0.0),
presence_penalty: Some(0.0),
stop: None,
top_k: -1,
min_p: 0.0,
......@@ -338,7 +338,7 @@ async fn test_conversations_crud_basic() {
#[test]
fn test_responses_request_creation() {
let request = ResponsesRequest {
background: false,
background: Some(false),
include: None,
input: ResponseInput::Text("Hello, world!".to_string()),
instructions: Some("Be helpful".to_string()),
......@@ -346,29 +346,29 @@ fn test_responses_request_creation() {
max_tool_calls: None,
metadata: None,
model: Some("test-model".to_string()),
parallel_tool_calls: true,
parallel_tool_calls: Some(true),
previous_response_id: None,
reasoning: Some(ResponseReasoningParam {
effort: Some(ReasoningEffort::Medium),
summary: None,
}),
service_tier: ServiceTier::Auto,
store: true,
stream: false,
service_tier: Some(ServiceTier::Auto),
store: Some(true),
stream: Some(false),
temperature: Some(0.7),
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
tools: vec![ResponseTool {
tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)),
tools: Some(vec![ResponseTool {
r#type: ResponseToolType::WebSearchPreview,
..Default::default()
}],
top_logprobs: 5,
}]),
top_logprobs: Some(5),
top_p: Some(0.9),
truncation: Truncation::Disabled,
truncation: Some(Truncation::Disabled),
user: Some("test-user".to_string()),
request_id: "resp_test123".to_string(),
priority: 0,
frequency_penalty: 0.0,
presence_penalty: 0.0,
frequency_penalty: Some(0.0),
presence_penalty: Some(0.0),
stop: None,
top_k: -1,
min_p: 0.0,
......@@ -385,7 +385,7 @@ fn test_responses_request_creation() {
#[test]
fn test_sampling_params_conversion() {
let request = ResponsesRequest {
background: false,
background: Some(false),
include: None,
input: ResponseInput::Text("Test".to_string()),
instructions: None,
......@@ -393,23 +393,23 @@ fn test_sampling_params_conversion() {
max_tool_calls: None,
metadata: None,
model: Some("test-model".to_string()),
parallel_tool_calls: true, // Use default true
parallel_tool_calls: Some(true), // Use default true
previous_response_id: None,
reasoning: None,
service_tier: ServiceTier::Auto,
store: true, // Use default true
stream: false,
service_tier: Some(ServiceTier::Auto),
store: Some(true), // Use default true
stream: Some(false),
temperature: Some(0.8),
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
tools: vec![],
top_logprobs: 0, // Use default 0
tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)),
tools: Some(vec![]),
top_logprobs: Some(0), // Use default 0
top_p: Some(0.95),
truncation: Truncation::Auto,
truncation: Some(Truncation::Auto),
user: None,
request_id: "resp_test456".to_string(),
priority: 0,
frequency_penalty: 0.1,
presence_penalty: 0.2,
frequency_penalty: Some(0.1),
presence_penalty: Some(0.2),
stop: None,
top_k: 10,
min_p: 0.05,
......@@ -493,7 +493,7 @@ fn test_reasoning_param_default() {
#[test]
fn test_json_serialization() {
let request = ResponsesRequest {
background: true,
background: Some(true),
include: None,
input: ResponseInput::Text("Test input".to_string()),
instructions: Some("Test instructions".to_string()),
......@@ -501,29 +501,29 @@ fn test_json_serialization() {
max_tool_calls: Some(5),
metadata: None,
model: Some("gpt-4".to_string()),
parallel_tool_calls: false,
parallel_tool_calls: Some(false),
previous_response_id: None,
reasoning: Some(ResponseReasoningParam {
effort: Some(ReasoningEffort::High),
summary: None,
}),
service_tier: ServiceTier::Priority,
store: false,
stream: true,
service_tier: Some(ServiceTier::Priority),
store: Some(false),
stream: Some(true),
temperature: Some(0.9),
tool_choice: ToolChoice::Value(ToolChoiceValue::Required),
tools: vec![ResponseTool {
tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Required)),
tools: Some(vec![ResponseTool {
r#type: ResponseToolType::CodeInterpreter,
..Default::default()
}],
top_logprobs: 10,
}]),
top_logprobs: Some(10),
top_p: Some(0.8),
truncation: Truncation::Auto,
truncation: Some(Truncation::Auto),
user: Some("test_user".to_string()),
request_id: "resp_comprehensive_test".to_string(),
priority: 1,
frequency_penalty: 0.3,
presence_penalty: 0.4,
frequency_penalty: Some(0.3),
presence_penalty: Some(0.4),
stop: None,
top_k: 50,
min_p: 0.1,
......@@ -537,9 +537,9 @@ fn test_json_serialization() {
assert_eq!(parsed.request_id, "resp_comprehensive_test");
assert_eq!(parsed.model, Some("gpt-4".to_string()));
assert!(parsed.background);
assert!(parsed.stream);
assert_eq!(parsed.tools.len(), 1);
assert_eq!(parsed.background, Some(true));
assert_eq!(parsed.stream, Some(true));
assert_eq!(parsed.tools.as_ref().map(|t| t.len()), Some(1));
}
#[tokio::test]
......@@ -620,7 +620,7 @@ async fn test_multi_turn_loop_with_mcp() {
// Build request with MCP tools
let req = ResponsesRequest {
background: false,
background: Some(false),
include: None,
input: ResponseInput::Text("search for SGLang".to_string()),
instructions: Some("Be helpful".to_string()),
......@@ -628,30 +628,30 @@ async fn test_multi_turn_loop_with_mcp() {
max_tool_calls: None, // No limit - test unlimited
metadata: None,
model: Some("mock-model".to_string()),
parallel_tool_calls: true,
parallel_tool_calls: Some(true),
previous_response_id: None,
reasoning: None,
service_tier: ServiceTier::Auto,
store: true,
stream: false,
service_tier: Some(ServiceTier::Auto),
store: Some(true),
stream: Some(false),
temperature: Some(0.7),
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
tools: vec![ResponseTool {
tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)),
tools: Some(vec![ResponseTool {
r#type: ResponseToolType::Mcp,
server_url: Some(mcp.url()),
server_label: Some("mock".to_string()),
server_description: Some("Mock MCP server for testing".to_string()),
require_approval: Some("never".to_string()),
..Default::default()
}],
top_logprobs: 0,
}]),
top_logprobs: Some(0),
top_p: Some(1.0),
truncation: Truncation::Disabled,
truncation: Some(Truncation::Disabled),
user: None,
request_id: "resp_multi_turn_test".to_string(),
priority: 0,
frequency_penalty: 0.0,
presence_penalty: 0.0,
frequency_penalty: Some(0.0),
presence_penalty: Some(0.0),
stop: None,
top_k: 50,
min_p: 0.0,
......@@ -796,7 +796,7 @@ async fn test_max_tool_calls_limit() {
.expect("router");
let req = ResponsesRequest {
background: false,
background: Some(false),
include: None,
input: ResponseInput::Text("test max calls".to_string()),
instructions: None,
......@@ -804,28 +804,28 @@ async fn test_max_tool_calls_limit() {
max_tool_calls: Some(1), // Limit to 1 call
metadata: None,
model: Some("mock-model".to_string()),
parallel_tool_calls: true,
parallel_tool_calls: Some(true),
previous_response_id: None,
reasoning: None,
service_tier: ServiceTier::Auto,
store: false,
stream: false,
service_tier: Some(ServiceTier::Auto),
store: Some(false),
stream: Some(false),
temperature: Some(0.7),
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
tools: vec![ResponseTool {
tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)),
tools: Some(vec![ResponseTool {
r#type: ResponseToolType::Mcp,
server_url: Some(mcp.url()),
server_label: Some("mock".to_string()),
..Default::default()
}],
top_logprobs: 0,
}]),
top_logprobs: Some(0),
top_p: Some(1.0),
truncation: Truncation::Disabled,
truncation: Some(Truncation::Disabled),
user: None,
request_id: "resp_max_calls_test".to_string(),
priority: 0,
frequency_penalty: 0.0,
presence_penalty: 0.0,
frequency_penalty: Some(0.0),
presence_penalty: Some(0.0),
stop: None,
top_k: 50,
min_p: 0.0,
......@@ -990,7 +990,7 @@ async fn test_streaming_with_mcp_tool_calls() {
// Build streaming request with MCP tools
let req = ResponsesRequest {
background: false,
background: Some(false),
include: None,
input: ResponseInput::Text("search for something interesting".to_string()),
instructions: Some("Use tools when needed".to_string()),
......@@ -998,30 +998,30 @@ async fn test_streaming_with_mcp_tool_calls() {
max_tool_calls: Some(3),
metadata: None,
model: Some("mock-model".to_string()),
parallel_tool_calls: true,
parallel_tool_calls: Some(true),
previous_response_id: None,
reasoning: None,
service_tier: ServiceTier::Auto,
store: true,
stream: true, // KEY: Enable streaming
service_tier: Some(ServiceTier::Auto),
store: Some(true),
stream: Some(true), // KEY: Enable streaming
temperature: Some(0.7),
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
tools: vec![ResponseTool {
tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)),
tools: Some(vec![ResponseTool {
r#type: ResponseToolType::Mcp,
server_url: Some(mcp.url()),
server_label: Some("mock".to_string()),
server_description: Some("Mock MCP for streaming test".to_string()),
require_approval: Some("never".to_string()),
..Default::default()
}],
top_logprobs: 0,
}]),
top_logprobs: Some(0),
top_p: Some(1.0),
truncation: Truncation::Disabled,
truncation: Some(Truncation::Disabled),
user: None,
request_id: "resp_streaming_mcp_test".to_string(),
priority: 0,
frequency_penalty: 0.0,
presence_penalty: 0.0,
frequency_penalty: Some(0.0),
presence_penalty: Some(0.0),
stop: None,
top_k: 50,
min_p: 0.0,
......@@ -1271,7 +1271,7 @@ async fn test_streaming_multi_turn_with_mcp() {
let (mut mcp, mut worker, router, _dir) = setup_streaming_mcp_test().await;
let req = ResponsesRequest {
background: false,
background: Some(false),
include: None,
input: ResponseInput::Text("complex query requiring multiple tool calls".to_string()),
instructions: Some("Be thorough".to_string()),
......@@ -1279,28 +1279,28 @@ async fn test_streaming_multi_turn_with_mcp() {
max_tool_calls: Some(5), // Allow multiple rounds
metadata: None,
model: Some("mock-model".to_string()),
parallel_tool_calls: true,
parallel_tool_calls: Some(true),
previous_response_id: None,
reasoning: None,
service_tier: ServiceTier::Auto,
store: true,
stream: true,
service_tier: Some(ServiceTier::Auto),
store: Some(true),
stream: Some(true),
temperature: Some(0.8),
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
tools: vec![ResponseTool {
tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)),
tools: Some(vec![ResponseTool {
r#type: ResponseToolType::Mcp,
server_url: Some(mcp.url()),
server_label: Some("mock".to_string()),
..Default::default()
}],
top_logprobs: 0,
}]),
top_logprobs: Some(0),
top_p: Some(1.0),
truncation: Truncation::Disabled,
truncation: Some(Truncation::Disabled),
user: None,
request_id: "resp_streaming_multiturn_test".to_string(),
priority: 0,
frequency_penalty: 0.0,
presence_penalty: 0.0,
frequency_penalty: Some(0.0),
presence_penalty: Some(0.0),
stop: None,
top_k: 50,
min_p: 0.0,
......
......@@ -234,7 +234,7 @@ async fn test_openai_router_responses_with_mock() {
let request1 = ResponsesRequest {
model: Some("gpt-4o-mini".to_string()),
input: ResponseInput::Text("Say hi".to_string()),
store: true,
store: Some(true),
..Default::default()
};
......@@ -250,7 +250,7 @@ async fn test_openai_router_responses_with_mock() {
let request2 = ResponsesRequest {
model: Some("gpt-4o-mini".to_string()),
input: ResponseInput::Text("Thanks".to_string()),
store: true,
store: Some(true),
previous_response_id: Some(resp1_id.clone()),
..Default::default()
};
......@@ -501,8 +501,8 @@ async fn test_openai_router_responses_streaming_with_mock() {
instructions: Some("Be kind".to_string()),
metadata: Some(metadata),
previous_response_id: Some("resp_prev_chain".to_string()),
store: true,
stream: true,
store: Some(true),
stream: Some(true),
..Default::default()
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment