Unverified Commit 34151f17 authored by Keyang Ru's avatar Keyang Ru Committed by GitHub
Browse files

[router] Steaming support for MCP Tool Calls in OpenAI Router (#11173)

parent 6794d210
......@@ -31,6 +31,38 @@ use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{error, info, warn};
// SSE Event Type Constants - single source of truth for event type strings
mod event_types {
// Response lifecycle events
pub const RESPONSE_CREATED: &str = "response.created";
pub const RESPONSE_IN_PROGRESS: &str = "response.in_progress";
pub const RESPONSE_COMPLETED: &str = "response.completed";
// Output item events
pub const OUTPUT_ITEM_ADDED: &str = "response.output_item.added";
pub const OUTPUT_ITEM_DONE: &str = "response.output_item.done";
pub const OUTPUT_ITEM_DELTA: &str = "response.output_item.delta";
// Function call events
pub const FUNCTION_CALL_ARGUMENTS_DELTA: &str = "response.function_call_arguments.delta";
pub const FUNCTION_CALL_ARGUMENTS_DONE: &str = "response.function_call_arguments.done";
// MCP call events
pub const MCP_CALL_ARGUMENTS_DELTA: &str = "response.mcp_call_arguments.delta";
pub const MCP_CALL_ARGUMENTS_DONE: &str = "response.mcp_call_arguments.done";
pub const MCP_CALL_IN_PROGRESS: &str = "response.mcp_call.in_progress";
pub const MCP_CALL_COMPLETED: &str = "response.mcp_call.completed";
pub const MCP_LIST_TOOLS_IN_PROGRESS: &str = "response.mcp_list_tools.in_progress";
pub const MCP_LIST_TOOLS_COMPLETED: &str = "response.mcp_list_tools.completed";
// Item types
pub const ITEM_TYPE_FUNCTION_CALL: &str = "function_call";
pub const ITEM_TYPE_FUNCTION_TOOL_CALL: &str = "function_tool_call";
pub const ITEM_TYPE_MCP_CALL: &str = "mcp_call";
pub const ITEM_TYPE_FUNCTION: &str = "function";
pub const ITEM_TYPE_MCP_LIST_TOOLS: &str = "mcp_list_tools";
}
/// Router for OpenAI backend
pub struct OpenAIRouter {
/// HTTP client for upstream OpenAI-compatible API
......@@ -102,7 +134,7 @@ impl ToolLoopState {
) {
// Add function_call item to history
let func_item = json!({
"type": "function_call",
"type": event_types::ITEM_TYPE_FUNCTION_CALL,
"call_id": call_id,
"name": tool_name,
"arguments": args_json_str
......@@ -133,6 +165,386 @@ struct StreamingResponseAccumulator {
encountered_error: Option<Value>,
}
/// Represents a function call being accumulated across delta events
#[derive(Debug, Clone)]
struct FunctionCallInProgress {
call_id: String,
name: String,
arguments_buffer: String,
output_index: usize,
last_obfuscation: Option<String>,
assigned_output_index: Option<usize>,
}
impl FunctionCallInProgress {
fn new(call_id: String, output_index: usize) -> Self {
Self {
call_id,
name: String::new(),
arguments_buffer: String::new(),
output_index,
last_obfuscation: None,
assigned_output_index: None,
}
}
fn is_complete(&self) -> bool {
// A tool call is complete if it has a name
!self.name.is_empty()
}
fn effective_output_index(&self) -> usize {
self.assigned_output_index.unwrap_or(self.output_index)
}
}
#[derive(Debug, Default)]
struct OutputIndexMapper {
next_index: usize,
// Map upstream output_index -> remapped output_index
assigned: HashMap<usize, usize>,
}
impl OutputIndexMapper {
fn with_start(next_index: usize) -> Self {
Self {
next_index,
assigned: HashMap::new(),
}
}
fn ensure_mapping(&mut self, upstream_index: usize) -> usize {
*self.assigned.entry(upstream_index).or_insert_with(|| {
let assigned = self.next_index;
self.next_index += 1;
assigned
})
}
fn lookup(&self, upstream_index: usize) -> Option<usize> {
self.assigned.get(&upstream_index).copied()
}
fn allocate_synthetic(&mut self) -> usize {
let assigned = self.next_index;
self.next_index += 1;
assigned
}
fn next_index(&self) -> usize {
self.next_index
}
}
/// Action to take based on streaming event processing
#[derive(Debug)]
enum StreamAction {
Forward, // Pass event to client
Buffer, // Accumulate for tool execution
ExecuteTools, // Function call complete, execute now
}
/// Handles streaming responses with MCP tool call interception
struct StreamingToolHandler {
/// Accumulator for response persistence
accumulator: StreamingResponseAccumulator,
/// Function calls being built from deltas
pending_calls: Vec<FunctionCallInProgress>,
/// Track if we're currently in a function call
in_function_call: bool,
/// Manage output_index remapping so they increment per item
output_index_mapper: OutputIndexMapper,
/// Original response id captured from the first response.created event
original_response_id: Option<String>,
}
impl StreamingToolHandler {
fn with_starting_index(start: usize) -> Self {
Self {
accumulator: StreamingResponseAccumulator::new(),
pending_calls: Vec::new(),
in_function_call: false,
output_index_mapper: OutputIndexMapper::with_start(start),
original_response_id: None,
}
}
fn ensure_output_index(&mut self, upstream_index: usize) -> usize {
self.output_index_mapper.ensure_mapping(upstream_index)
}
fn mapped_output_index(&self, upstream_index: usize) -> Option<usize> {
self.output_index_mapper.lookup(upstream_index)
}
fn allocate_synthetic_output_index(&mut self) -> usize {
self.output_index_mapper.allocate_synthetic()
}
fn next_output_index(&self) -> usize {
self.output_index_mapper.next_index()
}
fn original_response_id(&self) -> Option<&str> {
self.original_response_id
.as_deref()
.or_else(|| self.accumulator.original_response_id())
}
fn snapshot_final_response(&self) -> Option<Value> {
self.accumulator.snapshot_final_response()
}
/// Process an SSE event and determine what action to take
fn process_event(&mut self, event_name: Option<&str>, data: &str) -> StreamAction {
// Always feed to accumulator for storage
self.accumulator.ingest_block(&format!(
"{}data: {}",
event_name
.map(|n| format!("event: {}\n", n))
.unwrap_or_default(),
data
));
let parsed: Value = match serde_json::from_str(data) {
Ok(v) => v,
Err(_) => return StreamAction::Forward,
};
let event_type = event_name
.map(|s| s.to_string())
.or_else(|| {
parsed
.get("type")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
})
.unwrap_or_default();
match event_type.as_str() {
event_types::RESPONSE_CREATED => {
if self.original_response_id.is_none() {
if let Some(response_obj) = parsed.get("response").and_then(|v| v.as_object()) {
if let Some(id) = response_obj.get("id").and_then(|v| v.as_str()) {
self.original_response_id = Some(id.to_string());
}
}
}
StreamAction::Forward
}
event_types::RESPONSE_COMPLETED => StreamAction::Forward,
event_types::OUTPUT_ITEM_ADDED => {
if let Some(idx) = parsed.get("output_index").and_then(|v| v.as_u64()) {
self.ensure_output_index(idx as usize);
}
// Check if this is a function_call item being added
if let Some(item) = parsed.get("item") {
if let Some(item_type) = item.get("type").and_then(|v| v.as_str()) {
if item_type == event_types::ITEM_TYPE_FUNCTION_CALL
|| item_type == event_types::ITEM_TYPE_FUNCTION_TOOL_CALL
{
match parsed.get("output_index").and_then(|v| v.as_u64()) {
Some(idx) => {
let output_index = idx as usize;
let assigned_index = self.ensure_output_index(output_index);
let call_id =
item.get("call_id").and_then(|v| v.as_str()).unwrap_or("");
let name =
item.get("name").and_then(|v| v.as_str()).unwrap_or("");
// Create or update the function call
let call = self.get_or_create_call(output_index, item);
call.call_id = call_id.to_string();
call.name = name.to_string();
call.assigned_output_index = Some(assigned_index);
self.in_function_call = true;
}
None => {
tracing::warn!(
"Missing output_index in function_call added event, \
forwarding without processing for tool execution"
);
}
}
}
}
}
StreamAction::Forward
}
event_types::FUNCTION_CALL_ARGUMENTS_DELTA => {
// Accumulate arguments for the function call
if let Some(output_index) = parsed
.get("output_index")
.and_then(|v| v.as_u64())
.map(|v| v as usize)
{
let assigned_index = self.ensure_output_index(output_index);
if let Some(delta) = parsed.get("delta").and_then(|v| v.as_str()) {
if let Some(call) = self
.pending_calls
.iter_mut()
.find(|c| c.output_index == output_index)
{
call.arguments_buffer.push_str(delta);
if let Some(obfuscation) =
parsed.get("obfuscation").and_then(|v| v.as_str())
{
call.last_obfuscation = Some(obfuscation.to_string());
}
if call.assigned_output_index.is_none() {
call.assigned_output_index = Some(assigned_index);
}
}
}
}
StreamAction::Forward
}
event_types::FUNCTION_CALL_ARGUMENTS_DONE => {
// Function call arguments complete - check if ready to execute
if let Some(output_index) = parsed
.get("output_index")
.and_then(|v| v.as_u64())
.map(|v| v as usize)
{
let assigned_index = self.ensure_output_index(output_index);
if let Some(call) = self
.pending_calls
.iter_mut()
.find(|c| c.output_index == output_index)
{
if call.assigned_output_index.is_none() {
call.assigned_output_index = Some(assigned_index);
}
}
}
if self.has_complete_calls() {
StreamAction::ExecuteTools
} else {
StreamAction::Forward
}
}
event_types::OUTPUT_ITEM_DELTA => self.process_output_delta(&parsed),
event_types::OUTPUT_ITEM_DONE => {
// Check if we have complete function calls ready to execute
if let Some(output_index) = parsed
.get("output_index")
.and_then(|v| v.as_u64())
.map(|v| v as usize)
{
self.ensure_output_index(output_index);
}
if self.has_complete_calls() {
StreamAction::ExecuteTools
} else {
StreamAction::Forward
}
}
_ => StreamAction::Forward,
}
}
/// Process output delta events to detect and accumulate function calls
fn process_output_delta(&mut self, event: &Value) -> StreamAction {
let output_index = event
.get("output_index")
.and_then(|v| v.as_u64())
.map(|v| v as usize)
.unwrap_or(0);
let assigned_index = self.ensure_output_index(output_index);
let delta = match event.get("delta") {
Some(d) => d,
None => return StreamAction::Forward,
};
// Check if this is a function call delta
let item_type = delta.get("type").and_then(|v| v.as_str());
if item_type == Some(event_types::ITEM_TYPE_FUNCTION_TOOL_CALL)
|| item_type == Some(event_types::ITEM_TYPE_FUNCTION_CALL)
{
self.in_function_call = true;
// Get or create function call for this output index
let call = self.get_or_create_call(output_index, delta);
call.assigned_output_index = Some(assigned_index);
// Accumulate call_id if present
if let Some(call_id) = delta.get("call_id").and_then(|v| v.as_str()) {
call.call_id = call_id.to_string();
}
// Accumulate name if present
if let Some(name) = delta.get("name").and_then(|v| v.as_str()) {
call.name.push_str(name);
}
// Accumulate arguments if present
if let Some(args) = delta.get("arguments").and_then(|v| v.as_str()) {
call.arguments_buffer.push_str(args);
}
if let Some(obfuscation) = delta.get("obfuscation").and_then(|v| v.as_str()) {
call.last_obfuscation = Some(obfuscation.to_string());
}
// Buffer this event, don't forward to client
return StreamAction::Buffer;
}
// Forward non-function-call events
StreamAction::Forward
}
fn get_or_create_call(
&mut self,
output_index: usize,
delta: &Value,
) -> &mut FunctionCallInProgress {
// Find existing call for this output index
// Note: We use position() + index instead of iter_mut().find() because we need
// to potentially mutate pending_calls after the early return, which causes
// borrow checker issues with the iter_mut approach
if let Some(pos) = self
.pending_calls
.iter()
.position(|c| c.output_index == output_index)
{
return &mut self.pending_calls[pos];
}
// Create new call
let call_id = delta
.get("call_id")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let mut call = FunctionCallInProgress::new(call_id, output_index);
if let Some(obfuscation) = delta.get("obfuscation").and_then(|v| v.as_str()) {
call.last_obfuscation = Some(obfuscation.to_string());
}
self.pending_calls.push(call);
self.pending_calls
.last_mut()
.expect("Just pushed to pending_calls, must have at least one element")
}
fn has_complete_calls(&self) -> bool {
!self.pending_calls.is_empty() && self.pending_calls.iter().all(|c| c.is_complete())
}
fn take_pending_calls(&mut self) -> Vec<FunctionCallInProgress> {
std::mem::take(&mut self.pending_calls)
}
}
impl StreamingResponseAccumulator {
fn new() -> Self {
Self {
......@@ -163,6 +575,36 @@ impl StreamingResponseAccumulator {
fn encountered_error(&self) -> Option<&Value> {
self.encountered_error.as_ref()
}
fn original_response_id(&self) -> Option<&str> {
self.initial_response
.as_ref()
.and_then(|response| response.get("id"))
.and_then(|id| id.as_str())
}
fn snapshot_final_response(&self) -> Option<Value> {
if let Some(resp) = &self.completed_response {
return Some(resp.clone());
}
self.build_fallback_response_snapshot()
}
fn build_fallback_response_snapshot(&self) -> Option<Value> {
let mut response = self.initial_response.clone()?;
if let Some(obj) = response.as_object_mut() {
obj.insert("status".to_string(), Value::String("completed".to_string()));
let mut output_items = self.output_items.clone();
output_items.sort_by_key(|(index, _)| *index);
let outputs: Vec<Value> = output_items.into_iter().map(|(_, item)| item).collect();
obj.insert("output".to_string(), Value::Array(outputs));
}
Some(response)
}
fn process_block(&mut self, block: &str) {
let trimmed = block.trim();
if trimmed.is_empty() {
......@@ -208,17 +650,19 @@ impl StreamingResponseAccumulator {
.unwrap_or_default();
match event_type.as_str() {
"response.created" => {
if let Some(response) = parsed.get("response") {
self.initial_response = Some(response.clone());
event_types::RESPONSE_CREATED => {
if self.initial_response.is_none() {
if let Some(response) = parsed.get("response") {
self.initial_response = Some(response.clone());
}
}
}
"response.completed" => {
event_types::RESPONSE_COMPLETED => {
if let Some(response) = parsed.get("response") {
self.completed_response = Some(response.clone());
}
}
"response.output_item.done" => {
event_types::OUTPUT_ITEM_DONE => {
if let (Some(index), Some(item)) = (
parsed
.get("output_index")
......@@ -588,6 +1032,46 @@ impl OpenAIRouter {
payload: Value,
original_body: &ResponsesRequest,
original_previous_response_id: Option<String>,
) -> Response {
// Check if MCP is active for this request
let req_mcp_manager = Self::mcp_manager_from_request_tools(&original_body.tools).await;
let active_mcp = req_mcp_manager.as_ref().or(self.mcp_manager.as_ref());
// If no MCP is active, use simple pass-through streaming
if active_mcp.is_none() {
return self
.handle_simple_streaming_passthrough(
url,
headers,
payload,
original_body,
original_previous_response_id,
)
.await;
}
let active_mcp = active_mcp.unwrap();
// MCP is active - transform tools and set up interception
self.handle_streaming_with_tool_interception(
url,
headers,
payload,
original_body,
original_previous_response_id,
active_mcp,
)
.await
}
/// Simple pass-through streaming without MCP interception
async fn handle_simple_streaming_passthrough(
&self,
url: String,
headers: Option<&HeaderMap>,
payload: Value,
original_body: &ResponsesRequest,
original_previous_response_id: Option<String>,
) -> Response {
let mut request_builder = self.client.post(&url).json(&payload);
......@@ -738,50 +1222,1101 @@ impl OpenAIRouter {
response
}
async fn store_response_internal(
&self,
response_json: &Value,
original_body: &ResponsesRequest,
) -> Result<(), String> {
if !original_body.store {
return Ok(());
/// Apply all transformations to event data in-place (rewrite + transform)
/// Optimized to parse JSON only once instead of multiple times
/// Returns true if any changes were made
fn apply_event_transformations_inplace(
parsed_data: &mut Value,
server_label: &str,
original_request: &ResponsesRequest,
previous_response_id: Option<&str>,
) -> bool {
let mut changed = false;
// 1. Apply rewrite_streaming_block logic (store, previous_response_id, tools masking)
let event_type = parsed_data
.get("type")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.unwrap_or_default();
let should_patch = matches!(
event_type.as_str(),
event_types::RESPONSE_CREATED
| event_types::RESPONSE_IN_PROGRESS
| event_types::RESPONSE_COMPLETED
);
if should_patch {
if let Some(response_obj) = parsed_data
.get_mut("response")
.and_then(|v| v.as_object_mut())
{
let desired_store = Value::Bool(original_request.store);
if response_obj.get("store") != Some(&desired_store) {
response_obj.insert("store".to_string(), desired_store);
changed = true;
}
if let Some(prev_id) = previous_response_id {
let needs_previous = response_obj
.get("previous_response_id")
.map(|v| v.is_null() || v.as_str().map(|s| s.is_empty()).unwrap_or(false))
.unwrap_or(true);
if needs_previous {
response_obj.insert(
"previous_response_id".to_string(),
Value::String(prev_id.to_string()),
);
changed = true;
}
}
// Mask tools from function to MCP format (optimized without cloning)
if response_obj.get("tools").is_some() {
let requested_mcp = original_request
.tools
.iter()
.any(|t| matches!(t.r#type, ResponseToolType::Mcp));
if requested_mcp {
if let Some(mcp_tools) = Self::build_mcp_tools_value(original_request) {
response_obj.insert("tools".to_string(), mcp_tools);
response_obj
.entry("tool_choice".to_string())
.or_insert(Value::String("auto".to_string()));
changed = true;
}
}
}
}
}
match Self::store_response_impl(&self.response_storage, response_json, original_body).await
{
Ok(response_id) => {
info!(response_id = %response_id.0, "Stored response locally");
Ok(())
// 2. Apply transform_streaming_event logic (function_call → mcp_call)
match event_type.as_str() {
event_types::OUTPUT_ITEM_ADDED | event_types::OUTPUT_ITEM_DONE => {
if let Some(item) = parsed_data.get_mut("item") {
if let Some(item_type) = item.get("type").and_then(|v| v.as_str()) {
if item_type == event_types::ITEM_TYPE_FUNCTION_CALL
|| item_type == event_types::ITEM_TYPE_FUNCTION_TOOL_CALL
{
item["type"] = json!(event_types::ITEM_TYPE_MCP_CALL);
item["server_label"] = json!(server_label);
// Transform ID from fc_* to mcp_*
if let Some(id) = item.get("id").and_then(|v| v.as_str()) {
if let Some(stripped) = id.strip_prefix("fc_") {
let new_id = format!("mcp_{}", stripped);
item["id"] = json!(new_id);
}
}
changed = true;
}
}
}
}
Err(e) => Err(e),
event_types::FUNCTION_CALL_ARGUMENTS_DONE => {
parsed_data["type"] = json!(event_types::MCP_CALL_ARGUMENTS_DONE);
// Transform item_id from fc_* to mcp_*
if let Some(item_id) = parsed_data.get("item_id").and_then(|v| v.as_str()) {
if let Some(stripped) = item_id.strip_prefix("fc_") {
let new_id = format!("mcp_{}", stripped);
parsed_data["item_id"] = json!(new_id);
}
}
changed = true;
}
_ => {}
}
changed
}
async fn store_response_impl(
response_storage: &SharedResponseStorage,
response_json: &Value,
original_body: &ResponsesRequest,
) -> Result<ResponseId, String> {
let input_text = match &original_body.input {
ResponseInput::Text(text) => text.clone(),
ResponseInput::Items(_) => "complex input".to_string(),
/// Forward and transform a streaming event to the client
/// Returns false if client disconnected
#[allow(clippy::too_many_arguments)]
fn forward_streaming_event(
raw_block: &str,
event_name: Option<&str>,
data: &str,
handler: &mut StreamingToolHandler,
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
server_label: &str,
original_request: &ResponsesRequest,
previous_response_id: Option<&str>,
sequence_number: &mut u64,
) -> bool {
// Skip individual function_call_arguments.delta events - we'll send them as one
if event_name == Some(event_types::FUNCTION_CALL_ARGUMENTS_DELTA) {
return true;
}
// Parse JSON data once (optimized!)
let mut parsed_data: Value = match serde_json::from_str(data) {
Ok(v) => v,
Err(_) => {
// If parsing fails, forward raw block as-is
let chunk_to_send = format!("{}\n\n", raw_block);
return tx.send(Ok(Bytes::from(chunk_to_send))).is_ok();
}
};
let output_text = Self::extract_primary_output_text(response_json).unwrap_or_default();
let event_type = event_name
.or_else(|| parsed_data.get("type").and_then(|v| v.as_str()))
.unwrap_or("");
let mut stored_response = StoredResponse::new(input_text, output_text, None);
if event_type == event_types::RESPONSE_COMPLETED {
return true;
}
stored_response.instructions = response_json
.get("instructions")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.or_else(|| original_body.instructions.clone());
// Check if this is function_call_arguments.done - need to send buffered args first
let mut mapped_output_index: Option<usize> = None;
stored_response.model = response_json
.get("model")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.or_else(|| original_body.model.clone());
if event_name == Some(event_types::FUNCTION_CALL_ARGUMENTS_DONE) {
if let Some(output_index) = parsed_data
.get("output_index")
.and_then(|v| v.as_u64())
.map(|v| v as usize)
{
let assigned_index = handler
.mapped_output_index(output_index)
.unwrap_or(output_index);
mapped_output_index = Some(assigned_index);
if let Some(call) = handler
.pending_calls
.iter()
.find(|c| c.output_index == output_index)
{
let arguments_value = if call.arguments_buffer.is_empty() {
"{}".to_string()
} else {
call.arguments_buffer.clone()
};
// Make sure the done event carries full arguments
parsed_data["arguments"] = Value::String(arguments_value.clone());
// Get item_id and transform it
let item_id = parsed_data
.get("item_id")
.and_then(|v| v.as_str())
.unwrap_or("");
let mcp_item_id = if let Some(stripped) = item_id.strip_prefix("fc_") {
format!("mcp_{}", stripped)
} else {
item_id.to_string()
};
// Emit a synthetic MCP arguments delta event before the done event
let mut delta_event = json!({
"type": event_types::MCP_CALL_ARGUMENTS_DELTA,
"sequence_number": *sequence_number,
"output_index": assigned_index,
"item_id": mcp_item_id,
"delta": arguments_value,
});
if let Some(obfuscation) = call.last_obfuscation.as_ref() {
if let Some(obj) = delta_event.as_object_mut() {
obj.insert(
"obfuscation".to_string(),
Value::String(obfuscation.clone()),
);
}
} else if let Some(obfuscation) = parsed_data.get("obfuscation").cloned() {
if let Some(obj) = delta_event.as_object_mut() {
obj.insert("obfuscation".to_string(), obfuscation);
}
}
let delta_block = format!(
"event: {}\ndata: {}\n\n",
event_types::MCP_CALL_ARGUMENTS_DELTA,
delta_event
);
if tx.send(Ok(Bytes::from(delta_block))).is_err() {
return false;
}
*sequence_number += 1;
}
}
}
// Remap output_index (if present) so downstream sees sequential indices
if mapped_output_index.is_none() {
if let Some(output_index) = parsed_data
.get("output_index")
.and_then(|v| v.as_u64())
.map(|v| v as usize)
{
mapped_output_index = handler.mapped_output_index(output_index);
}
}
if let Some(mapped) = mapped_output_index {
parsed_data["output_index"] = json!(mapped);
}
// Apply all transformations in-place (single parse/serialize!)
Self::apply_event_transformations_inplace(
&mut parsed_data,
server_label,
original_request,
previous_response_id,
);
if let Some(response_obj) = parsed_data
.get_mut("response")
.and_then(|v| v.as_object_mut())
{
if let Some(original_id) = handler.original_response_id() {
response_obj.insert("id".to_string(), Value::String(original_id.to_string()));
}
}
// Update sequence number if present in the event
if parsed_data.get("sequence_number").is_some() {
parsed_data["sequence_number"] = json!(*sequence_number);
*sequence_number += 1;
}
// Serialize once
let final_data = match serde_json::to_string(&parsed_data) {
Ok(s) => s,
Err(_) => {
// Serialization failed, forward original
let chunk_to_send = format!("{}\n\n", raw_block);
return tx.send(Ok(Bytes::from(chunk_to_send))).is_ok();
}
};
// Rebuild SSE block with potentially transformed event name
let mut final_block = String::new();
if let Some(evt) = event_name {
// Update event name for function_call_arguments events
if evt == event_types::FUNCTION_CALL_ARGUMENTS_DELTA {
final_block.push_str(&format!(
"event: {}\n",
event_types::MCP_CALL_ARGUMENTS_DELTA
));
} else if evt == event_types::FUNCTION_CALL_ARGUMENTS_DONE {
final_block.push_str(&format!(
"event: {}\n",
event_types::MCP_CALL_ARGUMENTS_DONE
));
} else {
final_block.push_str(&format!("event: {}\n", evt));
}
}
final_block.push_str(&format!("data: {}", final_data));
let chunk_to_send = format!("{}\n\n", final_block);
if tx.send(Ok(Bytes::from(chunk_to_send))).is_err() {
return false;
}
// After sending output_item.added for mcp_call, inject mcp_call.in_progress event
if event_name == Some(event_types::OUTPUT_ITEM_ADDED) {
if let Some(item) = parsed_data.get("item") {
if item.get("type").and_then(|v| v.as_str())
== Some(event_types::ITEM_TYPE_MCP_CALL)
{
// Already transformed to mcp_call
if let (Some(item_id), Some(output_index)) = (
item.get("id").and_then(|v| v.as_str()),
parsed_data.get("output_index").and_then(|v| v.as_u64()),
) {
let in_progress_event = json!({
"type": event_types::MCP_CALL_IN_PROGRESS,
"sequence_number": *sequence_number,
"output_index": output_index,
"item_id": item_id
});
*sequence_number += 1;
let in_progress_block = format!(
"event: {}\ndata: {}\n\n",
event_types::MCP_CALL_IN_PROGRESS,
in_progress_event
);
if tx.send(Ok(Bytes::from(in_progress_block))).is_err() {
return false;
}
}
}
}
}
true
}
/// Execute detected tool calls and send completion events to client
/// Returns false if client disconnected during execution
async fn execute_streaming_tool_calls(
pending_calls: Vec<FunctionCallInProgress>,
active_mcp: &Arc<crate::mcp::McpClientManager>,
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
state: &mut ToolLoopState,
server_label: &str,
sequence_number: &mut u64,
) -> bool {
// Execute all pending tool calls (sequential, as PR3 is skipped)
for call in pending_calls {
// Skip if name is empty (invalid call)
if call.name.is_empty() {
warn!(
"Skipping incomplete tool call: name is empty, args_len={}",
call.arguments_buffer.len()
);
continue;
}
info!(
"Executing tool call during streaming: {} ({})",
call.name, call.call_id
);
// Use empty JSON object if arguments_buffer is empty
let args_str = if call.arguments_buffer.is_empty() {
"{}"
} else {
&call.arguments_buffer
};
let call_result = Self::execute_mcp_call(active_mcp, &call.name, args_str).await;
let (output_str, success, error_msg) = match call_result {
Ok((_, output)) => (output, true, None),
Err(err) => {
warn!("Tool execution failed during streaming: {}", err);
(json!({ "error": &err }).to_string(), false, Some(err))
}
};
// Send mcp_call completion event to client
if !OpenAIRouter::send_mcp_call_completion_events_with_error(
tx,
&call,
&output_str,
server_label,
success,
error_msg.as_deref(),
sequence_number,
) {
// Client disconnected, no point continuing tool execution
return false;
}
// Record the call
state.record_call(call.call_id, call.name, call.arguments_buffer, output_str);
}
true
}
/// Transform payload to replace MCP tools with function tools for streaming
fn prepare_mcp_payload_for_streaming(
payload: &mut Value,
active_mcp: &Arc<crate::mcp::McpClientManager>,
) {
if let Some(obj) = payload.as_object_mut() {
// Remove any non-function tools from outgoing payload
if let Some(v) = obj.get_mut("tools") {
if let Some(arr) = v.as_array_mut() {
arr.retain(|item| {
item.get("type")
.and_then(|v| v.as_str())
.map(|s| s == event_types::ITEM_TYPE_FUNCTION)
.unwrap_or(false)
});
}
}
// Build function tools for all discovered MCP tools
let mut tools_json = Vec::new();
let tools = active_mcp.list_tools();
for t in tools {
let parameters = t.parameters.clone().unwrap_or(serde_json::json!({
"type": "object",
"properties": {},
"additionalProperties": false
}));
let tool = serde_json::json!({
"type": event_types::ITEM_TYPE_FUNCTION,
"name": t.name,
"description": t.description,
"parameters": parameters
});
tools_json.push(tool);
}
if !tools_json.is_empty() {
obj.insert("tools".to_string(), Value::Array(tools_json));
obj.insert("tool_choice".to_string(), Value::String("auto".to_string()));
}
}
}
/// Handle streaming WITH MCP tool call interception and execution
async fn handle_streaming_with_tool_interception(
&self,
url: String,
headers: Option<&HeaderMap>,
mut payload: Value,
original_body: &ResponsesRequest,
original_previous_response_id: Option<String>,
active_mcp: &Arc<crate::mcp::McpClientManager>,
) -> Response {
// Transform MCP tools to function tools in payload
Self::prepare_mcp_payload_for_streaming(&mut payload, active_mcp);
let (tx, rx) = mpsc::unbounded_channel::<Result<Bytes, io::Error>>();
let should_store = original_body.store;
let storage = self.response_storage.clone();
let original_request = original_body.clone();
let previous_response_id = original_previous_response_id.clone();
let client = self.client.clone();
let url_clone = url.clone();
let headers_opt = headers.cloned();
let payload_clone = payload.clone();
let active_mcp_clone = Arc::clone(active_mcp);
// Spawn the streaming loop task
tokio::spawn(async move {
let mut state = ToolLoopState::new(original_request.input.clone());
let loop_config = McpLoopConfig::default();
let max_tool_calls = original_request.max_tool_calls.map(|n| n as usize);
let tools_json = payload_clone.get("tools").cloned().unwrap_or(json!([]));
let base_payload = payload_clone.clone();
let mut current_payload = payload_clone;
let mut mcp_list_tools_sent = false;
let mut is_first_iteration = true;
let mut sequence_number: u64 = 0; // Track global sequence number across all iterations
let mut next_output_index: usize = 0;
let mut preserved_response_id: Option<String> = None;
let server_label = original_request
.tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
.and_then(|t| t.server_label.as_deref())
.unwrap_or("mcp");
loop {
// Make streaming request
let mut request_builder = client.post(&url_clone).json(&current_payload);
if let Some(ref h) = headers_opt {
request_builder = apply_request_headers(h, request_builder, true);
}
request_builder = request_builder.header("Accept", "text/event-stream");
let response = match request_builder.send().await {
Ok(r) => r,
Err(e) => {
let error_event = format!(
"event: error\ndata: {{\"error\": {{\"message\": \"{}\"}}}}\n\n",
e
);
let _ = tx.send(Ok(Bytes::from(error_event)));
return;
}
};
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
let error_event = format!("event: error\ndata: {{\"error\": {{\"message\": \"Upstream error {}: {}\"}}}}\n\n", status, body);
let _ = tx.send(Ok(Bytes::from(error_event)));
return;
}
// Stream events and check for tool calls
let mut upstream_stream = response.bytes_stream();
let mut handler = StreamingToolHandler::with_starting_index(next_output_index);
if let Some(ref id) = preserved_response_id {
handler.original_response_id = Some(id.clone());
}
let mut pending = String::new();
let mut tool_calls_detected = false;
let mut seen_in_progress = false;
while let Some(chunk_result) = upstream_stream.next().await {
match chunk_result {
Ok(chunk) => {
let chunk_text = match std::str::from_utf8(&chunk) {
Ok(text) => Cow::Borrowed(text),
Err(_) => Cow::Owned(String::from_utf8_lossy(&chunk).to_string()),
};
pending.push_str(&chunk_text.replace("\r\n", "\n"));
while let Some(pos) = pending.find("\n\n") {
let raw_block = pending[..pos].to_string();
pending.drain(..pos + 2);
if raw_block.trim().is_empty() {
continue;
}
// Parse event
let (event_name, data) = Self::parse_sse_block(&raw_block);
if data.is_empty() {
continue;
}
// Process through handler
let action = handler.process_event(event_name, data.as_ref());
match action {
StreamAction::Forward => {
// Skip response.created and response.in_progress on subsequent iterations
// Do NOT consume their sequence numbers - we want continuous numbering
let should_skip = if !is_first_iteration {
if let Ok(parsed) =
serde_json::from_str::<Value>(data.as_ref())
{
matches!(
parsed.get("type").and_then(|v| v.as_str()),
Some(event_types::RESPONSE_CREATED)
| Some(event_types::RESPONSE_IN_PROGRESS)
)
} else {
false
}
} else {
false
};
if !should_skip {
// Forward the event
if !Self::forward_streaming_event(
&raw_block,
event_name,
data.as_ref(),
&mut handler,
&tx,
server_label,
&original_request,
previous_response_id.as_deref(),
&mut sequence_number,
) {
// Client disconnected
return;
}
}
// After forwarding response.in_progress, send mcp_list_tools events (once)
if !seen_in_progress {
if let Ok(parsed) =
serde_json::from_str::<Value>(data.as_ref())
{
if parsed.get("type").and_then(|v| v.as_str())
== Some(event_types::RESPONSE_IN_PROGRESS)
{
seen_in_progress = true;
if !mcp_list_tools_sent {
let list_tools_index = handler
.allocate_synthetic_output_index();
if !OpenAIRouter::send_mcp_list_tools_events(
&tx,
&active_mcp_clone,
server_label,
list_tools_index,
&mut sequence_number,
) {
// Client disconnected
return;
}
mcp_list_tools_sent = true;
}
}
}
}
}
StreamAction::Buffer => {
// Don't forward, just buffer
}
StreamAction::ExecuteTools => {
if !Self::forward_streaming_event(
&raw_block,
event_name,
data.as_ref(),
&mut handler,
&tx,
server_label,
&original_request,
previous_response_id.as_deref(),
&mut sequence_number,
) {
// Client disconnected
return;
}
tool_calls_detected = true;
break; // Exit stream processing to execute tools
}
}
}
if tool_calls_detected {
break;
}
}
Err(e) => {
let error_event = format!("event: error\ndata: {{\"error\": {{\"message\": \"Stream error: {}\"}}}}\n\n", e);
let _ = tx.send(Ok(Bytes::from(error_event)));
return;
}
}
}
next_output_index = handler.next_output_index();
if let Some(id) = handler.original_response_id().map(|s| s.to_string()) {
preserved_response_id = Some(id);
}
// If no tool calls, we're done - stream is complete
if !tool_calls_detected {
if !Self::send_final_response_event(
&handler,
&tx,
&mut sequence_number,
&state,
Some(&active_mcp_clone),
&original_request,
previous_response_id.as_deref(),
server_label,
) {
return;
}
// Send final events and done marker
if should_store {
if let Some(mut response_json) = handler.accumulator.into_final_response() {
if let Some(ref id) = preserved_response_id {
if let Some(obj) = response_json.as_object_mut() {
obj.insert("id".to_string(), Value::String(id.clone()));
}
}
Self::inject_mcp_metadata_streaming(
&mut response_json,
&state,
&active_mcp_clone,
server_label,
);
// Mask tools back to MCP format
Self::mask_tools_as_mcp(&mut response_json, &original_request);
Self::patch_streaming_response_json(
&mut response_json,
&original_request,
previous_response_id.as_deref(),
);
if let Err(err) = Self::store_response_impl(
&storage,
&response_json,
&original_request,
)
.await
{
warn!("Failed to store streaming response: {}", err);
}
}
}
let _ = tx.send(Ok(Bytes::from("data: [DONE]\n\n")));
return;
}
// Execute tools
let pending_calls = handler.take_pending_calls();
// Check iteration limit
state.iteration += 1;
state.total_calls += pending_calls.len();
let effective_limit = match max_tool_calls {
Some(user_max) => user_max.min(loop_config.max_iterations),
None => loop_config.max_iterations,
};
if state.total_calls > effective_limit {
warn!(
"Reached tool call limit during streaming: {}",
effective_limit
);
let error_event = "event: error\ndata: {\"error\": {\"message\": \"Exceeded max_tool_calls limit\"}}\n\n".to_string();
let _ = tx.send(Ok(Bytes::from(error_event)));
let _ = tx.send(Ok(Bytes::from("data: [DONE]\n\n")));
return;
}
// Execute all pending tool calls
if !Self::execute_streaming_tool_calls(
pending_calls,
&active_mcp_clone,
&tx,
&mut state,
server_label,
&mut sequence_number,
)
.await
{
// Client disconnected during tool execution
return;
}
// Build resume payload
match Self::build_resume_payload(
&base_payload,
&state.conversation_history,
&state.original_input,
&tools_json,
true, // is_streaming = true
) {
Ok(resume_payload) => {
current_payload = resume_payload;
// Mark that we're no longer on the first iteration
is_first_iteration = false;
// Continue loop to make next streaming request
}
Err(e) => {
let error_event = format!("event: error\ndata: {{\"error\": {{\"message\": \"Failed to build resume payload: {}\"}}}}\n\n", e);
let _ = tx.send(Ok(Bytes::from(error_event)));
let _ = tx.send(Ok(Bytes::from("data: [DONE]\n\n")));
return;
}
}
}
});
let body_stream = UnboundedReceiverStream::new(rx);
let mut response = Response::new(Body::from_stream(body_stream));
*response.status_mut() = StatusCode::OK;
response
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
response
}
/// Parse an SSE block into event name and data
///
/// Returns borrowed strings when possible to avoid allocations in hot paths.
/// Only allocates when multiple data lines need to be joined.
fn parse_sse_block(block: &str) -> (Option<&str>, std::borrow::Cow<'_, str>) {
let mut event_name: Option<&str> = None;
let mut data_lines: Vec<&str> = Vec::new();
for line in block.lines() {
if let Some(rest) = line.strip_prefix("event:") {
event_name = Some(rest.trim());
} else if let Some(rest) = line.strip_prefix("data:") {
data_lines.push(rest.trim_start());
}
}
let data = if data_lines.len() == 1 {
std::borrow::Cow::Borrowed(data_lines[0])
} else {
std::borrow::Cow::Owned(data_lines.join("\n"))
};
(event_name, data)
}
// Note: transform_streaming_event has been replaced by apply_event_transformations_inplace
// which is more efficient (parses JSON only once instead of twice)
/// Send mcp_list_tools events to client at the start of streaming
/// Returns false if client disconnected
fn send_mcp_list_tools_events(
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
mcp: &Arc<crate::mcp::McpClientManager>,
server_label: &str,
output_index: usize,
sequence_number: &mut u64,
) -> bool {
let tools_item_full = Self::build_mcp_list_tools_item(mcp, server_label);
let item_id = tools_item_full
.get("id")
.and_then(|v| v.as_str())
.unwrap_or("");
// Create empty tools version for the initial added event
let mut tools_item_empty = tools_item_full.clone();
if let Some(obj) = tools_item_empty.as_object_mut() {
obj.insert("tools".to_string(), json!([]));
}
// Event 1: response.output_item.added with empty tools
let event1_payload = json!({
"type": event_types::OUTPUT_ITEM_ADDED,
"sequence_number": *sequence_number,
"output_index": output_index,
"item": tools_item_empty
});
*sequence_number += 1;
let event1 = format!(
"event: {}\ndata: {}\n\n",
event_types::OUTPUT_ITEM_ADDED,
event1_payload
);
if tx.send(Ok(Bytes::from(event1))).is_err() {
return false; // Client disconnected
}
// Event 2: response.mcp_list_tools.in_progress
let event2_payload = json!({
"type": event_types::MCP_LIST_TOOLS_IN_PROGRESS,
"sequence_number": *sequence_number,
"output_index": output_index,
"item_id": item_id
});
*sequence_number += 1;
let event2 = format!(
"event: {}\ndata: {}\n\n",
event_types::MCP_LIST_TOOLS_IN_PROGRESS,
event2_payload
);
if tx.send(Ok(Bytes::from(event2))).is_err() {
return false;
}
// Event 3: response.mcp_list_tools.completed
let event3_payload = json!({
"type": event_types::MCP_LIST_TOOLS_COMPLETED,
"sequence_number": *sequence_number,
"output_index": output_index,
"item_id": item_id
});
*sequence_number += 1;
let event3 = format!(
"event: {}\ndata: {}\n\n",
event_types::MCP_LIST_TOOLS_COMPLETED,
event3_payload
);
if tx.send(Ok(Bytes::from(event3))).is_err() {
return false;
}
// Event 4: response.output_item.done with full tools list
let event4_payload = json!({
"type": event_types::OUTPUT_ITEM_DONE,
"sequence_number": *sequence_number,
"output_index": output_index,
"item": tools_item_full
});
*sequence_number += 1;
let event4 = format!(
"event: {}\ndata: {}\n\n",
event_types::OUTPUT_ITEM_DONE,
event4_payload
);
tx.send(Ok(Bytes::from(event4))).is_ok()
}
/// Send mcp_call completion events after tool execution
/// Returns false if client disconnected
fn send_mcp_call_completion_events_with_error(
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
call: &FunctionCallInProgress,
output: &str,
server_label: &str,
success: bool,
error_msg: Option<&str>,
sequence_number: &mut u64,
) -> bool {
let effective_output_index = call.effective_output_index();
// Build mcp_call item (reuse existing function)
let mcp_call_item = Self::build_mcp_call_item(
&call.name,
&call.arguments_buffer,
output,
server_label,
success,
error_msg,
);
// Get the mcp_call item_id
let item_id = mcp_call_item
.get("id")
.and_then(|v| v.as_str())
.unwrap_or("");
// Event 1: response.mcp_call.completed
let completed_payload = json!({
"type": event_types::MCP_CALL_COMPLETED,
"sequence_number": *sequence_number,
"output_index": effective_output_index,
"item_id": item_id
});
*sequence_number += 1;
let completed_event = format!(
"event: {}\ndata: {}\n\n",
event_types::MCP_CALL_COMPLETED,
completed_payload
);
if tx.send(Ok(Bytes::from(completed_event))).is_err() {
return false;
}
// Event 2: response.output_item.done (with completed mcp_call)
let done_payload = json!({
"type": event_types::OUTPUT_ITEM_DONE,
"sequence_number": *sequence_number,
"output_index": effective_output_index,
"item": mcp_call_item
});
*sequence_number += 1;
let done_event = format!(
"event: {}\ndata: {}\n\n",
event_types::OUTPUT_ITEM_DONE,
done_payload
);
tx.send(Ok(Bytes::from(done_event))).is_ok()
}
#[allow(clippy::too_many_arguments)]
fn send_final_response_event(
handler: &StreamingToolHandler,
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
sequence_number: &mut u64,
state: &ToolLoopState,
active_mcp: Option<&Arc<crate::mcp::McpClientManager>>,
original_request: &ResponsesRequest,
previous_response_id: Option<&str>,
server_label: &str,
) -> bool {
let mut final_response = match handler.snapshot_final_response() {
Some(resp) => resp,
None => {
warn!("Final response snapshot unavailable; skipping synthetic completion event");
return true;
}
};
if let Some(original_id) = handler.original_response_id() {
if let Some(obj) = final_response.as_object_mut() {
obj.insert("id".to_string(), Value::String(original_id.to_string()));
}
}
if let Some(mcp) = active_mcp {
Self::inject_mcp_metadata_streaming(&mut final_response, state, mcp, server_label);
}
Self::mask_tools_as_mcp(&mut final_response, original_request);
Self::patch_streaming_response_json(
&mut final_response,
original_request,
previous_response_id,
);
if let Some(obj) = final_response.as_object_mut() {
obj.insert("status".to_string(), Value::String("completed".to_string()));
}
let completed_payload = json!({
"type": event_types::RESPONSE_COMPLETED,
"sequence_number": *sequence_number,
"response": final_response
});
*sequence_number += 1;
let completed_event = format!(
"event: {}\ndata: {}\n\n",
event_types::RESPONSE_COMPLETED,
completed_payload
);
tx.send(Ok(Bytes::from(completed_event))).is_ok()
}
/// Inject MCP metadata into a streaming response
fn inject_mcp_metadata_streaming(
response: &mut Value,
state: &ToolLoopState,
mcp: &Arc<crate::mcp::McpClientManager>,
server_label: &str,
) {
if let Some(output_array) = response.get_mut("output").and_then(|v| v.as_array_mut()) {
output_array.retain(|item| {
item.get("type").and_then(|t| t.as_str())
!= Some(event_types::ITEM_TYPE_MCP_LIST_TOOLS)
});
let list_tools_item = Self::build_mcp_list_tools_item(mcp, server_label);
output_array.insert(0, list_tools_item);
let mcp_call_items =
Self::build_executed_mcp_call_items(&state.conversation_history, server_label);
let mut insert_pos = 1;
for item in mcp_call_items {
output_array.insert(insert_pos, item);
insert_pos += 1;
}
} else if let Some(obj) = response.as_object_mut() {
let mut output_items = Vec::new();
output_items.push(Self::build_mcp_list_tools_item(mcp, server_label));
output_items.extend(Self::build_executed_mcp_call_items(
&state.conversation_history,
server_label,
));
obj.insert("output".to_string(), Value::Array(output_items));
}
}
async fn store_response_internal(
&self,
response_json: &Value,
original_body: &ResponsesRequest,
) -> Result<(), String> {
if !original_body.store {
return Ok(());
}
match Self::store_response_impl(&self.response_storage, response_json, original_body).await
{
Ok(response_id) => {
info!(response_id = %response_id.0, "Stored response locally");
Ok(())
}
Err(e) => Err(e),
}
}
async fn store_response_impl(
response_storage: &SharedResponseStorage,
response_json: &Value,
original_body: &ResponsesRequest,
) -> Result<ResponseId, String> {
let input_text = match &original_body.input {
ResponseInput::Text(text) => text.clone(),
ResponseInput::Items(_) => "complex input".to_string(),
};
let output_text = Self::extract_primary_output_text(response_json).unwrap_or_default();
let mut stored_response = StoredResponse::new(input_text, output_text, None);
stored_response.instructions = response_json
.get("instructions")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.or_else(|| original_body.instructions.clone());
stored_response.model = response_json
.get("model")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.or_else(|| original_body.model.clone());
stored_response.user = response_json
.get("user")
......@@ -926,7 +2461,9 @@ impl OpenAIRouter {
let should_patch = matches!(
event_type,
"response.created" | "response.in_progress" | "response.completed"
event_types::RESPONSE_CREATED
| event_types::RESPONSE_IN_PROGRESS
| event_types::RESPONSE_COMPLETED
);
if !should_patch {
......@@ -1018,7 +2555,9 @@ impl OpenAIRouter {
for item in output {
let obj = item.as_object()?;
let t = obj.get("type")?.as_str()?;
if t == "function_tool_call" || t == "function_call" {
if t == event_types::ITEM_TYPE_FUNCTION_TOOL_CALL
|| t == event_types::ITEM_TYPE_FUNCTION_CALL
{
let call_id = obj
.get("call_id")
.and_then(|v| v.as_str())
......@@ -1038,6 +2577,40 @@ impl OpenAIRouter {
/// Replace returned tools with the original request's MCP tool block (if present) so
/// external clients see MCP semantics rather than internal function tools.
/// Build MCP tools array value without cloning entire response object
fn build_mcp_tools_value(original_body: &ResponsesRequest) -> Option<Value> {
let mcp_tool = original_body
.tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some())?;
let mut m = serde_json::Map::new();
m.insert("type".to_string(), Value::String("mcp".to_string()));
if let Some(label) = &mcp_tool.server_label {
m.insert("server_label".to_string(), Value::String(label.clone()));
}
if let Some(url) = &mcp_tool.server_url {
m.insert("server_url".to_string(), Value::String(url.clone()));
}
if let Some(desc) = &mcp_tool.server_description {
m.insert(
"server_description".to_string(),
Value::String(desc.clone()),
);
}
if let Some(req) = &mcp_tool.require_approval {
m.insert("require_approval".to_string(), Value::String(req.clone()));
}
if let Some(allowed) = &mcp_tool.allowed_tools {
m.insert(
"allowed_tools".to_string(),
Value::Array(allowed.iter().map(|s| Value::String(s.clone())).collect()),
);
}
Some(Value::Array(vec![Value::Object(m)]))
}
fn mask_tools_as_mcp(resp: &mut Value, original_body: &ResponsesRequest) {
let mcp_tool = original_body
.tools
......@@ -1108,6 +2681,7 @@ impl OpenAIRouter {
conversation_history: &[Value],
original_input: &ResponseInput,
tools_json: &Value,
is_streaming: bool,
) -> Result<Value, String> {
// Clone the base payload which already has cleaned fields
let mut payload = base_payload.clone();
......@@ -1152,8 +2726,8 @@ impl OpenAIRouter {
}
}
// Ensure non-streaming and no store to upstream
obj.insert("stream".to_string(), Value::Bool(false));
// Set streaming mode based on caller's context
obj.insert("stream".to_string(), Value::Bool(is_streaming));
obj.insert("store".to_string(), Value::Bool(false));
// Note: SGLang-specific fields were already removed from base_payload
......@@ -1170,7 +2744,9 @@ impl OpenAIRouter {
let mut mcp_call_items = Vec::new();
for item in conversation_history {
if item.get("type").and_then(|t| t.as_str()) == Some("function_call") {
if item.get("type").and_then(|t| t.as_str())
== Some(event_types::ITEM_TYPE_FUNCTION_CALL)
{
let call_id = item.get("call_id").and_then(|v| v.as_str()).unwrap_or("");
let tool_name = item.get("name").and_then(|v| v.as_str()).unwrap_or("");
let args = item
......@@ -1245,7 +2821,9 @@ impl OpenAIRouter {
// Find any function_call items and convert them to mcp_call (incomplete)
let mut mcp_call_items = Vec::new();
for item in output_array.iter() {
if item.get("type").and_then(|t| t.as_str()) == Some("function_tool_call") {
if item.get("type").and_then(|t| t.as_str())
== Some(event_types::ITEM_TYPE_FUNCTION_TOOL_CALL)
{
let tool_name = item.get("name").and_then(|v| v.as_str()).unwrap_or("");
let args = item
.get("arguments")
......@@ -1424,6 +3002,7 @@ impl OpenAIRouter {
&state.conversation_history,
&state.original_input,
&tools_json,
false, // is_streaming = false (non-streaming tool loop)
)?;
} else {
// No more tool calls, we're done
......@@ -1507,7 +3086,7 @@ impl OpenAIRouter {
json!({
"id": Self::generate_mcp_id("mcpl"),
"type": "mcp_list_tools",
"type": event_types::ITEM_TYPE_MCP_LIST_TOOLS,
"server_label": server_label,
"tools": tools_json
})
......@@ -1524,7 +3103,7 @@ impl OpenAIRouter {
) -> Value {
json!({
"id": Self::generate_mcp_id("mcp"),
"type": "mcp_call",
"type": event_types::ITEM_TYPE_MCP_CALL,
"status": if success { "completed" } else { "failed" },
"approval_request_id": Value::Null,
"arguments": arguments,
......
......@@ -608,29 +608,353 @@ async fn responses_handler(
if is_stream {
let request_id = format!("resp-{}", Uuid::new_v4());
let stream = stream::once(async move {
let chunk = json!({
"id": request_id,
"object": "response",
"created_at": timestamp,
"model": "mock-model",
"status": "in_progress",
"output": [{
"type": "message",
"role": "assistant",
"content": [{
"type": "output_text",
"text": "This is a mock responses streamed output."
// Check if this is an MCP tool call scenario
let has_tools = payload
.get("tools")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter().any(|tool| {
tool.get("type")
.and_then(|t| t.as_str())
.map(|t| t == "function")
.unwrap_or(false)
})
})
.unwrap_or(false);
let has_function_output = payload
.get("input")
.and_then(|v| v.as_array())
.map(|items| {
items.iter().any(|item| {
item.get("type")
.and_then(|t| t.as_str())
.map(|t| t == "function_call_output")
.unwrap_or(false)
})
})
.unwrap_or(false);
if has_tools && !has_function_output {
// First turn: emit streaming tool call events
let call_id = format!(
"call_{}",
Uuid::new_v4().to_string().split('-').next().unwrap()
);
let rid = request_id.clone();
let events = vec![
// response.created
Ok::<_, Infallible>(
Event::default().event("response.created").data(
json!({
"type": "response.created",
"response": {
"id": rid.clone(),
"object": "response",
"created_at": timestamp,
"model": "mock-model",
"status": "in_progress"
}
})
.to_string(),
),
),
// response.in_progress
Ok(Event::default().event("response.in_progress").data(
json!({
"type": "response.in_progress",
"response": {
"id": rid.clone(),
"object": "response",
"created_at": timestamp,
"model": "mock-model",
"status": "in_progress"
}
})
.to_string(),
)),
// response.output_item.added with function_tool_call
Ok(Event::default().event("response.output_item.added").data(
json!({
"type": "response.output_item.added",
"output_index": 0,
"item": {
"id": call_id.clone(),
"type": "function_tool_call",
"name": "brave_web_search",
"arguments": "",
"status": "in_progress"
}
})
.to_string(),
)),
// response.function_call_arguments.delta events
Ok(Event::default()
.event("response.function_call_arguments.delta")
.data(
json!({
"type": "response.function_call_arguments.delta",
"output_index": 0,
"item_id": call_id.clone(),
"delta": "{\"query\""
})
.to_string(),
)),
Ok(Event::default()
.event("response.function_call_arguments.delta")
.data(
json!({
"type": "response.function_call_arguments.delta",
"output_index": 0,
"item_id": call_id.clone(),
"delta": ":\"SGLang"
})
.to_string(),
)),
Ok(Event::default()
.event("response.function_call_arguments.delta")
.data(
json!({
"type": "response.function_call_arguments.delta",
"output_index": 0,
"item_id": call_id.clone(),
"delta": " router MCP"
})
.to_string(),
)),
Ok(Event::default()
.event("response.function_call_arguments.delta")
.data(
json!({
"type": "response.function_call_arguments.delta",
"output_index": 0,
"item_id": call_id.clone(),
"delta": " integration\"}"
})
.to_string(),
)),
// response.function_call_arguments.done
Ok(Event::default()
.event("response.function_call_arguments.done")
.data(
json!({
"type": "response.function_call_arguments.done",
"output_index": 0,
"item_id": call_id.clone()
})
.to_string(),
)),
// response.output_item.done
Ok(Event::default().event("response.output_item.done").data(
json!({
"type": "response.output_item.done",
"output_index": 0,
"item": {
"id": call_id.clone(),
"type": "function_tool_call",
"name": "brave_web_search",
"arguments": "{\"query\":\"SGLang router MCP integration\"}",
"status": "completed"
}
})
.to_string(),
)),
// response.completed
Ok(Event::default().event("response.completed").data(
json!({
"type": "response.completed",
"response": {
"id": rid,
"object": "response",
"created_at": timestamp,
"model": "mock-model",
"status": "completed"
}
})
.to_string(),
)),
// [DONE]
Ok(Event::default().data("[DONE]")),
];
let stream = stream::iter(events);
Sse::new(stream)
.keep_alive(KeepAlive::default())
.into_response()
} else if has_tools && has_function_output {
// Second turn: emit streaming text response
let rid = request_id.clone();
let msg_id = format!(
"msg_{}",
Uuid::new_v4().to_string().split('-').next().unwrap()
);
let events = vec![
// response.created
Ok::<_, Infallible>(
Event::default().event("response.created").data(
json!({
"type": "response.created",
"response": {
"id": rid.clone(),
"object": "response",
"created_at": timestamp,
"model": "mock-model",
"status": "in_progress"
}
})
.to_string(),
),
),
// response.in_progress
Ok(Event::default().event("response.in_progress").data(
json!({
"type": "response.in_progress",
"response": {
"id": rid.clone(),
"object": "response",
"created_at": timestamp,
"model": "mock-model",
"status": "in_progress"
}
})
.to_string(),
)),
// response.output_item.added with message
Ok(Event::default().event("response.output_item.added").data(
json!({
"type": "response.output_item.added",
"output_index": 0,
"item": {
"id": msg_id.clone(),
"type": "message",
"role": "assistant",
"content": []
}
})
.to_string(),
)),
// response.content_part.added
Ok(Event::default().event("response.content_part.added").data(
json!({
"type": "response.content_part.added",
"output_index": 0,
"item_id": msg_id.clone(),
"part": {
"type": "output_text",
"text": ""
}
})
.to_string(),
)),
// response.output_text.delta events
Ok(Event::default().event("response.output_text.delta").data(
json!({
"type": "response.output_text.delta",
"output_index": 0,
"content_index": 0,
"delta": "Tool result"
})
.to_string(),
)),
Ok(Event::default().event("response.output_text.delta").data(
json!({
"type": "response.output_text.delta",
"output_index": 0,
"content_index": 0,
"delta": " consumed;"
})
.to_string(),
)),
Ok(Event::default().event("response.output_text.delta").data(
json!({
"type": "response.output_text.delta",
"output_index": 0,
"content_index": 0,
"delta": " here is the final answer."
})
.to_string(),
)),
// response.output_text.done
Ok(Event::default().event("response.output_text.done").data(
json!({
"type": "response.output_text.done",
"output_index": 0,
"content_index": 0,
"text": "Tool result consumed; here is the final answer."
})
.to_string(),
)),
// response.output_item.done
Ok(Event::default().event("response.output_item.done").data(
json!({
"type": "response.output_item.done",
"output_index": 0,
"item": {
"id": msg_id,
"type": "message",
"role": "assistant",
"content": [{
"type": "output_text",
"text": "Tool result consumed; here is the final answer."
}]
}
})
.to_string(),
)),
// response.completed
Ok(Event::default().event("response.completed").data(
json!({
"type": "response.completed",
"response": {
"id": rid,
"object": "response",
"created_at": timestamp,
"model": "mock-model",
"status": "completed",
"usage": {
"input_tokens": 12,
"output_tokens": 7,
"total_tokens": 19
}
}
})
.to_string(),
)),
// [DONE]
Ok(Event::default().data("[DONE]")),
];
let stream = stream::iter(events);
Sse::new(stream)
.keep_alive(KeepAlive::default())
.into_response()
} else {
// Default streaming response
let stream = stream::once(async move {
let chunk = json!({
"id": request_id,
"object": "response",
"created_at": timestamp,
"model": "mock-model",
"status": "in_progress",
"output": [{
"type": "message",
"role": "assistant",
"content": [{
"type": "output_text",
"text": "This is a mock responses streamed output."
}]
}]
}]
});
Ok::<_, Infallible>(Event::default().data(chunk.to_string()))
})
.chain(stream::once(async { Ok(Event::default().data("[DONE]")) }));
});
Ok::<_, Infallible>(Event::default().data(chunk.to_string()))
})
.chain(stream::once(async { Ok(Event::default().data("[DONE]")) }));
Sse::new(stream)
.keep_alive(KeepAlive::default())
.into_response()
Sse::new(stream)
.keep_alive(KeepAlive::default())
.into_response()
}
} else if is_background {
let rid = req_id.unwrap_or_else(|| format!("resp-{}", Uuid::new_v4()));
Json(json!({
......
......@@ -765,3 +765,464 @@ async fn test_max_tool_calls_limit() {
worker.stop().await;
mcp.stop().await;
}
/// Helper function to set up common test infrastructure for streaming MCP tests
/// Returns (mcp_server, worker, router, temp_dir)
async fn setup_streaming_mcp_test() -> (
MockMCPServer,
MockWorker,
Box<dyn sglang_router_rs::routers::RouterTrait>,
tempfile::TempDir,
) {
let mcp = MockMCPServer::start().await.expect("start mcp");
let mcp_yaml = format!(
"servers:\n - name: mock\n protocol: streamable\n url: {}\n",
mcp.url()
);
let dir = tempfile::tempdir().expect("tmpdir");
let cfg_path = dir.path().join("mcp.yaml");
std::fs::write(&cfg_path, mcp_yaml).expect("write mcp cfg");
let mut worker = MockWorker::new(MockWorkerConfig {
port: 0,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
});
let worker_url = worker.start().await.expect("start worker");
let router_cfg = RouterConfig {
mode: RoutingMode::OpenAI {
worker_urls: vec![worker_url],
},
connection_mode: ConnectionMode::Http,
policy: PolicyConfig::Random,
host: "127.0.0.1".to_string(),
port: 0,
max_payload_size: 8 * 1024 * 1024,
request_timeout_secs: 60,
worker_startup_timeout_secs: 5,
worker_startup_check_interval_secs: 1,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,
log_level: Some("info".to_string()),
request_id_headers: None,
max_concurrent_requests: 32,
queue_size: 0,
queue_timeout_secs: 5,
rate_limit_tokens_per_second: None,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: HealthCheckConfig::default(),
enable_igw: false,
model_path: None,
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
};
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx");
let router = RouterFactory::create_router(&Arc::new(ctx))
.await
.expect("router");
(mcp, worker, router, dir)
}
/// Parse SSE (Server-Sent Events) stream into structured events
fn parse_sse_events(body: &str) -> Vec<(Option<String>, serde_json::Value)> {
let mut events = Vec::new();
let blocks: Vec<&str> = body
.split("\n\n")
.filter(|s| !s.trim().is_empty())
.collect();
for block in blocks {
let mut event_name: Option<String> = None;
let mut data_lines: Vec<String> = Vec::new();
for line in block.lines() {
if let Some(rest) = line.strip_prefix("event:") {
event_name = Some(rest.trim().to_string());
} else if let Some(rest) = line.strip_prefix("data:") {
let data = rest.trim_start();
// Skip [DONE] marker
if data != "[DONE]" {
data_lines.push(data.to_string());
}
}
}
if !data_lines.is_empty() {
let data = data_lines.join("\n");
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&data) {
events.push((event_name, parsed));
}
}
}
events
}
#[tokio::test]
async fn test_streaming_with_mcp_tool_calls() {
// This test verifies that streaming works with MCP tool calls:
// 1. Initial streaming request with MCP tools
// 2. Mock worker streams text, then function_call deltas
// 3. Router buffers function call, executes MCP tool
// 4. Router resumes streaming with tool results
// 5. Mock worker streams final answer
// 6. Verify SSE events are properly formatted
let (mut mcp, mut worker, router, _dir) = setup_streaming_mcp_test().await;
// Build streaming request with MCP tools
let req = ResponsesRequest {
background: false,
include: None,
input: ResponseInput::Text("search for something interesting".to_string()),
instructions: Some("Use tools when needed".to_string()),
max_output_tokens: Some(256),
max_tool_calls: Some(3),
metadata: None,
model: Some("mock-model".to_string()),
parallel_tool_calls: true,
previous_response_id: None,
reasoning: None,
service_tier: ServiceTier::Auto,
store: true,
stream: true, // KEY: Enable streaming
temperature: Some(0.7),
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
tools: vec![ResponseTool {
r#type: ResponseToolType::Mcp,
server_url: Some(mcp.url()),
server_label: Some("mock".to_string()),
server_description: Some("Mock MCP for streaming test".to_string()),
require_approval: Some("never".to_string()),
..Default::default()
}],
top_logprobs: 0,
top_p: Some(1.0),
truncation: Truncation::Disabled,
user: None,
request_id: "resp_streaming_mcp_test".to_string(),
priority: 0,
frequency_penalty: 0.0,
presence_penalty: 0.0,
stop: None,
top_k: 50,
min_p: 0.0,
repetition_penalty: 1.0,
};
let response = router.route_responses(None, &req, None).await;
// Verify streaming response
assert_eq!(
response.status(),
axum::http::StatusCode::OK,
"Streaming request should succeed"
);
// Check Content-Type is text/event-stream
let content_type = response
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok());
assert_eq!(
content_type,
Some("text/event-stream"),
"Should have SSE content type"
);
// Read the streaming body
use axum::body::to_bytes;
let response_body = response.into_body();
let body_bytes = to_bytes(response_body, usize::MAX).await.unwrap();
let body_text = String::from_utf8_lossy(&body_bytes);
println!("Streaming SSE response:\n{}", body_text);
// Parse all SSE events into structured format
let events = parse_sse_events(&body_text);
assert!(!events.is_empty(), "Should have at least one SSE event");
println!("Total parsed SSE events: {}", events.len());
// Check for [DONE] marker
let has_done_marker = body_text.contains("data: [DONE]");
assert!(has_done_marker, "Stream should end with [DONE] marker");
// Track which events we've seen
let mut found_mcp_list_tools = false;
let mut found_mcp_list_tools_in_progress = false;
let mut found_mcp_list_tools_completed = false;
let mut found_response_created = false;
let mut found_mcp_call_added = false;
let mut found_mcp_call_in_progress = false;
let mut found_mcp_call_arguments = false;
let mut found_mcp_call_arguments_done = false;
let mut found_mcp_call_done = false;
let mut found_response_completed = false;
for (event_name, data) in &events {
let event_type = data.get("type").and_then(|v| v.as_str()).unwrap_or("");
match event_type {
"response.output_item.added" => {
// Check if it's an mcp_list_tools item
if let Some(item) = data.get("item") {
if item.get("type").and_then(|v| v.as_str()) == Some("mcp_list_tools") {
found_mcp_list_tools = true;
println!("✓ Found mcp_list_tools added event");
// Verify tools array is present (should be empty in added event)
assert!(
item.get("tools").is_some(),
"mcp_list_tools should have tools array"
);
} else if item.get("type").and_then(|v| v.as_str()) == Some("mcp_call") {
found_mcp_call_added = true;
println!("✓ Found mcp_call added event");
// Verify mcp_call has required fields
assert!(item.get("name").is_some(), "mcp_call should have name");
assert_eq!(
item.get("server_label").and_then(|v| v.as_str()),
Some("mock"),
"mcp_call should have server_label"
);
}
}
}
"response.mcp_list_tools.in_progress" => {
found_mcp_list_tools_in_progress = true;
println!("✓ Found mcp_list_tools.in_progress event");
// Verify it has output_index and item_id
assert!(
data.get("output_index").is_some(),
"mcp_list_tools.in_progress should have output_index"
);
assert!(
data.get("item_id").is_some(),
"mcp_list_tools.in_progress should have item_id"
);
}
"response.mcp_list_tools.completed" => {
found_mcp_list_tools_completed = true;
println!("✓ Found mcp_list_tools.completed event");
// Verify it has output_index and item_id
assert!(
data.get("output_index").is_some(),
"mcp_list_tools.completed should have output_index"
);
assert!(
data.get("item_id").is_some(),
"mcp_list_tools.completed should have item_id"
);
}
"response.mcp_call.in_progress" => {
found_mcp_call_in_progress = true;
println!("✓ Found mcp_call.in_progress event");
// Verify it has output_index and item_id
assert!(
data.get("output_index").is_some(),
"mcp_call.in_progress should have output_index"
);
assert!(
data.get("item_id").is_some(),
"mcp_call.in_progress should have item_id"
);
}
"response.mcp_call_arguments.delta" => {
found_mcp_call_arguments = true;
println!("✓ Found mcp_call_arguments.delta event");
// Delta should include arguments payload
assert!(
data.get("delta").is_some(),
"mcp_call_arguments.delta should include delta text"
);
}
"response.mcp_call_arguments.done" => {
found_mcp_call_arguments_done = true;
println!("✓ Found mcp_call_arguments.done event");
assert!(
data.get("arguments").is_some(),
"mcp_call_arguments.done should include full arguments"
);
}
"response.output_item.done" => {
if let Some(item) = data.get("item") {
if item.get("type").and_then(|v| v.as_str()) == Some("mcp_call") {
found_mcp_call_done = true;
println!("✓ Found mcp_call done event");
// Verify mcp_call.done has output
assert!(
item.get("output").is_some(),
"mcp_call done should have output"
);
}
}
}
"response.created" => {
found_response_created = true;
println!("✓ Found response.created event");
// Verify response has required fields
assert!(
data.get("response").is_some(),
"response.created should have response object"
);
}
"response.completed" => {
found_response_completed = true;
println!("✓ Found response.completed event");
}
_ => {
println!(" Other event: {}", event_type);
}
}
if let Some(name) = event_name {
println!(" Event name: {}", name);
}
}
// Verify key events were present
println!("\n=== Event Summary ===");
println!("MCP list_tools added: {}", found_mcp_list_tools);
println!(
"MCP list_tools in_progress: {}",
found_mcp_list_tools_in_progress
);
println!(
"MCP list_tools completed: {}",
found_mcp_list_tools_completed
);
println!("Response created: {}", found_response_created);
println!("MCP call added: {}", found_mcp_call_added);
println!("MCP call in_progress: {}", found_mcp_call_in_progress);
println!("MCP call arguments delta: {}", found_mcp_call_arguments);
println!("MCP call arguments done: {}", found_mcp_call_arguments_done);
println!("MCP call done: {}", found_mcp_call_done);
println!("Response completed: {}", found_response_completed);
// Assert critical events are present
assert!(
found_mcp_list_tools,
"Should send mcp_list_tools added event at the start"
);
assert!(
found_mcp_list_tools_in_progress,
"Should send mcp_list_tools.in_progress event"
);
assert!(
found_mcp_list_tools_completed,
"Should send mcp_list_tools.completed event"
);
assert!(found_response_created, "Should send response.created event");
assert!(found_mcp_call_added, "Should send mcp_call added event");
assert!(
found_mcp_call_in_progress,
"Should send mcp_call.in_progress event"
);
assert!(found_mcp_call_done, "Should send mcp_call done event");
assert!(
found_mcp_call_arguments,
"Should send mcp_call_arguments.delta event"
);
assert!(
found_mcp_call_arguments_done,
"Should send mcp_call_arguments.done event"
);
// Verify no error events
let has_error = body_text.contains("event: error");
assert!(!has_error, "Should not have error events");
worker.stop().await;
mcp.stop().await;
}
#[tokio::test]
async fn test_streaming_multi_turn_with_mcp() {
// Test streaming with multiple tool call rounds
let (mut mcp, mut worker, router, _dir) = setup_streaming_mcp_test().await;
let req = ResponsesRequest {
background: false,
include: None,
input: ResponseInput::Text("complex query requiring multiple tool calls".to_string()),
instructions: Some("Be thorough".to_string()),
max_output_tokens: Some(512),
max_tool_calls: Some(5), // Allow multiple rounds
metadata: None,
model: Some("mock-model".to_string()),
parallel_tool_calls: true,
previous_response_id: None,
reasoning: None,
service_tier: ServiceTier::Auto,
store: true,
stream: true,
temperature: Some(0.8),
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
tools: vec![ResponseTool {
r#type: ResponseToolType::Mcp,
server_url: Some(mcp.url()),
server_label: Some("mock".to_string()),
..Default::default()
}],
top_logprobs: 0,
top_p: Some(1.0),
truncation: Truncation::Disabled,
user: None,
request_id: "resp_streaming_multiturn_test".to_string(),
priority: 0,
frequency_penalty: 0.0,
presence_penalty: 0.0,
stop: None,
top_k: 50,
min_p: 0.0,
repetition_penalty: 1.0,
};
let response = router.route_responses(None, &req, None).await;
assert_eq!(response.status(), axum::http::StatusCode::OK);
use axum::body::to_bytes;
let body_bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
let body_text = String::from_utf8_lossy(&body_bytes);
println!("Multi-turn streaming response:\n{}", body_text);
// Verify streaming completed successfully
assert!(body_text.contains("data: [DONE]"));
assert!(!body_text.contains("event: error"));
// Count events
let event_count = body_text
.split("\n\n")
.filter(|s| !s.trim().is_empty())
.count();
println!("Total events in multi-turn stream: {}", event_count);
assert!(event_count > 0, "Should have received streaming events");
worker.stop().await;
mcp.stop().await;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment