"git@developer.sourcefind.cn:change/sglang.git" did not exist on "cdae77b03dfc6fec3863630550b45bbfc789f957"
Unverified Commit 83804bc6 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][grpc] Restructure modules and code clean up (#12598)

parent d5fa019c
......@@ -251,9 +251,10 @@ pub enum ResponseOutputItem {
// Configuration Enums
// ============================================================================
#[derive(Debug, Clone, Deserialize, Serialize)]
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum ServiceTier {
#[default]
Auto,
Default,
Flex,
......@@ -261,25 +262,14 @@ pub enum ServiceTier {
Priority,
}
impl Default for ServiceTier {
fn default() -> Self {
Self::Auto
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum Truncation {
Auto,
#[default]
Disabled,
}
impl Default for Truncation {
fn default() -> Self {
Self::Disabled
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ResponseStatus {
......
//! Shared code for both regular and harmony routers
pub mod response_collection;
pub mod response_formatting;
pub mod responses;
pub mod stages;
//! Shared response collection logic
//!
//! This module contains common logic for collecting responses from execution results.
//! Both regular and harmony processors use these functions to avoid duplication.
use axum::response::Response;
use crate::{
grpc_client::proto,
routers::grpc::{context::ExecutionResult, error, utils},
};
/// Collect and merge responses from execution result
///
/// Handles both Single and Dual (prefill-decode) execution modes.
/// For Dual mode, merges prefill input_logprobs into decode responses if requested.
///
/// # Arguments
/// * `execution_result` - The execution result containing stream(s)
/// * `merge_logprobs` - Whether to merge prefill input_logprobs (for chat with logprobs=true)
///
/// # Returns
/// Vector of GenerateComplete responses, one per index (n parameter)
pub async fn collect_responses(
execution_result: ExecutionResult,
merge_logprobs: bool,
) -> Result<Vec<proto::GenerateComplete>, Response> {
let all_responses = match execution_result {
ExecutionResult::Single { mut stream } => {
let responses = utils::collect_stream_responses(&mut stream, "Single").await?;
stream.mark_completed();
responses
}
ExecutionResult::Dual {
mut prefill,
decode,
} => {
// Collect prefill for input_logprobs (don't mark completed yet)
let prefill_responses =
utils::collect_stream_responses(&mut prefill, "Prefill").await?;
// Collect decode for actual output (don't mark completed yet)
let mut decode_stream = *decode;
let mut decode_responses =
utils::collect_stream_responses(&mut decode_stream, "Decode").await?;
// Mark both streams as completed now that both succeeded
prefill.mark_completed();
decode_stream.mark_completed();
// Merge prefill input_logprobs if requested
if merge_logprobs {
merge_prefill_logprobs(&prefill_responses, &mut decode_responses);
}
decode_responses
}
};
if all_responses.is_empty() {
return Err(error::internal_error("No responses from server"));
}
Ok(all_responses)
}
/// Merge prefill input_logprobs into decode responses
///
/// Takes input_logprobs from the first prefill response and copies them
/// into all decode responses. This is used in PD mode when logprobs are requested.
fn merge_prefill_logprobs(
prefill_responses: &[proto::GenerateComplete],
decode_responses: &mut [proto::GenerateComplete],
) {
if let Some(prefill_input_logprobs) = prefill_responses
.first()
.and_then(|r| r.input_logprobs.clone())
{
for response in decode_responses.iter_mut() {
response.input_logprobs = Some(prefill_input_logprobs.clone());
}
}
}
//! Shared response formatting logic
//!
//! This module contains common logic for formatting responses, including:
//! - Usage calculation from gRPC responses
//! - ChatCompletionResponse construction
use crate::{
grpc_client::proto,
protocols::{
chat::{ChatChoice, ChatCompletionResponse},
common::Usage,
},
routers::grpc::context::DispatchMetadata,
};
/// Build usage information from collected gRPC responses
///
/// Sums prompt_tokens and completion_tokens across all responses.
/// Typically used with n>1 parameter where multiple completions are generated.
///
/// # Arguments
/// * `responses` - Vector of GenerateComplete responses from the backend
///
/// # Returns
/// Usage object with aggregated token counts
pub fn build_usage(responses: &[proto::GenerateComplete]) -> Usage {
let total_prompt_tokens: u32 = responses.iter().map(|r| r.prompt_tokens as u32).sum();
let total_completion_tokens: u32 = responses.iter().map(|r| r.completion_tokens as u32).sum();
Usage {
prompt_tokens: total_prompt_tokens,
completion_tokens: total_completion_tokens,
total_tokens: total_prompt_tokens + total_completion_tokens,
completion_tokens_details: None,
}
}
/// Build final ChatCompletionResponse from processed choices
///
/// Constructs the OpenAI-compatible response object with all metadata.
///
/// # Arguments
/// * `choices` - Processed chat choices (after parsing, logprobs, etc.)
/// * `dispatch` - Dispatch metadata (request_id, created timestamp, etc.)
/// * `model` - Model name to include in response
/// * `usage` - Token usage information
///
/// # Returns
/// Complete ChatCompletionResponse ready to send to client
pub fn build_chat_response(
choices: Vec<ChatChoice>,
dispatch: &DispatchMetadata,
model: String,
usage: Usage,
) -> ChatCompletionResponse {
ChatCompletionResponse {
id: dispatch.request_id.clone(),
object: "chat.completion".to_string(),
created: dispatch.created,
model,
choices,
usage: Some(usage),
system_fingerprint: dispatch.weight_version.clone(),
}
}
//! Shared response handlers for both regular and harmony implementations
//!
//! These handlers are used by both pipelines for retrieving and cancelling responses.
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
};
use serde_json::json;
use tracing::{debug, error, warn};
use crate::{
data_connector::ResponseId, routers::grpc::regular::responses::context::ResponsesContext,
};
/// Implementation for GET /v1/responses/{response_id}
///
/// Retrieves a stored response from the database.
/// Used by both regular and harmony implementations.
pub async fn get_response_impl(ctx: &ResponsesContext, response_id: &str) -> Response {
let resp_id = ResponseId::from(response_id);
// Retrieve response from storage
match ctx.response_storage.get_response(&resp_id).await {
Ok(Some(stored_response)) => axum::Json(stored_response.raw_response).into_response(),
Ok(None) => (
StatusCode::NOT_FOUND,
axum::Json(json!({
"error": {
"message": format!("Response with id '{}' not found", response_id),
"type": "not_found_error",
"code": "response_not_found"
}
})),
)
.into_response(),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
axum::Json(json!({
"error": {
"message": format!("Failed to retrieve response: {}", e),
"type": "internal_error"
}
})),
)
.into_response(),
}
}
/// Implementation for POST /v1/responses/{response_id}/cancel
///
/// Cancels a background response if it's still in progress.
pub async fn cancel_response_impl(ctx: &ResponsesContext, response_id: &str) -> Response {
let resp_id = ResponseId::from(response_id);
// Retrieve response from storage to check if it exists and get current status
match ctx.response_storage.get_response(&resp_id).await {
Ok(Some(stored_response)) => {
// Check current status - only queued or in_progress responses can be cancelled
let current_status = stored_response
.raw_response
.get("status")
.and_then(|v| v.as_str())
.unwrap_or("unknown");
match current_status {
"queued" | "in_progress" => {
// Attempt to abort the background task
let mut tasks = ctx.background_tasks.write().await;
if let Some(task_info) = tasks.remove(response_id) {
// Abort the Rust task immediately
task_info.handle.abort();
// Abort the Python/scheduler request via gRPC (if client is available)
let client_opt = task_info.client.read().await;
if let Some(ref client) = *client_opt {
if let Err(e) = client
.abort_request(
task_info.grpc_request_id.clone(),
"User cancelled via API".to_string(),
)
.await
{
warn!(
"Failed to abort Python request {}: {}",
task_info.grpc_request_id, e
);
} else {
debug!(
"Successfully aborted Python request: {}",
task_info.grpc_request_id
);
}
} else {
debug!("Client not yet available for abort, request may not have started yet");
}
// Task was found and aborted
(
StatusCode::OK,
axum::Json(json!({
"id": response_id,
"status": "cancelled",
"message": "Background task has been cancelled"
})),
)
.into_response()
} else {
// Task handle not found but status is queued/in_progress
// This can happen if: (1) task crashed, or (2) storage persistence failed
error!(
"Response {} has status '{}' but task handle is missing. Task may have crashed or storage update failed.",
response_id, current_status
);
(
StatusCode::INTERNAL_SERVER_ERROR,
axum::Json(json!({
"error": {
"message": "Internal error: background task completed but failed to update status in storage",
"type": "internal_error",
"code": "status_update_failed"
}
})),
)
.into_response()
}
}
"completed" => (
StatusCode::BAD_REQUEST,
axum::Json(json!({
"error": {
"message": "Cannot cancel completed response",
"type": "invalid_request_error",
"code": "response_already_completed"
}
})),
)
.into_response(),
"failed" => (
StatusCode::BAD_REQUEST,
axum::Json(json!({
"error": {
"message": "Cannot cancel failed response",
"type": "invalid_request_error",
"code": "response_already_failed"
}
})),
)
.into_response(),
"cancelled" => (
StatusCode::OK,
axum::Json(json!({
"id": response_id,
"status": "cancelled",
"message": "Response was already cancelled"
})),
)
.into_response(),
_ => {
// Unknown status
(
StatusCode::INTERNAL_SERVER_ERROR,
axum::Json(json!({
"error": {
"message": format!("Unknown response status: {}", current_status),
"type": "internal_error"
}
})),
)
.into_response()
}
}
}
Ok(None) => (
StatusCode::NOT_FOUND,
axum::Json(json!({
"error": {
"message": format!("Response with id '{}' not found", response_id),
"type": "not_found_error",
"code": "response_not_found"
}
})),
)
.into_response(),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
axum::Json(json!({
"error": {
"message": format!("Failed to retrieve response: {}", e),
"type": "internal_error"
}
})),
)
.into_response(),
}
}
//! Shared response functionality used by both regular and harmony implementations
pub mod handlers;
pub mod streaming;
pub use handlers::{cancel_response_impl, get_response_impl};
pub use streaming::{OutputItemType, ResponseStreamEventEmitter};
......@@ -7,7 +7,7 @@ use serde_json::json;
use tokio::sync::mpsc;
use uuid::Uuid;
use crate::protocols::chat::ChatCompletionStreamResponse;
use crate::{mcp, protocols::chat::ChatCompletionStreamResponse};
pub enum OutputItemType {
Message,
......@@ -30,10 +30,6 @@ struct OutputItemState {
status: ItemStatus,
}
// ============================================================================
// Streaming Event Emitter
// ============================================================================
/// OpenAI-compatible event emitter for /v1/responses streaming
///
/// Manages state and sequence numbers to emit proper event types:
......@@ -66,7 +62,7 @@ pub struct ResponseStreamEventEmitter {
has_emitted_content_part_added: bool,
// MCP call tracking
mcp_call_accumulated_args: HashMap<String, String>,
// Output item tracking (NEW)
// Output item tracking
output_items: Vec<OutputItemState>,
next_output_index: usize,
current_message_output_index: Option<usize>, // Tracks output_index of current message
......@@ -248,7 +244,7 @@ impl ResponseStreamEventEmitter {
pub fn emit_mcp_list_tools_completed(
&mut self,
output_index: usize,
tools: &[crate::mcp::Tool],
tools: &[mcp::Tool],
) -> serde_json::Value {
let tool_items: Vec<_> = tools
.iter()
......@@ -331,7 +327,7 @@ impl ResponseStreamEventEmitter {
})
}
pub(super) fn emit_mcp_call_failed(
pub fn emit_mcp_call_failed(
&mut self,
output_index: usize,
item_id: &str,
......@@ -453,7 +449,7 @@ impl ResponseStreamEventEmitter {
}
/// Process a chunk and emit appropriate events
pub(super) fn process_chunk(
pub fn process_chunk(
&mut self,
chunk: &ChatCompletionStreamResponse,
tx: &mpsc::UnboundedSender<Result<Bytes, std::io::Error>>,
......
//! Common helper functions shared across stages
use std::sync::Arc;
use proto::DisaggregatedParams;
use rand::Rng;
use tracing::debug;
use crate::{core::Worker, grpc_client::proto};
/// Inject PD bootstrap metadata into a gRPC request
///
/// Used by both chat and generate request building stages when in PD mode.
pub fn inject_bootstrap_metadata(
request: &mut proto::GenerateRequest,
prefill_worker: &Arc<dyn Worker>,
) {
let hostname = prefill_worker.bootstrap_host();
let bootstrap_port = prefill_worker.bootstrap_port().unwrap_or(8998);
// Generate room ID for bootstrap
let room_id = rand::rng().random_range(0..i32::MAX);
// Create DisaggregatedParams
let disagg_params = DisaggregatedParams {
bootstrap_host: hostname.to_string(),
bootstrap_port: bootstrap_port as i32,
bootstrap_room: room_id,
};
// Inject metadata directly into request
request.disaggregated_params = Some(disagg_params);
debug!(
"Injected bootstrap metadata: host={}, port={}, room={}",
hostname, bootstrap_port, room_id
);
}
//! Pipeline stages for gRPC router request processing
//! Common pipeline stages shared across all endpoints and model types
//!
//! This module defines the core pipeline abstraction and individual processing stages
//! that transform a RequestContext through its lifecycle.
//! These stages are endpoint-agnostic and model-agnostic:
//! - Worker selection
//! - Client acquisition
//! - Dispatch metadata generation
//! - Request execution
use async_trait::async_trait;
use axum::response::Response;
use crate::routers::grpc::context::RequestContext;
// ============================================================================
// Pipeline Trait
// ============================================================================
/// Trait for pipeline stages that process requests
#[async_trait]
pub trait PipelineStage: Send + Sync {
......@@ -27,26 +26,14 @@ pub trait PipelineStage: Send + Sync {
fn name(&self) -> &'static str;
}
// ============================================================================
// Stage Modules
// ============================================================================
mod client_acquisition;
mod dispatch_metadata;
mod preparation;
mod request_building;
pub mod helpers;
mod request_execution;
mod response_processing;
mod worker_selection;
// ============================================================================
// Public Exports
// ============================================================================
// Export stage implementations
pub use client_acquisition::ClientAcquisitionStage;
pub use dispatch_metadata::DispatchMetadataStage;
pub use preparation::PreparationStage;
pub use request_building::RequestBuildingStage;
pub use request_execution::{ExecutionMode, RequestExecutionStage};
pub use response_processing::ResponseProcessingStage;
pub use worker_selection::{WorkerSelectionMode, WorkerSelectionStage};
......@@ -109,7 +109,6 @@ impl WorkerSelectionStage {
false, // get all workers, we'll filter by is_available() next
);
// Filter by availability (health + circuit breaker)
let available: Vec<Arc<dyn Worker>> = workers
.iter()
.filter(|w| w.is_available())
......
......@@ -11,7 +11,7 @@ use serde_json::Value;
use crate::{
core::Worker,
grpc_client::{proto, SglangSchedulerClient},
grpc_client::{proto, sglang_scheduler::AbortOnDropStream, SglangSchedulerClient},
protocols::{
chat::{ChatCompletionRequest, ChatCompletionResponse},
generate::{GenerateRequest, GenerateResponse},
......@@ -22,23 +22,14 @@ use crate::{
tool_parser::ParserFactory as ToolParserFactory,
};
// ============================================================================
// Core Context Types
// ============================================================================
/// Main request processing context
///
/// This is the single source of truth for all request state as it flows
/// through the pipeline stages. Uses Rust's type system to enforce proper
/// stage ordering at compile time.
pub struct RequestContext {
// === Input (Immutable) ===
pub input: RequestInput,
// === Shared Components (Immutable References) ===
pub components: Arc<SharedComponents>,
// === Processing State (Mutable, evolves through pipeline) ===
pub state: ProcessingState,
}
......@@ -86,10 +77,6 @@ pub struct ProcessingState {
pub response: ResponseState,
}
// ============================================================================
// Stage-Specific Output Types
// ============================================================================
/// Output from preparation stage (Step 1)
pub struct PreparationOutput {
/// Original text (for chat) or resolved text (for generate)
......@@ -201,10 +188,6 @@ pub struct StreamingState {
pub has_tool_calls: HashMap<u32, bool>,
}
// ============================================================================
// Context Builders
// ============================================================================
impl RequestContext {
/// Create context for chat completion request
pub fn for_chat(
......@@ -323,14 +306,6 @@ impl RequestContext {
}
}
// ============================================================================
// Default Implementations
// ============================================================================
// ============================================================================
// Helper Methods
// ============================================================================
impl WorkerSelection {
pub fn is_dual(&self) -> bool {
matches!(self, Self::Dual { .. })
......@@ -428,12 +403,6 @@ impl ClientSelection {
}
}
// ============================================================================
// Execution and Response Types
// ============================================================================
use crate::grpc_client::sglang_scheduler::AbortOnDropStream;
/// Result of request execution (streams from workers)
/// Uses AbortOnDropStream to automatically abort on cancellation
pub enum ExecutionResult {
......
......@@ -17,8 +17,9 @@ use crate::{
},
},
routers::grpc::{
common::{response_collection, response_formatting},
context::{DispatchMetadata, ExecutionResult},
error, utils,
error,
},
};
......@@ -34,28 +35,6 @@ impl HarmonyResponseProcessor {
Self
}
/// Collect responses from ExecutionResult (similar to regular processor)
async fn collect_responses(
execution_result: ExecutionResult,
) -> Result<Vec<proto::GenerateComplete>, Response> {
match execution_result {
ExecutionResult::Single { mut stream } => {
let responses = utils::collect_stream_responses(&mut stream, "Single").await?;
stream.mark_completed();
Ok(responses)
}
ExecutionResult::Dual { prefill, decode } => {
// For Harmony we currently rely only on decode stream for outputs
let mut decode_stream = *decode;
let responses =
utils::collect_stream_responses(&mut decode_stream, "Decode").await?;
prefill.mark_completed();
decode_stream.mark_completed();
Ok(responses)
}
}
}
/// Process a non-streaming Harmony chat response
pub async fn process_non_streaming_chat_response(
&self,
......@@ -64,7 +43,7 @@ impl HarmonyResponseProcessor {
dispatch: DispatchMetadata,
) -> Result<ChatCompletionResponse, Response> {
// Collect all completed responses (one per choice)
let all_responses = Self::collect_responses(execution_result).await?;
let all_responses = response_collection::collect_responses(execution_result, false).await?;
if all_responses.is_empty() {
return Err(error::internal_error("No responses from server"));
}
......@@ -117,28 +96,15 @@ impl HarmonyResponseProcessor {
}
// Build usage from proto fields
let prompt_tokens: u32 = all_responses.iter().map(|r| r.prompt_tokens as u32).sum();
let completion_tokens: u32 = all_responses
.iter()
.map(|r| r.completion_tokens as u32)
.sum();
let usage = Usage {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
completion_tokens_details: None,
};
let usage = response_formatting::build_usage(&all_responses);
// Final ChatCompletionResponse
let response = ChatCompletionResponse {
id: dispatch.request_id.clone(),
object: "chat.completion".to_string(),
created: dispatch.created,
model: chat_request.model.clone(),
let response = response_formatting::build_chat_response(
choices,
usage: Some(usage),
system_fingerprint: dispatch.weight_version.clone(),
};
&dispatch,
chat_request.model.clone(),
usage,
);
Ok(response)
}
......@@ -191,7 +157,7 @@ impl HarmonyResponseProcessor {
dispatch: DispatchMetadata,
) -> Result<ResponsesIterationResult, Response> {
// Collect all completed responses
let all_responses = Self::collect_responses(execution_result).await?;
let all_responses = response_collection::collect_responses(execution_result, false).await?;
if all_responses.is_empty() {
return Err(error::internal_error("No responses from server"));
}
......@@ -280,14 +246,7 @@ impl HarmonyResponseProcessor {
}
// Build usage
let prompt_tokens = complete.prompt_tokens as u32;
let completion_tokens = complete.completion_tokens as u32;
let usage = Usage {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
completion_tokens_details: None,
};
let usage = response_formatting::build_usage(std::slice::from_ref(complete));
// Build ResponsesResponse with all required fields
let response = ResponsesResponse {
......@@ -316,9 +275,9 @@ impl HarmonyResponseProcessor {
top_p: responses_request.top_p,
truncation: None,
usage: Some(ResponsesUsage::Modern(ResponseUsage {
input_tokens: prompt_tokens,
output_tokens: completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
total_tokens: usage.total_tokens,
input_tokens_details: None,
output_tokens_details: None,
})),
......
......@@ -37,12 +37,14 @@
//! for complete architecture, rationale, and implementation details.
use std::{
io,
sync::Arc,
time::{SystemTime, UNIX_EPOCH},
};
use axum::{body::Body, http::StatusCode, response::Response};
use serde_json::Value as JsonValue;
use bytes::Bytes;
use serde_json::{from_str, from_value, json, to_string, to_value, Value};
use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, warn};
......@@ -50,21 +52,22 @@ use uuid::Uuid;
use crate::{
data_connector::{ResponseId, ResponseStorage},
mcp::McpManager,
mcp::{self, McpManager},
protocols::{
common::{Function, ToolCall},
responses::{
ResponseInput, ResponseInputOutputItem, ResponseTool, ResponseToolType,
McpToolInfo, ResponseContentPart, ResponseInput, ResponseInputOutputItem,
ResponseOutputItem, ResponseReasoningContent, ResponseTool, ResponseToolType,
ResponsesRequest, ResponsesResponse, StringOrContentParts,
},
},
routers::{
grpc::{
common::responses::streaming::{OutputItemType, ResponseStreamEventEmitter},
context::SharedComponents,
error,
harmony::processor::ResponsesIterationResult,
pipeline::RequestPipeline,
responses::streaming::{OutputItemType, ResponseStreamEventEmitter},
},
openai::mcp::ensure_request_mcp_client,
},
......@@ -414,10 +417,6 @@ pub async fn serve_harmony_responses_stream(
Err(err_response) => return err_response,
};
use std::io;
use bytes::Bytes;
// Create SSE channel
let (tx, rx) = mpsc::unbounded_channel();
let stream = UnboundedReceiverStream::new(rx);
......@@ -444,7 +443,7 @@ pub async fn serve_harmony_responses_stream(
// Helper to emit error and return
let emit_error = |tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>, error_msg: &str| {
// Create error event manually since emit_failed doesn't exist
let event = serde_json::json!({
let event = json!({
"type": "response.failed",
"response_id": response_id_for_error,
"error": {
......@@ -452,7 +451,7 @@ pub async fn serve_harmony_responses_stream(
"type": "internal_error"
}
});
let sse_data = format!("data: {}\n\n", serde_json::to_string(&event).unwrap());
let sse_data = format!("data: {}\n\n", to_string(&event).unwrap());
let _ = tx.send(Ok(Bytes::from(sse_data)));
};
......@@ -517,7 +516,6 @@ pub async fn serve_harmony_responses_stream(
let tool_items: Vec<_> = mcp_tools
.iter()
.map(|t| {
use serde_json::{json, Value};
json!({
"name": t.name,
"description": t.description,
......@@ -527,7 +525,7 @@ pub async fn serve_harmony_responses_stream(
.collect();
// Emit output_item.added
let item = serde_json::json!({
let item = json!({
"id": item_id,
"type": "mcp_list_tools",
"server_label": "sglang-mcp",
......@@ -552,7 +550,7 @@ pub async fn serve_harmony_responses_stream(
}
// Emit output_item.done
let item_done = serde_json::json!({
let item_done = json!({
"id": item_id,
"type": "mcp_list_tools",
"server_label": "sglang-mcp",
......@@ -677,7 +675,7 @@ pub async fn serve_harmony_responses_stream(
);
// Emit response.completed with usage
let usage_json = serde_json::json!({
let usage_json = json!({
"prompt_tokens": usage.prompt_tokens,
"completion_tokens": usage.completion_tokens,
"total_tokens": usage.total_tokens,
......@@ -733,7 +731,7 @@ async fn execute_mcp_tools(
// Parse tool arguments from JSON string
let args_str = tool_call.function.arguments.as_deref().unwrap_or("{}");
let args: JsonValue = serde_json::from_str(args_str).map_err(|e| {
let args: Value = from_str(args_str).map_err(|e| {
error::internal_error(format!(
"Invalid tool arguments JSON for tool '{}': {}",
tool_call.function.name, e
......@@ -741,8 +739,7 @@ async fn execute_mcp_tools(
})?;
// Execute tool via MCP manager
// Convert JsonValue to ToolArgs via Option<Map> (MCP manager expects this)
let args_map = if let JsonValue::Object(map) = args {
let args_map = if let Value::Object(map) = args {
Some(map)
} else {
None
......@@ -763,15 +760,14 @@ async fn execute_mcp_tools(
let output = if let Some(content) = mcp_result.content.first() {
// TODO: Handle different content types (text, image, resource)
// For now, serialize the entire content item
serde_json::to_value(content).unwrap_or_else(
|_| serde_json::json!({"error": "Failed to serialize tool result"}),
)
to_value(content)
.unwrap_or_else(|_| json!({"error": "Failed to serialize tool result"}))
} else {
serde_json::json!({"result": "success"})
json!({"result": "success"})
};
let is_error = mcp_result.is_error.unwrap_or(false);
let output_str = serde_json::to_string(&output)
let output_str = to_string(&output)
.unwrap_or_else(|_| r#"{"error": "Failed to serialize output"}"#.to_string());
// Record this call in tracking
......@@ -804,10 +800,10 @@ async fn execute_mcp_tools(
);
let error_msg = format!("Tool execution failed: {}", e);
let error_output = serde_json::json!({
let error_output = json!({
"error": error_msg.clone()
});
let error_output_str = serde_json::to_string(&error_output)
let error_output_str = to_string(&error_output)
.unwrap_or_else(|_| format!(r#"{{"error": "{}"}}"#, error_msg));
// Record failed call in tracking
......@@ -859,12 +855,6 @@ fn build_next_request_with_tools(
analysis: Option<String>,
partial_text: String,
) -> Result<ResponsesRequest, Box<Response>> {
use uuid::Uuid;
use crate::protocols::responses::{
ResponseContentPart, ResponseInputOutputItem, ResponseReasoningContent,
};
// Get current input items (or empty vec if Text variant)
let mut items = match request.input {
ResponseInput::Items(items) => items,
......@@ -926,7 +916,7 @@ fn build_next_request_with_tools(
// Add tool results
for tool_result in tool_results {
// Serialize tool output to string
let output_str = serde_json::to_string(&tool_result.output).unwrap_or_else(|e| {
let output_str = to_string(&tool_result.output).unwrap_or_else(|e| {
format!("{{\"error\": \"Failed to serialize tool output: {}\"}}", e)
});
......@@ -967,7 +957,7 @@ struct ToolResult {
tool_name: String,
/// Tool output (JSON value)
output: JsonValue,
output: Value,
/// Whether this is an error result
is_error: bool,
......@@ -985,11 +975,7 @@ struct ToolResult {
/// # Returns
///
/// Vector of ResponseTool entries in MCP format
pub fn convert_mcp_tools_to_response_tools(mcp_tools: &[crate::mcp::Tool]) -> Vec<ResponseTool> {
use serde_json::Value;
use crate::protocols::responses::ResponseToolType;
pub fn convert_mcp_tools_to_response_tools(mcp_tools: &[mcp::Tool]) -> Vec<ResponseTool> {
mcp_tools
.iter()
.map(|tool_info| ResponseTool {
......@@ -1027,11 +1013,6 @@ fn inject_mcp_metadata(
tracking: &McpCallTracking,
mcp_manager: &Arc<McpManager>,
) {
use serde_json::{json, Value};
use uuid::Uuid;
use crate::protocols::responses::{McpToolInfo, ResponseOutputItem};
// Build mcp_list_tools item
let tools = mcp_manager.list_tools();
let tools_info: Vec<McpToolInfo> = tools
......@@ -1121,23 +1102,22 @@ async fn load_previous_messages(
let mut history_items = Vec::new();
// Helper to deserialize and collect items from a JSON array
let deserialize_items =
|arr: &serde_json::Value, item_type: &str| -> Vec<ResponseInputOutputItem> {
arr.as_array()
.into_iter()
.flat_map(|items| items.iter())
.filter_map(|item| {
serde_json::from_value::<ResponseInputOutputItem>(item.clone())
.map_err(|e| {
warn!(
"Failed to deserialize stored {} item: {}. Item: {}",
item_type, e, item
);
})
.ok()
})
.collect()
};
let deserialize_items = |arr: &Value, item_type: &str| -> Vec<ResponseInputOutputItem> {
arr.as_array()
.into_iter()
.flat_map(|items| items.iter())
.filter_map(|item| {
from_value::<ResponseInputOutputItem>(item.clone())
.map_err(|e| {
warn!(
"Failed to deserialize stored {} item: {}. Item: {}",
item_type, e, item
);
})
.ok()
})
.collect()
};
for stored in chain.responses.iter() {
history_items.extend(deserialize_items(&stored.input, "input"));
......
......@@ -12,10 +12,9 @@ use crate::{
responses::ResponsesRequest,
},
routers::grpc::{
common::stages::PipelineStage,
context::{PreparationOutput, RequestContext, RequestType},
error,
stages::PipelineStage,
utils,
error, utils,
},
};
......
//! Harmony Request Building Stage: Build gRPC request from Harmony-encoded tokens
use std::sync::Arc;
use async_trait::async_trait;
use axum::response::Response;
use rand::Rng;
use tracing::debug;
use uuid::Uuid;
use crate::{
core::Worker,
grpc_client::proto::{DisaggregatedParams, GenerateRequest},
routers::grpc::{
context::{ClientSelection, RequestContext, RequestType, WorkerSelection},
error,
stages::PipelineStage,
},
use crate::routers::grpc::{
common::stages::{helpers, PipelineStage},
context::{ClientSelection, RequestContext, RequestType, WorkerSelection},
error,
};
/// Harmony Request Building stage: Convert Harmony tokens to gRPC request
......@@ -31,34 +24,6 @@ impl HarmonyRequestBuildingStage {
pub fn new(inject_pd_metadata: bool) -> Self {
Self { inject_pd_metadata }
}
/// Inject PD (prefill-decode) bootstrap metadata
fn inject_bootstrap_metadata(
&self,
request: &mut GenerateRequest,
prefill_worker: &Arc<dyn Worker>,
) {
let hostname = prefill_worker.bootstrap_host();
let bootstrap_port = prefill_worker.bootstrap_port().unwrap_or(8998);
// Generate room ID for bootstrap
let room_id = rand::rng().random_range(0..i32::MAX);
// Create DisaggregatedParams
let disagg_params = DisaggregatedParams {
bootstrap_host: hostname.to_string(),
bootstrap_port: bootstrap_port as i32,
bootstrap_room: room_id,
};
// Inject metadata directly into request
request.disaggregated_params = Some(disagg_params);
debug!(
"Injected Harmony bootstrap metadata: host={}, port={}, room={}",
hostname, bootstrap_port, room_id
);
}
}
#[async_trait]
......@@ -94,7 +59,6 @@ impl PipelineStage for HarmonyRequestBuildingStage {
};
// Build gRPC request using token_ids directly (Harmony encoding already handled message rendering)
// Use a placeholder for original_text; Harmony uses input_ids for tokenization
let placeholder_processed_text = "[harmony]".to_string();
let mut proto_request = match &ctx.input.request_type {
......@@ -141,7 +105,7 @@ impl PipelineStage for HarmonyRequestBuildingStage {
// Inject PD metadata if needed
if self.inject_pd_metadata {
if let Some(WorkerSelection::Dual { prefill, .. }) = ctx.state.workers.as_ref() {
self.inject_bootstrap_metadata(&mut proto_request, prefill);
helpers::inject_bootstrap_metadata(&mut proto_request, prefill);
}
}
......
......@@ -7,9 +7,9 @@ use axum::response::Response;
use super::super::{HarmonyResponseProcessor, HarmonyStreamingProcessor};
use crate::routers::grpc::{
common::stages::PipelineStage,
context::{FinalResponse, RequestContext, RequestType},
error,
stages::PipelineStage,
};
/// Harmony Response Processing stage: Parse and format Harmony responses
......
......@@ -31,8 +31,8 @@ use crate::{
responses::{ResponseStatus, ResponseUsage, ResponsesResponse, ResponsesUsage},
},
routers::grpc::{
common::responses::streaming::{OutputItemType, ResponseStreamEventEmitter},
context,
responses::streaming::{OutputItemType, ResponseStreamEventEmitter},
},
};
/// Processor for streaming Harmony responses
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment