"git@developer.sourcefind.cn:OpenDAS/nerfacc.git" did not exist on "2e7ad6e0bd797fa3e05bcdf4301adf2ad1e3a7f1"
Unverified Commit 83804bc6 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][grpc] Restructure modules and code clean up (#12598)

parent d5fa019c
...@@ -251,9 +251,10 @@ pub enum ResponseOutputItem { ...@@ -251,9 +251,10 @@ pub enum ResponseOutputItem {
// Configuration Enums // Configuration Enums
// ============================================================================ // ============================================================================
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize, Default)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum ServiceTier { pub enum ServiceTier {
#[default]
Auto, Auto,
Default, Default,
Flex, Flex,
...@@ -261,25 +262,14 @@ pub enum ServiceTier { ...@@ -261,25 +262,14 @@ pub enum ServiceTier {
Priority, Priority,
} }
impl Default for ServiceTier { #[derive(Debug, Clone, Deserialize, Serialize, Default)]
fn default() -> Self {
Self::Auto
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum Truncation { pub enum Truncation {
Auto, Auto,
#[default]
Disabled, Disabled,
} }
impl Default for Truncation {
fn default() -> Self {
Self::Disabled
}
}
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum ResponseStatus { pub enum ResponseStatus {
......
//! Shared code for both regular and harmony routers
pub mod response_collection;
pub mod response_formatting;
pub mod responses;
pub mod stages;
//! Shared response collection logic
//!
//! This module contains common logic for collecting responses from execution results.
//! Both regular and harmony processors use these functions to avoid duplication.
use axum::response::Response;
use crate::{
grpc_client::proto,
routers::grpc::{context::ExecutionResult, error, utils},
};
/// Collect and merge responses from execution result
///
/// Handles both Single and Dual (prefill-decode) execution modes.
/// For Dual mode, merges prefill input_logprobs into decode responses if requested.
///
/// # Arguments
/// * `execution_result` - The execution result containing stream(s)
/// * `merge_logprobs` - Whether to merge prefill input_logprobs (for chat with logprobs=true)
///
/// # Returns
/// Vector of GenerateComplete responses, one per index (n parameter)
pub async fn collect_responses(
execution_result: ExecutionResult,
merge_logprobs: bool,
) -> Result<Vec<proto::GenerateComplete>, Response> {
let all_responses = match execution_result {
ExecutionResult::Single { mut stream } => {
let responses = utils::collect_stream_responses(&mut stream, "Single").await?;
stream.mark_completed();
responses
}
ExecutionResult::Dual {
mut prefill,
decode,
} => {
// Collect prefill for input_logprobs (don't mark completed yet)
let prefill_responses =
utils::collect_stream_responses(&mut prefill, "Prefill").await?;
// Collect decode for actual output (don't mark completed yet)
let mut decode_stream = *decode;
let mut decode_responses =
utils::collect_stream_responses(&mut decode_stream, "Decode").await?;
// Mark both streams as completed now that both succeeded
prefill.mark_completed();
decode_stream.mark_completed();
// Merge prefill input_logprobs if requested
if merge_logprobs {
merge_prefill_logprobs(&prefill_responses, &mut decode_responses);
}
decode_responses
}
};
if all_responses.is_empty() {
return Err(error::internal_error("No responses from server"));
}
Ok(all_responses)
}
/// Merge prefill input_logprobs into decode responses
///
/// Takes input_logprobs from the first prefill response and copies them
/// into all decode responses. This is used in PD mode when logprobs are requested.
fn merge_prefill_logprobs(
prefill_responses: &[proto::GenerateComplete],
decode_responses: &mut [proto::GenerateComplete],
) {
if let Some(prefill_input_logprobs) = prefill_responses
.first()
.and_then(|r| r.input_logprobs.clone())
{
for response in decode_responses.iter_mut() {
response.input_logprobs = Some(prefill_input_logprobs.clone());
}
}
}
//! Shared response formatting logic
//!
//! This module contains common logic for formatting responses, including:
//! - Usage calculation from gRPC responses
//! - ChatCompletionResponse construction
use crate::{
grpc_client::proto,
protocols::{
chat::{ChatChoice, ChatCompletionResponse},
common::Usage,
},
routers::grpc::context::DispatchMetadata,
};
/// Build usage information from collected gRPC responses
///
/// Sums prompt_tokens and completion_tokens across all responses.
/// Typically used with n>1 parameter where multiple completions are generated.
///
/// # Arguments
/// * `responses` - Vector of GenerateComplete responses from the backend
///
/// # Returns
/// Usage object with aggregated token counts
pub fn build_usage(responses: &[proto::GenerateComplete]) -> Usage {
let total_prompt_tokens: u32 = responses.iter().map(|r| r.prompt_tokens as u32).sum();
let total_completion_tokens: u32 = responses.iter().map(|r| r.completion_tokens as u32).sum();
Usage {
prompt_tokens: total_prompt_tokens,
completion_tokens: total_completion_tokens,
total_tokens: total_prompt_tokens + total_completion_tokens,
completion_tokens_details: None,
}
}
/// Build final ChatCompletionResponse from processed choices
///
/// Constructs the OpenAI-compatible response object with all metadata.
///
/// # Arguments
/// * `choices` - Processed chat choices (after parsing, logprobs, etc.)
/// * `dispatch` - Dispatch metadata (request_id, created timestamp, etc.)
/// * `model` - Model name to include in response
/// * `usage` - Token usage information
///
/// # Returns
/// Complete ChatCompletionResponse ready to send to client
pub fn build_chat_response(
choices: Vec<ChatChoice>,
dispatch: &DispatchMetadata,
model: String,
usage: Usage,
) -> ChatCompletionResponse {
ChatCompletionResponse {
id: dispatch.request_id.clone(),
object: "chat.completion".to_string(),
created: dispatch.created,
model,
choices,
usage: Some(usage),
system_fingerprint: dispatch.weight_version.clone(),
}
}
//! Shared response handlers for both regular and harmony implementations
//!
//! These handlers are used by both pipelines for retrieving and cancelling responses.
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
};
use serde_json::json;
use tracing::{debug, error, warn};
use crate::{
data_connector::ResponseId, routers::grpc::regular::responses::context::ResponsesContext,
};
/// Implementation for GET /v1/responses/{response_id}
///
/// Retrieves a stored response from the database.
/// Used by both regular and harmony implementations.
pub async fn get_response_impl(ctx: &ResponsesContext, response_id: &str) -> Response {
let resp_id = ResponseId::from(response_id);
// Retrieve response from storage
match ctx.response_storage.get_response(&resp_id).await {
Ok(Some(stored_response)) => axum::Json(stored_response.raw_response).into_response(),
Ok(None) => (
StatusCode::NOT_FOUND,
axum::Json(json!({
"error": {
"message": format!("Response with id '{}' not found", response_id),
"type": "not_found_error",
"code": "response_not_found"
}
})),
)
.into_response(),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
axum::Json(json!({
"error": {
"message": format!("Failed to retrieve response: {}", e),
"type": "internal_error"
}
})),
)
.into_response(),
}
}
/// Implementation for POST /v1/responses/{response_id}/cancel
///
/// Cancels a background response if it's still in progress.
pub async fn cancel_response_impl(ctx: &ResponsesContext, response_id: &str) -> Response {
let resp_id = ResponseId::from(response_id);
// Retrieve response from storage to check if it exists and get current status
match ctx.response_storage.get_response(&resp_id).await {
Ok(Some(stored_response)) => {
// Check current status - only queued or in_progress responses can be cancelled
let current_status = stored_response
.raw_response
.get("status")
.and_then(|v| v.as_str())
.unwrap_or("unknown");
match current_status {
"queued" | "in_progress" => {
// Attempt to abort the background task
let mut tasks = ctx.background_tasks.write().await;
if let Some(task_info) = tasks.remove(response_id) {
// Abort the Rust task immediately
task_info.handle.abort();
// Abort the Python/scheduler request via gRPC (if client is available)
let client_opt = task_info.client.read().await;
if let Some(ref client) = *client_opt {
if let Err(e) = client
.abort_request(
task_info.grpc_request_id.clone(),
"User cancelled via API".to_string(),
)
.await
{
warn!(
"Failed to abort Python request {}: {}",
task_info.grpc_request_id, e
);
} else {
debug!(
"Successfully aborted Python request: {}",
task_info.grpc_request_id
);
}
} else {
debug!("Client not yet available for abort, request may not have started yet");
}
// Task was found and aborted
(
StatusCode::OK,
axum::Json(json!({
"id": response_id,
"status": "cancelled",
"message": "Background task has been cancelled"
})),
)
.into_response()
} else {
// Task handle not found but status is queued/in_progress
// This can happen if: (1) task crashed, or (2) storage persistence failed
error!(
"Response {} has status '{}' but task handle is missing. Task may have crashed or storage update failed.",
response_id, current_status
);
(
StatusCode::INTERNAL_SERVER_ERROR,
axum::Json(json!({
"error": {
"message": "Internal error: background task completed but failed to update status in storage",
"type": "internal_error",
"code": "status_update_failed"
}
})),
)
.into_response()
}
}
"completed" => (
StatusCode::BAD_REQUEST,
axum::Json(json!({
"error": {
"message": "Cannot cancel completed response",
"type": "invalid_request_error",
"code": "response_already_completed"
}
})),
)
.into_response(),
"failed" => (
StatusCode::BAD_REQUEST,
axum::Json(json!({
"error": {
"message": "Cannot cancel failed response",
"type": "invalid_request_error",
"code": "response_already_failed"
}
})),
)
.into_response(),
"cancelled" => (
StatusCode::OK,
axum::Json(json!({
"id": response_id,
"status": "cancelled",
"message": "Response was already cancelled"
})),
)
.into_response(),
_ => {
// Unknown status
(
StatusCode::INTERNAL_SERVER_ERROR,
axum::Json(json!({
"error": {
"message": format!("Unknown response status: {}", current_status),
"type": "internal_error"
}
})),
)
.into_response()
}
}
}
Ok(None) => (
StatusCode::NOT_FOUND,
axum::Json(json!({
"error": {
"message": format!("Response with id '{}' not found", response_id),
"type": "not_found_error",
"code": "response_not_found"
}
})),
)
.into_response(),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
axum::Json(json!({
"error": {
"message": format!("Failed to retrieve response: {}", e),
"type": "internal_error"
}
})),
)
.into_response(),
}
}
//! Shared response functionality used by both regular and harmony implementations
pub mod handlers;
pub mod streaming;
pub use handlers::{cancel_response_impl, get_response_impl};
pub use streaming::{OutputItemType, ResponseStreamEventEmitter};
...@@ -7,7 +7,7 @@ use serde_json::json; ...@@ -7,7 +7,7 @@ use serde_json::json;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use uuid::Uuid; use uuid::Uuid;
use crate::protocols::chat::ChatCompletionStreamResponse; use crate::{mcp, protocols::chat::ChatCompletionStreamResponse};
pub enum OutputItemType { pub enum OutputItemType {
Message, Message,
...@@ -30,10 +30,6 @@ struct OutputItemState { ...@@ -30,10 +30,6 @@ struct OutputItemState {
status: ItemStatus, status: ItemStatus,
} }
// ============================================================================
// Streaming Event Emitter
// ============================================================================
/// OpenAI-compatible event emitter for /v1/responses streaming /// OpenAI-compatible event emitter for /v1/responses streaming
/// ///
/// Manages state and sequence numbers to emit proper event types: /// Manages state and sequence numbers to emit proper event types:
...@@ -66,7 +62,7 @@ pub struct ResponseStreamEventEmitter { ...@@ -66,7 +62,7 @@ pub struct ResponseStreamEventEmitter {
has_emitted_content_part_added: bool, has_emitted_content_part_added: bool,
// MCP call tracking // MCP call tracking
mcp_call_accumulated_args: HashMap<String, String>, mcp_call_accumulated_args: HashMap<String, String>,
// Output item tracking (NEW) // Output item tracking
output_items: Vec<OutputItemState>, output_items: Vec<OutputItemState>,
next_output_index: usize, next_output_index: usize,
current_message_output_index: Option<usize>, // Tracks output_index of current message current_message_output_index: Option<usize>, // Tracks output_index of current message
...@@ -248,7 +244,7 @@ impl ResponseStreamEventEmitter { ...@@ -248,7 +244,7 @@ impl ResponseStreamEventEmitter {
pub fn emit_mcp_list_tools_completed( pub fn emit_mcp_list_tools_completed(
&mut self, &mut self,
output_index: usize, output_index: usize,
tools: &[crate::mcp::Tool], tools: &[mcp::Tool],
) -> serde_json::Value { ) -> serde_json::Value {
let tool_items: Vec<_> = tools let tool_items: Vec<_> = tools
.iter() .iter()
...@@ -331,7 +327,7 @@ impl ResponseStreamEventEmitter { ...@@ -331,7 +327,7 @@ impl ResponseStreamEventEmitter {
}) })
} }
pub(super) fn emit_mcp_call_failed( pub fn emit_mcp_call_failed(
&mut self, &mut self,
output_index: usize, output_index: usize,
item_id: &str, item_id: &str,
...@@ -453,7 +449,7 @@ impl ResponseStreamEventEmitter { ...@@ -453,7 +449,7 @@ impl ResponseStreamEventEmitter {
} }
/// Process a chunk and emit appropriate events /// Process a chunk and emit appropriate events
pub(super) fn process_chunk( pub fn process_chunk(
&mut self, &mut self,
chunk: &ChatCompletionStreamResponse, chunk: &ChatCompletionStreamResponse,
tx: &mpsc::UnboundedSender<Result<Bytes, std::io::Error>>, tx: &mpsc::UnboundedSender<Result<Bytes, std::io::Error>>,
......
//! Common helper functions shared across stages
use std::sync::Arc;
use proto::DisaggregatedParams;
use rand::Rng;
use tracing::debug;
use crate::{core::Worker, grpc_client::proto};
/// Inject PD bootstrap metadata into a gRPC request
///
/// Used by both chat and generate request building stages when in PD mode.
pub fn inject_bootstrap_metadata(
request: &mut proto::GenerateRequest,
prefill_worker: &Arc<dyn Worker>,
) {
let hostname = prefill_worker.bootstrap_host();
let bootstrap_port = prefill_worker.bootstrap_port().unwrap_or(8998);
// Generate room ID for bootstrap
let room_id = rand::rng().random_range(0..i32::MAX);
// Create DisaggregatedParams
let disagg_params = DisaggregatedParams {
bootstrap_host: hostname.to_string(),
bootstrap_port: bootstrap_port as i32,
bootstrap_room: room_id,
};
// Inject metadata directly into request
request.disaggregated_params = Some(disagg_params);
debug!(
"Injected bootstrap metadata: host={}, port={}, room={}",
hostname, bootstrap_port, room_id
);
}
//! Pipeline stages for gRPC router request processing //! Common pipeline stages shared across all endpoints and model types
//! //!
//! This module defines the core pipeline abstraction and individual processing stages //! These stages are endpoint-agnostic and model-agnostic:
//! that transform a RequestContext through its lifecycle. //! - Worker selection
//! - Client acquisition
//! - Dispatch metadata generation
//! - Request execution
use async_trait::async_trait; use async_trait::async_trait;
use axum::response::Response; use axum::response::Response;
use crate::routers::grpc::context::RequestContext; use crate::routers::grpc::context::RequestContext;
// ============================================================================
// Pipeline Trait
// ============================================================================
/// Trait for pipeline stages that process requests /// Trait for pipeline stages that process requests
#[async_trait] #[async_trait]
pub trait PipelineStage: Send + Sync { pub trait PipelineStage: Send + Sync {
...@@ -27,26 +26,14 @@ pub trait PipelineStage: Send + Sync { ...@@ -27,26 +26,14 @@ pub trait PipelineStage: Send + Sync {
fn name(&self) -> &'static str; fn name(&self) -> &'static str;
} }
// ============================================================================
// Stage Modules
// ============================================================================
mod client_acquisition; mod client_acquisition;
mod dispatch_metadata; mod dispatch_metadata;
mod preparation; pub mod helpers;
mod request_building;
mod request_execution; mod request_execution;
mod response_processing;
mod worker_selection; mod worker_selection;
// ============================================================================ // Export stage implementations
// Public Exports
// ============================================================================
pub use client_acquisition::ClientAcquisitionStage; pub use client_acquisition::ClientAcquisitionStage;
pub use dispatch_metadata::DispatchMetadataStage; pub use dispatch_metadata::DispatchMetadataStage;
pub use preparation::PreparationStage;
pub use request_building::RequestBuildingStage;
pub use request_execution::{ExecutionMode, RequestExecutionStage}; pub use request_execution::{ExecutionMode, RequestExecutionStage};
pub use response_processing::ResponseProcessingStage;
pub use worker_selection::{WorkerSelectionMode, WorkerSelectionStage}; pub use worker_selection::{WorkerSelectionMode, WorkerSelectionStage};
...@@ -109,7 +109,6 @@ impl WorkerSelectionStage { ...@@ -109,7 +109,6 @@ impl WorkerSelectionStage {
false, // get all workers, we'll filter by is_available() next false, // get all workers, we'll filter by is_available() next
); );
// Filter by availability (health + circuit breaker)
let available: Vec<Arc<dyn Worker>> = workers let available: Vec<Arc<dyn Worker>> = workers
.iter() .iter()
.filter(|w| w.is_available()) .filter(|w| w.is_available())
......
...@@ -11,7 +11,7 @@ use serde_json::Value; ...@@ -11,7 +11,7 @@ use serde_json::Value;
use crate::{ use crate::{
core::Worker, core::Worker,
grpc_client::{proto, SglangSchedulerClient}, grpc_client::{proto, sglang_scheduler::AbortOnDropStream, SglangSchedulerClient},
protocols::{ protocols::{
chat::{ChatCompletionRequest, ChatCompletionResponse}, chat::{ChatCompletionRequest, ChatCompletionResponse},
generate::{GenerateRequest, GenerateResponse}, generate::{GenerateRequest, GenerateResponse},
...@@ -22,23 +22,14 @@ use crate::{ ...@@ -22,23 +22,14 @@ use crate::{
tool_parser::ParserFactory as ToolParserFactory, tool_parser::ParserFactory as ToolParserFactory,
}; };
// ============================================================================
// Core Context Types
// ============================================================================
/// Main request processing context /// Main request processing context
/// ///
/// This is the single source of truth for all request state as it flows /// This is the single source of truth for all request state as it flows
/// through the pipeline stages. Uses Rust's type system to enforce proper /// through the pipeline stages. Uses Rust's type system to enforce proper
/// stage ordering at compile time. /// stage ordering at compile time.
pub struct RequestContext { pub struct RequestContext {
// === Input (Immutable) ===
pub input: RequestInput, pub input: RequestInput,
// === Shared Components (Immutable References) ===
pub components: Arc<SharedComponents>, pub components: Arc<SharedComponents>,
// === Processing State (Mutable, evolves through pipeline) ===
pub state: ProcessingState, pub state: ProcessingState,
} }
...@@ -86,10 +77,6 @@ pub struct ProcessingState { ...@@ -86,10 +77,6 @@ pub struct ProcessingState {
pub response: ResponseState, pub response: ResponseState,
} }
// ============================================================================
// Stage-Specific Output Types
// ============================================================================
/// Output from preparation stage (Step 1) /// Output from preparation stage (Step 1)
pub struct PreparationOutput { pub struct PreparationOutput {
/// Original text (for chat) or resolved text (for generate) /// Original text (for chat) or resolved text (for generate)
...@@ -201,10 +188,6 @@ pub struct StreamingState { ...@@ -201,10 +188,6 @@ pub struct StreamingState {
pub has_tool_calls: HashMap<u32, bool>, pub has_tool_calls: HashMap<u32, bool>,
} }
// ============================================================================
// Context Builders
// ============================================================================
impl RequestContext { impl RequestContext {
/// Create context for chat completion request /// Create context for chat completion request
pub fn for_chat( pub fn for_chat(
...@@ -323,14 +306,6 @@ impl RequestContext { ...@@ -323,14 +306,6 @@ impl RequestContext {
} }
} }
// ============================================================================
// Default Implementations
// ============================================================================
// ============================================================================
// Helper Methods
// ============================================================================
impl WorkerSelection { impl WorkerSelection {
pub fn is_dual(&self) -> bool { pub fn is_dual(&self) -> bool {
matches!(self, Self::Dual { .. }) matches!(self, Self::Dual { .. })
...@@ -428,12 +403,6 @@ impl ClientSelection { ...@@ -428,12 +403,6 @@ impl ClientSelection {
} }
} }
// ============================================================================
// Execution and Response Types
// ============================================================================
use crate::grpc_client::sglang_scheduler::AbortOnDropStream;
/// Result of request execution (streams from workers) /// Result of request execution (streams from workers)
/// Uses AbortOnDropStream to automatically abort on cancellation /// Uses AbortOnDropStream to automatically abort on cancellation
pub enum ExecutionResult { pub enum ExecutionResult {
......
...@@ -17,8 +17,9 @@ use crate::{ ...@@ -17,8 +17,9 @@ use crate::{
}, },
}, },
routers::grpc::{ routers::grpc::{
common::{response_collection, response_formatting},
context::{DispatchMetadata, ExecutionResult}, context::{DispatchMetadata, ExecutionResult},
error, utils, error,
}, },
}; };
...@@ -34,28 +35,6 @@ impl HarmonyResponseProcessor { ...@@ -34,28 +35,6 @@ impl HarmonyResponseProcessor {
Self Self
} }
/// Collect responses from ExecutionResult (similar to regular processor)
async fn collect_responses(
execution_result: ExecutionResult,
) -> Result<Vec<proto::GenerateComplete>, Response> {
match execution_result {
ExecutionResult::Single { mut stream } => {
let responses = utils::collect_stream_responses(&mut stream, "Single").await?;
stream.mark_completed();
Ok(responses)
}
ExecutionResult::Dual { prefill, decode } => {
// For Harmony we currently rely only on decode stream for outputs
let mut decode_stream = *decode;
let responses =
utils::collect_stream_responses(&mut decode_stream, "Decode").await?;
prefill.mark_completed();
decode_stream.mark_completed();
Ok(responses)
}
}
}
/// Process a non-streaming Harmony chat response /// Process a non-streaming Harmony chat response
pub async fn process_non_streaming_chat_response( pub async fn process_non_streaming_chat_response(
&self, &self,
...@@ -64,7 +43,7 @@ impl HarmonyResponseProcessor { ...@@ -64,7 +43,7 @@ impl HarmonyResponseProcessor {
dispatch: DispatchMetadata, dispatch: DispatchMetadata,
) -> Result<ChatCompletionResponse, Response> { ) -> Result<ChatCompletionResponse, Response> {
// Collect all completed responses (one per choice) // Collect all completed responses (one per choice)
let all_responses = Self::collect_responses(execution_result).await?; let all_responses = response_collection::collect_responses(execution_result, false).await?;
if all_responses.is_empty() { if all_responses.is_empty() {
return Err(error::internal_error("No responses from server")); return Err(error::internal_error("No responses from server"));
} }
...@@ -117,28 +96,15 @@ impl HarmonyResponseProcessor { ...@@ -117,28 +96,15 @@ impl HarmonyResponseProcessor {
} }
// Build usage from proto fields // Build usage from proto fields
let prompt_tokens: u32 = all_responses.iter().map(|r| r.prompt_tokens as u32).sum(); let usage = response_formatting::build_usage(&all_responses);
let completion_tokens: u32 = all_responses
.iter()
.map(|r| r.completion_tokens as u32)
.sum();
let usage = Usage {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
completion_tokens_details: None,
};
// Final ChatCompletionResponse // Final ChatCompletionResponse
let response = ChatCompletionResponse { let response = response_formatting::build_chat_response(
id: dispatch.request_id.clone(),
object: "chat.completion".to_string(),
created: dispatch.created,
model: chat_request.model.clone(),
choices, choices,
usage: Some(usage), &dispatch,
system_fingerprint: dispatch.weight_version.clone(), chat_request.model.clone(),
}; usage,
);
Ok(response) Ok(response)
} }
...@@ -191,7 +157,7 @@ impl HarmonyResponseProcessor { ...@@ -191,7 +157,7 @@ impl HarmonyResponseProcessor {
dispatch: DispatchMetadata, dispatch: DispatchMetadata,
) -> Result<ResponsesIterationResult, Response> { ) -> Result<ResponsesIterationResult, Response> {
// Collect all completed responses // Collect all completed responses
let all_responses = Self::collect_responses(execution_result).await?; let all_responses = response_collection::collect_responses(execution_result, false).await?;
if all_responses.is_empty() { if all_responses.is_empty() {
return Err(error::internal_error("No responses from server")); return Err(error::internal_error("No responses from server"));
} }
...@@ -280,14 +246,7 @@ impl HarmonyResponseProcessor { ...@@ -280,14 +246,7 @@ impl HarmonyResponseProcessor {
} }
// Build usage // Build usage
let prompt_tokens = complete.prompt_tokens as u32; let usage = response_formatting::build_usage(std::slice::from_ref(complete));
let completion_tokens = complete.completion_tokens as u32;
let usage = Usage {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
completion_tokens_details: None,
};
// Build ResponsesResponse with all required fields // Build ResponsesResponse with all required fields
let response = ResponsesResponse { let response = ResponsesResponse {
...@@ -316,9 +275,9 @@ impl HarmonyResponseProcessor { ...@@ -316,9 +275,9 @@ impl HarmonyResponseProcessor {
top_p: responses_request.top_p, top_p: responses_request.top_p,
truncation: None, truncation: None,
usage: Some(ResponsesUsage::Modern(ResponseUsage { usage: Some(ResponsesUsage::Modern(ResponseUsage {
input_tokens: prompt_tokens, input_tokens: usage.prompt_tokens,
output_tokens: completion_tokens, output_tokens: usage.completion_tokens,
total_tokens: prompt_tokens + completion_tokens, total_tokens: usage.total_tokens,
input_tokens_details: None, input_tokens_details: None,
output_tokens_details: None, output_tokens_details: None,
})), })),
......
...@@ -37,12 +37,14 @@ ...@@ -37,12 +37,14 @@
//! for complete architecture, rationale, and implementation details. //! for complete architecture, rationale, and implementation details.
use std::{ use std::{
io,
sync::Arc, sync::Arc,
time::{SystemTime, UNIX_EPOCH}, time::{SystemTime, UNIX_EPOCH},
}; };
use axum::{body::Body, http::StatusCode, response::Response}; use axum::{body::Body, http::StatusCode, response::Response};
use serde_json::Value as JsonValue; use bytes::Bytes;
use serde_json::{from_str, from_value, json, to_string, to_value, Value};
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, warn}; use tracing::{debug, warn};
...@@ -50,21 +52,22 @@ use uuid::Uuid; ...@@ -50,21 +52,22 @@ use uuid::Uuid;
use crate::{ use crate::{
data_connector::{ResponseId, ResponseStorage}, data_connector::{ResponseId, ResponseStorage},
mcp::McpManager, mcp::{self, McpManager},
protocols::{ protocols::{
common::{Function, ToolCall}, common::{Function, ToolCall},
responses::{ responses::{
ResponseInput, ResponseInputOutputItem, ResponseTool, ResponseToolType, McpToolInfo, ResponseContentPart, ResponseInput, ResponseInputOutputItem,
ResponseOutputItem, ResponseReasoningContent, ResponseTool, ResponseToolType,
ResponsesRequest, ResponsesResponse, StringOrContentParts, ResponsesRequest, ResponsesResponse, StringOrContentParts,
}, },
}, },
routers::{ routers::{
grpc::{ grpc::{
common::responses::streaming::{OutputItemType, ResponseStreamEventEmitter},
context::SharedComponents, context::SharedComponents,
error, error,
harmony::processor::ResponsesIterationResult, harmony::processor::ResponsesIterationResult,
pipeline::RequestPipeline, pipeline::RequestPipeline,
responses::streaming::{OutputItemType, ResponseStreamEventEmitter},
}, },
openai::mcp::ensure_request_mcp_client, openai::mcp::ensure_request_mcp_client,
}, },
...@@ -414,10 +417,6 @@ pub async fn serve_harmony_responses_stream( ...@@ -414,10 +417,6 @@ pub async fn serve_harmony_responses_stream(
Err(err_response) => return err_response, Err(err_response) => return err_response,
}; };
use std::io;
use bytes::Bytes;
// Create SSE channel // Create SSE channel
let (tx, rx) = mpsc::unbounded_channel(); let (tx, rx) = mpsc::unbounded_channel();
let stream = UnboundedReceiverStream::new(rx); let stream = UnboundedReceiverStream::new(rx);
...@@ -444,7 +443,7 @@ pub async fn serve_harmony_responses_stream( ...@@ -444,7 +443,7 @@ pub async fn serve_harmony_responses_stream(
// Helper to emit error and return // Helper to emit error and return
let emit_error = |tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>, error_msg: &str| { let emit_error = |tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>, error_msg: &str| {
// Create error event manually since emit_failed doesn't exist // Create error event manually since emit_failed doesn't exist
let event = serde_json::json!({ let event = json!({
"type": "response.failed", "type": "response.failed",
"response_id": response_id_for_error, "response_id": response_id_for_error,
"error": { "error": {
...@@ -452,7 +451,7 @@ pub async fn serve_harmony_responses_stream( ...@@ -452,7 +451,7 @@ pub async fn serve_harmony_responses_stream(
"type": "internal_error" "type": "internal_error"
} }
}); });
let sse_data = format!("data: {}\n\n", serde_json::to_string(&event).unwrap()); let sse_data = format!("data: {}\n\n", to_string(&event).unwrap());
let _ = tx.send(Ok(Bytes::from(sse_data))); let _ = tx.send(Ok(Bytes::from(sse_data)));
}; };
...@@ -517,7 +516,6 @@ pub async fn serve_harmony_responses_stream( ...@@ -517,7 +516,6 @@ pub async fn serve_harmony_responses_stream(
let tool_items: Vec<_> = mcp_tools let tool_items: Vec<_> = mcp_tools
.iter() .iter()
.map(|t| { .map(|t| {
use serde_json::{json, Value};
json!({ json!({
"name": t.name, "name": t.name,
"description": t.description, "description": t.description,
...@@ -527,7 +525,7 @@ pub async fn serve_harmony_responses_stream( ...@@ -527,7 +525,7 @@ pub async fn serve_harmony_responses_stream(
.collect(); .collect();
// Emit output_item.added // Emit output_item.added
let item = serde_json::json!({ let item = json!({
"id": item_id, "id": item_id,
"type": "mcp_list_tools", "type": "mcp_list_tools",
"server_label": "sglang-mcp", "server_label": "sglang-mcp",
...@@ -552,7 +550,7 @@ pub async fn serve_harmony_responses_stream( ...@@ -552,7 +550,7 @@ pub async fn serve_harmony_responses_stream(
} }
// Emit output_item.done // Emit output_item.done
let item_done = serde_json::json!({ let item_done = json!({
"id": item_id, "id": item_id,
"type": "mcp_list_tools", "type": "mcp_list_tools",
"server_label": "sglang-mcp", "server_label": "sglang-mcp",
...@@ -677,7 +675,7 @@ pub async fn serve_harmony_responses_stream( ...@@ -677,7 +675,7 @@ pub async fn serve_harmony_responses_stream(
); );
// Emit response.completed with usage // Emit response.completed with usage
let usage_json = serde_json::json!({ let usage_json = json!({
"prompt_tokens": usage.prompt_tokens, "prompt_tokens": usage.prompt_tokens,
"completion_tokens": usage.completion_tokens, "completion_tokens": usage.completion_tokens,
"total_tokens": usage.total_tokens, "total_tokens": usage.total_tokens,
...@@ -733,7 +731,7 @@ async fn execute_mcp_tools( ...@@ -733,7 +731,7 @@ async fn execute_mcp_tools(
// Parse tool arguments from JSON string // Parse tool arguments from JSON string
let args_str = tool_call.function.arguments.as_deref().unwrap_or("{}"); let args_str = tool_call.function.arguments.as_deref().unwrap_or("{}");
let args: JsonValue = serde_json::from_str(args_str).map_err(|e| { let args: Value = from_str(args_str).map_err(|e| {
error::internal_error(format!( error::internal_error(format!(
"Invalid tool arguments JSON for tool '{}': {}", "Invalid tool arguments JSON for tool '{}': {}",
tool_call.function.name, e tool_call.function.name, e
...@@ -741,8 +739,7 @@ async fn execute_mcp_tools( ...@@ -741,8 +739,7 @@ async fn execute_mcp_tools(
})?; })?;
// Execute tool via MCP manager // Execute tool via MCP manager
// Convert JsonValue to ToolArgs via Option<Map> (MCP manager expects this) let args_map = if let Value::Object(map) = args {
let args_map = if let JsonValue::Object(map) = args {
Some(map) Some(map)
} else { } else {
None None
...@@ -763,15 +760,14 @@ async fn execute_mcp_tools( ...@@ -763,15 +760,14 @@ async fn execute_mcp_tools(
let output = if let Some(content) = mcp_result.content.first() { let output = if let Some(content) = mcp_result.content.first() {
// TODO: Handle different content types (text, image, resource) // TODO: Handle different content types (text, image, resource)
// For now, serialize the entire content item // For now, serialize the entire content item
serde_json::to_value(content).unwrap_or_else( to_value(content)
|_| serde_json::json!({"error": "Failed to serialize tool result"}), .unwrap_or_else(|_| json!({"error": "Failed to serialize tool result"}))
)
} else { } else {
serde_json::json!({"result": "success"}) json!({"result": "success"})
}; };
let is_error = mcp_result.is_error.unwrap_or(false); let is_error = mcp_result.is_error.unwrap_or(false);
let output_str = serde_json::to_string(&output) let output_str = to_string(&output)
.unwrap_or_else(|_| r#"{"error": "Failed to serialize output"}"#.to_string()); .unwrap_or_else(|_| r#"{"error": "Failed to serialize output"}"#.to_string());
// Record this call in tracking // Record this call in tracking
...@@ -804,10 +800,10 @@ async fn execute_mcp_tools( ...@@ -804,10 +800,10 @@ async fn execute_mcp_tools(
); );
let error_msg = format!("Tool execution failed: {}", e); let error_msg = format!("Tool execution failed: {}", e);
let error_output = serde_json::json!({ let error_output = json!({
"error": error_msg.clone() "error": error_msg.clone()
}); });
let error_output_str = serde_json::to_string(&error_output) let error_output_str = to_string(&error_output)
.unwrap_or_else(|_| format!(r#"{{"error": "{}"}}"#, error_msg)); .unwrap_or_else(|_| format!(r#"{{"error": "{}"}}"#, error_msg));
// Record failed call in tracking // Record failed call in tracking
...@@ -859,12 +855,6 @@ fn build_next_request_with_tools( ...@@ -859,12 +855,6 @@ fn build_next_request_with_tools(
analysis: Option<String>, analysis: Option<String>,
partial_text: String, partial_text: String,
) -> Result<ResponsesRequest, Box<Response>> { ) -> Result<ResponsesRequest, Box<Response>> {
use uuid::Uuid;
use crate::protocols::responses::{
ResponseContentPart, ResponseInputOutputItem, ResponseReasoningContent,
};
// Get current input items (or empty vec if Text variant) // Get current input items (or empty vec if Text variant)
let mut items = match request.input { let mut items = match request.input {
ResponseInput::Items(items) => items, ResponseInput::Items(items) => items,
...@@ -926,7 +916,7 @@ fn build_next_request_with_tools( ...@@ -926,7 +916,7 @@ fn build_next_request_with_tools(
// Add tool results // Add tool results
for tool_result in tool_results { for tool_result in tool_results {
// Serialize tool output to string // Serialize tool output to string
let output_str = serde_json::to_string(&tool_result.output).unwrap_or_else(|e| { let output_str = to_string(&tool_result.output).unwrap_or_else(|e| {
format!("{{\"error\": \"Failed to serialize tool output: {}\"}}", e) format!("{{\"error\": \"Failed to serialize tool output: {}\"}}", e)
}); });
...@@ -967,7 +957,7 @@ struct ToolResult { ...@@ -967,7 +957,7 @@ struct ToolResult {
tool_name: String, tool_name: String,
/// Tool output (JSON value) /// Tool output (JSON value)
output: JsonValue, output: Value,
/// Whether this is an error result /// Whether this is an error result
is_error: bool, is_error: bool,
...@@ -985,11 +975,7 @@ struct ToolResult { ...@@ -985,11 +975,7 @@ struct ToolResult {
/// # Returns /// # Returns
/// ///
/// Vector of ResponseTool entries in MCP format /// Vector of ResponseTool entries in MCP format
pub fn convert_mcp_tools_to_response_tools(mcp_tools: &[crate::mcp::Tool]) -> Vec<ResponseTool> { pub fn convert_mcp_tools_to_response_tools(mcp_tools: &[mcp::Tool]) -> Vec<ResponseTool> {
use serde_json::Value;
use crate::protocols::responses::ResponseToolType;
mcp_tools mcp_tools
.iter() .iter()
.map(|tool_info| ResponseTool { .map(|tool_info| ResponseTool {
...@@ -1027,11 +1013,6 @@ fn inject_mcp_metadata( ...@@ -1027,11 +1013,6 @@ fn inject_mcp_metadata(
tracking: &McpCallTracking, tracking: &McpCallTracking,
mcp_manager: &Arc<McpManager>, mcp_manager: &Arc<McpManager>,
) { ) {
use serde_json::{json, Value};
use uuid::Uuid;
use crate::protocols::responses::{McpToolInfo, ResponseOutputItem};
// Build mcp_list_tools item // Build mcp_list_tools item
let tools = mcp_manager.list_tools(); let tools = mcp_manager.list_tools();
let tools_info: Vec<McpToolInfo> = tools let tools_info: Vec<McpToolInfo> = tools
...@@ -1121,23 +1102,22 @@ async fn load_previous_messages( ...@@ -1121,23 +1102,22 @@ async fn load_previous_messages(
let mut history_items = Vec::new(); let mut history_items = Vec::new();
// Helper to deserialize and collect items from a JSON array // Helper to deserialize and collect items from a JSON array
let deserialize_items = let deserialize_items = |arr: &Value, item_type: &str| -> Vec<ResponseInputOutputItem> {
|arr: &serde_json::Value, item_type: &str| -> Vec<ResponseInputOutputItem> { arr.as_array()
arr.as_array() .into_iter()
.into_iter() .flat_map(|items| items.iter())
.flat_map(|items| items.iter()) .filter_map(|item| {
.filter_map(|item| { from_value::<ResponseInputOutputItem>(item.clone())
serde_json::from_value::<ResponseInputOutputItem>(item.clone()) .map_err(|e| {
.map_err(|e| { warn!(
warn!( "Failed to deserialize stored {} item: {}. Item: {}",
"Failed to deserialize stored {} item: {}. Item: {}", item_type, e, item
item_type, e, item );
); })
}) .ok()
.ok() })
}) .collect()
.collect() };
};
for stored in chain.responses.iter() { for stored in chain.responses.iter() {
history_items.extend(deserialize_items(&stored.input, "input")); history_items.extend(deserialize_items(&stored.input, "input"));
......
...@@ -12,10 +12,9 @@ use crate::{ ...@@ -12,10 +12,9 @@ use crate::{
responses::ResponsesRequest, responses::ResponsesRequest,
}, },
routers::grpc::{ routers::grpc::{
common::stages::PipelineStage,
context::{PreparationOutput, RequestContext, RequestType}, context::{PreparationOutput, RequestContext, RequestType},
error, error, utils,
stages::PipelineStage,
utils,
}, },
}; };
......
//! Harmony Request Building Stage: Build gRPC request from Harmony-encoded tokens //! Harmony Request Building Stage: Build gRPC request from Harmony-encoded tokens
use std::sync::Arc;
use async_trait::async_trait; use async_trait::async_trait;
use axum::response::Response; use axum::response::Response;
use rand::Rng;
use tracing::debug; use tracing::debug;
use uuid::Uuid; use uuid::Uuid;
use crate::{ use crate::routers::grpc::{
core::Worker, common::stages::{helpers, PipelineStage},
grpc_client::proto::{DisaggregatedParams, GenerateRequest}, context::{ClientSelection, RequestContext, RequestType, WorkerSelection},
routers::grpc::{ error,
context::{ClientSelection, RequestContext, RequestType, WorkerSelection},
error,
stages::PipelineStage,
},
}; };
/// Harmony Request Building stage: Convert Harmony tokens to gRPC request /// Harmony Request Building stage: Convert Harmony tokens to gRPC request
...@@ -31,34 +24,6 @@ impl HarmonyRequestBuildingStage { ...@@ -31,34 +24,6 @@ impl HarmonyRequestBuildingStage {
pub fn new(inject_pd_metadata: bool) -> Self { pub fn new(inject_pd_metadata: bool) -> Self {
Self { inject_pd_metadata } Self { inject_pd_metadata }
} }
/// Inject PD (prefill-decode) bootstrap metadata
fn inject_bootstrap_metadata(
&self,
request: &mut GenerateRequest,
prefill_worker: &Arc<dyn Worker>,
) {
let hostname = prefill_worker.bootstrap_host();
let bootstrap_port = prefill_worker.bootstrap_port().unwrap_or(8998);
// Generate room ID for bootstrap
let room_id = rand::rng().random_range(0..i32::MAX);
// Create DisaggregatedParams
let disagg_params = DisaggregatedParams {
bootstrap_host: hostname.to_string(),
bootstrap_port: bootstrap_port as i32,
bootstrap_room: room_id,
};
// Inject metadata directly into request
request.disaggregated_params = Some(disagg_params);
debug!(
"Injected Harmony bootstrap metadata: host={}, port={}, room={}",
hostname, bootstrap_port, room_id
);
}
} }
#[async_trait] #[async_trait]
...@@ -94,7 +59,6 @@ impl PipelineStage for HarmonyRequestBuildingStage { ...@@ -94,7 +59,6 @@ impl PipelineStage for HarmonyRequestBuildingStage {
}; };
// Build gRPC request using token_ids directly (Harmony encoding already handled message rendering) // Build gRPC request using token_ids directly (Harmony encoding already handled message rendering)
// Use a placeholder for original_text; Harmony uses input_ids for tokenization
let placeholder_processed_text = "[harmony]".to_string(); let placeholder_processed_text = "[harmony]".to_string();
let mut proto_request = match &ctx.input.request_type { let mut proto_request = match &ctx.input.request_type {
...@@ -141,7 +105,7 @@ impl PipelineStage for HarmonyRequestBuildingStage { ...@@ -141,7 +105,7 @@ impl PipelineStage for HarmonyRequestBuildingStage {
// Inject PD metadata if needed // Inject PD metadata if needed
if self.inject_pd_metadata { if self.inject_pd_metadata {
if let Some(WorkerSelection::Dual { prefill, .. }) = ctx.state.workers.as_ref() { if let Some(WorkerSelection::Dual { prefill, .. }) = ctx.state.workers.as_ref() {
self.inject_bootstrap_metadata(&mut proto_request, prefill); helpers::inject_bootstrap_metadata(&mut proto_request, prefill);
} }
} }
......
...@@ -7,9 +7,9 @@ use axum::response::Response; ...@@ -7,9 +7,9 @@ use axum::response::Response;
use super::super::{HarmonyResponseProcessor, HarmonyStreamingProcessor}; use super::super::{HarmonyResponseProcessor, HarmonyStreamingProcessor};
use crate::routers::grpc::{ use crate::routers::grpc::{
common::stages::PipelineStage,
context::{FinalResponse, RequestContext, RequestType}, context::{FinalResponse, RequestContext, RequestType},
error, error,
stages::PipelineStage,
}; };
/// Harmony Response Processing stage: Parse and format Harmony responses /// Harmony Response Processing stage: Parse and format Harmony responses
......
...@@ -31,8 +31,8 @@ use crate::{ ...@@ -31,8 +31,8 @@ use crate::{
responses::{ResponseStatus, ResponseUsage, ResponsesResponse, ResponsesUsage}, responses::{ResponseStatus, ResponseUsage, ResponsesResponse, ResponsesUsage},
}, },
routers::grpc::{ routers::grpc::{
common::responses::streaming::{OutputItemType, ResponseStreamEventEmitter},
context, context,
responses::streaming::{OutputItemType, ResponseStreamEventEmitter},
}, },
}; };
/// Processor for streaming Harmony responses /// Processor for streaming Harmony responses
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment