Unverified Commit 83804bc6 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][grpc] Restructure modules and code clean up (#12598)

parent d5fa019c
//! Shared types for Harmony pipeline //! Shared types for Harmony pipeline
use openai_harmony::chat::Content;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
...@@ -36,8 +37,6 @@ impl HarmonyMessage { ...@@ -36,8 +37,6 @@ impl HarmonyMessage {
/// Convert from openai_harmony::chat::Message to our simplified HarmonyMessage /// Convert from openai_harmony::chat::Message to our simplified HarmonyMessage
pub fn from_openai_harmony(msg: openai_harmony::chat::Message) -> Self { pub fn from_openai_harmony(msg: openai_harmony::chat::Message) -> Self {
use openai_harmony::chat::Content;
// Extract role as string // Extract role as string
let role = match msg.author.role { let role = match msg.author.role {
openai_harmony::chat::Role::User => "user", openai_harmony::chat::Role::User => "user",
......
...@@ -2,16 +2,14 @@ ...@@ -2,16 +2,14 @@
use crate::{grpc_client::proto, protocols::common::StringOrArray}; use crate::{grpc_client::proto, protocols::common::StringOrArray};
pub mod common;
pub mod context; pub mod context;
pub mod error; pub mod error;
pub mod harmony; pub mod harmony;
pub mod pd_router; pub mod pd_router;
pub mod pipeline; pub mod pipeline;
pub mod processing; pub mod regular;
pub mod responses;
pub mod router; pub mod router;
pub mod stages;
pub mod streaming;
pub mod utils; pub mod utils;
/// Processed chat messages ready for gRPC generation /// Processed chat messages ready for gRPC generation
......
// PD (Prefill-Decode) gRPC Router Implementation
use std::sync::Arc; use std::sync::Arc;
use async_trait::async_trait; use async_trait::async_trait;
...@@ -161,7 +159,6 @@ impl RouterTrait for GrpcPDRouter { ...@@ -161,7 +159,6 @@ impl RouterTrait for GrpcPDRouter {
} }
async fn health_generate(&self, _req: Request<Body>) -> Response { async fn health_generate(&self, _req: Request<Body>) -> Response {
// TODO: Implement actual generation test for gRPC PD mode
( (
StatusCode::NOT_IMPLEMENTED, StatusCode::NOT_IMPLEMENTED,
"Health generate not yet implemented for gRPC PD", "Health generate not yet implemented for gRPC PD",
......
...@@ -3,15 +3,17 @@ ...@@ -3,15 +3,17 @@
//! This module defines the RequestPipeline orchestrator that coordinates //! This module defines the RequestPipeline orchestrator that coordinates
//! the execution of pipeline stages from request preparation to response delivery. //! the execution of pipeline stages from request preparation to response delivery.
use std::{collections::HashMap, sync::Arc}; use std::sync::Arc;
use axum::response::{IntoResponse, Response}; use axum::response::{IntoResponse, Response};
use tokio::sync::RwLock; use tracing::error;
use tracing::{debug, error};
// Import all stage types from the stages module use super::{
use super::stages::*; common::stages::*,
use super::{context::*, error, harmony, processing, responses::BackgroundTaskInfo, streaming}; context::*,
error, harmony,
regular::{processor, stages::*, streaming},
};
use crate::{ use crate::{
core::WorkerRegistry, core::WorkerRegistry,
policies::PolicyRegistry, policies::PolicyRegistry,
...@@ -24,10 +26,6 @@ use crate::{ ...@@ -24,10 +26,6 @@ use crate::{
tool_parser::ParserFactory as ToolParserFactory, tool_parser::ParserFactory as ToolParserFactory,
}; };
// ============================================================================
// Pipeline Orchestrator
// ============================================================================
/// Generic request pipeline for all request types /// Generic request pipeline for all request types
/// ///
/// Orchestrates all stages from request preparation to response delivery. /// Orchestrates all stages from request preparation to response delivery.
...@@ -48,8 +46,7 @@ impl RequestPipeline { ...@@ -48,8 +46,7 @@ impl RequestPipeline {
configured_tool_parser: Option<String>, configured_tool_parser: Option<String>,
configured_reasoning_parser: Option<String>, configured_reasoning_parser: Option<String>,
) -> Self { ) -> Self {
// Create response processor let processor = processor::ResponseProcessor::new(
let processor = processing::ResponseProcessor::new(
tokenizer.clone(), tokenizer.clone(),
tool_parser_factory.clone(), tool_parser_factory.clone(),
reasoning_parser_factory.clone(), reasoning_parser_factory.clone(),
...@@ -57,7 +54,6 @@ impl RequestPipeline { ...@@ -57,7 +54,6 @@ impl RequestPipeline {
configured_reasoning_parser.clone(), configured_reasoning_parser.clone(),
); );
// Create streaming processor
let streaming_processor = Arc::new(streaming::StreamingProcessor::new( let streaming_processor = Arc::new(streaming::StreamingProcessor::new(
tokenizer, tokenizer,
tool_parser_factory, tool_parser_factory,
...@@ -67,7 +63,7 @@ impl RequestPipeline { ...@@ -67,7 +63,7 @@ impl RequestPipeline {
)); ));
let stages: Vec<Box<dyn PipelineStage>> = vec![ let stages: Vec<Box<dyn PipelineStage>> = vec![
Box::new(PreparationStage), Box::new(PreparationStage::new()),
Box::new(WorkerSelectionStage::new( Box::new(WorkerSelectionStage::new(
worker_registry, worker_registry,
policy_registry, policy_registry,
...@@ -153,8 +149,7 @@ impl RequestPipeline { ...@@ -153,8 +149,7 @@ impl RequestPipeline {
configured_tool_parser: Option<String>, configured_tool_parser: Option<String>,
configured_reasoning_parser: Option<String>, configured_reasoning_parser: Option<String>,
) -> Self { ) -> Self {
// Create response processor let processor = processor::ResponseProcessor::new(
let processor = processing::ResponseProcessor::new(
tokenizer.clone(), tokenizer.clone(),
tool_parser_factory.clone(), tool_parser_factory.clone(),
reasoning_parser_factory.clone(), reasoning_parser_factory.clone(),
...@@ -162,7 +157,6 @@ impl RequestPipeline { ...@@ -162,7 +157,6 @@ impl RequestPipeline {
configured_reasoning_parser.clone(), configured_reasoning_parser.clone(),
); );
// Create streaming processor
let streaming_processor = Arc::new(streaming::StreamingProcessor::new( let streaming_processor = Arc::new(streaming::StreamingProcessor::new(
tokenizer, tokenizer,
tool_parser_factory, tool_parser_factory,
...@@ -172,7 +166,7 @@ impl RequestPipeline { ...@@ -172,7 +166,7 @@ impl RequestPipeline {
)); ));
let stages: Vec<Box<dyn PipelineStage>> = vec![ let stages: Vec<Box<dyn PipelineStage>> = vec![
Box::new(PreparationStage), Box::new(PreparationStage::new()),
Box::new(WorkerSelectionStage::new( Box::new(WorkerSelectionStage::new(
worker_registry, worker_registry,
policy_registry, policy_registry,
...@@ -200,7 +194,6 @@ impl RequestPipeline { ...@@ -200,7 +194,6 @@ impl RequestPipeline {
) -> Response { ) -> Response {
let mut ctx = RequestContext::for_chat(request, headers, model_id, components); let mut ctx = RequestContext::for_chat(request, headers, model_id, components);
// Execute each stage in sequence
for (idx, stage) in self.stages.iter().enumerate() { for (idx, stage) in self.stages.iter().enumerate() {
match stage.execute(&mut ctx).await { match stage.execute(&mut ctx).await {
Ok(Some(response)) => { Ok(Some(response)) => {
...@@ -208,7 +201,6 @@ impl RequestPipeline { ...@@ -208,7 +201,6 @@ impl RequestPipeline {
return response; return response;
} }
Ok(None) => { Ok(None) => {
// Continue to next stage
continue; continue;
} }
Err(response) => { Err(response) => {
...@@ -224,7 +216,6 @@ impl RequestPipeline { ...@@ -224,7 +216,6 @@ impl RequestPipeline {
} }
} }
// Extract final response
match ctx.state.response.final_response { match ctx.state.response.final_response {
Some(FinalResponse::Chat(response)) => axum::Json(response).into_response(), Some(FinalResponse::Chat(response)) => axum::Json(response).into_response(),
Some(FinalResponse::Generate(_)) => { Some(FinalResponse::Generate(_)) => {
...@@ -244,7 +235,6 @@ impl RequestPipeline { ...@@ -244,7 +235,6 @@ impl RequestPipeline {
) -> Response { ) -> Response {
let mut ctx = RequestContext::for_generate(request, headers, model_id, components); let mut ctx = RequestContext::for_generate(request, headers, model_id, components);
// Execute each stage in sequence
for (idx, stage) in self.stages.iter().enumerate() { for (idx, stage) in self.stages.iter().enumerate() {
match stage.execute(&mut ctx).await { match stage.execute(&mut ctx).await {
Ok(Some(response)) => { Ok(Some(response)) => {
...@@ -252,7 +242,6 @@ impl RequestPipeline { ...@@ -252,7 +242,6 @@ impl RequestPipeline {
return response; return response;
} }
Ok(None) => { Ok(None) => {
// Continue to next stage
continue; continue;
} }
Err(response) => { Err(response) => {
...@@ -268,7 +257,6 @@ impl RequestPipeline { ...@@ -268,7 +257,6 @@ impl RequestPipeline {
} }
} }
// Extract final response
match ctx.state.response.final_response { match ctx.state.response.final_response {
Some(FinalResponse::Generate(response)) => axum::Json(response).into_response(), Some(FinalResponse::Generate(response)) => axum::Json(response).into_response(),
Some(FinalResponse::Chat(_)) => { Some(FinalResponse::Chat(_)) => {
...@@ -280,25 +268,19 @@ impl RequestPipeline { ...@@ -280,25 +268,19 @@ impl RequestPipeline {
/// Execute chat pipeline for responses endpoint /// Execute chat pipeline for responses endpoint
/// ///
/// TODO: The support for background tasks is not scalable. Consider replacing this with /// Used by ALL non-streaming /v1/responses requests.
/// a better design in the future. /// Uses the same 7 pipeline stages as execute_chat(), with two differences:
/// Used by ALL non-streaming /v1/responses requests (both sync and background modes).
/// Uses the same 7 pipeline stages as execute_chat(), with three differences:
/// 1. Returns Result<ChatCompletionResponse, Response> for tool_loop composition /// 1. Returns Result<ChatCompletionResponse, Response> for tool_loop composition
/// 2. Disallows streaming (responses endpoint uses different SSE format) /// 2. Disallows streaming (responses endpoint uses different SSE format)
/// 3. Injects hooks for background task cancellation (only active when response_id provided)
pub async fn execute_chat_for_responses( pub async fn execute_chat_for_responses(
&self, &self,
request: Arc<ChatCompletionRequest>, request: Arc<ChatCompletionRequest>,
headers: Option<http::HeaderMap>, headers: Option<http::HeaderMap>,
model_id: Option<String>, model_id: Option<String>,
components: Arc<SharedComponents>, components: Arc<SharedComponents>,
response_id: Option<String>,
background_tasks: Option<Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>>,
) -> Result<ChatCompletionResponse, Response> { ) -> Result<ChatCompletionResponse, Response> {
let mut ctx = RequestContext::for_chat(request, headers, model_id, components); let mut ctx = RequestContext::for_chat(request, headers, model_id, components);
// Execute each stage in sequence
for (idx, stage) in self.stages.iter().enumerate() { for (idx, stage) in self.stages.iter().enumerate() {
match stage.execute(&mut ctx).await { match stage.execute(&mut ctx).await {
Ok(Some(_response)) => { Ok(Some(_response)) => {
...@@ -308,40 +290,6 @@ impl RequestPipeline { ...@@ -308,40 +290,6 @@ impl RequestPipeline {
)); ));
} }
Ok(None) => { Ok(None) => {
let stage_name = stage.name();
// After ClientAcquisitionStage, store client for background task cancellation
if stage_name == "ClientAcquisition" {
if let (Some(ref clients), Some(ref resp_id), Some(ref tasks)) =
(&ctx.state.clients, &response_id, &background_tasks)
{
let client_to_store = match clients {
ClientSelection::Single { client } => client.clone(),
ClientSelection::Dual { decode, .. } => decode.clone(),
};
if let Some(task_info) = tasks.write().await.get_mut(resp_id.as_str()) {
*task_info.client.write().await = Some(client_to_store);
debug!("Stored client for response_id: {}", resp_id);
}
}
}
// After DispatchMetadataStage, store grpc_request_id for background task cancellation
if stage_name == "DispatchMetadata" {
if let (Some(ref dispatch), Some(ref resp_id), Some(ref tasks)) =
(&ctx.state.dispatch, &response_id, &background_tasks)
{
let grpc_request_id = dispatch.request_id.clone();
if let Some(task_info) = tasks.write().await.get_mut(resp_id.as_str()) {
task_info.grpc_request_id = grpc_request_id.clone();
debug!("Stored grpc_request_id for response_id: {}", resp_id);
}
}
}
// Continue to next stage
continue; continue;
} }
Err(response) => { Err(response) => {
...@@ -357,7 +305,6 @@ impl RequestPipeline { ...@@ -357,7 +305,6 @@ impl RequestPipeline {
} }
} }
// Extract final response
match ctx.state.response.final_response { match ctx.state.response.final_response {
Some(FinalResponse::Chat(response)) => Ok(response), Some(FinalResponse::Chat(response)) => Ok(response),
Some(FinalResponse::Generate(_)) => { Some(FinalResponse::Generate(_)) => {
...@@ -367,26 +314,6 @@ impl RequestPipeline { ...@@ -367,26 +314,6 @@ impl RequestPipeline {
} }
} }
/// Execute Responses API pipeline
///
/// TODO: Implement Responses API native execution
/// This is a stub to allow compilation. The actual implementation should:
/// 1. Support multi-turn MCP loop orchestration
/// 2. Handle tool call execution and result injection
/// 3. Emit proper SSE events for streaming mode
/// 4. Store responses in data connector
///
/// For now, this returns an error indicating the feature is not implemented.
pub async fn execute_responses(
&self,
_request: Arc<crate::protocols::responses::ResponsesRequest>,
_headers: Option<http::HeaderMap>,
_model_id: Option<String>,
_components: Arc<SharedComponents>,
) -> Response {
error::internal_error("Responses API execution not yet implemented")
}
/// Execute Harmony Responses API request through all pipeline stages /// Execute Harmony Responses API request through all pipeline stages
/// ///
/// This method runs a single iteration of the Responses API request, /// This method runs a single iteration of the Responses API request,
...@@ -415,7 +342,6 @@ impl RequestPipeline { ...@@ -415,7 +342,6 @@ impl RequestPipeline {
harmony_ctx.components.clone(), harmony_ctx.components.clone(),
); );
// Execute each pipeline stage in sequence
for (idx, stage) in self.stages.iter().enumerate() { for (idx, stage) in self.stages.iter().enumerate() {
match stage.execute(&mut ctx).await { match stage.execute(&mut ctx).await {
Ok(Some(response)) => { Ok(Some(response)) => {
...@@ -428,7 +354,6 @@ impl RequestPipeline { ...@@ -428,7 +354,6 @@ impl RequestPipeline {
return Err(response); return Err(response);
} }
Ok(None) => { Ok(None) => {
// Stage completed successfully, continue to next stage
continue; continue;
} }
Err(response) => { Err(response) => {
...@@ -472,7 +397,6 @@ impl RequestPipeline { ...@@ -472,7 +397,6 @@ impl RequestPipeline {
harmony_ctx.components.clone(), harmony_ctx.components.clone(),
); );
// Execute pipeline stages up to dispatch (which creates the stream)
for (idx, stage) in self.stages.iter().enumerate() { for (idx, stage) in self.stages.iter().enumerate() {
match stage.execute(&mut ctx).await { match stage.execute(&mut ctx).await {
Ok(Some(response)) => { Ok(Some(response)) => {
......
//! Regular (non-harmony) model processing
//!
//! This module contains all code specific to regular tokenizer-based models,
//! including pipeline stages, response processing, and streaming.
pub mod processor;
pub mod responses;
pub mod stages;
pub mod streaming;
//! Shared response processing logic for gRPC routers //! Shared response processing logic for gRPC routers
//! //!
//! This module contains response processing functions that are shared between //! This module contains response processing functions that are shared between
//! the regular router and PD router, eliminating ~1,200 lines of exact duplicates. //! the regular router and PD router.
use std::{sync::Arc, time::Instant}; use std::{sync::Arc, time::Instant};
...@@ -9,18 +9,19 @@ use proto::generate_complete::MatchedStop; ...@@ -9,18 +9,19 @@ use proto::generate_complete::MatchedStop;
use serde_json::Value; use serde_json::Value;
use tracing::error; use tracing::error;
use super::{
context::{DispatchMetadata, ExecutionResult},
error, utils,
};
use crate::{ use crate::{
grpc_client::proto, grpc_client::proto,
protocols::{ protocols::{
chat::{ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse}, chat::{ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse},
common::{FunctionCallResponse, ToolCall, ToolChoice, ToolChoiceValue, Usage}, common::{FunctionCallResponse, ToolCall, ToolChoice, ToolChoiceValue},
generate::{GenerateMetaInfo, GenerateRequest, GenerateResponse}, generate::{GenerateMetaInfo, GenerateRequest, GenerateResponse},
}, },
reasoning_parser::ParserFactory as ReasoningParserFactory, reasoning_parser::ParserFactory as ReasoningParserFactory,
routers::grpc::{
common::{response_collection, response_formatting},
context::{DispatchMetadata, ExecutionResult},
error, utils,
},
tokenizer::{ tokenizer::{
stop::{SequenceDecoderOutput, StopSequenceDecoder}, stop::{SequenceDecoderOutput, StopSequenceDecoder},
traits::Tokenizer, traits::Tokenizer,
...@@ -28,10 +29,6 @@ use crate::{ ...@@ -28,10 +29,6 @@ use crate::{
tool_parser::ParserFactory as ToolParserFactory, tool_parser::ParserFactory as ToolParserFactory,
}; };
// ============================================================================
// Response Processor - Main Entry Point
// ============================================================================
/// Unified response processor for both routers /// Unified response processor for both routers
#[derive(Clone)] #[derive(Clone)]
pub struct ResponseProcessor { pub struct ResponseProcessor {
...@@ -59,57 +56,6 @@ impl ResponseProcessor { ...@@ -59,57 +56,6 @@ impl ResponseProcessor {
} }
} }
/// Helper to collect responses from execution result and merge logprobs if needed
async fn collect_and_merge_responses(
execution_result: ExecutionResult,
request_logprobs: bool,
) -> Result<Vec<proto::GenerateComplete>, axum::response::Response> {
let all_responses = match execution_result {
ExecutionResult::Single { mut stream } => {
let responses = utils::collect_stream_responses(&mut stream, "Single").await?;
stream.mark_completed();
responses
}
ExecutionResult::Dual {
mut prefill,
decode,
} => {
// Collect prefill for input_logprobs (don't mark completed yet)
let prefill_responses =
utils::collect_stream_responses(&mut prefill, "Prefill").await?;
// Collect decode for actual output (don't mark completed yet)
let mut decode_stream = *decode;
let mut decode_responses =
utils::collect_stream_responses(&mut decode_stream, "Decode").await?;
// Mark both streams as completed now that both succeeded
prefill.mark_completed();
decode_stream.mark_completed();
// Merge prefill input_logprobs if requested
if request_logprobs {
if let Some(prefill_input_logprobs) = prefill_responses
.first()
.and_then(|r| r.input_logprobs.clone())
{
for response in &mut decode_responses {
response.input_logprobs = Some(prefill_input_logprobs.clone());
}
}
}
decode_responses
}
};
if all_responses.is_empty() {
return Err(error::internal_error("No responses from server"));
}
Ok(all_responses)
}
/// Process a single choice from GenerateComplete response /// Process a single choice from GenerateComplete response
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub async fn process_single_choice( pub async fn process_single_choice(
...@@ -151,7 +97,6 @@ impl ResponseProcessor { ...@@ -151,7 +97,6 @@ impl ResponseProcessor {
let mut reasoning_text: Option<String> = None; let mut reasoning_text: Option<String> = None;
let mut processed_text = final_text; let mut processed_text = final_text;
// Check if reasoning parsing is enabled and parser is available
if original_request.separate_reasoning && reasoning_parser_available { if original_request.separate_reasoning && reasoning_parser_available {
let pooled_parser = utils::get_reasoning_parser( let pooled_parser = utils::get_reasoning_parser(
&self.reasoning_parser_factory, &self.reasoning_parser_factory,
...@@ -275,7 +220,7 @@ impl ResponseProcessor { ...@@ -275,7 +220,7 @@ impl ResponseProcessor {
) -> Result<ChatCompletionResponse, axum::response::Response> { ) -> Result<ChatCompletionResponse, axum::response::Response> {
// Collect all responses from the execution result // Collect all responses from the execution result
let all_responses = let all_responses =
Self::collect_and_merge_responses(execution_result, request_logprobs).await?; response_collection::collect_responses(execution_result, request_logprobs).await?;
let history_tool_calls_count = utils::get_history_tool_calls_count(&chat_request); let history_tool_calls_count = utils::get_history_tool_calls_count(&chat_request);
...@@ -341,28 +286,15 @@ impl ResponseProcessor { ...@@ -341,28 +286,15 @@ impl ResponseProcessor {
} }
// Build usage // Build usage
let total_prompt_tokens: u32 = all_responses.iter().map(|r| r.prompt_tokens as u32).sum(); let usage = response_formatting::build_usage(&all_responses);
let total_completion_tokens: u32 = all_responses
.iter()
.map(|r| r.completion_tokens as u32)
.sum();
let usage = Usage {
prompt_tokens: total_prompt_tokens,
completion_tokens: total_completion_tokens,
total_tokens: total_prompt_tokens + total_completion_tokens,
completion_tokens_details: None,
};
// Build final ChatCompletionResponse // Build final ChatCompletionResponse
let response = ChatCompletionResponse { let response = response_formatting::build_chat_response(
id: dispatch.request_id.clone(),
object: "chat.completion".to_string(),
created: dispatch.created,
model: dispatch.model.clone(),
choices, choices,
usage: Some(usage), &dispatch,
system_fingerprint: dispatch.weight_version.clone(), dispatch.model.clone(),
}; usage,
);
Ok(response) Ok(response)
} }
...@@ -436,7 +368,7 @@ impl ResponseProcessor { ...@@ -436,7 +368,7 @@ impl ResponseProcessor {
) -> Result<Vec<GenerateResponse>, axum::response::Response> { ) -> Result<Vec<GenerateResponse>, axum::response::Response> {
// Collect all responses from the execution result // Collect all responses from the execution result
let all_responses = let all_responses =
Self::collect_and_merge_responses(execution_result, request_logprobs).await?; response_collection::collect_responses(execution_result, request_logprobs).await?;
// Process each completion // Process each completion
let mut result_array = Vec::new(); let mut result_array = Vec::new();
...@@ -474,7 +406,7 @@ impl ResponseProcessor { ...@@ -474,7 +406,7 @@ impl ResponseProcessor {
} }
let output_ids = std::mem::take(&mut complete.output_ids); let output_ids = std::mem::take(&mut complete.output_ids);
let finish_reason_str = std::mem::take(&mut complete.finish_reason); let finish_reason_str = complete.finish_reason.to_string();
// Parse finish_reason from string to proper type // Parse finish_reason from string to proper type
let finish_reason = let finish_reason =
......
//! gRPC Router `/v1/responses` endpoint implementation //! Regular gRPC Router `/v1/responses` endpoint implementation
//! //!
//! This module handles all responses-specific logic including: //! This module handles all responses-specific logic for the regular (non-Harmony) pipeline including:
//! - Request validation //! - Request validation
//! - Conversation history and response chain loading //! - Conversation history and response chain loading
//! - Background mode execution
//! - Streaming support //! - Streaming support
//! - MCP tool loop wrapper //! - MCP tool loop wrapper
//! - Response persistence //! - Response persistence
...@@ -12,11 +11,10 @@ ...@@ -12,11 +11,10 @@
pub mod context; pub mod context;
mod conversions; mod conversions;
mod handlers; mod handlers;
pub mod streaming;
pub mod tool_loop; pub mod tool_loop;
pub mod types; pub mod types;
// Public exports // Public exports
pub use context::ResponsesContext; pub use context::ResponsesContext;
pub use handlers::{cancel_response_impl, get_response_impl, route_responses}; pub use handlers::route_responses;
pub use types::BackgroundTaskInfo; pub use types::BackgroundTaskInfo;
...@@ -12,23 +12,30 @@ use axum::{ ...@@ -12,23 +12,30 @@ use axum::{
response::Response, response::Response,
}; };
use bytes::Bytes; use bytes::Bytes;
use futures_util::StreamExt;
use serde_json::json; use serde_json::json;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, warn}; use tracing::{debug, warn};
use uuid::Uuid; use uuid::Uuid;
use super::{ use super::conversions;
super::error, use crate::{
conversions, mcp::{self, McpManager},
streaming::{OutputItemType, ResponseStreamEventEmitter}, protocols::{
}; chat::{
use crate::protocols::{ ChatChoice, ChatCompletionMessage, ChatCompletionResponse, ChatCompletionStreamResponse,
chat::ChatCompletionResponse, },
common::{Tool, ToolChoice, ToolChoiceValue}, common::{Function, FunctionCallResponse, Tool, ToolCall, ToolChoice, ToolChoiceValue},
responses::{ responses::{
McpToolInfo, ResponseContentPart, ResponseInput, ResponseInputOutputItem, self, McpToolInfo, ResponseContentPart, ResponseInput, ResponseInputOutputItem,
ResponseOutputItem, ResponseStatus, ResponseToolType, ResponsesRequest, ResponsesResponse, ResponseOutputItem, ResponseStatus, ResponseToolType, ResponsesRequest,
ResponsesResponse,
},
},
routers::grpc::{
common::responses::streaming::{OutputItemType, ResponseStreamEventEmitter},
error,
}, },
}; };
...@@ -155,10 +162,7 @@ fn generate_mcp_id(prefix: &str) -> String { ...@@ -155,10 +162,7 @@ fn generate_mcp_id(prefix: &str) -> String {
} }
/// Build mcp_list_tools output item /// Build mcp_list_tools output item
fn build_mcp_list_tools_item( fn build_mcp_list_tools_item(mcp: &Arc<McpManager>, server_label: &str) -> ResponseOutputItem {
mcp: &Arc<crate::mcp::McpManager>,
server_label: &str,
) -> ResponseOutputItem {
let tools = mcp.list_tools(); let tools = mcp.list_tools();
let tools_info: Vec<McpToolInfo> = tools let tools_info: Vec<McpToolInfo> = tools
.iter() .iter()
...@@ -263,8 +267,6 @@ pub(super) async fn execute_tool_loop( ...@@ -263,8 +267,6 @@ pub(super) async fn execute_tool_loop(
headers.clone(), headers.clone(),
model_id.clone(), model_id.clone(),
ctx.components.clone(), ctx.components.clone(),
response_id.clone(),
Some(ctx.background_tasks.clone()),
) )
.await?; .await?;
...@@ -358,10 +360,9 @@ pub(super) async fn execute_tool_loop( ...@@ -358,10 +360,9 @@ pub(super) async fn execute_tool_loop(
content: vec![ResponseContentPart::InputText { text: text.clone() }], content: vec![ResponseContentPart::InputText { text: text.clone() }],
status: Some("completed".to_string()), status: Some("completed".to_string()),
}], }],
ResponseInput::Items(items) => items ResponseInput::Items(items) => {
.iter() items.iter().map(responses::normalize_input_item).collect()
.map(crate::protocols::responses::normalize_input_item) }
.collect(),
}; };
// Append all conversation history (function calls and outputs) // Append all conversation history (function calls and outputs)
...@@ -830,10 +831,9 @@ async fn execute_tool_loop_streaming_internal( ...@@ -830,10 +831,9 @@ async fn execute_tool_loop_streaming_internal(
content: vec![ResponseContentPart::InputText { text: text.clone() }], content: vec![ResponseContentPart::InputText { text: text.clone() }],
status: Some("completed".to_string()), status: Some("completed".to_string()),
}], }],
ResponseInput::Items(items) => items ResponseInput::Items(items) => {
.iter() items.iter().map(responses::normalize_input_item).collect()
.map(crate::protocols::responses::normalize_input_item) }
.collect(),
}; };
input_items.extend_from_slice(&state.conversation_history); input_items.extend_from_slice(&state.conversation_history);
...@@ -911,13 +911,13 @@ async fn execute_tool_loop_streaming_internal( ...@@ -911,13 +911,13 @@ async fn execute_tool_loop_streaming_internal(
} }
/// Convert MCP tools to Chat API tool format /// Convert MCP tools to Chat API tool format
fn convert_mcp_tools_to_chat_tools(mcp_tools: &[crate::mcp::Tool]) -> Vec<Tool> { fn convert_mcp_tools_to_chat_tools(mcp_tools: &[mcp::Tool]) -> Vec<Tool> {
use serde_json::Value; use serde_json::Value;
mcp_tools mcp_tools
.iter() .iter()
.map(|tool_info| Tool { .map(|tool_info| Tool {
tool_type: "function".to_string(), tool_type: "function".to_string(),
function: crate::protocols::common::Function { function: Function {
name: tool_info.name.to_string(), name: tool_info.name.to_string(),
description: tool_info.description.as_ref().map(|d| d.to_string()), description: tool_info.description.as_ref().map(|d| d.to_string()),
parameters: Value::Object((*tool_info.input_schema).clone()), parameters: Value::Object((*tool_info.input_schema).clone()),
...@@ -933,10 +933,6 @@ async fn convert_and_accumulate_stream( ...@@ -933,10 +933,6 @@ async fn convert_and_accumulate_stream(
emitter: &mut ResponseStreamEventEmitter, emitter: &mut ResponseStreamEventEmitter,
tx: &mpsc::UnboundedSender<Result<Bytes, std::io::Error>>, tx: &mpsc::UnboundedSender<Result<Bytes, std::io::Error>>,
) -> Result<ChatCompletionResponse, String> { ) -> Result<ChatCompletionResponse, String> {
use futures_util::StreamExt;
use crate::protocols::chat::ChatCompletionStreamResponse;
let mut accumulator = ChatResponseAccumulator::new(); let mut accumulator = ChatResponseAccumulator::new();
let mut stream = body.into_data_stream(); let mut stream = body.into_data_stream();
...@@ -971,7 +967,7 @@ struct ChatResponseAccumulator { ...@@ -971,7 +967,7 @@ struct ChatResponseAccumulator {
id: String, id: String,
model: String, model: String,
content: String, content: String,
tool_calls: HashMap<usize, crate::protocols::common::ToolCall>, tool_calls: HashMap<usize, ToolCall>,
finish_reason: Option<String>, finish_reason: Option<String>,
} }
...@@ -986,7 +982,7 @@ impl ChatResponseAccumulator { ...@@ -986,7 +982,7 @@ impl ChatResponseAccumulator {
} }
} }
fn process_chunk(&mut self, chunk: &crate::protocols::chat::ChatCompletionStreamResponse) { fn process_chunk(&mut self, chunk: &ChatCompletionStreamResponse) {
if !chunk.id.is_empty() { if !chunk.id.is_empty() {
self.id = chunk.id.clone(); self.id = chunk.id.clone();
} }
...@@ -1004,15 +1000,13 @@ impl ChatResponseAccumulator { ...@@ -1004,15 +1000,13 @@ impl ChatResponseAccumulator {
if let Some(tool_call_deltas) = &choice.delta.tool_calls { if let Some(tool_call_deltas) = &choice.delta.tool_calls {
for delta in tool_call_deltas { for delta in tool_call_deltas {
let index = delta.index as usize; let index = delta.index as usize;
let entry = self.tool_calls.entry(index).or_insert_with(|| { let entry = self.tool_calls.entry(index).or_insert_with(|| ToolCall {
crate::protocols::common::ToolCall { id: String::new(),
id: String::new(), tool_type: "function".to_string(),
tool_type: "function".to_string(), function: FunctionCallResponse {
function: crate::protocols::common::FunctionCallResponse { name: String::new(),
name: String::new(), arguments: Some(String::new()),
arguments: Some(String::new()), },
},
}
}); });
if let Some(id) = &delta.id { if let Some(id) = &delta.id {
...@@ -1048,9 +1042,9 @@ impl ChatResponseAccumulator { ...@@ -1048,9 +1042,9 @@ impl ChatResponseAccumulator {
object: "chat.completion".to_string(), object: "chat.completion".to_string(),
created: chrono::Utc::now().timestamp() as u64, created: chrono::Utc::now().timestamp() as u64,
model: self.model, model: self.model,
choices: vec![crate::protocols::chat::ChatChoice { choices: vec![ChatChoice {
index: 0, index: 0,
message: crate::protocols::chat::ChatCompletionMessage { message: ChatCompletionMessage {
role: "assistant".to_string(), role: "assistant".to_string(),
content: if self.content.is_empty() { content: if self.content.is_empty() {
None None
......
//! Chat endpoint pipeline stages
//!
//! These stages handle chat-specific preprocessing, request building, and response processing.
//! They work with any model type by using injected model adapters.
mod preparation;
mod request_building;
mod response_processing;
pub use preparation::ChatPreparationStage;
pub use request_building::ChatRequestBuildingStage;
pub use response_processing::ChatResponseProcessingStage;
//! Chat preparation stage: Filter tools, process messages, tokenize, build constraints
use std::borrow::Cow;
use async_trait::async_trait;
use axum::response::Response;
use crate::{
protocols::chat::ChatCompletionRequest,
routers::grpc::{
common::stages::PipelineStage,
context::{PreparationOutput, RequestContext},
error, utils,
},
};
/// Chat preparation stage
///
/// Extracts chat-specific preparation logic from the old unified PreparationStage.
/// This is a direct extraction without architectural changes.
pub struct ChatPreparationStage;
#[async_trait]
impl PipelineStage for ChatPreparationStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
let request = ctx.chat_request_arc();
self.prepare_chat(ctx, &request).await?;
Ok(None)
}
fn name(&self) -> &'static str {
"ChatPreparation"
}
}
impl ChatPreparationStage {
async fn prepare_chat(
&self,
ctx: &mut RequestContext,
request: &ChatCompletionRequest,
) -> Result<(), Response> {
// Step 1: Filter tools if needed
let body_ref = utils::filter_tools_for_request(request);
// Step 2: Process messages and apply chat template
let processed_messages =
match utils::process_chat_messages(&body_ref, &*ctx.components.tokenizer) {
Ok(msgs) => msgs,
Err(e) => {
return Err(error::bad_request(e));
}
};
// Step 3: Tokenize the processed text
let encoding = match ctx.components.tokenizer.encode(&processed_messages.text) {
Ok(encoding) => encoding,
Err(e) => {
return Err(error::internal_error(format!("Tokenization failed: {}", e)));
}
};
let token_ids = encoding.token_ids().to_vec();
// Step 4: Build tool constraints if needed
let tool_call_constraint = if let Some(tools) = body_ref.tools.as_ref() {
utils::generate_tool_constraints(tools, &request.tool_choice, &request.model)
.map_err(|e| error::bad_request(format!("Invalid tool configuration: {}", e)))?
} else {
None
};
// Step 5: Create stop sequence decoder (build once, reuse in non-stream)
let stop_decoder = utils::create_stop_decoder(
&ctx.components.tokenizer,
request.stop.as_ref(),
request.stop_token_ids.as_ref(),
request.skip_special_tokens,
request.no_stop_trim,
);
// Store results in context
ctx.state.preparation = Some(PreparationOutput {
original_text: Some(processed_messages.text.clone()),
token_ids,
processed_messages: Some(processed_messages),
tool_constraints: tool_call_constraint,
filtered_request: if matches!(body_ref, Cow::Owned(_)) {
Some(body_ref.into_owned())
} else {
None
},
// Harmony fields (not used for regular preparation)
harmony_mode: false,
selection_text: None,
harmony_messages: None,
harmony_stop_ids: None,
});
// Store stop decoder for reuse in response processing
ctx.state.response.stop_decoder = Some(stop_decoder);
Ok(())
}
}
//! Chat request building stage: Build proto GenerateRequest for chat requests
use async_trait::async_trait;
use axum::response::Response;
use uuid::Uuid;
use crate::routers::grpc::{
common::stages::{helpers, PipelineStage},
context::{ClientSelection, RequestContext, WorkerSelection},
error,
};
/// Chat request building stage
///
/// Extracts chat-specific request building logic from the old unified RequestBuildingStage.
pub struct ChatRequestBuildingStage {
inject_pd_metadata: bool,
}
impl ChatRequestBuildingStage {
pub fn new(inject_pd_metadata: bool) -> Self {
Self { inject_pd_metadata }
}
}
#[async_trait]
impl PipelineStage for ChatRequestBuildingStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
let prep = ctx
.state
.preparation
.as_ref()
.ok_or_else(|| error::internal_error("Preparation not completed"))?;
let clients = ctx
.state
.clients
.as_ref()
.ok_or_else(|| error::internal_error("Client acquisition not completed"))?;
let chat_request = ctx.chat_request_arc();
// Get client for building request (use prefill client if PD mode)
let builder_client = match clients {
ClientSelection::Single { client } => client,
ClientSelection::Dual { prefill, .. } => prefill,
};
// Build chat request
let request_id = format!("chatcmpl-{}", Uuid::new_v4());
let body_ref = prep.filtered_request.as_ref().unwrap_or(&chat_request);
let mut proto_request = builder_client
.build_generate_request(
request_id,
body_ref,
prep.processed_messages.as_ref().unwrap().text.clone(),
prep.token_ids.clone(),
prep.processed_messages
.as_ref()
.unwrap()
.multimodal_inputs
.clone(),
prep.tool_constraints.clone(),
)
.map_err(|e| error::bad_request(format!("Invalid request parameters: {}", e)))?;
// Inject PD metadata if needed
if self.inject_pd_metadata {
if let WorkerSelection::Dual { prefill, .. } = ctx.state.workers.as_ref().unwrap() {
helpers::inject_bootstrap_metadata(&mut proto_request, prefill);
}
}
ctx.state.proto_request = Some(proto_request);
Ok(None)
}
fn name(&self) -> &'static str {
"ChatRequestBuilding"
}
}
//! Chat response processing stage: Handles both streaming and non-streaming responses
//!
//! - For streaming: Spawns background task and returns SSE response (early exit)
//! - For non-streaming: Collects all responses and builds final ChatCompletionResponse
use std::sync::Arc;
use async_trait::async_trait;
use axum::response::Response;
use crate::routers::grpc::{
common::stages::PipelineStage,
context::{FinalResponse, RequestContext},
error,
regular::{processor, streaming},
};
/// Chat response processing stage
///
/// Extracts chat-specific response processing logic from the old unified ResponseProcessingStage.
pub struct ChatResponseProcessingStage {
processor: processor::ResponseProcessor,
streaming_processor: Arc<streaming::StreamingProcessor>,
}
impl ChatResponseProcessingStage {
pub fn new(
processor: processor::ResponseProcessor,
streaming_processor: Arc<streaming::StreamingProcessor>,
) -> Self {
Self {
processor,
streaming_processor,
}
}
}
#[async_trait]
impl PipelineStage for ChatResponseProcessingStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
self.process_chat_response(ctx).await
}
fn name(&self) -> &'static str {
"ChatResponseProcessing"
}
}
impl ChatResponseProcessingStage {
async fn process_chat_response(
&self,
ctx: &mut RequestContext,
) -> Result<Option<Response>, Response> {
let is_streaming = ctx.is_streaming();
// Extract execution result
let execution_result = ctx
.state
.response
.execution_result
.take()
.ok_or_else(|| error::internal_error("No execution result"))?;
// Get dispatch metadata (needed by both streaming and non-streaming)
let dispatch = ctx
.state
.dispatch
.as_ref()
.ok_or_else(|| error::internal_error("Dispatch metadata not set"))?
.clone();
if is_streaming {
// Streaming: Use StreamingProcessor and return SSE response (done)
return Ok(Some(
self.streaming_processor.clone().process_streaming_response(
execution_result,
ctx.chat_request_arc(), // Cheap Arc clone (8 bytes)
dispatch,
),
));
}
// Non-streaming: Delegate to ResponseProcessor
let request_logprobs = ctx.chat_request().logprobs;
let chat_request = ctx.chat_request_arc();
let stop_decoder = ctx
.state
.response
.stop_decoder
.as_mut()
.ok_or_else(|| error::internal_error("Stop decoder not initialized"))?;
let response = self
.processor
.process_non_streaming_chat_response(
execution_result,
chat_request,
dispatch,
stop_decoder,
request_logprobs,
)
.await?;
// Store the final response
ctx.state.response.final_response = Some(FinalResponse::Chat(response));
Ok(None)
}
}
//! Generate endpoint pipeline stages
//!
//! These stages handle generate-specific preprocessing, request building, and response processing.
//! They work with any model type by using injected model adapters.
mod preparation;
mod request_building;
mod response_processing;
pub use preparation::GeneratePreparationStage;
pub use request_building::GenerateRequestBuildingStage;
pub use response_processing::GenerateResponseProcessingStage;
//! Preparation stage: Filter tools, process messages, tokenize, build constraints //! Generate preparation stage: Resolve input, tokenize, create stop decoder
use std::{borrow::Cow, sync::Arc}; use std::sync::Arc;
use async_trait::async_trait; use async_trait::async_trait;
use axum::response::Response; use axum::response::Response;
use super::PipelineStage;
use crate::{ use crate::{
protocols::{chat::ChatCompletionRequest, common::InputIds, generate::GenerateRequest}, protocols::{common::InputIds, generate::GenerateRequest},
routers::grpc::{ routers::grpc::{
context::{PreparationOutput, RequestContext, RequestType}, common::stages::PipelineStage,
context::{PreparationOutput, RequestContext},
error, utils, error, utils,
}, },
tokenizer::traits::Tokenizer, tokenizer::traits::Tokenizer,
}; };
/// Preparation stage: Filter tools, process messages, tokenize, build constraints /// Generate preparation stage
pub struct PreparationStage; ///
/// Extracts generate-specific preparation logic from the old unified PreparationStage.
/// This is a direct extraction without architectural changes.
pub struct GeneratePreparationStage;
#[async_trait] #[async_trait]
impl PipelineStage for PreparationStage { impl PipelineStage for GeneratePreparationStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> { async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
// Clone Arc before match to avoid borrow checker issues let request = ctx.generate_request_arc();
// (matching borrows ctx, but prepare_* methods need mutable borrow) self.prepare_generate(ctx, &request).await?;
// Arc clone is cheap (8 bytes) - avoids full request clone (15KB-200KB)
let is_chat = matches!(&ctx.input.request_type, RequestType::Chat(_));
if is_chat {
let request_arc = ctx.chat_request_arc();
self.prepare_chat(ctx, &request_arc).await?;
} else {
let request_arc = ctx.generate_request_arc();
self.prepare_generate(ctx, &request_arc).await?;
}
Ok(None) Ok(None)
} }
fn name(&self) -> &'static str { fn name(&self) -> &'static str {
"Preparation" "GeneratePreparation"
} }
} }
impl PreparationStage { impl GeneratePreparationStage {
async fn prepare_chat(
&self,
ctx: &mut RequestContext,
request: &ChatCompletionRequest,
) -> Result<(), Response> {
// Step 1: Filter tools if needed
let body_ref = utils::filter_tools_for_request(request);
// Step 2: Process messages and apply chat template
let processed_messages =
match utils::process_chat_messages(&body_ref, &*ctx.components.tokenizer) {
Ok(msgs) => msgs,
Err(e) => {
return Err(error::bad_request(e));
}
};
// Step 3: Tokenize the processed text
let encoding = match ctx.components.tokenizer.encode(&processed_messages.text) {
Ok(encoding) => encoding,
Err(e) => {
return Err(error::internal_error(format!("Tokenization failed: {}", e)));
}
};
let token_ids = encoding.token_ids().to_vec();
// Step 4: Build tool constraints if needed
let tool_call_constraint = if let Some(tools) = body_ref.tools.as_ref() {
utils::generate_tool_constraints(tools, &request.tool_choice, &request.model)
.map_err(|e| error::bad_request(format!("Invalid tool configuration: {}", e)))?
} else {
None
};
// Step 5: Create stop sequence decoder (build once, reuse in non-stream)
let stop_decoder = utils::create_stop_decoder(
&ctx.components.tokenizer,
request.stop.as_ref(),
request.stop_token_ids.as_ref(),
request.skip_special_tokens,
request.no_stop_trim,
);
// Store results in context
ctx.state.preparation = Some(PreparationOutput {
original_text: Some(processed_messages.text.clone()),
token_ids,
processed_messages: Some(processed_messages),
tool_constraints: tool_call_constraint,
filtered_request: if matches!(body_ref, Cow::Owned(_)) {
Some(body_ref.into_owned())
} else {
None
},
// Harmony fields (not used for regular preparation)
harmony_mode: false,
selection_text: None,
harmony_messages: None,
harmony_stop_ids: None,
});
// Store stop decoder for reuse in response processing
ctx.state.response.stop_decoder = Some(stop_decoder);
Ok(())
}
async fn prepare_generate( async fn prepare_generate(
&self, &self,
ctx: &mut RequestContext, ctx: &mut RequestContext,
......
//! Generate request building stage: Build proto GenerateRequest for generate requests
use async_trait::async_trait;
use axum::response::Response;
use uuid::Uuid;
use crate::routers::grpc::{
common::stages::{helpers, PipelineStage},
context::{ClientSelection, RequestContext, WorkerSelection},
error,
};
/// Generate request building stage
///
/// Extracts generate-specific request building logic from the old unified RequestBuildingStage.
pub struct GenerateRequestBuildingStage {
inject_pd_metadata: bool,
}
impl GenerateRequestBuildingStage {
pub fn new(inject_pd_metadata: bool) -> Self {
Self { inject_pd_metadata }
}
}
#[async_trait]
impl PipelineStage for GenerateRequestBuildingStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
let prep = ctx
.state
.preparation
.as_ref()
.ok_or_else(|| error::internal_error("Preparation not completed"))?;
let clients = ctx
.state
.clients
.as_ref()
.ok_or_else(|| error::internal_error("Client acquisition not completed"))?;
let generate_request = ctx.generate_request_arc();
// Get client for building request (use prefill client if PD mode)
let builder_client = match clients {
ClientSelection::Single { client } => client,
ClientSelection::Dual { prefill, .. } => prefill,
};
// Build generate request
let request_id = generate_request
.rid
.clone()
.unwrap_or_else(|| format!("gen-{}", Uuid::new_v4()));
let mut proto_request = builder_client
.build_plain_generate_request(
request_id,
&generate_request,
prep.original_text.clone(),
prep.token_ids.clone(),
)
.map_err(error::bad_request)?;
// Inject PD metadata if needed
if self.inject_pd_metadata {
if let WorkerSelection::Dual { prefill, .. } = ctx.state.workers.as_ref().unwrap() {
helpers::inject_bootstrap_metadata(&mut proto_request, prefill);
}
}
ctx.state.proto_request = Some(proto_request);
Ok(None)
}
fn name(&self) -> &'static str {
"GenerateRequestBuilding"
}
}
//! Response processing stage: Handles both streaming and non-streaming responses //! Generate response processing stage: Handles both streaming and non-streaming responses
//!
//! - For streaming: Spawns background task and returns SSE response (early exit)
//! - For non-streaming: Collects all responses and builds final ChatCompletionResponse
use std::{sync::Arc, time::Instant}; use std::{sync::Arc, time::Instant};
use async_trait::async_trait; use async_trait::async_trait;
use axum::response::Response; use axum::response::Response;
use super::PipelineStage;
use crate::routers::grpc::{ use crate::routers::grpc::{
context::{FinalResponse, RequestContext, RequestType}, common::stages::PipelineStage,
error, processing, streaming, context::{FinalResponse, RequestContext},
error,
regular::{processor, streaming},
}; };
/// Response processing stage: Handles both streaming and non-streaming responses /// Generate response processing stage
/// ///
/// - For streaming: Spawns background task and returns SSE response (early exit) /// Extracts generate-specific response processing logic from the old unified ResponseProcessingStage.
/// - For non-streaming: Collects all responses and builds final ChatCompletionResponse pub struct GenerateResponseProcessingStage {
pub struct ResponseProcessingStage { processor: processor::ResponseProcessor,
processor: processing::ResponseProcessor,
streaming_processor: Arc<streaming::StreamingProcessor>, streaming_processor: Arc<streaming::StreamingProcessor>,
} }
impl ResponseProcessingStage { impl GenerateResponseProcessingStage {
pub fn new( pub fn new(
processor: processing::ResponseProcessor, processor: processor::ResponseProcessor,
streaming_processor: Arc<streaming::StreamingProcessor>, streaming_processor: Arc<streaming::StreamingProcessor>,
) -> Self { ) -> Self {
Self { Self {
...@@ -36,89 +33,17 @@ impl ResponseProcessingStage { ...@@ -36,89 +33,17 @@ impl ResponseProcessingStage {
} }
#[async_trait] #[async_trait]
impl PipelineStage for ResponseProcessingStage { impl PipelineStage for GenerateResponseProcessingStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> { async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
// Delegate to request-type specific processing self.process_generate_response(ctx).await
match &ctx.input.request_type {
RequestType::Chat(_) => self.process_chat_response(ctx).await,
RequestType::Generate(_) => self.process_generate_response(ctx).await,
RequestType::Responses(_) => Err(error::bad_request(
"Responses API processing must be handled by responses handler".to_string(),
)),
}
} }
fn name(&self) -> &'static str { fn name(&self) -> &'static str {
"ResponseProcessing" "GenerateResponseProcessing"
} }
} }
impl ResponseProcessingStage { impl GenerateResponseProcessingStage {
async fn process_chat_response(
&self,
ctx: &mut RequestContext,
) -> Result<Option<Response>, Response> {
let is_streaming = ctx.is_streaming();
// Extract execution result
let execution_result = ctx
.state
.response
.execution_result
.take()
.ok_or_else(|| error::internal_error("No execution result"))?;
// Get dispatch metadata (needed by both streaming and non-streaming)
let dispatch = ctx
.state
.dispatch
.as_ref()
.ok_or_else(|| error::internal_error("Dispatch metadata not set"))?
.clone();
if is_streaming {
// Streaming: Use StreamingProcessor and return SSE response (done)
return Ok(Some(
self.streaming_processor.clone().process_streaming_response(
execution_result,
ctx.chat_request_arc(), // Cheap Arc clone (8 bytes)
dispatch,
),
));
}
// Non-streaming: Delegate to ResponseProcessor
let request_logprobs = match &ctx.input.request_type {
RequestType::Chat(req) => req.logprobs,
_ => false,
};
let chat_request = ctx.chat_request_arc();
let stop_decoder = ctx
.state
.response
.stop_decoder
.as_mut()
.ok_or_else(|| error::internal_error("Stop decoder not initialized"))?;
let response = self
.processor
.process_non_streaming_chat_response(
execution_result,
chat_request,
dispatch,
stop_decoder,
request_logprobs,
)
.await?;
// Store the final response
ctx.state.response.final_response = Some(FinalResponse::Chat(response));
Ok(None)
}
async fn process_generate_response( async fn process_generate_response(
&self, &self,
ctx: &mut RequestContext, ctx: &mut RequestContext,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment