Unverified Commit 03b3e89a authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] Harmony Pipeline: Chat Completion & Responses API with MCP Support (#12153)


Co-authored-by: default avatarChang Su <chang.s.su@oracle.com>
parent 9ff9fa7f
//! Pipeline stages for gRPC router request processing
//! Pipeline orchestrator for gRPC router request processing
//!
//! This module defines the core pipeline abstraction and individual processing stages
//! that transform a RequestContext through its lifecycle.
//! This module defines the RequestPipeline orchestrator that coordinates
//! the execution of pipeline stages from request preparation to response delivery.
use std::{
borrow::Cow,
collections::HashMap,
sync::Arc,
time::{Instant, SystemTime, UNIX_EPOCH},
};
use std::{collections::HashMap, sync::Arc};
use async_trait::async_trait;
use axum::response::{IntoResponse, Response};
use proto::DisaggregatedParams;
use rand::Rng;
use tokio::sync::RwLock;
use tracing::{debug, error, warn};
use uuid::Uuid;
use tracing::{debug, error};
use super::{context::*, processing, responses::BackgroundTaskInfo, streaming, utils};
// Import all stage types from the stages module
use super::stages::*;
use super::{context::*, harmony, processing, responses::BackgroundTaskInfo, streaming, utils};
use crate::{
core::{ConnectionMode, Worker, WorkerRegistry, WorkerType},
grpc_client::proto,
core::WorkerRegistry,
policies::PolicyRegistry,
protocols::{
chat::{ChatCompletionRequest, ChatCompletionResponse},
common::InputIds,
generate::GenerateRequest,
},
reasoning_parser::ParserFactory as ReasoningParserFactory,
......@@ -33,875 +24,6 @@ use crate::{
tool_parser::ParserFactory as ToolParserFactory,
};
// ============================================================================
// Pipeline Trait
// ============================================================================
/// Trait for pipeline stages that process requests
#[async_trait]
pub trait PipelineStage: Send + Sync {
/// Execute this stage, mutating the context
///
/// Returns:
/// - `Ok(None)` - Continue to next stage
/// - `Ok(Some(response))` - Pipeline complete, return this response (e.g., streaming)
/// - `Err(response)` - Error occurred, return this error response
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response>;
/// Stage name for logging
fn name(&self) -> &'static str;
}
// ============================================================================
// Stage 1: Preparation
// ============================================================================
/// Preparation stage: Filter tools, process messages, tokenize, build constraints
pub struct PreparationStage;
#[async_trait]
impl PipelineStage for PreparationStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
// Clone Arc before match to avoid borrow checker issues
// (matching borrows ctx, but prepare_* methods need mutable borrow)
// Arc clone is cheap (8 bytes) - avoids full request clone (15KB-200KB)
let is_chat = matches!(&ctx.input.request_type, RequestType::Chat(_));
if is_chat {
let request_arc = ctx.chat_request_arc();
self.prepare_chat(ctx, &request_arc).await?;
} else {
let request_arc = ctx.generate_request_arc();
self.prepare_generate(ctx, &request_arc).await?;
}
Ok(None)
}
fn name(&self) -> &'static str {
"Preparation"
}
}
impl PreparationStage {
async fn prepare_chat(
&self,
ctx: &mut RequestContext,
request: &ChatCompletionRequest,
) -> Result<(), Response> {
// Step 1: Filter tools if needed
let body_ref = utils::filter_tools_for_request(request);
// Step 2: Process messages and apply chat template
let processed_messages =
match utils::process_chat_messages(&body_ref, &*ctx.components.tokenizer) {
Ok(msgs) => msgs,
Err(e) => {
return Err(utils::bad_request_error(e));
}
};
// Step 3: Tokenize the processed text
let encoding = match ctx.components.tokenizer.encode(&processed_messages.text) {
Ok(encoding) => encoding,
Err(e) => {
return Err(utils::internal_error_message(format!(
"Tokenization failed: {}",
e
)));
}
};
let token_ids = encoding.token_ids().to_vec();
// Step 4: Build tool constraints if needed
let tool_call_constraint = if let Some(tools) = body_ref.tools.as_ref() {
utils::generate_tool_constraints(tools, &request.tool_choice, &request.model).map_err(
|e| utils::bad_request_error(format!("Invalid tool configuration: {}", e)),
)?
} else {
None
};
// Step 5: Create stop sequence decoder (build once, reuse in non-stream)
let stop_decoder = utils::create_stop_decoder(
&ctx.components.tokenizer,
request.stop.as_ref(),
request.stop_token_ids.as_ref(),
request.skip_special_tokens,
request.no_stop_trim,
);
// Store results in context
ctx.state.preparation = Some(PreparationOutput {
original_text: Some(processed_messages.text.clone()),
token_ids,
processed_messages: Some(processed_messages),
tool_constraints: tool_call_constraint,
filtered_request: if matches!(body_ref, Cow::Owned(_)) {
Some(body_ref.into_owned())
} else {
None
},
});
// Store stop decoder for reuse in response processing
ctx.state.response.stop_decoder = Some(stop_decoder);
Ok(())
}
async fn prepare_generate(
&self,
ctx: &mut RequestContext,
request: &GenerateRequest,
) -> Result<(), Response> {
// Resolve input (text, prompt, or input_ids)
let (original_text, token_ids) = match self.resolve_generate_input(ctx, request) {
Ok(res) => res,
Err(msg) => {
return Err(utils::bad_request_error(msg));
}
};
// Create stop sequence decoder for generate requests
let params = request.sampling_params.as_ref();
let stop_decoder = utils::create_stop_decoder(
&ctx.components.tokenizer,
params.and_then(|p| p.stop.as_ref()),
params.and_then(|p| p.stop_token_ids.as_ref()),
params.and_then(|p| p.skip_special_tokens).unwrap_or(true),
params.and_then(|p| p.no_stop_trim).unwrap_or(false),
);
ctx.state.preparation = Some(PreparationOutput {
original_text,
token_ids,
processed_messages: None,
tool_constraints: None,
filtered_request: None,
});
// Store stop decoder
ctx.state.response.stop_decoder = Some(stop_decoder);
Ok(())
}
fn resolve_generate_input(
&self,
ctx: &RequestContext,
request: &GenerateRequest,
) -> Result<(Option<String>, Vec<u32>), String> {
if let Some(text) = &request.text {
return self
.tokenize_single_text(&ctx.components.tokenizer, text)
.map(|(original, ids)| (Some(original), ids));
}
// Handle input_ids - validate and convert
if let Some(input_ids) = &request.input_ids {
return match input_ids {
InputIds::Single(ids) => ids
.iter()
.map(|&id| u32::try_from(id))
.collect::<Result<Vec<u32>, _>>()
.map(|converted| (None, converted))
.map_err(|_| "input_ids must be non-negative".to_string()),
InputIds::Batch(_) => {
Err("Batch input_ids are not supported over gRPC generate yet".to_string())
}
};
}
Err("Either `text` or `input_ids` must be provided".to_string())
}
fn tokenize_single_text(
&self,
tokenizer: &Arc<dyn Tokenizer>,
text: &str,
) -> Result<(String, Vec<u32>), String> {
let encoding = tokenizer
.encode(text)
.map_err(|e| format!("Tokenization failed: {}", e))?;
Ok((text.to_string(), encoding.token_ids().to_vec()))
}
}
// ============================================================================
// Stage 2: Worker Selection
// ============================================================================
/// Worker selection stage: Select appropriate worker(s) based on routing mode
pub struct WorkerSelectionStage {
worker_registry: Arc<WorkerRegistry>,
policy_registry: Arc<PolicyRegistry>,
mode: WorkerSelectionMode,
}
pub enum WorkerSelectionMode {
/// Regular mode: select single worker
Regular,
/// PD mode: select prefill + decode workers
PrefillDecode,
}
impl WorkerSelectionStage {
pub fn new(
worker_registry: Arc<WorkerRegistry>,
policy_registry: Arc<PolicyRegistry>,
mode: WorkerSelectionMode,
) -> Self {
Self {
worker_registry,
policy_registry,
mode,
}
}
}
#[async_trait]
impl PipelineStage for WorkerSelectionStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
let prep = ctx
.state
.preparation
.as_ref()
.ok_or_else(|| utils::internal_error_static("Preparation stage not completed"))?;
let text = prep.original_text.as_deref();
let workers = match self.mode {
WorkerSelectionMode::Regular => {
match self.select_single_worker(ctx.input.model_id.as_deref(), text) {
Some(w) => WorkerSelection::Single { worker: w },
None => {
return Err(utils::service_unavailable_error(format!(
"No available workers for model: {:?}",
ctx.input.model_id
)));
}
}
}
WorkerSelectionMode::PrefillDecode => {
match self.select_pd_pair(ctx.input.model_id.as_deref(), text) {
Some((prefill, decode)) => WorkerSelection::Dual { prefill, decode },
None => {
return Err(utils::service_unavailable_error(format!(
"No available PD worker pairs for model: {:?}",
ctx.input.model_id
)));
}
}
}
};
ctx.state.workers = Some(workers);
Ok(None)
}
fn name(&self) -> &'static str {
"WorkerSelection"
}
}
impl WorkerSelectionStage {
fn select_single_worker(
&self,
model_id: Option<&str>,
text: Option<&str>,
) -> Option<Arc<dyn Worker>> {
// Get workers for the specified model, filtered by connection mode
let workers = self.worker_registry.get_workers_filtered(
model_id,
Some(WorkerType::Regular),
Some(ConnectionMode::Grpc { port: None }),
false, // get all workers, we'll filter by is_available() next
);
// Filter by availability (health + circuit breaker)
let available: Vec<Arc<dyn Worker>> = workers
.iter()
.filter(|w| w.is_available())
.cloned()
.collect();
if available.is_empty() {
return None;
}
// Get the appropriate policy for this model
let policy = match model_id {
Some(model) => self.policy_registry.get_policy_or_default(model),
None => self.policy_registry.get_default_policy(),
};
// Select worker using the policy
let idx = policy.select_worker(&available, text)?;
Some(available[idx].clone())
}
fn select_pd_pair(
&self,
model_id: Option<&str>,
text: Option<&str>,
) -> Option<(Arc<dyn Worker>, Arc<dyn Worker>)> {
let all_workers = self.worker_registry.get_workers_filtered(
model_id,
None,
Some(ConnectionMode::Grpc { port: None }), // Match any gRPC worker
false,
);
let (available_prefill, available_decode): (Vec<_>, Vec<_>) =
all_workers
.into_iter()
.fold((Vec::new(), Vec::new()), |mut acc, w| {
if w.is_available() {
match w.metadata().worker_type {
WorkerType::Prefill { .. } => acc.0.push(w),
WorkerType::Decode => acc.1.push(w),
_ => {}
}
}
acc
});
if available_prefill.is_empty() {
warn!("No available prefill workers");
return None;
}
if available_decode.is_empty() {
warn!("No available decode workers");
return None;
}
// Select using policies
let policy = match model_id {
Some(model) => self.policy_registry.get_policy_or_default(model),
None => self.policy_registry.get_default_policy(),
};
let prefill_idx = policy.select_worker(&available_prefill, text)?;
let decode_idx = policy.select_worker(&available_decode, text)?;
Some((
available_prefill[prefill_idx].clone(),
available_decode[decode_idx].clone(),
))
}
}
// ============================================================================
// Stage 3: Client Acquisition
// ============================================================================
/// Client acquisition stage: Get gRPC clients from selected workers
pub struct ClientAcquisitionStage;
#[async_trait]
impl PipelineStage for ClientAcquisitionStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
let workers = ctx
.state
.workers
.as_ref()
.ok_or_else(|| utils::internal_error_static("Worker selection not completed"))?;
let clients = match workers {
WorkerSelection::Single { worker } => {
let client = utils::get_grpc_client_from_worker(worker).await?;
ClientSelection::Single { client }
}
WorkerSelection::Dual { prefill, decode } => {
let prefill_client = utils::get_grpc_client_from_worker(prefill).await?;
let decode_client = utils::get_grpc_client_from_worker(decode).await?;
ClientSelection::Dual {
prefill: prefill_client,
decode: decode_client,
}
}
};
ctx.state.clients = Some(clients);
Ok(None)
}
fn name(&self) -> &'static str {
"ClientAcquisition"
}
}
// ============================================================================
// Stage 4: Request Building
// ============================================================================
/// Request building stage: Build proto GenerateRequest
pub struct RequestBuildingStage {
inject_pd_metadata: bool,
}
impl RequestBuildingStage {
pub fn new(inject_pd_metadata: bool) -> Self {
Self { inject_pd_metadata }
}
}
#[async_trait]
impl PipelineStage for RequestBuildingStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
let prep = ctx
.state
.preparation
.as_ref()
.ok_or_else(|| utils::internal_error_static("Preparation not completed"))?;
let clients = ctx
.state
.clients
.as_ref()
.ok_or_else(|| utils::internal_error_static("Client acquisition not completed"))?;
// Get client for building request (use prefill client if PD mode)
let builder_client = match clients {
ClientSelection::Single { client } => client,
ClientSelection::Dual { prefill, .. } => prefill,
};
let mut proto_request = match &ctx.input.request_type {
RequestType::Chat(request) => {
let request_id = format!("chatcmpl-{}", Uuid::new_v4());
let body_ref = prep.filtered_request.as_ref().unwrap_or(request);
builder_client
.build_generate_request(
request_id,
body_ref,
prep.processed_messages.as_ref().unwrap().text.clone(),
prep.token_ids.clone(),
prep.processed_messages
.as_ref()
.unwrap()
.multimodal_inputs
.clone(),
prep.tool_constraints.clone(),
)
.map_err(|e| {
utils::bad_request_error(format!("Invalid request parameters: {}", e))
})?
}
RequestType::Generate(request) => {
let request_id = request
.rid
.clone()
.unwrap_or_else(|| format!("gen-{}", Uuid::new_v4()));
builder_client
.build_plain_generate_request(
request_id,
request,
prep.original_text.clone(),
prep.token_ids.clone(),
)
.map_err(utils::bad_request_error)?
}
};
// Inject PD metadata if needed
if self.inject_pd_metadata {
if let WorkerSelection::Dual { prefill, .. } = ctx.state.workers.as_ref().unwrap() {
self.inject_bootstrap_metadata(&mut proto_request, prefill);
}
}
ctx.state.proto_request = Some(proto_request);
Ok(None)
}
fn name(&self) -> &'static str {
"RequestBuilding"
}
}
impl RequestBuildingStage {
fn inject_bootstrap_metadata(
&self,
request: &mut proto::GenerateRequest,
prefill_worker: &Arc<dyn Worker>,
) {
let hostname = prefill_worker.bootstrap_host();
let bootstrap_port = prefill_worker.bootstrap_port().unwrap_or(8998);
// Generate room ID for bootstrap
let room_id = rand::rng().random_range(0..i32::MAX);
// Create DisaggregatedParams
let disagg_params = DisaggregatedParams {
bootstrap_host: hostname.to_string(),
bootstrap_port: bootstrap_port as i32,
bootstrap_room: room_id,
};
// Inject metadata directly into request
request.disaggregated_params = Some(disagg_params);
debug!(
"Injected bootstrap metadata: host={}, port={}, room={}",
hostname, bootstrap_port, room_id
);
}
}
// ============================================================================
// Stage 5: Dispatch Metadata
// ============================================================================
/// Dispatch metadata stage: Prepare metadata for dispatch
pub struct DispatchMetadataStage;
#[async_trait]
impl PipelineStage for DispatchMetadataStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
let proto_request = ctx
.state
.proto_request
.as_ref()
.ok_or_else(|| utils::internal_error_static("Proto request not built"))?;
let request_id = proto_request.request_id.clone();
let model = match &ctx.input.request_type {
RequestType::Chat(req) => req.model.clone(),
RequestType::Generate(_req) => {
// Generate requests don't have a model field
// Use model_id from input or default
ctx.input
.model_id
.clone()
.unwrap_or_else(|| "default".to_string())
}
};
let weight_version = ctx
.state
.workers
.as_ref()
.map(|w| match w {
WorkerSelection::Single { worker } => worker,
WorkerSelection::Dual { decode, .. } => decode,
})
.and_then(|w| w.metadata().labels.get("weight_version").cloned())
.unwrap_or_else(|| "default".to_string());
let created = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
ctx.state.dispatch = Some(DispatchMetadata {
request_id,
model,
created,
weight_version: Some(weight_version),
is_streaming: ctx.is_streaming(),
});
Ok(None)
}
fn name(&self) -> &'static str {
"DispatchMetadata"
}
}
// ============================================================================
// Stage 6: Request Execution
// ============================================================================
/// Request execution stage: Execute gRPC requests (single or dual dispatch)
pub struct RequestExecutionStage {
mode: ExecutionMode,
}
pub enum ExecutionMode {
/// Regular mode: single worker execution
Single,
/// PD mode: dual dispatch to prefill + decode workers
DualDispatch,
}
impl RequestExecutionStage {
pub fn new(mode: ExecutionMode) -> Self {
Self { mode }
}
}
#[async_trait]
impl PipelineStage for RequestExecutionStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
let proto_request = ctx
.state
.proto_request
.take()
.ok_or_else(|| utils::internal_error_static("Proto request not built"))?;
let clients = ctx
.state
.clients
.as_mut()
.ok_or_else(|| utils::internal_error_static("Client acquisition not completed"))?;
let result = match self.mode {
ExecutionMode::Single => self.execute_single(proto_request, clients).await?,
ExecutionMode::DualDispatch => {
self.execute_dual_dispatch(proto_request, clients).await?
}
};
// Store result in context for ResponseProcessingStage
ctx.state.response.execution_result = Some(result);
Ok(None)
}
fn name(&self) -> &'static str {
"RequestExecution"
}
}
impl RequestExecutionStage {
async fn execute_single(
&self,
proto_request: proto::GenerateRequest,
clients: &mut ClientSelection,
) -> Result<ExecutionResult, Response> {
let client = clients
.single_mut()
.ok_or_else(|| utils::internal_error_static("Expected single client but got dual"))?;
let stream = client.generate(proto_request).await.map_err(|e| {
utils::internal_error_message(format!("Failed to start generation: {}", e))
})?;
Ok(ExecutionResult::Single { stream })
}
async fn execute_dual_dispatch(
&self,
proto_request: proto::GenerateRequest,
clients: &mut ClientSelection,
) -> Result<ExecutionResult, Response> {
let (prefill_client, decode_client) = clients
.dual_mut()
.ok_or_else(|| utils::internal_error_static("Expected dual clients but got single"))?;
let prefill_request = proto_request.clone();
let decode_request = proto_request;
let (prefill_result, decode_result) = tokio::join!(
prefill_client.generate(prefill_request),
decode_client.generate(decode_request)
);
// Handle prefill result
let prefill_stream = match prefill_result {
Ok(s) => s,
Err(e) => {
return Err(utils::internal_error_message(format!(
"Prefill worker failed to start: {}",
e
)));
}
};
// Handle decode result
let decode_stream = match decode_result {
Ok(s) => s,
Err(e) => {
return Err(utils::internal_error_message(format!(
"Decode worker failed to start: {}",
e
)));
}
};
Ok(ExecutionResult::Dual {
prefill: prefill_stream,
decode: Box::new(decode_stream),
})
}
}
// ============================================================================
// Stage 7: Response Processing
// ============================================================================
/// Response processing stage: Handles both streaming and non-streaming responses
///
/// - For streaming: Spawns background task and returns SSE response (early exit)
/// - For non-streaming: Collects all responses and builds final ChatCompletionResponse
pub struct ResponseProcessingStage {
processor: processing::ResponseProcessor,
streaming_processor: Arc<streaming::StreamingProcessor>,
}
impl ResponseProcessingStage {
pub fn new(
processor: processing::ResponseProcessor,
streaming_processor: Arc<streaming::StreamingProcessor>,
) -> Self {
Self {
processor,
streaming_processor,
}
}
}
#[async_trait]
impl PipelineStage for ResponseProcessingStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
// Delegate to request-type specific processing
match &ctx.input.request_type {
RequestType::Chat(_) => return self.process_chat_response(ctx).await,
RequestType::Generate(_) => return self.process_generate_response(ctx).await,
}
}
fn name(&self) -> &'static str {
"ResponseProcessing"
}
}
impl ResponseProcessingStage {
async fn process_chat_response(
&self,
ctx: &mut RequestContext,
) -> Result<Option<Response>, Response> {
let is_streaming = ctx.is_streaming();
// Extract execution result
let execution_result = ctx
.state
.response
.execution_result
.take()
.ok_or_else(|| utils::internal_error_static("No execution result"))?;
// Get dispatch metadata (needed by both streaming and non-streaming)
let dispatch = ctx
.state
.dispatch
.as_ref()
.ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?
.clone();
if is_streaming {
// Streaming: Use StreamingProcessor and return SSE response (done)
return Ok(Some(
self.streaming_processor.clone().process_streaming_response(
execution_result,
ctx.chat_request_arc(), // Cheap Arc clone (8 bytes)
dispatch,
),
));
}
// Non-streaming: Delegate to ResponseProcessor
let request_logprobs = match &ctx.input.request_type {
RequestType::Chat(req) => req.logprobs,
_ => false,
};
let chat_request = ctx.chat_request_arc();
let stop_decoder = ctx
.state
.response
.stop_decoder
.as_mut()
.ok_or_else(|| utils::internal_error_static("Stop decoder not initialized"))?;
let response = self
.processor
.process_non_streaming_chat_response(
execution_result,
chat_request,
dispatch,
stop_decoder,
request_logprobs,
)
.await?;
// Store the final response
ctx.state.response.final_response = Some(FinalResponse::Chat(response));
Ok(None)
}
async fn process_generate_response(
&self,
ctx: &mut RequestContext,
) -> Result<Option<Response>, Response> {
let start_time = Instant::now();
let is_streaming = ctx.is_streaming();
// Extract execution result
let execution_result = ctx
.state
.response
.execution_result
.take()
.ok_or_else(|| utils::internal_error_static("No execution result"))?;
// Get dispatch metadata (needed by both streaming and non-streaming)
let dispatch = ctx
.state
.dispatch
.as_ref()
.ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?
.clone();
if is_streaming {
// Streaming: Use StreamingProcessor and return SSE response (done)
return Ok(Some(
self.streaming_processor.clone().process_streaming_generate(
execution_result,
ctx.generate_request_arc(), // Cheap Arc clone (8 bytes)
dispatch,
),
));
}
// Non-streaming: Delegate to ResponseProcessor
let request_logprobs = ctx.generate_request().return_logprob.unwrap_or(false);
let generate_request = ctx.generate_request_arc();
let stop_decoder = ctx
.state
.response
.stop_decoder
.as_mut()
.ok_or_else(|| utils::internal_error_static("Stop decoder not initialized"))?;
let result_array = self
.processor
.process_non_streaming_generate_response(
execution_result,
generate_request,
dispatch,
stop_decoder,
request_logprobs,
start_time,
)
.await?;
// Store the final response
ctx.state.response.final_response = Some(FinalResponse::Generate(result_array));
Ok(None)
}
}
// ============================================================================
// Pipeline Orchestrator
// ============================================================================
......@@ -963,6 +85,64 @@ impl RequestPipeline {
}
}
/// Create a Harmony (single-worker) pipeline for Harmony-capable models
pub fn new_harmony(
worker_registry: Arc<WorkerRegistry>,
policy_registry: Arc<PolicyRegistry>,
_tokenizer: Arc<dyn Tokenizer>,
_tool_parser_factory: ToolParserFactory,
_reasoning_parser_factory: ReasoningParserFactory,
_configured_tool_parser: Option<String>,
_configured_reasoning_parser: Option<String>,
) -> Self {
let stages: Vec<Box<dyn PipelineStage>> = vec![
Box::new(harmony::stages::HarmonyPreparationStage::new()),
Box::new(WorkerSelectionStage::new(
worker_registry,
policy_registry,
WorkerSelectionMode::Regular,
)),
Box::new(ClientAcquisitionStage),
Box::new(harmony::stages::HarmonyRequestBuildingStage::new(false)),
Box::new(DispatchMetadataStage),
Box::new(RequestExecutionStage::new(ExecutionMode::Single)),
Box::new(harmony::stages::HarmonyResponseProcessingStage::new()),
];
Self {
stages: Arc::new(stages),
}
}
/// Create a Harmony PD (prefill-decode) pipeline
pub fn new_harmony_pd(
worker_registry: Arc<WorkerRegistry>,
policy_registry: Arc<PolicyRegistry>,
_tokenizer: Arc<dyn Tokenizer>,
_tool_parser_factory: ToolParserFactory,
_reasoning_parser_factory: ReasoningParserFactory,
_configured_tool_parser: Option<String>,
_configured_reasoning_parser: Option<String>,
) -> Self {
let stages: Vec<Box<dyn PipelineStage>> = vec![
Box::new(harmony::stages::HarmonyPreparationStage::new()),
Box::new(WorkerSelectionStage::new(
worker_registry,
policy_registry,
WorkerSelectionMode::PrefillDecode,
)),
Box::new(ClientAcquisitionStage),
Box::new(harmony::stages::HarmonyRequestBuildingStage::new(true)),
Box::new(DispatchMetadataStage),
Box::new(RequestExecutionStage::new(ExecutionMode::DualDispatch)),
Box::new(harmony::stages::HarmonyResponseProcessingStage::new()),
];
Self {
stages: Arc::new(stages),
}
}
/// Create a PD (prefill-decode) pipeline
pub fn new_pd(
worker_registry: Arc<WorkerRegistry>,
......@@ -1186,4 +366,92 @@ impl RequestPipeline {
None => Err(utils::internal_error_static("No response produced")),
}
}
/// Execute Responses API pipeline
///
/// TODO: Implement Responses API native execution
/// This is a stub to allow compilation. The actual implementation should:
/// 1. Support multi-turn MCP loop orchestration
/// 2. Handle tool call execution and result injection
/// 3. Emit proper SSE events for streaming mode
/// 4. Store responses in data connector
///
/// For now, this returns an error indicating the feature is not implemented.
pub async fn execute_responses(
&self,
_request: Arc<crate::protocols::responses::ResponsesRequest>,
_headers: Option<http::HeaderMap>,
_model_id: Option<String>,
_components: Arc<SharedComponents>,
) -> Response {
utils::internal_error_static("Responses API execution not yet implemented")
}
/// Execute Harmony Responses API request through all pipeline stages
///
/// This method runs a single iteration of the Responses API request,
/// returning either ToolCallsFound (continue serving) or Completed (final response).
///
/// Called by harmony::responses::serve_harmony_responses() for each iteration.
///
/// # Arguments
///
/// * `request` - Responses API request
/// * `ctx` - Harmony Responses context with MCP manager and components
///
/// # Returns
///
/// ResponsesIterationResult indicating whether to continue iteration or return
pub async fn execute_harmony_responses(
&self,
request: &crate::protocols::responses::ResponsesRequest,
harmony_ctx: &harmony::responses::HarmonyResponsesContext,
) -> Result<harmony::ResponsesIterationResult, Response> {
// Create RequestContext for this Responses request
let mut ctx = RequestContext::for_responses(
Arc::new(request.clone()),
None, // No headers needed for internal pipeline execution
None, // Model ID already set in request
harmony_ctx.components.clone(),
);
// Execute each pipeline stage in sequence
for (idx, stage) in self.stages.iter().enumerate() {
match stage.execute(&mut ctx).await {
Ok(Some(response)) => {
// Stage returned early response (e.g., streaming) - not expected for Responses iteration
error!(
"Stage {} ({}) returned unexpected response during Responses iteration",
idx + 1,
stage.name()
);
return Err(response);
}
Ok(None) => {
// Stage completed successfully, continue to next stage
continue;
}
Err(response) => {
// Stage failed
error!(
"Stage {} ({}) failed with status {}",
idx + 1,
stage.name(),
response.status()
);
return Err(response);
}
}
}
// Extract ResponsesIterationResult from context
// This should have been set by HarmonyResponseProcessingStage
ctx.state
.response
.responses_iteration_result
.take()
.ok_or_else(|| {
utils::internal_error_static("No ResponsesIterationResult produced by pipeline")
})
}
}
......@@ -110,7 +110,7 @@ impl ResponseProcessor {
Ok(all_responses)
}
/// Process a single choice from GenerateComplete response (EXACT COPY from router.rs:1573-1725)
/// Process a single choice from GenerateComplete response
#[allow(clippy::too_many_arguments)]
pub async fn process_single_choice(
&self,
......@@ -367,7 +367,7 @@ impl ResponseProcessor {
Ok(response)
}
/// Parse tool calls using model-specific parser (EXACT COPY from router.rs:296-361)
/// Parse tool calls using model-specific parser
pub async fn parse_tool_calls(
&self,
processed_text: &str,
......
......@@ -12,7 +12,8 @@ use crate::protocols::{
common::{FunctionCallResponse, StreamOptions, ToolCall, UsageInfo},
responses::{
ResponseContentPart, ResponseInput, ResponseInputOutputItem, ResponseOutputItem,
ResponseStatus, ResponsesRequest, ResponsesResponse, ResponsesUsage,
ResponseReasoningContent::ReasoningText, ResponseStatus, ResponsesRequest,
ResponsesResponse, ResponsesUsage, StringOrContentParts,
},
};
......@@ -50,7 +51,6 @@ pub fn responses_to_chat(req: &ResponsesRequest) -> Result<ChatCompletionRequest
match item {
ResponseInputOutputItem::SimpleInputMessage { content, role, .. } => {
// Convert SimpleInputMessage to chat message
use crate::protocols::responses::StringOrContentParts;
let text = match content {
StringOrContentParts::String(s) => s.clone(),
StringOrContentParts::Array(parts) => {
......@@ -170,9 +170,7 @@ pub fn responses_to_chat(req: &ResponsesRequest) -> Result<ChatCompletionRequest
let reasoning_text = content
.iter()
.map(|c| match c {
crate::protocols::responses::ResponseReasoningContent::ReasoningText { text } => {
text.as_str()
}
ReasoningText { text } => text.as_str(),
})
.collect::<Vec<_>>()
.join("\n");
......@@ -184,6 +182,17 @@ pub fn responses_to_chat(req: &ResponsesRequest) -> Result<ChatCompletionRequest
reasoning_content: Some(reasoning_text),
});
}
ResponseInputOutputItem::FunctionCallOutput {
call_id, output, ..
} => {
// Function call output - add as tool message
// Note: The function name is looked up from prev_outputs in Harmony path
// For Chat path, we just use the call_id
messages.push(ChatMessage::Tool {
content: output.clone(),
tool_call_id: call_id.clone(),
});
}
}
}
}
......@@ -282,11 +291,9 @@ pub fn chat_to_responses(
output.push(ResponseOutputItem::Reasoning {
id: format!("reasoning_{}", chat_resp.id),
summary: vec![],
content: vec![
crate::protocols::responses::ResponseReasoningContent::ReasoningText {
text: reasoning.clone(),
},
],
content: vec![ReasoningText {
text: reasoning.clone(),
}],
status: Some("completed".to_string()),
});
}
......
......@@ -11,12 +11,15 @@ use axum::{
};
use tracing::debug;
use super::{context::SharedComponents, pipeline::RequestPipeline, responses};
use super::{
context::SharedComponents,
harmony::{serve_harmony_responses, HarmonyDetector, HarmonyResponsesContext},
pipeline::RequestPipeline,
responses,
};
use crate::{
app_context::AppContext,
config::types::RetryConfig,
core::WorkerRegistry,
policies::PolicyRegistry,
protocols::{
chat::ChatCompletionRequest,
classify::ClassifyRequest,
......@@ -26,10 +29,7 @@ use crate::{
rerank::RerankRequest,
responses::{ResponsesGetParams, ResponsesRequest},
},
reasoning_parser::ParserFactory as ReasoningParserFactory,
routers::RouterTrait,
tokenizer::traits::Tokenizer,
tool_parser::ParserFactory as ToolParserFactory,
};
/// gRPC router implementation for SGLang
......@@ -37,19 +37,13 @@ use crate::{
#[allow(dead_code)]
pub struct GrpcRouter {
worker_registry: Arc<WorkerRegistry>,
policy_registry: Arc<PolicyRegistry>,
tokenizer: Arc<dyn Tokenizer>,
reasoning_parser_factory: ReasoningParserFactory,
tool_parser_factory: ToolParserFactory,
dp_aware: bool,
api_key: Option<String>,
retry_config: RetryConfig,
configured_reasoning_parser: Option<String>,
configured_tool_parser: Option<String>,
pipeline: RequestPipeline,
harmony_pipeline: RequestPipeline,
shared_components: Arc<SharedComponents>,
// Responses context (bundles all /v1/responses dependencies: storage, MCP, background_tasks)
responses_context: responses::ResponsesContext,
// Harmony responses context (uses harmony pipeline)
harmony_responses_context: responses::ResponsesContext,
}
impl GrpcRouter {
......@@ -73,7 +67,7 @@ impl GrpcRouter {
.clone();
let worker_registry = ctx.worker_registry.clone();
let policy_registry = ctx.policy_registry.clone();
let _policy_registry = ctx.policy_registry.clone();
// Create shared components for pipeline
let shared_components = Arc::new(SharedComponents {
......@@ -82,10 +76,10 @@ impl GrpcRouter {
reasoning_parser_factory: reasoning_parser_factory.clone(),
});
// Create pipeline
// Create regular pipeline
let pipeline = RequestPipeline::new_regular(
worker_registry.clone(),
policy_registry.clone(),
_policy_registry.clone(),
tokenizer.clone(),
tool_parser_factory.clone(),
reasoning_parser_factory.clone(),
......@@ -93,34 +87,48 @@ impl GrpcRouter {
ctx.configured_reasoning_parser.clone(),
);
// Create responses context with all dependencies
let responses_context = responses::ResponsesContext::new(
Arc::new(pipeline.clone()),
shared_components.clone(),
// Create Harmony pipelines
let harmony_pipeline = RequestPipeline::new_harmony(
worker_registry.clone(),
ctx.response_storage.clone(),
ctx.conversation_storage.clone(),
ctx.conversation_item_storage.clone(),
ctx.mcp_manager
.get()
.ok_or_else(|| "gRPC router requires MCP manager".to_string())?
.clone(),
_policy_registry.clone(),
tokenizer.clone(),
tool_parser_factory.clone(),
reasoning_parser_factory.clone(),
ctx.configured_tool_parser.clone(),
ctx.configured_reasoning_parser.clone(),
);
// Extract shared dependencies for responses contexts
let mcp_manager = ctx
.mcp_manager
.get()
.ok_or_else(|| "gRPC router requires MCP manager".to_string())?
.clone();
// Helper closure to create responses context with a given pipeline
let create_responses_context = |pipeline: &RequestPipeline| {
responses::ResponsesContext::new(
Arc::new(pipeline.clone()),
shared_components.clone(),
worker_registry.clone(),
ctx.response_storage.clone(),
ctx.conversation_storage.clone(),
ctx.conversation_item_storage.clone(),
mcp_manager.clone(),
)
};
// Create responses contexts for both pipelines
let responses_context = create_responses_context(&pipeline);
let harmony_responses_context = create_responses_context(&harmony_pipeline);
Ok(GrpcRouter {
worker_registry,
policy_registry,
tokenizer,
reasoning_parser_factory,
tool_parser_factory,
dp_aware: ctx.router_config.dp_aware,
api_key: ctx.router_config.api_key.clone(),
retry_config: ctx.router_config.effective_retry_config(),
configured_reasoning_parser: ctx.configured_reasoning_parser.clone(),
configured_tool_parser: ctx.configured_tool_parser.clone(),
pipeline,
harmony_pipeline,
shared_components,
responses_context,
harmony_responses_context,
})
}
......@@ -131,13 +139,22 @@ impl GrpcRouter {
body: &ChatCompletionRequest,
model_id: Option<&str>,
) -> Response {
// Choose Harmony pipeline if model indicates Harmony
let is_harmony = HarmonyDetector::is_harmony_model(&body.model);
debug!(
"Processing chat completion request for model: {:?}",
model_id
"Processing chat completion request for model: {:?}, using_harmony={}",
model_id, is_harmony
);
// Use pipeline for ALL requests (streaming and non-streaming)
self.pipeline
let pipeline = if is_harmony {
&self.harmony_pipeline
} else {
&self.pipeline
};
// Use selected pipeline for ALL requests (streaming and non-streaming)
pipeline
.execute_chat(
Arc::new(body.clone()),
headers.cloned(),
......@@ -166,6 +183,33 @@ impl GrpcRouter {
)
.await
}
/// Main route_responses implementation (pipeline-based for Harmony)
async fn route_responses_impl(
&self,
_headers: Option<&HeaderMap>,
body: &ResponsesRequest,
model_id: Option<&str>,
) -> Response {
debug!(
"Processing Harmony responses request for model: {:?}",
model_id
);
// Create HarmonyResponsesContext from existing responses context
let harmony_ctx = HarmonyResponsesContext::new(
Arc::new(self.harmony_pipeline.clone()),
self.shared_components.clone(),
self.harmony_responses_context.mcp_manager.clone(),
self.harmony_responses_context.response_storage.clone(),
);
// Use serve_harmony_responses for multi-turn MCP tool orchestration
match serve_harmony_responses(&harmony_ctx, body.clone()).await {
Ok(response) => axum::Json(response).into_response(),
Err(error_response) => error_response,
}
}
}
impl std::fmt::Debug for GrpcRouter {
......@@ -173,7 +217,6 @@ impl std::fmt::Debug for GrpcRouter {
let stats = self.worker_registry.stats();
f.debug_struct("GrpcRouter")
.field("workers_count", &stats.total_workers)
.field("dp_aware", &self.dp_aware)
.finish()
}
}
......@@ -238,13 +281,27 @@ impl RouterTrait for GrpcRouter {
body: &ResponsesRequest,
model_id: Option<&str>,
) -> Response {
responses::route_responses(
&self.responses_context,
Arc::new(body.clone()),
headers.cloned(),
model_id.map(|s| s.to_string()),
)
.await
// Choose implementation based on Harmony model detection
let is_harmony = HarmonyDetector::is_harmony_model(&body.model);
debug!(
"Processing responses request for model: {:?}, using_harmony={}",
model_id, is_harmony
);
if is_harmony {
// Use pipeline-based implementation for Harmony models
self.route_responses_impl(headers, body, model_id).await
} else {
// Use legacy responses module for non-Harmony models
responses::route_responses(
&self.responses_context,
Arc::new(body.clone()),
headers.cloned(),
model_id.map(|s| s.to_string()),
)
.await
}
}
async fn get_response(
......@@ -260,19 +317,19 @@ impl RouterTrait for GrpcRouter {
responses::cancel_response_impl(&self.responses_context, response_id).await
}
async fn route_classify(
async fn route_embeddings(
&self,
_headers: Option<&HeaderMap>,
_body: &ClassifyRequest,
_body: &EmbeddingRequest,
_model_id: Option<&str>,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_embeddings(
async fn route_classify(
&self,
_headers: Option<&HeaderMap>,
_body: &EmbeddingRequest,
_body: &ClassifyRequest,
_model_id: Option<&str>,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
......
//! Client acquisition stage: Get gRPC clients from selected workers
use async_trait::async_trait;
use axum::response::Response;
use super::PipelineStage;
use crate::routers::grpc::{
context::{ClientSelection, RequestContext, WorkerSelection},
utils,
};
/// Client acquisition stage: Get gRPC clients from selected workers
pub struct ClientAcquisitionStage;
#[async_trait]
impl PipelineStage for ClientAcquisitionStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
let workers = ctx
.state
.workers
.as_ref()
.ok_or_else(|| utils::internal_error_static("Worker selection not completed"))?;
let clients = match workers {
WorkerSelection::Single { worker } => {
let client = utils::get_grpc_client_from_worker(worker).await?;
ClientSelection::Single { client }
}
WorkerSelection::Dual { prefill, decode } => {
let prefill_client = utils::get_grpc_client_from_worker(prefill).await?;
let decode_client = utils::get_grpc_client_from_worker(decode).await?;
ClientSelection::Dual {
prefill: prefill_client,
decode: decode_client,
}
}
};
ctx.state.clients = Some(clients);
Ok(None)
}
fn name(&self) -> &'static str {
"ClientAcquisition"
}
}
//! Dispatch metadata stage: Prepare metadata for dispatch
use std::time::{SystemTime, UNIX_EPOCH};
use async_trait::async_trait;
use axum::response::Response;
use super::PipelineStage;
use crate::routers::grpc::{
context::{DispatchMetadata, RequestContext, RequestType, WorkerSelection},
utils,
};
/// Dispatch metadata stage: Prepare metadata for dispatch
pub struct DispatchMetadataStage;
#[async_trait]
impl PipelineStage for DispatchMetadataStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
let proto_request = ctx
.state
.proto_request
.as_ref()
.ok_or_else(|| utils::internal_error_static("Proto request not built"))?;
let request_id = proto_request.request_id.clone();
let model = match &ctx.input.request_type {
RequestType::Chat(req) => req.model.clone(),
RequestType::Generate(_req) => {
// Generate requests don't have a model field
// Use model_id from input or default
ctx.input
.model_id
.clone()
.unwrap_or_else(|| "default".to_string())
}
RequestType::Responses(req) => req.model.clone(),
};
let weight_version = ctx
.state
.workers
.as_ref()
.map(|w| match w {
WorkerSelection::Single { worker } => worker,
WorkerSelection::Dual { decode, .. } => decode,
})
.and_then(|w| w.metadata().labels.get("weight_version").cloned())
.unwrap_or_else(|| "default".to_string());
let created = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
ctx.state.dispatch = Some(DispatchMetadata {
request_id,
model,
created,
weight_version: Some(weight_version),
is_streaming: ctx.is_streaming(),
});
Ok(None)
}
fn name(&self) -> &'static str {
"DispatchMetadata"
}
}
//! Pipeline stages for gRPC router request processing
//!
//! This module defines the core pipeline abstraction and individual processing stages
//! that transform a RequestContext through its lifecycle.
use async_trait::async_trait;
use axum::response::Response;
use crate::routers::grpc::context::RequestContext;
// ============================================================================
// Pipeline Trait
// ============================================================================
/// Trait for pipeline stages that process requests
#[async_trait]
pub trait PipelineStage: Send + Sync {
/// Execute this stage, mutating the context
///
/// Returns:
/// - `Ok(None)` - Continue to next stage
/// - `Ok(Some(response))` - Pipeline complete, return this response (e.g., streaming)
/// - `Err(response)` - Error occurred, return this error response
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response>;
/// Stage name for logging
fn name(&self) -> &'static str;
}
// ============================================================================
// Stage Modules
// ============================================================================
mod client_acquisition;
mod dispatch_metadata;
mod preparation;
mod request_building;
mod request_execution;
mod response_processing;
mod worker_selection;
// ============================================================================
// Public Exports
// ============================================================================
pub use client_acquisition::ClientAcquisitionStage;
pub use dispatch_metadata::DispatchMetadataStage;
pub use preparation::PreparationStage;
pub use request_building::RequestBuildingStage;
pub use request_execution::{ExecutionMode, RequestExecutionStage};
pub use response_processing::ResponseProcessingStage;
pub use worker_selection::{WorkerSelectionMode, WorkerSelectionStage};
//! Preparation stage: Filter tools, process messages, tokenize, build constraints
use std::{borrow::Cow, sync::Arc};
use async_trait::async_trait;
use axum::response::Response;
use super::PipelineStage;
use crate::{
protocols::{chat::ChatCompletionRequest, common::InputIds, generate::GenerateRequest},
routers::grpc::{
context::{PreparationOutput, RequestContext, RequestType},
utils,
},
tokenizer::traits::Tokenizer,
};
/// Preparation stage: Filter tools, process messages, tokenize, build constraints
pub struct PreparationStage;
#[async_trait]
impl PipelineStage for PreparationStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
// Clone Arc before match to avoid borrow checker issues
// (matching borrows ctx, but prepare_* methods need mutable borrow)
// Arc clone is cheap (8 bytes) - avoids full request clone (15KB-200KB)
let is_chat = matches!(&ctx.input.request_type, RequestType::Chat(_));
if is_chat {
let request_arc = ctx.chat_request_arc();
self.prepare_chat(ctx, &request_arc).await?;
} else {
let request_arc = ctx.generate_request_arc();
self.prepare_generate(ctx, &request_arc).await?;
}
Ok(None)
}
fn name(&self) -> &'static str {
"Preparation"
}
}
impl PreparationStage {
async fn prepare_chat(
&self,
ctx: &mut RequestContext,
request: &ChatCompletionRequest,
) -> Result<(), Response> {
// Step 1: Filter tools if needed
let body_ref = utils::filter_tools_for_request(request);
// Step 2: Process messages and apply chat template
let processed_messages =
match utils::process_chat_messages(&body_ref, &*ctx.components.tokenizer) {
Ok(msgs) => msgs,
Err(e) => {
return Err(utils::bad_request_error(e));
}
};
// Step 3: Tokenize the processed text
let encoding = match ctx.components.tokenizer.encode(&processed_messages.text) {
Ok(encoding) => encoding,
Err(e) => {
return Err(utils::internal_error_message(format!(
"Tokenization failed: {}",
e
)));
}
};
let token_ids = encoding.token_ids().to_vec();
// Step 4: Build tool constraints if needed
let tool_call_constraint = if let Some(tools) = body_ref.tools.as_ref() {
utils::generate_tool_constraints(tools, &request.tool_choice, &request.model).map_err(
|e| utils::bad_request_error(format!("Invalid tool configuration: {}", e)),
)?
} else {
None
};
// Step 5: Create stop sequence decoder (build once, reuse in non-stream)
let stop_decoder = utils::create_stop_decoder(
&ctx.components.tokenizer,
request.stop.as_ref(),
request.stop_token_ids.as_ref(),
request.skip_special_tokens,
request.no_stop_trim,
);
// Store results in context
ctx.state.preparation = Some(PreparationOutput {
original_text: Some(processed_messages.text.clone()),
token_ids,
processed_messages: Some(processed_messages),
tool_constraints: tool_call_constraint,
filtered_request: if matches!(body_ref, Cow::Owned(_)) {
Some(body_ref.into_owned())
} else {
None
},
// Harmony fields (not used for regular preparation)
harmony_mode: false,
selection_text: None,
harmony_messages: None,
harmony_stop_ids: None,
});
// Store stop decoder for reuse in response processing
ctx.state.response.stop_decoder = Some(stop_decoder);
Ok(())
}
async fn prepare_generate(
&self,
ctx: &mut RequestContext,
request: &GenerateRequest,
) -> Result<(), Response> {
// Resolve input (text, prompt, or input_ids)
let (original_text, token_ids) = match self.resolve_generate_input(ctx, request) {
Ok(res) => res,
Err(msg) => {
return Err(utils::bad_request_error(msg));
}
};
// Create stop sequence decoder for generate requests
let params = request.sampling_params.as_ref();
let stop_decoder = utils::create_stop_decoder(
&ctx.components.tokenizer,
params.and_then(|p| p.stop.as_ref()),
params.and_then(|p| p.stop_token_ids.as_ref()),
params.and_then(|p| p.skip_special_tokens).unwrap_or(true),
params.and_then(|p| p.no_stop_trim).unwrap_or(false),
);
ctx.state.preparation = Some(PreparationOutput {
original_text,
token_ids,
processed_messages: None,
tool_constraints: None,
filtered_request: None,
// Harmony fields (not used for generate requests)
harmony_mode: false,
selection_text: None,
harmony_messages: None,
harmony_stop_ids: None,
});
// Store stop decoder
ctx.state.response.stop_decoder = Some(stop_decoder);
Ok(())
}
fn resolve_generate_input(
&self,
ctx: &RequestContext,
request: &GenerateRequest,
) -> Result<(Option<String>, Vec<u32>), String> {
if let Some(text) = &request.text {
return self
.tokenize_single_text(&ctx.components.tokenizer, text)
.map(|(original, ids)| (Some(original), ids));
}
// Handle input_ids - validate and convert
if let Some(input_ids) = &request.input_ids {
return match input_ids {
InputIds::Single(ids) => ids
.iter()
.map(|&id| u32::try_from(id))
.collect::<Result<Vec<u32>, _>>()
.map(|converted| (None, converted))
.map_err(|_| "input_ids must be non-negative".to_string()),
InputIds::Batch(_) => {
Err("Batch input_ids are not supported over gRPC generate yet".to_string())
}
};
}
Err("Either `text` or `input_ids` must be provided".to_string())
}
fn tokenize_single_text(
&self,
tokenizer: &Arc<dyn Tokenizer>,
text: &str,
) -> Result<(String, Vec<u32>), String> {
let encoding = tokenizer
.encode(text)
.map_err(|e| format!("Tokenization failed: {}", e))?;
Ok((text.to_string(), encoding.token_ids().to_vec()))
}
}
//! Request building stage: Build proto GenerateRequest
use std::sync::Arc;
use async_trait::async_trait;
use axum::response::Response;
use proto::DisaggregatedParams;
use rand::Rng;
use tracing::debug;
use uuid::Uuid;
use super::PipelineStage;
use crate::{
core::Worker,
grpc_client::proto,
routers::grpc::{
context::{ClientSelection, RequestContext, RequestType, WorkerSelection},
utils,
},
};
/// Request building stage: Build proto GenerateRequest
pub struct RequestBuildingStage {
inject_pd_metadata: bool,
}
impl RequestBuildingStage {
pub fn new(inject_pd_metadata: bool) -> Self {
Self { inject_pd_metadata }
}
}
#[async_trait]
impl PipelineStage for RequestBuildingStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
let prep = ctx
.state
.preparation
.as_ref()
.ok_or_else(|| utils::internal_error_static("Preparation not completed"))?;
let clients = ctx
.state
.clients
.as_ref()
.ok_or_else(|| utils::internal_error_static("Client acquisition not completed"))?;
// Get client for building request (use prefill client if PD mode)
let builder_client = match clients {
ClientSelection::Single { client } => client,
ClientSelection::Dual { prefill, .. } => prefill,
};
let mut proto_request = match &ctx.input.request_type {
RequestType::Chat(request) => {
let request_id = format!("chatcmpl-{}", Uuid::new_v4());
let body_ref = prep.filtered_request.as_ref().unwrap_or(request);
builder_client
.build_generate_request(
request_id,
body_ref,
prep.processed_messages.as_ref().unwrap().text.clone(),
prep.token_ids.clone(),
prep.processed_messages
.as_ref()
.unwrap()
.multimodal_inputs
.clone(),
prep.tool_constraints.clone(),
)
.map_err(|e| {
utils::bad_request_error(format!("Invalid request parameters: {}", e))
})?
}
RequestType::Generate(request) => {
let request_id = request
.rid
.clone()
.unwrap_or_else(|| format!("gen-{}", Uuid::new_v4()));
builder_client
.build_plain_generate_request(
request_id,
request,
prep.original_text.clone(),
prep.token_ids.clone(),
)
.map_err(utils::bad_request_error)?
}
RequestType::Responses(_request) => {
// Responses API builds request during the MCP loop
// For now, create minimal request - responses handler will populate it
let request_id = format!("resp-{}", Uuid::new_v4());
proto::GenerateRequest {
request_id,
..Default::default()
}
}
};
// Inject PD metadata if needed
if self.inject_pd_metadata {
if let WorkerSelection::Dual { prefill, .. } = ctx.state.workers.as_ref().unwrap() {
self.inject_bootstrap_metadata(&mut proto_request, prefill);
}
}
ctx.state.proto_request = Some(proto_request);
Ok(None)
}
fn name(&self) -> &'static str {
"RequestBuilding"
}
}
impl RequestBuildingStage {
fn inject_bootstrap_metadata(
&self,
request: &mut proto::GenerateRequest,
prefill_worker: &Arc<dyn Worker>,
) {
let hostname = prefill_worker.bootstrap_host();
let bootstrap_port = prefill_worker.bootstrap_port().unwrap_or(8998);
// Generate room ID for bootstrap
let room_id = rand::rng().random_range(0..i32::MAX);
// Create DisaggregatedParams
let disagg_params = DisaggregatedParams {
bootstrap_host: hostname.to_string(),
bootstrap_port: bootstrap_port as i32,
bootstrap_room: room_id,
};
// Inject metadata directly into request
request.disaggregated_params = Some(disagg_params);
debug!(
"Injected bootstrap metadata: host={}, port={}, room={}",
hostname, bootstrap_port, room_id
);
}
}
//! Request execution stage: Execute gRPC requests (single or dual dispatch)
use async_trait::async_trait;
use axum::response::Response;
use super::PipelineStage;
use crate::{
grpc_client::proto,
routers::grpc::{
context::{ClientSelection, ExecutionResult, RequestContext},
utils,
},
};
/// Request execution stage: Execute gRPC requests (single or dual dispatch)
pub struct RequestExecutionStage {
mode: ExecutionMode,
}
pub enum ExecutionMode {
/// Regular mode: single worker execution
Single,
/// PD mode: dual dispatch to prefill + decode workers
DualDispatch,
}
impl RequestExecutionStage {
pub fn new(mode: ExecutionMode) -> Self {
Self { mode }
}
}
#[async_trait]
impl PipelineStage for RequestExecutionStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
let proto_request = ctx
.state
.proto_request
.take()
.ok_or_else(|| utils::internal_error_static("Proto request not built"))?;
let clients = ctx
.state
.clients
.as_mut()
.ok_or_else(|| utils::internal_error_static("Client acquisition not completed"))?;
let result = match self.mode {
ExecutionMode::Single => self.execute_single(proto_request, clients).await?,
ExecutionMode::DualDispatch => {
self.execute_dual_dispatch(proto_request, clients).await?
}
};
// Store result in context for ResponseProcessingStage
ctx.state.response.execution_result = Some(result);
Ok(None)
}
fn name(&self) -> &'static str {
"RequestExecution"
}
}
impl RequestExecutionStage {
async fn execute_single(
&self,
proto_request: proto::GenerateRequest,
clients: &mut ClientSelection,
) -> Result<ExecutionResult, Response> {
let client = clients
.single_mut()
.ok_or_else(|| utils::internal_error_static("Expected single client but got dual"))?;
let stream = client.generate(proto_request).await.map_err(|e| {
utils::internal_error_message(format!("Failed to start generation: {}", e))
})?;
Ok(ExecutionResult::Single { stream })
}
async fn execute_dual_dispatch(
&self,
proto_request: proto::GenerateRequest,
clients: &mut ClientSelection,
) -> Result<ExecutionResult, Response> {
let (prefill_client, decode_client) = clients
.dual_mut()
.ok_or_else(|| utils::internal_error_static("Expected dual clients but got single"))?;
let prefill_request = proto_request.clone();
let decode_request = proto_request;
let (prefill_result, decode_result) = tokio::join!(
prefill_client.generate(prefill_request),
decode_client.generate(decode_request)
);
// Handle prefill result
let prefill_stream = match prefill_result {
Ok(s) => s,
Err(e) => {
return Err(utils::internal_error_message(format!(
"Prefill worker failed to start: {}",
e
)));
}
};
// Handle decode result
let decode_stream = match decode_result {
Ok(s) => s,
Err(e) => {
return Err(utils::internal_error_message(format!(
"Decode worker failed to start: {}",
e
)));
}
};
Ok(ExecutionResult::Dual {
prefill: prefill_stream,
decode: Box::new(decode_stream),
})
}
}
//! Response processing stage: Handles both streaming and non-streaming responses
//!
//! - For streaming: Spawns background task and returns SSE response (early exit)
//! - For non-streaming: Collects all responses and builds final ChatCompletionResponse
use std::{sync::Arc, time::Instant};
use async_trait::async_trait;
use axum::response::Response;
use super::PipelineStage;
use crate::routers::grpc::{
context::{FinalResponse, RequestContext, RequestType},
processing, streaming, utils,
};
/// Response processing stage: Handles both streaming and non-streaming responses
///
/// - For streaming: Spawns background task and returns SSE response (early exit)
/// - For non-streaming: Collects all responses and builds final ChatCompletionResponse
pub struct ResponseProcessingStage {
processor: processing::ResponseProcessor,
streaming_processor: Arc<streaming::StreamingProcessor>,
}
impl ResponseProcessingStage {
pub fn new(
processor: processing::ResponseProcessor,
streaming_processor: Arc<streaming::StreamingProcessor>,
) -> Self {
Self {
processor,
streaming_processor,
}
}
}
#[async_trait]
impl PipelineStage for ResponseProcessingStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
// Delegate to request-type specific processing
match &ctx.input.request_type {
RequestType::Chat(_) => self.process_chat_response(ctx).await,
RequestType::Generate(_) => self.process_generate_response(ctx).await,
RequestType::Responses(_) => Err(utils::bad_request_error(
"Responses API processing must be handled by responses handler".to_string(),
)),
}
}
fn name(&self) -> &'static str {
"ResponseProcessing"
}
}
impl ResponseProcessingStage {
async fn process_chat_response(
&self,
ctx: &mut RequestContext,
) -> Result<Option<Response>, Response> {
let is_streaming = ctx.is_streaming();
// Extract execution result
let execution_result = ctx
.state
.response
.execution_result
.take()
.ok_or_else(|| utils::internal_error_static("No execution result"))?;
// Get dispatch metadata (needed by both streaming and non-streaming)
let dispatch = ctx
.state
.dispatch
.as_ref()
.ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?
.clone();
if is_streaming {
// Streaming: Use StreamingProcessor and return SSE response (done)
return Ok(Some(
self.streaming_processor.clone().process_streaming_response(
execution_result,
ctx.chat_request_arc(), // Cheap Arc clone (8 bytes)
dispatch,
),
));
}
// Non-streaming: Delegate to ResponseProcessor
let request_logprobs = match &ctx.input.request_type {
RequestType::Chat(req) => req.logprobs,
_ => false,
};
let chat_request = ctx.chat_request_arc();
let stop_decoder = ctx
.state
.response
.stop_decoder
.as_mut()
.ok_or_else(|| utils::internal_error_static("Stop decoder not initialized"))?;
let response = self
.processor
.process_non_streaming_chat_response(
execution_result,
chat_request,
dispatch,
stop_decoder,
request_logprobs,
)
.await?;
// Store the final response
ctx.state.response.final_response = Some(FinalResponse::Chat(response));
Ok(None)
}
async fn process_generate_response(
&self,
ctx: &mut RequestContext,
) -> Result<Option<Response>, Response> {
let start_time = Instant::now();
let is_streaming = ctx.is_streaming();
// Extract execution result
let execution_result = ctx
.state
.response
.execution_result
.take()
.ok_or_else(|| utils::internal_error_static("No execution result"))?;
// Get dispatch metadata (needed by both streaming and non-streaming)
let dispatch = ctx
.state
.dispatch
.as_ref()
.ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?
.clone();
if is_streaming {
// Streaming: Use StreamingProcessor and return SSE response (done)
return Ok(Some(
self.streaming_processor.clone().process_streaming_generate(
execution_result,
ctx.generate_request_arc(), // Cheap Arc clone (8 bytes)
dispatch,
),
));
}
// Non-streaming: Delegate to ResponseProcessor
let request_logprobs = ctx.generate_request().return_logprob.unwrap_or(false);
let generate_request = ctx.generate_request_arc();
let stop_decoder = ctx
.state
.response
.stop_decoder
.as_mut()
.ok_or_else(|| utils::internal_error_static("Stop decoder not initialized"))?;
let result_array = self
.processor
.process_non_streaming_generate_response(
execution_result,
generate_request,
dispatch,
stop_decoder,
request_logprobs,
start_time,
)
.await?;
// Store the final response
ctx.state.response.final_response = Some(FinalResponse::Generate(result_array));
Ok(None)
}
}
//! Worker selection stage: Select appropriate worker(s) based on routing mode
use std::sync::Arc;
use async_trait::async_trait;
use axum::response::Response;
use tracing::warn;
use super::PipelineStage;
use crate::{
core::{ConnectionMode, Worker, WorkerRegistry, WorkerType},
policies::PolicyRegistry,
routers::grpc::{
context::{RequestContext, WorkerSelection},
utils,
},
};
/// Worker selection stage: Select appropriate worker(s) based on routing mode
pub struct WorkerSelectionStage {
worker_registry: Arc<WorkerRegistry>,
policy_registry: Arc<PolicyRegistry>,
mode: WorkerSelectionMode,
}
pub enum WorkerSelectionMode {
/// Regular mode: select single worker
Regular,
/// PD mode: select prefill + decode workers
PrefillDecode,
}
impl WorkerSelectionStage {
pub fn new(
worker_registry: Arc<WorkerRegistry>,
policy_registry: Arc<PolicyRegistry>,
mode: WorkerSelectionMode,
) -> Self {
Self {
worker_registry,
policy_registry,
mode,
}
}
}
#[async_trait]
impl PipelineStage for WorkerSelectionStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
let prep = ctx
.state
.preparation
.as_ref()
.ok_or_else(|| utils::internal_error_static("Preparation stage not completed"))?;
// For Harmony, use selection_text produced during Harmony encoding
// Otherwise, use original_text from regular preparation
let text = if prep.harmony_mode {
prep.selection_text.as_deref()
} else {
prep.original_text.as_deref()
};
let workers = match self.mode {
WorkerSelectionMode::Regular => {
match self.select_single_worker(ctx.input.model_id.as_deref(), text) {
Some(w) => WorkerSelection::Single { worker: w },
None => {
return Err(utils::service_unavailable_error(format!(
"No available workers for model: {:?}",
ctx.input.model_id
)));
}
}
}
WorkerSelectionMode::PrefillDecode => {
match self.select_pd_pair(ctx.input.model_id.as_deref(), text) {
Some((prefill, decode)) => WorkerSelection::Dual { prefill, decode },
None => {
return Err(utils::service_unavailable_error(format!(
"No available PD worker pairs for model: {:?}",
ctx.input.model_id
)));
}
}
}
};
ctx.state.workers = Some(workers);
Ok(None)
}
fn name(&self) -> &'static str {
"WorkerSelection"
}
}
impl WorkerSelectionStage {
fn select_single_worker(
&self,
model_id: Option<&str>,
text: Option<&str>,
) -> Option<Arc<dyn Worker>> {
// Get workers for the specified model, filtered by connection mode
let workers = self.worker_registry.get_workers_filtered(
model_id,
Some(WorkerType::Regular),
Some(ConnectionMode::Grpc { port: None }),
false, // get all workers, we'll filter by is_available() next
);
// Filter by availability (health + circuit breaker)
let available: Vec<Arc<dyn Worker>> = workers
.iter()
.filter(|w| w.is_available())
.cloned()
.collect();
if available.is_empty() {
return None;
}
// Get the appropriate policy for this model
let policy = match model_id {
Some(model) => self.policy_registry.get_policy_or_default(model),
None => self.policy_registry.get_default_policy(),
};
// Select worker using the policy
let idx = policy.select_worker(&available, text)?;
Some(available[idx].clone())
}
fn select_pd_pair(
&self,
model_id: Option<&str>,
text: Option<&str>,
) -> Option<(Arc<dyn Worker>, Arc<dyn Worker>)> {
let all_workers = self.worker_registry.get_workers_filtered(
model_id,
None,
Some(ConnectionMode::Grpc { port: None }), // Match any gRPC worker
false,
);
let (available_prefill, available_decode): (Vec<_>, Vec<_>) =
all_workers
.into_iter()
.fold((Vec::new(), Vec::new()), |mut acc, w| {
if w.is_available() {
match w.metadata().worker_type {
WorkerType::Prefill { .. } => acc.0.push(w),
WorkerType::Decode => acc.1.push(w),
_ => {}
}
}
acc
});
if available_prefill.is_empty() {
warn!("No available prefill workers");
return None;
}
if available_decode.is_empty() {
warn!("No available decode workers");
return None;
}
// Select using policies
let policy = match model_id {
Some(model) => self.policy_registry.get_policy_or_default(model),
None => self.policy_registry.get_default_policy(),
};
let prefill_idx = policy.select_worker(&available_prefill, text)?;
let decode_idx = policy.select_worker(&available_decode, text)?;
Some((
available_prefill[prefill_idx].clone(),
available_decode[decode_idx].clone(),
))
}
}
......@@ -80,6 +80,7 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
tool_choice: Some(ToolChoice::default()),
tools: Some(vec![ResponseTool {
r#type: ResponseToolType::Mcp,
function: None,
server_url: Some(mcp.url()),
authorization: None,
server_label: Some("mock".to_string()),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment