Unverified Commit 700daa34 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] harmony responses api streaming support (#12395)

parent 39cee0fe
......@@ -43,7 +43,9 @@ pub use builder::HarmonyBuilder;
pub use detector::HarmonyDetector;
pub use parser::HarmonyParserAdapter;
pub use processor::{HarmonyResponseProcessor, ResponsesIterationResult};
pub use responses::{serve_harmony_responses, HarmonyResponsesContext};
pub use responses::{
serve_harmony_responses, serve_harmony_responses_stream, HarmonyResponsesContext,
};
pub use stages::{
HarmonyPreparationStage, HarmonyRequestBuildingStage, HarmonyResponseProcessingStage,
};
......
......@@ -103,7 +103,7 @@ impl HarmonyParserAdapter {
/// # Returns
///
/// Tuple of (analysis, commentary, final_text)
fn parse_messages(
pub fn parse_messages(
messages: &[openai_harmony::chat::Message],
) -> (Option<String>, Option<Vec<ToolCall>>, String) {
let mut analysis = None;
......@@ -260,6 +260,51 @@ impl HarmonyParserAdapter {
self.parser.messages().to_vec()
}
/// Extract incomplete commentary content from parser state
///
/// When the stream ends, there may be incomplete commentary content in the parser
/// that hasn't been finalized into a completed message. This method extracts
/// such content and converts it to tool calls.
///
/// # Returns
///
/// Optional vector of ToolCall if incomplete commentary is found
pub fn extract_incomplete_commentary(&self) -> Option<Vec<ToolCall>> {
// Check if current channel is commentary
let current_channel = self.parser.current_channel();
if current_channel.as_deref() != Some("commentary") {
return None;
}
// Get current recipient (should be "functions.{name}")
let recipient = self.parser.current_recipient()?;
if !recipient.starts_with("functions.") {
return None;
}
// Get current incomplete content
let content = self.parser.current_content().ok()?;
if content.is_empty() {
return None;
}
// Extract function name from recipient
let function_name = recipient.strip_prefix("functions.").unwrap();
// Create tool call from incomplete content
let call_id = format!("call_{}", Uuid::new_v4());
let tool_call = ToolCall {
id: call_id,
tool_type: "function".to_string(),
function: FunctionCallResponse {
name: function_name.to_string(),
arguments: Some(content),
},
};
Some(vec![tool_call])
}
/// Parse streaming chunk
///
/// Parses incremental token IDs and returns a delta with any new content
......
......@@ -40,25 +40,26 @@ impl Default for HarmonyResponseProcessingStage {
#[async_trait]
impl PipelineStage for HarmonyResponseProcessingStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
// Get execution result (output tokens from model)
let execution_result = ctx
.state
.response
.execution_result
.take()
.ok_or_else(|| utils::internal_error_static("No execution result"))?;
let is_streaming = ctx.is_streaming();
let dispatch = ctx
.state
.dispatch
.as_ref()
.cloned()
.ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?;
// Check request type to determine which processor method to call
match &ctx.input.request_type {
RequestType::Chat(_) => {
// Get execution result (output tokens from model)
let execution_result = ctx
.state
.response
.execution_result
.take()
.ok_or_else(|| utils::internal_error_static("No execution result"))?;
let dispatch = ctx
.state
.dispatch
.as_ref()
.cloned()
.ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?;
// For streaming, delegate to streaming processor and return SSE response
if is_streaming {
return Ok(Some(
......@@ -83,14 +84,28 @@ impl PipelineStage for HarmonyResponseProcessingStage {
Ok(None)
}
RequestType::Responses(_) => {
// For Responses API, process iteration and store result
// Streaming not yet supported for Responses API
// For streaming Responses API, leave execution_result in context
// for external streaming processor (serve_harmony_responses_stream)
if is_streaming {
return Err(utils::internal_error_static(
"Streaming not yet supported for Responses API",
));
// Don't take execution_result - let the caller handle it
return Ok(None);
}
// For non-streaming, process normally
let execution_result = ctx
.state
.response
.execution_result
.take()
.ok_or_else(|| utils::internal_error_static("No execution result"))?;
let dispatch = ctx
.state
.dispatch
.as_ref()
.cloned()
.ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?;
let responses_request = ctx.responses_request_arc();
let iteration_result = self
.processor
......
......@@ -454,4 +454,53 @@ impl RequestPipeline {
utils::internal_error_static("No ResponsesIterationResult produced by pipeline")
})
}
/// Execute Harmony Responses pipeline iteration with streaming support
///
/// This version executes the pipeline up to the dispatch stage and returns
/// the raw ExecutionResult (with stream) for token-level streaming processing.
pub async fn execute_harmony_responses_streaming(
&self,
request: &crate::protocols::responses::ResponsesRequest,
harmony_ctx: &harmony::responses::HarmonyResponsesContext,
) -> Result<ExecutionResult, Response> {
// Create RequestContext for this Responses request
let mut ctx = RequestContext::for_responses(
Arc::new(request.clone()),
None,
None,
harmony_ctx.components.clone(),
);
// Execute pipeline stages up to dispatch (which creates the stream)
for (idx, stage) in self.stages.iter().enumerate() {
match stage.execute(&mut ctx).await {
Ok(Some(response)) => {
error!(
"Stage {} ({}) returned unexpected response during streaming Responses",
idx + 1,
stage.name()
);
return Err(response);
}
Ok(None) => continue,
Err(response) => {
error!(
"Stage {} ({}) failed with status {}",
idx + 1,
stage.name(),
response.status()
);
return Err(response);
}
}
}
// Extract execution_result (the raw stream from workers)
ctx.state
.response
.execution_result
.take()
.ok_or_else(|| utils::internal_error_static("No ExecutionResult produced by pipeline"))
}
}
......@@ -9,7 +9,7 @@ use uuid::Uuid;
use crate::protocols::chat::ChatCompletionStreamResponse;
pub(super) enum OutputItemType {
pub enum OutputItemType {
Message,
McpListTools,
McpCall,
......@@ -53,9 +53,9 @@ struct OutputItemState {
/// - response.mcp_call_arguments.done
/// - response.mcp_call.completed
/// - response.mcp_call.failed
pub(super) struct ResponseStreamEventEmitter {
pub struct ResponseStreamEventEmitter {
sequence_number: u64,
response_id: String,
pub response_id: String,
model: String,
created_at: u64,
message_id: String,
......@@ -74,7 +74,7 @@ pub(super) struct ResponseStreamEventEmitter {
}
impl ResponseStreamEventEmitter {
pub(super) fn new(response_id: String, model: String, created_at: u64) -> Self {
pub fn new(response_id: String, model: String, created_at: u64) -> Self {
let message_id = format!("msg_{}", Uuid::new_v4());
Self {
......@@ -102,7 +102,7 @@ impl ResponseStreamEventEmitter {
seq
}
pub(super) fn emit_created(&mut self) -> serde_json::Value {
pub fn emit_created(&mut self) -> serde_json::Value {
self.has_emitted_created = true;
json!({
"type": "response.created",
......@@ -118,7 +118,7 @@ impl ResponseStreamEventEmitter {
})
}
pub(super) fn emit_in_progress(&mut self) -> serde_json::Value {
pub fn emit_in_progress(&mut self) -> serde_json::Value {
self.has_emitted_in_progress = true;
json!({
"type": "response.in_progress",
......@@ -131,7 +131,7 @@ impl ResponseStreamEventEmitter {
})
}
pub(super) fn emit_content_part_added(
pub fn emit_content_part_added(
&mut self,
output_index: usize,
item_id: &str,
......@@ -151,7 +151,7 @@ impl ResponseStreamEventEmitter {
})
}
pub(super) fn emit_text_delta(
pub fn emit_text_delta(
&mut self,
delta: &str,
output_index: usize,
......@@ -169,7 +169,7 @@ impl ResponseStreamEventEmitter {
})
}
pub(super) fn emit_text_done(
pub fn emit_text_done(
&mut self,
output_index: usize,
item_id: &str,
......@@ -185,7 +185,7 @@ impl ResponseStreamEventEmitter {
})
}
pub(super) fn emit_content_part_done(
pub fn emit_content_part_done(
&mut self,
output_index: usize,
item_id: &str,
......@@ -204,10 +204,7 @@ impl ResponseStreamEventEmitter {
})
}
pub(super) fn emit_completed(
&mut self,
usage: Option<&serde_json::Value>,
) -> serde_json::Value {
pub fn emit_completed(&mut self, usage: Option<&serde_json::Value>) -> serde_json::Value {
let mut response = json!({
"type": "response.completed",
"sequence_number": self.next_sequence(),
......@@ -240,10 +237,7 @@ impl ResponseStreamEventEmitter {
// MCP Event Emission Methods
// ========================================================================
pub(super) fn emit_mcp_list_tools_in_progress(
&mut self,
output_index: usize,
) -> serde_json::Value {
pub fn emit_mcp_list_tools_in_progress(&mut self, output_index: usize) -> serde_json::Value {
json!({
"type": "response.mcp_list_tools.in_progress",
"sequence_number": self.next_sequence(),
......@@ -251,7 +245,7 @@ impl ResponseStreamEventEmitter {
})
}
pub(super) fn emit_mcp_list_tools_completed(
pub fn emit_mcp_list_tools_completed(
&mut self,
output_index: usize,
tools: &[crate::mcp::Tool],
......@@ -275,7 +269,7 @@ impl ResponseStreamEventEmitter {
})
}
pub(super) fn emit_mcp_call_in_progress(
pub fn emit_mcp_call_in_progress(
&mut self,
output_index: usize,
item_id: &str,
......@@ -288,7 +282,7 @@ impl ResponseStreamEventEmitter {
})
}
pub(super) fn emit_mcp_call_arguments_delta(
pub fn emit_mcp_call_arguments_delta(
&mut self,
output_index: usize,
item_id: &str,
......@@ -309,7 +303,7 @@ impl ResponseStreamEventEmitter {
})
}
pub(super) fn emit_mcp_call_arguments_done(
pub fn emit_mcp_call_arguments_done(
&mut self,
output_index: usize,
item_id: &str,
......@@ -324,7 +318,7 @@ impl ResponseStreamEventEmitter {
})
}
pub(super) fn emit_mcp_call_completed(
pub fn emit_mcp_call_completed(
&mut self,
output_index: usize,
item_id: &str,
......@@ -357,7 +351,7 @@ impl ResponseStreamEventEmitter {
// ========================================================================
/// Emit response.output_item.added event
pub(super) fn emit_output_item_added(
pub fn emit_output_item_added(
&mut self,
output_index: usize,
item: &serde_json::Value,
......@@ -371,7 +365,7 @@ impl ResponseStreamEventEmitter {
}
/// Emit response.output_item.done event
pub(super) fn emit_output_item_done(
pub fn emit_output_item_done(
&mut self,
output_index: usize,
item: &serde_json::Value,
......@@ -390,7 +384,7 @@ impl ResponseStreamEventEmitter {
}
/// Allocate next output index and track item
pub(super) fn allocate_output_index(&mut self, item_type: OutputItemType) -> (usize, String) {
pub fn allocate_output_index(&mut self, item_type: OutputItemType) -> (usize, String) {
let index = self.next_output_index;
self.next_output_index += 1;
......@@ -412,7 +406,7 @@ impl ResponseStreamEventEmitter {
}
/// Mark output item as completed
pub(super) fn complete_output_item(&mut self, output_index: usize) {
pub fn complete_output_item(&mut self, output_index: usize) {
if let Some(item) = self
.output_items
.iter_mut()
......@@ -426,7 +420,7 @@ impl ResponseStreamEventEmitter {
///
/// Reasoning items in OpenAI format are simple placeholders emitted between tool iterations.
/// They don't have streaming content - just wrapper events with empty/null content.
pub(super) fn emit_reasoning_item(
pub fn emit_reasoning_item(
&mut self,
tx: &mpsc::UnboundedSender<Result<Bytes, std::io::Error>>,
reasoning_content: Option<String>,
......@@ -550,7 +544,7 @@ impl ResponseStreamEventEmitter {
Ok(())
}
pub(super) fn send_event(
pub fn send_event(
&self,
event: &serde_json::Value,
tx: &mpsc::UnboundedSender<Result<Bytes, std::io::Error>>,
......@@ -558,13 +552,38 @@ impl ResponseStreamEventEmitter {
let event_json = serde_json::to_string(event)
.map_err(|e| format!("Failed to serialize event: {}", e))?;
if tx
.send(Ok(Bytes::from(format!("data: {}\n\n", event_json))))
.is_err()
{
// Extract event type from the JSON for SSE event field
let event_type = event
.get("type")
.and_then(|v| v.as_str())
.unwrap_or("message");
// Format as SSE with event: field
let sse_message = format!("event: {}\ndata: {}\n\n", event_type, event_json);
if tx.send(Ok(Bytes::from(sse_message))).is_err() {
return Err("Client disconnected".to_string());
}
Ok(())
}
/// Send event and log any errors (typically client disconnect)
///
/// This is a convenience method for streaming scenarios where client
/// disconnection is expected and should be logged but not fail the operation.
/// Returns true if sent successfully, false if client disconnected.
pub fn send_event_best_effort(
&self,
event: &serde_json::Value,
tx: &mpsc::UnboundedSender<Result<Bytes, std::io::Error>>,
) -> bool {
match self.send_event(event, tx) {
Ok(()) => true,
Err(e) => {
tracing::debug!("Failed to send event (likely client disconnect): {}", e);
false
}
}
}
}
......@@ -13,7 +13,10 @@ use tracing::debug;
use super::{
context::SharedComponents,
harmony::{serve_harmony_responses, HarmonyDetector, HarmonyResponsesContext},
harmony::{
serve_harmony_responses, serve_harmony_responses_stream, HarmonyDetector,
HarmonyResponsesContext,
},
pipeline::RequestPipeline,
responses,
};
......@@ -192,8 +195,8 @@ impl GrpcRouter {
model_id: Option<&str>,
) -> Response {
debug!(
"Processing Harmony responses request for model: {:?}",
model_id
"Processing Harmony responses request for model: {:?}, streaming: {:?}",
model_id, body.stream
);
// Create HarmonyResponsesContext from existing responses context
......@@ -204,10 +207,15 @@ impl GrpcRouter {
self.harmony_responses_context.response_storage.clone(),
);
// Use serve_harmony_responses for multi-turn MCP tool orchestration
match serve_harmony_responses(&harmony_ctx, body.clone()).await {
Ok(response) => axum::Json(response).into_response(),
Err(error_response) => error_response,
// Check if streaming is requested
if body.stream.unwrap_or(false) {
serve_harmony_responses_stream(&harmony_ctx, body.clone()).await
} else {
// Use non-streaming version for standard JSON responses
match serve_harmony_responses(&harmony_ctx, body.clone()).await {
Ok(response) => axum::Json(response).into_response(),
Err(error_response) => error_response,
}
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment