Unverified Commit 83804bc6 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][grpc] Restructure modules and code clean up (#12598)

parent d5fa019c
//! Shared types for Harmony pipeline
use openai_harmony::chat::Content;
use serde::{Deserialize, Serialize};
use serde_json::Value;
......@@ -36,8 +37,6 @@ impl HarmonyMessage {
/// Convert from openai_harmony::chat::Message to our simplified HarmonyMessage
pub fn from_openai_harmony(msg: openai_harmony::chat::Message) -> Self {
use openai_harmony::chat::Content;
// Extract role as string
let role = match msg.author.role {
openai_harmony::chat::Role::User => "user",
......
......@@ -2,16 +2,14 @@
use crate::{grpc_client::proto, protocols::common::StringOrArray};
pub mod common;
pub mod context;
pub mod error;
pub mod harmony;
pub mod pd_router;
pub mod pipeline;
pub mod processing;
pub mod responses;
pub mod regular;
pub mod router;
pub mod stages;
pub mod streaming;
pub mod utils;
/// Processed chat messages ready for gRPC generation
......
// PD (Prefill-Decode) gRPC Router Implementation
use std::sync::Arc;
use async_trait::async_trait;
......@@ -161,7 +159,6 @@ impl RouterTrait for GrpcPDRouter {
}
async fn health_generate(&self, _req: Request<Body>) -> Response {
// TODO: Implement actual generation test for gRPC PD mode
(
StatusCode::NOT_IMPLEMENTED,
"Health generate not yet implemented for gRPC PD",
......
......@@ -3,15 +3,17 @@
//! This module defines the RequestPipeline orchestrator that coordinates
//! the execution of pipeline stages from request preparation to response delivery.
use std::{collections::HashMap, sync::Arc};
use std::sync::Arc;
use axum::response::{IntoResponse, Response};
use tokio::sync::RwLock;
use tracing::{debug, error};
use tracing::error;
// Import all stage types from the stages module
use super::stages::*;
use super::{context::*, error, harmony, processing, responses::BackgroundTaskInfo, streaming};
use super::{
common::stages::*,
context::*,
error, harmony,
regular::{processor, stages::*, streaming},
};
use crate::{
core::WorkerRegistry,
policies::PolicyRegistry,
......@@ -24,10 +26,6 @@ use crate::{
tool_parser::ParserFactory as ToolParserFactory,
};
// ============================================================================
// Pipeline Orchestrator
// ============================================================================
/// Generic request pipeline for all request types
///
/// Orchestrates all stages from request preparation to response delivery.
......@@ -48,8 +46,7 @@ impl RequestPipeline {
configured_tool_parser: Option<String>,
configured_reasoning_parser: Option<String>,
) -> Self {
// Create response processor
let processor = processing::ResponseProcessor::new(
let processor = processor::ResponseProcessor::new(
tokenizer.clone(),
tool_parser_factory.clone(),
reasoning_parser_factory.clone(),
......@@ -57,7 +54,6 @@ impl RequestPipeline {
configured_reasoning_parser.clone(),
);
// Create streaming processor
let streaming_processor = Arc::new(streaming::StreamingProcessor::new(
tokenizer,
tool_parser_factory,
......@@ -67,7 +63,7 @@ impl RequestPipeline {
));
let stages: Vec<Box<dyn PipelineStage>> = vec![
Box::new(PreparationStage),
Box::new(PreparationStage::new()),
Box::new(WorkerSelectionStage::new(
worker_registry,
policy_registry,
......@@ -153,8 +149,7 @@ impl RequestPipeline {
configured_tool_parser: Option<String>,
configured_reasoning_parser: Option<String>,
) -> Self {
// Create response processor
let processor = processing::ResponseProcessor::new(
let processor = processor::ResponseProcessor::new(
tokenizer.clone(),
tool_parser_factory.clone(),
reasoning_parser_factory.clone(),
......@@ -162,7 +157,6 @@ impl RequestPipeline {
configured_reasoning_parser.clone(),
);
// Create streaming processor
let streaming_processor = Arc::new(streaming::StreamingProcessor::new(
tokenizer,
tool_parser_factory,
......@@ -172,7 +166,7 @@ impl RequestPipeline {
));
let stages: Vec<Box<dyn PipelineStage>> = vec![
Box::new(PreparationStage),
Box::new(PreparationStage::new()),
Box::new(WorkerSelectionStage::new(
worker_registry,
policy_registry,
......@@ -200,7 +194,6 @@ impl RequestPipeline {
) -> Response {
let mut ctx = RequestContext::for_chat(request, headers, model_id, components);
// Execute each stage in sequence
for (idx, stage) in self.stages.iter().enumerate() {
match stage.execute(&mut ctx).await {
Ok(Some(response)) => {
......@@ -208,7 +201,6 @@ impl RequestPipeline {
return response;
}
Ok(None) => {
// Continue to next stage
continue;
}
Err(response) => {
......@@ -224,7 +216,6 @@ impl RequestPipeline {
}
}
// Extract final response
match ctx.state.response.final_response {
Some(FinalResponse::Chat(response)) => axum::Json(response).into_response(),
Some(FinalResponse::Generate(_)) => {
......@@ -244,7 +235,6 @@ impl RequestPipeline {
) -> Response {
let mut ctx = RequestContext::for_generate(request, headers, model_id, components);
// Execute each stage in sequence
for (idx, stage) in self.stages.iter().enumerate() {
match stage.execute(&mut ctx).await {
Ok(Some(response)) => {
......@@ -252,7 +242,6 @@ impl RequestPipeline {
return response;
}
Ok(None) => {
// Continue to next stage
continue;
}
Err(response) => {
......@@ -268,7 +257,6 @@ impl RequestPipeline {
}
}
// Extract final response
match ctx.state.response.final_response {
Some(FinalResponse::Generate(response)) => axum::Json(response).into_response(),
Some(FinalResponse::Chat(_)) => {
......@@ -280,25 +268,19 @@ impl RequestPipeline {
/// Execute chat pipeline for responses endpoint
///
/// TODO: The support for background tasks is not scalable. Consider replacing this with
/// a better design in the future.
/// Used by ALL non-streaming /v1/responses requests (both sync and background modes).
/// Uses the same 7 pipeline stages as execute_chat(), with three differences:
/// Used by ALL non-streaming /v1/responses requests.
/// Uses the same 7 pipeline stages as execute_chat(), with two differences:
/// 1. Returns Result<ChatCompletionResponse, Response> for tool_loop composition
/// 2. Disallows streaming (responses endpoint uses different SSE format)
/// 3. Injects hooks for background task cancellation (only active when response_id provided)
pub async fn execute_chat_for_responses(
&self,
request: Arc<ChatCompletionRequest>,
headers: Option<http::HeaderMap>,
model_id: Option<String>,
components: Arc<SharedComponents>,
response_id: Option<String>,
background_tasks: Option<Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>>,
) -> Result<ChatCompletionResponse, Response> {
let mut ctx = RequestContext::for_chat(request, headers, model_id, components);
// Execute each stage in sequence
for (idx, stage) in self.stages.iter().enumerate() {
match stage.execute(&mut ctx).await {
Ok(Some(_response)) => {
......@@ -308,40 +290,6 @@ impl RequestPipeline {
));
}
Ok(None) => {
let stage_name = stage.name();
// After ClientAcquisitionStage, store client for background task cancellation
if stage_name == "ClientAcquisition" {
if let (Some(ref clients), Some(ref resp_id), Some(ref tasks)) =
(&ctx.state.clients, &response_id, &background_tasks)
{
let client_to_store = match clients {
ClientSelection::Single { client } => client.clone(),
ClientSelection::Dual { decode, .. } => decode.clone(),
};
if let Some(task_info) = tasks.write().await.get_mut(resp_id.as_str()) {
*task_info.client.write().await = Some(client_to_store);
debug!("Stored client for response_id: {}", resp_id);
}
}
}
// After DispatchMetadataStage, store grpc_request_id for background task cancellation
if stage_name == "DispatchMetadata" {
if let (Some(ref dispatch), Some(ref resp_id), Some(ref tasks)) =
(&ctx.state.dispatch, &response_id, &background_tasks)
{
let grpc_request_id = dispatch.request_id.clone();
if let Some(task_info) = tasks.write().await.get_mut(resp_id.as_str()) {
task_info.grpc_request_id = grpc_request_id.clone();
debug!("Stored grpc_request_id for response_id: {}", resp_id);
}
}
}
// Continue to next stage
continue;
}
Err(response) => {
......@@ -357,7 +305,6 @@ impl RequestPipeline {
}
}
// Extract final response
match ctx.state.response.final_response {
Some(FinalResponse::Chat(response)) => Ok(response),
Some(FinalResponse::Generate(_)) => {
......@@ -367,26 +314,6 @@ impl RequestPipeline {
}
}
/// Execute Responses API pipeline
///
/// TODO: Implement Responses API native execution
/// This is a stub to allow compilation. The actual implementation should:
/// 1. Support multi-turn MCP loop orchestration
/// 2. Handle tool call execution and result injection
/// 3. Emit proper SSE events for streaming mode
/// 4. Store responses in data connector
///
/// For now, this returns an error indicating the feature is not implemented.
pub async fn execute_responses(
&self,
_request: Arc<crate::protocols::responses::ResponsesRequest>,
_headers: Option<http::HeaderMap>,
_model_id: Option<String>,
_components: Arc<SharedComponents>,
) -> Response {
error::internal_error("Responses API execution not yet implemented")
}
/// Execute Harmony Responses API request through all pipeline stages
///
/// This method runs a single iteration of the Responses API request,
......@@ -415,7 +342,6 @@ impl RequestPipeline {
harmony_ctx.components.clone(),
);
// Execute each pipeline stage in sequence
for (idx, stage) in self.stages.iter().enumerate() {
match stage.execute(&mut ctx).await {
Ok(Some(response)) => {
......@@ -428,7 +354,6 @@ impl RequestPipeline {
return Err(response);
}
Ok(None) => {
// Stage completed successfully, continue to next stage
continue;
}
Err(response) => {
......@@ -472,7 +397,6 @@ impl RequestPipeline {
harmony_ctx.components.clone(),
);
// Execute pipeline stages up to dispatch (which creates the stream)
for (idx, stage) in self.stages.iter().enumerate() {
match stage.execute(&mut ctx).await {
Ok(Some(response)) => {
......
//! Regular (non-harmony) model processing
//!
//! This module contains all code specific to regular tokenizer-based models,
//! including pipeline stages, response processing, and streaming.
pub mod processor;
pub mod responses;
pub mod stages;
pub mod streaming;
//! Shared response processing logic for gRPC routers
//!
//! This module contains response processing functions that are shared between
//! the regular router and PD router, eliminating ~1,200 lines of exact duplicates.
//! the regular router and PD router.
use std::{sync::Arc, time::Instant};
......@@ -9,18 +9,19 @@ use proto::generate_complete::MatchedStop;
use serde_json::Value;
use tracing::error;
use super::{
context::{DispatchMetadata, ExecutionResult},
error, utils,
};
use crate::{
grpc_client::proto,
protocols::{
chat::{ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse},
common::{FunctionCallResponse, ToolCall, ToolChoice, ToolChoiceValue, Usage},
common::{FunctionCallResponse, ToolCall, ToolChoice, ToolChoiceValue},
generate::{GenerateMetaInfo, GenerateRequest, GenerateResponse},
},
reasoning_parser::ParserFactory as ReasoningParserFactory,
routers::grpc::{
common::{response_collection, response_formatting},
context::{DispatchMetadata, ExecutionResult},
error, utils,
},
tokenizer::{
stop::{SequenceDecoderOutput, StopSequenceDecoder},
traits::Tokenizer,
......@@ -28,10 +29,6 @@ use crate::{
tool_parser::ParserFactory as ToolParserFactory,
};
// ============================================================================
// Response Processor - Main Entry Point
// ============================================================================
/// Unified response processor for both routers
#[derive(Clone)]
pub struct ResponseProcessor {
......@@ -59,57 +56,6 @@ impl ResponseProcessor {
}
}
/// Helper to collect responses from execution result and merge logprobs if needed
async fn collect_and_merge_responses(
execution_result: ExecutionResult,
request_logprobs: bool,
) -> Result<Vec<proto::GenerateComplete>, axum::response::Response> {
let all_responses = match execution_result {
ExecutionResult::Single { mut stream } => {
let responses = utils::collect_stream_responses(&mut stream, "Single").await?;
stream.mark_completed();
responses
}
ExecutionResult::Dual {
mut prefill,
decode,
} => {
// Collect prefill for input_logprobs (don't mark completed yet)
let prefill_responses =
utils::collect_stream_responses(&mut prefill, "Prefill").await?;
// Collect decode for actual output (don't mark completed yet)
let mut decode_stream = *decode;
let mut decode_responses =
utils::collect_stream_responses(&mut decode_stream, "Decode").await?;
// Mark both streams as completed now that both succeeded
prefill.mark_completed();
decode_stream.mark_completed();
// Merge prefill input_logprobs if requested
if request_logprobs {
if let Some(prefill_input_logprobs) = prefill_responses
.first()
.and_then(|r| r.input_logprobs.clone())
{
for response in &mut decode_responses {
response.input_logprobs = Some(prefill_input_logprobs.clone());
}
}
}
decode_responses
}
};
if all_responses.is_empty() {
return Err(error::internal_error("No responses from server"));
}
Ok(all_responses)
}
/// Process a single choice from GenerateComplete response
#[allow(clippy::too_many_arguments)]
pub async fn process_single_choice(
......@@ -151,7 +97,6 @@ impl ResponseProcessor {
let mut reasoning_text: Option<String> = None;
let mut processed_text = final_text;
// Check if reasoning parsing is enabled and parser is available
if original_request.separate_reasoning && reasoning_parser_available {
let pooled_parser = utils::get_reasoning_parser(
&self.reasoning_parser_factory,
......@@ -275,7 +220,7 @@ impl ResponseProcessor {
) -> Result<ChatCompletionResponse, axum::response::Response> {
// Collect all responses from the execution result
let all_responses =
Self::collect_and_merge_responses(execution_result, request_logprobs).await?;
response_collection::collect_responses(execution_result, request_logprobs).await?;
let history_tool_calls_count = utils::get_history_tool_calls_count(&chat_request);
......@@ -341,28 +286,15 @@ impl ResponseProcessor {
}
// Build usage
let total_prompt_tokens: u32 = all_responses.iter().map(|r| r.prompt_tokens as u32).sum();
let total_completion_tokens: u32 = all_responses
.iter()
.map(|r| r.completion_tokens as u32)
.sum();
let usage = Usage {
prompt_tokens: total_prompt_tokens,
completion_tokens: total_completion_tokens,
total_tokens: total_prompt_tokens + total_completion_tokens,
completion_tokens_details: None,
};
let usage = response_formatting::build_usage(&all_responses);
// Build final ChatCompletionResponse
let response = ChatCompletionResponse {
id: dispatch.request_id.clone(),
object: "chat.completion".to_string(),
created: dispatch.created,
model: dispatch.model.clone(),
let response = response_formatting::build_chat_response(
choices,
usage: Some(usage),
system_fingerprint: dispatch.weight_version.clone(),
};
&dispatch,
dispatch.model.clone(),
usage,
);
Ok(response)
}
......@@ -436,7 +368,7 @@ impl ResponseProcessor {
) -> Result<Vec<GenerateResponse>, axum::response::Response> {
// Collect all responses from the execution result
let all_responses =
Self::collect_and_merge_responses(execution_result, request_logprobs).await?;
response_collection::collect_responses(execution_result, request_logprobs).await?;
// Process each completion
let mut result_array = Vec::new();
......@@ -474,7 +406,7 @@ impl ResponseProcessor {
}
let output_ids = std::mem::take(&mut complete.output_ids);
let finish_reason_str = std::mem::take(&mut complete.finish_reason);
let finish_reason_str = complete.finish_reason.to_string();
// Parse finish_reason from string to proper type
let finish_reason =
......
......@@ -9,18 +9,19 @@
//! # Architecture
//!
//! This module orchestrates all request handling for the /v1/responses endpoint.
//! It supports three execution modes:
//! It supports two execution modes:
//!
//! 1. **Synchronous** - Returns complete response immediately
//! 2. **Background** - Returns queued response, executes in background task
//! 3. **Streaming** - Returns SSE stream with real-time events
//! 2. **Streaming** - Returns SSE stream with real-time events
//!
//! Note: Background mode is no longer supported. Requests with background=true
//! will be rejected with a 400 error.
//!
//! # Request Flow
//!
//! ```text
//! route_responses()
//! ├─► route_responses_sync() → route_responses_internal()
//! ├─► route_responses_background() → spawn(route_responses_internal())
//! └─► route_responses_streaming() → convert_chat_stream_to_responses_stream()
//!
//! route_responses_internal()
......@@ -31,10 +32,7 @@
//! └─► pipeline.execute_chat_for_responses()
//! ```
use std::{
sync::Arc,
time::{SystemTime, UNIX_EPOCH},
};
use std::sync::Arc;
use axum::{
body::Body,
......@@ -44,39 +42,36 @@ use axum::{
use bytes::Bytes;
use futures_util::StreamExt;
use serde_json::json;
use tokio::sync::{mpsc, RwLock};
use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, warn};
use tracing::{debug, warn};
use uuid::Uuid;
use validator::Validate;
use super::{
conversions,
streaming::ResponseStreamEventEmitter,
tool_loop::{execute_tool_loop, execute_tool_loop_streaming},
types::BackgroundTaskInfo,
};
use crate::{
data_connector::{
ConversationId, ConversationItemStorage, ConversationStorage, ResponseId, ResponseStorage,
self, ConversationId, ConversationItemStorage, ConversationStorage, ResponseId,
ResponseStorage,
},
protocols::{
chat::ChatCompletionStreamResponse,
chat::{self, ChatCompletionStreamResponse},
common,
responses::{
ResponseContentPart, ResponseInput, ResponseInputOutputItem, ResponseOutputItem,
ResponseStatus, ResponsesRequest, ResponsesResponse, ResponsesUsage,
self, ResponseContentPart, ResponseInput, ResponseInputOutputItem, ResponseOutputItem,
ResponseReasoningContent, ResponseStatus, ResponsesRequest, ResponsesResponse,
ResponsesUsage,
},
},
routers::{
grpc::error,
grpc::{common::responses::streaming::ResponseStreamEventEmitter, error},
openai::{conversations::persist_conversation_items, mcp::ensure_request_mcp_client},
},
};
// ============================================================================
// Main Request Handler
// ============================================================================
/// Main handler for POST /v1/responses
///
/// Validates request, determines execution mode (sync/async/streaming), and delegates
......@@ -154,19 +149,17 @@ pub async fn route_responses(
.into_response();
}
// 3. Check for incompatible parameter combinations
let is_streaming = request.stream.unwrap_or(false);
// 3. Reject background mode (no longer supported)
let is_background = request.background.unwrap_or(false);
if is_streaming && is_background {
if is_background {
return (
StatusCode::BAD_REQUEST,
axum::Json(json!({
"error": {
"message": "Cannot use streaming with background mode. Please set either 'stream' or 'background' to false.",
"message": "Background mode is not supported. Please set 'background' to false or omit it.",
"type": "invalid_request_error",
"param": serde_json::Value::Null,
"code": "incompatible_parameters"
"param": "background",
"code": "unsupported_parameter"
}
})),
)
......@@ -174,10 +167,9 @@ pub async fn route_responses(
}
// 4. Route based on execution mode
let is_streaming = request.stream.unwrap_or(false);
if is_streaming {
route_responses_streaming(ctx, request, headers, model_id).await
} else if is_background {
route_responses_background(ctx, request, headers, model_id).await
} else {
route_responses_sync(ctx, request, headers, model_id, None).await
}
......@@ -284,134 +276,7 @@ async fn route_responses_internal(
Ok(responses_response)
}
// ============================================================================
// Background Mode Execution
// ============================================================================
/// Execute responses request in background mode
#[allow(clippy::too_many_arguments)]
async fn route_responses_background(
ctx: &super::context::ResponsesContext,
request: Arc<ResponsesRequest>,
headers: Option<http::HeaderMap>,
model_id: Option<String>,
) -> Response {
// Generate response_id for background tracking
let response_id = format!("resp_{}", Uuid::new_v4());
// Get current timestamp
let created_at = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as i64;
// Create queued response
let queued_response = ResponsesResponse {
id: response_id.clone(),
object: "response".to_string(),
created_at,
status: ResponseStatus::Queued,
error: None,
incomplete_details: None,
instructions: request.instructions.clone(),
max_output_tokens: request.max_output_tokens,
model: request.model.clone(),
output: Vec::new(),
parallel_tool_calls: request.parallel_tool_calls.unwrap_or(true),
previous_response_id: request.previous_response_id.clone(),
reasoning: None,
store: request.store.unwrap_or(true),
temperature: request.temperature,
text: None,
tool_choice: "auto".to_string(),
tools: request.tools.clone().unwrap_or_default(),
top_p: request.top_p,
truncation: None,
usage: None,
user: None,
safety_identifier: request.user.clone(),
metadata: request.metadata.clone().unwrap_or_default(),
};
// Persist queued response to storage
if let Ok(response_json) = serde_json::to_value(&queued_response) {
if let Err(e) = persist_conversation_items(
ctx.conversation_storage.clone(),
ctx.conversation_item_storage.clone(),
ctx.response_storage.clone(),
&response_json,
&request,
)
.await
{
warn!("Failed to persist queued response: {}", e);
}
}
// Spawn background task
let ctx_clone = ctx.clone();
let request_clone = request.clone();
let headers_clone = headers.clone();
let model_id_clone = model_id.clone();
let response_id_clone = response_id.clone();
let handle = tokio::task::spawn(async move {
// Execute synchronously (set background=false to prevent recursion)
let mut background_request = (*request_clone).clone();
background_request.background = Some(false);
match route_responses_internal(
&ctx_clone,
Arc::new(background_request),
headers_clone,
model_id_clone,
Some(response_id_clone.clone()),
)
.await
{
Ok(_) => {
debug!(
"Background response {} completed successfully",
response_id_clone
);
}
Err(response) => {
warn!(
"Background response {} failed with status {}",
response_id_clone,
response.status()
);
}
}
// Clean up task handle when done
ctx_clone
.background_tasks
.write()
.await
.remove(&response_id_clone);
});
// Store task info for cancellation support
ctx.background_tasks.write().await.insert(
response_id.clone(),
BackgroundTaskInfo {
handle,
grpc_request_id: String::new(), // Will be populated by pipeline at DispatchMetadataStage
client: Arc::new(RwLock::new(None)),
},
);
// Return queued response immediately
axum::Json(queued_response).into_response()
}
// ============================================================================
// Streaming Mode Execution
// ============================================================================
/// Execute streaming responses request
#[allow(clippy::too_many_arguments)]
async fn route_responses_streaming(
ctx: &super::context::ResponsesContext,
request: Arc<ResponsesRequest>,
......@@ -467,10 +332,9 @@ async fn route_responses_streaming(
/// 3. Converts ChatCompletionStreamResponse → ResponsesResponse delta
/// 4. Accumulates response state for final persistence
/// 5. Emits transformed SSE events in responses format
#[allow(clippy::too_many_arguments)]
async fn convert_chat_stream_to_responses_stream(
ctx: &super::context::ResponsesContext,
chat_request: Arc<crate::protocols::chat::ChatCompletionRequest>,
chat_request: Arc<chat::ChatCompletionRequest>,
headers: Option<http::HeaderMap>,
model_id: Option<String>,
original_request: &ResponsesRequest,
......@@ -557,7 +421,7 @@ async fn convert_chat_stream_to_responses_stream(
async fn process_and_transform_sse_stream(
body: Body,
original_request: ResponsesRequest,
_chat_request: Arc<crate::protocols::chat::ChatCompletionRequest>,
_chat_request: Arc<chat::ChatCompletionRequest>,
response_storage: Arc<dyn ResponseStorage>,
conversation_storage: Arc<dyn ConversationStorage>,
conversation_item_storage: Arc<dyn ConversationItemStorage>,
......@@ -673,7 +537,7 @@ struct StreamingResponseAccumulator {
// Completion state
finish_reason: Option<String>,
usage: Option<crate::protocols::common::Usage>,
usage: Option<common::Usage>,
// Original request for final response construction
original_request: ResponsesRequest,
......@@ -789,11 +653,9 @@ impl StreamingResponseAccumulator {
output.push(ResponseOutputItem::Reasoning {
id: format!("reasoning_{}", self.response_id),
summary: vec![],
content: vec![
crate::protocols::responses::ResponseReasoningContent::ReasoningText {
text: self.reasoning_buffer,
},
],
content: vec![ResponseReasoningContent::ReasoningText {
text: self.reasoning_buffer,
}],
status: Some("completed".to_string()),
});
}
......@@ -811,7 +673,7 @@ impl StreamingResponseAccumulator {
// Convert usage
let usage = self.usage.as_ref().map(|u| {
let usage_info = crate::protocols::common::UsageInfo {
let usage_info = common::UsageInfo {
prompt_tokens: u.prompt_tokens,
completion_tokens: u.completion_tokens,
total_tokens: u.total_tokens,
......@@ -878,8 +740,6 @@ async fn execute_without_mcp(
headers,
model_id,
ctx.components.clone(),
response_id.clone(),
Some(ctx.background_tasks.clone()),
)
.await?; // Preserve the Response error as-is
......@@ -973,9 +833,9 @@ async fn load_conversation_history(
// Load conversation history
const MAX_CONVERSATION_HISTORY_ITEMS: usize = 100;
let params = crate::data_connector::ListParams {
let params = data_connector::ListParams {
limit: MAX_CONVERSATION_HISTORY_ITEMS,
order: crate::data_connector::SortOrder::Asc,
order: data_connector::SortOrder::Asc,
after: None,
};
......@@ -1014,8 +874,7 @@ async fn load_conversation_history(
ResponseInput::Items(current_items) => {
// Process all item types, converting SimpleInputMessage to Message
for item in current_items.iter() {
let normalized =
crate::protocols::responses::normalize_input_item(item);
let normalized = responses::normalize_input_item(item);
items.push(normalized);
}
}
......@@ -1050,7 +909,7 @@ async fn load_conversation_history(
ResponseInput::Items(current_items) => {
// Process all item types, converting SimpleInputMessage to Message
for item in current_items.iter() {
let normalized = crate::protocols::responses::normalize_input_item(item);
let normalized = responses::normalize_input_item(item);
items.push(normalized);
}
}
......@@ -1061,194 +920,3 @@ async fn load_conversation_history(
Ok(modified_request)
}
// ============================================================================
// GET Response Implementation
// ============================================================================
/// Implementation for GET /v1/responses/{response_id}
pub async fn get_response_impl(
ctx: &super::context::ResponsesContext,
response_id: &str,
) -> Response {
let resp_id = ResponseId::from(response_id);
// Retrieve response from storage
match ctx.response_storage.get_response(&resp_id).await {
Ok(Some(stored_response)) => axum::Json(stored_response.raw_response).into_response(),
Ok(None) => (
StatusCode::NOT_FOUND,
axum::Json(json!({
"error": {
"message": format!("Response with id '{}' not found", response_id),
"type": "not_found_error",
"code": "response_not_found"
}
})),
)
.into_response(),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
axum::Json(json!({
"error": {
"message": format!("Failed to retrieve response: {}", e),
"type": "internal_error"
}
})),
)
.into_response(),
}
}
// ============================================================================
// CANCEL Response Implementation
// ============================================================================
/// Implementation for POST /v1/responses/{response_id}/cancel
pub async fn cancel_response_impl(
ctx: &super::context::ResponsesContext,
response_id: &str,
) -> Response {
let resp_id = ResponseId::from(response_id);
// Retrieve response from storage to check if it exists and get current status
match ctx.response_storage.get_response(&resp_id).await {
Ok(Some(stored_response)) => {
// Check current status - only queued or in_progress responses can be cancelled
let current_status = stored_response
.raw_response
.get("status")
.and_then(|v| v.as_str())
.unwrap_or("unknown");
match current_status {
"queued" | "in_progress" => {
// Attempt to abort the background task
let mut tasks = ctx.background_tasks.write().await;
if let Some(task_info) = tasks.remove(response_id) {
// Abort the Rust task immediately
task_info.handle.abort();
// Abort the Python/scheduler request via gRPC (if client is available)
let client_opt = task_info.client.read().await;
if let Some(ref client) = *client_opt {
if let Err(e) = client
.abort_request(
task_info.grpc_request_id.clone(),
"User cancelled via API".to_string(),
)
.await
{
warn!(
"Failed to abort Python request {}: {}",
task_info.grpc_request_id, e
);
} else {
debug!(
"Successfully aborted Python request: {}",
task_info.grpc_request_id
);
}
} else {
debug!("Client not yet available for abort, request may not have started yet");
}
// Task was found and aborted
(
StatusCode::OK,
axum::Json(json!({
"id": response_id,
"status": "cancelled",
"message": "Background task has been cancelled"
})),
)
.into_response()
} else {
// Task handle not found but status is queued/in_progress
// This can happen if: (1) task crashed, or (2) storage persistence failed
error!(
"Response {} has status '{}' but task handle is missing. Task may have crashed or storage update failed.",
response_id, current_status
);
(
StatusCode::INTERNAL_SERVER_ERROR,
axum::Json(json!({
"error": {
"message": "Internal error: background task completed but failed to update status in storage",
"type": "internal_error",
"code": "status_update_failed"
}
})),
)
.into_response()
}
}
"completed" => (
StatusCode::BAD_REQUEST,
axum::Json(json!({
"error": {
"message": "Cannot cancel completed response",
"type": "invalid_request_error",
"code": "response_already_completed"
}
})),
)
.into_response(),
"failed" => (
StatusCode::BAD_REQUEST,
axum::Json(json!({
"error": {
"message": "Cannot cancel failed response",
"type": "invalid_request_error",
"code": "response_already_failed"
}
})),
)
.into_response(),
"cancelled" => (
StatusCode::OK,
axum::Json(json!({
"id": response_id,
"status": "cancelled",
"message": "Response was already cancelled"
})),
)
.into_response(),
_ => {
// Unknown status
(
StatusCode::INTERNAL_SERVER_ERROR,
axum::Json(json!({
"error": {
"message": format!("Unknown response status: {}", current_status),
"type": "internal_error"
}
})),
)
.into_response()
}
}
}
Ok(None) => (
StatusCode::NOT_FOUND,
axum::Json(json!({
"error": {
"message": format!("Response with id '{}' not found", response_id),
"type": "not_found_error",
"code": "response_not_found"
}
})),
)
.into_response(),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
axum::Json(json!({
"error": {
"message": format!("Failed to retrieve response: {}", e),
"type": "internal_error"
}
})),
)
.into_response(),
}
}
//! gRPC Router `/v1/responses` endpoint implementation
//! Regular gRPC Router `/v1/responses` endpoint implementation
//!
//! This module handles all responses-specific logic including:
//! This module handles all responses-specific logic for the regular (non-Harmony) pipeline including:
//! - Request validation
//! - Conversation history and response chain loading
//! - Background mode execution
//! - Streaming support
//! - MCP tool loop wrapper
//! - Response persistence
......@@ -12,11 +11,10 @@
pub mod context;
mod conversions;
mod handlers;
pub mod streaming;
pub mod tool_loop;
pub mod types;
// Public exports
pub use context::ResponsesContext;
pub use handlers::{cancel_response_impl, get_response_impl, route_responses};
pub use handlers::route_responses;
pub use types::BackgroundTaskInfo;
......@@ -12,23 +12,30 @@ use axum::{
response::Response,
};
use bytes::Bytes;
use futures_util::StreamExt;
use serde_json::json;
use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, warn};
use uuid::Uuid;
use super::{
super::error,
conversions,
streaming::{OutputItemType, ResponseStreamEventEmitter},
};
use crate::protocols::{
chat::ChatCompletionResponse,
common::{Tool, ToolChoice, ToolChoiceValue},
responses::{
McpToolInfo, ResponseContentPart, ResponseInput, ResponseInputOutputItem,
ResponseOutputItem, ResponseStatus, ResponseToolType, ResponsesRequest, ResponsesResponse,
use super::conversions;
use crate::{
mcp::{self, McpManager},
protocols::{
chat::{
ChatChoice, ChatCompletionMessage, ChatCompletionResponse, ChatCompletionStreamResponse,
},
common::{Function, FunctionCallResponse, Tool, ToolCall, ToolChoice, ToolChoiceValue},
responses::{
self, McpToolInfo, ResponseContentPart, ResponseInput, ResponseInputOutputItem,
ResponseOutputItem, ResponseStatus, ResponseToolType, ResponsesRequest,
ResponsesResponse,
},
},
routers::grpc::{
common::responses::streaming::{OutputItemType, ResponseStreamEventEmitter},
error,
},
};
......@@ -155,10 +162,7 @@ fn generate_mcp_id(prefix: &str) -> String {
}
/// Build mcp_list_tools output item
fn build_mcp_list_tools_item(
mcp: &Arc<crate::mcp::McpManager>,
server_label: &str,
) -> ResponseOutputItem {
fn build_mcp_list_tools_item(mcp: &Arc<McpManager>, server_label: &str) -> ResponseOutputItem {
let tools = mcp.list_tools();
let tools_info: Vec<McpToolInfo> = tools
.iter()
......@@ -263,8 +267,6 @@ pub(super) async fn execute_tool_loop(
headers.clone(),
model_id.clone(),
ctx.components.clone(),
response_id.clone(),
Some(ctx.background_tasks.clone()),
)
.await?;
......@@ -358,10 +360,9 @@ pub(super) async fn execute_tool_loop(
content: vec![ResponseContentPart::InputText { text: text.clone() }],
status: Some("completed".to_string()),
}],
ResponseInput::Items(items) => items
.iter()
.map(crate::protocols::responses::normalize_input_item)
.collect(),
ResponseInput::Items(items) => {
items.iter().map(responses::normalize_input_item).collect()
}
};
// Append all conversation history (function calls and outputs)
......@@ -830,10 +831,9 @@ async fn execute_tool_loop_streaming_internal(
content: vec![ResponseContentPart::InputText { text: text.clone() }],
status: Some("completed".to_string()),
}],
ResponseInput::Items(items) => items
.iter()
.map(crate::protocols::responses::normalize_input_item)
.collect(),
ResponseInput::Items(items) => {
items.iter().map(responses::normalize_input_item).collect()
}
};
input_items.extend_from_slice(&state.conversation_history);
......@@ -911,13 +911,13 @@ async fn execute_tool_loop_streaming_internal(
}
/// Convert MCP tools to Chat API tool format
fn convert_mcp_tools_to_chat_tools(mcp_tools: &[crate::mcp::Tool]) -> Vec<Tool> {
fn convert_mcp_tools_to_chat_tools(mcp_tools: &[mcp::Tool]) -> Vec<Tool> {
use serde_json::Value;
mcp_tools
.iter()
.map(|tool_info| Tool {
tool_type: "function".to_string(),
function: crate::protocols::common::Function {
function: Function {
name: tool_info.name.to_string(),
description: tool_info.description.as_ref().map(|d| d.to_string()),
parameters: Value::Object((*tool_info.input_schema).clone()),
......@@ -933,10 +933,6 @@ async fn convert_and_accumulate_stream(
emitter: &mut ResponseStreamEventEmitter,
tx: &mpsc::UnboundedSender<Result<Bytes, std::io::Error>>,
) -> Result<ChatCompletionResponse, String> {
use futures_util::StreamExt;
use crate::protocols::chat::ChatCompletionStreamResponse;
let mut accumulator = ChatResponseAccumulator::new();
let mut stream = body.into_data_stream();
......@@ -971,7 +967,7 @@ struct ChatResponseAccumulator {
id: String,
model: String,
content: String,
tool_calls: HashMap<usize, crate::protocols::common::ToolCall>,
tool_calls: HashMap<usize, ToolCall>,
finish_reason: Option<String>,
}
......@@ -986,7 +982,7 @@ impl ChatResponseAccumulator {
}
}
fn process_chunk(&mut self, chunk: &crate::protocols::chat::ChatCompletionStreamResponse) {
fn process_chunk(&mut self, chunk: &ChatCompletionStreamResponse) {
if !chunk.id.is_empty() {
self.id = chunk.id.clone();
}
......@@ -1004,15 +1000,13 @@ impl ChatResponseAccumulator {
if let Some(tool_call_deltas) = &choice.delta.tool_calls {
for delta in tool_call_deltas {
let index = delta.index as usize;
let entry = self.tool_calls.entry(index).or_insert_with(|| {
crate::protocols::common::ToolCall {
id: String::new(),
tool_type: "function".to_string(),
function: crate::protocols::common::FunctionCallResponse {
name: String::new(),
arguments: Some(String::new()),
},
}
let entry = self.tool_calls.entry(index).or_insert_with(|| ToolCall {
id: String::new(),
tool_type: "function".to_string(),
function: FunctionCallResponse {
name: String::new(),
arguments: Some(String::new()),
},
});
if let Some(id) = &delta.id {
......@@ -1048,9 +1042,9 @@ impl ChatResponseAccumulator {
object: "chat.completion".to_string(),
created: chrono::Utc::now().timestamp() as u64,
model: self.model,
choices: vec![crate::protocols::chat::ChatChoice {
choices: vec![ChatChoice {
index: 0,
message: crate::protocols::chat::ChatCompletionMessage {
message: ChatCompletionMessage {
role: "assistant".to_string(),
content: if self.content.is_empty() {
None
......
//! Chat endpoint pipeline stages
//!
//! These stages handle chat-specific preprocessing, request building, and response processing.
//! They work with any model type by using injected model adapters.
mod preparation;
mod request_building;
mod response_processing;
pub use preparation::ChatPreparationStage;
pub use request_building::ChatRequestBuildingStage;
pub use response_processing::ChatResponseProcessingStage;
//! Chat preparation stage: Filter tools, process messages, tokenize, build constraints
use std::borrow::Cow;
use async_trait::async_trait;
use axum::response::Response;
use crate::{
protocols::chat::ChatCompletionRequest,
routers::grpc::{
common::stages::PipelineStage,
context::{PreparationOutput, RequestContext},
error, utils,
},
};
/// Chat preparation stage
///
/// Extracts chat-specific preparation logic from the old unified PreparationStage.
/// This is a direct extraction without architectural changes.
pub struct ChatPreparationStage;
#[async_trait]
impl PipelineStage for ChatPreparationStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
let request = ctx.chat_request_arc();
self.prepare_chat(ctx, &request).await?;
Ok(None)
}
fn name(&self) -> &'static str {
"ChatPreparation"
}
}
impl ChatPreparationStage {
async fn prepare_chat(
&self,
ctx: &mut RequestContext,
request: &ChatCompletionRequest,
) -> Result<(), Response> {
// Step 1: Filter tools if needed
let body_ref = utils::filter_tools_for_request(request);
// Step 2: Process messages and apply chat template
let processed_messages =
match utils::process_chat_messages(&body_ref, &*ctx.components.tokenizer) {
Ok(msgs) => msgs,
Err(e) => {
return Err(error::bad_request(e));
}
};
// Step 3: Tokenize the processed text
let encoding = match ctx.components.tokenizer.encode(&processed_messages.text) {
Ok(encoding) => encoding,
Err(e) => {
return Err(error::internal_error(format!("Tokenization failed: {}", e)));
}
};
let token_ids = encoding.token_ids().to_vec();
// Step 4: Build tool constraints if needed
let tool_call_constraint = if let Some(tools) = body_ref.tools.as_ref() {
utils::generate_tool_constraints(tools, &request.tool_choice, &request.model)
.map_err(|e| error::bad_request(format!("Invalid tool configuration: {}", e)))?
} else {
None
};
// Step 5: Create stop sequence decoder (build once, reuse in non-stream)
let stop_decoder = utils::create_stop_decoder(
&ctx.components.tokenizer,
request.stop.as_ref(),
request.stop_token_ids.as_ref(),
request.skip_special_tokens,
request.no_stop_trim,
);
// Store results in context
ctx.state.preparation = Some(PreparationOutput {
original_text: Some(processed_messages.text.clone()),
token_ids,
processed_messages: Some(processed_messages),
tool_constraints: tool_call_constraint,
filtered_request: if matches!(body_ref, Cow::Owned(_)) {
Some(body_ref.into_owned())
} else {
None
},
// Harmony fields (not used for regular preparation)
harmony_mode: false,
selection_text: None,
harmony_messages: None,
harmony_stop_ids: None,
});
// Store stop decoder for reuse in response processing
ctx.state.response.stop_decoder = Some(stop_decoder);
Ok(())
}
}
//! Chat request building stage: Build proto GenerateRequest for chat requests
use async_trait::async_trait;
use axum::response::Response;
use uuid::Uuid;
use crate::routers::grpc::{
common::stages::{helpers, PipelineStage},
context::{ClientSelection, RequestContext, WorkerSelection},
error,
};
/// Chat request building stage
///
/// Extracts chat-specific request building logic from the old unified RequestBuildingStage.
pub struct ChatRequestBuildingStage {
inject_pd_metadata: bool,
}
impl ChatRequestBuildingStage {
pub fn new(inject_pd_metadata: bool) -> Self {
Self { inject_pd_metadata }
}
}
#[async_trait]
impl PipelineStage for ChatRequestBuildingStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
let prep = ctx
.state
.preparation
.as_ref()
.ok_or_else(|| error::internal_error("Preparation not completed"))?;
let clients = ctx
.state
.clients
.as_ref()
.ok_or_else(|| error::internal_error("Client acquisition not completed"))?;
let chat_request = ctx.chat_request_arc();
// Get client for building request (use prefill client if PD mode)
let builder_client = match clients {
ClientSelection::Single { client } => client,
ClientSelection::Dual { prefill, .. } => prefill,
};
// Build chat request
let request_id = format!("chatcmpl-{}", Uuid::new_v4());
let body_ref = prep.filtered_request.as_ref().unwrap_or(&chat_request);
let mut proto_request = builder_client
.build_generate_request(
request_id,
body_ref,
prep.processed_messages.as_ref().unwrap().text.clone(),
prep.token_ids.clone(),
prep.processed_messages
.as_ref()
.unwrap()
.multimodal_inputs
.clone(),
prep.tool_constraints.clone(),
)
.map_err(|e| error::bad_request(format!("Invalid request parameters: {}", e)))?;
// Inject PD metadata if needed
if self.inject_pd_metadata {
if let WorkerSelection::Dual { prefill, .. } = ctx.state.workers.as_ref().unwrap() {
helpers::inject_bootstrap_metadata(&mut proto_request, prefill);
}
}
ctx.state.proto_request = Some(proto_request);
Ok(None)
}
fn name(&self) -> &'static str {
"ChatRequestBuilding"
}
}
//! Chat response processing stage: Handles both streaming and non-streaming responses
//!
//! - For streaming: Spawns background task and returns SSE response (early exit)
//! - For non-streaming: Collects all responses and builds final ChatCompletionResponse
use std::sync::Arc;
use async_trait::async_trait;
use axum::response::Response;
use crate::routers::grpc::{
common::stages::PipelineStage,
context::{FinalResponse, RequestContext},
error,
regular::{processor, streaming},
};
/// Chat response processing stage
///
/// Extracts chat-specific response processing logic from the old unified ResponseProcessingStage.
pub struct ChatResponseProcessingStage {
processor: processor::ResponseProcessor,
streaming_processor: Arc<streaming::StreamingProcessor>,
}
impl ChatResponseProcessingStage {
pub fn new(
processor: processor::ResponseProcessor,
streaming_processor: Arc<streaming::StreamingProcessor>,
) -> Self {
Self {
processor,
streaming_processor,
}
}
}
#[async_trait]
impl PipelineStage for ChatResponseProcessingStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
self.process_chat_response(ctx).await
}
fn name(&self) -> &'static str {
"ChatResponseProcessing"
}
}
impl ChatResponseProcessingStage {
async fn process_chat_response(
&self,
ctx: &mut RequestContext,
) -> Result<Option<Response>, Response> {
let is_streaming = ctx.is_streaming();
// Extract execution result
let execution_result = ctx
.state
.response
.execution_result
.take()
.ok_or_else(|| error::internal_error("No execution result"))?;
// Get dispatch metadata (needed by both streaming and non-streaming)
let dispatch = ctx
.state
.dispatch
.as_ref()
.ok_or_else(|| error::internal_error("Dispatch metadata not set"))?
.clone();
if is_streaming {
// Streaming: Use StreamingProcessor and return SSE response (done)
return Ok(Some(
self.streaming_processor.clone().process_streaming_response(
execution_result,
ctx.chat_request_arc(), // Cheap Arc clone (8 bytes)
dispatch,
),
));
}
// Non-streaming: Delegate to ResponseProcessor
let request_logprobs = ctx.chat_request().logprobs;
let chat_request = ctx.chat_request_arc();
let stop_decoder = ctx
.state
.response
.stop_decoder
.as_mut()
.ok_or_else(|| error::internal_error("Stop decoder not initialized"))?;
let response = self
.processor
.process_non_streaming_chat_response(
execution_result,
chat_request,
dispatch,
stop_decoder,
request_logprobs,
)
.await?;
// Store the final response
ctx.state.response.final_response = Some(FinalResponse::Chat(response));
Ok(None)
}
}
//! Generate endpoint pipeline stages
//!
//! These stages handle generate-specific preprocessing, request building, and response processing.
//! They work with any model type by using injected model adapters.
mod preparation;
mod request_building;
mod response_processing;
pub use preparation::GeneratePreparationStage;
pub use request_building::GenerateRequestBuildingStage;
pub use response_processing::GenerateResponseProcessingStage;
//! Preparation stage: Filter tools, process messages, tokenize, build constraints
//! Generate preparation stage: Resolve input, tokenize, create stop decoder
use std::{borrow::Cow, sync::Arc};
use std::sync::Arc;
use async_trait::async_trait;
use axum::response::Response;
use super::PipelineStage;
use crate::{
protocols::{chat::ChatCompletionRequest, common::InputIds, generate::GenerateRequest},
protocols::{common::InputIds, generate::GenerateRequest},
routers::grpc::{
context::{PreparationOutput, RequestContext, RequestType},
common::stages::PipelineStage,
context::{PreparationOutput, RequestContext},
error, utils,
},
tokenizer::traits::Tokenizer,
};
/// Preparation stage: Filter tools, process messages, tokenize, build constraints
pub struct PreparationStage;
/// Generate preparation stage
///
/// Extracts generate-specific preparation logic from the old unified PreparationStage.
/// This is a direct extraction without architectural changes.
pub struct GeneratePreparationStage;
#[async_trait]
impl PipelineStage for PreparationStage {
impl PipelineStage for GeneratePreparationStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
// Clone Arc before match to avoid borrow checker issues
// (matching borrows ctx, but prepare_* methods need mutable borrow)
// Arc clone is cheap (8 bytes) - avoids full request clone (15KB-200KB)
let is_chat = matches!(&ctx.input.request_type, RequestType::Chat(_));
if is_chat {
let request_arc = ctx.chat_request_arc();
self.prepare_chat(ctx, &request_arc).await?;
} else {
let request_arc = ctx.generate_request_arc();
self.prepare_generate(ctx, &request_arc).await?;
}
let request = ctx.generate_request_arc();
self.prepare_generate(ctx, &request).await?;
Ok(None)
}
fn name(&self) -> &'static str {
"Preparation"
"GeneratePreparation"
}
}
impl PreparationStage {
async fn prepare_chat(
&self,
ctx: &mut RequestContext,
request: &ChatCompletionRequest,
) -> Result<(), Response> {
// Step 1: Filter tools if needed
let body_ref = utils::filter_tools_for_request(request);
// Step 2: Process messages and apply chat template
let processed_messages =
match utils::process_chat_messages(&body_ref, &*ctx.components.tokenizer) {
Ok(msgs) => msgs,
Err(e) => {
return Err(error::bad_request(e));
}
};
// Step 3: Tokenize the processed text
let encoding = match ctx.components.tokenizer.encode(&processed_messages.text) {
Ok(encoding) => encoding,
Err(e) => {
return Err(error::internal_error(format!("Tokenization failed: {}", e)));
}
};
let token_ids = encoding.token_ids().to_vec();
// Step 4: Build tool constraints if needed
let tool_call_constraint = if let Some(tools) = body_ref.tools.as_ref() {
utils::generate_tool_constraints(tools, &request.tool_choice, &request.model)
.map_err(|e| error::bad_request(format!("Invalid tool configuration: {}", e)))?
} else {
None
};
// Step 5: Create stop sequence decoder (build once, reuse in non-stream)
let stop_decoder = utils::create_stop_decoder(
&ctx.components.tokenizer,
request.stop.as_ref(),
request.stop_token_ids.as_ref(),
request.skip_special_tokens,
request.no_stop_trim,
);
// Store results in context
ctx.state.preparation = Some(PreparationOutput {
original_text: Some(processed_messages.text.clone()),
token_ids,
processed_messages: Some(processed_messages),
tool_constraints: tool_call_constraint,
filtered_request: if matches!(body_ref, Cow::Owned(_)) {
Some(body_ref.into_owned())
} else {
None
},
// Harmony fields (not used for regular preparation)
harmony_mode: false,
selection_text: None,
harmony_messages: None,
harmony_stop_ids: None,
});
// Store stop decoder for reuse in response processing
ctx.state.response.stop_decoder = Some(stop_decoder);
Ok(())
}
impl GeneratePreparationStage {
async fn prepare_generate(
&self,
ctx: &mut RequestContext,
......
//! Generate request building stage: Build proto GenerateRequest for generate requests
use async_trait::async_trait;
use axum::response::Response;
use uuid::Uuid;
use crate::routers::grpc::{
common::stages::{helpers, PipelineStage},
context::{ClientSelection, RequestContext, WorkerSelection},
error,
};
/// Generate request building stage
///
/// Extracts generate-specific request building logic from the old unified RequestBuildingStage.
pub struct GenerateRequestBuildingStage {
inject_pd_metadata: bool,
}
impl GenerateRequestBuildingStage {
pub fn new(inject_pd_metadata: bool) -> Self {
Self { inject_pd_metadata }
}
}
#[async_trait]
impl PipelineStage for GenerateRequestBuildingStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
let prep = ctx
.state
.preparation
.as_ref()
.ok_or_else(|| error::internal_error("Preparation not completed"))?;
let clients = ctx
.state
.clients
.as_ref()
.ok_or_else(|| error::internal_error("Client acquisition not completed"))?;
let generate_request = ctx.generate_request_arc();
// Get client for building request (use prefill client if PD mode)
let builder_client = match clients {
ClientSelection::Single { client } => client,
ClientSelection::Dual { prefill, .. } => prefill,
};
// Build generate request
let request_id = generate_request
.rid
.clone()
.unwrap_or_else(|| format!("gen-{}", Uuid::new_v4()));
let mut proto_request = builder_client
.build_plain_generate_request(
request_id,
&generate_request,
prep.original_text.clone(),
prep.token_ids.clone(),
)
.map_err(error::bad_request)?;
// Inject PD metadata if needed
if self.inject_pd_metadata {
if let WorkerSelection::Dual { prefill, .. } = ctx.state.workers.as_ref().unwrap() {
helpers::inject_bootstrap_metadata(&mut proto_request, prefill);
}
}
ctx.state.proto_request = Some(proto_request);
Ok(None)
}
fn name(&self) -> &'static str {
"GenerateRequestBuilding"
}
}
//! Response processing stage: Handles both streaming and non-streaming responses
//!
//! - For streaming: Spawns background task and returns SSE response (early exit)
//! - For non-streaming: Collects all responses and builds final ChatCompletionResponse
//! Generate response processing stage: Handles both streaming and non-streaming responses
use std::{sync::Arc, time::Instant};
use async_trait::async_trait;
use axum::response::Response;
use super::PipelineStage;
use crate::routers::grpc::{
context::{FinalResponse, RequestContext, RequestType},
error, processing, streaming,
common::stages::PipelineStage,
context::{FinalResponse, RequestContext},
error,
regular::{processor, streaming},
};
/// Response processing stage: Handles both streaming and non-streaming responses
/// Generate response processing stage
///
/// - For streaming: Spawns background task and returns SSE response (early exit)
/// - For non-streaming: Collects all responses and builds final ChatCompletionResponse
pub struct ResponseProcessingStage {
processor: processing::ResponseProcessor,
/// Extracts generate-specific response processing logic from the old unified ResponseProcessingStage.
pub struct GenerateResponseProcessingStage {
processor: processor::ResponseProcessor,
streaming_processor: Arc<streaming::StreamingProcessor>,
}
impl ResponseProcessingStage {
impl GenerateResponseProcessingStage {
pub fn new(
processor: processing::ResponseProcessor,
processor: processor::ResponseProcessor,
streaming_processor: Arc<streaming::StreamingProcessor>,
) -> Self {
Self {
......@@ -36,89 +33,17 @@ impl ResponseProcessingStage {
}
#[async_trait]
impl PipelineStage for ResponseProcessingStage {
impl PipelineStage for GenerateResponseProcessingStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
// Delegate to request-type specific processing
match &ctx.input.request_type {
RequestType::Chat(_) => self.process_chat_response(ctx).await,
RequestType::Generate(_) => self.process_generate_response(ctx).await,
RequestType::Responses(_) => Err(error::bad_request(
"Responses API processing must be handled by responses handler".to_string(),
)),
}
self.process_generate_response(ctx).await
}
fn name(&self) -> &'static str {
"ResponseProcessing"
"GenerateResponseProcessing"
}
}
impl ResponseProcessingStage {
async fn process_chat_response(
&self,
ctx: &mut RequestContext,
) -> Result<Option<Response>, Response> {
let is_streaming = ctx.is_streaming();
// Extract execution result
let execution_result = ctx
.state
.response
.execution_result
.take()
.ok_or_else(|| error::internal_error("No execution result"))?;
// Get dispatch metadata (needed by both streaming and non-streaming)
let dispatch = ctx
.state
.dispatch
.as_ref()
.ok_or_else(|| error::internal_error("Dispatch metadata not set"))?
.clone();
if is_streaming {
// Streaming: Use StreamingProcessor and return SSE response (done)
return Ok(Some(
self.streaming_processor.clone().process_streaming_response(
execution_result,
ctx.chat_request_arc(), // Cheap Arc clone (8 bytes)
dispatch,
),
));
}
// Non-streaming: Delegate to ResponseProcessor
let request_logprobs = match &ctx.input.request_type {
RequestType::Chat(req) => req.logprobs,
_ => false,
};
let chat_request = ctx.chat_request_arc();
let stop_decoder = ctx
.state
.response
.stop_decoder
.as_mut()
.ok_or_else(|| error::internal_error("Stop decoder not initialized"))?;
let response = self
.processor
.process_non_streaming_chat_response(
execution_result,
chat_request,
dispatch,
stop_decoder,
request_logprobs,
)
.await?;
// Store the final response
ctx.state.response.final_response = Some(FinalResponse::Chat(response));
Ok(None)
}
impl GenerateResponseProcessingStage {
async fn process_generate_response(
&self,
ctx: &mut RequestContext,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment