Unverified Commit 72392f29 authored by Keyang Ru's avatar Keyang Ru Committed by GitHub
Browse files

[router] basic mcp support for openai router response api (#10978)

parent c1c8dd1d
......@@ -683,6 +683,33 @@ pub struct CompletionStreamChoice {
pub struct ResponseTool {
#[serde(rename = "type")]
pub r#type: ResponseToolType,
// MCP-specific fields (used when type == "mcp")
#[serde(skip_serializing_if = "Option::is_none")]
pub server_url: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub authorization: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub server_label: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub server_description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub require_approval: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub allowed_tools: Option<Vec<String>>,
}
impl Default for ResponseTool {
fn default() -> Self {
Self {
r#type: ResponseToolType::WebSearchPreview,
server_url: None,
authorization: None,
server_label: None,
server_description: None,
require_approval: None,
allowed_tools: None,
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
......@@ -690,6 +717,7 @@ pub struct ResponseTool {
pub enum ResponseToolType {
WebSearchPreview,
CodeInterpreter,
Mcp,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
......
......@@ -644,27 +644,96 @@ async fn responses_handler(
}))
.into_response()
} else {
Json(json!({
"id": format!("resp-{}", Uuid::new_v4()),
"object": "response",
"created_at": timestamp,
"model": "mock-model",
"output": [{
"type": "message",
"role": "assistant",
"content": [{
"type": "output_text",
"text": "This is a mock responses output."
}]
}],
"status": "completed",
"usage": {
"input_tokens": 10,
"output_tokens": 5,
"total_tokens": 15
}
}))
.into_response()
// If tools are provided and this is the first call (no previous_response_id),
// emit a single function_tool_call to trigger the router's MCP flow.
let has_tools = payload
.get("tools")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter().any(|tool| {
tool.get("type")
.and_then(|t| t.as_str())
.map(|t| t == "function")
.unwrap_or(false)
})
})
.unwrap_or(false);
let has_function_output = payload
.get("input")
.and_then(|v| v.as_array())
.map(|items| {
items.iter().any(|item| {
item.get("type")
.and_then(|t| t.as_str())
.map(|t| t == "function_call_output")
.unwrap_or(false)
})
})
.unwrap_or(false);
if has_tools && !has_function_output {
let rid = format!("resp-{}", Uuid::new_v4());
Json(json!({
"id": rid,
"object": "response",
"created_at": timestamp,
"model": "mock-model",
"output": [{
"type": "function_tool_call",
"id": "call_1",
"name": "brave_web_search",
"arguments": "{\"query\":\"SGLang router MCP integration\"}",
"status": "in_progress"
}],
"status": "in_progress",
"usage": null
}))
.into_response()
} else if has_tools && has_function_output {
Json(json!({
"id": format!("resp-{}", Uuid::new_v4()),
"object": "response",
"created_at": timestamp,
"model": "mock-model",
"output": [{
"type": "message",
"role": "assistant",
"content": [{
"type": "output_text",
"text": "Tool result consumed; here is the final answer."
}]
}],
"status": "completed",
"usage": {
"input_tokens": 12,
"output_tokens": 7,
"total_tokens": 19
}
}))
.into_response()
} else {
Json(json!({
"id": format!("resp-{}", Uuid::new_v4()),
"object": "response",
"created_at": timestamp,
"model": "mock-model",
"output": [{
"type": "message",
"role": "assistant",
"content": [{
"type": "output_text",
"text": "This is a mock responses output."
}]
}],
"status": "completed",
"usage": {
"input_tokens": 10,
"output_tokens": 5,
"total_tokens": 15
}
}))
.into_response()
}
}
}
......
......@@ -6,6 +6,186 @@ use sglang_router_rs::protocols::spec::{
ToolChoiceValue, Truncation, UsageInfo,
};
mod common;
use common::mock_mcp_server::MockMCPServer;
use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType};
use sglang_router_rs::config::{
CircuitBreakerConfig, ConnectionMode, HealthCheckConfig, PolicyConfig, RetryConfig,
RouterConfig, RoutingMode,
};
use sglang_router_rs::routers::RouterFactory;
use sglang_router_rs::server::AppContext;
use std::sync::Arc;
#[tokio::test]
async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
// Start mock MCP server
let mut mcp = MockMCPServer::start().await.expect("start mcp");
// Write a temp MCP config file
let mcp_yaml = format!(
"servers:\n - name: mock\n protocol: streamable\n url: {}\n",
mcp.url()
);
let dir = tempfile::tempdir().expect("tmpdir");
let cfg_path = dir.path().join("mcp.yaml");
std::fs::write(&cfg_path, mcp_yaml).expect("write mcp cfg");
// Start mock OpenAI worker
let mut worker = MockWorker::new(MockWorkerConfig {
port: 0,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
});
let worker_url = worker.start().await.expect("start worker");
// Build router config (HTTP OpenAI mode)
let router_cfg = RouterConfig {
mode: RoutingMode::OpenAI {
worker_urls: vec![worker_url],
},
connection_mode: ConnectionMode::Http,
policy: PolicyConfig::Random,
host: "127.0.0.1".to_string(),
port: 0,
max_payload_size: 8 * 1024 * 1024,
request_timeout_secs: 60,
worker_startup_timeout_secs: 5,
worker_startup_check_interval_secs: 1,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,
log_level: Some("warn".to_string()),
request_id_headers: None,
max_concurrent_requests: 32,
queue_size: 0,
queue_timeout_secs: 5,
rate_limit_tokens_per_second: None,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: HealthCheckConfig::default(),
enable_igw: false,
model_path: None,
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
};
// Create router and context
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx");
let router = RouterFactory::create_router(&Arc::new(ctx))
.await
.expect("router");
// Build a simple ResponsesRequest that will trigger the tool call
let req = ResponsesRequest {
background: false,
include: None,
input: ResponseInput::Text("search something".to_string()),
instructions: Some("Be brief".to_string()),
max_output_tokens: Some(64),
max_tool_calls: None,
metadata: None,
model: Some("mock-model".to_string()),
parallel_tool_calls: true,
previous_response_id: None,
reasoning: None,
service_tier: sglang_router_rs::protocols::spec::ServiceTier::Auto,
store: true,
stream: false,
temperature: Some(0.2),
tool_choice: sglang_router_rs::protocols::spec::ToolChoice::default(),
tools: vec![ResponseTool {
r#type: ResponseToolType::Mcp,
server_url: Some(mcp.url()),
authorization: None,
server_label: Some("mock".to_string()),
server_description: None,
require_approval: None,
allowed_tools: None,
}],
top_logprobs: 0,
top_p: None,
truncation: sglang_router_rs::protocols::spec::Truncation::Disabled,
user: None,
request_id: "resp_test_mcp_e2e".to_string(),
priority: 0,
frequency_penalty: 0.0,
presence_penalty: 0.0,
stop: None,
top_k: -1,
min_p: 0.0,
repetition_penalty: 1.0,
};
let resp = router
.route_responses(None, &req, req.model.as_deref())
.await;
assert_eq!(resp.status(), axum::http::StatusCode::OK);
let body_bytes = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.expect("Failed to read response body");
let body_json: serde_json::Value =
serde_json::from_slice(&body_bytes).expect("Failed to parse response JSON");
let output = body_json
.get("output")
.and_then(|v| v.as_array())
.expect("response output missing");
assert!(!output.is_empty(), "expected at least one output item");
let final_text = output
.iter()
.rev()
.filter_map(|entry| entry.get("content"))
.filter_map(|content| content.as_array())
.flat_map(|parts| parts.iter())
.filter_map(|part| part.get("text"))
.filter_map(|v| v.as_str())
.next();
if let Some(text) = final_text {
assert_eq!(text, "Tool result consumed; here is the final answer.");
} else {
let call_entry = output.iter().find(|entry| {
entry.get("type") == Some(&serde_json::Value::String("function_tool_call".into()))
});
assert!(call_entry.is_some(), "missing function tool call entry");
if let Some(entry) = call_entry {
assert_eq!(
entry.get("status").and_then(|v| v.as_str()),
Some("in_progress"),
"function call should be in progress when no content is returned"
);
}
}
let tools = body_json
.get("tools")
.and_then(|v| v.as_array())
.expect("tools array missing");
assert_eq!(tools.len(), 1);
let tool = tools.first().unwrap();
assert_eq!(tool.get("type").and_then(|v| v.as_str()), Some("mcp"));
assert_eq!(
tool.get("server_label").and_then(|v| v.as_str()),
Some("mock")
);
// Cleanup
worker.stop().await;
mcp.stop().await;
}
#[test]
fn test_responses_request_creation() {
let request = ResponsesRequest {
......@@ -29,6 +209,7 @@ fn test_responses_request_creation() {
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
tools: vec![ResponseTool {
r#type: ResponseToolType::WebSearchPreview,
..Default::default()
}],
top_logprobs: 5,
top_p: Some(0.9),
......@@ -179,6 +360,7 @@ fn test_json_serialization() {
tool_choice: ToolChoice::Value(ToolChoiceValue::Required),
tools: vec![ResponseTool {
r#type: ResponseToolType::CodeInterpreter,
..Default::default()
}],
top_logprobs: 10,
top_p: Some(0.8),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment