Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a28b394f
Unverified
Commit
a28b394f
authored
Oct 01, 2025
by
Keyang Ru
Committed by
GitHub
Oct 01, 2025
Browse files
[router] Add multi-turn tool calling loop support for MCP integration (#11143)
parent
96fe2d0f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
791 additions
and
238 deletions
+791
-238
sgl-router/src/protocols/spec.rs
sgl-router/src/protocols/spec.rs
+11
-0
sgl-router/src/routers/http/openai_router.rs
sgl-router/src/routers/http/openai_router.rs
+452
-238
sgl-router/tests/responses_api_test.rs
sgl-router/tests/responses_api_test.rs
+328
-0
No files found.
sgl-router/src/protocols/spec.rs
View file @
a28b394f
...
@@ -723,7 +723,10 @@ pub enum ResponseToolType {
...
@@ -723,7 +723,10 @@ pub enum ResponseToolType {
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ResponseReasoningParam
{
pub
struct
ResponseReasoningParam
{
#[serde(default
=
"default_reasoning_effort"
)]
#[serde(default
=
"default_reasoning_effort"
)]
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
effort
:
Option
<
ReasoningEffort
>
,
pub
effort
:
Option
<
ReasoningEffort
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
summary
:
Option
<
ReasoningSummary
>
,
}
}
fn
default_reasoning_effort
()
->
Option
<
ReasoningEffort
>
{
fn
default_reasoning_effort
()
->
Option
<
ReasoningEffort
>
{
...
@@ -738,6 +741,14 @@ pub enum ReasoningEffort {
...
@@ -738,6 +741,14 @@ pub enum ReasoningEffort {
High
,
High
,
}
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(rename_all
=
"snake_case"
)]
pub
enum
ReasoningSummary
{
Auto
,
Concise
,
Detailed
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(tag
=
"type"
)]
#[serde(tag
=
"type"
)]
#[serde(rename_all
=
"snake_case"
)]
#[serde(rename_all
=
"snake_case"
)]
...
...
sgl-router/src/routers/http/openai_router.rs
View file @
a28b394f
...
@@ -26,7 +26,6 @@ use std::{
...
@@ -26,7 +26,6 @@ use std::{
collections
::
HashMap
,
collections
::
HashMap
,
io
,
io
,
sync
::{
atomic
::
AtomicBool
,
Arc
},
sync
::{
atomic
::
AtomicBool
,
Arc
},
time
::
SystemTime
,
};
};
use
tokio
::
sync
::
mpsc
;
use
tokio
::
sync
::
mpsc
;
use
tokio_stream
::
wrappers
::
UnboundedReceiverStream
;
use
tokio_stream
::
wrappers
::
UnboundedReceiverStream
;
...
@@ -57,6 +56,69 @@ impl std::fmt::Debug for OpenAIRouter {
...
@@ -57,6 +56,69 @@ impl std::fmt::Debug for OpenAIRouter {
}
}
}
}
/// Configuration for MCP tool calling loops
#[derive(Debug,
Clone)]
struct
McpLoopConfig
{
/// Maximum iterations as safety limit (internal only, default: 10)
/// Prevents infinite loops when max_tool_calls is not set
max_iterations
:
usize
,
}
impl
Default
for
McpLoopConfig
{
fn
default
()
->
Self
{
Self
{
max_iterations
:
10
}
}
}
/// State for tracking multi-turn tool calling loop
struct
ToolLoopState
{
/// Current iteration number (starts at 0, increments with each tool call)
iteration
:
usize
,
/// Total number of tool calls executed
total_calls
:
usize
,
/// Conversation history (function_call and function_call_output items)
conversation_history
:
Vec
<
Value
>
,
/// Original user input (preserved for building resume payloads)
original_input
:
ResponseInput
,
}
impl
ToolLoopState
{
fn
new
(
original_input
:
ResponseInput
)
->
Self
{
Self
{
iteration
:
0
,
total_calls
:
0
,
conversation_history
:
Vec
::
new
(),
original_input
,
}
}
/// Record a tool call in the loop state
fn
record_call
(
&
mut
self
,
call_id
:
String
,
tool_name
:
String
,
args_json_str
:
String
,
output_str
:
String
,
)
{
// Add function_call item to history
let
func_item
=
json!
({
"type"
:
"function_call"
,
"call_id"
:
call_id
,
"name"
:
tool_name
,
"arguments"
:
args_json_str
});
self
.conversation_history
.push
(
func_item
);
// Add function_call_output item to history
let
output_item
=
json!
({
"type"
:
"function_call_output"
,
"call_id"
:
call_id
,
"output"
:
output_str
});
self
.conversation_history
.push
(
output_item
);
}
}
/// Helper that parses SSE frames from the OpenAI responses stream and
/// Helper that parses SSE frames from the OpenAI responses stream and
/// accumulates enough information to persist the final response locally.
/// accumulates enough information to persist the final response locally.
struct
StreamingResponseAccumulator
{
struct
StreamingResponseAccumulator
{
...
@@ -388,126 +450,32 @@ impl OpenAIRouter {
...
@@ -388,126 +450,32 @@ impl OpenAIRouter {
obj
.insert
(
"store"
.to_string
(),
Value
::
Bool
(
original_body
.store
));
obj
.insert
(
"store"
.to_string
(),
Value
::
Bool
(
original_body
.store
));
}
}
let
mut
final_response_json
=
openai_response_json
;
// If MCP is active and we detect a function call, enter the tool loop
let
mut
final_response_json
=
if
let
Some
(
mcp
)
=
active_mcp
{
if
let
Some
(
mcp
)
=
active_mcp
{
if
Self
::
extract_function_call
(
&
openai_response_json
)
.is_some
()
{
if
let
Some
((
call_id
,
tool_name
,
args_json_str
))
=
// Use the loop to handle potentially multiple tool calls
Self
::
extract_function_call
(
&
final_response_json
)
let
loop_config
=
McpLoopConfig
::
default
();
{
info!
(
"Detected function call: name={}, call_id={}, args={}"
,
tool_name
,
call_id
,
args_json_str
);
let
call_started
=
SystemTime
::
now
();
let
call_result
=
Self
::
execute_mcp_call
(
mcp
,
&
tool_name
,
&
args_json_str
)
.await
;
let
call_duration_ms
=
call_started
.elapsed
()
.unwrap_or_default
()
.as_millis
();
let
(
output_payload
,
call_ok
,
call_error
)
=
match
call_result
{
Ok
((
server
,
out
))
=>
{
info!
(
call_id
=
%
call_id
,
tool_name
=
%
tool_name
,
server
=
%
server
,
duration_ms
=
call_duration_ms
,
"MCP tool call succeeded"
);
(
out
,
true
,
None
)
}
Err
(
err
)
=>
{
warn!
(
call_id
=
%
call_id
,
tool_name
=
%
tool_name
,
duration_ms
=
call_duration_ms
,
error
=
%
err
,
"MCP tool call failed"
);
(
serde_json
::
json!
({
"error"
:
err
})
.to_string
(),
false
,
Some
(
err
),
)
}
};
match
self
match
self
.
resume_with_tool_result
(
ResumeWithToolArgs
{
.
execute_tool_loop
(
url
:
&
url
,
&
url
,
headers
,
headers
,
original_payload
:
&
payload
,
payload
.clone
(),
call_id
:
&
call_id
,
tool_name
:
&
tool_name
,
args_json_str
:
&
args_json_str
,
output_str
:
&
output_payload
,
original_body
,
original_body
,
})
mcp
,
&
loop_config
,
)
.await
.await
{
{
Ok
(
mut
resumed_json
)
=>
{
Ok
(
loop_result
)
=>
loop_result
,
// Inject MCP output items (mcp_list_tools and mcp_call)
let
server_label
=
original_body
.tools
.iter
()
.find
(|
t
|
matches!
(
t
.r
#
type
,
ResponseToolType
::
Mcp
))
.and_then
(|
t
|
t
.server_label
.as_deref
())
.unwrap_or
(
"mcp"
);
if
let
Err
(
inject_err
)
=
Self
::
inject_mcp_output_items
(
&
mut
resumed_json
,
mcp
,
McpOutputItemsArgs
{
tool_name
:
&
tool_name
,
args_json
:
&
args_json_str
,
output
:
&
output_payload
,
server_label
,
success
:
call_ok
,
error
:
call_error
.as_deref
(),
},
)
{
warn!
(
"Failed to inject MCP output items: {}"
,
inject_err
);
}
if
!
call_ok
{
if
let
Some
(
obj
)
=
resumed_json
.as_object_mut
()
{
let
metadata_value
=
obj
.entry
(
"metadata"
)
.or_insert_with
(||
{
Value
::
Object
(
serde_json
::
Map
::
new
())
});
if
let
Some
(
metadata
)
=
metadata_value
.as_object_mut
()
{
if
let
Some
(
err_msg
)
=
call_error
.as_ref
()
{
metadata
.insert
(
"mcp_error"
.to_string
(),
Value
::
String
(
err_msg
.clone
()),
);
}
}
}
}
final_response_json
=
resumed_json
;
}
Err
(
err
)
=>
{
Err
(
err
)
=>
{
warn!
(
"
Failed to resume with tool result
: {}"
,
err
);
warn!
(
"
Tool loop failed
: {}"
,
err
);
let
error_body
=
json!
({
let
error_body
=
json!
({
"error"
:
{
"error"
:
{
"message"
:
format!
(
"message"
:
format!
(
"Tool loop failed: {}"
,
err
),
"Failed to resume with tool result: {}"
,
err
),
"type"
:
"internal_error"
,
"type"
:
"internal_error"
,
}
}
})
})
.to_string
();
.to_string
();
return
(
return
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
StatusCode
::
INTERNAL_SERVER_ERROR
,
[(
"content-type"
,
"application/json"
)],
[(
"content-type"
,
"application/json"
)],
...
@@ -517,10 +485,14 @@ impl OpenAIRouter {
...
@@ -517,10 +485,14 @@ impl OpenAIRouter {
}
}
}
}
}
else
{
}
else
{
info!
(
"No function call found in upstream response; skipping MCP"
);
// No function call detected, use response as-is
openai_response_json
}
}
}
}
else
{
openai_response_json
};
// Mask tools back to MCP format for client
Self
::
mask_tools_as_mcp
(
&
mut
final_response_json
,
original_body
);
Self
::
mask_tools_as_mcp
(
&
mut
final_response_json
,
original_body
);
if
original_body
.store
{
if
original_body
.store
{
if
let
Err
(
e
)
=
self
if
let
Err
(
e
)
=
self
...
@@ -1040,26 +1012,6 @@ impl OpenAIRouter {
...
@@ -1040,26 +1012,6 @@ impl OpenAIRouter {
}
}
}
}
struct
ResumeWithToolArgs
<
'a
>
{
url
:
&
'a
str
,
headers
:
Option
<&
'a
HeaderMap
>
,
original_payload
:
&
'a
Value
,
call_id
:
&
'a
str
,
tool_name
:
&
'a
str
,
args_json_str
:
&
'a
str
,
output_str
:
&
'a
str
,
original_body
:
&
'a
ResponsesRequest
,
}
struct
McpOutputItemsArgs
<
'a
>
{
tool_name
:
&
'a
str
,
args_json
:
&
'a
str
,
output
:
&
'a
str
,
server_label
:
&
'a
str
,
success
:
bool
,
error
:
Option
<&
'a
str
>
,
}
impl
OpenAIRouter
{
impl
OpenAIRouter
{
fn
extract_function_call
(
resp
:
&
Value
)
->
Option
<
(
String
,
String
,
String
)
>
{
fn
extract_function_call
(
resp
:
&
Value
)
->
Option
<
(
String
,
String
,
String
)
>
{
let
output
=
resp
.get
(
"output"
)
?
.as_array
()
?
;
let
output
=
resp
.get
(
"output"
)
?
.as_array
()
?
;
...
@@ -1150,6 +1102,375 @@ impl OpenAIRouter {
...
@@ -1150,6 +1102,375 @@ impl OpenAIRouter {
Ok
((
server_name
,
output_str
))
Ok
((
server_name
,
output_str
))
}
}
/// Build a resume payload with conversation history
fn
build_resume_payload
(
base_payload
:
&
Value
,
conversation_history
:
&
[
Value
],
original_input
:
&
ResponseInput
,
tools_json
:
&
Value
,
)
->
Result
<
Value
,
String
>
{
// Clone the base payload which already has cleaned fields
let
mut
payload
=
base_payload
.clone
();
let
obj
=
payload
.as_object_mut
()
.ok_or_else
(||
"payload not an object"
.to_string
())
?
;
// Build input array: start with original user input
let
mut
input_array
=
Vec
::
new
();
// Add original user message
// For structured input, serialize the original input items
match
original_input
{
ResponseInput
::
Text
(
text
)
=>
{
let
user_item
=
json!
({
"type"
:
"message"
,
"role"
:
"user"
,
"content"
:
[{
"type"
:
"input_text"
,
"text"
:
text
}]
});
input_array
.push
(
user_item
);
}
ResponseInput
::
Items
(
items
)
=>
{
// Items are already structured ResponseInputOutputItem, convert to JSON
if
let
Ok
(
items_value
)
=
serde_json
::
to_value
(
items
)
{
if
let
Some
(
items_arr
)
=
items_value
.as_array
()
{
input_array
.extend_from_slice
(
items_arr
);
}
}
}
}
// Add all conversation history (function calls and outputs)
input_array
.extend_from_slice
(
conversation_history
);
obj
.insert
(
"input"
.to_string
(),
Value
::
Array
(
input_array
));
// Use the transformed tools (function tools, not MCP tools)
if
let
Some
(
tools_arr
)
=
tools_json
.as_array
()
{
if
!
tools_arr
.is_empty
()
{
obj
.insert
(
"tools"
.to_string
(),
tools_json
.clone
());
}
}
// Ensure non-streaming and no store to upstream
obj
.insert
(
"stream"
.to_string
(),
Value
::
Bool
(
false
));
obj
.insert
(
"store"
.to_string
(),
Value
::
Bool
(
false
));
// Note: SGLang-specific fields were already removed from base_payload
// before it was passed to execute_tool_loop (see route_responses lines 1935-1946)
Ok
(
payload
)
}
/// Helper function to build mcp_call items from executed tool calls in conversation history
fn
build_executed_mcp_call_items
(
conversation_history
:
&
[
Value
],
server_label
:
&
str
,
)
->
Vec
<
Value
>
{
let
mut
mcp_call_items
=
Vec
::
new
();
for
item
in
conversation_history
{
if
item
.get
(
"type"
)
.and_then
(|
t
|
t
.as_str
())
==
Some
(
"function_call"
)
{
let
call_id
=
item
.get
(
"call_id"
)
.and_then
(|
v
|
v
.as_str
())
.unwrap_or
(
""
);
let
tool_name
=
item
.get
(
"name"
)
.and_then
(|
v
|
v
.as_str
())
.unwrap_or
(
""
);
let
args
=
item
.get
(
"arguments"
)
.and_then
(|
v
|
v
.as_str
())
.unwrap_or
(
"{}"
);
// Find corresponding output
let
output_item
=
conversation_history
.iter
()
.find
(|
o
|
{
o
.get
(
"type"
)
.and_then
(|
t
|
t
.as_str
())
==
Some
(
"function_call_output"
)
&&
o
.get
(
"call_id"
)
.and_then
(|
c
|
c
.as_str
())
==
Some
(
call_id
)
});
let
output_str
=
output_item
.and_then
(|
o
|
o
.get
(
"output"
)
.and_then
(|
v
|
v
.as_str
()))
.unwrap_or
(
"{}"
);
// Check if output contains error by parsing JSON
let
is_error
=
serde_json
::
from_str
::
<
serde_json
::
Value
>
(
output_str
)
.map
(|
v
|
v
.get
(
"error"
)
.is_some
())
.unwrap_or
(
false
);
let
mcp_call_item
=
Self
::
build_mcp_call_item
(
tool_name
,
args
,
output_str
,
server_label
,
!
is_error
,
if
is_error
{
Some
(
"Tool execution failed"
)
}
else
{
None
},
);
mcp_call_items
.push
(
mcp_call_item
);
}
}
mcp_call_items
}
/// Build an incomplete response when limits are exceeded
fn
build_incomplete_response
(
mut
response
:
Value
,
state
:
ToolLoopState
,
reason
:
&
str
,
active_mcp
:
&
Arc
<
crate
::
mcp
::
McpClientManager
>
,
original_body
:
&
ResponsesRequest
,
)
->
Result
<
Value
,
String
>
{
let
obj
=
response
.as_object_mut
()
.ok_or_else
(||
"response not an object"
.to_string
())
?
;
// Set status to completed (not failed - partial success)
obj
.insert
(
"status"
.to_string
(),
Value
::
String
(
"completed"
.to_string
()));
// Set incomplete_details
obj
.insert
(
"incomplete_details"
.to_string
(),
json!
({
"reason"
:
reason
}),
);
// Convert any function_call in output to mcp_call format
if
let
Some
(
output_array
)
=
obj
.get_mut
(
"output"
)
.and_then
(|
v
|
v
.as_array_mut
())
{
let
server_label
=
original_body
.tools
.iter
()
.find
(|
t
|
matches!
(
t
.r
#
type
,
ResponseToolType
::
Mcp
))
.and_then
(|
t
|
t
.server_label
.as_deref
())
.unwrap_or
(
"mcp"
);
// Find any function_call items and convert them to mcp_call (incomplete)
let
mut
mcp_call_items
=
Vec
::
new
();
for
item
in
output_array
.iter
()
{
if
item
.get
(
"type"
)
.and_then
(|
t
|
t
.as_str
())
==
Some
(
"function_tool_call"
)
{
let
tool_name
=
item
.get
(
"name"
)
.and_then
(|
v
|
v
.as_str
())
.unwrap_or
(
""
);
let
args
=
item
.get
(
"arguments"
)
.and_then
(|
v
|
v
.as_str
())
.unwrap_or
(
"{}"
);
// Mark as incomplete - not executed
let
mcp_call_item
=
Self
::
build_mcp_call_item
(
tool_name
,
args
,
""
,
// No output - wasn't executed
server_label
,
false
,
// Not successful
Some
(
"Not executed - response stopped due to limit"
),
);
mcp_call_items
.push
(
mcp_call_item
);
}
}
// Add mcp_list_tools and executed mcp_call items at the beginning
if
state
.total_calls
>
0
||
!
mcp_call_items
.is_empty
()
{
let
list_tools_item
=
Self
::
build_mcp_list_tools_item
(
active_mcp
,
server_label
);
output_array
.insert
(
0
,
list_tools_item
);
// Add mcp_call items for executed calls using helper
let
executed_items
=
Self
::
build_executed_mcp_call_items
(
&
state
.conversation_history
,
server_label
);
let
mut
insert_pos
=
1
;
for
item
in
executed_items
{
output_array
.insert
(
insert_pos
,
item
);
insert_pos
+=
1
;
}
// Add incomplete mcp_call items
for
item
in
mcp_call_items
{
output_array
.insert
(
insert_pos
,
item
);
insert_pos
+=
1
;
}
}
}
// Add warning to metadata
if
let
Some
(
metadata_val
)
=
obj
.get_mut
(
"metadata"
)
{
if
let
Some
(
metadata_obj
)
=
metadata_val
.as_object_mut
()
{
if
let
Some
(
mcp_val
)
=
metadata_obj
.get_mut
(
"mcp"
)
{
if
let
Some
(
mcp_obj
)
=
mcp_val
.as_object_mut
()
{
mcp_obj
.insert
(
"truncation_warning"
.to_string
(),
Value
::
String
(
format!
(
"Loop terminated at {} iterations, {} total calls (reason: {})"
,
state
.iteration
,
state
.total_calls
,
reason
)),
);
}
}
}
}
Ok
(
response
)
}
/// Execute the tool calling loop
async
fn
execute_tool_loop
(
&
self
,
url
:
&
str
,
headers
:
Option
<&
HeaderMap
>
,
initial_payload
:
Value
,
original_body
:
&
ResponsesRequest
,
active_mcp
:
&
Arc
<
crate
::
mcp
::
McpClientManager
>
,
config
:
&
McpLoopConfig
,
)
->
Result
<
Value
,
String
>
{
let
mut
state
=
ToolLoopState
::
new
(
original_body
.input
.clone
());
// Get max_tool_calls from request (None means no user-specified limit)
let
max_tool_calls
=
original_body
.max_tool_calls
.map
(|
n
|
n
as
usize
);
// Keep initial_payload as base template (already has fields cleaned)
let
base_payload
=
initial_payload
.clone
();
let
tools_json
=
base_payload
.get
(
"tools"
)
.cloned
()
.unwrap_or
(
json!
([]));
let
mut
current_payload
=
initial_payload
;
info!
(
"Starting tool loop: max_tool_calls={:?}, max_iterations={}"
,
max_tool_calls
,
config
.max_iterations
);
loop
{
// Make request to upstream
let
request_builder
=
self
.client
.post
(
url
)
.json
(
&
current_payload
);
let
request_builder
=
if
let
Some
(
headers
)
=
headers
{
apply_request_headers
(
headers
,
request_builder
,
true
)
}
else
{
request_builder
};
let
response
=
request_builder
.send
()
.await
.map_err
(|
e
|
format!
(
"upstream request failed: {}"
,
e
))
?
;
if
!
response
.status
()
.is_success
()
{
let
status
=
response
.status
();
let
body
=
response
.text
()
.await
.unwrap_or_default
();
return
Err
(
format!
(
"upstream error {}: {}"
,
status
,
body
));
}
let
mut
response_json
=
response
.json
::
<
Value
>
()
.await
.map_err
(|
e
|
format!
(
"parse response: {}"
,
e
))
?
;
// Check for function call
if
let
Some
((
call_id
,
tool_name
,
args_json_str
))
=
Self
::
extract_function_call
(
&
response_json
)
{
state
.iteration
+=
1
;
state
.total_calls
+=
1
;
info!
(
"Tool loop iteration {}: calling {} (call_id: {})"
,
state
.iteration
,
tool_name
,
call_id
);
// Check combined limit: use minimum of user's max_tool_calls (if set) and safety max_iterations
let
effective_limit
=
match
max_tool_calls
{
Some
(
user_max
)
=>
user_max
.min
(
config
.max_iterations
),
None
=>
config
.max_iterations
,
};
if
state
.total_calls
>
effective_limit
{
if
let
Some
(
user_max
)
=
max_tool_calls
{
if
state
.total_calls
>
user_max
{
warn!
(
"Reached user-specified max_tool_calls limit: {}"
,
user_max
);
}
else
{
warn!
(
"Reached safety max_iterations limit: {}"
,
config
.max_iterations
);
}
}
else
{
warn!
(
"Reached safety max_iterations limit: {}"
,
config
.max_iterations
);
}
return
Self
::
build_incomplete_response
(
response_json
,
state
,
"max_tool_calls"
,
active_mcp
,
original_body
,
);
}
// Execute tool
let
call_result
=
Self
::
execute_mcp_call
(
active_mcp
,
&
tool_name
,
&
args_json_str
)
.await
;
let
output_str
=
match
call_result
{
Ok
((
_
,
output
))
=>
output
,
Err
(
err
)
=>
{
warn!
(
"Tool execution failed: {}"
,
err
);
// Return error as output, let model decide how to proceed
json!
({
"error"
:
err
})
.to_string
()
}
};
// Record the call
state
.record_call
(
call_id
,
tool_name
,
args_json_str
,
output_str
);
// Build resume payload
current_payload
=
Self
::
build_resume_payload
(
&
base_payload
,
&
state
.conversation_history
,
&
state
.original_input
,
&
tools_json
,
)
?
;
}
else
{
// No more tool calls, we're done
info!
(
"Tool loop completed: {} iterations, {} total calls"
,
state
.iteration
,
state
.total_calls
);
// Inject MCP output items if we executed any tools
if
state
.total_calls
>
0
{
let
server_label
=
original_body
.tools
.iter
()
.find
(|
t
|
matches!
(
t
.r
#
type
,
ResponseToolType
::
Mcp
))
.and_then
(|
t
|
t
.server_label
.as_deref
())
.unwrap_or
(
"mcp"
);
// Build mcp_list_tools item
let
list_tools_item
=
Self
::
build_mcp_list_tools_item
(
active_mcp
,
server_label
);
// Insert at beginning of output array
if
let
Some
(
output_array
)
=
response_json
.get_mut
(
"output"
)
.and_then
(|
v
|
v
.as_array_mut
())
{
output_array
.insert
(
0
,
list_tools_item
);
// Build mcp_call items using helper function
let
mcp_call_items
=
Self
::
build_executed_mcp_call_items
(
&
state
.conversation_history
,
server_label
,
);
// Insert mcp_call items after mcp_list_tools using mutable position
let
mut
insert_pos
=
1
;
for
item
in
mcp_call_items
{
output_array
.insert
(
insert_pos
,
item
);
insert_pos
+=
1
;
}
}
}
return
Ok
(
response_json
);
}
}
}
/// Generate a unique ID for MCP output items (similar to OpenAI format)
/// Generate a unique ID for MCP output items (similar to OpenAI format)
fn
generate_mcp_id
(
prefix
:
&
str
)
->
String
{
fn
generate_mcp_id
(
prefix
:
&
str
)
->
String
{
use
rand
::
RngCore
;
use
rand
::
RngCore
;
...
@@ -1213,113 +1534,6 @@ impl OpenAIRouter {
...
@@ -1213,113 +1534,6 @@ impl OpenAIRouter {
"server_label"
:
server_label
"server_label"
:
server_label
})
})
}
}
/// Inject mcp_list_tools and mcp_call items into the response output array
fn
inject_mcp_output_items
(
response_json
:
&
mut
Value
,
mcp
:
&
Arc
<
crate
::
mcp
::
McpClientManager
>
,
args
:
McpOutputItemsArgs
,
)
->
Result
<
(),
String
>
{
let
output_array
=
response_json
.get_mut
(
"output"
)
.and_then
(|
v
|
v
.as_array_mut
())
.ok_or
(
"missing output array"
)
?
;
// Build MCP output items
let
list_tools_item
=
Self
::
build_mcp_list_tools_item
(
mcp
,
args
.server_label
);
let
call_item
=
Self
::
build_mcp_call_item
(
args
.tool_name
,
args
.args_json
,
args
.output
,
args
.server_label
,
args
.success
,
args
.error
,
);
// Find the index of the last message item to insert mcp_call before it
let
call_insertion_index
=
output_array
.iter
()
.rposition
(|
item
|
item
.get
(
"type"
)
.and_then
(|
v
|
v
.as_str
())
==
Some
(
"message"
))
.unwrap_or
(
output_array
.len
());
// Insert items in-place for efficiency
output_array
.insert
(
call_insertion_index
,
call_item
);
output_array
.insert
(
0
,
list_tools_item
);
Ok
(())
}
async
fn
resume_with_tool_result
(
&
self
,
args
:
ResumeWithToolArgs
<
'_
>
)
->
Result
<
Value
,
String
>
{
let
mut
payload2
=
args
.original_payload
.clone
();
let
obj
=
payload2
.as_object_mut
()
.ok_or_else
(||
"payload not an object"
.to_string
())
?
;
// Build function_call and tool result items per OpenAI Responses spec
let
user_item
=
serde_json
::
json!
({
"type"
:
"message"
,
"role"
:
"user"
,
"content"
:
args
.original_body.input
.clone
()
});
// temp system message since currently only support 1 turn of mcp function call
let
system_item
=
serde_json
::
json!
({
"type"
:
"message"
,
"role"
:
"system"
,
"content"
:
"please resume with the following tool result, and answer user's question directly, don't trigger any more tool calls"
});
let
func_item
=
serde_json
::
json!
({
"type"
:
"function_call"
,
"call_id"
:
args
.call_id
,
"name"
:
args
.tool_name
,
"arguments"
:
args
.args_json_str
});
// Build tool result item as function_call_output per OpenAI Responses spec
let
tool_item
=
serde_json
::
json!
({
"type"
:
"function_call_output"
,
"call_id"
:
args
.call_id
,
"output"
:
args
.output_str
});
obj
.insert
(
"input"
.to_string
(),
Value
::
Array
(
vec!
[
user_item
,
system_item
,
func_item
,
tool_item
]),
);
// Ensure non-streaming and no store to upstream
obj
.insert
(
"stream"
.to_string
(),
Value
::
Bool
(
false
));
obj
.insert
(
"store"
.to_string
(),
Value
::
Bool
(
false
));
let
mut
req
=
self
.client
.post
(
args
.url
)
.json
(
&
payload2
);
if
let
Some
(
headers
)
=
args
.headers
{
req
=
apply_request_headers
(
headers
,
req
,
true
);
}
let
resp
=
req
.send
()
.await
.map_err
(|
e
|
format!
(
"resume request failed: {}"
,
e
))
?
;
if
!
resp
.status
()
.is_success
()
{
let
status
=
resp
.status
();
let
body
=
resp
.text
()
.await
.unwrap_or_default
();
return
Err
(
format!
(
"resume upstream error {}: {}"
,
status
,
body
));
}
let
mut
v
=
resp
.json
::
<
Value
>
()
.await
.map_err
(|
e
|
format!
(
"parse resume response: {}"
,
e
))
?
;
if
let
Some
(
instr
)
=
&
args
.original_body.instructions
{
if
let
Some
(
obj
)
=
v
.as_object_mut
()
{
obj
.entry
(
"instructions"
)
.or_insert
(
Value
::
String
(
instr
.clone
()));
}
}
// After resume, mask tools as MCP if request used MCP
Self
::
mask_tools_as_mcp
(
&
mut
v
,
args
.original_body
);
if
let
Some
(
obj
)
=
v
.as_object_mut
()
{
obj
.insert
(
"store"
.to_string
(),
Value
::
Bool
(
args
.original_body.store
));
}
Ok
(
v
)
}
}
}
#[async_trait]
#[async_trait]
...
...
sgl-router/tests/responses_api_test.rs
View file @
a28b394f
...
@@ -252,6 +252,7 @@ fn test_responses_request_creation() {
...
@@ -252,6 +252,7 @@ fn test_responses_request_creation() {
previous_response_id
:
None
,
previous_response_id
:
None
,
reasoning
:
Some
(
ResponseReasoningParam
{
reasoning
:
Some
(
ResponseReasoningParam
{
effort
:
Some
(
ReasoningEffort
::
Medium
),
effort
:
Some
(
ReasoningEffort
::
Medium
),
summary
:
None
,
}),
}),
service_tier
:
ServiceTier
::
Auto
,
service_tier
:
ServiceTier
::
Auto
,
store
:
true
,
store
:
true
,
...
@@ -380,6 +381,7 @@ fn test_usage_conversion() {
...
@@ -380,6 +381,7 @@ fn test_usage_conversion() {
fn
test_reasoning_param_default
()
{
fn
test_reasoning_param_default
()
{
let
param
=
ResponseReasoningParam
{
let
param
=
ResponseReasoningParam
{
effort
:
Some
(
ReasoningEffort
::
Medium
),
effort
:
Some
(
ReasoningEffort
::
Medium
),
summary
:
None
,
};
};
let
json
=
serde_json
::
to_string
(
&
param
)
.unwrap
();
let
json
=
serde_json
::
to_string
(
&
param
)
.unwrap
();
...
@@ -403,6 +405,7 @@ fn test_json_serialization() {
...
@@ -403,6 +405,7 @@ fn test_json_serialization() {
previous_response_id
:
None
,
previous_response_id
:
None
,
reasoning
:
Some
(
ResponseReasoningParam
{
reasoning
:
Some
(
ResponseReasoningParam
{
effort
:
Some
(
ReasoningEffort
::
High
),
effort
:
Some
(
ReasoningEffort
::
High
),
summary
:
None
,
}),
}),
service_tier
:
ServiceTier
::
Priority
,
service_tier
:
ServiceTier
::
Priority
,
store
:
false
,
store
:
false
,
...
@@ -437,3 +440,328 @@ fn test_json_serialization() {
...
@@ -437,3 +440,328 @@ fn test_json_serialization() {
assert
!
(
parsed
.stream
);
assert
!
(
parsed
.stream
);
assert_eq!
(
parsed
.tools
.len
(),
1
);
assert_eq!
(
parsed
.tools
.len
(),
1
);
}
}
#[tokio::test]
async
fn
test_multi_turn_loop_with_mcp
()
{
// This test verifies the multi-turn loop functionality:
// 1. Initial request with MCP tools
// 2. Mock worker returns function_call
// 3. Router executes MCP tool and resumes
// 4. Mock worker returns final answer
// 5. Verify the complete flow worked
// Start mock MCP server
let
mut
mcp
=
MockMCPServer
::
start
()
.await
.expect
(
"start mcp"
);
// Write a temp MCP config file
let
mcp_yaml
=
format!
(
"servers:
\n
- name: mock
\n
protocol: streamable
\n
url: {}
\n
"
,
mcp
.url
()
);
let
dir
=
tempfile
::
tempdir
()
.expect
(
"tmpdir"
);
let
cfg_path
=
dir
.path
()
.join
(
"mcp.yaml"
);
std
::
fs
::
write
(
&
cfg_path
,
mcp_yaml
)
.expect
(
"write mcp cfg"
);
std
::
env
::
set_var
(
"SGLANG_MCP_CONFIG"
,
cfg_path
.to_str
()
.unwrap
());
// Start mock OpenAI worker
let
mut
worker
=
MockWorker
::
new
(
MockWorkerConfig
{
port
:
0
,
worker_type
:
WorkerType
::
Regular
,
health_status
:
HealthStatus
::
Healthy
,
response_delay_ms
:
0
,
fail_rate
:
0.0
,
});
let
worker_url
=
worker
.start
()
.await
.expect
(
"start worker"
);
// Build router config
let
router_cfg
=
RouterConfig
{
mode
:
RoutingMode
::
OpenAI
{
worker_urls
:
vec!
[
worker_url
],
},
connection_mode
:
ConnectionMode
::
Http
,
policy
:
PolicyConfig
::
Random
,
host
:
"127.0.0.1"
.to_string
(),
port
:
0
,
max_payload_size
:
8
*
1024
*
1024
,
request_timeout_secs
:
60
,
worker_startup_timeout_secs
:
5
,
worker_startup_check_interval_secs
:
1
,
dp_aware
:
false
,
api_key
:
None
,
discovery
:
None
,
metrics
:
None
,
log_dir
:
None
,
log_level
:
Some
(
"info"
.to_string
()),
request_id_headers
:
None
,
max_concurrent_requests
:
32
,
queue_size
:
0
,
queue_timeout_secs
:
5
,
rate_limit_tokens_per_second
:
None
,
cors_allowed_origins
:
vec!
[],
retry
:
RetryConfig
::
default
(),
circuit_breaker
:
CircuitBreakerConfig
::
default
(),
disable_retries
:
false
,
disable_circuit_breaker
:
false
,
health_check
:
HealthCheckConfig
::
default
(),
enable_igw
:
false
,
model_path
:
None
,
tokenizer_path
:
None
,
history_backend
:
sglang_router_rs
::
config
::
HistoryBackend
::
Memory
,
oracle
:
None
,
};
let
ctx
=
AppContext
::
new
(
router_cfg
,
reqwest
::
Client
::
new
(),
64
,
None
)
.expect
(
"ctx"
);
let
router
=
RouterFactory
::
create_router
(
&
Arc
::
new
(
ctx
))
.await
.expect
(
"router"
);
// Build request with MCP tools
let
req
=
ResponsesRequest
{
background
:
false
,
include
:
None
,
input
:
ResponseInput
::
Text
(
"search for SGLang"
.to_string
()),
instructions
:
Some
(
"Be helpful"
.to_string
()),
max_output_tokens
:
Some
(
128
),
max_tool_calls
:
None
,
// No limit - test unlimited
metadata
:
None
,
model
:
Some
(
"mock-model"
.to_string
()),
parallel_tool_calls
:
true
,
previous_response_id
:
None
,
reasoning
:
None
,
service_tier
:
ServiceTier
::
Auto
,
store
:
true
,
stream
:
false
,
temperature
:
Some
(
0.7
),
tool_choice
:
ToolChoice
::
Value
(
ToolChoiceValue
::
Auto
),
tools
:
vec!
[
ResponseTool
{
r
#
type
:
ResponseToolType
::
Mcp
,
server_url
:
Some
(
mcp
.url
()),
server_label
:
Some
(
"mock"
.to_string
()),
server_description
:
Some
(
"Mock MCP server for testing"
.to_string
()),
require_approval
:
Some
(
"never"
.to_string
()),
..
Default
::
default
()
}],
top_logprobs
:
0
,
top_p
:
Some
(
1.0
),
truncation
:
Truncation
::
Disabled
,
user
:
None
,
request_id
:
"resp_multi_turn_test"
.to_string
(),
priority
:
0
,
frequency_penalty
:
0.0
,
presence_penalty
:
0.0
,
stop
:
None
,
top_k
:
50
,
min_p
:
0.0
,
repetition_penalty
:
1.0
,
};
// Execute the request (this should trigger the multi-turn loop)
let
response
=
router
.route_responses
(
None
,
&
req
,
None
)
.await
;
// Check status
assert_eq!
(
response
.status
(),
axum
::
http
::
StatusCode
::
OK
,
"Request should succeed"
);
// Read the response body
use
axum
::
body
::
to_bytes
;
let
response_body
=
response
.into_body
();
let
body_bytes
=
to_bytes
(
response_body
,
usize
::
MAX
)
.await
.unwrap
();
let
response_json
:
serde_json
::
Value
=
serde_json
::
from_slice
(
&
body_bytes
)
.unwrap
();
println!
(
"Multi-turn response: {}"
,
serde_json
::
to_string_pretty
(
&
response_json
)
.unwrap
()
);
// Verify the response structure
assert_eq!
(
response_json
[
"object"
],
"response"
);
assert_eq!
(
response_json
[
"status"
],
"completed"
);
// Note: mock worker generates its own ID, so we just verify it exists
assert
!
(
response_json
[
"id"
]
.is_string
(),
"Response should have an id"
);
// Check that output contains final message
let
output
=
response_json
[
"output"
]
.as_array
()
.expect
(
"output should be array"
);
assert
!
(
!
output
.is_empty
(),
"output should not be empty"
);
// Find the final message with text
let
has_final_text
=
output
.iter
()
.any
(|
item
|
{
item
.get
(
"type"
)
.and_then
(|
t
|
t
.as_str
())
.map
(|
t
|
t
==
"message"
)
.unwrap_or
(
false
)
&&
item
.get
(
"content"
)
.and_then
(|
c
|
c
.as_array
())
.map
(|
arr
|
{
arr
.iter
()
.any
(|
part
|
{
part
.get
(
"type"
)
.and_then
(|
t
|
t
.as_str
())
.map
(|
t
|
t
==
"output_text"
)
.unwrap_or
(
false
)
})
})
.unwrap_or
(
false
)
});
assert
!
(
has_final_text
,
"Should have final text output"
);
// Verify tools are masked back to MCP format
let
tools
=
response_json
[
"tools"
]
.as_array
()
.expect
(
"tools should be array"
);
assert_eq!
(
tools
.len
(),
1
);
assert_eq!
(
tools
[
0
][
"type"
],
"mcp"
);
assert_eq!
(
tools
[
0
][
"server_label"
],
"mock"
);
// Clean up
std
::
env
::
remove_var
(
"SGLANG_MCP_CONFIG"
);
worker
.stop
()
.await
;
mcp
.stop
()
.await
;
}
#[tokio::test]
async
fn
test_max_tool_calls_limit
()
{
// This test verifies that max_tool_calls is respected
// Note: The mock worker returns a final answer after one tool call,
// so with max_tool_calls=1, it completes normally (doesn't exceed the limit)
let
mut
mcp
=
MockMCPServer
::
start
()
.await
.expect
(
"start mcp"
);
let
mcp_yaml
=
format!
(
"servers:
\n
- name: mock
\n
protocol: streamable
\n
url: {}
\n
"
,
mcp
.url
()
);
let
dir
=
tempfile
::
tempdir
()
.expect
(
"tmpdir"
);
let
cfg_path
=
dir
.path
()
.join
(
"mcp.yaml"
);
std
::
fs
::
write
(
&
cfg_path
,
mcp_yaml
)
.expect
(
"write mcp cfg"
);
std
::
env
::
set_var
(
"SGLANG_MCP_CONFIG"
,
cfg_path
.to_str
()
.unwrap
());
let
mut
worker
=
MockWorker
::
new
(
MockWorkerConfig
{
port
:
0
,
worker_type
:
WorkerType
::
Regular
,
health_status
:
HealthStatus
::
Healthy
,
response_delay_ms
:
0
,
fail_rate
:
0.0
,
});
let
worker_url
=
worker
.start
()
.await
.expect
(
"start worker"
);
let
router_cfg
=
RouterConfig
{
mode
:
RoutingMode
::
OpenAI
{
worker_urls
:
vec!
[
worker_url
],
},
connection_mode
:
ConnectionMode
::
Http
,
policy
:
PolicyConfig
::
Random
,
host
:
"127.0.0.1"
.to_string
(),
port
:
0
,
max_payload_size
:
8
*
1024
*
1024
,
request_timeout_secs
:
60
,
worker_startup_timeout_secs
:
5
,
worker_startup_check_interval_secs
:
1
,
dp_aware
:
false
,
api_key
:
None
,
discovery
:
None
,
metrics
:
None
,
log_dir
:
None
,
log_level
:
Some
(
"info"
.to_string
()),
request_id_headers
:
None
,
max_concurrent_requests
:
32
,
queue_size
:
0
,
queue_timeout_secs
:
5
,
rate_limit_tokens_per_second
:
None
,
cors_allowed_origins
:
vec!
[],
retry
:
RetryConfig
::
default
(),
circuit_breaker
:
CircuitBreakerConfig
::
default
(),
disable_retries
:
false
,
disable_circuit_breaker
:
false
,
health_check
:
HealthCheckConfig
::
default
(),
enable_igw
:
false
,
model_path
:
None
,
tokenizer_path
:
None
,
history_backend
:
sglang_router_rs
::
config
::
HistoryBackend
::
Memory
,
oracle
:
None
,
};
let
ctx
=
AppContext
::
new
(
router_cfg
,
reqwest
::
Client
::
new
(),
64
,
None
)
.expect
(
"ctx"
);
let
router
=
RouterFactory
::
create_router
(
&
Arc
::
new
(
ctx
))
.await
.expect
(
"router"
);
let
req
=
ResponsesRequest
{
background
:
false
,
include
:
None
,
input
:
ResponseInput
::
Text
(
"test max calls"
.to_string
()),
instructions
:
None
,
max_output_tokens
:
Some
(
128
),
max_tool_calls
:
Some
(
1
),
// Limit to 1 call
metadata
:
None
,
model
:
Some
(
"mock-model"
.to_string
()),
parallel_tool_calls
:
true
,
previous_response_id
:
None
,
reasoning
:
None
,
service_tier
:
ServiceTier
::
Auto
,
store
:
false
,
stream
:
false
,
temperature
:
Some
(
0.7
),
tool_choice
:
ToolChoice
::
Value
(
ToolChoiceValue
::
Auto
),
tools
:
vec!
[
ResponseTool
{
r
#
type
:
ResponseToolType
::
Mcp
,
server_url
:
Some
(
mcp
.url
()),
server_label
:
Some
(
"mock"
.to_string
()),
..
Default
::
default
()
}],
top_logprobs
:
0
,
top_p
:
Some
(
1.0
),
truncation
:
Truncation
::
Disabled
,
user
:
None
,
request_id
:
"resp_max_calls_test"
.to_string
(),
priority
:
0
,
frequency_penalty
:
0.0
,
presence_penalty
:
0.0
,
stop
:
None
,
top_k
:
50
,
min_p
:
0.0
,
repetition_penalty
:
1.0
,
};
let
response
=
router
.route_responses
(
None
,
&
req
,
None
)
.await
;
assert_eq!
(
response
.status
(),
axum
::
http
::
StatusCode
::
OK
);
use
axum
::
body
::
to_bytes
;
let
response_body
=
response
.into_body
();
let
body_bytes
=
to_bytes
(
response_body
,
usize
::
MAX
)
.await
.unwrap
();
let
response_json
:
serde_json
::
Value
=
serde_json
::
from_slice
(
&
body_bytes
)
.unwrap
();
println!
(
"Max calls response: {}"
,
serde_json
::
to_string_pretty
(
&
response_json
)
.unwrap
()
);
// With max_tool_calls=1, the mock returns a final answer after 1 call
// So it completes normally without exceeding the limit
assert_eq!
(
response_json
[
"status"
],
"completed"
);
// Verify the basic response structure
assert
!
(
response_json
[
"id"
]
.is_string
());
assert_eq!
(
response_json
[
"object"
],
"response"
);
// The response should have tools masked back to MCP format
let
tools
=
response_json
[
"tools"
]
.as_array
()
.expect
(
"tools should be array"
);
assert_eq!
(
tools
.len
(),
1
);
assert_eq!
(
tools
[
0
][
"type"
],
"mcp"
);
// Note: To test actual limit exceeding, we would need a mock that keeps
// calling tools indefinitely, which would hit max_iterations (safety limit)
std
::
env
::
remove_var
(
"SGLANG_MCP_CONFIG"
);
worker
.stop
()
.await
;
mcp
.stop
()
.await
;
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment