Unverified Commit 03b3e89a authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] Harmony Pipeline: Chat Completion & Responses API with MCP Support (#12153)


Co-authored-by: default avatarChang Su <chang.s.su@oracle.com>
parent 9ff9fa7f
...@@ -83,6 +83,7 @@ oracle = { version = "0.6.3", features = ["chrono"] } ...@@ -83,6 +83,7 @@ oracle = { version = "0.6.3", features = ["chrono"] }
subtle = "2.6" subtle = "2.6"
rustpython-parser = "0.4.0" rustpython-parser = "0.4.0"
num-traits = "0.2" num-traits = "0.2"
openai-harmony = { git = "https://github.com/openai/harmony", tag = "v0.0.4" }
# gRPC and Protobuf dependencies # gRPC and Protobuf dependencies
tonic = { version = "0.14.2", features = ["gzip", "transport"] } tonic = { version = "0.14.2", features = ["gzip", "transport"] }
......
# SGLang Router Makefile # Model Gateway Makefile
# Provides convenient shortcuts for common development tasks # Provides convenient shortcuts for common development tasks
# Check if sccache is available and set RUSTC_WRAPPER accordingly # Check if sccache is available and set RUSTC_WRAPPER accordingly
...@@ -13,14 +13,14 @@ endif ...@@ -13,14 +13,14 @@ endif
.PHONY: help bench bench-quick bench-baseline bench-compare test build clean .PHONY: help bench bench-quick bench-baseline bench-compare test build clean
help: ## Show this help message help: ## Show this help message
@echo "SGLang Router Development Commands" @echo "Model Gateway Development Commands"
@echo "==================================" @echo "=================================="
@echo "" @echo ""
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-20s\033[0m %s\n", $$1, $$2}' @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-20s\033[0m %s\n", $$1, $$2}'
@echo "" @echo ""
build: ## Build the project in release mode build: ## Build the project in release mode
@echo "Building SGLang Router..." @echo "Building SGLang Model Gateway..."
@cargo build --release @cargo build --release
test: ## Run all tests test: ## Run all tests
...@@ -59,11 +59,11 @@ check: ## Run cargo check and clippy ...@@ -59,11 +59,11 @@ check: ## Run cargo check and clippy
@echo "Running cargo check..." @echo "Running cargo check..."
@cargo check @cargo check
@echo "Running clippy..." @echo "Running clippy..."
@cargo clippy @cargo clippy --all-targets --all-features -- -D warnings
fmt: ## Format code with rustfmt fmt: ## Format code with rustfmt
@echo "Formatting code..." @echo "Formatting code..."
@cargo fmt @rustup run nightly cargo fmt
# Development workflow shortcuts # Development workflow shortcuts
dev-setup: build test ## Set up development environment dev-setup: build test ## Set up development environment
......
...@@ -16,6 +16,7 @@ use crate::protocols::{ ...@@ -16,6 +16,7 @@ use crate::protocols::{
chat::ChatCompletionRequest, chat::ChatCompletionRequest,
common::{ResponseFormat, StringOrArray, ToolChoice, ToolChoiceValue}, common::{ResponseFormat, StringOrArray, ToolChoice, ToolChoiceValue},
generate::GenerateRequest, generate::GenerateRequest,
responses::ResponsesRequest,
sampling_params::SamplingParams as GenerateSamplingParams, sampling_params::SamplingParams as GenerateSamplingParams,
}; };
...@@ -301,6 +302,42 @@ impl SglangSchedulerClient { ...@@ -301,6 +302,42 @@ impl SglangSchedulerClient {
Ok(grpc_request) Ok(grpc_request)
} }
/// Build a GenerateRequest from ResponsesRequest (OpenAI Responses API)
pub fn build_generate_request_from_responses(
&self,
request_id: String,
body: &ResponsesRequest,
processed_text: String,
token_ids: Vec<u32>,
harmony_stop_ids: Option<Vec<u32>>,
) -> Result<proto::GenerateRequest, String> {
// Build sampling params from ResponsesRequest
let mut sampling_params = self.build_grpc_sampling_params_from_responses(body)?;
// Inject Harmony stop token IDs if provided
if let Some(stop_ids) = harmony_stop_ids {
sampling_params.stop_token_ids = stop_ids;
}
let grpc_request = proto::GenerateRequest {
request_id,
tokenized: Some(proto::TokenizedInput {
original_text: processed_text,
input_ids: token_ids,
}),
mm_inputs: None, // Responses API doesn't support multimodal yet
sampling_params: Some(sampling_params),
return_logprob: false, // Responses API uses top_logprobs field instead
logprob_start_len: -1,
top_logprobs_num: body.top_logprobs.unwrap_or(0) as i32,
return_hidden_states: false,
stream: body.stream.unwrap_or(false),
..Default::default()
};
Ok(grpc_request)
}
/// Build gRPC SamplingParams from OpenAI request /// Build gRPC SamplingParams from OpenAI request
fn build_grpc_sampling_params( fn build_grpc_sampling_params(
&self, &self,
...@@ -400,6 +437,37 @@ impl SglangSchedulerClient { ...@@ -400,6 +437,37 @@ impl SglangSchedulerClient {
} }
} }
/// Build gRPC SamplingParams from ResponsesRequest
fn build_grpc_sampling_params_from_responses(
&self,
request: &ResponsesRequest,
) -> Result<proto::SamplingParams, String> {
// ResponsesRequest doesn't have stop sequences in the same way
// Tools are handled externally by MCP loop, not via constraints
let max_new_tokens = request.max_output_tokens.map(|v| v as i32);
Ok(proto::SamplingParams {
temperature: request.temperature.unwrap_or(1.0),
top_p: request.top_p.unwrap_or(1.0),
top_k: -1, // ResponsesRequest doesn't expose top_k
min_p: 0.0, // ResponsesRequest doesn't expose min_p
frequency_penalty: 0.0, // ResponsesRequest doesn't expose frequency_penalty
presence_penalty: 0.0, // ResponsesRequest doesn't expose presence_penalty
repetition_penalty: 1.0, // ResponsesRequest doesn't expose repetition_penalty
max_new_tokens,
stop: vec![], // No stop sequences in Responses API
stop_token_ids: vec![], // Handled by Harmony stop tokens
skip_special_tokens: false, // Keep special tokens for Harmony
spaces_between_special_tokens: true,
ignore_eos: false,
no_stop_trim: false,
n: 1, // Responses API doesn't support n>1
constraint: None, // No constraints - tools handled by MCP
..Default::default()
})
}
fn build_single_constraint_from_plain( fn build_single_constraint_from_plain(
params: &GenerateSamplingParams, params: &GenerateSamplingParams,
) -> Result<Option<proto::sampling_params::Constraint>, String> { ) -> Result<Option<proto::sampling_params::Constraint>, String> {
......
...@@ -674,7 +674,6 @@ pub struct ChatMessageDelta { ...@@ -674,7 +674,6 @@ pub struct ChatMessageDelta {
pub struct ChatStreamChoice { pub struct ChatStreamChoice {
pub index: u32, pub index: u32,
pub delta: ChatMessageDelta, pub delta: ChatMessageDelta,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<ChatLogProbs>, pub logprobs: Option<ChatLogProbs>,
pub finish_reason: Option<String>, pub finish_reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
......
...@@ -21,6 +21,9 @@ use super::common::{ ...@@ -21,6 +21,9 @@ use super::common::{
pub struct ResponseTool { pub struct ResponseTool {
#[serde(rename = "type")] #[serde(rename = "type")]
pub r#type: ResponseToolType, pub r#type: ResponseToolType,
// Function tool fields (used when type == "function")
#[serde(skip_serializing_if = "Option::is_none")]
pub function: Option<crate::protocols::common::Function>,
// MCP-specific fields (used when type == "mcp") // MCP-specific fields (used when type == "mcp")
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub server_url: Option<String>, pub server_url: Option<String>,
...@@ -40,6 +43,7 @@ impl Default for ResponseTool { ...@@ -40,6 +43,7 @@ impl Default for ResponseTool {
fn default() -> Self { fn default() -> Self {
Self { Self {
r#type: ResponseToolType::WebSearchPreview, r#type: ResponseToolType::WebSearchPreview,
function: None,
server_url: None, server_url: None,
authorization: None, authorization: None,
server_label: None, server_label: None,
...@@ -53,6 +57,7 @@ impl Default for ResponseTool { ...@@ -53,6 +57,7 @@ impl Default for ResponseTool {
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum ResponseToolType { pub enum ResponseToolType {
Function,
WebSearchPreview, WebSearchPreview,
CodeInterpreter, CodeInterpreter,
Mcp, Mcp,
...@@ -134,6 +139,13 @@ pub enum ResponseInputOutputItem { ...@@ -134,6 +139,13 @@ pub enum ResponseInputOutputItem {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
status: Option<String>, status: Option<String>,
}, },
#[serde(rename = "function_call_output")]
FunctionCallOutput {
call_id: String,
output: String,
#[serde(skip_serializing_if = "Option::is_none")]
status: Option<String>,
},
#[serde(untagged)] #[serde(untagged)]
SimpleInputMessage { SimpleInputMessage {
content: StringOrContentParts, content: StringOrContentParts,
...@@ -499,7 +511,7 @@ pub struct ResponsesRequest { ...@@ -499,7 +511,7 @@ pub struct ResponsesRequest {
pub store: Option<bool>, pub store: Option<bool>,
/// Whether to stream the response /// Whether to stream the response
#[serde(skip_serializing_if = "Option::is_none")] #[serde(default)]
pub stream: Option<bool>, pub stream: Option<bool>,
/// Temperature for sampling /// Temperature for sampling
...@@ -678,6 +690,9 @@ impl GenerationRequest for ResponsesRequest { ...@@ -678,6 +690,9 @@ impl GenerationRequest for ResponsesRequest {
ResponseInputOutputItem::FunctionToolCall { arguments, .. } => { ResponseInputOutputItem::FunctionToolCall { arguments, .. } => {
Some(arguments.clone()) Some(arguments.clone())
} }
ResponseInputOutputItem::FunctionCallOutput { output, .. } => {
Some(output.clone())
}
}) })
.collect::<Vec<String>>() .collect::<Vec<String>>()
.join(" "), .join(" "),
......
...@@ -15,6 +15,7 @@ use crate::{ ...@@ -15,6 +15,7 @@ use crate::{
protocols::{ protocols::{
chat::{ChatCompletionRequest, ChatCompletionResponse}, chat::{ChatCompletionRequest, ChatCompletionResponse},
generate::{GenerateRequest, GenerateResponse}, generate::{GenerateRequest, GenerateResponse},
responses::ResponsesRequest,
}, },
reasoning_parser::ParserFactory as ReasoningParserFactory, reasoning_parser::ParserFactory as ReasoningParserFactory,
tokenizer::{stop::StopSequenceDecoder, traits::Tokenizer}, tokenizer::{stop::StopSequenceDecoder, traits::Tokenizer},
...@@ -53,6 +54,7 @@ pub struct RequestInput { ...@@ -53,6 +54,7 @@ pub struct RequestInput {
pub enum RequestType { pub enum RequestType {
Chat(Arc<ChatCompletionRequest>), Chat(Arc<ChatCompletionRequest>),
Generate(Arc<GenerateRequest>), Generate(Arc<GenerateRequest>),
Responses(Arc<ResponsesRequest>),
} }
/// Shared components (injected once at creation) /// Shared components (injected once at creation)
...@@ -104,6 +106,19 @@ pub struct PreparationOutput { ...@@ -104,6 +106,19 @@ pub struct PreparationOutput {
/// Filtered request (if tools were filtered) /// Filtered request (if tools were filtered)
pub filtered_request: Option<ChatCompletionRequest>, pub filtered_request: Option<ChatCompletionRequest>,
// Harmony-specific fields
/// Whether this is a Harmony request (default: false)
pub harmony_mode: bool,
/// Selection text for worker routing (Harmony only)
pub selection_text: Option<String>,
/// Harmony messages for history tracking (Harmony only)
pub harmony_messages: Option<Vec<super::harmony::HarmonyMessage>>,
/// Stop token IDs for Harmony models
pub harmony_stop_ids: Option<Vec<u32>>,
} }
/// Worker selection (Step 2) /// Worker selection (Step 2)
...@@ -155,6 +170,16 @@ pub struct ResponseState { ...@@ -155,6 +170,16 @@ pub struct ResponseState {
/// Final processed response /// Final processed response
pub final_response: Option<FinalResponse>, pub final_response: Option<FinalResponse>,
/// Responses API iteration result (Harmony only, for tool loop orchestration)
pub responses_iteration_result: Option<super::harmony::ResponsesIterationResult>,
// Harmony-specific parser state
/// Harmony parser for non-streaming (single parser for all indices)
pub harmony_parser: Option<super::harmony::HarmonyParserAdapter>,
/// Harmony parsers for streaming (one per index for n>1 support)
pub harmony_parser_per_index: Option<HashMap<usize, super::harmony::HarmonyParserAdapter>>,
} }
/// Streaming state (per-choice tracking) /// Streaming state (per-choice tracking)
...@@ -217,6 +242,24 @@ impl RequestContext { ...@@ -217,6 +242,24 @@ impl RequestContext {
} }
} }
/// Create context for Responses API request
pub fn for_responses(
request: Arc<ResponsesRequest>,
headers: Option<HeaderMap>,
model_id: Option<String>,
components: Arc<SharedComponents>,
) -> Self {
Self {
input: RequestInput {
request_type: RequestType::Responses(request),
headers,
model_id,
},
components,
state: ProcessingState::default(),
}
}
/// Get reference to original request (type-safe) /// Get reference to original request (type-safe)
pub fn request(&self) -> &RequestType { pub fn request(&self) -> &RequestType {
&self.input.request_type &self.input.request_type
...@@ -254,11 +297,28 @@ impl RequestContext { ...@@ -254,11 +297,28 @@ impl RequestContext {
} }
} }
/// Get responses request (panics if not responses)
pub fn responses_request(&self) -> &ResponsesRequest {
match &self.input.request_type {
RequestType::Responses(req) => req.as_ref(),
_ => panic!("Expected responses request"),
}
}
/// Get Arc clone of responses request (panics if not responses)
pub fn responses_request_arc(&self) -> Arc<ResponsesRequest> {
match &self.input.request_type {
RequestType::Responses(req) => Arc::clone(req),
_ => panic!("Expected responses request"),
}
}
/// Check if request is streaming /// Check if request is streaming
pub fn is_streaming(&self) -> bool { pub fn is_streaming(&self) -> bool {
match &self.input.request_type { match &self.input.request_type {
RequestType::Chat(req) => req.stream, RequestType::Chat(req) => req.stream,
RequestType::Generate(req) => req.stream, RequestType::Generate(req) => req.stream,
RequestType::Responses(req) => req.stream.unwrap_or(false),
} }
} }
} }
......
This diff is collapsed.
//! Harmony model detection
/// Harmony model detector
///
/// Detects if a model name indicates support for Harmony encoding/parsing.
pub struct HarmonyDetector;
impl HarmonyDetector {
pub fn is_harmony_model(model_name: &str) -> bool {
// Case-insensitive substring search without heap allocation
// More efficient than to_lowercase() which allocates a new String
model_name
.as_bytes()
.windows(7) // "gpt-oss".len()
.any(|window| window.eq_ignore_ascii_case(b"gpt-oss"))
}
}
//! Harmony pipeline implementation
//!
//! This module provides support for GPT-OSS models that use Harmony encoding/parsing.
//! The Harmony protocol uses a channel-based approach with three channels:
//! - **analysis**: Reasoning/thinking content (optional)
//! - **commentary**: Tool calls (optional)
//! - **final**: Final response text (required)
//!
//! ## Architecture
//!
//! The Harmony implementation is structured as follows:
//!
//! - **detector**: Model detection (is this a Harmony-capable model?)
//! - **builder**: Request encoding (convert Chat/Responses → input_ids)
//! - **parser**: Response parsing (output_ids → channels)
//! - **types**: Shared type definitions
//!
//! ## Usage
//!
//! ```ignore
//! use sglang_router_rs::routers::grpc::harmony::{HarmonyDetector, HarmonyBuilder};
//!
//! // Detect if model supports Harmony
//! if HarmonyDetector::is_harmony_model("gpt-4o") {
//! // Build Harmony request
//! let builder = HarmonyBuilder::new();
//! let output = builder.build_from_chat(&request)?;
//! // ... use output.input_ids for gRPC request
//! }
//! ```
pub mod builder;
pub mod detector;
pub mod parser;
pub mod processor;
pub mod responses;
pub mod stages;
pub mod streaming;
pub mod types;
// Re-export main types for convenience
pub use builder::HarmonyBuilder;
pub use detector::HarmonyDetector;
pub use parser::HarmonyParserAdapter;
pub use processor::{HarmonyResponseProcessor, ResponsesIterationResult};
pub use responses::{serve_harmony_responses, HarmonyResponsesContext};
pub use stages::{
HarmonyPreparationStage, HarmonyRequestBuildingStage, HarmonyResponseProcessingStage,
};
pub use streaming::HarmonyStreamingProcessor;
pub use types::{
FunctionDelta, HarmonyBuildOutput, HarmonyChannelDelta, HarmonyChannelOutput, HarmonyMessage,
ToolCallDelta,
};
//! Harmony response parser
//!
//! Adapter for openai_harmony::StreamableParser that handles channel-based parsing.
use openai_harmony::{chat::Role, HarmonyEncoding, StreamableParser};
use uuid::Uuid;
use super::types::{HarmonyChannelDelta, HarmonyChannelOutput};
use crate::protocols::common::{FunctionCallResponse, ToolCall};
/// Get the global Harmony encoding
///
/// References the same encoding used by the builder for consistency
fn get_harmony_encoding() -> &'static HarmonyEncoding {
use super::builder::get_harmony_encoding;
get_harmony_encoding()
}
/// Harmony parser adapter
///
/// Wraps openai_harmony::StreamableParser and provides methods for parsing
/// complete responses and streaming chunks.
pub struct HarmonyParserAdapter {
parser: StreamableParser,
prev_recipient: Option<String>,
}
impl HarmonyParserAdapter {
/// Create a new Harmony parser
pub fn new() -> Result<Self, String> {
let encoding = get_harmony_encoding();
let parser = StreamableParser::new(encoding.clone(), Some(Role::Assistant))
.map_err(|e| format!("Failed to create StreamableParser: {}", e))?;
Ok(Self {
parser,
prev_recipient: None,
})
}
/// Extract text from message content (private helper)
///
/// Filters text content from a message's content array and joins them into a single string.
///
/// # Arguments
///
/// * `content` - The content array from a Harmony message
///
/// # Returns
///
/// Joined text string from all text content items
fn extract_text_from_content(content: &[openai_harmony::chat::Content]) -> String {
content
.iter()
.filter_map(|c| match c {
openai_harmony::chat::Content::Text(tc) => Some(tc.text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("")
}
/// Handle incomplete content from parser state (private helper)
///
/// Checks for any remaining incomplete content in the parser and appends it
/// to the appropriate channel (analysis or final_text).
///
/// # Arguments
///
/// * `parser` - Reference to the StreamableParser
/// * `analysis` - Mutable reference to analysis content
/// * `final_text` - Mutable reference to final text content
fn handle_incomplete_content(
parser: &StreamableParser,
analysis: &mut Option<String>,
final_text: &mut String,
) {
if let Ok(current_content) = parser.current_content() {
if !current_content.is_empty() {
let current_channel = parser.current_channel();
match current_channel.as_deref() {
Some("analysis") => {
*analysis = Some(current_content);
}
Some("final") | None => {
final_text.push_str(&current_content);
}
_ => {}
}
}
}
}
/// Parse messages into channel outputs (private helper)
///
/// Extracts analysis, commentary (tool calls), and final text from Harmony messages.
/// This is the core parsing logic shared by both parse_complete and finalize.
///
/// # Arguments
///
/// * `messages` - The messages to parse from the Harmony parser
///
/// # Returns
///
/// Tuple of (analysis, commentary, final_text)
fn parse_messages(
messages: &[openai_harmony::chat::Message],
) -> (Option<String>, Option<Vec<ToolCall>>, String) {
let mut analysis = None;
let mut commentary: Option<Vec<ToolCall>> = None;
let mut final_text = String::new();
for msg in messages {
// Filter: Only process assistant messages
if msg.author.role != Role::Assistant {
continue;
}
let channel = msg.channel.as_deref().unwrap_or("");
let recipient = msg.recipient.as_deref();
match channel {
"analysis" => {
// Process each content item
// For Chat API, we join them into a single reasoning_content
let text = Self::extract_text_from_content(&msg.content);
if !text.is_empty() {
analysis = Some(text);
}
}
"commentary" => {
// Handle different recipient types
if let Some(recipient_str) = recipient {
if recipient_str.starts_with("functions.") {
let function_name = recipient_str.strip_prefix("functions.").unwrap();
// Process each content item separately
for content in &msg.content {
if let openai_harmony::chat::Content::Text(tc) = content {
let call_id = format!("call_{}", Uuid::new_v4());
let tool_call = ToolCall {
id: call_id,
tool_type: "function".to_string(),
function: FunctionCallResponse {
name: function_name.to_string(),
arguments: Some(tc.text.clone()),
},
};
match commentary.as_mut() {
Some(calls) => calls.push(tool_call),
None => commentary = Some(vec![tool_call]),
}
}
}
} else if recipient_str.starts_with("python")
|| recipient_str.starts_with("browser")
|| recipient_str.starts_with("container")
{
// Built-in tools → treat as reasoning
// For Chat API, we add to analysis content
let text = Self::extract_text_from_content(&msg.content);
if !text.is_empty() {
// Append to analysis (built-in tools are reasoning)
match analysis.as_mut() {
Some(existing) => {
existing.push('\n');
existing.push_str(&text);
}
None => analysis = Some(text),
}
}
}
// Unknown recipient would raise ValueError
// For now, we silently ignore (can add logging later)
}
}
"final" => {
// Process final channel content
let text = Self::extract_text_from_content(&msg.content);
final_text.push_str(&text);
}
_ => {
// Unknown channel, append to final text as fallback
let text = Self::extract_text_from_content(&msg.content);
final_text.push_str(&text);
}
}
}
(analysis, commentary, final_text)
}
/// Parse complete response
///
/// Parses all output token IDs and returns the complete channel output
/// containing analysis, commentary (tool calls), and final text.
///
/// # Arguments
///
/// * `output_ids` - The complete output token IDs from the model
/// * `finish_reason` - The finish reason from GenerateComplete ("stop", "length", etc.)
/// * `matched_stop` - Optional matched stop token information from GenerateComplete
///
/// # Returns
///
/// Complete HarmonyChannelOutput with all three channels parsed
pub fn parse_complete(
&mut self,
output_ids: &[u32],
finish_reason: String,
matched_stop: Option<serde_json::Value>,
) -> Result<HarmonyChannelOutput, String> {
// Feed all tokens to the parser
for &token_id in output_ids {
self.parser.process(token_id).map_err(|e| {
// Log the full output_ids context on error
tracing::error!(
token_id = token_id,
output_ids = ?output_ids,
error = %e,
"Harmony parser failed to process token"
);
format!("Failed to process token {}: {}", token_id, e)
})?;
}
// Extract all completed messages from the parser
let messages = self.parser.messages();
// Parse messages into channel outputs using shared helper
let (mut analysis, commentary, mut final_text) = Self::parse_messages(messages);
// Check for incomplete content in parser state
Self::handle_incomplete_content(&self.parser, &mut analysis, &mut final_text);
// Determine finish reason: override to "tool_calls" if commentary has tool calls
let final_finish_reason = if commentary.is_some() {
"tool_calls".to_string()
} else {
finish_reason
};
Ok(HarmonyChannelOutput {
analysis,
commentary,
final_text,
finish_reason: final_finish_reason,
matched_stop,
})
}
/// Get all messages from the parser
///
/// Returns the raw messages extracted by the Harmony parser.
/// Used for validation checks.
pub fn get_messages(&self) -> Vec<openai_harmony::chat::Message> {
self.parser.messages().to_vec()
}
/// Parse streaming chunk
///
/// Parses incremental token IDs and returns a delta with any new content
/// from the analysis, commentary, or final channels.
///
/// # Arguments
///
/// * `chunk_ids` - New token IDs from the current chunk
///
/// # Returns
///
/// Optional HarmonyChannelDelta if there's new content to emit
pub fn parse_chunk(
&mut self,
chunk_ids: &[u32],
) -> Result<Option<HarmonyChannelDelta>, String> {
let mut has_delta = false;
let mut analysis_delta = None;
let mut commentary_delta = None;
let mut final_delta = None;
// Track message count before processing
let prev_message_count = self.parser.messages().len();
// Accumulate delta text for commentary channel
let mut accumulated_delta = String::new();
// Process each token
for &token_id in chunk_ids {
self.parser
.process(token_id)
.map_err(|e| format!("Failed to process token {}: {}", token_id, e))?;
// Check for content delta
if let Ok(Some(delta_text)) = self.parser.last_content_delta() {
has_delta = true;
// Determine which channel this delta belongs to
let channel = self.parser.current_channel();
match channel.as_deref() {
Some("analysis") => {
analysis_delta = Some(delta_text);
}
Some("final") | None => {
final_delta = Some(delta_text);
}
Some("commentary") => {
// Accumulate delta for commentary
accumulated_delta.push_str(&delta_text);
}
_ => {}
}
}
}
// Handle commentary channel tool call deltas
if self.parser.current_channel().as_deref() == Some("commentary") {
if let Some(cur_recipient) = self.parser.current_recipient() {
if cur_recipient.starts_with("functions.") {
has_delta = true;
// Count completed tool calls for index
let base_index = self
.parser
.messages()
.iter()
.filter(|msg| {
msg.channel.as_deref() == Some("commentary")
&& msg
.recipient
.as_deref()
.is_some_and(|r| r.starts_with("functions."))
})
.count();
// Check if recipient changed (new tool call)
let recipient_changed = self.prev_recipient.as_deref() != Some(&cur_recipient);
if recipient_changed {
// NEW tool call: emit name + id
let tool_name = cur_recipient.strip_prefix("functions.").unwrap();
let call_id = format!("call_{}", Uuid::new_v4());
commentary_delta = Some(super::types::ToolCallDelta {
index: base_index,
id: Some(call_id),
function: Some(super::types::FunctionDelta {
name: Some(tool_name.to_string()),
arguments: Some(String::new()),
}),
});
// Update prev_recipient
self.prev_recipient = Some(cur_recipient);
} else if !accumulated_delta.is_empty() {
// CONTINUING tool call: emit arguments delta
commentary_delta = Some(super::types::ToolCallDelta {
index: base_index,
id: None,
function: Some(super::types::FunctionDelta {
name: None,
arguments: Some(accumulated_delta),
}),
});
}
}
}
}
// Check if new messages were completed
let current_message_count = self.parser.messages().len();
let is_final = current_message_count > prev_message_count;
if has_delta {
Ok(Some(HarmonyChannelDelta {
analysis_delta,
commentary_delta,
final_delta,
is_final,
}))
} else {
Ok(None)
}
}
/// Finalize parsing
///
/// Called at the end of streaming to get the final state and any
/// remaining content.
///
/// # Arguments
///
/// * `finish_reason` - The finish reason from GenerateComplete ("stop", "length", etc.)
/// * `matched_stop` - Optional matched stop token information from GenerateComplete
///
/// # Returns
///
/// Final HarmonyChannelOutput with complete parsed content
pub fn finalize(
&mut self,
finish_reason: String,
matched_stop: Option<serde_json::Value>,
) -> Result<HarmonyChannelOutput, String> {
// Extract all completed messages
let messages = self.parser.messages();
// Parse messages into channel outputs using shared helper
let (mut analysis, commentary, mut final_text) = Self::parse_messages(messages);
// Check for remaining incomplete content
Self::handle_incomplete_content(&self.parser, &mut analysis, &mut final_text);
// Determine finish reason: override to "tool_calls" if commentary has tool calls
let final_finish_reason = if commentary.is_some() {
"tool_calls".to_string()
} else {
finish_reason
};
Ok(HarmonyChannelOutput {
analysis,
commentary,
final_text,
finish_reason: final_finish_reason,
matched_stop,
})
}
/// Reset parser state
///
/// Resets the parser to initial state for reuse
pub fn reset(&mut self) -> Result<(), String> {
// Create a new parser instance (StreamableParser doesn't have a reset method)
let encoding = get_harmony_encoding();
self.parser = StreamableParser::new(encoding.clone(), Some(Role::Assistant))
.map_err(|e| format!("Failed to reset parser: {}", e))?;
self.prev_recipient = None;
Ok(())
}
}
impl Default for HarmonyParserAdapter {
fn default() -> Self {
Self::new().expect("Failed to create default parser")
}
}
//! Harmony response processor for non-streaming responses
use std::sync::Arc;
use axum::response::Response;
use proto::generate_complete::MatchedStop::{MatchedStopStr, MatchedTokenId};
use super::HarmonyParserAdapter;
use crate::{
grpc_client::proto,
protocols::{
chat::{ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse},
common::{ToolCall, Usage},
responses::{
ResponseContentPart, ResponseOutputItem, ResponseReasoningContent, ResponseStatus,
ResponseUsage, ResponsesRequest, ResponsesResponse, ResponsesUsage,
},
},
routers::grpc::{
context::{DispatchMetadata, ExecutionResult},
utils,
},
};
/// Processor for non-streaming Harmony responses
///
/// Collects all output tokens from execution and parses them using
/// HarmonyParserAdapter to extract the complete response.
pub struct HarmonyResponseProcessor;
impl HarmonyResponseProcessor {
/// Create a new Harmony response processor
pub fn new() -> Self {
Self
}
/// Collect responses from ExecutionResult (similar to regular processor)
async fn collect_responses(
execution_result: ExecutionResult,
) -> Result<Vec<proto::GenerateComplete>, Response> {
match execution_result {
ExecutionResult::Single { mut stream } => {
let responses = utils::collect_stream_responses(&mut stream, "Single").await?;
stream.mark_completed();
Ok(responses)
}
ExecutionResult::Dual { prefill, decode } => {
// For Harmony we currently rely only on decode stream for outputs
let mut decode_stream = *decode;
let responses =
utils::collect_stream_responses(&mut decode_stream, "Decode").await?;
prefill.mark_completed();
decode_stream.mark_completed();
Ok(responses)
}
}
}
/// Process a non-streaming Harmony chat response
pub async fn process_non_streaming_chat_response(
&self,
execution_result: ExecutionResult,
chat_request: Arc<ChatCompletionRequest>,
dispatch: DispatchMetadata,
) -> Result<ChatCompletionResponse, Response> {
// Collect all completed responses (one per choice)
let all_responses = Self::collect_responses(execution_result).await?;
if all_responses.is_empty() {
return Err(utils::internal_error_static("No responses from server"));
}
// Build choices by parsing output with HarmonyParserAdapter
let mut choices: Vec<ChatChoice> = Vec::new();
for (index, complete) in all_responses.iter().enumerate() {
// Convert matched_stop from proto to JSON
let matched_stop = complete.matched_stop.as_ref().map(|m| match m {
MatchedTokenId(id) => {
serde_json::json!(id)
}
MatchedStopStr(s) => {
serde_json::json!(s)
}
});
// Parse Harmony channels with HarmonyParserAdapter
let mut parser = HarmonyParserAdapter::new().map_err(|e| {
utils::internal_error_message(format!("Failed to create Harmony parser: {}", e))
})?;
// Parse Harmony channels with finish_reason and matched_stop
let parsed = parser
.parse_complete(
&complete.output_ids,
complete.finish_reason.clone(),
matched_stop.clone(),
)
.map_err(|e| {
utils::internal_error_message(format!("Harmony parsing failed: {}", e))
})?;
// Build response message (assistant)
let message = ChatCompletionMessage {
role: "assistant".to_string(),
content: (!parsed.final_text.is_empty()).then_some(parsed.final_text),
tool_calls: parsed.commentary,
reasoning_content: parsed.analysis,
};
let finish_reason = parsed.finish_reason;
choices.push(ChatChoice {
index: index as u32,
message,
logprobs: None,
finish_reason: Some(finish_reason),
matched_stop,
hidden_states: None,
});
}
// Build usage from proto fields
let prompt_tokens: u32 = all_responses.iter().map(|r| r.prompt_tokens as u32).sum();
let completion_tokens: u32 = all_responses
.iter()
.map(|r| r.completion_tokens as u32)
.sum();
let usage = Usage {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
completion_tokens_details: None,
};
// Final ChatCompletionResponse
let response = ChatCompletionResponse {
id: dispatch.request_id.clone(),
object: "chat.completion".to_string(),
created: dispatch.created,
model: chat_request.model.clone(),
choices,
usage: Some(usage),
system_fingerprint: dispatch.weight_version.clone(),
};
Ok(response)
}
}
impl Default for HarmonyResponseProcessor {
fn default() -> Self {
Self::new()
}
}
/// Result of processing a single Responses API iteration
///
/// Used by the MCP tool loop to determine whether to continue
/// executing tools or return the final response.
pub enum ResponsesIterationResult {
/// Tool calls found in commentary channel - continue MCP loop
ToolCallsFound {
tool_calls: Vec<ToolCall>,
analysis: Option<String>, // For streaming emission
partial_text: String, // For streaming emission
},
/// No tool calls - return final ResponsesResponse
Completed {
response: Box<ResponsesResponse>,
usage: Usage,
},
}
impl HarmonyResponseProcessor {
/// Process a single Responses API iteration
///
/// Parses Harmony channels and determines if tool calls are present.
/// If tool calls found, returns ToolCallsFound for MCP loop to execute.
/// If no tool calls, builds final ResponsesResponse.
///
/// # Arguments
///
/// * `execution_result` - The execution result from the model
/// * `responses_request` - The original Responses API request
/// * `dispatch` - Dispatch metadata for request tracking
///
/// # Returns
///
/// ResponsesIterationResult indicating whether to continue loop or return
pub async fn process_responses_iteration(
&self,
execution_result: ExecutionResult,
responses_request: Arc<ResponsesRequest>,
dispatch: DispatchMetadata,
) -> Result<ResponsesIterationResult, Response> {
// Collect all completed responses
let all_responses = Self::collect_responses(execution_result).await?;
if all_responses.is_empty() {
return Err(utils::internal_error_static("No responses from server"));
}
// For Responses API, we only process the first response (n=1)
let complete = all_responses
.first()
.ok_or_else(|| utils::internal_error_static("No complete response"))?;
// Parse Harmony channels
let mut parser = HarmonyParserAdapter::new().map_err(|e| {
utils::internal_error_message(format!("Failed to create Harmony parser: {}", e))
})?;
// Convert matched_stop from proto to JSON
let matched_stop = complete.matched_stop.as_ref().map(|m| match m {
MatchedTokenId(id) => {
serde_json::json!(id)
}
MatchedStopStr(s) => {
serde_json::json!(s)
}
});
let parsed = parser
.parse_complete(
&complete.output_ids,
complete.finish_reason.clone(),
matched_stop,
)
.map_err(|e| utils::internal_error_message(format!("Harmony parsing failed: {}", e)))?;
// VALIDATION: Check if model incorrectly generated Tool role messages
// This happens when the model copies the format of tool result messages
// instead of continuing as assistant. This is a model hallucination bug.
let messages = parser.get_messages();
let tool_messages_generated = messages.iter().any(|msg| {
msg.author.role == openai_harmony::chat::Role::Tool
&& msg.recipient.as_deref() == Some("assistant")
});
if tool_messages_generated {
tracing::warn!(
"Model generated Tool->Assistant message instead of Assistant message. \
This is a model hallucination bug where it copies tool result format."
);
}
// Check for tool calls in commentary channel
if let Some(tool_calls) = parsed.commentary {
// Tool calls found - return for MCP loop execution
return Ok(ResponsesIterationResult::ToolCallsFound {
tool_calls,
analysis: parsed.analysis,
partial_text: parsed.final_text,
});
}
// No tool calls - build final ResponsesResponse
let mut output: Vec<ResponseOutputItem> = Vec::new();
// Map analysis channel → ResponseOutputItem::Reasoning
if let Some(analysis) = parsed.analysis {
let reasoning_item = ResponseOutputItem::Reasoning {
id: format!("reasoning_{}", dispatch.request_id),
summary: vec![],
content: vec![ResponseReasoningContent::ReasoningText { text: analysis }],
status: Some("completed".to_string()),
};
output.push(reasoning_item);
}
// Map final channel → ResponseOutputItem::Message
if !parsed.final_text.is_empty() {
let message_item = ResponseOutputItem::Message {
id: format!("msg_{}", dispatch.request_id),
role: "assistant".to_string(),
content: vec![ResponseContentPart::OutputText {
text: parsed.final_text,
annotations: vec![],
logprobs: None,
}],
status: "completed".to_string(),
};
output.push(message_item);
}
// Build usage
let prompt_tokens = complete.prompt_tokens as u32;
let completion_tokens = complete.completion_tokens as u32;
let usage = Usage {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
completion_tokens_details: None,
};
// Build ResponsesResponse with all required fields
let response = ResponsesResponse {
id: dispatch.request_id.clone(),
object: "response".to_string(),
created_at: dispatch.created as i64,
status: ResponseStatus::Completed,
error: None,
incomplete_details: None,
instructions: responses_request.instructions.clone(),
max_output_tokens: responses_request.max_output_tokens,
model: responses_request.model.clone(),
output,
parallel_tool_calls: responses_request.parallel_tool_calls.unwrap_or(true),
previous_response_id: responses_request.previous_response_id.clone(),
reasoning: None, // Set by caller if needed
store: responses_request.store.unwrap_or(true),
temperature: responses_request.temperature,
text: None,
tool_choice: responses_request
.tool_choice
.as_ref()
.map(|tc| serde_json::to_string(tc).unwrap_or_else(|_| "auto".to_string()))
.unwrap_or_else(|| "auto".to_string()),
tools: responses_request.tools.clone().unwrap_or_default(),
top_p: responses_request.top_p,
truncation: None,
usage: Some(ResponsesUsage::Modern(ResponseUsage {
input_tokens: prompt_tokens,
output_tokens: completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
input_tokens_details: None,
output_tokens_details: None,
})),
user: None,
metadata: responses_request.metadata.clone().unwrap_or_default(),
};
Ok(ResponsesIterationResult::Completed {
response: Box::new(response),
usage,
})
}
}
This diff is collapsed.
//! Harmony-specific pipeline stages
//!
//! These stages replace their regular counterparts in the Harmony pipeline:
//! - HarmonyPreparationStage: Harmony encoding instead of chat template + tokenization
//! - HarmonyRequestBuildingStage: Token-based request building
//! - HarmonyResponseProcessingStage: Harmony channel parsing
pub mod preparation;
pub mod request_building;
pub mod response_processing;
pub use preparation::HarmonyPreparationStage;
pub use request_building::HarmonyRequestBuildingStage;
pub use response_processing::HarmonyResponseProcessingStage;
//! Harmony Preparation Stage: Harmony encoding for chat and generate requests
use async_trait::async_trait;
use axum::response::Response;
use serde_json::json;
use super::super::HarmonyBuilder;
use crate::{
protocols::{
chat::ChatCompletionRequest,
common::{Tool, ToolChoice, ToolChoiceValue},
responses::ResponsesRequest,
},
routers::grpc::{
context::{PreparationOutput, RequestContext, RequestType},
stages::PipelineStage,
utils,
},
};
/// Harmony Preparation stage: Encode requests using Harmony protocol
///
/// Replaces the regular PreparationStage for Harmony models.
/// Converts chat/generate requests to Harmony-encoded token_ids and extraction_text.
pub struct HarmonyPreparationStage {
builder: HarmonyBuilder,
}
impl HarmonyPreparationStage {
/// Create a new Harmony preparation stage
pub fn new() -> Self {
Self {
builder: HarmonyBuilder::new(),
}
}
}
impl Default for HarmonyPreparationStage {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl PipelineStage for HarmonyPreparationStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
// Clone Arc before match to avoid borrow checker issues
// Arc clone is cheap (8 bytes) - avoids full request clone (15KB-200KB)
let is_chat = matches!(&ctx.input.request_type, RequestType::Chat(_));
let is_responses = matches!(&ctx.input.request_type, RequestType::Responses(_));
if is_chat {
let request_arc = ctx.chat_request_arc();
self.prepare_chat(ctx, &request_arc).await?;
} else if is_responses {
let request_arc = ctx.responses_request_arc();
self.prepare_responses(ctx, &request_arc).await?;
} else {
return Err(utils::bad_request_error(
"Only Chat and Responses requests supported in Harmony pipeline".to_string(),
));
}
Ok(None)
}
fn name(&self) -> &'static str {
"HarmonyPreparation"
}
}
impl HarmonyPreparationStage {
/// Prepare a chat completion request using Harmony encoding
async fn prepare_chat(
&self,
ctx: &mut RequestContext,
request: &ChatCompletionRequest,
) -> Result<Option<Response>, Response> {
// Validate - reject logprobs
if request.logprobs {
return Err(utils::bad_request_error(
"logprobs are not supported for Harmony models".to_string(),
));
}
// Step 1: Filter tools if needed
let body_ref = utils::filter_tools_for_request(request);
// Step 2: Build tool constraints
let tool_constraints = if let Some(tools) = body_ref.tools.as_ref() {
Self::generate_harmony_structural_tag(tools, &body_ref.tool_choice).map_err(|e| *e)?
} else {
None
};
// Step 3: Build via Harmony
let build_output = self
.builder
.build_from_chat(&body_ref)
.map_err(|e| utils::bad_request_error(format!("Harmony build failed: {}", e)))?;
// Step 4: Store results
ctx.state.preparation = Some(PreparationOutput {
original_text: None,
token_ids: build_output.input_ids,
processed_messages: None,
tool_constraints,
filtered_request: if matches!(body_ref, std::borrow::Cow::Owned(_)) {
Some(body_ref.into_owned())
} else {
None
},
harmony_mode: true,
selection_text: Some(build_output.selection_text),
harmony_messages: Some(build_output.harmony_messages),
harmony_stop_ids: Some(build_output.stop_token_ids),
});
Ok(None)
}
/// Prepare a responses API request using Harmony encoding
///
/// For responses API, we build from conversation history using the same Harmony
/// encoding that the builder provides. This handles the MCP loop integration.
pub async fn prepare_responses(
&self,
ctx: &mut RequestContext,
request: &ResponsesRequest,
) -> Result<Option<Response>, Response> {
// Build via Harmony from responses API request
let build_output = self
.builder
.build_from_responses(request)
.map_err(|e| utils::bad_request_error(format!("Harmony build failed: {}", e)))?;
// Store results in preparation output
ctx.state.preparation = Some(PreparationOutput {
original_text: None,
token_ids: build_output.input_ids,
processed_messages: None,
tool_constraints: None,
filtered_request: None,
harmony_mode: true,
selection_text: Some(build_output.selection_text),
harmony_messages: Some(build_output.harmony_messages),
harmony_stop_ids: Some(build_output.stop_token_ids),
});
Ok(None)
}
/// Generate Harmony structural tag for tool constraints
///
/// Uses structural tags with `triggered_tags` format to force Harmony format output.
/// This ensures the model outputs in Harmony format (with channels) even when constrained.
fn generate_harmony_structural_tag(
tools: &[Tool],
tool_choice: &Option<ToolChoice>,
) -> Result<Option<(String, String)>, Box<Response>> {
let Some(choice) = tool_choice.as_ref() else {
return Ok(None);
};
match choice {
ToolChoice::Function { function, .. } => {
let tag = Self::build_harmony_structural_tag(tools, Some(&function.name))?;
Ok(Some(("structural_tag".to_string(), tag)))
}
ToolChoice::Value(ToolChoiceValue::Required) => {
let tag = Self::build_harmony_structural_tag(tools, None)?;
Ok(Some(("structural_tag".to_string(), tag)))
}
ToolChoice::AllowedTools { mode, .. } => {
if mode == "required" {
let tag = Self::build_harmony_structural_tag(tools, None)?;
Ok(Some(("structural_tag".to_string(), tag)))
} else {
Ok(None)
}
}
_ => Ok(None),
}
}
/// Build Harmony structural tag for tool calling constraints
fn build_harmony_structural_tag(
tools: &[Tool],
specific_function: Option<&str>,
) -> Result<String, Box<Response>> {
let mut tags = Vec::new();
// Filter tools if specific function requested
let tools_to_use: Vec<&Tool> = if let Some(func_name) = specific_function {
tools
.iter()
.filter(|t| t.function.name == func_name)
.collect()
} else {
tools.iter().collect()
};
// Validate specific function exists
if specific_function.is_some() && tools_to_use.is_empty() {
return Err(Box::new(utils::bad_request_error(format!(
"Tool '{}' not found in tools list",
specific_function.unwrap()
))));
}
// Build tags for each tool
for tool in tools_to_use {
let tool_name = &tool.function.name;
let params_schema = &tool.function.parameters;
tags.push(json!({
"begin": format!("<|channel|>commentary to=functions.{}<|constrain|>json<|message|>", tool_name),
"content": {
"type": "json_schema",
"json_schema": params_schema
},
"end": "" // `end` is empty because <|call|> comes naturally from Harmony stop tokens
}));
}
let stop_after_first = specific_function.is_some();
let structural_tag = json!({
"format": {
"type": "triggered_tags",
"triggers": ["<|channel|>commentary"],
"tags": tags,
"at_least_one": true,
"stop_after_first": stop_after_first
}
});
serde_json::to_string(&structural_tag).map_err(|e| {
Box::new(utils::internal_error_message(format!(
"Failed to serialize structural tag: {}",
e
)))
})
}
}
//! Harmony Request Building Stage: Build gRPC request from Harmony-encoded tokens
use std::sync::Arc;
use async_trait::async_trait;
use axum::response::Response;
use rand::Rng;
use tracing::debug;
use uuid::Uuid;
use crate::{
core::Worker,
grpc_client::proto::{DisaggregatedParams, GenerateRequest},
routers::grpc::{
context::{ClientSelection, RequestContext, RequestType, WorkerSelection},
stages::PipelineStage,
utils,
},
};
/// Harmony Request Building stage: Convert Harmony tokens to gRPC request
///
/// Takes the Harmony-encoded input_ids from preparation and builds a proto::GenerateRequest.
/// Unlike regular request building, this uses token_ids directly (Harmony encoding handles messages).
pub struct HarmonyRequestBuildingStage {
inject_pd_metadata: bool,
}
impl HarmonyRequestBuildingStage {
/// Create a new Harmony request building stage
pub fn new(inject_pd_metadata: bool) -> Self {
Self { inject_pd_metadata }
}
/// Inject PD (prefill-decode) bootstrap metadata
fn inject_bootstrap_metadata(
&self,
request: &mut GenerateRequest,
prefill_worker: &Arc<dyn Worker>,
) {
let hostname = prefill_worker.bootstrap_host();
let bootstrap_port = prefill_worker.bootstrap_port().unwrap_or(8998);
// Generate room ID for bootstrap
let room_id = rand::rng().random_range(0..i32::MAX);
// Create DisaggregatedParams
let disagg_params = DisaggregatedParams {
bootstrap_host: hostname.to_string(),
bootstrap_port: bootstrap_port as i32,
bootstrap_room: room_id,
};
// Inject metadata directly into request
request.disaggregated_params = Some(disagg_params);
debug!(
"Injected Harmony bootstrap metadata: host={}, port={}, room={}",
hostname, bootstrap_port, room_id
);
}
}
#[async_trait]
impl PipelineStage for HarmonyRequestBuildingStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
// Get preparation output
let prep = ctx
.state
.preparation
.as_ref()
.ok_or_else(|| utils::internal_error_static("Preparation not completed"))?;
// Get clients
let clients = ctx
.state
.clients
.as_ref()
.ok_or_else(|| utils::internal_error_static("Client acquisition not completed"))?;
let builder_client = match clients {
ClientSelection::Single { client } => client,
ClientSelection::Dual { prefill, .. } => prefill,
};
// Generate request_id based on request type
let request_id = match &ctx.input.request_type {
RequestType::Chat(_) => format!("chatcmpl-{}", Uuid::new_v4()),
RequestType::Responses(_) => format!("responses-{}", Uuid::new_v4()),
RequestType::Generate(_) => {
return Err(utils::bad_request_error(
"Generate requests are not supported with Harmony models".to_string(),
));
}
};
// Build gRPC request using token_ids directly (Harmony encoding already handled message rendering)
// Use a placeholder for original_text; Harmony uses input_ids for tokenization
let placeholder_processed_text = "[harmony]".to_string();
let mut proto_request = match &ctx.input.request_type {
RequestType::Chat(request) => {
// Use filtered request if present from preparation; otherwise original
let body = prep.filtered_request.as_ref().unwrap_or(request.as_ref());
builder_client
.build_generate_request(
request_id,
body,
placeholder_processed_text,
prep.token_ids.clone(),
None,
prep.tool_constraints.clone(),
)
.map_err(|e| {
utils::bad_request_error(format!("Invalid request parameters: {}", e))
})?
}
RequestType::Responses(request) => builder_client
.build_generate_request_from_responses(
request_id,
request.as_ref(),
placeholder_processed_text,
prep.token_ids.clone(),
prep.harmony_stop_ids.clone(),
)
.map_err(|e| {
utils::bad_request_error(format!("Invalid request parameters: {}", e))
})?,
_ => unreachable!(),
};
// Inject Harmony stop token IDs into sampling params for ALL Harmony requests
// These stop tokens (<|return|> and <|call|>) prevent the model from generating
// malformed Harmony sequences
if let Some(harmony_stops) = &prep.harmony_stop_ids {
if let Some(params) = proto_request.sampling_params.as_mut() {
params.stop_token_ids.extend_from_slice(harmony_stops);
debug!(
stop_token_count = harmony_stops.len(),
"Injected Harmony stop tokens into sampling params"
);
}
}
// Inject PD metadata if needed
if self.inject_pd_metadata {
if let Some(WorkerSelection::Dual { prefill, .. }) = ctx.state.workers.as_ref() {
self.inject_bootstrap_metadata(&mut proto_request, prefill);
}
}
ctx.state.proto_request = Some(proto_request);
Ok(None)
}
fn name(&self) -> &'static str {
"HarmonyRequestBuilding"
}
}
//! Harmony Response Processing Stage: Parse Harmony channels to ChatCompletionResponse
use std::sync::Arc;
use async_trait::async_trait;
use axum::response::Response;
use super::super::{HarmonyResponseProcessor, HarmonyStreamingProcessor};
use crate::routers::grpc::{
context::{FinalResponse, RequestContext, RequestType},
stages::PipelineStage,
utils,
};
/// Harmony Response Processing stage: Parse and format Harmony responses
///
/// Takes output tokens from execution and parses them using HarmonyParserAdapter
/// to extract analysis, tool calls, and final response text from Harmony channels.
pub struct HarmonyResponseProcessingStage {
processor: HarmonyResponseProcessor,
streaming_processor: Arc<HarmonyStreamingProcessor>,
}
impl HarmonyResponseProcessingStage {
/// Create a new Harmony response processing stage
pub fn new() -> Self {
Self {
processor: HarmonyResponseProcessor::new(),
streaming_processor: Arc::new(HarmonyStreamingProcessor::new()),
}
}
}
impl Default for HarmonyResponseProcessingStage {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl PipelineStage for HarmonyResponseProcessingStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
// Get execution result (output tokens from model)
let execution_result = ctx
.state
.response
.execution_result
.take()
.ok_or_else(|| utils::internal_error_static("No execution result"))?;
let is_streaming = ctx.is_streaming();
let dispatch = ctx
.state
.dispatch
.as_ref()
.cloned()
.ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?;
// Check request type to determine which processor method to call
match &ctx.input.request_type {
RequestType::Chat(_) => {
// For streaming, delegate to streaming processor and return SSE response
if is_streaming {
return Ok(Some(
self.streaming_processor
.clone()
.process_streaming_chat_response(
execution_result,
ctx.chat_request_arc(),
dispatch,
),
));
}
// For non-streaming, delegate to Harmony response processor to build ChatCompletionResponse
let chat_request = ctx.chat_request_arc();
let response = self
.processor
.process_non_streaming_chat_response(execution_result, chat_request, dispatch)
.await?;
ctx.state.response.final_response = Some(FinalResponse::Chat(response));
Ok(None)
}
RequestType::Responses(_) => {
// For Responses API, process iteration and store result
// Streaming not yet supported for Responses API
if is_streaming {
return Err(utils::internal_error_static(
"Streaming not yet supported for Responses API",
));
}
let responses_request = ctx.responses_request_arc();
let iteration_result = self
.processor
.process_responses_iteration(execution_result, responses_request, dispatch)
.await?;
ctx.state.response.responses_iteration_result = Some(iteration_result);
Ok(None)
}
RequestType::Generate(_) => Err(utils::internal_error_static(
"Generate requests not supported in Harmony pipeline",
)),
}
}
fn name(&self) -> &'static str {
"HarmonyResponseProcessing"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_response_processing_stage_creation() {
let stage = HarmonyResponseProcessingStage::new();
assert_eq!(stage.name(), "HarmonyResponseProcessing");
}
}
This diff is collapsed.
//! Shared types for Harmony pipeline
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::protocols::common::ToolCall;
/// Harmony message format
///
/// Represents messages in the Harmony encoding format with role and content.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HarmonyMessage {
pub role: String,
pub content: String,
}
impl HarmonyMessage {
pub fn new(role: impl Into<String>, content: impl Into<String>) -> Self {
Self {
role: role.into(),
content: content.into(),
}
}
pub fn user(content: impl Into<String>) -> Self {
Self::new("user", content)
}
pub fn assistant(content: impl Into<String>) -> Self {
Self::new("assistant", content)
}
pub fn system(content: impl Into<String>) -> Self {
Self::new("system", content)
}
/// Convert from openai_harmony::chat::Message to our simplified HarmonyMessage
pub fn from_openai_harmony(msg: openai_harmony::chat::Message) -> Self {
use openai_harmony::chat::Content;
// Extract role as string
let role = match msg.author.role {
openai_harmony::chat::Role::User => "user",
openai_harmony::chat::Role::Assistant => "assistant",
openai_harmony::chat::Role::System => "system",
openai_harmony::chat::Role::Developer => "developer",
openai_harmony::chat::Role::Tool => "tool",
}
.to_string();
// Extract text content from all Content::Text parts
let content = msg
.content
.iter()
.filter_map(|c| match c {
Content::Text(tc) => Some(tc.text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("");
Self { role, content }
}
}
/// Output from Harmony encoding process
///
/// Contains the encoded input_ids, stop tokens, selection text for worker routing,
/// and the Harmony message history.
#[derive(Debug, Clone)]
pub struct HarmonyBuildOutput {
/// Encoded token IDs to send to the model
pub input_ids: Vec<u32>,
/// Stop token IDs for this model (injected into sampling params)
pub stop_token_ids: Vec<u32>,
/// Selection text for worker routing (concise snippet from last user message)
pub selection_text: String,
/// Harmony messages for this conversation (used for history tracking)
pub harmony_messages: Vec<HarmonyMessage>,
}
/// Parsed output from all three Harmony channels
///
/// Represents the complete response after parsing analysis, commentary, and final channels.
#[derive(Debug, Clone)]
pub struct HarmonyChannelOutput {
/// Analysis/reasoning content (from analysis channel)
pub analysis: Option<String>,
/// Tool calls (from commentary channel)
pub commentary: Option<Vec<ToolCall>>,
/// Final text content (from final channel)
pub final_text: String,
/// Finish reason
pub finish_reason: String,
/// Matched stop token (if any)
pub matched_stop: Option<Value>,
}
/// Streaming delta for SSE responses
///
/// Represents incremental updates as tokens are parsed from the stream.
#[derive(Debug, Clone)]
pub struct HarmonyChannelDelta {
/// Delta for analysis/reasoning content
pub analysis_delta: Option<String>,
/// Delta for tool calls
pub commentary_delta: Option<ToolCallDelta>,
/// Delta for final text content
pub final_delta: Option<String>,
/// Whether this is the final delta
pub is_final: bool,
}
/// Tool call delta for streaming
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallDelta {
pub index: usize,
pub id: Option<String>,
pub function: Option<FunctionDelta>,
}
/// Function call delta for streaming
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionDelta {
pub name: Option<String>,
pub arguments: Option<String>,
}
...@@ -3,11 +3,13 @@ ...@@ -3,11 +3,13 @@
use crate::{grpc_client::proto, protocols::common::StringOrArray}; use crate::{grpc_client::proto, protocols::common::StringOrArray};
pub mod context; pub mod context;
pub mod harmony;
pub mod pd_router; pub mod pd_router;
pub mod pipeline; pub mod pipeline;
pub mod processing; pub mod processing;
pub mod responses; pub mod responses;
pub mod router; pub mod router;
pub mod stages;
pub mod streaming; pub mod streaming;
pub mod utils; pub mod utils;
......
...@@ -14,9 +14,7 @@ use tracing::debug; ...@@ -14,9 +14,7 @@ use tracing::debug;
use super::{context::SharedComponents, pipeline::RequestPipeline}; use super::{context::SharedComponents, pipeline::RequestPipeline};
use crate::{ use crate::{
app_context::AppContext, app_context::AppContext,
config::types::RetryConfig,
core::{ConnectionMode, WorkerRegistry, WorkerType}, core::{ConnectionMode, WorkerRegistry, WorkerType},
policies::PolicyRegistry,
protocols::{ protocols::{
chat::ChatCompletionRequest, chat::ChatCompletionRequest,
classify::ClassifyRequest, classify::ClassifyRequest,
...@@ -26,26 +24,13 @@ use crate::{ ...@@ -26,26 +24,13 @@ use crate::{
rerank::RerankRequest, rerank::RerankRequest,
responses::{ResponsesGetParams, ResponsesRequest}, responses::{ResponsesGetParams, ResponsesRequest},
}, },
reasoning_parser::ParserFactory as ReasoningParserFactory,
routers::RouterTrait, routers::RouterTrait,
tokenizer::traits::Tokenizer,
tool_parser::ParserFactory as ToolParserFactory,
}; };
/// gRPC PD (Prefill-Decode) router implementation for SGLang /// gRPC PD (Prefill-Decode) router implementation for SGLang
#[derive(Clone)] #[derive(Clone)]
#[allow(dead_code)] // Fields will be used once implementation is complete
pub struct GrpcPDRouter { pub struct GrpcPDRouter {
worker_registry: Arc<WorkerRegistry>, worker_registry: Arc<WorkerRegistry>,
policy_registry: Arc<PolicyRegistry>,
tokenizer: Arc<dyn Tokenizer>,
reasoning_parser_factory: ReasoningParserFactory,
tool_parser_factory: ToolParserFactory,
dp_aware: bool,
api_key: Option<String>,
retry_config: RetryConfig,
configured_reasoning_parser: Option<String>,
configured_tool_parser: Option<String>,
pipeline: RequestPipeline, pipeline: RequestPipeline,
shared_components: Arc<SharedComponents>, shared_components: Arc<SharedComponents>,
} }
...@@ -94,15 +79,6 @@ impl GrpcPDRouter { ...@@ -94,15 +79,6 @@ impl GrpcPDRouter {
Ok(GrpcPDRouter { Ok(GrpcPDRouter {
worker_registry, worker_registry,
policy_registry,
tokenizer,
reasoning_parser_factory,
tool_parser_factory,
dp_aware: ctx.router_config.dp_aware,
api_key: ctx.router_config.api_key.clone(),
retry_config: ctx.router_config.effective_retry_config(),
configured_reasoning_parser: ctx.configured_reasoning_parser.clone(),
configured_tool_parser: ctx.configured_tool_parser.clone(),
pipeline, pipeline,
shared_components, shared_components,
}) })
...@@ -174,7 +150,6 @@ impl std::fmt::Debug for GrpcPDRouter { ...@@ -174,7 +150,6 @@ impl std::fmt::Debug for GrpcPDRouter {
f.debug_struct("GrpcPDRouter") f.debug_struct("GrpcPDRouter")
.field("prefill_workers_count", &prefill_workers.len()) .field("prefill_workers_count", &prefill_workers.len())
.field("decode_workers_count", &decode_workers.len()) .field("decode_workers_count", &decode_workers.len())
.field("dp_aware", &self.dp_aware)
.finish() .finish()
} }
} }
...@@ -255,19 +230,19 @@ impl RouterTrait for GrpcPDRouter { ...@@ -255,19 +230,19 @@ impl RouterTrait for GrpcPDRouter {
(StatusCode::NOT_IMPLEMENTED).into_response() (StatusCode::NOT_IMPLEMENTED).into_response()
} }
async fn route_classify( async fn route_embeddings(
&self, &self,
_headers: Option<&HeaderMap>, _headers: Option<&HeaderMap>,
_body: &ClassifyRequest, _body: &EmbeddingRequest,
_model_id: Option<&str>, _model_id: Option<&str>,
) -> Response { ) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response() (StatusCode::NOT_IMPLEMENTED).into_response()
} }
async fn route_embeddings( async fn route_classify(
&self, &self,
_headers: Option<&HeaderMap>, _headers: Option<&HeaderMap>,
_body: &EmbeddingRequest, _body: &ClassifyRequest,
_model_id: Option<&str>, _model_id: Option<&str>,
) -> Response { ) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response() (StatusCode::NOT_IMPLEMENTED).into_response()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment