Unverified Commit 03b3e89a authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] Harmony Pipeline: Chat Completion & Responses API with MCP Support (#12153)


Co-authored-by: default avatarChang Su <chang.s.su@oracle.com>
parent 9ff9fa7f
...@@ -83,6 +83,7 @@ oracle = { version = "0.6.3", features = ["chrono"] } ...@@ -83,6 +83,7 @@ oracle = { version = "0.6.3", features = ["chrono"] }
subtle = "2.6" subtle = "2.6"
rustpython-parser = "0.4.0" rustpython-parser = "0.4.0"
num-traits = "0.2" num-traits = "0.2"
openai-harmony = { git = "https://github.com/openai/harmony", tag = "v0.0.4" }
# gRPC and Protobuf dependencies # gRPC and Protobuf dependencies
tonic = { version = "0.14.2", features = ["gzip", "transport"] } tonic = { version = "0.14.2", features = ["gzip", "transport"] }
......
# SGLang Router Makefile # Model Gateway Makefile
# Provides convenient shortcuts for common development tasks # Provides convenient shortcuts for common development tasks
# Check if sccache is available and set RUSTC_WRAPPER accordingly # Check if sccache is available and set RUSTC_WRAPPER accordingly
...@@ -13,14 +13,14 @@ endif ...@@ -13,14 +13,14 @@ endif
.PHONY: help bench bench-quick bench-baseline bench-compare test build clean .PHONY: help bench bench-quick bench-baseline bench-compare test build clean
help: ## Show this help message help: ## Show this help message
@echo "SGLang Router Development Commands" @echo "Model Gateway Development Commands"
@echo "==================================" @echo "=================================="
@echo "" @echo ""
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-20s\033[0m %s\n", $$1, $$2}' @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-20s\033[0m %s\n", $$1, $$2}'
@echo "" @echo ""
build: ## Build the project in release mode build: ## Build the project in release mode
@echo "Building SGLang Router..." @echo "Building SGLang Model Gateway..."
@cargo build --release @cargo build --release
test: ## Run all tests test: ## Run all tests
...@@ -59,11 +59,11 @@ check: ## Run cargo check and clippy ...@@ -59,11 +59,11 @@ check: ## Run cargo check and clippy
@echo "Running cargo check..." @echo "Running cargo check..."
@cargo check @cargo check
@echo "Running clippy..." @echo "Running clippy..."
@cargo clippy @cargo clippy --all-targets --all-features -- -D warnings
fmt: ## Format code with rustfmt fmt: ## Format code with rustfmt
@echo "Formatting code..." @echo "Formatting code..."
@cargo fmt @rustup run nightly cargo fmt
# Development workflow shortcuts # Development workflow shortcuts
dev-setup: build test ## Set up development environment dev-setup: build test ## Set up development environment
......
...@@ -16,6 +16,7 @@ use crate::protocols::{ ...@@ -16,6 +16,7 @@ use crate::protocols::{
chat::ChatCompletionRequest, chat::ChatCompletionRequest,
common::{ResponseFormat, StringOrArray, ToolChoice, ToolChoiceValue}, common::{ResponseFormat, StringOrArray, ToolChoice, ToolChoiceValue},
generate::GenerateRequest, generate::GenerateRequest,
responses::ResponsesRequest,
sampling_params::SamplingParams as GenerateSamplingParams, sampling_params::SamplingParams as GenerateSamplingParams,
}; };
...@@ -301,6 +302,42 @@ impl SglangSchedulerClient { ...@@ -301,6 +302,42 @@ impl SglangSchedulerClient {
Ok(grpc_request) Ok(grpc_request)
} }
/// Build a GenerateRequest from ResponsesRequest (OpenAI Responses API)
pub fn build_generate_request_from_responses(
&self,
request_id: String,
body: &ResponsesRequest,
processed_text: String,
token_ids: Vec<u32>,
harmony_stop_ids: Option<Vec<u32>>,
) -> Result<proto::GenerateRequest, String> {
// Build sampling params from ResponsesRequest
let mut sampling_params = self.build_grpc_sampling_params_from_responses(body)?;
// Inject Harmony stop token IDs if provided
if let Some(stop_ids) = harmony_stop_ids {
sampling_params.stop_token_ids = stop_ids;
}
let grpc_request = proto::GenerateRequest {
request_id,
tokenized: Some(proto::TokenizedInput {
original_text: processed_text,
input_ids: token_ids,
}),
mm_inputs: None, // Responses API doesn't support multimodal yet
sampling_params: Some(sampling_params),
return_logprob: false, // Responses API uses top_logprobs field instead
logprob_start_len: -1,
top_logprobs_num: body.top_logprobs.unwrap_or(0) as i32,
return_hidden_states: false,
stream: body.stream.unwrap_or(false),
..Default::default()
};
Ok(grpc_request)
}
/// Build gRPC SamplingParams from OpenAI request /// Build gRPC SamplingParams from OpenAI request
fn build_grpc_sampling_params( fn build_grpc_sampling_params(
&self, &self,
...@@ -400,6 +437,37 @@ impl SglangSchedulerClient { ...@@ -400,6 +437,37 @@ impl SglangSchedulerClient {
} }
} }
/// Build gRPC SamplingParams from ResponsesRequest
fn build_grpc_sampling_params_from_responses(
&self,
request: &ResponsesRequest,
) -> Result<proto::SamplingParams, String> {
// ResponsesRequest doesn't have stop sequences in the same way
// Tools are handled externally by MCP loop, not via constraints
let max_new_tokens = request.max_output_tokens.map(|v| v as i32);
Ok(proto::SamplingParams {
temperature: request.temperature.unwrap_or(1.0),
top_p: request.top_p.unwrap_or(1.0),
top_k: -1, // ResponsesRequest doesn't expose top_k
min_p: 0.0, // ResponsesRequest doesn't expose min_p
frequency_penalty: 0.0, // ResponsesRequest doesn't expose frequency_penalty
presence_penalty: 0.0, // ResponsesRequest doesn't expose presence_penalty
repetition_penalty: 1.0, // ResponsesRequest doesn't expose repetition_penalty
max_new_tokens,
stop: vec![], // No stop sequences in Responses API
stop_token_ids: vec![], // Handled by Harmony stop tokens
skip_special_tokens: false, // Keep special tokens for Harmony
spaces_between_special_tokens: true,
ignore_eos: false,
no_stop_trim: false,
n: 1, // Responses API doesn't support n>1
constraint: None, // No constraints - tools handled by MCP
..Default::default()
})
}
fn build_single_constraint_from_plain( fn build_single_constraint_from_plain(
params: &GenerateSamplingParams, params: &GenerateSamplingParams,
) -> Result<Option<proto::sampling_params::Constraint>, String> { ) -> Result<Option<proto::sampling_params::Constraint>, String> {
......
...@@ -674,7 +674,6 @@ pub struct ChatMessageDelta { ...@@ -674,7 +674,6 @@ pub struct ChatMessageDelta {
pub struct ChatStreamChoice { pub struct ChatStreamChoice {
pub index: u32, pub index: u32,
pub delta: ChatMessageDelta, pub delta: ChatMessageDelta,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<ChatLogProbs>, pub logprobs: Option<ChatLogProbs>,
pub finish_reason: Option<String>, pub finish_reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
......
...@@ -21,6 +21,9 @@ use super::common::{ ...@@ -21,6 +21,9 @@ use super::common::{
pub struct ResponseTool { pub struct ResponseTool {
#[serde(rename = "type")] #[serde(rename = "type")]
pub r#type: ResponseToolType, pub r#type: ResponseToolType,
// Function tool fields (used when type == "function")
#[serde(skip_serializing_if = "Option::is_none")]
pub function: Option<crate::protocols::common::Function>,
// MCP-specific fields (used when type == "mcp") // MCP-specific fields (used when type == "mcp")
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub server_url: Option<String>, pub server_url: Option<String>,
...@@ -40,6 +43,7 @@ impl Default for ResponseTool { ...@@ -40,6 +43,7 @@ impl Default for ResponseTool {
fn default() -> Self { fn default() -> Self {
Self { Self {
r#type: ResponseToolType::WebSearchPreview, r#type: ResponseToolType::WebSearchPreview,
function: None,
server_url: None, server_url: None,
authorization: None, authorization: None,
server_label: None, server_label: None,
...@@ -53,6 +57,7 @@ impl Default for ResponseTool { ...@@ -53,6 +57,7 @@ impl Default for ResponseTool {
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum ResponseToolType { pub enum ResponseToolType {
Function,
WebSearchPreview, WebSearchPreview,
CodeInterpreter, CodeInterpreter,
Mcp, Mcp,
...@@ -134,6 +139,13 @@ pub enum ResponseInputOutputItem { ...@@ -134,6 +139,13 @@ pub enum ResponseInputOutputItem {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
status: Option<String>, status: Option<String>,
}, },
#[serde(rename = "function_call_output")]
FunctionCallOutput {
call_id: String,
output: String,
#[serde(skip_serializing_if = "Option::is_none")]
status: Option<String>,
},
#[serde(untagged)] #[serde(untagged)]
SimpleInputMessage { SimpleInputMessage {
content: StringOrContentParts, content: StringOrContentParts,
...@@ -499,7 +511,7 @@ pub struct ResponsesRequest { ...@@ -499,7 +511,7 @@ pub struct ResponsesRequest {
pub store: Option<bool>, pub store: Option<bool>,
/// Whether to stream the response /// Whether to stream the response
#[serde(skip_serializing_if = "Option::is_none")] #[serde(default)]
pub stream: Option<bool>, pub stream: Option<bool>,
/// Temperature for sampling /// Temperature for sampling
...@@ -678,6 +690,9 @@ impl GenerationRequest for ResponsesRequest { ...@@ -678,6 +690,9 @@ impl GenerationRequest for ResponsesRequest {
ResponseInputOutputItem::FunctionToolCall { arguments, .. } => { ResponseInputOutputItem::FunctionToolCall { arguments, .. } => {
Some(arguments.clone()) Some(arguments.clone())
} }
ResponseInputOutputItem::FunctionCallOutput { output, .. } => {
Some(output.clone())
}
}) })
.collect::<Vec<String>>() .collect::<Vec<String>>()
.join(" "), .join(" "),
......
...@@ -15,6 +15,7 @@ use crate::{ ...@@ -15,6 +15,7 @@ use crate::{
protocols::{ protocols::{
chat::{ChatCompletionRequest, ChatCompletionResponse}, chat::{ChatCompletionRequest, ChatCompletionResponse},
generate::{GenerateRequest, GenerateResponse}, generate::{GenerateRequest, GenerateResponse},
responses::ResponsesRequest,
}, },
reasoning_parser::ParserFactory as ReasoningParserFactory, reasoning_parser::ParserFactory as ReasoningParserFactory,
tokenizer::{stop::StopSequenceDecoder, traits::Tokenizer}, tokenizer::{stop::StopSequenceDecoder, traits::Tokenizer},
...@@ -53,6 +54,7 @@ pub struct RequestInput { ...@@ -53,6 +54,7 @@ pub struct RequestInput {
pub enum RequestType { pub enum RequestType {
Chat(Arc<ChatCompletionRequest>), Chat(Arc<ChatCompletionRequest>),
Generate(Arc<GenerateRequest>), Generate(Arc<GenerateRequest>),
Responses(Arc<ResponsesRequest>),
} }
/// Shared components (injected once at creation) /// Shared components (injected once at creation)
...@@ -104,6 +106,19 @@ pub struct PreparationOutput { ...@@ -104,6 +106,19 @@ pub struct PreparationOutput {
/// Filtered request (if tools were filtered) /// Filtered request (if tools were filtered)
pub filtered_request: Option<ChatCompletionRequest>, pub filtered_request: Option<ChatCompletionRequest>,
// Harmony-specific fields
/// Whether this is a Harmony request (default: false)
pub harmony_mode: bool,
/// Selection text for worker routing (Harmony only)
pub selection_text: Option<String>,
/// Harmony messages for history tracking (Harmony only)
pub harmony_messages: Option<Vec<super::harmony::HarmonyMessage>>,
/// Stop token IDs for Harmony models
pub harmony_stop_ids: Option<Vec<u32>>,
} }
/// Worker selection (Step 2) /// Worker selection (Step 2)
...@@ -155,6 +170,16 @@ pub struct ResponseState { ...@@ -155,6 +170,16 @@ pub struct ResponseState {
/// Final processed response /// Final processed response
pub final_response: Option<FinalResponse>, pub final_response: Option<FinalResponse>,
/// Responses API iteration result (Harmony only, for tool loop orchestration)
pub responses_iteration_result: Option<super::harmony::ResponsesIterationResult>,
// Harmony-specific parser state
/// Harmony parser for non-streaming (single parser for all indices)
pub harmony_parser: Option<super::harmony::HarmonyParserAdapter>,
/// Harmony parsers for streaming (one per index for n>1 support)
pub harmony_parser_per_index: Option<HashMap<usize, super::harmony::HarmonyParserAdapter>>,
} }
/// Streaming state (per-choice tracking) /// Streaming state (per-choice tracking)
...@@ -217,6 +242,24 @@ impl RequestContext { ...@@ -217,6 +242,24 @@ impl RequestContext {
} }
} }
/// Create context for Responses API request
pub fn for_responses(
request: Arc<ResponsesRequest>,
headers: Option<HeaderMap>,
model_id: Option<String>,
components: Arc<SharedComponents>,
) -> Self {
Self {
input: RequestInput {
request_type: RequestType::Responses(request),
headers,
model_id,
},
components,
state: ProcessingState::default(),
}
}
/// Get reference to original request (type-safe) /// Get reference to original request (type-safe)
pub fn request(&self) -> &RequestType { pub fn request(&self) -> &RequestType {
&self.input.request_type &self.input.request_type
...@@ -254,11 +297,28 @@ impl RequestContext { ...@@ -254,11 +297,28 @@ impl RequestContext {
} }
} }
/// Get responses request (panics if not responses)
pub fn responses_request(&self) -> &ResponsesRequest {
match &self.input.request_type {
RequestType::Responses(req) => req.as_ref(),
_ => panic!("Expected responses request"),
}
}
/// Get Arc clone of responses request (panics if not responses)
pub fn responses_request_arc(&self) -> Arc<ResponsesRequest> {
match &self.input.request_type {
RequestType::Responses(req) => Arc::clone(req),
_ => panic!("Expected responses request"),
}
}
/// Check if request is streaming /// Check if request is streaming
pub fn is_streaming(&self) -> bool { pub fn is_streaming(&self) -> bool {
match &self.input.request_type { match &self.input.request_type {
RequestType::Chat(req) => req.stream, RequestType::Chat(req) => req.stream,
RequestType::Generate(req) => req.stream, RequestType::Generate(req) => req.stream,
RequestType::Responses(req) => req.stream.unwrap_or(false),
} }
} }
} }
......
//! Harmony request builder
//!
//! Handles encoding of Chat/Responses requests into Harmony format using openai-harmony library.
use std::sync::OnceLock;
use chrono::Local;
use openai_harmony::{
chat::{
Author, ChannelConfig, Content, Conversation, DeveloperContent, Message as HarmonyMessage,
ReasoningEffort, Role, SystemContent, TextContent, ToolDescription,
},
HarmonyEncoding, HarmonyEncodingName,
};
use tracing::debug;
use super::types::HarmonyBuildOutput;
use crate::protocols::{
chat::{ChatCompletionRequest, ChatMessage, UserMessageContent},
common::{ContentPart, Tool},
responses::{
ReasoningEffort as ResponsesReasoningEffort, ResponseContentPart, ResponseInput,
ResponseInputOutputItem, ResponseReasoningContent, ResponseTool, ResponseToolType,
ResponsesRequest, StringOrContentParts,
},
};
/// Global Harmony encoding (lazy-initialized)
static HARMONY_ENCODING: OnceLock<HarmonyEncoding> = OnceLock::new();
/// Get or initialize the Harmony encoding
///
/// Uses HarmonyGptOss encoding which supports the gpt-oss model family.
pub(super) fn get_harmony_encoding() -> &'static HarmonyEncoding {
HARMONY_ENCODING.get_or_init(|| {
openai_harmony::load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss)
.expect("Failed to load Harmony encoding")
})
}
/// Built-in tools that are added to the system message
const BUILTIN_TOOLS: &[&str] = &["web_search_preview", "code_interpreter", "container"];
/// Trait for tool-like objects that can be converted to Harmony ToolDescription
trait ToolLike {
/// Check if this is a built-in tool (should be skipped in developer message)
#[allow(dead_code)]
fn is_builtin(&self) -> bool;
/// Check if this is a custom tool (function or MCP)
fn is_custom(&self) -> bool;
/// Convert to ToolDescription
fn to_tool_description(&self) -> Option<ToolDescription>;
}
/// Implement ToolLike for Chat Completion Tool
impl ToolLike for Tool {
fn is_builtin(&self) -> bool {
matches!(
self.tool_type.as_str(),
"web_search_preview" | "code_interpreter" | "container"
)
}
fn is_custom(&self) -> bool {
matches!(self.tool_type.as_str(), "mcp" | "function")
}
fn to_tool_description(&self) -> Option<ToolDescription> {
Some(ToolDescription::new(
self.function.name.clone(),
self.function.description.clone().unwrap_or_default(),
Some(self.function.parameters.clone()),
))
}
}
/// Implement ToolLike for Responses API Tool
impl ToolLike for ResponseTool {
fn is_builtin(&self) -> bool {
matches!(
self.r#type,
ResponseToolType::WebSearchPreview | ResponseToolType::CodeInterpreter
)
}
fn is_custom(&self) -> bool {
matches!(
self.r#type,
ResponseToolType::Mcp | ResponseToolType::Function
)
}
fn to_tool_description(&self) -> Option<ToolDescription> {
self.function.as_ref().map(|func| {
ToolDescription::new(
func.name.clone(),
func.description.clone().unwrap_or_default(),
Some(func.parameters.clone()),
)
})
}
}
fn has_custom_tools(tool_types: &[&str]) -> bool {
!tool_types.iter().all(|t| BUILTIN_TOOLS.contains(t))
}
/// Harmony request builder
///
/// Converts OpenAI-format requests into Harmony-encoded format with input_ids,
/// stop tokens, and selection text for worker routing.
pub struct HarmonyBuilder {
encoding: &'static HarmonyEncoding,
}
impl HarmonyBuilder {
/// Create a new Harmony builder
pub fn new() -> Self {
Self {
encoding: get_harmony_encoding(),
}
}
/// Build Harmony request from Chat Completion request
///
/// # Arguments
///
/// * `request` - The ChatCompletionRequest to encode
///
/// # Returns
///
/// HarmonyBuildOutput containing input_ids, stop_token_ids, selection_text, and messages
pub fn build_from_chat(
&self,
request: &ChatCompletionRequest,
) -> Result<HarmonyBuildOutput, String> {
let mut all_messages = Vec::new();
let sys_msg = self.build_system_message_from_chat(request);
all_messages.push(sys_msg);
let dev_msg = self.build_developer_message_from_chat(request.tools.as_ref());
all_messages.push(dev_msg);
let mut user_messages = self.convert_chat_messages(&request.messages)?;
all_messages.append(&mut user_messages);
let conversation = Conversation::from_messages(all_messages.clone());
let token_ids = self
.encoding
.render_conversation_for_completion(&conversation, Role::Assistant, None)
.map_err(|e| format!("Failed to encode Harmony conversation: {}", e))?;
let selection_text = self.extract_selection_text(&all_messages);
// Get stop tokens for Harmony assistant actions (<|return|> and <|call|>)
let stop_token_ids: Vec<u32> = self
.encoding
.stop_tokens_for_assistant_actions()
.into_iter()
.flat_map(|set| set.into_iter())
.collect();
Ok(HarmonyBuildOutput {
input_ids: token_ids,
stop_token_ids,
selection_text,
harmony_messages: all_messages
.into_iter()
.map(super::types::HarmonyMessage::from_openai_harmony)
.collect(),
})
}
/// Build Harmony request from Responses request
///
/// # Arguments
///
/// * `request` - The ResponsesRequest to encode
///
/// # Returns
///
/// HarmonyBuildOutput containing input_ids, stop_token_ids, selection_text, and messages
pub fn build_from_responses(
&self,
request: &ResponsesRequest,
) -> Result<HarmonyBuildOutput, String> {
let all_messages = self.construct_input_messages_with_harmony(request)?;
let conversation = Conversation::from_messages(all_messages.clone());
let token_ids = self
.encoding
.render_conversation_for_completion(&conversation, Role::Assistant, None)
.map_err(|e| format!("Failed to encode Harmony conversation: {}", e))?;
let selection_text = self.extract_selection_text(&all_messages);
// Get stop tokens for Harmony assistant actions (<|return|> and <|call|>)
let stop_token_ids: Vec<u32> = self
.encoding
.stop_tokens_for_assistant_actions()
.into_iter()
.flat_map(|set| set.into_iter())
.collect();
// Decode tokens to see what the model actually receives
let decoded_text = self
.encoding
.tokenizer()
.decode_utf8(&token_ids)
.unwrap_or_else(|_| "<decode error>".to_string());
debug!(
token_count = token_ids.len(),
token_preview = ?&token_ids[..token_ids.len().min(20)],
decoded_length = decoded_text.len(),
"Encoded conversation to tokens - decoded text follows:"
);
debug!("DECODED_TEXT_START\n{}\nDECODED_TEXT_END", decoded_text);
Ok(HarmonyBuildOutput {
input_ids: token_ids,
stop_token_ids,
selection_text,
harmony_messages: all_messages
.into_iter()
.map(super::types::HarmonyMessage::from_openai_harmony)
.collect(),
})
}
/// Build system message from ChatCompletionRequest
/// Build system message with common logic
///
/// # Arguments
/// * `reasoning_effort` - Optional reasoning effort level
/// * `has_tools` - Whether custom tools are present
fn build_system_message(
&self,
reasoning_effort: Option<ReasoningEffort>,
has_tools: bool,
) -> HarmonyMessage {
let mut sys_content = SystemContent::new();
// Add reasoning_effort if provided
if let Some(effort) = reasoning_effort {
sys_content = sys_content.with_reasoning_effort(effort);
}
// Set conversation start date (always current date)
sys_content =
sys_content.with_conversation_start_date(Local::now().format("%Y-%m-%d").to_string());
// If no tools, remove "commentary" from valid channels
if !has_tools {
if let Some(channel_config) = &sys_content.channel_config {
let valid_channels: Vec<String> = channel_config
.valid_channels
.iter()
.filter(|c| c.as_str() != "commentary")
.cloned()
.collect();
sys_content = sys_content
.with_channel_config(ChannelConfig::require_channels(valid_channels));
}
}
HarmonyMessage::from_role_and_content(Role::System, sys_content)
}
fn build_system_message_from_chat(&self, request: &ChatCompletionRequest) -> HarmonyMessage {
let reasoning_effort = request
.reasoning_effort
.as_deref()
.map(|effort| match effort {
"high" => ReasoningEffort::High,
"medium" => ReasoningEffort::Medium,
"low" => ReasoningEffort::Low,
_ => ReasoningEffort::Medium,
});
let has_tools = request.tools.is_some();
self.build_system_message(reasoning_effort, has_tools)
}
/// Build system message from ResponsesRequest
///
/// # Arguments
/// * `request` - The ResponsesRequest
/// * `with_custom_tools` - Whether custom tools (beyond built-ins) are present
fn build_system_message_from_responses(
&self,
request: &ResponsesRequest,
with_custom_tools: bool,
) -> HarmonyMessage {
let reasoning_effort = request
.reasoning
.as_ref()
.and_then(|r| r.effort.as_ref())
.map(|effort| match effort {
ResponsesReasoningEffort::High => ReasoningEffort::High,
ResponsesReasoningEffort::Medium => ReasoningEffort::Medium,
ResponsesReasoningEffort::Low => ReasoningEffort::Low,
});
self.build_system_message(reasoning_effort, with_custom_tools)
}
/// Build developer message with common logic
///
/// Filters out built-in tools and converts custom tools to ToolDescription
///
/// # Arguments
/// * `tools` - Optional list of tools
/// * `instructions` - Optional instructions (Responses API only)
fn build_developer_message<T: ToolLike>(
&self,
tools: Option<&Vec<T>>,
instructions: Option<&str>,
) -> HarmonyMessage {
let mut dev_content = DeveloperContent::new();
// Add instructions if provided (Responses API only)
if let Some(instructions) = instructions {
dev_content = dev_content.with_instructions(instructions.to_string());
}
// Early return if no tools
let Some(tools) = tools else {
return HarmonyMessage::from_role_and_content(Role::Developer, dev_content);
};
// Filter to custom tools and convert to ToolDescription
let tool_descriptions: Vec<ToolDescription> = tools
.iter()
.filter(|t| t.is_custom())
.filter_map(|t| t.to_tool_description())
.collect();
// Add function tools to developer content
if !tool_descriptions.is_empty() {
dev_content = dev_content.with_function_tools(tool_descriptions);
}
HarmonyMessage::from_role_and_content(Role::Developer, dev_content)
}
fn build_developer_message_from_chat(&self, tools: Option<&Vec<Tool>>) -> HarmonyMessage {
self.build_developer_message(tools, None)
}
/// Build developer message from Responses request
///
/// # Arguments
/// * `instructions` - Optional instructions (Responses API specific)
/// * `tools` - Optional list of tools
fn build_developer_message_from_responses(
&self,
instructions: Option<&str>,
tools: Option<&Vec<ResponseTool>>,
) -> HarmonyMessage {
self.build_developer_message(tools, instructions)
}
/// Construct input messages for Responses API with Harmony
///
/// Handles both new conversations and continuations of previous responses.
///
/// This handles:
/// - New conversation: system message, developer message, and user input
/// - Continuing conversation: loads previous messages, cleans up chain-of-thoughts
/// - MCP tool allowlisting for special tool types
/// - Complex response input parsing with function call tracking
///
/// # Arguments
/// * `request` - The ResponsesRequest
/// * `prev_response` - Optional previous response to continue from
fn construct_input_messages_with_harmony(
&self,
request: &ResponsesRequest,
) -> Result<Vec<HarmonyMessage>, String> {
let mut all_messages = Vec::new();
// Handle new vs continuing conversation
if request.previous_response_id.is_none() {
// New conversation
let tool_types: Vec<&str> = request
.tools
.as_ref()
.map(|tools| {
tools
.iter()
.map(|tool| match tool.r#type {
ResponseToolType::Function => "function",
ResponseToolType::WebSearchPreview => "web_search_preview",
ResponseToolType::CodeInterpreter => "code_interpreter",
ResponseToolType::Mcp => "mcp",
})
.collect()
})
.unwrap_or_default();
let with_custom_tools = has_custom_tools(&tool_types);
// Add system message
let sys_msg = self.build_system_message_from_responses(request, with_custom_tools);
all_messages.push(sys_msg);
// Add developer message only if we have custom tools
if with_custom_tools {
let dev_msg = self.build_developer_message_from_responses(
request.instructions.as_deref(),
request.tools.as_ref(),
);
all_messages.push(dev_msg);
}
} else {
// Continue the previous conversation
// NOTE: Previous messages are loaded by serve_harmony_responses() before calling this method.
// The request.input will already contain the conversation history when previous_response_id was set.
// We just proceed with parsing the input items as normal.
debug!("Continuing conversation (history already loaded in request.input)");
}
// Append the new input
// Responses API supports simple text inputs without chat format
match &request.input {
ResponseInput::Text(text) => {
let user_msg = HarmonyMessage {
author: Author {
role: Role::User,
name: None,
},
recipient: None,
content: vec![Content::Text(TextContent { text: text.clone() })],
channel: None,
content_type: None,
};
all_messages.push(user_msg);
}
ResponseInput::Items(items) => {
// Track function calls for looking up call_id → name mapping
let mut prev_outputs: Vec<&ResponseInputOutputItem> = Vec::new();
for item in items {
let msg = self.parse_response_item_to_harmony_message(item, &prev_outputs)?;
all_messages.push(msg);
// Track function tool calls so that function_call_output can find the name
if matches!(item, ResponseInputOutputItem::FunctionToolCall { .. }) {
prev_outputs.push(item);
}
}
}
}
debug!(
message_count = all_messages.len(),
"Constructed Harmony messages for Responses API"
);
Ok(all_messages)
}
/// Parse a ResponseInputOutputItem into a HarmonyMessage
///
/// Handles conversion of various response item types (messages, function calls, reasoning, etc.)
/// to Harmony message format.
///
/// # Arguments
/// * `item` - The ResponseInputOutputItem to parse
/// * `prev_outputs` - Previous items for looking up function call names (for function_call_output)
fn parse_response_item_to_harmony_message(
&self,
item: &ResponseInputOutputItem,
prev_outputs: &[&ResponseInputOutputItem],
) -> Result<HarmonyMessage, String> {
match item {
// Regular message (user or assistant)
ResponseInputOutputItem::Message { role, content, .. } => {
let harmony_role = match role.as_str() {
"user" => Role::User,
"assistant" => Role::Assistant,
"system" => Role::System,
_ => Role::User, // Default to user for unknown roles
};
// Extract text from content parts
let text_parts: Vec<String> = content
.iter()
.filter_map(|part| match part {
ResponseContentPart::OutputText { text, .. } => Some(text.clone()),
ResponseContentPart::InputText { text } => Some(text.clone()),
ResponseContentPart::Unknown => None,
})
.collect();
let text = text_parts.join("\n");
Ok(HarmonyMessage {
author: Author {
role: harmony_role,
name: None,
},
recipient: None,
content: vec![Content::Text(TextContent { text })],
channel: None,
content_type: None,
})
}
// Reasoning content (chain-of-thought)
ResponseInputOutputItem::Reasoning { content, .. } => {
// Extract reasoning text
let reasoning_texts: Vec<String> = content
.iter()
.map(|rc| match rc {
ResponseReasoningContent::ReasoningText { text } => text.clone(),
})
.collect();
let text = reasoning_texts.join("\n");
// Reasoning goes in the "analysis" channel for Harmony
Ok(HarmonyMessage {
author: Author {
role: Role::Assistant,
name: None,
},
recipient: None,
content: vec![Content::Text(TextContent { text })],
channel: Some("analysis".to_string()),
content_type: None,
})
}
// Function tool call (with optional output)
ResponseInputOutputItem::FunctionToolCall {
name,
arguments,
output,
..
} => {
// If there's an output, this represents the tool result
// Otherwise, it's the tool call itself
if let Some(output_str) = output {
// Tool result - use Tool role with "functions.{name}" as author name
// IMPORTANT: Must include recipient="assistant" for parser to recognize it.
// We keep channel=None to minimize what the model might copy.
let author_name = format!("functions.{}", name);
debug!(
tool_name = %name,
author_name = %author_name,
output_preview = %output_str.chars().take(100).collect::<String>(),
"Building tool result message with Tool role (recipient=assistant, no channel)"
);
Ok(HarmonyMessage {
author: Author {
role: Role::Tool,
name: Some(author_name),
},
recipient: Some("assistant".to_string()),
content: vec![Content::Text(TextContent {
text: output_str.clone(),
})],
channel: None,
content_type: None,
})
} else {
// Tool call - assistant message in commentary channel with recipient
// msg.with_channel("commentary").with_recipient(f"functions.{name}")
let recipient = format!("functions.{}", name);
debug!(
tool_name = %name,
recipient = %recipient,
"Building tool call message with recipient"
);
Ok(HarmonyMessage {
author: Author {
role: Role::Assistant,
name: None,
},
recipient: Some(recipient),
content: vec![Content::Text(TextContent {
text: arguments.clone(),
})],
channel: Some("commentary".to_string()),
content_type: Some("json".to_string()),
})
}
}
// Function call output (separate from call) - requires looking up the original call
ResponseInputOutputItem::FunctionCallOutput {
call_id, output, ..
} => {
// Search prev_outputs in reverse order to find the matching function call
let call = prev_outputs
.iter()
.rev()
.find_map(|item| match item {
ResponseInputOutputItem::FunctionToolCall { id, name, .. }
if id == call_id =>
{
Some(name.clone())
}
_ => None,
})
.ok_or_else(|| format!("No function call found for call_id: {}", call_id))?;
// Create Tool message with "functions.{name}" prefix
// IMPORTANT: Must include recipient="assistant" for parser to recognize it.
// We keep channel=None to minimize what the model might copy.
Ok(HarmonyMessage {
author: Author {
role: Role::Tool,
name: Some(format!("functions.{}", call)),
},
recipient: Some("assistant".to_string()),
content: vec![Content::Text(TextContent {
text: output.clone(),
})],
channel: None,
content_type: None,
})
}
// Simple input message (usually user message)
ResponseInputOutputItem::SimpleInputMessage { content, role, .. } => {
let harmony_role = match role.as_str() {
"user" => Role::User,
"assistant" => Role::Assistant,
"system" => Role::System,
_ => Role::User,
};
let text = match content {
StringOrContentParts::String(s) => s.clone(),
StringOrContentParts::Array(parts) => {
// Extract text from content parts
parts
.iter()
.filter_map(|part| match part {
ResponseContentPart::OutputText { text, .. } => Some(text.clone()),
ResponseContentPart::InputText { text } => Some(text.clone()),
ResponseContentPart::Unknown => None,
})
.collect::<Vec<_>>()
.join("\n")
}
};
Ok(HarmonyMessage {
author: Author {
role: harmony_role,
name: None,
},
recipient: None,
content: vec![Content::Text(TextContent { text })],
channel: None,
content_type: None,
})
}
}
}
/// Convert OpenAI ChatMessage format to Harmony messages
///
/// - Assistant messages with tool_calls create multiple messages (one per tool call)
/// - Tool role messages use Role::Tool with proper author
/// - Tool-related messages use channel="commentary"
fn convert_chat_messages(
&self,
messages: &[ChatMessage],
) -> Result<Vec<HarmonyMessage>, String> {
let mut harmony_messages = Vec::new();
// Build a map of tool_call_id -> function_name for tool responses
let mut tool_call_map = std::collections::HashMap::new();
for msg in messages {
if let ChatMessage::Assistant {
tool_calls: Some(calls),
..
} = msg
{
for call in calls {
tool_call_map.insert(call.id.clone(), call.function.name.clone());
}
}
}
for msg in messages {
match msg {
ChatMessage::System { content, name } => {
// System messages stay as-is
let harmony_msg = HarmonyMessage {
author: Author {
role: Role::System,
name: name.clone(),
},
recipient: None,
content: vec![Content::Text(TextContent {
text: content.clone(),
})],
channel: None,
content_type: None,
};
harmony_messages.push(harmony_msg);
}
ChatMessage::User { content, name } => {
// Extract text from user content
let text = match content {
UserMessageContent::Text(text) => text.clone(),
UserMessageContent::Parts(parts) => {
// For multimodal content, extract text parts
parts
.iter()
.filter_map(|part| {
if let ContentPart::Text { text } = part {
Some(text.as_str())
} else {
None
}
})
.collect::<Vec<_>>()
.join("\n")
}
};
let harmony_msg = HarmonyMessage {
author: Author {
role: Role::User,
name: name.clone(),
},
recipient: None,
content: vec![Content::Text(TextContent { text })],
channel: None,
content_type: None,
};
harmony_messages.push(harmony_msg);
}
ChatMessage::Assistant {
content,
name,
tool_calls,
reasoning_content,
} => {
if let Some(calls) = tool_calls {
// Create one message per tool call with channel="commentary"
for call in calls {
let function_name = &call.function.name;
let arguments = call.function.arguments.clone().unwrap_or_default();
let tool_call_msg = HarmonyMessage {
author: Author {
role: Role::Assistant,
name: name.clone(),
},
recipient: Some(format!("functions.{}", function_name)),
content: vec![Content::Text(TextContent { text: arguments })],
channel: Some("commentary".to_string()),
content_type: Some("json".to_string()),
};
harmony_messages.push(tool_call_msg);
}
} else {
// Regular assistant message with content
// Combine content with reasoning if present
let mut text = content.clone().unwrap_or_default();
if let Some(reasoning) = reasoning_content {
if !text.is_empty() {
text.push('\n');
}
text.push_str(reasoning);
}
let harmony_msg = HarmonyMessage {
author: Author {
role: Role::Assistant,
name: name.clone(),
},
recipient: None,
content: vec![Content::Text(TextContent { text })],
channel: Some("final".to_string()),
content_type: None,
};
harmony_messages.push(harmony_msg);
}
}
ChatMessage::Tool {
content,
tool_call_id,
} => {
// Look up the function name from the tool_call_id
let function_name = tool_call_map
.get(tool_call_id)
.cloned()
.unwrap_or_else(|| tool_call_id.clone());
// Tool result - Must include recipient="assistant" for parser to recognize it.
// We keep channel=None to minimize what the model might copy.
let harmony_msg = HarmonyMessage {
author: Author {
role: Role::Tool,
name: Some(format!("functions.{}", function_name)),
},
recipient: Some("assistant".to_string()),
content: vec![Content::Text(TextContent {
text: content.clone(),
})],
channel: None,
content_type: None,
};
harmony_messages.push(harmony_msg);
}
ChatMessage::Function { content, name } => {
// Function messages also use Role::Tool
// Tool result - Must include recipient="assistant" for parser to recognize it.
// We keep channel=None to minimize what the model might copy.
let harmony_msg = HarmonyMessage {
author: Author {
role: Role::Tool,
name: Some(format!("functions.{}", name)),
},
recipient: Some("assistant".to_string()),
content: vec![Content::Text(TextContent {
text: content.clone(),
})],
channel: None,
content_type: None,
};
harmony_messages.push(harmony_msg);
}
}
}
Ok(harmony_messages)
}
/// Extract selection text for worker routing
///
/// Uses the last user message for load balancing
fn extract_selection_text(&self, messages: &[HarmonyMessage]) -> String {
// Find the last user message
if let Some(last_user_msg) = messages.iter().rev().find(|m| m.author.role == Role::User) {
// Extract full text from content
return last_user_msg
.content
.iter()
.filter_map(|c| match c {
Content::Text(tc) => Some(tc.text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("");
}
// Fallback: concatenate all text
messages
.iter()
.flat_map(|m| &m.content)
.filter_map(|c| match c {
Content::Text(tc) => Some(tc.text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join(" ")
}
}
impl Default for HarmonyBuilder {
fn default() -> Self {
Self::new()
}
}
//! Harmony model detection
/// Harmony model detector
///
/// Detects if a model name indicates support for Harmony encoding/parsing.
pub struct HarmonyDetector;
impl HarmonyDetector {
pub fn is_harmony_model(model_name: &str) -> bool {
// Case-insensitive substring search without heap allocation
// More efficient than to_lowercase() which allocates a new String
model_name
.as_bytes()
.windows(7) // "gpt-oss".len()
.any(|window| window.eq_ignore_ascii_case(b"gpt-oss"))
}
}
//! Harmony pipeline implementation
//!
//! This module provides support for GPT-OSS models that use Harmony encoding/parsing.
//! The Harmony protocol uses a channel-based approach with three channels:
//! - **analysis**: Reasoning/thinking content (optional)
//! - **commentary**: Tool calls (optional)
//! - **final**: Final response text (required)
//!
//! ## Architecture
//!
//! The Harmony implementation is structured as follows:
//!
//! - **detector**: Model detection (is this a Harmony-capable model?)
//! - **builder**: Request encoding (convert Chat/Responses → input_ids)
//! - **parser**: Response parsing (output_ids → channels)
//! - **types**: Shared type definitions
//!
//! ## Usage
//!
//! ```ignore
//! use sglang_router_rs::routers::grpc::harmony::{HarmonyDetector, HarmonyBuilder};
//!
//! // Detect if model supports Harmony
//! if HarmonyDetector::is_harmony_model("gpt-4o") {
//! // Build Harmony request
//! let builder = HarmonyBuilder::new();
//! let output = builder.build_from_chat(&request)?;
//! // ... use output.input_ids for gRPC request
//! }
//! ```
pub mod builder;
pub mod detector;
pub mod parser;
pub mod processor;
pub mod responses;
pub mod stages;
pub mod streaming;
pub mod types;
// Re-export main types for convenience
pub use builder::HarmonyBuilder;
pub use detector::HarmonyDetector;
pub use parser::HarmonyParserAdapter;
pub use processor::{HarmonyResponseProcessor, ResponsesIterationResult};
pub use responses::{serve_harmony_responses, HarmonyResponsesContext};
pub use stages::{
HarmonyPreparationStage, HarmonyRequestBuildingStage, HarmonyResponseProcessingStage,
};
pub use streaming::HarmonyStreamingProcessor;
pub use types::{
FunctionDelta, HarmonyBuildOutput, HarmonyChannelDelta, HarmonyChannelOutput, HarmonyMessage,
ToolCallDelta,
};
//! Harmony response parser
//!
//! Adapter for openai_harmony::StreamableParser that handles channel-based parsing.
use openai_harmony::{chat::Role, HarmonyEncoding, StreamableParser};
use uuid::Uuid;
use super::types::{HarmonyChannelDelta, HarmonyChannelOutput};
use crate::protocols::common::{FunctionCallResponse, ToolCall};
/// Get the global Harmony encoding
///
/// References the same encoding used by the builder for consistency
fn get_harmony_encoding() -> &'static HarmonyEncoding {
use super::builder::get_harmony_encoding;
get_harmony_encoding()
}
/// Harmony parser adapter
///
/// Wraps openai_harmony::StreamableParser and provides methods for parsing
/// complete responses and streaming chunks.
pub struct HarmonyParserAdapter {
parser: StreamableParser,
prev_recipient: Option<String>,
}
impl HarmonyParserAdapter {
/// Create a new Harmony parser
pub fn new() -> Result<Self, String> {
let encoding = get_harmony_encoding();
let parser = StreamableParser::new(encoding.clone(), Some(Role::Assistant))
.map_err(|e| format!("Failed to create StreamableParser: {}", e))?;
Ok(Self {
parser,
prev_recipient: None,
})
}
/// Extract text from message content (private helper)
///
/// Filters text content from a message's content array and joins them into a single string.
///
/// # Arguments
///
/// * `content` - The content array from a Harmony message
///
/// # Returns
///
/// Joined text string from all text content items
fn extract_text_from_content(content: &[openai_harmony::chat::Content]) -> String {
content
.iter()
.filter_map(|c| match c {
openai_harmony::chat::Content::Text(tc) => Some(tc.text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("")
}
/// Handle incomplete content from parser state (private helper)
///
/// Checks for any remaining incomplete content in the parser and appends it
/// to the appropriate channel (analysis or final_text).
///
/// # Arguments
///
/// * `parser` - Reference to the StreamableParser
/// * `analysis` - Mutable reference to analysis content
/// * `final_text` - Mutable reference to final text content
fn handle_incomplete_content(
parser: &StreamableParser,
analysis: &mut Option<String>,
final_text: &mut String,
) {
if let Ok(current_content) = parser.current_content() {
if !current_content.is_empty() {
let current_channel = parser.current_channel();
match current_channel.as_deref() {
Some("analysis") => {
*analysis = Some(current_content);
}
Some("final") | None => {
final_text.push_str(&current_content);
}
_ => {}
}
}
}
}
/// Parse messages into channel outputs (private helper)
///
/// Extracts analysis, commentary (tool calls), and final text from Harmony messages.
/// This is the core parsing logic shared by both parse_complete and finalize.
///
/// # Arguments
///
/// * `messages` - The messages to parse from the Harmony parser
///
/// # Returns
///
/// Tuple of (analysis, commentary, final_text)
fn parse_messages(
messages: &[openai_harmony::chat::Message],
) -> (Option<String>, Option<Vec<ToolCall>>, String) {
let mut analysis = None;
let mut commentary: Option<Vec<ToolCall>> = None;
let mut final_text = String::new();
for msg in messages {
// Filter: Only process assistant messages
if msg.author.role != Role::Assistant {
continue;
}
let channel = msg.channel.as_deref().unwrap_or("");
let recipient = msg.recipient.as_deref();
match channel {
"analysis" => {
// Process each content item
// For Chat API, we join them into a single reasoning_content
let text = Self::extract_text_from_content(&msg.content);
if !text.is_empty() {
analysis = Some(text);
}
}
"commentary" => {
// Handle different recipient types
if let Some(recipient_str) = recipient {
if recipient_str.starts_with("functions.") {
let function_name = recipient_str.strip_prefix("functions.").unwrap();
// Process each content item separately
for content in &msg.content {
if let openai_harmony::chat::Content::Text(tc) = content {
let call_id = format!("call_{}", Uuid::new_v4());
let tool_call = ToolCall {
id: call_id,
tool_type: "function".to_string(),
function: FunctionCallResponse {
name: function_name.to_string(),
arguments: Some(tc.text.clone()),
},
};
match commentary.as_mut() {
Some(calls) => calls.push(tool_call),
None => commentary = Some(vec![tool_call]),
}
}
}
} else if recipient_str.starts_with("python")
|| recipient_str.starts_with("browser")
|| recipient_str.starts_with("container")
{
// Built-in tools → treat as reasoning
// For Chat API, we add to analysis content
let text = Self::extract_text_from_content(&msg.content);
if !text.is_empty() {
// Append to analysis (built-in tools are reasoning)
match analysis.as_mut() {
Some(existing) => {
existing.push('\n');
existing.push_str(&text);
}
None => analysis = Some(text),
}
}
}
// Unknown recipient would raise ValueError
// For now, we silently ignore (can add logging later)
}
}
"final" => {
// Process final channel content
let text = Self::extract_text_from_content(&msg.content);
final_text.push_str(&text);
}
_ => {
// Unknown channel, append to final text as fallback
let text = Self::extract_text_from_content(&msg.content);
final_text.push_str(&text);
}
}
}
(analysis, commentary, final_text)
}
/// Parse complete response
///
/// Parses all output token IDs and returns the complete channel output
/// containing analysis, commentary (tool calls), and final text.
///
/// # Arguments
///
/// * `output_ids` - The complete output token IDs from the model
/// * `finish_reason` - The finish reason from GenerateComplete ("stop", "length", etc.)
/// * `matched_stop` - Optional matched stop token information from GenerateComplete
///
/// # Returns
///
/// Complete HarmonyChannelOutput with all three channels parsed
pub fn parse_complete(
&mut self,
output_ids: &[u32],
finish_reason: String,
matched_stop: Option<serde_json::Value>,
) -> Result<HarmonyChannelOutput, String> {
// Feed all tokens to the parser
for &token_id in output_ids {
self.parser.process(token_id).map_err(|e| {
// Log the full output_ids context on error
tracing::error!(
token_id = token_id,
output_ids = ?output_ids,
error = %e,
"Harmony parser failed to process token"
);
format!("Failed to process token {}: {}", token_id, e)
})?;
}
// Extract all completed messages from the parser
let messages = self.parser.messages();
// Parse messages into channel outputs using shared helper
let (mut analysis, commentary, mut final_text) = Self::parse_messages(messages);
// Check for incomplete content in parser state
Self::handle_incomplete_content(&self.parser, &mut analysis, &mut final_text);
// Determine finish reason: override to "tool_calls" if commentary has tool calls
let final_finish_reason = if commentary.is_some() {
"tool_calls".to_string()
} else {
finish_reason
};
Ok(HarmonyChannelOutput {
analysis,
commentary,
final_text,
finish_reason: final_finish_reason,
matched_stop,
})
}
/// Get all messages from the parser
///
/// Returns the raw messages extracted by the Harmony parser.
/// Used for validation checks.
pub fn get_messages(&self) -> Vec<openai_harmony::chat::Message> {
self.parser.messages().to_vec()
}
/// Parse streaming chunk
///
/// Parses incremental token IDs and returns a delta with any new content
/// from the analysis, commentary, or final channels.
///
/// # Arguments
///
/// * `chunk_ids` - New token IDs from the current chunk
///
/// # Returns
///
/// Optional HarmonyChannelDelta if there's new content to emit
pub fn parse_chunk(
&mut self,
chunk_ids: &[u32],
) -> Result<Option<HarmonyChannelDelta>, String> {
let mut has_delta = false;
let mut analysis_delta = None;
let mut commentary_delta = None;
let mut final_delta = None;
// Track message count before processing
let prev_message_count = self.parser.messages().len();
// Accumulate delta text for commentary channel
let mut accumulated_delta = String::new();
// Process each token
for &token_id in chunk_ids {
self.parser
.process(token_id)
.map_err(|e| format!("Failed to process token {}: {}", token_id, e))?;
// Check for content delta
if let Ok(Some(delta_text)) = self.parser.last_content_delta() {
has_delta = true;
// Determine which channel this delta belongs to
let channel = self.parser.current_channel();
match channel.as_deref() {
Some("analysis") => {
analysis_delta = Some(delta_text);
}
Some("final") | None => {
final_delta = Some(delta_text);
}
Some("commentary") => {
// Accumulate delta for commentary
accumulated_delta.push_str(&delta_text);
}
_ => {}
}
}
}
// Handle commentary channel tool call deltas
if self.parser.current_channel().as_deref() == Some("commentary") {
if let Some(cur_recipient) = self.parser.current_recipient() {
if cur_recipient.starts_with("functions.") {
has_delta = true;
// Count completed tool calls for index
let base_index = self
.parser
.messages()
.iter()
.filter(|msg| {
msg.channel.as_deref() == Some("commentary")
&& msg
.recipient
.as_deref()
.is_some_and(|r| r.starts_with("functions."))
})
.count();
// Check if recipient changed (new tool call)
let recipient_changed = self.prev_recipient.as_deref() != Some(&cur_recipient);
if recipient_changed {
// NEW tool call: emit name + id
let tool_name = cur_recipient.strip_prefix("functions.").unwrap();
let call_id = format!("call_{}", Uuid::new_v4());
commentary_delta = Some(super::types::ToolCallDelta {
index: base_index,
id: Some(call_id),
function: Some(super::types::FunctionDelta {
name: Some(tool_name.to_string()),
arguments: Some(String::new()),
}),
});
// Update prev_recipient
self.prev_recipient = Some(cur_recipient);
} else if !accumulated_delta.is_empty() {
// CONTINUING tool call: emit arguments delta
commentary_delta = Some(super::types::ToolCallDelta {
index: base_index,
id: None,
function: Some(super::types::FunctionDelta {
name: None,
arguments: Some(accumulated_delta),
}),
});
}
}
}
}
// Check if new messages were completed
let current_message_count = self.parser.messages().len();
let is_final = current_message_count > prev_message_count;
if has_delta {
Ok(Some(HarmonyChannelDelta {
analysis_delta,
commentary_delta,
final_delta,
is_final,
}))
} else {
Ok(None)
}
}
/// Finalize parsing
///
/// Called at the end of streaming to get the final state and any
/// remaining content.
///
/// # Arguments
///
/// * `finish_reason` - The finish reason from GenerateComplete ("stop", "length", etc.)
/// * `matched_stop` - Optional matched stop token information from GenerateComplete
///
/// # Returns
///
/// Final HarmonyChannelOutput with complete parsed content
pub fn finalize(
&mut self,
finish_reason: String,
matched_stop: Option<serde_json::Value>,
) -> Result<HarmonyChannelOutput, String> {
// Extract all completed messages
let messages = self.parser.messages();
// Parse messages into channel outputs using shared helper
let (mut analysis, commentary, mut final_text) = Self::parse_messages(messages);
// Check for remaining incomplete content
Self::handle_incomplete_content(&self.parser, &mut analysis, &mut final_text);
// Determine finish reason: override to "tool_calls" if commentary has tool calls
let final_finish_reason = if commentary.is_some() {
"tool_calls".to_string()
} else {
finish_reason
};
Ok(HarmonyChannelOutput {
analysis,
commentary,
final_text,
finish_reason: final_finish_reason,
matched_stop,
})
}
/// Reset parser state
///
/// Resets the parser to initial state for reuse
pub fn reset(&mut self) -> Result<(), String> {
// Create a new parser instance (StreamableParser doesn't have a reset method)
let encoding = get_harmony_encoding();
self.parser = StreamableParser::new(encoding.clone(), Some(Role::Assistant))
.map_err(|e| format!("Failed to reset parser: {}", e))?;
self.prev_recipient = None;
Ok(())
}
}
impl Default for HarmonyParserAdapter {
fn default() -> Self {
Self::new().expect("Failed to create default parser")
}
}
//! Harmony response processor for non-streaming responses
use std::sync::Arc;
use axum::response::Response;
use proto::generate_complete::MatchedStop::{MatchedStopStr, MatchedTokenId};
use super::HarmonyParserAdapter;
use crate::{
grpc_client::proto,
protocols::{
chat::{ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse},
common::{ToolCall, Usage},
responses::{
ResponseContentPart, ResponseOutputItem, ResponseReasoningContent, ResponseStatus,
ResponseUsage, ResponsesRequest, ResponsesResponse, ResponsesUsage,
},
},
routers::grpc::{
context::{DispatchMetadata, ExecutionResult},
utils,
},
};
/// Processor for non-streaming Harmony responses
///
/// Collects all output tokens from execution and parses them using
/// HarmonyParserAdapter to extract the complete response.
pub struct HarmonyResponseProcessor;
impl HarmonyResponseProcessor {
/// Create a new Harmony response processor
pub fn new() -> Self {
Self
}
/// Collect responses from ExecutionResult (similar to regular processor)
async fn collect_responses(
execution_result: ExecutionResult,
) -> Result<Vec<proto::GenerateComplete>, Response> {
match execution_result {
ExecutionResult::Single { mut stream } => {
let responses = utils::collect_stream_responses(&mut stream, "Single").await?;
stream.mark_completed();
Ok(responses)
}
ExecutionResult::Dual { prefill, decode } => {
// For Harmony we currently rely only on decode stream for outputs
let mut decode_stream = *decode;
let responses =
utils::collect_stream_responses(&mut decode_stream, "Decode").await?;
prefill.mark_completed();
decode_stream.mark_completed();
Ok(responses)
}
}
}
/// Process a non-streaming Harmony chat response
pub async fn process_non_streaming_chat_response(
&self,
execution_result: ExecutionResult,
chat_request: Arc<ChatCompletionRequest>,
dispatch: DispatchMetadata,
) -> Result<ChatCompletionResponse, Response> {
// Collect all completed responses (one per choice)
let all_responses = Self::collect_responses(execution_result).await?;
if all_responses.is_empty() {
return Err(utils::internal_error_static("No responses from server"));
}
// Build choices by parsing output with HarmonyParserAdapter
let mut choices: Vec<ChatChoice> = Vec::new();
for (index, complete) in all_responses.iter().enumerate() {
// Convert matched_stop from proto to JSON
let matched_stop = complete.matched_stop.as_ref().map(|m| match m {
MatchedTokenId(id) => {
serde_json::json!(id)
}
MatchedStopStr(s) => {
serde_json::json!(s)
}
});
// Parse Harmony channels with HarmonyParserAdapter
let mut parser = HarmonyParserAdapter::new().map_err(|e| {
utils::internal_error_message(format!("Failed to create Harmony parser: {}", e))
})?;
// Parse Harmony channels with finish_reason and matched_stop
let parsed = parser
.parse_complete(
&complete.output_ids,
complete.finish_reason.clone(),
matched_stop.clone(),
)
.map_err(|e| {
utils::internal_error_message(format!("Harmony parsing failed: {}", e))
})?;
// Build response message (assistant)
let message = ChatCompletionMessage {
role: "assistant".to_string(),
content: (!parsed.final_text.is_empty()).then_some(parsed.final_text),
tool_calls: parsed.commentary,
reasoning_content: parsed.analysis,
};
let finish_reason = parsed.finish_reason;
choices.push(ChatChoice {
index: index as u32,
message,
logprobs: None,
finish_reason: Some(finish_reason),
matched_stop,
hidden_states: None,
});
}
// Build usage from proto fields
let prompt_tokens: u32 = all_responses.iter().map(|r| r.prompt_tokens as u32).sum();
let completion_tokens: u32 = all_responses
.iter()
.map(|r| r.completion_tokens as u32)
.sum();
let usage = Usage {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
completion_tokens_details: None,
};
// Final ChatCompletionResponse
let response = ChatCompletionResponse {
id: dispatch.request_id.clone(),
object: "chat.completion".to_string(),
created: dispatch.created,
model: chat_request.model.clone(),
choices,
usage: Some(usage),
system_fingerprint: dispatch.weight_version.clone(),
};
Ok(response)
}
}
impl Default for HarmonyResponseProcessor {
fn default() -> Self {
Self::new()
}
}
/// Result of processing a single Responses API iteration
///
/// Used by the MCP tool loop to determine whether to continue
/// executing tools or return the final response.
pub enum ResponsesIterationResult {
/// Tool calls found in commentary channel - continue MCP loop
ToolCallsFound {
tool_calls: Vec<ToolCall>,
analysis: Option<String>, // For streaming emission
partial_text: String, // For streaming emission
},
/// No tool calls - return final ResponsesResponse
Completed {
response: Box<ResponsesResponse>,
usage: Usage,
},
}
impl HarmonyResponseProcessor {
/// Process a single Responses API iteration
///
/// Parses Harmony channels and determines if tool calls are present.
/// If tool calls found, returns ToolCallsFound for MCP loop to execute.
/// If no tool calls, builds final ResponsesResponse.
///
/// # Arguments
///
/// * `execution_result` - The execution result from the model
/// * `responses_request` - The original Responses API request
/// * `dispatch` - Dispatch metadata for request tracking
///
/// # Returns
///
/// ResponsesIterationResult indicating whether to continue loop or return
pub async fn process_responses_iteration(
&self,
execution_result: ExecutionResult,
responses_request: Arc<ResponsesRequest>,
dispatch: DispatchMetadata,
) -> Result<ResponsesIterationResult, Response> {
// Collect all completed responses
let all_responses = Self::collect_responses(execution_result).await?;
if all_responses.is_empty() {
return Err(utils::internal_error_static("No responses from server"));
}
// For Responses API, we only process the first response (n=1)
let complete = all_responses
.first()
.ok_or_else(|| utils::internal_error_static("No complete response"))?;
// Parse Harmony channels
let mut parser = HarmonyParserAdapter::new().map_err(|e| {
utils::internal_error_message(format!("Failed to create Harmony parser: {}", e))
})?;
// Convert matched_stop from proto to JSON
let matched_stop = complete.matched_stop.as_ref().map(|m| match m {
MatchedTokenId(id) => {
serde_json::json!(id)
}
MatchedStopStr(s) => {
serde_json::json!(s)
}
});
let parsed = parser
.parse_complete(
&complete.output_ids,
complete.finish_reason.clone(),
matched_stop,
)
.map_err(|e| utils::internal_error_message(format!("Harmony parsing failed: {}", e)))?;
// VALIDATION: Check if model incorrectly generated Tool role messages
// This happens when the model copies the format of tool result messages
// instead of continuing as assistant. This is a model hallucination bug.
let messages = parser.get_messages();
let tool_messages_generated = messages.iter().any(|msg| {
msg.author.role == openai_harmony::chat::Role::Tool
&& msg.recipient.as_deref() == Some("assistant")
});
if tool_messages_generated {
tracing::warn!(
"Model generated Tool->Assistant message instead of Assistant message. \
This is a model hallucination bug where it copies tool result format."
);
}
// Check for tool calls in commentary channel
if let Some(tool_calls) = parsed.commentary {
// Tool calls found - return for MCP loop execution
return Ok(ResponsesIterationResult::ToolCallsFound {
tool_calls,
analysis: parsed.analysis,
partial_text: parsed.final_text,
});
}
// No tool calls - build final ResponsesResponse
let mut output: Vec<ResponseOutputItem> = Vec::new();
// Map analysis channel → ResponseOutputItem::Reasoning
if let Some(analysis) = parsed.analysis {
let reasoning_item = ResponseOutputItem::Reasoning {
id: format!("reasoning_{}", dispatch.request_id),
summary: vec![],
content: vec![ResponseReasoningContent::ReasoningText { text: analysis }],
status: Some("completed".to_string()),
};
output.push(reasoning_item);
}
// Map final channel → ResponseOutputItem::Message
if !parsed.final_text.is_empty() {
let message_item = ResponseOutputItem::Message {
id: format!("msg_{}", dispatch.request_id),
role: "assistant".to_string(),
content: vec![ResponseContentPart::OutputText {
text: parsed.final_text,
annotations: vec![],
logprobs: None,
}],
status: "completed".to_string(),
};
output.push(message_item);
}
// Build usage
let prompt_tokens = complete.prompt_tokens as u32;
let completion_tokens = complete.completion_tokens as u32;
let usage = Usage {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
completion_tokens_details: None,
};
// Build ResponsesResponse with all required fields
let response = ResponsesResponse {
id: dispatch.request_id.clone(),
object: "response".to_string(),
created_at: dispatch.created as i64,
status: ResponseStatus::Completed,
error: None,
incomplete_details: None,
instructions: responses_request.instructions.clone(),
max_output_tokens: responses_request.max_output_tokens,
model: responses_request.model.clone(),
output,
parallel_tool_calls: responses_request.parallel_tool_calls.unwrap_or(true),
previous_response_id: responses_request.previous_response_id.clone(),
reasoning: None, // Set by caller if needed
store: responses_request.store.unwrap_or(true),
temperature: responses_request.temperature,
text: None,
tool_choice: responses_request
.tool_choice
.as_ref()
.map(|tc| serde_json::to_string(tc).unwrap_or_else(|_| "auto".to_string()))
.unwrap_or_else(|| "auto".to_string()),
tools: responses_request.tools.clone().unwrap_or_default(),
top_p: responses_request.top_p,
truncation: None,
usage: Some(ResponsesUsage::Modern(ResponseUsage {
input_tokens: prompt_tokens,
output_tokens: completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
input_tokens_details: None,
output_tokens_details: None,
})),
user: None,
metadata: responses_request.metadata.clone().unwrap_or_default(),
};
Ok(ResponsesIterationResult::Completed {
response: Box::new(response),
usage,
})
}
}
//! Harmony Responses API implementation with multi-turn MCP tool support
//!
//! This module implements the Harmony Responses API orchestration logic,
//! coordinating full pipeline execution with MCP tool support for multi-turn conversations.
//!
//! ## Architecture
//!
//! Multi-turn pipeline orchestration (NOT just a tool loop):
//! - Serves Harmony Responses API requests end-to-end
//! - Each iteration executes FULL pipeline (worker selection + client acquisition + execution + parsing)
//! - Handles MCP tool execution and history building between iterations
//! - Clean separation: serving orchestration (this file) vs. pipeline stages (stages/)
//!
//! ## Flow
//!
//! ```text
//! loop {
//! // Execute through FULL pipeline
//! let result = pipeline.execute_harmony_responses(&request, &ctx).await?;
//!
//! match result {
//! ToolCallsFound { tool_calls, .. } => {
//! // Execute MCP tools
//! // Build next request with tool results
//! // Continue loop
//! }
//! Completed { response, .. } => {
//! return Ok(response);
//! }
//! }
//! }
//! ```
//!
//! ## Design Reference
//!
//! See `/Users/simolin/workspace/sglang/.claude/docs/harmony_pipeline/tool_loop_design.md`
//! for complete architecture, rationale, and implementation details.
use std::sync::Arc;
use axum::response::Response;
use serde_json::Value as JsonValue;
use crate::{
data_connector::{ResponseId, ResponseStorage},
mcp::McpManager,
protocols::{
common::{Function, ToolCall},
responses::{
ResponseInput, ResponseInputOutputItem, ResponseTool, ResponsesRequest,
ResponsesResponse, StringOrContentParts,
},
},
routers::grpc::{
context::SharedComponents, harmony::processor::ResponsesIterationResult,
pipeline::RequestPipeline, utils,
},
};
/// Maximum number of tool execution iterations to prevent infinite loops
const MAX_TOOL_ITERATIONS: usize = 10;
/// Record of a single MCP tool call execution
///
/// Stores metadata needed to build mcp_call output items for Responses API format
#[derive(Debug, Clone)]
struct McpCallRecord {
/// Tool call ID (stored for potential future use, currently generate new IDs)
#[allow(dead_code)]
call_id: String,
/// Tool name
tool_name: String,
/// JSON-encoded arguments
arguments: String,
/// JSON-encoded output/result
output: String,
/// Whether execution succeeded
success: bool,
/// Error message if execution failed
error: Option<String>,
}
/// Tracking structure for MCP tool calls across iterations
///
/// Accumulates all MCP tool call metadata during multi-turn conversation
/// so we can build proper mcp_list_tools and mcp_call output items.
#[derive(Debug, Clone)]
struct McpCallTracking {
/// MCP server label (e.g., "sglang-mcp")
server_label: String,
/// All tool call records across all iterations
tool_calls: Vec<McpCallRecord>,
}
impl McpCallTracking {
fn new(server_label: String) -> Self {
Self {
server_label,
tool_calls: Vec::new(),
}
}
fn record_call(
&mut self,
call_id: String,
tool_name: String,
arguments: String,
output: String,
success: bool,
error: Option<String>,
) {
self.tool_calls.push(McpCallRecord {
call_id,
tool_name,
arguments,
output,
success,
error,
});
}
fn total_calls(&self) -> usize {
self.tool_calls.len()
}
}
/// Context for Harmony Responses execution with MCP tool support
///
/// Contains all dependencies needed for multi-turn Responses API execution.
/// Cheap to clone (all Arc references).
#[derive(Clone)]
pub struct HarmonyResponsesContext {
/// Pipeline for executing Harmony requests
pub pipeline: Arc<RequestPipeline>,
/// Shared components (tokenizer, parsers)
pub components: Arc<SharedComponents>,
/// MCP manager for tool execution
pub mcp_manager: Arc<McpManager>,
/// Response storage for loading conversation history
pub response_storage: Arc<dyn ResponseStorage>,
/// Optional streaming sender (for future streaming support)
pub stream_tx: Option<tokio::sync::mpsc::UnboundedSender<Result<String, String>>>,
}
impl HarmonyResponsesContext {
/// Create a new Harmony Responses context
pub fn new(
pipeline: Arc<RequestPipeline>,
components: Arc<SharedComponents>,
mcp_manager: Arc<McpManager>,
response_storage: Arc<dyn ResponseStorage>,
) -> Self {
Self {
pipeline,
components,
mcp_manager,
response_storage,
stream_tx: None,
}
}
/// Create with streaming support
pub fn with_streaming(
pipeline: Arc<RequestPipeline>,
components: Arc<SharedComponents>,
mcp_manager: Arc<McpManager>,
response_storage: Arc<dyn ResponseStorage>,
stream_tx: tokio::sync::mpsc::UnboundedSender<Result<String, String>>,
) -> Self {
Self {
pipeline,
components,
mcp_manager,
response_storage,
stream_tx: Some(stream_tx),
}
}
}
/// Execute Harmony Responses API request with multi-turn MCP tool support
///
/// This function orchestrates the multi-turn conversation flow:
/// 1. Execute request through full pipeline
/// 2. Check for tool calls in commentary channel
/// 3. If tool calls found:
/// - Execute MCP tools
/// - Build next request with tool results
/// - Repeat from step 1 (full pipeline re-execution)
/// 4. If no tool calls, return final response
///
/// # Architecture
///
/// Uses **external loop pattern**: wraps full pipeline execution rather than
/// implementing loop inside pipeline. Each iteration goes through:
/// - Worker selection (fresh selection based on current context)
/// - Client acquisition (new gRPC client if worker changed)
/// - Request building (Harmony prefill with complete history)
/// - Execution (model generation)
/// - Response processing (parse channels, detect tool calls)
///
/// # Arguments
///
/// * `ctx` - Harmony Responses context with pipeline, components, MCP manager
/// * `request` - Initial Responses API request
///
/// # Returns
///
/// Final ResponsesResponse after all tool iterations complete
///
/// # Errors
///
/// Returns error if:
/// - Max iterations exceeded (10 iterations)
/// - Pipeline execution fails
/// - MCP tool execution fails
/// - Response building fails
pub async fn serve_harmony_responses(
ctx: &HarmonyResponsesContext,
request: ResponsesRequest,
) -> Result<ResponsesResponse, Response> {
// Load previous conversation history if previous_response_id is set
let mut current_request = load_previous_messages(ctx, request).await?;
let mut iteration_count = 0;
// Check if request has MCP tools - if so, ensure dynamic client is registered
// and add static MCP tools to the request
use crate::{
protocols::responses::ResponseToolType, routers::openai::mcp::ensure_request_mcp_client,
};
let has_mcp_tools = current_request
.tools
.as_ref()
.map(|tools| {
tools
.iter()
.any(|t| matches!(t.r#type, ResponseToolType::Mcp))
})
.unwrap_or(false);
// Initialize MCP call tracking (will be passed to processor for final response)
let mut mcp_tracking = if has_mcp_tools {
Some(McpCallTracking::new("sglang-mcp".to_string()))
} else {
None
};
if has_mcp_tools {
// Ensure dynamic MCP client is registered for request-scoped tools
if let Some(tools) = &current_request.tools {
ensure_request_mcp_client(&ctx.mcp_manager, tools).await;
}
// Add static MCP tools from inventory to the request
// (similar to non-Harmony pipeline pattern)
let mcp_tools = ctx.mcp_manager.list_tools();
if !mcp_tools.is_empty() {
let mcp_response_tools = convert_mcp_tools_to_response_tools(&mcp_tools);
let mut all_tools = current_request.tools.clone().unwrap_or_default();
all_tools.extend(mcp_response_tools);
current_request.tools = Some(all_tools);
tracing::debug!(
mcp_tool_count = mcp_tools.len(),
total_tool_count = current_request.tools.as_ref().map(|t| t.len()).unwrap_or(0),
"Request has MCP tools - added static MCP tools to Harmony Responses request"
);
}
}
loop {
iteration_count += 1;
// Safety check: prevent infinite loops
if iteration_count > MAX_TOOL_ITERATIONS {
return Err(utils::internal_error_message(format!(
"Maximum tool iterations ({}) exceeded",
MAX_TOOL_ITERATIONS
)));
}
tracing::debug!(
iteration = iteration_count,
"Harmony Responses serving iteration"
);
// Execute through full pipeline
// This includes:
// - HarmonyPreparationStage (builder.rs: construct_input_messages_with_harmony)
// - WorkerSelectionStage (FRESH selection based on current context)
// - ClientAcquisitionStage (NEW gRPC client if needed)
// - HarmonyRequestBuildingStage (encode to token_ids)
// - RequestExecutionStage (model generation)
// - HarmonyResponseProcessingStage (processor.rs: process_responses_iteration)
let iteration_result = ctx
.pipeline
.execute_harmony_responses(&current_request, ctx)
.await?;
match iteration_result {
ResponsesIterationResult::ToolCallsFound {
tool_calls,
analysis,
partial_text,
} => {
tracing::debug!(
tool_call_count = tool_calls.len(),
has_analysis = analysis.is_some(),
partial_text_len = partial_text.len(),
"Tool calls found in commentary channel"
);
// TODO: Streaming support - emit intermediate chunks
// if let Some(tx) = &ctx.stream_tx {
// emit_intermediate_chunks(tx, &analysis, &partial_text, iteration_count).await?;
// }
// Execute MCP tools via MCP manager
// If tools don't exist, call_tool() will return error naturally
let tool_results = if let Some(ref mut tracking) = mcp_tracking {
execute_mcp_tools(&ctx.mcp_manager, &tool_calls, tracking).await?
} else {
// Should never happen (we only get tool_calls when has_mcp_tools=true)
return Err(utils::internal_error_static(
"Tool calls found but MCP tracking not initialized",
));
};
// Build next request with appended history
current_request = build_next_request_with_tools(
current_request,
tool_calls,
tool_results,
analysis,
partial_text,
)
.map_err(|e| *e)?;
// Continue loop - next iteration will select workers and execute
}
ResponsesIterationResult::Completed {
mut response,
usage,
} => {
tracing::debug!(
output_items = response.output.len(),
input_tokens = usage.prompt_tokens,
output_tokens = usage.completion_tokens,
has_mcp_tracking = mcp_tracking.is_some(),
"Harmony Responses serving completed - no more tool calls"
);
// Inject MCP output items if MCP tools were available
// (even if no tools were called, we still list available tools)
if let Some(tracking) = mcp_tracking {
inject_mcp_metadata(&mut response, &tracking, &ctx.mcp_manager);
tracing::debug!(
mcp_calls = tracking.total_calls(),
output_items_after = response.output.len(),
"Injected MCP metadata into final response"
);
}
// No tool calls - this is the final response
// TODO: Accumulate usage across all iterations if needed
return Ok(*response);
}
}
}
}
/// Execute MCP tools and collect results
///
/// Executes each tool call sequentially via the MCP manager.
/// Tool execution errors are returned as error results to the model
/// (allows model to handle gracefully).
///
/// # Arguments
///
/// * `mcp_manager` - MCP manager for tool execution
/// * `tool_calls` - Tool calls from commentary channel
///
/// # Returns
///
/// Vector of tool results (one per tool call)
async fn execute_mcp_tools(
mcp_manager: &Arc<McpManager>,
tool_calls: &[ToolCall],
tracking: &mut McpCallTracking,
) -> Result<Vec<ToolResult>, Response> {
let mut results = Vec::new();
for tool_call in tool_calls {
tracing::debug!(
tool_name = %tool_call.function.name,
call_id = %tool_call.id,
"Executing MCP tool"
);
// Parse tool arguments from JSON string
let args_str = tool_call.function.arguments.as_deref().unwrap_or("{}");
let args: JsonValue = serde_json::from_str(args_str).map_err(|e| {
utils::internal_error_message(format!(
"Invalid tool arguments JSON for tool '{}': {}",
tool_call.function.name, e
))
})?;
// Execute tool via MCP manager
// Convert JsonValue to ToolArgs via Option<Map> (MCP manager expects this)
let args_map = if let JsonValue::Object(map) = args {
Some(map)
} else {
None
};
match mcp_manager
.call_tool(&tool_call.function.name, args_map)
.await
{
Ok(mcp_result) => {
tracing::debug!(
tool_name = %tool_call.function.name,
call_id = %tool_call.id,
"Tool execution succeeded"
);
// Extract content from MCP result
let output = if let Some(content) = mcp_result.content.first() {
// TODO: Handle different content types (text, image, resource)
// For now, serialize the entire content item
serde_json::to_value(content).unwrap_or_else(
|_| serde_json::json!({"error": "Failed to serialize tool result"}),
)
} else {
serde_json::json!({"result": "success"})
};
let is_error = mcp_result.is_error.unwrap_or(false);
let output_str = serde_json::to_string(&output)
.unwrap_or_else(|_| r#"{"error": "Failed to serialize output"}"#.to_string());
// Record this call in tracking
tracking.record_call(
tool_call.id.clone(),
tool_call.function.name.clone(),
args_str.to_string(),
output_str.clone(),
!is_error,
if is_error {
Some(output_str.clone())
} else {
None
},
);
results.push(ToolResult {
call_id: tool_call.id.clone(),
tool_name: tool_call.function.name.clone(),
output,
is_error,
});
}
Err(e) => {
tracing::warn!(
tool_name = %tool_call.function.name,
call_id = %tool_call.id,
error = %e,
"Tool execution failed"
);
let error_msg = format!("Tool execution failed: {}", e);
let error_output = serde_json::json!({
"error": error_msg.clone()
});
let error_output_str = serde_json::to_string(&error_output)
.unwrap_or_else(|_| format!(r#"{{"error": "{}"}}"#, error_msg));
// Record failed call in tracking
tracking.record_call(
tool_call.id.clone(),
tool_call.function.name.clone(),
args_str.to_string(),
error_output_str.clone(),
false,
Some(error_msg),
);
// Return error result to model (let it handle gracefully)
results.push(ToolResult {
call_id: tool_call.id.clone(),
tool_name: tool_call.function.name.clone(),
output: error_output,
is_error: true,
});
}
}
}
Ok(results)
}
/// Build next request with tool results appended to history
///
/// Constructs a new ResponsesRequest with:
/// 1. Original input items (preserved)
/// 2. Assistant message with analysis (reasoning) + partial_text + tool_calls
/// 3. Tool result messages for each tool execution
///
/// # Arguments
///
/// * `request` - Current request (contains original input)
/// * `tool_calls` - Tool calls from commentary channel
/// * `tool_results` - Results from MCP tool execution
/// * `analysis` - Analysis channel content (becomes reasoning content)
/// * `partial_text` - Final channel content (becomes message content)
///
/// # Returns
///
/// New ResponsesRequest with appended history
fn build_next_request_with_tools(
mut request: ResponsesRequest,
tool_calls: Vec<ToolCall>,
tool_results: Vec<ToolResult>,
analysis: Option<String>,
partial_text: String,
) -> Result<ResponsesRequest, Box<Response>> {
use uuid::Uuid;
use crate::protocols::responses::{
ResponseContentPart, ResponseInputOutputItem, ResponseReasoningContent,
};
// Get current input items (or empty vec if Text variant)
let mut items = match request.input {
ResponseInput::Items(items) => items,
ResponseInput::Text(text) => {
// Convert text to items format
vec![ResponseInputOutputItem::SimpleInputMessage {
content: StringOrContentParts::String(text),
role: "user".to_string(),
r#type: None,
}]
}
};
// Build assistant response item with reasoning + content + tool calls
// This represents what the model generated in this iteration
let assistant_id = format!("msg_{}", Uuid::new_v4());
// Add reasoning if present (from analysis channel)
if let Some(analysis_text) = analysis {
items.push(ResponseInputOutputItem::Reasoning {
id: format!("reasoning_{}", assistant_id),
summary: vec![],
content: vec![ResponseReasoningContent::ReasoningText {
text: analysis_text,
}],
status: Some("completed".to_string()),
});
}
// Add message content if present (from final channel)
if !partial_text.is_empty() {
items.push(ResponseInputOutputItem::Message {
id: assistant_id.clone(),
role: "assistant".to_string(),
content: vec![ResponseContentPart::OutputText {
text: partial_text,
annotations: vec![],
logprobs: None,
}],
status: Some("completed".to_string()),
});
}
// Add function tool calls (from commentary channel)
for tool_call in tool_calls {
items.push(ResponseInputOutputItem::FunctionToolCall {
id: tool_call.id.clone(),
name: tool_call.function.name.clone(),
arguments: tool_call
.function
.arguments
.unwrap_or_else(|| "{}".to_string()),
output: None, // Output will be added next
status: Some("in_progress".to_string()),
});
}
// Add tool results
for tool_result in tool_results {
// Serialize tool output to string
let output_str = serde_json::to_string(&tool_result.output).unwrap_or_else(|e| {
format!("{{\"error\": \"Failed to serialize tool output: {}\"}}", e)
});
// Update the corresponding tool call with output and completed status
// Find and update the matching FunctionToolCall
if let Some(ResponseInputOutputItem::FunctionToolCall {
output,
status,
..
}) = items
.iter_mut()
.find(|item| matches!(item, ResponseInputOutputItem::FunctionToolCall { id, .. } if id == &tool_result.call_id))
{
*output = Some(output_str);
*status = if tool_result.is_error {
Some("failed".to_string())
} else {
Some("completed".to_string())
};
}
}
// Update request with new items
request.input = ResponseInput::Items(items);
Ok(request)
}
/// Tool execution result
///
/// Contains the result of executing a single MCP tool.
struct ToolResult {
/// Tool call ID (for matching with request)
call_id: String,
/// Tool name
#[allow(dead_code)] // Kept for documentation and future use
tool_name: String,
/// Tool output (JSON value)
output: JsonValue,
/// Whether this is an error result
is_error: bool,
}
/// Convert MCP tools to Responses API tool format
///
/// Converts MCP Tool entries (from rmcp SDK) to ResponseTool format so the model
/// knows about available MCP tools when making tool calls.
///
/// # Arguments
///
/// * `mcp_tools` - MCP tools from the MCP manager inventory (rmcp::model::Tool)
///
/// # Returns
///
/// Vector of ResponseTool entries in MCP format
fn convert_mcp_tools_to_response_tools(mcp_tools: &[crate::mcp::Tool]) -> Vec<ResponseTool> {
use serde_json::Value;
use crate::protocols::responses::ResponseToolType;
mcp_tools
.iter()
.map(|tool_info| ResponseTool {
r#type: ResponseToolType::Mcp,
function: Some(Function {
name: tool_info.name.to_string(),
description: tool_info.description.as_ref().map(|d| d.to_string()),
parameters: Value::Object((*tool_info.input_schema).clone()),
strict: None,
}),
server_url: None, // MCP tools from inventory don't have individual server URLs
authorization: None,
server_label: None,
server_description: tool_info.description.as_ref().map(|d| d.to_string()),
require_approval: None,
allowed_tools: None,
})
.collect()
}
/// Inject MCP metadata into final response
///
/// Adds mcp_list_tools and mcp_call output items to the response output array.
/// Following non-Harmony pipeline pattern:
/// 1. Prepend mcp_list_tools at the beginning
/// 2. Append all mcp_call items at the end
///
/// # Arguments
///
/// * `response` - Final response to modify
/// * `tracking` - MCP call tracking data
/// * `mcp_manager` - MCP manager for listing tools
fn inject_mcp_metadata(
response: &mut ResponsesResponse,
tracking: &McpCallTracking,
mcp_manager: &Arc<McpManager>,
) {
use serde_json::{json, Value};
use uuid::Uuid;
use crate::protocols::responses::{McpToolInfo, ResponseOutputItem};
// Build mcp_list_tools item
let tools = mcp_manager.list_tools();
let tools_info: Vec<McpToolInfo> = tools
.iter()
.map(|t| McpToolInfo {
name: t.name.to_string(),
description: t.description.as_ref().map(|d| d.to_string()),
input_schema: Value::Object((*t.input_schema).clone()),
annotations: Some(json!({
"read_only": false
})),
})
.collect();
let mcp_list_tools = ResponseOutputItem::McpListTools {
id: format!("mcpl_{}", Uuid::new_v4()),
server_label: tracking.server_label.clone(),
tools: tools_info,
};
// Build mcp_call items for each tracked call
let mcp_call_items: Vec<ResponseOutputItem> = tracking
.tool_calls
.iter()
.map(|record| ResponseOutputItem::McpCall {
id: format!("mcp_{}", Uuid::new_v4()),
status: if record.success {
"completed"
} else {
"failed"
}
.to_string(),
approval_request_id: None,
arguments: record.arguments.clone(),
error: record.error.clone(),
name: record.tool_name.clone(),
output: record.output.clone(),
server_label: tracking.server_label.clone(),
})
.collect();
// Inject into response output:
// 1. Prepend mcp_list_tools at the beginning
response.output.insert(0, mcp_list_tools);
// 2. Append all mcp_call items at the end
response.output.extend(mcp_call_items);
}
/// Load previous conversation messages from storage
///
/// If the request has `previous_response_id`, loads the response chain from storage
/// and prepends the conversation history to the request input items.
///
/// # Arguments
///
/// * `ctx` - Harmony Responses context with response_storage
/// * `request` - Current request (may have previous_response_id set)
///
/// # Returns
///
/// Modified request with conversation history prepended to input items
async fn load_previous_messages(
ctx: &HarmonyResponsesContext,
request: ResponsesRequest,
) -> Result<ResponsesRequest, Response> {
let Some(ref prev_id_str) = request.previous_response_id else {
// No previous_response_id, return request as-is
return Ok(request);
};
let prev_id = ResponseId::from(prev_id_str.as_str());
// Load response chain from storage
let chain = ctx
.response_storage
.get_response_chain(&prev_id, None)
.await
.map_err(|e| {
utils::internal_error_message(format!(
"Failed to load previous response chain for {}: {}",
prev_id_str, e
))
})?;
// Build conversation history from stored responses
let mut history_items = Vec::new();
// Helper to deserialize and collect items from a JSON array
let deserialize_items =
|arr: &serde_json::Value, item_type: &str| -> Vec<ResponseInputOutputItem> {
arr.as_array()
.into_iter()
.flat_map(|items| items.iter())
.filter_map(|item| {
serde_json::from_value::<ResponseInputOutputItem>(item.clone())
.map_err(|e| {
tracing::warn!(
"Failed to deserialize stored {} item: {}. Item: {}",
item_type,
e,
item
);
})
.ok()
})
.collect()
};
for stored in chain.responses.iter() {
history_items.extend(deserialize_items(&stored.input, "input"));
history_items.extend(deserialize_items(&stored.output, "output"));
}
tracing::debug!(
previous_response_id = %prev_id_str,
history_items_count = history_items.len(),
"Loaded conversation history from previous response"
);
// Build modified request with history prepended
let mut modified_request = request;
// Convert current input to items format
let all_items = match modified_request.input {
ResponseInput::Items(items) => {
// Prepend history to existing items
let mut combined = history_items;
combined.extend(items);
combined
}
ResponseInput::Text(text) => {
// Convert text to item and prepend history
history_items.push(ResponseInputOutputItem::SimpleInputMessage {
content: StringOrContentParts::String(text),
role: "user".to_string(),
r#type: None,
});
history_items
}
};
// Update request with combined items and clear previous_response_id
modified_request.input = ResponseInput::Items(all_items);
modified_request.previous_response_id = None;
Ok(modified_request)
}
// TODO: Implement streaming support
// /// Emit intermediate streaming chunks for analysis and partial text
// ///
// /// Emits SSE chunks for Responses API streaming:
// /// - Reasoning chunks for analysis channel
// /// - Message chunks for partial text from final channel
// ///
// /// # Arguments
// ///
// /// * `tx` - Streaming sender
// /// * `analysis` - Analysis channel content
// /// * `partial_text` - Final channel content
// /// * `iteration` - Current iteration number
// async fn emit_intermediate_chunks(
// tx: &tokio::sync::mpsc::UnboundedSender<Result<String, String>>,
// analysis: &Option<String>,
// partial_text: &str,
// iteration: usize,
// ) -> Result<(), Response> {
// // TODO: Implement streaming emission
// // - Emit reasoning chunks for analysis
// // - Emit message chunks for partial_text
// // - Follow OpenAI Responses streaming format (14 SSE event types)
// Ok(())
// }
//! Harmony-specific pipeline stages
//!
//! These stages replace their regular counterparts in the Harmony pipeline:
//! - HarmonyPreparationStage: Harmony encoding instead of chat template + tokenization
//! - HarmonyRequestBuildingStage: Token-based request building
//! - HarmonyResponseProcessingStage: Harmony channel parsing
pub mod preparation;
pub mod request_building;
pub mod response_processing;
pub use preparation::HarmonyPreparationStage;
pub use request_building::HarmonyRequestBuildingStage;
pub use response_processing::HarmonyResponseProcessingStage;
//! Harmony Preparation Stage: Harmony encoding for chat and generate requests
use async_trait::async_trait;
use axum::response::Response;
use serde_json::json;
use super::super::HarmonyBuilder;
use crate::{
protocols::{
chat::ChatCompletionRequest,
common::{Tool, ToolChoice, ToolChoiceValue},
responses::ResponsesRequest,
},
routers::grpc::{
context::{PreparationOutput, RequestContext, RequestType},
stages::PipelineStage,
utils,
},
};
/// Harmony Preparation stage: Encode requests using Harmony protocol
///
/// Replaces the regular PreparationStage for Harmony models.
/// Converts chat/generate requests to Harmony-encoded token_ids and extraction_text.
pub struct HarmonyPreparationStage {
builder: HarmonyBuilder,
}
impl HarmonyPreparationStage {
/// Create a new Harmony preparation stage
pub fn new() -> Self {
Self {
builder: HarmonyBuilder::new(),
}
}
}
impl Default for HarmonyPreparationStage {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl PipelineStage for HarmonyPreparationStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
// Clone Arc before match to avoid borrow checker issues
// Arc clone is cheap (8 bytes) - avoids full request clone (15KB-200KB)
let is_chat = matches!(&ctx.input.request_type, RequestType::Chat(_));
let is_responses = matches!(&ctx.input.request_type, RequestType::Responses(_));
if is_chat {
let request_arc = ctx.chat_request_arc();
self.prepare_chat(ctx, &request_arc).await?;
} else if is_responses {
let request_arc = ctx.responses_request_arc();
self.prepare_responses(ctx, &request_arc).await?;
} else {
return Err(utils::bad_request_error(
"Only Chat and Responses requests supported in Harmony pipeline".to_string(),
));
}
Ok(None)
}
fn name(&self) -> &'static str {
"HarmonyPreparation"
}
}
impl HarmonyPreparationStage {
/// Prepare a chat completion request using Harmony encoding
async fn prepare_chat(
&self,
ctx: &mut RequestContext,
request: &ChatCompletionRequest,
) -> Result<Option<Response>, Response> {
// Validate - reject logprobs
if request.logprobs {
return Err(utils::bad_request_error(
"logprobs are not supported for Harmony models".to_string(),
));
}
// Step 1: Filter tools if needed
let body_ref = utils::filter_tools_for_request(request);
// Step 2: Build tool constraints
let tool_constraints = if let Some(tools) = body_ref.tools.as_ref() {
Self::generate_harmony_structural_tag(tools, &body_ref.tool_choice).map_err(|e| *e)?
} else {
None
};
// Step 3: Build via Harmony
let build_output = self
.builder
.build_from_chat(&body_ref)
.map_err(|e| utils::bad_request_error(format!("Harmony build failed: {}", e)))?;
// Step 4: Store results
ctx.state.preparation = Some(PreparationOutput {
original_text: None,
token_ids: build_output.input_ids,
processed_messages: None,
tool_constraints,
filtered_request: if matches!(body_ref, std::borrow::Cow::Owned(_)) {
Some(body_ref.into_owned())
} else {
None
},
harmony_mode: true,
selection_text: Some(build_output.selection_text),
harmony_messages: Some(build_output.harmony_messages),
harmony_stop_ids: Some(build_output.stop_token_ids),
});
Ok(None)
}
/// Prepare a responses API request using Harmony encoding
///
/// For responses API, we build from conversation history using the same Harmony
/// encoding that the builder provides. This handles the MCP loop integration.
pub async fn prepare_responses(
&self,
ctx: &mut RequestContext,
request: &ResponsesRequest,
) -> Result<Option<Response>, Response> {
// Build via Harmony from responses API request
let build_output = self
.builder
.build_from_responses(request)
.map_err(|e| utils::bad_request_error(format!("Harmony build failed: {}", e)))?;
// Store results in preparation output
ctx.state.preparation = Some(PreparationOutput {
original_text: None,
token_ids: build_output.input_ids,
processed_messages: None,
tool_constraints: None,
filtered_request: None,
harmony_mode: true,
selection_text: Some(build_output.selection_text),
harmony_messages: Some(build_output.harmony_messages),
harmony_stop_ids: Some(build_output.stop_token_ids),
});
Ok(None)
}
/// Generate Harmony structural tag for tool constraints
///
/// Uses structural tags with `triggered_tags` format to force Harmony format output.
/// This ensures the model outputs in Harmony format (with channels) even when constrained.
fn generate_harmony_structural_tag(
tools: &[Tool],
tool_choice: &Option<ToolChoice>,
) -> Result<Option<(String, String)>, Box<Response>> {
let Some(choice) = tool_choice.as_ref() else {
return Ok(None);
};
match choice {
ToolChoice::Function { function, .. } => {
let tag = Self::build_harmony_structural_tag(tools, Some(&function.name))?;
Ok(Some(("structural_tag".to_string(), tag)))
}
ToolChoice::Value(ToolChoiceValue::Required) => {
let tag = Self::build_harmony_structural_tag(tools, None)?;
Ok(Some(("structural_tag".to_string(), tag)))
}
ToolChoice::AllowedTools { mode, .. } => {
if mode == "required" {
let tag = Self::build_harmony_structural_tag(tools, None)?;
Ok(Some(("structural_tag".to_string(), tag)))
} else {
Ok(None)
}
}
_ => Ok(None),
}
}
/// Build Harmony structural tag for tool calling constraints
fn build_harmony_structural_tag(
tools: &[Tool],
specific_function: Option<&str>,
) -> Result<String, Box<Response>> {
let mut tags = Vec::new();
// Filter tools if specific function requested
let tools_to_use: Vec<&Tool> = if let Some(func_name) = specific_function {
tools
.iter()
.filter(|t| t.function.name == func_name)
.collect()
} else {
tools.iter().collect()
};
// Validate specific function exists
if specific_function.is_some() && tools_to_use.is_empty() {
return Err(Box::new(utils::bad_request_error(format!(
"Tool '{}' not found in tools list",
specific_function.unwrap()
))));
}
// Build tags for each tool
for tool in tools_to_use {
let tool_name = &tool.function.name;
let params_schema = &tool.function.parameters;
tags.push(json!({
"begin": format!("<|channel|>commentary to=functions.{}<|constrain|>json<|message|>", tool_name),
"content": {
"type": "json_schema",
"json_schema": params_schema
},
"end": "" // `end` is empty because <|call|> comes naturally from Harmony stop tokens
}));
}
let stop_after_first = specific_function.is_some();
let structural_tag = json!({
"format": {
"type": "triggered_tags",
"triggers": ["<|channel|>commentary"],
"tags": tags,
"at_least_one": true,
"stop_after_first": stop_after_first
}
});
serde_json::to_string(&structural_tag).map_err(|e| {
Box::new(utils::internal_error_message(format!(
"Failed to serialize structural tag: {}",
e
)))
})
}
}
//! Harmony Request Building Stage: Build gRPC request from Harmony-encoded tokens
use std::sync::Arc;
use async_trait::async_trait;
use axum::response::Response;
use rand::Rng;
use tracing::debug;
use uuid::Uuid;
use crate::{
core::Worker,
grpc_client::proto::{DisaggregatedParams, GenerateRequest},
routers::grpc::{
context::{ClientSelection, RequestContext, RequestType, WorkerSelection},
stages::PipelineStage,
utils,
},
};
/// Harmony Request Building stage: Convert Harmony tokens to gRPC request
///
/// Takes the Harmony-encoded input_ids from preparation and builds a proto::GenerateRequest.
/// Unlike regular request building, this uses token_ids directly (Harmony encoding handles messages).
pub struct HarmonyRequestBuildingStage {
inject_pd_metadata: bool,
}
impl HarmonyRequestBuildingStage {
/// Create a new Harmony request building stage
pub fn new(inject_pd_metadata: bool) -> Self {
Self { inject_pd_metadata }
}
/// Inject PD (prefill-decode) bootstrap metadata
fn inject_bootstrap_metadata(
&self,
request: &mut GenerateRequest,
prefill_worker: &Arc<dyn Worker>,
) {
let hostname = prefill_worker.bootstrap_host();
let bootstrap_port = prefill_worker.bootstrap_port().unwrap_or(8998);
// Generate room ID for bootstrap
let room_id = rand::rng().random_range(0..i32::MAX);
// Create DisaggregatedParams
let disagg_params = DisaggregatedParams {
bootstrap_host: hostname.to_string(),
bootstrap_port: bootstrap_port as i32,
bootstrap_room: room_id,
};
// Inject metadata directly into request
request.disaggregated_params = Some(disagg_params);
debug!(
"Injected Harmony bootstrap metadata: host={}, port={}, room={}",
hostname, bootstrap_port, room_id
);
}
}
#[async_trait]
impl PipelineStage for HarmonyRequestBuildingStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
// Get preparation output
let prep = ctx
.state
.preparation
.as_ref()
.ok_or_else(|| utils::internal_error_static("Preparation not completed"))?;
// Get clients
let clients = ctx
.state
.clients
.as_ref()
.ok_or_else(|| utils::internal_error_static("Client acquisition not completed"))?;
let builder_client = match clients {
ClientSelection::Single { client } => client,
ClientSelection::Dual { prefill, .. } => prefill,
};
// Generate request_id based on request type
let request_id = match &ctx.input.request_type {
RequestType::Chat(_) => format!("chatcmpl-{}", Uuid::new_v4()),
RequestType::Responses(_) => format!("responses-{}", Uuid::new_v4()),
RequestType::Generate(_) => {
return Err(utils::bad_request_error(
"Generate requests are not supported with Harmony models".to_string(),
));
}
};
// Build gRPC request using token_ids directly (Harmony encoding already handled message rendering)
// Use a placeholder for original_text; Harmony uses input_ids for tokenization
let placeholder_processed_text = "[harmony]".to_string();
let mut proto_request = match &ctx.input.request_type {
RequestType::Chat(request) => {
// Use filtered request if present from preparation; otherwise original
let body = prep.filtered_request.as_ref().unwrap_or(request.as_ref());
builder_client
.build_generate_request(
request_id,
body,
placeholder_processed_text,
prep.token_ids.clone(),
None,
prep.tool_constraints.clone(),
)
.map_err(|e| {
utils::bad_request_error(format!("Invalid request parameters: {}", e))
})?
}
RequestType::Responses(request) => builder_client
.build_generate_request_from_responses(
request_id,
request.as_ref(),
placeholder_processed_text,
prep.token_ids.clone(),
prep.harmony_stop_ids.clone(),
)
.map_err(|e| {
utils::bad_request_error(format!("Invalid request parameters: {}", e))
})?,
_ => unreachable!(),
};
// Inject Harmony stop token IDs into sampling params for ALL Harmony requests
// These stop tokens (<|return|> and <|call|>) prevent the model from generating
// malformed Harmony sequences
if let Some(harmony_stops) = &prep.harmony_stop_ids {
if let Some(params) = proto_request.sampling_params.as_mut() {
params.stop_token_ids.extend_from_slice(harmony_stops);
debug!(
stop_token_count = harmony_stops.len(),
"Injected Harmony stop tokens into sampling params"
);
}
}
// Inject PD metadata if needed
if self.inject_pd_metadata {
if let Some(WorkerSelection::Dual { prefill, .. }) = ctx.state.workers.as_ref() {
self.inject_bootstrap_metadata(&mut proto_request, prefill);
}
}
ctx.state.proto_request = Some(proto_request);
Ok(None)
}
fn name(&self) -> &'static str {
"HarmonyRequestBuilding"
}
}
//! Harmony Response Processing Stage: Parse Harmony channels to ChatCompletionResponse
use std::sync::Arc;
use async_trait::async_trait;
use axum::response::Response;
use super::super::{HarmonyResponseProcessor, HarmonyStreamingProcessor};
use crate::routers::grpc::{
context::{FinalResponse, RequestContext, RequestType},
stages::PipelineStage,
utils,
};
/// Harmony Response Processing stage: Parse and format Harmony responses
///
/// Takes output tokens from execution and parses them using HarmonyParserAdapter
/// to extract analysis, tool calls, and final response text from Harmony channels.
pub struct HarmonyResponseProcessingStage {
processor: HarmonyResponseProcessor,
streaming_processor: Arc<HarmonyStreamingProcessor>,
}
impl HarmonyResponseProcessingStage {
/// Create a new Harmony response processing stage
pub fn new() -> Self {
Self {
processor: HarmonyResponseProcessor::new(),
streaming_processor: Arc::new(HarmonyStreamingProcessor::new()),
}
}
}
impl Default for HarmonyResponseProcessingStage {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl PipelineStage for HarmonyResponseProcessingStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
// Get execution result (output tokens from model)
let execution_result = ctx
.state
.response
.execution_result
.take()
.ok_or_else(|| utils::internal_error_static("No execution result"))?;
let is_streaming = ctx.is_streaming();
let dispatch = ctx
.state
.dispatch
.as_ref()
.cloned()
.ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?;
// Check request type to determine which processor method to call
match &ctx.input.request_type {
RequestType::Chat(_) => {
// For streaming, delegate to streaming processor and return SSE response
if is_streaming {
return Ok(Some(
self.streaming_processor
.clone()
.process_streaming_chat_response(
execution_result,
ctx.chat_request_arc(),
dispatch,
),
));
}
// For non-streaming, delegate to Harmony response processor to build ChatCompletionResponse
let chat_request = ctx.chat_request_arc();
let response = self
.processor
.process_non_streaming_chat_response(execution_result, chat_request, dispatch)
.await?;
ctx.state.response.final_response = Some(FinalResponse::Chat(response));
Ok(None)
}
RequestType::Responses(_) => {
// For Responses API, process iteration and store result
// Streaming not yet supported for Responses API
if is_streaming {
return Err(utils::internal_error_static(
"Streaming not yet supported for Responses API",
));
}
let responses_request = ctx.responses_request_arc();
let iteration_result = self
.processor
.process_responses_iteration(execution_result, responses_request, dispatch)
.await?;
ctx.state.response.responses_iteration_result = Some(iteration_result);
Ok(None)
}
RequestType::Generate(_) => Err(utils::internal_error_static(
"Generate requests not supported in Harmony pipeline",
)),
}
}
fn name(&self) -> &'static str {
"HarmonyResponseProcessing"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_response_processing_stage_creation() {
let stage = HarmonyResponseProcessingStage::new();
assert_eq!(stage.name(), "HarmonyResponseProcessing");
}
}
//! Harmony streaming response processor
use std::{
collections::{hash_map::Entry::Vacant, HashMap},
io,
sync::Arc,
};
use axum::{body::Body, http::StatusCode, response::Response};
use bytes::Bytes;
use http::header::{HeaderValue, CONTENT_TYPE};
use proto::{
generate_complete::MatchedStop::{MatchedStopStr, MatchedTokenId},
generate_response::Response::{Chunk, Complete},
};
use serde_json::json;
use tokio::sync::mpsc;
use tokio_stream::{wrappers::UnboundedReceiverStream, StreamExt};
use tracing::error;
use super::{types::HarmonyChannelDelta, HarmonyParserAdapter};
use crate::{
grpc_client::{proto, sglang_scheduler::AbortOnDropStream},
protocols::{
chat::{
ChatCompletionRequest, ChatCompletionStreamResponse, ChatMessageDelta, ChatStreamChoice,
},
common::{FunctionCallDelta, ToolCallDelta, Usage},
},
routers::grpc::context,
};
/// Processor for streaming Harmony responses
///
/// Returns an SSE stream that parses Harmony tokens incrementally and
/// emits ChatCompletionChunk events for streaming responses.
pub struct HarmonyStreamingProcessor;
impl HarmonyStreamingProcessor {
/// Create a new Harmony streaming processor
pub fn new() -> Self {
Self
}
/// Process a streaming Harmony Chat Completion response
///
/// Returns an SSE response with streaming token updates.
pub fn process_streaming_chat_response(
self: Arc<Self>,
execution_result: context::ExecutionResult,
chat_request: Arc<ChatCompletionRequest>,
dispatch: context::DispatchMetadata,
) -> Response {
// Create SSE channel
let (tx, rx) = mpsc::unbounded_channel::<Result<Bytes, io::Error>>();
// Spawn background task based on execution mode
match execution_result {
context::ExecutionResult::Single { stream } => {
tokio::spawn(async move {
let result =
Self::process_single_stream(stream, dispatch, chat_request, &tx).await;
if let Err(e) = result {
error!("Harmony streaming error: {}", e);
let error_chunk = format!(
"data: {}\n\n",
json!({
"error": {
"message": e,
"type": "internal_error"
}
})
);
let _ = tx.send(Ok(Bytes::from(error_chunk)));
}
let _ = tx.send(Ok(Bytes::from("data: [DONE]\n\n")));
});
}
context::ExecutionResult::Dual { prefill, decode } => {
tokio::spawn(async move {
let result =
Self::process_dual_stream(prefill, *decode, dispatch, chat_request, &tx)
.await;
if let Err(e) = result {
error!("Harmony dual streaming error: {}", e);
let error_chunk = format!(
"data: {}\n\n",
json!({
"error": {
"message": e,
"type": "internal_error"
}
})
);
let _ = tx.send(Ok(Bytes::from(error_chunk)));
}
let _ = tx.send(Ok(Bytes::from("data: [DONE]\n\n")));
});
}
}
// Return SSE response
Self::build_sse_response(rx)
}
/// Process streaming chunks from a single stream
async fn process_single_stream(
mut grpc_stream: AbortOnDropStream,
dispatch: context::DispatchMetadata,
original_request: Arc<ChatCompletionRequest>,
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
) -> Result<(), String> {
// Per-index state management (for n>1 support)
let mut parsers: HashMap<u32, HarmonyParserAdapter> = HashMap::new();
let mut is_firsts: HashMap<u32, bool> = HashMap::new();
let mut finish_reasons: HashMap<u32, Option<String>> = HashMap::new();
let mut matched_stops: HashMap<u32, Option<serde_json::Value>> = HashMap::new();
let mut prompt_tokens: HashMap<u32, u32> = HashMap::new();
let mut completion_tokens: HashMap<u32, u32> = HashMap::new();
let stream_options = &original_request.stream_options;
// Process stream
while let Some(result) = grpc_stream.next().await {
let response = result.map_err(|e| format!("Stream error: {}", e))?;
match response.response {
Some(Chunk(chunk)) => {
let index = chunk.index;
// Initialize parser for this index if needed
if let Vacant(e) = parsers.entry(index) {
e.insert(
HarmonyParserAdapter::new()
.map_err(|e| format!("Failed to create parser: {}", e))?,
);
is_firsts.insert(index, true);
}
// Track token counts
*completion_tokens.entry(index).or_insert(0) += 1;
// Parse chunk via Harmony parser
let parser = parsers
.get_mut(&index)
.ok_or("Parser not found for index")?;
let delta_result = parser
.parse_chunk(&chunk.token_ids)
.map_err(|e| format!("Parse error: {}", e))?;
// Emit SSE event if there's a delta
if let Some(delta) = delta_result {
let is_first = is_firsts.get(&index).copied().unwrap_or(false);
Self::emit_chunk_delta(
&delta,
index,
is_first,
&dispatch,
&original_request,
tx,
)?;
if is_first {
is_firsts.insert(index, false);
}
}
}
Some(Complete(complete)) => {
let index = complete.index;
// Store final metadata
finish_reasons.insert(index, Some(complete.finish_reason.clone()));
matched_stops.insert(
index,
complete.matched_stop.as_ref().map(|m| match m {
MatchedTokenId(id) => {
serde_json::json!(id)
}
MatchedStopStr(s) => {
serde_json::json!(s)
}
}),
);
prompt_tokens.insert(index, complete.prompt_tokens as u32);
*completion_tokens.entry(index).or_insert(0) =
complete.completion_tokens as u32;
// Finalize parser and emit final chunk
if let Some(parser) = parsers.get_mut(&index) {
let matched_stop = matched_stops.get(&index).and_then(|m| m.clone());
let final_output = parser
.finalize(complete.finish_reason.clone(), matched_stop.clone())
.map_err(|e| format!("Finalize error: {}", e))?;
Self::emit_final_chunk(
index,
&final_output.finish_reason,
final_output.matched_stop.as_ref(),
&dispatch,
&original_request,
tx,
)?;
}
}
Some(proto::generate_response::Response::Error(err)) => {
return Err(format!("Server error: {}", err.message));
}
None => {}
}
}
// Emit final usage if requested
if let Some(true) = stream_options.as_ref().and_then(|so| so.include_usage) {
let total_prompt: u32 = prompt_tokens.values().sum();
let total_completion: u32 = completion_tokens.values().sum();
Self::emit_usage_chunk(
total_prompt,
total_completion,
&dispatch,
&original_request,
tx,
)?;
}
// Mark stream as completed successfully to prevent abort on drop
grpc_stream.mark_completed();
Ok(())
}
/// Process streaming chunks from dual streams (prefill + decode)
async fn process_dual_stream(
mut prefill_stream: AbortOnDropStream,
mut decode_stream: AbortOnDropStream,
dispatch: context::DispatchMetadata,
original_request: Arc<ChatCompletionRequest>,
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
) -> Result<(), String> {
// Phase 1: Process prefill stream (collect metadata)
let mut prompt_tokens: HashMap<u32, u32> = HashMap::new();
while let Some(result) = prefill_stream.next().await {
let response = result.map_err(|e| format!("Prefill stream error: {}", e))?;
if let Some(Complete(complete)) = response.response {
prompt_tokens.insert(complete.index, complete.prompt_tokens as u32);
}
}
// Phase 2: Process decode stream (same as single stream)
let mut parsers: HashMap<u32, HarmonyParserAdapter> = HashMap::new();
let mut is_firsts: HashMap<u32, bool> = HashMap::new();
let mut finish_reasons: HashMap<u32, Option<String>> = HashMap::new();
let mut matched_stops: HashMap<u32, Option<serde_json::Value>> = HashMap::new();
let mut completion_tokens: HashMap<u32, u32> = HashMap::new();
let stream_options = &original_request.stream_options;
while let Some(result) = decode_stream.next().await {
let response = result.map_err(|e| format!("Decode stream error: {}", e))?;
match response.response {
Some(Chunk(chunk)) => {
let index = chunk.index;
// Initialize parser for this index if needed
if let Vacant(e) = parsers.entry(index) {
e.insert(
HarmonyParserAdapter::new()
.map_err(|e| format!("Failed to create parser: {}", e))?,
);
is_firsts.insert(index, true);
}
*completion_tokens.entry(index).or_insert(0) += 1;
let parser = parsers
.get_mut(&index)
.ok_or("Parser not found for index")?;
let delta_result = parser
.parse_chunk(&chunk.token_ids)
.map_err(|e| format!("Parse error: {}", e))?;
if let Some(delta) = delta_result {
let is_first = is_firsts.get(&index).copied().unwrap_or(false);
Self::emit_chunk_delta(
&delta,
index,
is_first,
&dispatch,
&original_request,
tx,
)?;
if is_first {
is_firsts.insert(index, false);
}
}
}
Some(Complete(complete)) => {
let index = complete.index;
finish_reasons.insert(index, Some(complete.finish_reason.clone()));
matched_stops.insert(
index,
complete.matched_stop.as_ref().map(|m| match m {
MatchedTokenId(id) => {
json!(id)
}
MatchedStopStr(s) => {
json!(s)
}
}),
);
*completion_tokens.entry(index).or_insert(0) =
complete.completion_tokens as u32;
if let Some(parser) = parsers.get_mut(&index) {
let matched_stop = matched_stops.get(&index).and_then(|m| m.clone());
let final_output = parser
.finalize(complete.finish_reason.clone(), matched_stop.clone())
.map_err(|e| format!("Finalize error: {}", e))?;
Self::emit_final_chunk(
index,
&final_output.finish_reason,
final_output.matched_stop.as_ref(),
&dispatch,
&original_request,
tx,
)?;
}
}
Some(proto::generate_response::Response::Error(err)) => {
return Err(format!("Server error: {}", err.message));
}
None => {}
}
}
decode_stream.mark_completed();
// Mark prefill stream as completed AFTER decode completes successfully
// This ensures that if client disconnects during decode, BOTH streams send abort
prefill_stream.mark_completed();
// Emit final usage if requested
if let Some(true) = stream_options.as_ref().and_then(|so| so.include_usage) {
let total_prompt: u32 = prompt_tokens.values().sum();
let total_completion: u32 = completion_tokens.values().sum();
Self::emit_usage_chunk(
total_prompt,
total_completion,
&dispatch,
&original_request,
tx,
)?;
}
Ok(())
}
/// Emit a chunk delta from Harmony channels
fn emit_chunk_delta(
delta: &HarmonyChannelDelta,
index: u32,
is_first: bool,
dispatch: &context::DispatchMetadata,
original_request: &ChatCompletionRequest,
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
) -> Result<(), String> {
// On first chunk, emit role announcement separately
if is_first {
let role_chunk = ChatCompletionStreamResponse {
id: dispatch.request_id.clone(),
object: "chat.completion.chunk".to_string(),
created: dispatch.created,
model: original_request.model.clone(),
system_fingerprint: dispatch.weight_version.clone(),
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: Some(String::new()),
tool_calls: None,
reasoning_content: None,
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
};
let chunk_json = serde_json::to_string(&role_chunk)
.map_err(|e| format!("JSON serialization error: {}", e))?;
let sse_data = format!("data: {}\n\n", chunk_json);
tx.send(Ok(Bytes::from(sse_data)))
.map_err(|_| "Failed to send role chunk".to_string())?;
}
// Emit content delta (role is always None for content chunks)
let chat_delta = ChatMessageDelta {
role: None,
content: delta.final_delta.clone(),
tool_calls: delta.commentary_delta.as_ref().map(|tc_delta| {
vec![ToolCallDelta {
index: tc_delta.index as u32,
id: tc_delta.id.clone(),
tool_type: tc_delta.id.as_ref().map(|_| "function".to_string()),
function: tc_delta.function.as_ref().map(|f| FunctionCallDelta {
name: f.name.clone(),
arguments: f.arguments.clone(),
}),
}]
}),
reasoning_content: delta.analysis_delta.clone(),
};
// Build and emit chunk
let chunk = ChatCompletionStreamResponse {
id: dispatch.request_id.clone(),
object: "chat.completion.chunk".to_string(),
created: dispatch.created,
model: original_request.model.clone(),
system_fingerprint: dispatch.weight_version.clone(),
choices: vec![ChatStreamChoice {
index,
delta: chat_delta,
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
};
let chunk_json = serde_json::to_string(&chunk)
.map_err(|e| format!("JSON serialization error: {}", e))?;
let sse_data = format!("data: {}\n\n", chunk_json);
tx.send(Ok(Bytes::from(sse_data)))
.map_err(|_| "Failed to send chunk".to_string())?;
Ok(())
}
/// Emit final chunk with finish_reason
fn emit_final_chunk(
index: u32,
finish_reason: &str,
matched_stop: Option<&serde_json::Value>,
dispatch: &context::DispatchMetadata,
original_request: &ChatCompletionRequest,
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
) -> Result<(), String> {
let chunk = ChatCompletionStreamResponse {
id: dispatch.request_id.clone(),
object: "chat.completion.chunk".to_string(),
created: dispatch.created,
model: original_request.model.clone(),
system_fingerprint: dispatch.weight_version.clone(),
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: None,
content: None,
tool_calls: None,
reasoning_content: None,
},
logprobs: None,
finish_reason: Some(finish_reason.to_string()),
matched_stop: matched_stop.cloned(),
}],
usage: None,
};
let chunk_json = serde_json::to_string(&chunk)
.map_err(|e| format!("JSON serialization error: {}", e))?;
let sse_data = format!("data: {}\n\n", chunk_json);
tx.send(Ok(Bytes::from(sse_data)))
.map_err(|_| "Failed to send final chunk".to_string())?;
Ok(())
}
/// Emit usage chunk at the end
fn emit_usage_chunk(
prompt_tokens: u32,
completion_tokens: u32,
dispatch: &context::DispatchMetadata,
original_request: &ChatCompletionRequest,
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
) -> Result<(), String> {
let usage_chunk = ChatCompletionStreamResponse {
id: dispatch.request_id.clone(),
object: "chat.completion.chunk".to_string(),
created: dispatch.created,
model: original_request.model.clone(),
system_fingerprint: dispatch.weight_version.clone(),
choices: vec![],
usage: Some(Usage {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
completion_tokens_details: None,
}),
};
let chunk_json = serde_json::to_string(&usage_chunk)
.map_err(|e| format!("JSON serialization error: {}", e))?;
let sse_data = format!("data: {}\n\n", chunk_json);
tx.send(Ok(Bytes::from(sse_data)))
.map_err(|_| "Failed to send usage chunk".to_string())?;
Ok(())
}
/// Build SSE response from receiver
fn build_sse_response(rx: mpsc::UnboundedReceiver<Result<Bytes, io::Error>>) -> Response {
let stream = UnboundedReceiverStream::new(rx);
let body = Body::from_stream(stream);
Response::builder()
.status(StatusCode::OK)
.header(
CONTENT_TYPE,
HeaderValue::from_static("text/event-stream; charset=utf-8"),
)
.header("Cache-Control", HeaderValue::from_static("no-cache"))
.header("Connection", HeaderValue::from_static("keep-alive"))
.body(body)
.unwrap()
}
}
impl Default for HarmonyStreamingProcessor {
fn default() -> Self {
Self::new()
}
}
//! Shared types for Harmony pipeline
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::protocols::common::ToolCall;
/// Harmony message format
///
/// Represents messages in the Harmony encoding format with role and content.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HarmonyMessage {
pub role: String,
pub content: String,
}
impl HarmonyMessage {
pub fn new(role: impl Into<String>, content: impl Into<String>) -> Self {
Self {
role: role.into(),
content: content.into(),
}
}
pub fn user(content: impl Into<String>) -> Self {
Self::new("user", content)
}
pub fn assistant(content: impl Into<String>) -> Self {
Self::new("assistant", content)
}
pub fn system(content: impl Into<String>) -> Self {
Self::new("system", content)
}
/// Convert from openai_harmony::chat::Message to our simplified HarmonyMessage
pub fn from_openai_harmony(msg: openai_harmony::chat::Message) -> Self {
use openai_harmony::chat::Content;
// Extract role as string
let role = match msg.author.role {
openai_harmony::chat::Role::User => "user",
openai_harmony::chat::Role::Assistant => "assistant",
openai_harmony::chat::Role::System => "system",
openai_harmony::chat::Role::Developer => "developer",
openai_harmony::chat::Role::Tool => "tool",
}
.to_string();
// Extract text content from all Content::Text parts
let content = msg
.content
.iter()
.filter_map(|c| match c {
Content::Text(tc) => Some(tc.text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("");
Self { role, content }
}
}
/// Output from Harmony encoding process
///
/// Contains the encoded input_ids, stop tokens, selection text for worker routing,
/// and the Harmony message history.
#[derive(Debug, Clone)]
pub struct HarmonyBuildOutput {
/// Encoded token IDs to send to the model
pub input_ids: Vec<u32>,
/// Stop token IDs for this model (injected into sampling params)
pub stop_token_ids: Vec<u32>,
/// Selection text for worker routing (concise snippet from last user message)
pub selection_text: String,
/// Harmony messages for this conversation (used for history tracking)
pub harmony_messages: Vec<HarmonyMessage>,
}
/// Parsed output from all three Harmony channels
///
/// Represents the complete response after parsing analysis, commentary, and final channels.
#[derive(Debug, Clone)]
pub struct HarmonyChannelOutput {
/// Analysis/reasoning content (from analysis channel)
pub analysis: Option<String>,
/// Tool calls (from commentary channel)
pub commentary: Option<Vec<ToolCall>>,
/// Final text content (from final channel)
pub final_text: String,
/// Finish reason
pub finish_reason: String,
/// Matched stop token (if any)
pub matched_stop: Option<Value>,
}
/// Streaming delta for SSE responses
///
/// Represents incremental updates as tokens are parsed from the stream.
#[derive(Debug, Clone)]
pub struct HarmonyChannelDelta {
/// Delta for analysis/reasoning content
pub analysis_delta: Option<String>,
/// Delta for tool calls
pub commentary_delta: Option<ToolCallDelta>,
/// Delta for final text content
pub final_delta: Option<String>,
/// Whether this is the final delta
pub is_final: bool,
}
/// Tool call delta for streaming
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallDelta {
pub index: usize,
pub id: Option<String>,
pub function: Option<FunctionDelta>,
}
/// Function call delta for streaming
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionDelta {
pub name: Option<String>,
pub arguments: Option<String>,
}
...@@ -3,11 +3,13 @@ ...@@ -3,11 +3,13 @@
use crate::{grpc_client::proto, protocols::common::StringOrArray}; use crate::{grpc_client::proto, protocols::common::StringOrArray};
pub mod context; pub mod context;
pub mod harmony;
pub mod pd_router; pub mod pd_router;
pub mod pipeline; pub mod pipeline;
pub mod processing; pub mod processing;
pub mod responses; pub mod responses;
pub mod router; pub mod router;
pub mod stages;
pub mod streaming; pub mod streaming;
pub mod utils; pub mod utils;
......
...@@ -14,9 +14,7 @@ use tracing::debug; ...@@ -14,9 +14,7 @@ use tracing::debug;
use super::{context::SharedComponents, pipeline::RequestPipeline}; use super::{context::SharedComponents, pipeline::RequestPipeline};
use crate::{ use crate::{
app_context::AppContext, app_context::AppContext,
config::types::RetryConfig,
core::{ConnectionMode, WorkerRegistry, WorkerType}, core::{ConnectionMode, WorkerRegistry, WorkerType},
policies::PolicyRegistry,
protocols::{ protocols::{
chat::ChatCompletionRequest, chat::ChatCompletionRequest,
classify::ClassifyRequest, classify::ClassifyRequest,
...@@ -26,26 +24,13 @@ use crate::{ ...@@ -26,26 +24,13 @@ use crate::{
rerank::RerankRequest, rerank::RerankRequest,
responses::{ResponsesGetParams, ResponsesRequest}, responses::{ResponsesGetParams, ResponsesRequest},
}, },
reasoning_parser::ParserFactory as ReasoningParserFactory,
routers::RouterTrait, routers::RouterTrait,
tokenizer::traits::Tokenizer,
tool_parser::ParserFactory as ToolParserFactory,
}; };
/// gRPC PD (Prefill-Decode) router implementation for SGLang /// gRPC PD (Prefill-Decode) router implementation for SGLang
#[derive(Clone)] #[derive(Clone)]
#[allow(dead_code)] // Fields will be used once implementation is complete
pub struct GrpcPDRouter { pub struct GrpcPDRouter {
worker_registry: Arc<WorkerRegistry>, worker_registry: Arc<WorkerRegistry>,
policy_registry: Arc<PolicyRegistry>,
tokenizer: Arc<dyn Tokenizer>,
reasoning_parser_factory: ReasoningParserFactory,
tool_parser_factory: ToolParserFactory,
dp_aware: bool,
api_key: Option<String>,
retry_config: RetryConfig,
configured_reasoning_parser: Option<String>,
configured_tool_parser: Option<String>,
pipeline: RequestPipeline, pipeline: RequestPipeline,
shared_components: Arc<SharedComponents>, shared_components: Arc<SharedComponents>,
} }
...@@ -94,15 +79,6 @@ impl GrpcPDRouter { ...@@ -94,15 +79,6 @@ impl GrpcPDRouter {
Ok(GrpcPDRouter { Ok(GrpcPDRouter {
worker_registry, worker_registry,
policy_registry,
tokenizer,
reasoning_parser_factory,
tool_parser_factory,
dp_aware: ctx.router_config.dp_aware,
api_key: ctx.router_config.api_key.clone(),
retry_config: ctx.router_config.effective_retry_config(),
configured_reasoning_parser: ctx.configured_reasoning_parser.clone(),
configured_tool_parser: ctx.configured_tool_parser.clone(),
pipeline, pipeline,
shared_components, shared_components,
}) })
...@@ -174,7 +150,6 @@ impl std::fmt::Debug for GrpcPDRouter { ...@@ -174,7 +150,6 @@ impl std::fmt::Debug for GrpcPDRouter {
f.debug_struct("GrpcPDRouter") f.debug_struct("GrpcPDRouter")
.field("prefill_workers_count", &prefill_workers.len()) .field("prefill_workers_count", &prefill_workers.len())
.field("decode_workers_count", &decode_workers.len()) .field("decode_workers_count", &decode_workers.len())
.field("dp_aware", &self.dp_aware)
.finish() .finish()
} }
} }
...@@ -255,19 +230,19 @@ impl RouterTrait for GrpcPDRouter { ...@@ -255,19 +230,19 @@ impl RouterTrait for GrpcPDRouter {
(StatusCode::NOT_IMPLEMENTED).into_response() (StatusCode::NOT_IMPLEMENTED).into_response()
} }
async fn route_classify( async fn route_embeddings(
&self, &self,
_headers: Option<&HeaderMap>, _headers: Option<&HeaderMap>,
_body: &ClassifyRequest, _body: &EmbeddingRequest,
_model_id: Option<&str>, _model_id: Option<&str>,
) -> Response { ) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response() (StatusCode::NOT_IMPLEMENTED).into_response()
} }
async fn route_embeddings( async fn route_classify(
&self, &self,
_headers: Option<&HeaderMap>, _headers: Option<&HeaderMap>,
_body: &EmbeddingRequest, _body: &ClassifyRequest,
_model_id: Option<&str>, _model_id: Option<&str>,
) -> Response { ) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response() (StatusCode::NOT_IMPLEMENTED).into_response()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment