Unverified Commit 700daa34 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] harmony responses api streaming support (#12395)

parent 39cee0fe
......@@ -43,7 +43,9 @@ pub use builder::HarmonyBuilder;
pub use detector::HarmonyDetector;
pub use parser::HarmonyParserAdapter;
pub use processor::{HarmonyResponseProcessor, ResponsesIterationResult};
pub use responses::{serve_harmony_responses, HarmonyResponsesContext};
pub use responses::{
serve_harmony_responses, serve_harmony_responses_stream, HarmonyResponsesContext,
};
pub use stages::{
HarmonyPreparationStage, HarmonyRequestBuildingStage, HarmonyResponseProcessingStage,
};
......
......@@ -103,7 +103,7 @@ impl HarmonyParserAdapter {
/// # Returns
///
/// Tuple of (analysis, commentary, final_text)
fn parse_messages(
pub fn parse_messages(
messages: &[openai_harmony::chat::Message],
) -> (Option<String>, Option<Vec<ToolCall>>, String) {
let mut analysis = None;
......@@ -260,6 +260,51 @@ impl HarmonyParserAdapter {
self.parser.messages().to_vec()
}
/// Extract incomplete commentary content from parser state
///
/// When the stream ends, there may be incomplete commentary content in the parser
/// that hasn't been finalized into a completed message. This method extracts
/// such content and converts it to tool calls.
///
/// # Returns
///
/// Optional vector of ToolCall if incomplete commentary is found
pub fn extract_incomplete_commentary(&self) -> Option<Vec<ToolCall>> {
// Check if current channel is commentary
let current_channel = self.parser.current_channel();
if current_channel.as_deref() != Some("commentary") {
return None;
}
// Get current recipient (should be "functions.{name}")
let recipient = self.parser.current_recipient()?;
if !recipient.starts_with("functions.") {
return None;
}
// Get current incomplete content
let content = self.parser.current_content().ok()?;
if content.is_empty() {
return None;
}
// Extract function name from recipient
let function_name = recipient.strip_prefix("functions.").unwrap();
// Create tool call from incomplete content
let call_id = format!("call_{}", Uuid::new_v4());
let tool_call = ToolCall {
id: call_id,
tool_type: "function".to_string(),
function: FunctionCallResponse {
name: function_name.to_string(),
arguments: Some(content),
},
};
Some(vec![tool_call])
}
/// Parse streaming chunk
///
/// Parses incremental token IDs and returns a delta with any new content
......
......@@ -36,10 +36,17 @@
//! See `/Users/simolin/workspace/sglang/.claude/docs/harmony_pipeline/tool_loop_design.md`
//! for complete architecture, rationale, and implementation details.
use std::sync::Arc;
use std::{
sync::Arc,
time::{SystemTime, UNIX_EPOCH},
};
use axum::response::Response;
use axum::{body::Body, http::StatusCode, response::Response};
use serde_json::Value as JsonValue;
use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, warn};
use uuid::Uuid;
use crate::{
data_connector::{ResponseId, ResponseStorage},
......@@ -47,13 +54,19 @@ use crate::{
protocols::{
common::{Function, ToolCall},
responses::{
ResponseInput, ResponseInputOutputItem, ResponseTool, ResponsesRequest,
ResponsesResponse, StringOrContentParts,
ResponseInput, ResponseInputOutputItem, ResponseTool, ResponseToolType,
ResponsesRequest, ResponsesResponse, StringOrContentParts,
},
},
routers::grpc::{
context::SharedComponents, harmony::processor::ResponsesIterationResult,
pipeline::RequestPipeline, utils,
routers::{
grpc::{
context::SharedComponents,
harmony::processor::ResponsesIterationResult,
pipeline::RequestPipeline,
responses::streaming::{OutputItemType, ResponseStreamEventEmitter},
utils,
},
openai::mcp::ensure_request_mcp_client,
},
};
......@@ -93,7 +106,7 @@ struct McpCallTracking {
}
impl McpCallTracking {
fn new(server_label: String) -> Self {
pub fn new(server_label: String) -> Self {
Self {
server_label,
tool_calls: Vec::new(),
......@@ -143,7 +156,7 @@ pub struct HarmonyResponsesContext {
pub response_storage: Arc<dyn ResponseStorage>,
/// Optional streaming sender (for future streaming support)
pub stream_tx: Option<tokio::sync::mpsc::UnboundedSender<Result<String, String>>>,
pub stream_tx: Option<mpsc::UnboundedSender<Result<String, String>>>,
}
impl HarmonyResponsesContext {
......@@ -169,7 +182,7 @@ impl HarmonyResponsesContext {
components: Arc<SharedComponents>,
mcp_manager: Arc<McpManager>,
response_storage: Arc<dyn ResponseStorage>,
stream_tx: tokio::sync::mpsc::UnboundedSender<Result<String, String>>,
stream_tx: mpsc::UnboundedSender<Result<String, String>>,
) -> Self {
Self {
pipeline,
......@@ -226,12 +239,6 @@ pub async fn serve_harmony_responses(
let mut current_request = load_previous_messages(ctx, request).await?;
let mut iteration_count = 0;
// Check if request has MCP tools - if so, ensure dynamic client is registered
// and add static MCP tools to the request
use crate::{
protocols::responses::ResponseToolType, routers::openai::mcp::ensure_request_mcp_client,
};
let has_mcp_tools = current_request
.tools
.as_ref()
......@@ -265,7 +272,7 @@ pub async fn serve_harmony_responses(
all_tools.extend(mcp_response_tools);
current_request.tools = Some(all_tools);
tracing::debug!(
debug!(
mcp_tool_count = mcp_tools.len(),
total_tool_count = current_request.tools.as_ref().map(|t| t.len()).unwrap_or(0),
"Request has MCP tools - added static MCP tools to Harmony Responses request"
......@@ -284,7 +291,7 @@ pub async fn serve_harmony_responses(
)));
}
tracing::debug!(
debug!(
iteration = iteration_count,
"Harmony Responses serving iteration"
);
......@@ -308,7 +315,7 @@ pub async fn serve_harmony_responses(
analysis,
partial_text,
} => {
tracing::debug!(
debug!(
tool_call_count = tool_calls.len(),
has_analysis = analysis.is_some(),
partial_text_len = partial_text.len(),
......@@ -347,7 +354,7 @@ pub async fn serve_harmony_responses(
mut response,
usage,
} => {
tracing::debug!(
debug!(
output_items = response.output.len(),
input_tokens = usage.prompt_tokens,
output_tokens = usage.completion_tokens,
......@@ -360,7 +367,7 @@ pub async fn serve_harmony_responses(
if let Some(tracking) = mcp_tracking {
inject_mcp_metadata(&mut response, &tracking, &ctx.mcp_manager);
tracing::debug!(
debug!(
mcp_calls = tracking.total_calls(),
output_items_after = response.output.len(),
"Injected MCP metadata into final response"
......@@ -375,6 +382,327 @@ pub async fn serve_harmony_responses(
}
}
/// Serve Harmony Responses API with streaming (SSE)
///
/// This is the streaming equivalent of `serve_harmony_responses()`.
/// Emits SSE events for lifecycle, MCP list_tools, and per-iteration streaming.
///
/// # Architecture
///
/// - Emits `response.created` and `response.in_progress` at start
/// - Emits `mcp_list_tools` events on first iteration (if MCP tools available)
/// - Loops through tool execution iterations (max 10)
/// - Calls `streaming::process_responses_iteration_stream()` for per-iteration events
/// - Emits `response.completed` at end
/// - Handles errors with `response.failed`
///
/// # Arguments
///
/// * `ctx` - Harmony responses context with pipeline and dependencies
/// * `request` - Responses API request
///
/// # Returns
///
/// SSE stream response with proper headers
pub async fn serve_harmony_responses_stream(
ctx: &HarmonyResponsesContext,
request: ResponsesRequest,
) -> Response {
// Load previous conversation history if previous_response_id is set
let mut current_request = match load_previous_messages(ctx, request).await {
Ok(req) => req,
Err(err_response) => return err_response,
};
use std::io;
use bytes::Bytes;
// Create SSE channel
let (tx, rx) = mpsc::unbounded_channel();
let stream = UnboundedReceiverStream::new(rx);
// Create response event emitter
let response_id = format!("resp_{}", Uuid::new_v4());
let model = current_request.model.clone();
let created_at = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let mut emitter = ResponseStreamEventEmitter::new(response_id.clone(), model, created_at);
// Clone context for spawned task
let ctx_clone = ctx.clone();
// Spawn async task to handle streaming
tokio::spawn(async move {
let ctx = &ctx_clone;
// Clone response_id for closure to avoid borrow conflicts
let response_id_for_error = response_id.clone();
// Helper to emit error and return
let emit_error = |tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>, error_msg: &str| {
// Create error event manually since emit_failed doesn't exist
let event = serde_json::json!({
"type": "response.failed",
"response_id": response_id_for_error,
"error": {
"message": error_msg,
"type": "internal_error"
}
});
let sse_data = format!("data: {}\n\n", serde_json::to_string(&event).unwrap());
let _ = tx.send(Ok(Bytes::from(sse_data)));
};
// Emit initial response.created and response.in_progress events
let event = emitter.emit_created();
if emitter.send_event(&event, &tx).is_err() {
return;
}
let event = emitter.emit_in_progress();
if emitter.send_event(&event, &tx).is_err() {
return;
}
// Check if request has MCP tools
let has_mcp_tools = current_request
.tools
.as_ref()
.map(|tools| {
tools
.iter()
.any(|t| matches!(t.r#type, ResponseToolType::Mcp))
})
.unwrap_or(false);
// Initialize MCP call tracking
let mut mcp_tracking = if has_mcp_tools {
Some(McpCallTracking::new("sglang-mcp".to_string()))
} else {
None
};
// Setup MCP tools if needed
if has_mcp_tools {
// Ensure dynamic MCP client is registered
if let Some(tools) = &current_request.tools {
ensure_request_mcp_client(&ctx.mcp_manager, tools).await;
}
// Add static MCP tools from inventory
let mcp_tools = ctx.mcp_manager.list_tools();
if !mcp_tools.is_empty() {
let mcp_response_tools = convert_mcp_tools_to_response_tools(&mcp_tools);
let mut all_tools = current_request.tools.clone().unwrap_or_default();
all_tools.extend(mcp_response_tools);
current_request.tools = Some(all_tools);
debug!(
mcp_tool_count = mcp_tools.len(),
total_tool_count = current_request.tools.as_ref().map(|t| t.len()).unwrap_or(0),
"Added static MCP tools to Harmony Responses streaming request"
);
}
}
// Emit mcp_list_tools on first iteration (only if MCP tools available)
if has_mcp_tools {
let mcp_tools = ctx.mcp_manager.list_tools();
let (output_index, item_id) =
emitter.allocate_output_index(OutputItemType::McpListTools);
// Build tools list for item structure
let tool_items: Vec<_> = mcp_tools
.iter()
.map(|t| {
use serde_json::{json, Value};
json!({
"name": t.name,
"description": t.description,
"input_schema": Value::Object((*t.input_schema).clone())
})
})
.collect();
// Emit output_item.added
let item = serde_json::json!({
"id": item_id,
"type": "mcp_list_tools",
"server_label": "sglang-mcp",
"status": "in_progress",
"tools": []
});
let event = emitter.emit_output_item_added(output_index, &item);
if emitter.send_event(&event, &tx).is_err() {
return;
}
// Emit mcp_list_tools.in_progress
let event = emitter.emit_mcp_list_tools_in_progress(output_index);
if emitter.send_event(&event, &tx).is_err() {
return;
}
// Emit mcp_list_tools.completed
let event = emitter.emit_mcp_list_tools_completed(output_index, &mcp_tools);
if emitter.send_event(&event, &tx).is_err() {
return;
}
// Emit output_item.done
let item_done = serde_json::json!({
"id": item_id,
"type": "mcp_list_tools",
"server_label": "sglang-mcp",
"status": "completed",
"tools": tool_items
});
let event = emitter.emit_output_item_done(output_index, &item_done);
if emitter.send_event(&event, &tx).is_err() {
return;
}
emitter.complete_output_item(output_index);
debug!(
tool_count = mcp_tools.len(),
"Emitted mcp_list_tools on first iteration"
);
}
// Tool loop (max 10 iterations)
let mut iteration_count = 0;
loop {
iteration_count += 1;
// Safety check: prevent infinite loops
if iteration_count > MAX_TOOL_ITERATIONS {
let error_msg =
format!("Maximum tool iterations ({}) exceeded", MAX_TOOL_ITERATIONS);
emit_error(&tx, &error_msg);
return;
}
debug!(
iteration = iteration_count,
"Harmony Responses streaming iteration"
);
// Execute through pipeline and get raw stream
let execution_result = match ctx
.pipeline
.execute_harmony_responses_streaming(&current_request, ctx)
.await
{
Ok(result) => result,
Err(err_response) => {
let error_msg = format!("Pipeline execution failed: {:?}", err_response);
emit_error(&tx, &error_msg);
return;
}
};
// Process stream with token-level streaming using HarmonyStreamingProcessor
let iteration_result = match super::streaming::HarmonyStreamingProcessor::process_responses_iteration_stream(
execution_result,
&mut emitter,
&tx,
)
.await
{
Ok(result) => result,
Err(err_msg) => {
emit_error(&tx, &err_msg);
return;
}
};
// Handle iteration result (tool calls or completion)
match iteration_result {
ResponsesIterationResult::ToolCallsFound {
tool_calls,
analysis,
partial_text,
} => {
debug!(
tool_call_count = tool_calls.len(),
has_analysis = analysis.is_some(),
partial_text_len = partial_text.len(),
"Tool calls found in commentary channel"
);
// Execute MCP tools
let tool_results = if let Some(ref mut tracking) = mcp_tracking {
match execute_mcp_tools(&ctx.mcp_manager, &tool_calls, tracking).await {
Ok(results) => results,
Err(err_response) => {
let error_msg =
format!("MCP tool execution failed: {:?}", err_response);
emit_error(&tx, &error_msg);
return;
}
}
} else {
let error_msg = "Tool calls found but MCP tracking not initialized";
emit_error(&tx, error_msg);
return;
};
// Build next request with appended history
current_request = match build_next_request_with_tools(
current_request,
tool_calls,
tool_results,
analysis,
partial_text,
) {
Ok(req) => req,
Err(e) => {
let error_msg = format!("Failed to build next request: {:?}", e);
emit_error(&tx, &error_msg);
return;
}
};
// Continue loop
}
ResponsesIterationResult::Completed { response, usage } => {
debug!(
output_items = response.output.len(),
input_tokens = usage.prompt_tokens,
output_tokens = usage.completion_tokens,
"Harmony Responses streaming completed - no more tool calls"
);
// Emit response.completed with usage
let usage_json = serde_json::json!({
"prompt_tokens": usage.prompt_tokens,
"completion_tokens": usage.completion_tokens,
"total_tokens": usage.total_tokens,
});
let event = emitter.emit_completed(Some(&usage_json));
emitter.send_event_best_effort(&event, &tx);
// Close channel
drop(tx);
return;
}
}
}
});
// Return SSE stream response
Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "text/event-stream")
.header("Cache-Control", "no-cache")
.header("Connection", "keep-alive")
.body(Body::from_stream(stream))
.unwrap()
}
/// Execute MCP tools and collect results
///
/// Executes each tool call sequentially via the MCP manager.
......@@ -397,7 +725,7 @@ async fn execute_mcp_tools(
let mut results = Vec::new();
for tool_call in tool_calls {
tracing::debug!(
debug!(
tool_name = %tool_call.function.name,
call_id = %tool_call.id,
"Executing MCP tool"
......@@ -425,7 +753,7 @@ async fn execute_mcp_tools(
.await
{
Ok(mcp_result) => {
tracing::debug!(
debug!(
tool_name = %tool_call.function.name,
call_id = %tool_call.id,
"Tool execution succeeded"
......@@ -468,7 +796,7 @@ async fn execute_mcp_tools(
});
}
Err(e) => {
tracing::warn!(
warn!(
tool_name = %tool_call.function.name,
call_id = %tool_call.id,
error = %e,
......@@ -656,7 +984,7 @@ struct ToolResult {
/// # Returns
///
/// Vector of ResponseTool entries in MCP format
fn convert_mcp_tools_to_response_tools(mcp_tools: &[crate::mcp::Tool]) -> Vec<ResponseTool> {
pub fn convert_mcp_tools_to_response_tools(mcp_tools: &[crate::mcp::Tool]) -> Vec<ResponseTool> {
use serde_json::Value;
use crate::protocols::responses::ResponseToolType;
......@@ -800,11 +1128,9 @@ async fn load_previous_messages(
.filter_map(|item| {
serde_json::from_value::<ResponseInputOutputItem>(item.clone())
.map_err(|e| {
tracing::warn!(
warn!(
"Failed to deserialize stored {} item: {}. Item: {}",
item_type,
e,
item
item_type, e, item
);
})
.ok()
......@@ -817,7 +1143,7 @@ async fn load_previous_messages(
history_items.extend(deserialize_items(&stored.output, "output"));
}
tracing::debug!(
debug!(
previous_response_id = %prev_id_str,
history_items_count = history_items.len(),
"Loaded conversation history from previous response"
......@@ -851,29 +1177,3 @@ async fn load_previous_messages(
Ok(modified_request)
}
// TODO: Implement streaming support
// /// Emit intermediate streaming chunks for analysis and partial text
// ///
// /// Emits SSE chunks for Responses API streaming:
// /// - Reasoning chunks for analysis channel
// /// - Message chunks for partial text from final channel
// ///
// /// # Arguments
// ///
// /// * `tx` - Streaming sender
// /// * `analysis` - Analysis channel content
// /// * `partial_text` - Final channel content
// /// * `iteration` - Current iteration number
// async fn emit_intermediate_chunks(
// tx: &tokio::sync::mpsc::UnboundedSender<Result<String, String>>,
// analysis: &Option<String>,
// partial_text: &str,
// iteration: usize,
// ) -> Result<(), Response> {
// // TODO: Implement streaming emission
// // - Emit reasoning chunks for analysis
// // - Emit message chunks for partial_text
// // - Follow OpenAI Responses streaming format (14 SSE event types)
// Ok(())
// }
......@@ -40,25 +40,26 @@ impl Default for HarmonyResponseProcessingStage {
#[async_trait]
impl PipelineStage for HarmonyResponseProcessingStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
// Get execution result (output tokens from model)
let execution_result = ctx
.state
.response
.execution_result
.take()
.ok_or_else(|| utils::internal_error_static("No execution result"))?;
let is_streaming = ctx.is_streaming();
let dispatch = ctx
.state
.dispatch
.as_ref()
.cloned()
.ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?;
// Check request type to determine which processor method to call
match &ctx.input.request_type {
RequestType::Chat(_) => {
// Get execution result (output tokens from model)
let execution_result = ctx
.state
.response
.execution_result
.take()
.ok_or_else(|| utils::internal_error_static("No execution result"))?;
let dispatch = ctx
.state
.dispatch
.as_ref()
.cloned()
.ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?;
// For streaming, delegate to streaming processor and return SSE response
if is_streaming {
return Ok(Some(
......@@ -83,14 +84,28 @@ impl PipelineStage for HarmonyResponseProcessingStage {
Ok(None)
}
RequestType::Responses(_) => {
// For Responses API, process iteration and store result
// Streaming not yet supported for Responses API
// For streaming Responses API, leave execution_result in context
// for external streaming processor (serve_harmony_responses_stream)
if is_streaming {
return Err(utils::internal_error_static(
"Streaming not yet supported for Responses API",
));
// Don't take execution_result - let the caller handle it
return Ok(None);
}
// For non-streaming, process normally
let execution_result = ctx
.state
.response
.execution_result
.take()
.ok_or_else(|| utils::internal_error_static("No execution result"))?;
let dispatch = ctx
.state
.dispatch
.as_ref()
.cloned()
.ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?;
let responses_request = ctx.responses_request_arc();
let iteration_result = self
.processor
......
......@@ -16,20 +16,25 @@ use proto::{
use serde_json::json;
use tokio::sync::mpsc;
use tokio_stream::{wrappers::UnboundedReceiverStream, StreamExt};
use tracing::error;
use tracing::{debug, error};
use super::{types::HarmonyChannelDelta, HarmonyParserAdapter};
use super::{
processor::ResponsesIterationResult, types::HarmonyChannelDelta, HarmonyParserAdapter,
};
use crate::{
grpc_client::{proto, sglang_scheduler::AbortOnDropStream},
protocols::{
chat::{
ChatCompletionRequest, ChatCompletionStreamResponse, ChatMessageDelta, ChatStreamChoice,
},
common::{FunctionCallDelta, ToolCallDelta, Usage},
common::{FunctionCallDelta, ToolCall, ToolCallDelta, Usage},
responses::{ResponseStatus, ResponseUsage, ResponsesResponse, ResponsesUsage},
},
routers::grpc::{
context,
responses::streaming::{OutputItemType, ResponseStreamEventEmitter},
},
routers::grpc::context,
};
/// Processor for streaming Harmony responses
///
/// Returns an SSE stream that parses Harmony tokens incrementally and
......@@ -526,6 +531,425 @@ impl HarmonyStreamingProcessor {
Ok(())
}
/// Common decode stream processing logic for both single and dual stream modes
///
/// This helper function contains the shared logic for processing the decode stream,
/// parsing Harmony tokens, emitting SSE events, and tracking state.
async fn process_decode_stream_common(
mut decode_stream: AbortOnDropStream,
emitter: &mut ResponseStreamEventEmitter,
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
) -> Result<ResponsesIterationResult, String> {
// Initialize Harmony parser for this iteration
let mut parser =
HarmonyParserAdapter::new().map_err(|e| format!("Failed to create parser: {}", e))?;
// State tracking for channels
let mut has_analysis = false;
let mut accumulated_final_text = String::new();
let mut accumulated_tool_calls: Option<Vec<ToolCall>> = None;
// Track which items we've started
let mut reasoning_output_index: Option<usize> = None;
let mut message_output_index: Option<usize> = None;
let mut message_item_id: Option<String> = None;
let mut has_emitted_content_part_added = false;
// MCP tool call tracking (call_index -> (output_index, item_id))
let mut mcp_call_tracking: HashMap<usize, (usize, String)> = HashMap::new();
// Metadata from Complete message
let mut finish_reason = String::from("stop");
let mut matched_stop: Option<serde_json::Value> = None;
// Process stream
let mut chunk_count = 0;
while let Some(result) = decode_stream.next().await {
chunk_count += 1;
let response = result.map_err(|e| format!("Decode stream error: {}", e))?;
match response.response {
Some(Chunk(chunk)) => {
// Parse chunk via Harmony parser
let delta_result = parser
.parse_chunk(&chunk.token_ids)
.map_err(|e| format!("Parse error: {}", e))?;
// Emit SSE events if there's a delta
if let Some(delta) = delta_result {
// Analysis channel → Reasoning item (wrapper events only, emitted once)
if let Some(_analysis_text) = &delta.analysis_delta {
if reasoning_output_index.is_none() {
// Allocate reasoning item and emit wrapper events
let (output_index, _item_id) =
emitter.allocate_output_index(OutputItemType::Reasoning);
reasoning_output_index = Some(output_index);
// Emit reasoning item (added + done in one call)
// Note: reasoning_content will be provided at finalize
emitter
.emit_reasoning_item(tx, None)
.map_err(|e| format!("Failed to emit reasoning item: {}", e))?;
has_analysis = true;
}
}
// Final channel → Message item (WITH text streaming)
if let Some(final_delta) = &delta.final_delta {
if !final_delta.is_empty() {
// Allocate message item if needed
if message_output_index.is_none() {
let (output_index, item_id) =
emitter.allocate_output_index(OutputItemType::Message);
message_output_index = Some(output_index);
message_item_id = Some(item_id.clone());
// Build message item structure
let item = json!({
"id": item_id,
"type": "message",
"role": "assistant",
"content": []
});
// Emit output_item.added
let event = emitter.emit_output_item_added(output_index, &item);
emitter.send_event_best_effort(&event, tx);
}
let output_index = message_output_index.unwrap();
let item_id = message_item_id.as_ref().unwrap();
let content_index = 0; // Single content part
// Emit content_part.added before first delta
if !has_emitted_content_part_added {
let event = emitter.emit_content_part_added(
output_index,
item_id,
content_index,
);
emitter.send_event_best_effort(&event, tx);
has_emitted_content_part_added = true;
}
// Emit text delta
let event = emitter.emit_text_delta(
final_delta,
output_index,
item_id,
content_index,
);
emitter.send_event_best_effort(&event, tx);
accumulated_final_text.push_str(final_delta);
}
}
// Commentary channel → MCP tool call streaming
if let Some(tc_delta) = &delta.commentary_delta {
let call_index = tc_delta.index;
// Check if this is a new tool call (has id and name)
if tc_delta.id.is_some() {
// NEW MCP CALL: Allocate output item and emit in_progress
let (output_index, item_id) =
emitter.allocate_output_index(OutputItemType::McpCall);
// Store tracking info
mcp_call_tracking
.insert(call_index, (output_index, item_id.clone()));
// Emit mcp_call.in_progress
let event =
emitter.emit_mcp_call_in_progress(output_index, &item_id);
emitter.send_event_best_effort(&event, tx);
// If we have function name, emit initial mcp_call_arguments.delta
if let Some(func) = &tc_delta.function {
if func.name.is_some() {
let event = emitter.emit_mcp_call_arguments_delta(
output_index,
&item_id,
"",
);
emitter.send_event_best_effort(&event, tx);
}
}
} else {
// CONTINUING MCP CALL: Emit arguments delta
if let Some((output_index, item_id)) =
mcp_call_tracking.get(&call_index)
{
if let Some(args) = tc_delta
.function
.as_ref()
.and_then(|f| f.arguments.as_ref())
.filter(|a| !a.is_empty())
{
let event = emitter.emit_mcp_call_arguments_delta(
*output_index,
item_id,
args,
);
emitter.send_event_best_effort(&event, tx);
}
}
}
}
}
}
Some(Complete(complete)) => {
// Store final metadata
finish_reason = complete.finish_reason.clone();
matched_stop = complete.matched_stop.as_ref().map(|m| match m {
MatchedTokenId(id) => {
serde_json::json!(id)
}
MatchedStopStr(s) => {
serde_json::json!(s)
}
});
// Finalize parser and get complete output
let final_output = parser
.finalize(finish_reason.clone(), matched_stop.clone())
.map_err(|e| format!("Finalize error: {}", e))?;
// Store finalized tool calls
accumulated_tool_calls = final_output.commentary.clone();
// Complete all MCP tool calls if we have commentary
if let Some(ref tool_calls) = accumulated_tool_calls {
for (call_idx, tool_call) in tool_calls.iter().enumerate() {
if let Some((output_index, item_id)) = mcp_call_tracking.get(&call_idx)
{
// Emit mcp_call_arguments.done with final arguments
let args_str =
tool_call.function.arguments.as_deref().unwrap_or("");
let event = emitter.emit_mcp_call_arguments_done(
*output_index,
item_id,
args_str,
);
emitter.send_event_best_effort(&event, tx);
// Emit mcp_call.completed
let event = emitter.emit_mcp_call_completed(*output_index, item_id);
emitter.send_event_best_effort(&event, tx);
// Mark output item as completed
emitter.complete_output_item(*output_index);
}
}
}
// Close message item if we opened one
if let Some(output_index) = message_output_index {
let item_id = message_item_id.as_ref().unwrap();
let content_index = 0;
// Emit text_done
let event = emitter.emit_text_done(output_index, item_id, content_index);
emitter.send_event_best_effort(&event, tx);
// Emit content_part.done
let event =
emitter.emit_content_part_done(output_index, item_id, content_index);
emitter.send_event_best_effort(&event, tx);
// Emit output_item.done
let item = json!({
"id": item_id,
"type": "message",
"role": "assistant",
"content": [{
"type": "text",
"text": accumulated_final_text.clone()
}]
});
let event = emitter.emit_output_item_done(output_index, &item);
emitter.send_event_best_effort(&event, tx);
emitter.complete_output_item(output_index);
}
}
Some(proto::generate_response::Response::Error(err)) => {
return Err(format!("Server error: {}", err.message));
}
None => {}
}
}
debug!(
"Stream loop ended. Total chunks received: {}, has_analysis: {}, tool_calls: {}, final_text_len: {}",
chunk_count,
has_analysis,
accumulated_tool_calls.as_ref().map(|tc| tc.len()).unwrap_or(0),
accumulated_final_text.len()
);
// Extract tool calls from completed messages or incomplete commentary
if chunk_count > 0 && accumulated_tool_calls.is_none() {
let messages = parser.get_messages();
// Try extracting from completed messages first
let (analysis_opt, commentary_opt, final_text_extracted) =
HarmonyParserAdapter::parse_messages(&messages);
accumulated_tool_calls = commentary_opt.clone();
// If no tool calls found, check for incomplete commentary in parser state
if accumulated_tool_calls.is_none() {
accumulated_tool_calls = parser.extract_incomplete_commentary();
}
debug!(
"Tool call extraction: completed_msgs={}, tool_calls={}, has_analysis={}, final_text_len={}",
messages.len(),
accumulated_tool_calls.as_ref().map(|tc| tc.len()).unwrap_or(0),
analysis_opt.is_some(),
final_text_extracted.len()
);
// Complete any pending MCP tool calls with data from completed messages
if let Some(ref tool_calls) = accumulated_tool_calls {
for (call_idx, tool_call) in tool_calls.iter().enumerate() {
if let Some((output_index, item_id)) = mcp_call_tracking.get(&call_idx) {
// Emit mcp_call_arguments.done with final arguments
let args_str = tool_call.function.arguments.as_deref().unwrap_or("");
let event =
emitter.emit_mcp_call_arguments_done(*output_index, item_id, args_str);
emitter.send_event_best_effort(&event, tx);
// Emit mcp_call.completed
let event = emitter.emit_mcp_call_completed(*output_index, item_id);
emitter.send_event_best_effort(&event, tx);
}
}
}
}
// Mark stream as completed successfully to prevent abort on drop
decode_stream.mark_completed();
// Return result based on whether tool calls were found
if let Some(tool_calls) = accumulated_tool_calls {
if !tool_calls.is_empty() {
let analysis_content = if has_analysis {
// Get analysis from finalized parser output by calling finalize again
// This is safe because finalize can be called multiple times
let output = parser.finalize(finish_reason.clone(), matched_stop.clone())?;
output.analysis
} else {
None
};
return Ok(ResponsesIterationResult::ToolCallsFound {
tool_calls,
analysis: analysis_content,
partial_text: accumulated_final_text,
});
}
}
// For streaming, we don't build the full ResponsesResponse here
// The caller will build it from the SSE events
// Return a placeholder Completed result (caller ignores these fields in streaming mode)
Ok(ResponsesIterationResult::Completed {
response: Box::new(ResponsesResponse {
id: String::new(),
object: "response".to_string(),
created_at: 0,
status: ResponseStatus::Completed,
error: None,
incomplete_details: None,
instructions: None,
max_output_tokens: None,
model: String::new(),
output: vec![],
parallel_tool_calls: true,
previous_response_id: None,
reasoning: None,
store: true,
temperature: None,
text: None,
tool_choice: "auto".to_string(),
tools: vec![],
top_p: None,
truncation: None,
user: None,
metadata: HashMap::new(),
usage: Some(ResponsesUsage::Modern(ResponseUsage {
input_tokens: 0,
output_tokens: 0,
total_tokens: 0,
input_tokens_details: None,
output_tokens_details: None,
})),
}),
usage: Usage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
completion_tokens_details: None,
},
})
}
/// Process streaming chunks for Responses API iteration
///
/// Returns ResponsesIterationResult indicating whether tool calls were found
/// (requiring MCP loop continuation) or if the iteration is complete.
pub async fn process_responses_iteration_stream(
execution_result: context::ExecutionResult,
emitter: &mut ResponseStreamEventEmitter,
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
) -> Result<ResponsesIterationResult, String> {
match execution_result {
context::ExecutionResult::Single { stream } => {
debug!("Processing Responses API single stream mode");
Self::process_responses_single_stream(stream, emitter, tx).await
}
context::ExecutionResult::Dual { prefill, decode } => {
debug!("Processing Responses API dual stream mode");
Self::process_responses_dual_stream(prefill, *decode, emitter, tx).await
}
}
}
/// Process streaming chunks from a single stream (Responses API)
async fn process_responses_single_stream(
grpc_stream: AbortOnDropStream,
emitter: &mut ResponseStreamEventEmitter,
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
) -> Result<ResponsesIterationResult, String> {
// Delegate to common helper
Self::process_decode_stream_common(grpc_stream, emitter, tx).await
}
/// Process streaming chunks from dual streams (Responses API)
async fn process_responses_dual_stream(
mut prefill_stream: AbortOnDropStream,
decode_stream: AbortOnDropStream,
emitter: &mut ResponseStreamEventEmitter,
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
) -> Result<ResponsesIterationResult, String> {
// Phase 1: Process prefill stream (collect metadata, no output)
while let Some(result) = prefill_stream.next().await {
let _response = result.map_err(|e| format!("Prefill stream error: {}", e))?;
// No-op for prefill in Responses API (just metadata collection)
}
// Phase 2: Process decode stream using common helper
let result = Self::process_decode_stream_common(decode_stream, emitter, tx).await;
// Mark prefill stream as completed AFTER decode completes successfully
// This ensures that if client disconnects during decode, BOTH streams send abort
prefill_stream.mark_completed();
result
}
/// Build SSE response from receiver
fn build_sse_response(rx: mpsc::UnboundedReceiver<Result<Bytes, io::Error>>) -> Response {
let stream = UnboundedReceiverStream::new(rx);
......
......@@ -454,4 +454,53 @@ impl RequestPipeline {
utils::internal_error_static("No ResponsesIterationResult produced by pipeline")
})
}
/// Execute Harmony Responses pipeline iteration with streaming support
///
/// This version executes the pipeline up to the dispatch stage and returns
/// the raw ExecutionResult (with stream) for token-level streaming processing.
pub async fn execute_harmony_responses_streaming(
&self,
request: &crate::protocols::responses::ResponsesRequest,
harmony_ctx: &harmony::responses::HarmonyResponsesContext,
) -> Result<ExecutionResult, Response> {
// Create RequestContext for this Responses request
let mut ctx = RequestContext::for_responses(
Arc::new(request.clone()),
None,
None,
harmony_ctx.components.clone(),
);
// Execute pipeline stages up to dispatch (which creates the stream)
for (idx, stage) in self.stages.iter().enumerate() {
match stage.execute(&mut ctx).await {
Ok(Some(response)) => {
error!(
"Stage {} ({}) returned unexpected response during streaming Responses",
idx + 1,
stage.name()
);
return Err(response);
}
Ok(None) => continue,
Err(response) => {
error!(
"Stage {} ({}) failed with status {}",
idx + 1,
stage.name(),
response.status()
);
return Err(response);
}
}
}
// Extract execution_result (the raw stream from workers)
ctx.state
.response
.execution_result
.take()
.ok_or_else(|| utils::internal_error_static("No ExecutionResult produced by pipeline"))
}
}
......@@ -9,7 +9,7 @@ use uuid::Uuid;
use crate::protocols::chat::ChatCompletionStreamResponse;
pub(super) enum OutputItemType {
pub enum OutputItemType {
Message,
McpListTools,
McpCall,
......@@ -53,9 +53,9 @@ struct OutputItemState {
/// - response.mcp_call_arguments.done
/// - response.mcp_call.completed
/// - response.mcp_call.failed
pub(super) struct ResponseStreamEventEmitter {
pub struct ResponseStreamEventEmitter {
sequence_number: u64,
response_id: String,
pub response_id: String,
model: String,
created_at: u64,
message_id: String,
......@@ -74,7 +74,7 @@ pub(super) struct ResponseStreamEventEmitter {
}
impl ResponseStreamEventEmitter {
pub(super) fn new(response_id: String, model: String, created_at: u64) -> Self {
pub fn new(response_id: String, model: String, created_at: u64) -> Self {
let message_id = format!("msg_{}", Uuid::new_v4());
Self {
......@@ -102,7 +102,7 @@ impl ResponseStreamEventEmitter {
seq
}
pub(super) fn emit_created(&mut self) -> serde_json::Value {
pub fn emit_created(&mut self) -> serde_json::Value {
self.has_emitted_created = true;
json!({
"type": "response.created",
......@@ -118,7 +118,7 @@ impl ResponseStreamEventEmitter {
})
}
pub(super) fn emit_in_progress(&mut self) -> serde_json::Value {
pub fn emit_in_progress(&mut self) -> serde_json::Value {
self.has_emitted_in_progress = true;
json!({
"type": "response.in_progress",
......@@ -131,7 +131,7 @@ impl ResponseStreamEventEmitter {
})
}
pub(super) fn emit_content_part_added(
pub fn emit_content_part_added(
&mut self,
output_index: usize,
item_id: &str,
......@@ -151,7 +151,7 @@ impl ResponseStreamEventEmitter {
})
}
pub(super) fn emit_text_delta(
pub fn emit_text_delta(
&mut self,
delta: &str,
output_index: usize,
......@@ -169,7 +169,7 @@ impl ResponseStreamEventEmitter {
})
}
pub(super) fn emit_text_done(
pub fn emit_text_done(
&mut self,
output_index: usize,
item_id: &str,
......@@ -185,7 +185,7 @@ impl ResponseStreamEventEmitter {
})
}
pub(super) fn emit_content_part_done(
pub fn emit_content_part_done(
&mut self,
output_index: usize,
item_id: &str,
......@@ -204,10 +204,7 @@ impl ResponseStreamEventEmitter {
})
}
pub(super) fn emit_completed(
&mut self,
usage: Option<&serde_json::Value>,
) -> serde_json::Value {
pub fn emit_completed(&mut self, usage: Option<&serde_json::Value>) -> serde_json::Value {
let mut response = json!({
"type": "response.completed",
"sequence_number": self.next_sequence(),
......@@ -240,10 +237,7 @@ impl ResponseStreamEventEmitter {
// MCP Event Emission Methods
// ========================================================================
pub(super) fn emit_mcp_list_tools_in_progress(
&mut self,
output_index: usize,
) -> serde_json::Value {
pub fn emit_mcp_list_tools_in_progress(&mut self, output_index: usize) -> serde_json::Value {
json!({
"type": "response.mcp_list_tools.in_progress",
"sequence_number": self.next_sequence(),
......@@ -251,7 +245,7 @@ impl ResponseStreamEventEmitter {
})
}
pub(super) fn emit_mcp_list_tools_completed(
pub fn emit_mcp_list_tools_completed(
&mut self,
output_index: usize,
tools: &[crate::mcp::Tool],
......@@ -275,7 +269,7 @@ impl ResponseStreamEventEmitter {
})
}
pub(super) fn emit_mcp_call_in_progress(
pub fn emit_mcp_call_in_progress(
&mut self,
output_index: usize,
item_id: &str,
......@@ -288,7 +282,7 @@ impl ResponseStreamEventEmitter {
})
}
pub(super) fn emit_mcp_call_arguments_delta(
pub fn emit_mcp_call_arguments_delta(
&mut self,
output_index: usize,
item_id: &str,
......@@ -309,7 +303,7 @@ impl ResponseStreamEventEmitter {
})
}
pub(super) fn emit_mcp_call_arguments_done(
pub fn emit_mcp_call_arguments_done(
&mut self,
output_index: usize,
item_id: &str,
......@@ -324,7 +318,7 @@ impl ResponseStreamEventEmitter {
})
}
pub(super) fn emit_mcp_call_completed(
pub fn emit_mcp_call_completed(
&mut self,
output_index: usize,
item_id: &str,
......@@ -357,7 +351,7 @@ impl ResponseStreamEventEmitter {
// ========================================================================
/// Emit response.output_item.added event
pub(super) fn emit_output_item_added(
pub fn emit_output_item_added(
&mut self,
output_index: usize,
item: &serde_json::Value,
......@@ -371,7 +365,7 @@ impl ResponseStreamEventEmitter {
}
/// Emit response.output_item.done event
pub(super) fn emit_output_item_done(
pub fn emit_output_item_done(
&mut self,
output_index: usize,
item: &serde_json::Value,
......@@ -390,7 +384,7 @@ impl ResponseStreamEventEmitter {
}
/// Allocate next output index and track item
pub(super) fn allocate_output_index(&mut self, item_type: OutputItemType) -> (usize, String) {
pub fn allocate_output_index(&mut self, item_type: OutputItemType) -> (usize, String) {
let index = self.next_output_index;
self.next_output_index += 1;
......@@ -412,7 +406,7 @@ impl ResponseStreamEventEmitter {
}
/// Mark output item as completed
pub(super) fn complete_output_item(&mut self, output_index: usize) {
pub fn complete_output_item(&mut self, output_index: usize) {
if let Some(item) = self
.output_items
.iter_mut()
......@@ -426,7 +420,7 @@ impl ResponseStreamEventEmitter {
///
/// Reasoning items in OpenAI format are simple placeholders emitted between tool iterations.
/// They don't have streaming content - just wrapper events with empty/null content.
pub(super) fn emit_reasoning_item(
pub fn emit_reasoning_item(
&mut self,
tx: &mpsc::UnboundedSender<Result<Bytes, std::io::Error>>,
reasoning_content: Option<String>,
......@@ -550,7 +544,7 @@ impl ResponseStreamEventEmitter {
Ok(())
}
pub(super) fn send_event(
pub fn send_event(
&self,
event: &serde_json::Value,
tx: &mpsc::UnboundedSender<Result<Bytes, std::io::Error>>,
......@@ -558,13 +552,38 @@ impl ResponseStreamEventEmitter {
let event_json = serde_json::to_string(event)
.map_err(|e| format!("Failed to serialize event: {}", e))?;
if tx
.send(Ok(Bytes::from(format!("data: {}\n\n", event_json))))
.is_err()
{
// Extract event type from the JSON for SSE event field
let event_type = event
.get("type")
.and_then(|v| v.as_str())
.unwrap_or("message");
// Format as SSE with event: field
let sse_message = format!("event: {}\ndata: {}\n\n", event_type, event_json);
if tx.send(Ok(Bytes::from(sse_message))).is_err() {
return Err("Client disconnected".to_string());
}
Ok(())
}
/// Send event and log any errors (typically client disconnect)
///
/// This is a convenience method for streaming scenarios where client
/// disconnection is expected and should be logged but not fail the operation.
/// Returns true if sent successfully, false if client disconnected.
pub fn send_event_best_effort(
&self,
event: &serde_json::Value,
tx: &mpsc::UnboundedSender<Result<Bytes, std::io::Error>>,
) -> bool {
match self.send_event(event, tx) {
Ok(()) => true,
Err(e) => {
tracing::debug!("Failed to send event (likely client disconnect): {}", e);
false
}
}
}
}
......@@ -13,7 +13,10 @@ use tracing::debug;
use super::{
context::SharedComponents,
harmony::{serve_harmony_responses, HarmonyDetector, HarmonyResponsesContext},
harmony::{
serve_harmony_responses, serve_harmony_responses_stream, HarmonyDetector,
HarmonyResponsesContext,
},
pipeline::RequestPipeline,
responses,
};
......@@ -192,8 +195,8 @@ impl GrpcRouter {
model_id: Option<&str>,
) -> Response {
debug!(
"Processing Harmony responses request for model: {:?}",
model_id
"Processing Harmony responses request for model: {:?}, streaming: {:?}",
model_id, body.stream
);
// Create HarmonyResponsesContext from existing responses context
......@@ -204,10 +207,15 @@ impl GrpcRouter {
self.harmony_responses_context.response_storage.clone(),
);
// Use serve_harmony_responses for multi-turn MCP tool orchestration
match serve_harmony_responses(&harmony_ctx, body.clone()).await {
Ok(response) => axum::Json(response).into_response(),
Err(error_response) => error_response,
// Check if streaming is requested
if body.stream.unwrap_or(false) {
serve_harmony_responses_stream(&harmony_ctx, body.clone()).await
} else {
// Use non-streaming version for standard JSON responses
match serve_harmony_responses(&harmony_ctx, body.clone()).await {
Ok(response) => axum::Json(response).into_response(),
Err(error_response) => error_response,
}
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment