Unverified Commit edd86b88 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][grpc] Refactor chat handler in grpc/ to use centralized orchestrator (#11314)


Co-authored-by: default avatarSimo Lin <linsimo.mark@gmail.com>
parent 4b4dc132
...@@ -2066,6 +2066,40 @@ impl GenerationRequest for GenerateRequest { ...@@ -2066,6 +2066,40 @@ impl GenerationRequest for GenerateRequest {
} }
} }
// TODO(generate): Define GenerateResponse and GenerateChoice structs
//
// Required for pipeline generate response processing (see grpc/pipeline.rs:931-964)
//
// #[derive(Debug, Clone, Serialize, Deserialize)]
// pub struct GenerateResponse {
// pub id: String,
// pub object: String, // "text.completion"
// pub created: u64,
// pub model: String,
// pub choices: Vec<GenerateChoice>,
// #[serde(skip_serializing_if = "Option::is_none")]
// pub usage: Option<Usage>,
// #[serde(skip_serializing_if = "Option::is_none")]
// pub system_fingerprint: Option<String>,
// }
//
// #[derive(Debug, Clone, Serialize, Deserialize)]
// pub struct GenerateChoice {
// pub index: u32,
// pub text: String,
// #[serde(skip_serializing_if = "Option::is_none")]
// pub output_ids: Option<Vec<u32>>,
// #[serde(skip_serializing_if = "Option::is_none")]
// pub finish_reason: Option<String>,
// #[serde(skip_serializing_if = "Option::is_none")]
// pub logprobs: Option<TopLogprobs>,
// #[serde(skip_serializing_if = "Option::is_none")]
// pub matched_stop: Option<Value>,
// }
//
// Note: Verify if similar structs already exist elsewhere before implementing.
// May need streaming variant (GenerateStreamResponse) as well.
// Constants for rerank API // Constants for rerank API
pub const DEFAULT_MODEL_NAME: &str = "default"; pub const DEFAULT_MODEL_NAME: &str = "default";
......
//! Request context types for gRPC router pipeline
//!
//! This module provides the core context types that flow through the router pipeline,
//! eliminating deep parameter passing chains and providing a single source of truth
//! for request state.
use std::collections::HashMap;
use std::sync::Arc;
use axum::http::HeaderMap;
use serde_json::Value;
use crate::core::Worker;
use crate::grpc_client::{proto, SglangSchedulerClient};
use crate::protocols::spec::{ChatCompletionRequest, ChatCompletionResponse, GenerateRequest};
use crate::reasoning_parser::ReasoningParserFactory;
use crate::tokenizer::stop::StopSequenceDecoder;
use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ToolParserFactory;
// ============================================================================
// Core Context Types
// ============================================================================
/// Main request processing context
///
/// This is the single source of truth for all request state as it flows
/// through the pipeline stages. Uses Rust's type system to enforce proper
/// stage ordering at compile time.
pub struct RequestContext {
// === Input (Immutable) ===
pub input: RequestInput,
// === Shared Components (Immutable References) ===
pub components: Arc<SharedComponents>,
// === Processing State (Mutable, evolves through pipeline) ===
pub state: ProcessingState,
}
/// Immutable request input
pub struct RequestInput {
pub request_type: RequestType,
pub headers: Option<HeaderMap>,
pub model_id: Option<String>,
}
/// Request type variants
pub enum RequestType {
Chat(Box<ChatCompletionRequest>),
Generate(Box<GenerateRequest>),
}
/// Shared components (injected once at creation)
pub struct SharedComponents {
pub tokenizer: Arc<dyn Tokenizer>,
pub tool_parser_factory: ToolParserFactory,
pub reasoning_parser_factory: ReasoningParserFactory,
}
/// Mutable processing state (evolves through pipeline stages)
#[derive(Default)]
pub struct ProcessingState {
// Stage 1: Preparation outputs
pub preparation: Option<PreparationOutput>,
// Stage 2: Worker selection outputs
pub workers: Option<WorkerSelection>,
// Stage 3: Client acquisition outputs
pub clients: Option<ClientSelection>,
// Stage 4: Request building outputs
pub proto_request: Option<proto::GenerateRequest>,
// Stage 5: Dispatch metadata
pub dispatch: Option<DispatchMetadata>,
// Stage 6: Response processing state
pub response: ResponseState,
}
// ============================================================================
// Stage-Specific Output Types
// ============================================================================
/// Output from preparation stage (Step 1)
pub struct PreparationOutput {
/// Original text (for chat) or resolved text (for generate)
pub original_text: Option<String>,
/// Tokenized input
pub token_ids: Vec<u32>,
/// Processed messages (chat only)
pub processed_messages: Option<super::ProcessedMessages>,
/// Tool call constraints (if applicable)
pub tool_constraints: Option<(String, String)>,
/// Filtered request (if tools were filtered)
pub filtered_request: Option<ChatCompletionRequest>,
}
/// Worker selection (Step 2)
pub enum WorkerSelection {
Single {
worker: Arc<dyn Worker>,
},
Dual {
prefill: Arc<dyn Worker>,
decode: Arc<dyn Worker>,
},
}
/// Client selection (Step 3)
pub enum ClientSelection {
Single {
client: SglangSchedulerClient,
},
Dual {
prefill: SglangSchedulerClient,
decode: SglangSchedulerClient,
},
}
/// Dispatch metadata (Step 5)
#[derive(Clone)]
pub struct DispatchMetadata {
pub request_id: String,
pub model: String,
pub created: u64,
pub weight_version: Option<String>,
pub is_streaming: bool,
}
/// Response processing state (Step 6)
#[derive(Default)]
pub struct ResponseState {
/// Stop sequence decoder
pub stop_decoder: Option<StopSequenceDecoder>,
/// Per-index streaming state (for n>1 support)
pub streaming: StreamingState,
/// Collected responses (non-streaming)
pub collected: Option<Vec<proto::GenerateComplete>>,
/// Execution result (streams from workers)
pub execution_result: Option<ExecutionResult>,
/// Final processed response
pub final_response: Option<FinalResponse>,
}
/// Streaming state (per-choice tracking)
#[derive(Default)]
pub struct StreamingState {
pub is_firsts: HashMap<u32, bool>,
pub stream_buffers: HashMap<u32, String>,
pub finish_reasons: HashMap<u32, String>,
pub matched_stops: HashMap<u32, Option<Value>>,
pub prompt_tokens: HashMap<u32, u32>,
pub completion_tokens: HashMap<u32, u32>,
pub cached_tokens: HashMap<u32, u32>,
// Parser state (lazy initialization per index)
pub reasoning_parsers:
HashMap<u32, Arc<std::sync::Mutex<Box<dyn crate::reasoning_parser::ReasoningParser>>>>,
pub tool_parsers:
HashMap<u32, Arc<tokio::sync::Mutex<Box<dyn crate::tool_parser::ToolParser>>>>,
pub has_tool_calls: HashMap<u32, bool>,
}
// ============================================================================
// Context Builders
// ============================================================================
impl RequestContext {
/// Create context for chat completion request
pub fn for_chat(
request: ChatCompletionRequest,
headers: Option<HeaderMap>,
model_id: Option<String>,
components: Arc<SharedComponents>,
) -> Self {
Self {
input: RequestInput {
request_type: RequestType::Chat(Box::new(request)),
headers,
model_id,
},
components,
state: ProcessingState::default(),
}
}
/// Create context for generate request
pub fn for_generate(
request: GenerateRequest,
headers: Option<HeaderMap>,
model_id: Option<String>,
components: Arc<SharedComponents>,
) -> Self {
Self {
input: RequestInput {
request_type: RequestType::Generate(Box::new(request)),
headers,
model_id,
},
components,
state: ProcessingState::default(),
}
}
/// Get reference to original request (type-safe)
pub fn request(&self) -> &RequestType {
&self.input.request_type
}
/// Get chat request (panics if not chat)
pub fn chat_request(&self) -> &ChatCompletionRequest {
match &self.input.request_type {
RequestType::Chat(req) => req.as_ref(),
_ => panic!("Expected chat request"),
}
}
/// Try to get chat request
pub fn try_chat_request(&self) -> Option<&ChatCompletionRequest> {
match &self.input.request_type {
RequestType::Chat(req) => Some(req.as_ref()),
_ => None,
}
}
/// Get generate request (panics if not generate)
pub fn generate_request(&self) -> &GenerateRequest {
match &self.input.request_type {
RequestType::Generate(req) => req.as_ref(),
_ => panic!("Expected generate request"),
}
}
/// Try to get generate request
pub fn try_generate_request(&self) -> Option<&GenerateRequest> {
match &self.input.request_type {
RequestType::Generate(req) => Some(req.as_ref()),
_ => None,
}
}
/// Check if request is streaming
pub fn is_streaming(&self) -> bool {
match &self.input.request_type {
RequestType::Chat(req) => req.stream,
RequestType::Generate(req) => req.stream,
}
}
/// Check if request is chat
pub fn is_chat(&self) -> bool {
matches!(&self.input.request_type, RequestType::Chat(_))
}
/// Check if request is generate
pub fn is_generate(&self) -> bool {
matches!(&self.input.request_type, RequestType::Generate(_))
}
}
// ============================================================================
// Default Implementations
// ============================================================================
// ============================================================================
// Helper Methods
// ============================================================================
impl WorkerSelection {
pub fn is_dual(&self) -> bool {
matches!(self, Self::Dual { .. })
}
pub fn single(&self) -> Option<&Arc<dyn Worker>> {
match self {
Self::Single { worker } => Some(worker),
_ => None,
}
}
#[allow(clippy::type_complexity)]
pub fn dual(&self) -> Option<(&Arc<dyn Worker>, &Arc<dyn Worker>)> {
match self {
Self::Dual { prefill, decode } => Some((prefill, decode)),
_ => None,
}
}
pub fn prefill_worker(&self) -> Option<&Arc<dyn Worker>> {
match self {
Self::Dual { prefill, .. } => Some(prefill),
_ => None,
}
}
pub fn decode_worker(&self) -> Option<&Arc<dyn Worker>> {
match self {
Self::Dual { decode, .. } => Some(decode),
_ => None,
}
}
}
impl ClientSelection {
pub fn is_dual(&self) -> bool {
matches!(self, Self::Dual { .. })
}
pub fn single(&self) -> Option<&SglangSchedulerClient> {
match self {
Self::Single { client } => Some(client),
_ => None,
}
}
pub fn single_mut(&mut self) -> Option<&mut SglangSchedulerClient> {
match self {
Self::Single { client } => Some(client),
_ => None,
}
}
pub fn dual(&self) -> Option<(&SglangSchedulerClient, &SglangSchedulerClient)> {
match self {
Self::Dual { prefill, decode } => Some((prefill, decode)),
_ => None,
}
}
pub fn dual_mut(&mut self) -> Option<(&mut SglangSchedulerClient, &mut SglangSchedulerClient)> {
match self {
Self::Dual { prefill, decode } => Some((prefill, decode)),
_ => None,
}
}
pub fn prefill_client(&self) -> Option<&SglangSchedulerClient> {
match self {
Self::Dual { prefill, .. } => Some(prefill),
_ => None,
}
}
pub fn prefill_client_mut(&mut self) -> Option<&mut SglangSchedulerClient> {
match self {
Self::Dual { prefill, .. } => Some(prefill),
_ => None,
}
}
pub fn decode_client(&self) -> Option<&SglangSchedulerClient> {
match self {
Self::Dual { decode, .. } => Some(decode),
_ => None,
}
}
pub fn decode_client_mut(&mut self) -> Option<&mut SglangSchedulerClient> {
match self {
Self::Dual { decode, .. } => Some(decode),
_ => None,
}
}
}
// ============================================================================
// Execution and Response Types
// ============================================================================
use tonic::codec::Streaming;
/// Result of request execution (streams from workers)
pub enum ExecutionResult {
Single {
stream: Streaming<proto::GenerateResponse>,
},
Dual {
prefill: Streaming<proto::GenerateResponse>,
decode: Box<Streaming<proto::GenerateResponse>>,
},
}
/// Final processed response
pub enum FinalResponse {
Chat(ChatCompletionResponse),
Generate(Box<GenerateRequest>),
}
...@@ -3,8 +3,12 @@ ...@@ -3,8 +3,12 @@
use crate::grpc_client::proto; use crate::grpc_client::proto;
use crate::protocols::spec::StringOrArray; use crate::protocols::spec::StringOrArray;
pub mod context;
pub mod pd_router; pub mod pd_router;
pub mod pipeline;
pub mod processing;
pub mod router; pub mod router;
pub mod streaming;
pub mod utils; pub mod utils;
/// Processed chat messages ready for gRPC generation /// Processed chat messages ready for gRPC generation
......
...@@ -6,19 +6,16 @@ use crate::grpc_client::proto; ...@@ -6,19 +6,16 @@ use crate::grpc_client::proto;
use crate::grpc_client::SglangSchedulerClient; use crate::grpc_client::SglangSchedulerClient;
use crate::policies::PolicyRegistry; use crate::policies::PolicyRegistry;
use crate::protocols::spec::{ use crate::protocols::spec::{
ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, InputIds,
ChatCompletionStreamResponse, ChatLogProbs, ChatLogProbsContent, ChatMessageDelta, RerankRequest, ResponsesGetParams, ResponsesRequest,
ChatStreamChoice, CompletionRequest, EmbeddingRequest, FunctionCallDelta, FunctionCallResponse,
GenerateRequest, InputIds, RerankRequest, ResponsesGetParams, ResponsesRequest, StringOrArray,
Tool, ToolCall, ToolCallDelta, ToolChoice, ToolChoiceValue, TopLogProb, Usage,
}; };
use crate::reasoning_parser::{ParserResult, ReasoningParser, ReasoningParserFactory}; use crate::reasoning_parser::ReasoningParserFactory;
use crate::routers::http::pd_types::generate_room_id; use crate::routers::http::pd_types::generate_room_id;
use crate::routers::{grpc, RouterTrait}; use crate::routers::{grpc, RouterTrait};
use crate::server::AppContext; use crate::server::AppContext;
use crate::tokenizer::traits::Tokenizer; use crate::tokenizer::traits::Tokenizer;
use crate::tokenizer::{SequenceDecoderOutput, StopSequenceDecoder}; use crate::tokenizer::SequenceDecoderOutput;
use crate::tool_parser::{StreamingParseResult, ToolParser, ToolParserFactory}; use crate::tool_parser::ToolParserFactory;
use async_trait::async_trait; use async_trait::async_trait;
use axum::{ use axum::{
body::Body, body::Body,
...@@ -29,16 +26,14 @@ use axum::{ ...@@ -29,16 +26,14 @@ use axum::{
}; };
use grpc::utils; use grpc::utils;
use proto::generate_response::Response::{Chunk, Complete, Error}; use proto::generate_response::Response::{Chunk, Complete, Error};
use serde_json::Value;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::sync::mpsc::unbounded_channel; use tokio::sync::mpsc::unbounded_channel;
use tokio::sync::mpsc::UnboundedSender; use tokio::sync::mpsc::UnboundedSender;
use tokio_stream::Stream; use tokio_stream::Stream;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use tracing::{debug, error, warn}; use tracing::{debug, error};
use uuid::Uuid; use uuid::Uuid;
/// gRPC PD (Prefill-Decode) router implementation for SGLang /// gRPC PD (Prefill-Decode) router implementation for SGLang
...@@ -55,6 +50,10 @@ pub struct GrpcPDRouter { ...@@ -55,6 +50,10 @@ pub struct GrpcPDRouter {
retry_config: RetryConfig, retry_config: RetryConfig,
configured_reasoning_parser: Option<String>, configured_reasoning_parser: Option<String>,
configured_tool_parser: Option<String>, configured_tool_parser: Option<String>,
// Pipeline for non-streaming requests
pipeline: super::pipeline::ChatCompletionPipeline,
// Shared components for pipeline
shared_components: Arc<super::context::SharedComponents>,
} }
impl GrpcPDRouter { impl GrpcPDRouter {
...@@ -81,6 +80,39 @@ impl GrpcPDRouter { ...@@ -81,6 +80,39 @@ impl GrpcPDRouter {
.ok_or_else(|| "gRPC PD router requires tool parser factory".to_string())? .ok_or_else(|| "gRPC PD router requires tool parser factory".to_string())?
.clone(); .clone();
// Create shared components for pipeline
let shared_components = Arc::new(super::context::SharedComponents {
tokenizer: tokenizer.clone(),
tool_parser_factory: tool_parser_factory.clone(),
reasoning_parser_factory: reasoning_parser_factory.clone(),
});
// Create response processor
let processor = super::processing::ResponseProcessor::new(
tokenizer.clone(),
tool_parser_factory.clone(),
reasoning_parser_factory.clone(),
ctx.configured_tool_parser.clone(),
ctx.configured_reasoning_parser.clone(),
);
// Create streaming processor
let streaming_processor = Arc::new(super::streaming::StreamingProcessor::new(
tokenizer.clone(),
tool_parser_factory.clone(),
reasoning_parser_factory.clone(),
ctx.configured_tool_parser.clone(),
ctx.configured_reasoning_parser.clone(),
));
// Create PD pipeline
let pipeline = super::pipeline::ChatCompletionPipeline::new_pd(
worker_registry.clone(),
policy_registry.clone(),
processor,
streaming_processor,
);
Ok(GrpcPDRouter { Ok(GrpcPDRouter {
worker_registry, worker_registry,
policy_registry, policy_registry,
...@@ -92,6 +124,8 @@ impl GrpcPDRouter { ...@@ -92,6 +124,8 @@ impl GrpcPDRouter {
retry_config: ctx.router_config.effective_retry_config(), retry_config: ctx.router_config.effective_retry_config(),
configured_reasoning_parser: ctx.configured_reasoning_parser.clone(), configured_reasoning_parser: ctx.configured_reasoning_parser.clone(),
configured_tool_parser: ctx.configured_tool_parser.clone(), configured_tool_parser: ctx.configured_tool_parser.clone(),
pipeline,
shared_components,
}) })
} }
...@@ -314,7 +348,7 @@ impl GrpcPDRouter { ...@@ -314,7 +348,7 @@ impl GrpcPDRouter {
/// Main route_chat implementation with PD dual dispatch /// Main route_chat implementation with PD dual dispatch
async fn route_chat_impl( async fn route_chat_impl(
&self, &self,
_headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
body: &ChatCompletionRequest, body: &ChatCompletionRequest,
model_id: Option<&str>, model_id: Option<&str>,
) -> Response { ) -> Response {
...@@ -323,91 +357,15 @@ impl GrpcPDRouter { ...@@ -323,91 +357,15 @@ impl GrpcPDRouter {
model_id model_id
); );
// Step 1: Filter tools if needed for allowed_tools or specific function // Use pipeline for ALL requests (streaming and non-streaming)
let body_ref = utils::filter_tools_for_request(body); self.pipeline
.execute_chat(
// Step 2: Process messages and apply chat template body.clone(),
let processed_messages = match utils::process_chat_messages(&body_ref, &*self.tokenizer) { headers.cloned(),
Ok(msgs) => msgs, model_id.map(|s| s.to_string()),
Err(e) => { self.shared_components.clone(),
return utils::bad_request_error(e.to_string()); )
}
};
// Step 3: Tokenize the processed text
let encoding = match self.tokenizer.encode(&processed_messages.text) {
Ok(encoding) => encoding,
Err(e) => {
return utils::internal_error_message(format!("Tokenization failed: {}", e));
}
};
// Step 4: Build tool constraints if needed
// body_ref already has filtered tools if needed
let tool_call_constraint = body_ref.tools.as_ref().and_then(|tools| {
utils::generate_tool_constraints(tools, &body.tool_choice, &body.model)
});
let token_ids = encoding.token_ids().to_vec();
debug!("Tokenized {} tokens from input", token_ids.len());
// Step 5: Select prefill-decode worker pair
let (prefill_worker, decode_worker) = match self
.select_pd_pair(Some(&processed_messages.text), model_id)
.await .await
{
Ok(pair) => pair,
Err(e) => {
return utils::service_unavailable_error(e);
}
};
debug!(
"Selected PD pair: prefill={}, decode={}",
prefill_worker.url(),
decode_worker.url()
);
// Step 6: Get gRPC clients for both workers
let prefill_client = match utils::get_grpc_client_from_worker(&prefill_worker).await {
Ok(client) => client,
Err(response) => return response,
};
let decode_client = match utils::get_grpc_client_from_worker(&decode_worker).await {
Ok(client) => client,
Err(response) => return response,
};
// Step 7: Build the base gRPC request
let request_id = format!("chatcmpl-{}", Uuid::new_v4());
let mut request = match prefill_client.build_generate_request(
request_id.clone(),
&body_ref,
processed_messages.text.clone(),
token_ids,
processed_messages.multimodal_inputs,
tool_call_constraint,
) {
Ok(request) => request,
Err(e) => {
return utils::bad_request_error(format!("Invalid request parameters: {}", e));
}
};
// Step 8: Inject bootstrap metadata into the request
if let Err(e) = Self::inject_bootstrap_metadata(&mut request, &*prefill_worker) {
return utils::internal_error_message(e);
}
// Step 9: Handle streaming vs non-streaming
if body.stream {
self.handle_streaming_chat(prefill_client, decode_client, request, body)
.await
} else {
self.handle_non_streaming_chat(prefill_client, decode_client, request, body)
.await
}
} }
/// Resolve the generate input into optional original text and token IDs /// Resolve the generate input into optional original text and token IDs
...@@ -441,109 +399,6 @@ impl GrpcPDRouter { ...@@ -441,109 +399,6 @@ impl GrpcPDRouter {
Err("Either `text` or `input_ids` must be provided".to_string()) Err("Either `text` or `input_ids` must be provided".to_string())
} }
/// Submit request and handle streaming response for chat completions (PD mode)
async fn handle_streaming_chat(
&self,
mut prefill_client: SglangSchedulerClient,
mut decode_client: SglangSchedulerClient,
request: proto::GenerateRequest,
original_request: &ChatCompletionRequest,
) -> Response {
let request_id = request.request_id.clone();
let model = original_request.model.clone();
// Create channel for SSE streaming
let (tx, rx) = unbounded_channel::<Result<bytes::Bytes, std::io::Error>>();
// Send requests in parallel to both prefill and decode workers
debug!("Starting concurrent streaming requests to prefill and decode workers");
let prefill_request = request.clone();
let decode_request = request;
let (prefill_result, decode_result) = tokio::join!(
prefill_client.generate(prefill_request),
decode_client.generate(decode_request)
);
// Get prefill stream
let prefill_stream = match prefill_result {
Ok(s) => s,
Err(e) => {
return utils::internal_error_message(format!(
"Prefill worker failed to start: {}",
e
));
}
};
// Get decode stream - this is what we'll process for output
let decode_stream = match decode_result {
Ok(s) => s,
Err(e) => {
return utils::internal_error_message(format!(
"Decode worker failed to start: {}",
e
));
}
};
let stop_params = (
original_request.stop.clone(),
original_request.stop_token_ids.clone(),
original_request.skip_special_tokens,
original_request.no_stop_trim,
);
// Spawn processing task for both streams
let self_clone = self.clone();
let original_request_clone = original_request.clone();
tokio::spawn(async move {
let result = Self::process_dual_streaming_chunks(
&self_clone,
prefill_stream,
decode_stream,
request_id,
model,
stop_params,
original_request_clone,
&tx,
)
.await;
if let Err(e) = result {
let error_chunk = format!(
"data: {}\n\n",
serde_json::json!({
"error": {
"message": e,
"type": "internal_error"
}
})
);
let _ = tx.send(Ok(bytes::Bytes::from(error_chunk)));
}
// Send DONE marker
let _ = tx.send(Ok(bytes::Bytes::from("data: [DONE]\n\n")));
});
// Create response with SSE headers
let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
let mut response = Response::new(Body::from_stream(stream));
*response.status_mut() = StatusCode::OK;
response.headers_mut().insert(
header::CONTENT_TYPE,
HeaderValue::from_static("text/event-stream"),
);
response
.headers_mut()
.insert("Cache-Control", HeaderValue::from_static("no-cache"));
response
.headers_mut()
.insert("Connection", HeaderValue::from_static("keep-alive"));
response
}
/// Submit request and handle streaming response for generate endpoint (PD mode) /// Submit request and handle streaming response for generate endpoint (PD mode)
async fn handle_streaming_generate( async fn handle_streaming_generate(
&self, &self,
...@@ -766,778 +621,6 @@ impl GrpcPDRouter { ...@@ -766,778 +621,6 @@ impl GrpcPDRouter {
Ok(()) Ok(())
} }
/// Process dual streaming chunks (prefill + decode) and send SSE events (PD mode)
#[allow(clippy::too_many_arguments)]
async fn process_dual_streaming_chunks(
router: &GrpcPDRouter,
mut prefill_stream: impl Stream<Item = Result<proto::GenerateResponse, tonic::Status>> + Unpin,
mut decode_stream: impl Stream<Item = Result<proto::GenerateResponse, tonic::Status>> + Unpin,
request_id: String,
model: String,
stop_params: (Option<StringOrArray>, Option<Vec<u32>>, bool, bool),
original_request: ChatCompletionRequest,
tx: &UnboundedSender<Result<bytes::Bytes, std::io::Error>>,
) -> Result<(), String> {
// Extract request parameters
let separate_reasoning = original_request.separate_reasoning;
let tool_choice = &original_request.tool_choice;
let tools = &original_request.tools;
let history_tool_calls_count = utils::get_history_tool_calls_count(&original_request);
let stream_options = &original_request.stream_options;
// Phase 1: Initialize state tracking (per-index for n>1 support)
let mut is_firsts: HashMap<u32, bool> = HashMap::new();
let mut stream_buffers: HashMap<u32, String> = HashMap::new();
let mut finish_reasons: HashMap<u32, String> = HashMap::new();
let mut matched_stops: HashMap<u32, Option<Value>> = HashMap::new();
let mut prompt_tokens: HashMap<u32, u32> = HashMap::new();
let mut completion_tokens: HashMap<u32, u32> = HashMap::new();
let mut cached_tokens: HashMap<u32, u32> = HashMap::new();
// Parser state (lazy initialization per index)
type PooledReasoningParser = Arc<std::sync::Mutex<Box<dyn ReasoningParser>>>;
let mut reasoning_parsers: HashMap<u32, PooledReasoningParser> = HashMap::new();
type PooledToolParser = Arc<tokio::sync::Mutex<Box<dyn ToolParser>>>;
let mut tool_parsers: HashMap<u32, PooledToolParser> = HashMap::new();
let mut has_tool_calls: HashMap<u32, bool> = HashMap::new();
// Create stop decoder
let (stop, stop_token_ids, skip_special_tokens, no_stop_trim) = stop_params;
let mut stop_decoder = utils::create_stop_decoder(
&router.tokenizer,
stop.as_ref(),
stop_token_ids.as_ref(),
skip_special_tokens,
no_stop_trim,
);
let created = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
// Phase 1.5: Collect input_logprobs from prefill stream if requested
// Note: In PD mode, input_logprobs come from prefill worker
// TODO: Store and emit input_logprobs when implementing prompt logprobs in streaming
if original_request.logprobs {
while let Some(response) = prefill_stream.next().await {
let gen_response = response.map_err(|e| format!("Prefill stream error: {}", e))?;
match gen_response.response {
Some(Complete(_complete)) => {
// Input logprobs collected but not yet used in streaming
// (OpenAI spec doesn't require prompt logprobs in streaming responses)
break;
}
Some(Error(error)) => {
return Err(format!("Prefill error: {}", error.message));
}
_ => continue,
}
}
}
// Phase 2: Main streaming loop (decode stream)
while let Some(response) = decode_stream.next().await {
let gen_response = response.map_err(|e| format!("Stream error: {}", e))?;
match gen_response.response {
Some(Chunk(chunk)) => {
let index = chunk.index;
// Process tokens through stop decoder
let (chunk_text, _should_stop) =
Self::process_chunk_tokens(&mut stop_decoder, &chunk.token_ids);
if chunk_text.is_empty() {
continue;
}
// Process logprobs if present
let choice_logprobs = if let Some(ref proto_logprobs) = chunk.output_logprobs {
match router.convert_proto_to_openai_logprobs(proto_logprobs) {
Ok(logprobs) => Some(logprobs),
Err(e) => {
warn!("Failed to process logprobs: {}", e);
None
}
}
} else {
None
};
// Initialize stream buffer if first time
let stream_buffer = stream_buffers.entry(index).or_default();
// Send first chunk with role
if is_firsts.get(&index).copied().unwrap_or(true) {
let first_chunk = ChatCompletionStreamResponse {
id: request_id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model.clone(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: None,
reasoning_content: None,
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
};
tx.send(Ok(bytes::Bytes::from(Self::format_sse_chunk(&first_chunk))))
.map_err(|_| "Failed to send first chunk".to_string())?;
is_firsts.insert(index, false);
}
// Calculate delta
let mut delta = chunk_text;
stream_buffer.push_str(&delta);
// Reasoning content handling
let in_reasoning = if separate_reasoning {
let (normal_text, reasoning_chunk, in_reasoning) = router
.process_reasoning_stream(
&delta,
index,
&mut reasoning_parsers,
&request_id,
&model,
created,
);
if let Some(chunk) = reasoning_chunk {
tx.send(Ok(bytes::Bytes::from(Self::format_sse_chunk(&chunk))))
.map_err(|_| "Failed to send reasoning chunk".to_string())?;
}
delta = normal_text;
in_reasoning
} else {
false
};
// Tool call handling
let tool_choice_enabled =
!matches!(tool_choice, Some(ToolChoice::Value(ToolChoiceValue::None)));
if !in_reasoning && tool_choice_enabled && tools.is_some() {
let (should_skip, tool_chunks) = router
.process_tool_calls_stream(
&delta,
index,
&mut tool_parsers,
&mut has_tool_calls,
tools.as_ref().unwrap(),
&request_id,
&model,
created,
history_tool_calls_count,
)
.await;
for chunk in tool_chunks {
tx.send(Ok(bytes::Bytes::from(Self::format_sse_chunk(&chunk))))
.map_err(|_| "Failed to send tool call chunk".to_string())?;
}
if should_skip {
continue;
}
}
// Regular content emission
if !delta.is_empty() {
let content_chunk = Self::create_content_chunk(
delta,
index,
&request_id,
&model,
created,
choice_logprobs,
);
tx.send(Ok(bytes::Bytes::from(Self::format_sse_chunk(
&content_chunk,
))))
.map_err(|_| "Failed to send content chunk".to_string())?;
}
}
Some(Complete(complete)) => {
// Flush any remaining text
if let SequenceDecoderOutput::Text(text) = stop_decoder.flush() {
if !text.is_empty() {
let index = complete.index;
let stream_buffer = stream_buffers.entry(index).or_default();
stream_buffer.push_str(&text);
let content_chunk = ChatCompletionStreamResponse {
id: request_id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model.clone(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: Some(text),
tool_calls: None,
reasoning_content: None,
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
};
let sse_chunk = serde_json::to_string(&content_chunk)
.map_err(|e| format!("Failed to serialize content chunk: {}", e))?;
tx.send(Ok(bytes::Bytes::from(format!("data: {}\n\n", sse_chunk))))
.map_err(|_| "Failed to send flushed content".to_string())?;
}
}
// Store metadata
let index = complete.index;
prompt_tokens.insert(index, complete.prompt_tokens as u32);
completion_tokens.insert(index, complete.completion_tokens as u32);
cached_tokens.insert(index, complete.cached_tokens as u32);
finish_reasons.insert(index, complete.finish_reason.clone());
// Extract matched_stop
let matched_stop_value = match &complete.matched_stop {
Some(proto::generate_complete::MatchedStop::MatchedTokenId(token_id)) => {
Some(Value::Number(serde_json::Number::from(*token_id)))
}
Some(proto::generate_complete::MatchedStop::MatchedStopStr(stop_str)) => {
Some(Value::String(stop_str.clone()))
}
None => None,
};
matched_stops.insert(index, matched_stop_value);
break;
}
Some(Error(error)) => {
return Err(error.message);
}
None => continue,
}
}
// Phase 3: Check unstreamed tool args
for (index, parser) in &tool_parsers {
let parser_guard = parser.lock().await;
if let Some(unstreamed_items) = parser_guard.get_unstreamed_tool_args() {
for tool_call_item in unstreamed_items {
let tool_call_delta = ToolCallDelta {
index: tool_call_item.tool_index as u32,
id: None,
tool_type: None,
function: Some(FunctionCallDelta {
name: None,
arguments: if !tool_call_item.parameters.is_empty() {
Some(tool_call_item.parameters)
} else {
None
},
}),
};
let tool_chunk = ChatCompletionStreamResponse {
id: request_id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model.clone(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index: *index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: Some(vec![tool_call_delta]),
reasoning_content: None,
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
};
let sse_chunk = serde_json::to_string(&tool_chunk)
.map_err(|e| format!("Failed to serialize tool chunk: {}", e))?;
tx.send(Ok(bytes::Bytes::from(format!("data: {}\n\n", sse_chunk))))
.map_err(|_| "Failed to send unstreamed tool args".to_string())?;
}
}
}
// Phase 4: Finish reason chunks
for (index, finish_reason) in finish_reasons.iter() {
let final_finish_reason =
if has_tool_calls.get(index).copied().unwrap_or(false) && finish_reason == "stop" {
"tool_calls".to_string()
} else {
finish_reason.clone()
};
let matched_stop_value = matched_stops.get(index).and_then(|v| v.clone());
let finish_chunk = ChatCompletionStreamResponse {
id: request_id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model.clone(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index: *index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: None,
reasoning_content: None,
},
logprobs: None,
finish_reason: Some(final_finish_reason),
matched_stop: matched_stop_value,
}],
usage: None,
};
let sse_chunk = serde_json::to_string(&finish_chunk)
.map_err(|e| format!("Failed to serialize finish chunk: {}", e))?;
tx.send(Ok(bytes::Bytes::from(format!("data: {}\n\n", sse_chunk))))
.map_err(|_| "Failed to send finish chunk".to_string())?;
}
// Phase 5: Usage chunk
if let Some(stream_opts) = stream_options {
if stream_opts.include_usage.unwrap_or(false) {
let total_prompt: u32 = prompt_tokens.values().sum();
let total_completion: u32 = completion_tokens.values().sum();
let usage_chunk = ChatCompletionStreamResponse {
id: request_id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model.clone(),
system_fingerprint: None,
choices: vec![],
usage: Some(Usage {
prompt_tokens: total_prompt,
completion_tokens: total_completion,
total_tokens: total_prompt + total_completion,
completion_tokens_details: None,
}),
};
let sse_chunk = serde_json::to_string(&usage_chunk)
.map_err(|e| format!("Failed to serialize usage chunk: {}", e))?;
tx.send(Ok(bytes::Bytes::from(format!("data: {}\n\n", sse_chunk))))
.map_err(|_| "Failed to send usage chunk".to_string())?;
}
}
Ok(())
}
/// Helper: Process reasoning content in streaming mode
fn process_reasoning_stream(
&self,
delta: &str,
index: u32,
reasoning_parsers: &mut HashMap<u32, Arc<std::sync::Mutex<Box<dyn ReasoningParser>>>>,
request_id: &str,
model: &str,
created: u64,
) -> (String, Option<ChatCompletionStreamResponse>, bool) {
// Get or create parser for this index
reasoning_parsers.entry(index).or_insert_with(|| {
utils::get_reasoning_parser(
&self.reasoning_parser_factory,
self.configured_reasoning_parser.as_ref(),
model,
)
});
if let Some(pooled_parser) = reasoning_parsers.get(&index) {
let (parse_result, in_reasoning) = {
let mut parser = pooled_parser.lock().unwrap();
let result = parser.parse_reasoning_streaming_incremental(delta);
let in_reasoning = parser.is_in_reasoning();
(result, in_reasoning)
};
match parse_result {
Ok(ParserResult {
reasoning_text,
normal_text,
}) => {
let chunk = if !reasoning_text.is_empty() {
Some(ChatCompletionStreamResponse {
id: request_id.to_string(),
object: "chat.completion.chunk".to_string(),
created,
model: model.to_string(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: None,
reasoning_content: Some(reasoning_text),
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
})
} else {
None
};
return (normal_text, chunk, in_reasoning);
}
Err(e) => {
warn!("Reasoning parsing error: {}", e);
}
}
}
(delta.to_string(), None, false)
}
/// Helper: Process tool calls in streaming mode
#[allow(clippy::too_many_arguments)]
async fn process_tool_calls_stream(
&self,
delta: &str,
index: u32,
tool_parsers: &mut HashMap<u32, Arc<tokio::sync::Mutex<Box<dyn ToolParser>>>>,
has_tool_calls: &mut HashMap<u32, bool>,
tools: &[Tool],
request_id: &str,
model: &str,
created: u64,
history_tool_calls_count: usize,
) -> (bool, Vec<ChatCompletionStreamResponse>) {
let mut chunks = Vec::new();
// Get or create parser for this index
tool_parsers.entry(index).or_insert_with(|| {
utils::get_tool_parser(
&self.tool_parser_factory,
self.configured_tool_parser.as_ref(),
model,
)
});
if let Some(pooled_parser) = tool_parsers.get(&index) {
let mut parser = pooled_parser.lock().await;
match parser.parse_incremental(delta, tools).await {
Ok(StreamingParseResult { normal_text, calls }) => {
// Emit normal text if present
if !normal_text.is_empty() {
chunks.push(ChatCompletionStreamResponse {
id: request_id.to_string(),
object: "chat.completion.chunk".to_string(),
created,
model: model.to_string(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: Some(normal_text),
tool_calls: None,
reasoning_content: None,
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
});
}
// Emit tool call chunks
for tool_call_item in calls {
has_tool_calls.insert(index, true);
let tool_call_id = if let Some(ref name) = tool_call_item.name {
Some(utils::generate_tool_call_id(
model,
name,
tool_call_item.tool_index,
history_tool_calls_count,
))
} else {
None
};
let tool_call_delta = ToolCallDelta {
index: tool_call_item.tool_index as u32,
id: tool_call_id,
tool_type: if tool_call_item.name.is_some() {
Some("function".to_string())
} else {
None
},
function: Some(FunctionCallDelta {
name: tool_call_item.name,
arguments: if !tool_call_item.parameters.is_empty() {
Some(tool_call_item.parameters)
} else {
None
},
}),
};
chunks.push(ChatCompletionStreamResponse {
id: request_id.to_string(),
object: "chat.completion.chunk".to_string(),
created,
model: model.to_string(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: Some(vec![tool_call_delta]),
reasoning_content: None,
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
});
}
// If we emitted chunks, skip regular content
return (!chunks.is_empty(), chunks);
}
Err(e) => {
warn!("Tool call parsing error: {}", e);
}
}
}
(false, chunks)
}
/// Helper: Create content chunk
fn create_content_chunk(
content: String,
index: u32,
request_id: &str,
model: &str,
created: u64,
logprobs: Option<ChatLogProbs>,
) -> ChatCompletionStreamResponse {
ChatCompletionStreamResponse {
id: request_id.to_string(),
object: "chat.completion.chunk".to_string(),
created,
model: model.to_string(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: Some(content),
tool_calls: None,
reasoning_content: None,
},
logprobs,
finish_reason: None,
matched_stop: None,
}],
usage: None,
}
}
/// Helper: Format response as SSE chunk
fn format_sse_chunk(response: &ChatCompletionStreamResponse) -> String {
format!(
"data: {}\n\n",
serde_json::to_string(response).unwrap_or_default()
)
}
/// Process a chunk of tokens through the stop decoder
fn process_chunk_tokens(
stop_decoder: &mut StopSequenceDecoder,
token_ids: &[u32],
) -> (String, bool) {
let mut chunk_text = String::new();
for &token_id in token_ids {
match stop_decoder.process_token(token_id).unwrap_or_else(|e| {
debug!(
"Error processing token {}: {}. Treating as Held.",
token_id, e
);
SequenceDecoderOutput::Held
}) {
SequenceDecoderOutput::Text(text) => {
chunk_text.push_str(&text);
}
SequenceDecoderOutput::StoppedWithText(text) => {
chunk_text.push_str(&text);
return (chunk_text, true);
}
SequenceDecoderOutput::Stopped => {
return (chunk_text, true);
}
SequenceDecoderOutput::Held => {}
}
}
(chunk_text, false)
}
/// Submit request and handle non-streaming response for chat completions (PD mode)
async fn handle_non_streaming_chat(
&self,
mut prefill_client: SglangSchedulerClient,
mut decode_client: SglangSchedulerClient,
request: proto::GenerateRequest,
original_request: &ChatCompletionRequest,
) -> Response {
// Step 1: Create stop decoder
let mut stop_decoder = utils::create_stop_decoder(
&self.tokenizer,
original_request.stop.as_ref(),
original_request.stop_token_ids.as_ref(),
original_request.skip_special_tokens,
original_request.no_stop_trim,
);
// Step 2: Send requests in parallel
debug!("Sending concurrent requests to prefill and decode workers");
let prefill_request = request.clone();
let decode_request = request;
let (prefill_result, decode_result) = tokio::join!(
prefill_client.generate(prefill_request),
decode_client.generate(decode_request)
);
// Step 3: Process prefill stream in parallel - if it fails, assume decode fails
let prefill_stream = match prefill_result {
Ok(s) => s,
Err(e) => {
error!("Failed to start prefill generation: {}", e);
return utils::internal_error_message(format!(
"Prefill worker failed to start: {}",
e
));
}
};
let decode_stream = match decode_result {
Ok(s) => s,
Err(e) => {
error!("Failed to start decode generation: {}", e);
return utils::internal_error_message(format!(
"Decode worker failed to start: {}",
e
));
}
};
// Collect prefill response (for input_logprobs if requested)
let prefill_responses =
match utils::collect_stream_responses(prefill_stream, "Prefill").await {
Ok(responses) => responses,
Err(error_response) => return error_response,
};
// Extract input_logprobs from prefill response if available
let prefill_input_logprobs = prefill_responses
.first()
.and_then(|r| r.input_logprobs.clone());
// Step 4: Process decode stream (collect all responses for n>1 support)
let all_responses = match utils::collect_stream_responses(decode_stream, "Decode").await {
Ok(responses) => responses,
Err(error_response) => return error_response,
};
if all_responses.is_empty() {
return utils::internal_error_static("No responses from decode worker");
}
// Process each response into a ChatChoice
let history_tool_calls_count = utils::get_history_tool_calls_count(original_request);
let mut choices = Vec::new();
for (index, complete) in all_responses.iter().enumerate() {
// Merge prefill input_logprobs if available and requested
let mut complete_with_logprobs = complete.clone();
if prefill_input_logprobs.is_some() && original_request.logprobs {
complete_with_logprobs.input_logprobs = prefill_input_logprobs.clone();
}
match self
.process_single_choice(
&complete_with_logprobs,
index,
original_request,
&mut stop_decoder,
history_tool_calls_count,
)
.await
{
Ok(choice) => choices.push(choice),
Err(e) => {
return utils::internal_error_message(format!(
"Failed to process choice {}: {}",
index, e
));
}
}
}
// Aggregate usage information from all responses
let total_prompt_tokens: u32 = all_responses.iter().map(|r| r.prompt_tokens as u32).sum();
let total_completion_tokens: u32 = all_responses
.iter()
.map(|r| r.completion_tokens as u32)
.sum();
let usage = Usage {
prompt_tokens: total_prompt_tokens,
completion_tokens: total_completion_tokens,
total_tokens: total_prompt_tokens + total_completion_tokens,
completion_tokens_details: None,
};
// Build final ChatCompletionResponse
let response = ChatCompletionResponse {
id: format!("chatcmpl-{}", Uuid::new_v4()),
object: "chat.completion".to_string(),
created: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
model: original_request.model.clone(),
choices,
usage: Some(usage),
system_fingerprint: None,
};
// Serialize and return JSON response
Json(response).into_response()
}
/// Submit request and handle non-streaming response for generate endpoint (PD mode) /// Submit request and handle non-streaming response for generate endpoint (PD mode)
async fn handle_non_streaming_generate( async fn handle_non_streaming_generate(
&self, &self,
...@@ -1683,301 +766,6 @@ impl GrpcPDRouter { ...@@ -1683,301 +766,6 @@ impl GrpcPDRouter {
Json(result_array).into_response() Json(result_array).into_response()
} }
/// Process a single GenerateComplete response into a ChatChoice
async fn process_single_choice(
&self,
complete: &proto::GenerateComplete,
index: usize,
original_request: &ChatCompletionRequest,
stop_decoder: &mut StopSequenceDecoder,
history_tool_calls_count: usize,
) -> Result<ChatChoice, String> {
stop_decoder.reset();
// Decode tokens
let outputs = stop_decoder
.process_tokens(&complete.output_ids)
.map_err(|e| format!("Failed to process tokens: {}", e))?;
// Accumulate text with early breaks
let mut final_text = String::new();
for output in outputs {
match output {
SequenceDecoderOutput::Text(t) => final_text.push_str(&t),
SequenceDecoderOutput::StoppedWithText(t) => {
final_text.push_str(&t);
break;
}
SequenceDecoderOutput::Stopped => break,
SequenceDecoderOutput::Held => {}
}
}
// Flush remaining text
if let SequenceDecoderOutput::Text(t) = stop_decoder.flush() {
final_text.push_str(&t);
}
// Step 1: Handle reasoning content parsing
let mut reasoning_text: Option<String> = None;
let mut processed_text = final_text;
// Check if reasoning parsing is enabled and separate_reasoning is requested
if original_request.separate_reasoning {
let pooled_parser = utils::get_reasoning_parser(
&self.reasoning_parser_factory,
self.configured_reasoning_parser.as_ref(),
&original_request.model,
);
let mut parser = pooled_parser
.lock()
.map_err(|e| format!("Failed to acquire reasoning parser lock: {}", e))?;
match parser.detect_and_parse_reasoning(&processed_text) {
Ok(result) => {
if !result.reasoning_text.is_empty() {
reasoning_text = Some(result.reasoning_text);
}
processed_text = result.normal_text;
}
Err(e) => {
return Err(format!("Reasoning parsing error: {}", e));
}
}
}
// Step 2: Handle tool call parsing
let mut tool_calls: Option<Vec<ToolCall>> = None;
// Check if tool calls should be processed
let tool_choice_enabled = !matches!(
&original_request.tool_choice,
Some(ToolChoice::Value(ToolChoiceValue::None))
);
if tool_choice_enabled && original_request.tools.is_some() {
// Check if JSON schema constraint was used (specific function or required mode)
let used_json_schema = match &original_request.tool_choice {
Some(ToolChoice::Function { .. }) => true,
Some(ToolChoice::Value(ToolChoiceValue::Required)) => true,
Some(ToolChoice::AllowedTools { mode, .. }) => mode == "required",
_ => false,
};
if used_json_schema {
(tool_calls, processed_text) = utils::parse_json_schema_response(
&processed_text,
&original_request.tool_choice,
);
} else {
(tool_calls, processed_text) = self
.parse_tool_calls(
&processed_text,
&original_request.model,
history_tool_calls_count,
)
.await;
}
}
// Step 3: Use finish reason directly from proto (already OpenAI-compatible string)
let finish_reason_str = &complete.finish_reason;
// Override finish reason if we have tool calls
let final_finish_reason_str = if tool_calls.is_some() {
"tool_calls"
} else {
finish_reason_str
};
// Extract matched_stop information from proto
let matched_stop = match &complete.matched_stop {
Some(proto::generate_complete::MatchedStop::MatchedTokenId(token_id)) => {
Some(Value::Number(serde_json::Number::from(*token_id)))
}
Some(proto::generate_complete::MatchedStop::MatchedStopStr(stop_str)) => {
Some(Value::String(stop_str.clone()))
}
None => None,
};
// Step 4: Convert output logprobs if present
// Note: complete.input_logprobs exists in proto but is not used for chat completions
// (input logprobs are only used in /v1/completions endpoint with echo=true)
let logprobs = if let Some(proto_logprobs) = &complete.output_logprobs {
match self.convert_proto_to_openai_logprobs(proto_logprobs) {
Ok(logprobs) => Some(logprobs),
Err(e) => {
error!("Failed to convert logprobs: {}", e);
None
}
}
} else {
None
};
// Step 5: Build ChatCompletionMessage (proper response message type)
let chat_message = ChatCompletionMessage {
role: "assistant".to_string(),
content: if processed_text.is_empty() {
None
} else {
Some(processed_text)
},
tool_calls,
reasoning_content: reasoning_text,
};
// Step 6: Build ChatChoice
let choice = ChatChoice {
index: index as u32,
message: chat_message,
logprobs,
finish_reason: Some(final_finish_reason_str.to_string()),
matched_stop,
hidden_states: None,
};
Ok(choice)
}
/// Parse tool calls using model-specific parser
async fn parse_tool_calls(
&self,
processed_text: &str,
model: &str,
history_tool_calls_count: usize,
) -> (Option<Vec<ToolCall>>, String) {
// Get pooled parser for this model
let pooled_parser = utils::get_tool_parser(
&self.tool_parser_factory,
self.configured_tool_parser.as_ref(),
model,
);
// Check format detection first
let can_parse = {
let parser = pooled_parser.lock().await;
parser.has_tool_markers(processed_text)
// Lock is dropped here
};
if !can_parse {
return (None, processed_text.to_string());
}
// Lock again for async parsing
let result = {
let parser = pooled_parser.lock().await;
parser.parse_complete(processed_text).await
// Lock is dropped here
};
match result {
Ok((normal_text, parsed_tool_calls)) => {
if parsed_tool_calls.is_empty() {
return (None, normal_text);
}
let spec_tool_calls = parsed_tool_calls
.into_iter()
.enumerate()
.map(|(index, tc)| {
// Generate ID for this tool call
let id = utils::generate_tool_call_id(
model,
&tc.function.name,
index,
history_tool_calls_count,
);
ToolCall {
id,
tool_type: "function".to_string(),
function: FunctionCallResponse {
name: tc.function.name,
arguments: Some(
serde_json::to_string(&tc.function.arguments)
.unwrap_or_else(|_| "{}".to_string()),
),
},
}
})
.collect();
(Some(spec_tool_calls), normal_text)
}
Err(e) => {
error!("Tool call parsing error: {}", e);
(None, processed_text.to_string())
}
}
}
/// Convert proto LogProbs to OpenAI ChatLogProbs format
/// Note: Always decodes with skip_special_tokens=false to show actual tokens generated
fn convert_proto_to_openai_logprobs(
&self,
proto_logprobs: &proto::OutputLogProbs,
) -> Result<ChatLogProbs, String> {
let mut content_items = Vec::new();
// Decode token IDs to text (always with skip_special_tokens=false for logprobs)
let token_texts: Vec<String> = proto_logprobs
.token_ids
.iter()
.map(|&token_id| {
self.tokenizer
.decode(&[token_id as u32], false)
.unwrap_or_else(|_| format!("<token_{}>", token_id))
})
.collect();
// Build ChatLogProbsContent for each token
for (i, &logprob) in proto_logprobs.token_logprobs.iter().enumerate() {
let token_text = token_texts.get(i).cloned().unwrap_or_default();
let bytes = Some(token_text.as_bytes().to_vec());
// Build top_logprobs for this position
let mut top_logprobs = Vec::new();
if let Some(top_logprobs_entry) = proto_logprobs.top_logprobs.get(i) {
// Decode top token IDs (always with skip_special_tokens=false)
let top_token_texts: Vec<String> = top_logprobs_entry
.token_ids
.iter()
.map(|&tid| {
self.tokenizer
.decode(&[tid as u32], false)
.unwrap_or_else(|_| format!("<token_{}>", tid))
})
.collect();
for (j, (&top_logprob, &_top_token_id)) in top_logprobs_entry
.values
.iter()
.zip(top_logprobs_entry.token_ids.iter())
.enumerate()
{
if let Some(top_token_text) = top_token_texts.get(j) {
top_logprobs.push(TopLogProb {
token: top_token_text.clone(),
logprob: top_logprob,
bytes: Some(top_token_text.as_bytes().to_vec()),
});
}
}
}
content_items.push(ChatLogProbsContent {
token: token_text,
logprob,
bytes,
top_logprobs,
});
}
Ok(ChatLogProbs::Detailed {
content: (!content_items.is_empty()).then_some(content_items),
})
}
} }
impl std::fmt::Debug for GrpcPDRouter { impl std::fmt::Debug for GrpcPDRouter {
......
//! Pipeline stages for gRPC router request processing
//!
//! This module defines the core pipeline abstraction and individual processing stages
//! that transform a RequestContext through its lifecycle.
use async_trait::async_trait;
use axum::response::{IntoResponse, Response};
use tracing::{debug, error, warn};
use super::context::*;
use super::processing;
use super::streaming;
use super::utils;
use crate::core::{ConnectionMode, WorkerRegistry, WorkerType};
use crate::grpc_client::proto;
use crate::policies::PolicyRegistry;
use crate::protocols::spec::{
ChatCompletionRequest, ChatCompletionResponse, GenerateRequest, InputIds, Usage,
};
use rand::Rng;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use uuid::Uuid;
// ============================================================================
// Pipeline Trait
// ============================================================================
/// Trait for pipeline stages that process requests
#[async_trait]
pub trait PipelineStage: Send + Sync {
/// Execute this stage, mutating the context
///
/// Returns:
/// - `Ok(None)` - Continue to next stage
/// - `Ok(Some(response))` - Pipeline complete, return this response (e.g., streaming)
/// - `Err(response)` - Error occurred, return this error response
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response>;
/// Stage name for logging
fn name(&self) -> &'static str;
}
// ============================================================================
// Stage 1: Preparation
// ============================================================================
/// Preparation stage: Filter tools, process messages, tokenize, build constraints
pub struct PreparationStage;
#[async_trait]
impl PipelineStage for PreparationStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
debug!("Stage {}: Processing request", self.name());
// Clone the request to avoid borrowing issues
match &ctx.input.request_type {
RequestType::Chat(request) => {
let request_clone = request.clone();
self.prepare_chat(ctx, &request_clone).await?;
}
RequestType::Generate(request) => {
let request_clone = request.clone();
self.prepare_generate(ctx, &request_clone).await?;
}
}
Ok(None)
}
fn name(&self) -> &'static str {
"Preparation"
}
}
impl PreparationStage {
async fn prepare_chat(
&self,
ctx: &mut RequestContext,
request: &ChatCompletionRequest,
) -> Result<(), Response> {
// Step 1: Filter tools if needed
let body_ref = utils::filter_tools_for_request(request);
// Step 2: Process messages and apply chat template
let processed_messages =
match utils::process_chat_messages(&body_ref, &*ctx.components.tokenizer) {
Ok(msgs) => msgs,
Err(e) => {
return Err(utils::bad_request_error(e));
}
};
// Step 3: Tokenize the processed text
let encoding = match ctx.components.tokenizer.encode(&processed_messages.text) {
Ok(encoding) => encoding,
Err(e) => {
return Err(utils::internal_error_message(format!(
"Tokenization failed: {}",
e
)));
}
};
let token_ids = encoding.token_ids().to_vec();
debug!("Tokenized {} tokens from input", token_ids.len());
// Step 4: Build tool constraints if needed
let tool_call_constraint = body_ref.tools.as_ref().and_then(|tools| {
utils::generate_tool_constraints(tools, &request.tool_choice, &request.model)
});
// Step 5: Create stop sequence decoder (build once, reuse in non-stream)
let stop_decoder = utils::create_stop_decoder(
&ctx.components.tokenizer,
request.stop.as_ref(),
request.stop_token_ids.as_ref(),
request.skip_special_tokens,
request.no_stop_trim,
);
// Store results in context
ctx.state.preparation = Some(PreparationOutput {
original_text: Some(processed_messages.text.clone()),
token_ids,
processed_messages: Some(processed_messages),
tool_constraints: tool_call_constraint,
filtered_request: if matches!(body_ref, std::borrow::Cow::Owned(_)) {
Some(body_ref.into_owned())
} else {
None
},
});
// Store stop decoder for reuse in response processing
ctx.state.response.stop_decoder = Some(stop_decoder);
Ok(())
}
async fn prepare_generate(
&self,
ctx: &mut RequestContext,
request: &GenerateRequest,
) -> Result<(), Response> {
// Resolve input (text, prompt, or input_ids)
let (original_text, token_ids) = match self.resolve_generate_input(ctx, request) {
Ok(res) => res,
Err(msg) => {
return Err(utils::bad_request_error(msg));
}
};
debug!("Resolved input with {} tokens", token_ids.len());
// Create stop sequence decoder for generate requests
let params = request.sampling_params.as_ref();
let stop_decoder = utils::create_stop_decoder(
&ctx.components.tokenizer,
params.and_then(|p| p.stop.as_ref()),
params.and_then(|p| p.stop_token_ids.as_ref()),
params.and_then(|p| p.skip_special_tokens).unwrap_or(true),
params.and_then(|p| p.no_stop_trim).unwrap_or(false),
);
ctx.state.preparation = Some(PreparationOutput {
original_text,
token_ids,
processed_messages: None,
tool_constraints: None,
filtered_request: None,
});
// Store stop decoder
ctx.state.response.stop_decoder = Some(stop_decoder);
Ok(())
}
fn resolve_generate_input(
&self,
ctx: &RequestContext,
request: &GenerateRequest,
) -> Result<(Option<String>, Vec<u32>), String> {
if let Some(text) = &request.text {
return self
.tokenize_single_text(&ctx.components.tokenizer, text)
.map(|(original, ids)| (Some(original), ids));
}
// Handle input_ids - validate and convert
if let Some(input_ids) = &request.input_ids {
return match input_ids {
InputIds::Single(ids) => ids
.iter()
.map(|&id| u32::try_from(id))
.collect::<Result<Vec<u32>, _>>()
.map(|converted| (None, converted))
.map_err(|_| "input_ids must be non-negative".to_string()),
InputIds::Batch(_) => {
Err("Batch input_ids are not supported over gRPC generate yet".to_string())
}
};
}
Err("Either `text` or `input_ids` must be provided".to_string())
}
fn tokenize_single_text(
&self,
tokenizer: &Arc<dyn crate::tokenizer::traits::Tokenizer>,
text: &str,
) -> Result<(String, Vec<u32>), String> {
let encoding = tokenizer
.encode(text)
.map_err(|e| format!("Tokenization failed: {}", e))?;
Ok((text.to_string(), encoding.token_ids().to_vec()))
}
}
// ============================================================================
// Stage 2: Worker Selection
// ============================================================================
/// Worker selection stage: Select appropriate worker(s) based on routing mode
pub struct WorkerSelectionStage {
worker_registry: Arc<WorkerRegistry>,
policy_registry: Arc<PolicyRegistry>,
mode: WorkerSelectionMode,
}
pub enum WorkerSelectionMode {
/// Regular mode: select single worker
Regular,
/// PD mode: select prefill + decode workers
PrefillDecode,
}
impl WorkerSelectionStage {
pub fn new(
worker_registry: Arc<WorkerRegistry>,
policy_registry: Arc<PolicyRegistry>,
mode: WorkerSelectionMode,
) -> Self {
Self {
worker_registry,
policy_registry,
mode,
}
}
}
#[async_trait]
impl PipelineStage for WorkerSelectionStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
debug!("Stage {}: Selecting workers", self.name());
let prep = ctx
.state
.preparation
.as_ref()
.ok_or_else(|| utils::internal_error_static("Preparation stage not completed"))?;
let text = prep.original_text.as_deref();
let workers = match self.mode {
WorkerSelectionMode::Regular => {
match self.select_single_worker(ctx.input.model_id.as_deref(), text) {
Some(w) => WorkerSelection::Single { worker: w },
None => {
return Err(utils::service_unavailable_error(format!(
"No available workers for model: {:?}",
ctx.input.model_id
)));
}
}
}
WorkerSelectionMode::PrefillDecode => {
match self.select_pd_pair(ctx.input.model_id.as_deref(), text) {
Some((prefill, decode)) => WorkerSelection::Dual { prefill, decode },
None => {
return Err(utils::service_unavailable_error(format!(
"No available PD worker pairs for model: {:?}",
ctx.input.model_id
)));
}
}
}
};
ctx.state.workers = Some(workers);
Ok(None)
}
fn name(&self) -> &'static str {
"WorkerSelection"
}
}
impl WorkerSelectionStage {
fn select_single_worker(
&self,
model_id: Option<&str>,
text: Option<&str>,
) -> Option<Arc<dyn crate::core::Worker>> {
// Get workers for the specified model, filtered by connection mode
let workers = self.worker_registry.get_workers_filtered(
model_id,
Some(WorkerType::Regular),
Some(ConnectionMode::Grpc { port: None }),
false, // get all workers, we'll filter by is_available() next
);
// Filter by availability (health + circuit breaker)
let available: Vec<Arc<dyn crate::core::Worker>> = workers
.iter()
.filter(|w| w.is_available())
.cloned()
.collect();
if available.is_empty() {
return None;
}
// Get the appropriate policy for this model
let policy = match model_id {
Some(model) => self.policy_registry.get_policy_or_default(model),
None => self.policy_registry.get_default_policy(),
};
// Select worker using the policy
let idx = policy.select_worker(&available, text)?;
Some(available[idx].clone())
}
fn select_pd_pair(
&self,
model_id: Option<&str>,
text: Option<&str>,
) -> Option<(Arc<dyn crate::core::Worker>, Arc<dyn crate::core::Worker>)> {
// Get prefill workers - use None for WorkerType filter to get all types,
// then filter manually (since Prefill is a struct variant)
let all_workers = self.worker_registry.get_workers_filtered(
model_id,
None, // Get all types
Some(ConnectionMode::Grpc { port: None }),
false,
);
let prefill_workers: Vec<_> = all_workers
.iter()
.filter(|w| matches!(w.metadata().worker_type, WorkerType::Prefill { .. }))
.cloned()
.collect();
let available_prefill: Vec<_> = prefill_workers
.iter()
.filter(|w| w.is_available())
.cloned()
.collect();
if available_prefill.is_empty() {
warn!("No available prefill workers");
return None;
}
// Get decode workers from the same all_workers list
let decode_workers: Vec<_> = all_workers
.iter()
.filter(|w| matches!(w.metadata().worker_type, WorkerType::Decode))
.cloned()
.collect();
let available_decode: Vec<_> = decode_workers
.iter()
.filter(|w| w.is_available())
.cloned()
.collect();
if available_decode.is_empty() {
warn!("No available decode workers");
return None;
}
// Select using policies
let policy = match model_id {
Some(model) => self.policy_registry.get_policy_or_default(model),
None => self.policy_registry.get_default_policy(),
};
let prefill_idx = policy.select_worker(&available_prefill, text)?;
let decode_idx = policy.select_worker(&available_decode, text)?;
Some((
available_prefill[prefill_idx].clone(),
available_decode[decode_idx].clone(),
))
}
}
// ============================================================================
// Stage 3: Client Acquisition
// ============================================================================
/// Client acquisition stage: Get gRPC clients from selected workers
pub struct ClientAcquisitionStage;
#[async_trait]
impl PipelineStage for ClientAcquisitionStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
debug!("Stage {}: Acquiring gRPC clients", self.name());
let workers = ctx
.state
.workers
.as_ref()
.ok_or_else(|| utils::internal_error_static("Worker selection not completed"))?;
let clients = match workers {
WorkerSelection::Single { worker } => {
let client = utils::get_grpc_client_from_worker(worker).await?;
ClientSelection::Single { client }
}
WorkerSelection::Dual { prefill, decode } => {
let prefill_client = utils::get_grpc_client_from_worker(prefill).await?;
let decode_client = utils::get_grpc_client_from_worker(decode).await?;
ClientSelection::Dual {
prefill: prefill_client,
decode: decode_client,
}
}
};
ctx.state.clients = Some(clients);
Ok(None)
}
fn name(&self) -> &'static str {
"ClientAcquisition"
}
}
// ============================================================================
// Stage 4: Request Building
// ============================================================================
/// Request building stage: Build proto GenerateRequest
pub struct RequestBuildingStage {
inject_pd_metadata: bool,
}
impl RequestBuildingStage {
pub fn new(inject_pd_metadata: bool) -> Self {
Self { inject_pd_metadata }
}
}
#[async_trait]
impl PipelineStage for RequestBuildingStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
debug!("Stage {}: Building proto request", self.name());
let prep = ctx
.state
.preparation
.as_ref()
.ok_or_else(|| utils::internal_error_static("Preparation not completed"))?;
let clients = ctx
.state
.clients
.as_ref()
.ok_or_else(|| utils::internal_error_static("Client acquisition not completed"))?;
// Get client for building request (use prefill client if PD mode)
let builder_client = match clients {
ClientSelection::Single { client } => client,
ClientSelection::Dual { prefill, .. } => prefill,
};
let mut proto_request = match &ctx.input.request_type {
RequestType::Chat(request) => {
let request_id = format!("chatcmpl-{}", Uuid::new_v4());
let body_ref = prep.filtered_request.as_ref().unwrap_or(request);
builder_client
.build_generate_request(
request_id,
body_ref,
prep.processed_messages.as_ref().unwrap().text.clone(),
prep.token_ids.clone(),
prep.processed_messages
.as_ref()
.unwrap()
.multimodal_inputs
.clone(),
prep.tool_constraints.clone(),
)
.map_err(|e| {
utils::bad_request_error(format!("Invalid request parameters: {}", e))
})?
}
RequestType::Generate(request) => {
let request_id = request
.rid
.clone()
.unwrap_or_else(|| format!("gen-{}", Uuid::new_v4()));
builder_client
.build_plain_generate_request(
request_id,
request,
prep.original_text.clone(),
prep.token_ids.clone(),
)
.map_err(utils::bad_request_error)?
}
};
// Inject PD metadata if needed
if self.inject_pd_metadata {
if let WorkerSelection::Dual { prefill, .. } = ctx.state.workers.as_ref().unwrap() {
self.inject_bootstrap_metadata(&mut proto_request, prefill);
}
}
ctx.state.proto_request = Some(proto_request);
Ok(None)
}
fn name(&self) -> &'static str {
"RequestBuilding"
}
}
impl RequestBuildingStage {
fn inject_bootstrap_metadata(
&self,
request: &mut proto::GenerateRequest,
prefill_worker: &Arc<dyn crate::core::Worker>,
) {
use proto::DisaggregatedParams;
let hostname = prefill_worker.bootstrap_host();
let bootstrap_port = prefill_worker.bootstrap_port().unwrap_or(8998);
// Generate room ID for bootstrap
let room_id = rand::rng().random_range(0..i32::MAX);
// Create DisaggregatedParams
let disagg_params = DisaggregatedParams {
bootstrap_host: hostname.to_string(),
bootstrap_port: bootstrap_port as i32,
bootstrap_room: room_id,
};
// Inject metadata directly into request
request.disaggregated_params = Some(disagg_params);
debug!(
"Injected bootstrap metadata: host={}, port={}, room={}",
hostname, bootstrap_port, room_id
);
}
}
// ============================================================================
// Stage 5: Dispatch Metadata
// ============================================================================
/// Dispatch metadata stage: Prepare metadata for dispatch
pub struct DispatchMetadataStage;
#[async_trait]
impl PipelineStage for DispatchMetadataStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
debug!("Stage {}: Preparing dispatch metadata", self.name());
let proto_request = ctx
.state
.proto_request
.as_ref()
.ok_or_else(|| utils::internal_error_static("Proto request not built"))?;
let request_id = proto_request.request_id.clone();
let model = match &ctx.input.request_type {
RequestType::Chat(req) => req.model.clone(),
RequestType::Generate(_req) => {
// Generate requests don't have a model field
// Use model_id from input or default
ctx.input
.model_id
.clone()
.unwrap_or_else(|| "default".to_string())
}
};
let weight_version = ctx
.state
.workers
.as_ref()
.map(|w| match w {
WorkerSelection::Single { worker } => worker,
WorkerSelection::Dual { decode, .. } => decode,
})
.and_then(|w| w.metadata().labels.get("weight_version").cloned())
.unwrap_or_else(|| "default".to_string());
let created = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
ctx.state.dispatch = Some(DispatchMetadata {
request_id,
model,
created,
weight_version: Some(weight_version),
is_streaming: ctx.is_streaming(),
});
Ok(None)
}
fn name(&self) -> &'static str {
"DispatchMetadata"
}
}
// ============================================================================
// Stage 6: Request Execution
// ============================================================================
/// Request execution stage: Execute gRPC requests (single or dual dispatch)
pub struct RequestExecutionStage {
mode: ExecutionMode,
}
pub enum ExecutionMode {
/// Regular mode: single worker execution
Single,
/// PD mode: dual dispatch to prefill + decode workers
DualDispatch,
}
impl RequestExecutionStage {
pub fn new(mode: ExecutionMode) -> Self {
Self { mode }
}
}
#[async_trait]
impl PipelineStage for RequestExecutionStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
debug!("Stage {}: Executing gRPC request", self.name());
let proto_request = ctx
.state
.proto_request
.take()
.ok_or_else(|| utils::internal_error_static("Proto request not built"))?;
let clients = ctx
.state
.clients
.as_mut()
.ok_or_else(|| utils::internal_error_static("Client acquisition not completed"))?;
let result = match self.mode {
ExecutionMode::Single => self.execute_single(proto_request, clients).await?,
ExecutionMode::DualDispatch => {
self.execute_dual_dispatch(proto_request, clients).await?
}
};
// Store result in context for ResponseProcessingStage
ctx.state.response.execution_result = Some(result);
Ok(None)
}
fn name(&self) -> &'static str {
"RequestExecution"
}
}
impl RequestExecutionStage {
async fn execute_single(
&self,
proto_request: proto::GenerateRequest,
clients: &mut ClientSelection,
) -> Result<ExecutionResult, Response> {
let client = clients
.single_mut()
.ok_or_else(|| utils::internal_error_static("Expected single client but got dual"))?;
let stream = client.generate(proto_request).await.map_err(|e| {
utils::internal_error_message(format!("Failed to start generation: {}", e))
})?;
Ok(ExecutionResult::Single { stream })
}
async fn execute_dual_dispatch(
&self,
proto_request: proto::GenerateRequest,
clients: &mut ClientSelection,
) -> Result<ExecutionResult, Response> {
let (prefill_client, decode_client) = clients
.dual_mut()
.ok_or_else(|| utils::internal_error_static("Expected dual clients but got single"))?;
debug!("Sending concurrent requests to prefill and decode workers");
let prefill_request = proto_request.clone();
let decode_request = proto_request;
let (prefill_result, decode_result) = tokio::join!(
prefill_client.generate(prefill_request),
decode_client.generate(decode_request)
);
// Handle prefill result
let prefill_stream = match prefill_result {
Ok(s) => s,
Err(e) => {
return Err(utils::internal_error_message(format!(
"Prefill worker failed to start: {}",
e
)));
}
};
// Handle decode result
let decode_stream = match decode_result {
Ok(s) => s,
Err(e) => {
return Err(utils::internal_error_message(format!(
"Decode worker failed to start: {}",
e
)));
}
};
Ok(ExecutionResult::Dual {
prefill: prefill_stream,
decode: Box::new(decode_stream),
})
}
}
// ============================================================================
// Stage 7: Response Processing
// ============================================================================
/// Response processing stage: Handles both streaming and non-streaming responses
///
/// - For streaming: Spawns background task and returns SSE response (early exit)
/// - For non-streaming: Collects all responses and builds final ChatCompletionResponse
pub struct ResponseProcessingStage {
processor: processing::ResponseProcessor,
streaming_processor: Arc<streaming::StreamingProcessor>,
}
impl ResponseProcessingStage {
pub fn new(
processor: processing::ResponseProcessor,
streaming_processor: Arc<streaming::StreamingProcessor>,
) -> Self {
Self {
processor,
streaming_processor,
}
}
}
#[async_trait]
impl PipelineStage for ResponseProcessingStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
debug!("Stage {}: Processing response", self.name());
// Delegate to request-type specific processing
match &ctx.input.request_type {
RequestType::Chat(_) => return self.process_chat_response(ctx).await,
RequestType::Generate(_) => return self.process_generate_response(ctx).await,
}
}
fn name(&self) -> &'static str {
"ResponseProcessing"
}
}
impl ResponseProcessingStage {
async fn process_chat_response(
&self,
ctx: &mut RequestContext,
) -> Result<Option<Response>, Response> {
let is_streaming = ctx.is_streaming();
// Extract execution result
let execution_result = ctx
.state
.response
.execution_result
.take()
.ok_or_else(|| utils::internal_error_static("No execution result"))?;
if is_streaming {
// Get dispatch metadata for consistent response fields
let dispatch = ctx
.state
.dispatch
.as_ref()
.ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?;
// Streaming: Use StreamingProcessor and return SSE response (done)
return Ok(Some(
self.streaming_processor.clone().process_streaming_response(
execution_result,
ctx.chat_request().clone(),
dispatch.clone(),
),
));
}
// Non-streaming: Extract chat request details before mutable borrows
let request_logprobs = match &ctx.input.request_type {
RequestType::Chat(req) => req.logprobs,
_ => false,
};
// Collect all responses from the execution result
let all_responses = match execution_result {
ExecutionResult::Single { stream } => {
utils::collect_stream_responses(stream, "Single").await?
}
ExecutionResult::Dual { prefill, decode } => {
// Collect prefill for input_logprobs
let prefill_responses = utils::collect_stream_responses(prefill, "Prefill").await?;
// Collect decode for actual output
let mut decode_responses =
utils::collect_stream_responses(*decode, "Decode").await?;
// Merge prefill input_logprobs if requested
if request_logprobs {
if let Some(prefill_input_logprobs) = prefill_responses
.first()
.and_then(|r| r.input_logprobs.clone())
{
for response in &mut decode_responses {
response.input_logprobs = Some(prefill_input_logprobs.clone());
}
}
}
decode_responses
}
};
if all_responses.is_empty() {
return Err(utils::internal_error_static("No responses from server"));
}
// Clone chat_request to avoid borrow checker conflict
// (ctx.chat_request() borrows ctx, preventing mutable borrow of ctx.state.response.stop_decoder)
let chat_request = ctx.chat_request().clone();
let history_tool_calls_count = utils::get_history_tool_calls_count(&chat_request);
let stop_decoder = ctx
.state
.response
.stop_decoder
.as_mut()
.ok_or_else(|| utils::internal_error_static("Stop decoder not initialized"))?;
let mut choices = Vec::new();
for (index, complete) in all_responses.iter().enumerate() {
match self
.processor
.process_single_choice(
complete,
index,
&chat_request,
stop_decoder,
history_tool_calls_count,
)
.await
{
Ok(choice) => choices.push(choice),
Err(e) => {
return Err(utils::internal_error_message(format!(
"Failed to process choice {}: {}",
index, e
)));
}
}
}
// Build usage
let total_prompt_tokens: u32 = all_responses.iter().map(|r| r.prompt_tokens as u32).sum();
let total_completion_tokens: u32 = all_responses
.iter()
.map(|r| r.completion_tokens as u32)
.sum();
let usage = Usage {
prompt_tokens: total_prompt_tokens,
completion_tokens: total_completion_tokens,
total_tokens: total_prompt_tokens + total_completion_tokens,
completion_tokens_details: None,
};
// Build final ChatCompletionResponse
let dispatch = ctx
.state
.dispatch
.as_ref()
.ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?;
let response = ChatCompletionResponse {
id: dispatch.request_id.clone(),
object: "chat.completion".to_string(),
created: dispatch.created,
model: dispatch.model.clone(),
choices,
usage: Some(usage),
system_fingerprint: dispatch.weight_version.clone(),
};
// Store the final response
ctx.state.response.final_response = Some(FinalResponse::Chat(response));
Ok(None)
}
async fn process_generate_response(
&self,
_ctx: &mut RequestContext,
) -> Result<Option<Response>, Response> {
// TODO(generate): Implement generate response processing
//
// Required implementation:
// 1. Extract execution_result from ctx
// 2. Check is_streaming flag
// 3. For streaming:
// - Add StreamingProcessor::process_streaming_generate() method
// - Similar to process_streaming_response but WITHOUT tool/reasoning parsing
// - Return Err(sse_response) for early exit
// 4. For non-streaming:
// - Collect stream responses using utils::collect_stream_responses()
// - Process through stop decoder (sequential with reset for n>1, like chat)
// - Build GenerateResponse struct (see TODO in protocols/spec.rs)
// - Set ctx.state.response.final_response = Some(FinalResponse::Generate(response))
//
// Reference implementation: router.rs:297-595
// Key differences from chat:
// - No tool parsing
// - No reasoning parsing
// - Different response format (GenerateResponse instead of ChatCompletionResponse)
// - Still needs: stop decoder, logprobs, finish_reason, matched_stop
Err((
axum::http::StatusCode::NOT_IMPLEMENTED,
axum::Json(serde_json::json!({
"error": {
"message": "Generate response processing not yet implemented in pipeline",
"type": "not_implemented",
"code": 501
}
})),
)
.into_response())
}
}
// ============================================================================
// Pipeline Orchestrator
// ============================================================================
/// Complete chat completion pipeline
///
/// Orchestrates all stages from request preparation to response delivery.
/// Configured differently for regular vs PD mode.
#[derive(Clone)]
pub struct ChatCompletionPipeline {
stages: Arc<Vec<Box<dyn PipelineStage>>>,
}
impl ChatCompletionPipeline {
/// Create a regular (single-worker) pipeline
pub fn new_regular(
worker_registry: Arc<WorkerRegistry>,
policy_registry: Arc<PolicyRegistry>,
processor: processing::ResponseProcessor,
streaming_processor: Arc<streaming::StreamingProcessor>,
) -> Self {
let stages: Vec<Box<dyn PipelineStage>> = vec![
Box::new(PreparationStage),
Box::new(WorkerSelectionStage::new(
worker_registry,
policy_registry,
WorkerSelectionMode::Regular,
)),
Box::new(ClientAcquisitionStage),
Box::new(RequestBuildingStage::new(false)), // No PD metadata
Box::new(DispatchMetadataStage),
Box::new(RequestExecutionStage::new(ExecutionMode::Single)),
Box::new(ResponseProcessingStage::new(
processor,
streaming_processor.clone(),
)),
];
Self {
stages: Arc::new(stages),
}
}
/// Create a PD (prefill-decode) pipeline
pub fn new_pd(
worker_registry: Arc<WorkerRegistry>,
policy_registry: Arc<PolicyRegistry>,
processor: processing::ResponseProcessor,
streaming_processor: Arc<streaming::StreamingProcessor>,
) -> Self {
let stages: Vec<Box<dyn PipelineStage>> = vec![
Box::new(PreparationStage),
Box::new(WorkerSelectionStage::new(
worker_registry,
policy_registry,
WorkerSelectionMode::PrefillDecode,
)),
Box::new(ClientAcquisitionStage),
Box::new(RequestBuildingStage::new(true)), // Inject PD metadata
Box::new(DispatchMetadataStage),
Box::new(RequestExecutionStage::new(ExecutionMode::DualDispatch)),
Box::new(ResponseProcessingStage::new(
processor,
streaming_processor.clone(),
)),
];
Self {
stages: Arc::new(stages),
}
}
/// Execute the complete pipeline for a chat request
pub async fn execute_chat(
&self,
request: ChatCompletionRequest,
headers: Option<axum::http::HeaderMap>,
model_id: Option<String>,
components: Arc<SharedComponents>,
) -> Response {
let mut ctx = RequestContext::for_chat(request, headers, model_id, components);
// Execute each stage in sequence
for (idx, stage) in self.stages.iter().enumerate() {
debug!("Executing stage {}: {}", idx + 1, stage.name());
match stage.execute(&mut ctx).await {
Ok(Some(response)) => {
// Stage completed successfully with a response (e.g., streaming)
debug!(
"Stage {} ({}) completed with response",
idx + 1,
stage.name()
);
return response;
}
Ok(None) => {
// Continue to next stage
continue;
}
Err(response) => {
// Error occurred
error!(
"Stage {} ({}) failed with status {}",
idx + 1,
stage.name(),
response.status()
);
return response;
}
}
}
// Extract final response
match ctx.state.response.final_response {
Some(FinalResponse::Chat(response)) => axum::Json(response).into_response(),
Some(FinalResponse::Generate(_)) => {
utils::internal_error_static("Internal error: wrong response type")
}
None => utils::internal_error_static("No response produced"),
}
}
/// Execute the complete pipeline for a generate request
pub async fn execute_generate(
&self,
request: GenerateRequest,
headers: Option<axum::http::HeaderMap>,
model_id: Option<String>,
components: Arc<SharedComponents>,
) -> Response {
let mut ctx = RequestContext::for_generate(request, headers, model_id, components);
// Execute each stage in sequence
for (idx, stage) in self.stages.iter().enumerate() {
debug!("Executing stage {}: {}", idx + 1, stage.name());
match stage.execute(&mut ctx).await {
Ok(Some(response)) => {
// Stage completed successfully with a response (e.g., streaming)
debug!(
"Stage {} ({}) completed with response",
idx + 1,
stage.name()
);
return response;
}
Ok(None) => {
// Continue to next stage
continue;
}
Err(response) => {
// Error occurred
error!(
"Stage {} ({}) failed with status {}",
idx + 1,
stage.name(),
response.status()
);
return response;
}
}
}
// Extract final response
match ctx.state.response.final_response {
Some(FinalResponse::Generate(response)) => axum::Json(*response).into_response(),
Some(FinalResponse::Chat(_)) => {
utils::internal_error_static("Internal error: wrong response type")
}
None => utils::internal_error_static("No response produced"),
}
}
}
//! Shared response processing logic for gRPC routers
//!
//! This module contains response processing functions that are shared between
//! the regular router and PD router, eliminating ~1,200 lines of exact duplicates.
use std::sync::Arc;
use serde_json::Value;
use tracing::error;
use crate::grpc_client::proto;
use crate::protocols::spec::{
ChatChoice, ChatCompletionMessage, ChatCompletionRequest, FunctionCallResponse, ToolCall,
ToolChoice, ToolChoiceValue,
};
use crate::reasoning_parser::ReasoningParserFactory;
use crate::tokenizer::stop::{SequenceDecoderOutput, StopSequenceDecoder};
use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ToolParserFactory;
use super::utils;
// ============================================================================
// Response Processor - Main Entry Point
// ============================================================================
/// Unified response processor for both routers
#[derive(Clone)]
pub struct ResponseProcessor {
pub tokenizer: Arc<dyn Tokenizer>,
pub tool_parser_factory: ToolParserFactory,
pub reasoning_parser_factory: ReasoningParserFactory,
configured_tool_parser: Option<String>,
configured_reasoning_parser: Option<String>,
}
impl ResponseProcessor {
pub fn new(
tokenizer: Arc<dyn Tokenizer>,
tool_parser_factory: ToolParserFactory,
reasoning_parser_factory: ReasoningParserFactory,
configured_tool_parser: Option<String>,
configured_reasoning_parser: Option<String>,
) -> Self {
Self {
tokenizer,
tool_parser_factory,
reasoning_parser_factory,
configured_tool_parser,
configured_reasoning_parser,
}
}
/// Process a single choice from GenerateComplete response (EXACT COPY from router.rs:1573-1725)
pub async fn process_single_choice(
&self,
complete: &proto::GenerateComplete,
index: usize,
original_request: &ChatCompletionRequest,
stop_decoder: &mut StopSequenceDecoder,
history_tool_calls_count: usize,
) -> Result<ChatChoice, String> {
stop_decoder.reset();
// Decode tokens
let outputs = stop_decoder
.process_tokens(&complete.output_ids)
.map_err(|e| format!("Failed to process tokens: {}", e))?;
// Accumulate text with early breaks
let mut final_text = String::new();
for output in outputs {
match output {
SequenceDecoderOutput::Text(t) => final_text.push_str(&t),
SequenceDecoderOutput::StoppedWithText(t) => {
final_text.push_str(&t);
break;
}
SequenceDecoderOutput::Stopped => break,
SequenceDecoderOutput::Held => {}
}
}
// Flush remaining text
if let SequenceDecoderOutput::Text(t) = stop_decoder.flush() {
final_text.push_str(&t);
}
// Step 1: Handle reasoning content parsing
let mut reasoning_text: Option<String> = None;
let mut processed_text = final_text;
// Check if reasoning parsing is enabled and separate_reasoning is requested
if original_request.separate_reasoning {
let pooled_parser = utils::get_reasoning_parser(
&self.reasoning_parser_factory,
self.configured_reasoning_parser.as_ref(),
&original_request.model,
);
let mut parser = pooled_parser
.lock()
.map_err(|e| format!("Failed to acquire reasoning parser lock: {}", e))?;
match parser.detect_and_parse_reasoning(&processed_text) {
Ok(result) => {
if !result.reasoning_text.is_empty() {
reasoning_text = Some(result.reasoning_text);
}
processed_text = result.normal_text;
}
Err(e) => {
return Err(format!("Reasoning parsing error: {}", e));
}
}
}
// Step 2: Handle tool call parsing
let mut tool_calls: Option<Vec<ToolCall>> = None;
// Check if tool calls should be processed
let tool_choice_enabled = !matches!(
&original_request.tool_choice,
Some(ToolChoice::Value(ToolChoiceValue::None))
);
if tool_choice_enabled && original_request.tools.is_some() {
// Check if JSON schema constraint was used (specific function or required mode)
let used_json_schema = match &original_request.tool_choice {
Some(ToolChoice::Function { .. }) => true,
Some(ToolChoice::Value(ToolChoiceValue::Required)) => true,
Some(ToolChoice::AllowedTools { mode, .. }) => mode == "required",
_ => false,
};
if used_json_schema {
(tool_calls, processed_text) = utils::parse_json_schema_response(
&processed_text,
&original_request.tool_choice,
);
} else {
(tool_calls, processed_text) = self
.parse_tool_calls(
&processed_text,
&original_request.model,
history_tool_calls_count,
)
.await;
}
}
// Step 3: Use finish reason directly from proto (already OpenAI-compatible string)
let finish_reason_str = &complete.finish_reason;
// Override finish reason if we have tool calls
let final_finish_reason_str = if tool_calls.is_some() {
"tool_calls"
} else {
finish_reason_str
};
// Extract matched_stop information from proto
let matched_stop = match &complete.matched_stop {
Some(proto::generate_complete::MatchedStop::MatchedTokenId(token_id)) => {
Some(Value::Number(serde_json::Number::from(*token_id)))
}
Some(proto::generate_complete::MatchedStop::MatchedStopStr(stop_str)) => {
Some(Value::String(stop_str.clone()))
}
None => None,
};
// Step 4: Convert output logprobs if present
let logprobs = if let Some(proto_logprobs) = &complete.output_logprobs {
match utils::convert_proto_to_openai_logprobs(proto_logprobs, &self.tokenizer) {
Ok(logprobs) => Some(logprobs),
Err(e) => {
error!("Failed to convert logprobs: {}", e);
None
}
}
} else {
None
};
// Step 5: Build ChatCompletionMessage (proper response message type)
let chat_message = ChatCompletionMessage {
role: "assistant".to_string(),
content: if processed_text.is_empty() {
None
} else {
Some(processed_text)
},
tool_calls,
reasoning_content: reasoning_text,
};
// Step 6: Build ChatChoice
let choice = ChatChoice {
index: index as u32,
message: chat_message,
logprobs,
finish_reason: Some(final_finish_reason_str.to_string()),
matched_stop,
hidden_states: None,
};
Ok(choice)
}
/// Parse tool calls using model-specific parser (EXACT COPY from router.rs:296-361)
pub async fn parse_tool_calls(
&self,
processed_text: &str,
model: &str,
history_tool_calls_count: usize,
) -> (Option<Vec<ToolCall>>, String) {
// Get pooled parser for this model
let pooled_parser = utils::get_tool_parser(
&self.tool_parser_factory,
self.configured_tool_parser.as_ref(),
model,
);
// Try parsing directly (parser will handle detection internally)
let result = {
let parser = pooled_parser.lock().await;
parser.parse_complete(processed_text).await
// Lock is dropped here
};
match result {
Ok((normal_text, parsed_tool_calls)) => {
if parsed_tool_calls.is_empty() {
return (None, normal_text);
}
let spec_tool_calls = parsed_tool_calls
.into_iter()
.enumerate()
.map(|(index, tc)| {
// Generate ID for this tool call
let id = utils::generate_tool_call_id(
model,
&tc.function.name,
index,
history_tool_calls_count,
);
ToolCall {
id,
tool_type: "function".to_string(),
function: FunctionCallResponse {
name: tc.function.name,
arguments: Some(
serde_json::to_string(&tc.function.arguments)
.unwrap_or_else(|_| "{}".to_string()),
),
},
}
})
.collect();
(Some(spec_tool_calls), normal_text)
}
Err(e) => {
error!("Tool call parsing error: {}", e);
(None, processed_text.to_string())
}
}
}
}
// gRPC Router Implementation // gRPC Router Implementation
use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use async_trait::async_trait; use async_trait::async_trait;
use axum::{ use axum::{
body::Body, body::Body,
extract::Request, extract::Request,
http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode}, http::{HeaderMap, StatusCode},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
Json, Json,
}; };
use bytes::Bytes; use tracing::debug;
use std::io;
use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, warn};
use crate::config::types::RetryConfig; use crate::config::types::RetryConfig;
use crate::core::{ConnectionMode, Worker, WorkerRegistry, WorkerType}; use crate::core::{ConnectionMode, Worker, WorkerRegistry, WorkerType};
use crate::grpc_client::{proto, SglangSchedulerClient}; use crate::grpc_client::{proto, SglangSchedulerClient};
use crate::policies::PolicyRegistry; use crate::policies::PolicyRegistry;
use crate::protocols::spec::{ use crate::protocols::spec::{
ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, InputIds,
ChatCompletionStreamResponse, ChatMessage, ChatMessageDelta, ChatStreamChoice, RerankRequest, ResponsesGetParams, ResponsesRequest,
CompletionRequest, EmbeddingRequest, FunctionCallDelta, FunctionCallResponse, GenerateRequest,
RerankRequest, ResponsesGetParams, ResponsesRequest, StringOrArray, ToolCall, ToolCallDelta,
ToolChoice, ToolChoiceValue, Usage,
}; };
use crate::reasoning_parser::{ParserResult, ReasoningParserFactory}; use crate::reasoning_parser::ReasoningParserFactory;
use crate::routers::{grpc, RouterTrait}; use crate::routers::{grpc, RouterTrait};
use crate::server::AppContext; use crate::server::AppContext;
use crate::tokenizer::stop::{SequenceDecoderOutput, StopSequenceDecoder}; use crate::tokenizer::stop::SequenceDecoderOutput;
use crate::tokenizer::traits::Tokenizer; use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::{StreamingParseResult, ToolParserFactory}; use crate::tool_parser::ToolParserFactory;
use grpc::utils; use grpc::utils;
use proto::generate_response::Response::{Chunk, Complete, Error}; use serde_json::json;
use serde_json::{json, Value}; use std::time::Instant;
use std::time::{Instant, SystemTime, UNIX_EPOCH};
use tokio_stream::StreamExt;
use uuid::Uuid; use uuid::Uuid;
/// gRPC router implementation for SGLang /// gRPC router implementation for SGLang
...@@ -55,6 +45,10 @@ pub struct GrpcRouter { ...@@ -55,6 +45,10 @@ pub struct GrpcRouter {
retry_config: RetryConfig, retry_config: RetryConfig,
configured_reasoning_parser: Option<String>, configured_reasoning_parser: Option<String>,
configured_tool_parser: Option<String>, configured_tool_parser: Option<String>,
// Pipeline for non-streaming requests
pipeline: super::pipeline::ChatCompletionPipeline,
// Shared components for pipeline
shared_components: Arc<super::context::SharedComponents>,
} }
impl GrpcRouter { impl GrpcRouter {
...@@ -80,6 +74,39 @@ impl GrpcRouter { ...@@ -80,6 +74,39 @@ impl GrpcRouter {
let worker_registry = ctx.worker_registry.clone(); let worker_registry = ctx.worker_registry.clone();
let policy_registry = ctx.policy_registry.clone(); let policy_registry = ctx.policy_registry.clone();
// Create shared components for pipeline
let shared_components = Arc::new(super::context::SharedComponents {
tokenizer: tokenizer.clone(),
tool_parser_factory: tool_parser_factory.clone(),
reasoning_parser_factory: reasoning_parser_factory.clone(),
});
// Create response processor
let processor = super::processing::ResponseProcessor::new(
tokenizer.clone(),
tool_parser_factory.clone(),
reasoning_parser_factory.clone(),
ctx.configured_tool_parser.clone(),
ctx.configured_reasoning_parser.clone(),
);
// Create streaming processor
let streaming_processor = Arc::new(super::streaming::StreamingProcessor::new(
tokenizer.clone(),
tool_parser_factory.clone(),
reasoning_parser_factory.clone(),
ctx.configured_tool_parser.clone(),
ctx.configured_reasoning_parser.clone(),
));
// Create pipeline
let pipeline = super::pipeline::ChatCompletionPipeline::new_regular(
worker_registry.clone(),
policy_registry.clone(),
processor,
streaming_processor,
);
Ok(GrpcRouter { Ok(GrpcRouter {
worker_registry, worker_registry,
policy_registry, policy_registry,
...@@ -91,13 +118,15 @@ impl GrpcRouter { ...@@ -91,13 +118,15 @@ impl GrpcRouter {
retry_config: ctx.router_config.effective_retry_config(), retry_config: ctx.router_config.effective_retry_config(),
configured_reasoning_parser: ctx.configured_reasoning_parser.clone(), configured_reasoning_parser: ctx.configured_reasoning_parser.clone(),
configured_tool_parser: ctx.configured_tool_parser.clone(), configured_tool_parser: ctx.configured_tool_parser.clone(),
pipeline,
shared_components,
}) })
} }
/// Main route_chat implementation /// Main route_chat implementation
async fn route_chat_impl( async fn route_chat_impl(
&self, &self,
_headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
body: &ChatCompletionRequest, body: &ChatCompletionRequest,
model_id: Option<&str>, model_id: Option<&str>,
) -> Response { ) -> Response {
...@@ -106,76 +135,15 @@ impl GrpcRouter { ...@@ -106,76 +135,15 @@ impl GrpcRouter {
model_id model_id
); );
// Step 1: Filter tools if needed for allowed_tools or specific function // Use pipeline for ALL requests (streaming and non-streaming)
let body_ref = utils::filter_tools_for_request(body); self.pipeline
.execute_chat(
// Step 2: Process messages and apply chat template body.clone(),
let processed_messages = match utils::process_chat_messages(&body_ref, &*self.tokenizer) { headers.cloned(),
Ok(msgs) => msgs, model_id.map(|s| s.to_string()),
Err(e) => { self.shared_components.clone(),
return utils::bad_request_error(e.to_string()); )
} .await
};
// Step 3: Tokenize the processed text
let encoding = match self.tokenizer.encode(&processed_messages.text) {
Ok(encoding) => encoding,
Err(e) => {
return utils::internal_error_message(format!("Tokenization failed: {}", e));
}
};
let token_ids = encoding.token_ids().to_vec();
debug!("Tokenized {} tokens from input", token_ids.len());
// Step 4: Build tool constraints if needed
// body_ref already has filtered tools if needed
let tool_call_constraint = body_ref.tools.as_ref().and_then(|tools| {
utils::generate_tool_constraints(tools, &body.tool_choice, &body.model)
});
// Step 5: Select worker
let worker = match self.select_worker_for_request(model_id, Some(&processed_messages.text))
{
Some(w) => w,
None => {
return utils::service_unavailable_error(format!(
"No available workers for model: {:?}",
model_id
));
}
};
debug!("Selected worker: {}", worker.url());
// Step 6: Get gRPC client from worker
let client = match utils::get_grpc_client_from_worker(&worker).await {
Ok(client) => client,
Err(response) => return response,
};
// Step 7: Build the base gRPC request (use body_ref with filtered tools if applicable)
let request_id = format!("chatcmpl-{}", Uuid::new_v4());
let request = match client.build_generate_request(
request_id,
&body_ref,
processed_messages.text.clone(),
token_ids,
processed_messages.multimodal_inputs,
tool_call_constraint, // Pass the full tuple (type, value)
) {
Ok(request) => request,
Err(e) => {
return utils::bad_request_error(format!("Invalid request parameters: {}", e));
}
};
// Step 7: Handle streaming vs non-streaming
if body.stream {
self.handle_streaming_chat(client, request, body).await
} else {
self.handle_non_streaming_chat(client, request, body).await
}
} }
/// Main route_generate implementation /// Main route_generate implementation
...@@ -288,77 +256,6 @@ impl GrpcRouter { ...@@ -288,77 +256,6 @@ impl GrpcRouter {
Some(available[idx].clone()) Some(available[idx].clone())
} }
/// Parse tool calls using model-specific parser
async fn parse_tool_calls(
&self,
processed_text: &str,
model: &str,
history_tool_calls_count: usize,
) -> (Option<Vec<ToolCall>>, String) {
// Get pooled parser for this model
let pooled_parser = utils::get_tool_parser(
&self.tool_parser_factory,
self.configured_tool_parser.as_ref(),
model,
);
// Check format detection first
let can_parse = {
let parser = pooled_parser.lock().await;
parser.has_tool_markers(processed_text)
// Lock is dropped here
};
if !can_parse {
return (None, processed_text.to_string());
}
// Lock again for async parsing
let result = {
let parser = pooled_parser.lock().await;
parser.parse_complete(processed_text).await
// Lock is dropped here
};
match result {
Ok((normal_text, parsed_tool_calls)) => {
if parsed_tool_calls.is_empty() {
return (None, normal_text);
}
let spec_tool_calls = parsed_tool_calls
.into_iter()
.enumerate()
.map(|(index, tc)| {
// Generate ID for this tool call
let id = Self::generate_tool_call_id(
model,
&tc.function.name,
index,
history_tool_calls_count,
);
ToolCall {
id,
tool_type: "function".to_string(),
function: FunctionCallResponse {
name: tc.function.name,
arguments: Some(
serde_json::to_string(&tc.function.arguments)
.unwrap_or_else(|_| "{}".to_string()),
),
},
}
})
.collect();
(Some(spec_tool_calls), normal_text)
}
Err(e) => {
error!("Tool call parsing error: {}", e);
(None, processed_text.to_string())
}
}
}
/// Resolve the generate input into optional original text and token IDs /// Resolve the generate input into optional original text and token IDs
fn resolve_generate_input( fn resolve_generate_input(
&self, &self,
...@@ -373,13 +270,13 @@ impl GrpcRouter { ...@@ -373,13 +270,13 @@ impl GrpcRouter {
// Handle input_ids - validate and convert // Handle input_ids - validate and convert
if let Some(input_ids) = &request.input_ids { if let Some(input_ids) = &request.input_ids {
return match input_ids { return match input_ids {
crate::protocols::spec::InputIds::Single(ids) => ids InputIds::Single(ids) => ids
.iter() .iter()
.map(|&id| u32::try_from(id)) .map(|&id| u32::try_from(id))
.collect::<Result<Vec<u32>, _>>() .collect::<Result<Vec<u32>, _>>()
.map(|converted| (None, converted)) .map(|converted| (None, converted))
.map_err(|_| "input_ids must be non-negative".to_string()), .map_err(|_| "input_ids must be non-negative".to_string()),
crate::protocols::spec::InputIds::Batch(_) => { InputIds::Batch(_) => {
Err("Batch input_ids are not supported over gRPC generate yet".to_string()) Err("Batch input_ids are not supported over gRPC generate yet".to_string())
} }
}; };
...@@ -396,837 +293,6 @@ impl GrpcRouter { ...@@ -396,837 +293,6 @@ impl GrpcRouter {
Ok((text.to_string(), encoding.token_ids().to_vec())) Ok((text.to_string(), encoding.token_ids().to_vec()))
} }
/// Count the number of tool calls in the request message history
/// This is used for KimiK2 format which needs globally unique indices
fn get_history_tool_calls_count(request: &ChatCompletionRequest) -> usize {
request
.messages
.iter()
.filter_map(|msg| {
if let ChatMessage::Assistant { tool_calls, .. } = msg {
tool_calls.as_ref().map(|calls| calls.len())
} else {
None
}
})
.sum()
}
/// Generate a tool call ID based on model format
///
/// # Arguments
/// * `model` - Model name to determine ID format
/// * `tool_name` - Name of the tool being called
/// * `tool_index` - Index of this tool call within the current message
/// * `history_count` - Number of tool calls in previous messages
///
/// # Returns
/// A unique ID string. KimiK2 uses `functions.{name}:{global_index}`, others use `call_{uuid}`
fn generate_tool_call_id(
model: &str,
tool_name: &str,
tool_index: usize,
history_count: usize,
) -> String {
if model.to_lowercase().contains("kimi") {
// KimiK2 format: functions.{name}:{global_index}
format!("functions.{}:{}", tool_name, history_count + tool_index)
} else {
// Standard OpenAI format: call_{24-char-uuid}
format!("call_{}", &Uuid::new_v4().simple().to_string()[..24])
}
}
/// Process a chunk of tokens through the stop decoder
fn process_chunk_tokens(
stop_decoder: &mut StopSequenceDecoder,
token_ids: &[u32],
) -> (String, bool) {
let mut chunk_text = String::new();
for &token_id in token_ids {
match stop_decoder.process_token(token_id).unwrap_or_else(|e| {
debug!(
"Error processing token {}: {}. Treating as Held.",
token_id, e
);
SequenceDecoderOutput::Held
}) {
SequenceDecoderOutput::Text(text) => {
chunk_text.push_str(&text);
}
SequenceDecoderOutput::StoppedWithText(text) => {
chunk_text.push_str(&text);
return (chunk_text, true); // Return text and signal to stop
}
SequenceDecoderOutput::Stopped => {
return (chunk_text, true); // Return text and signal to stop
}
SequenceDecoderOutput::Held => {
// Text held for potential stop sequence match
}
}
}
(chunk_text, false) // Return text and continue processing
}
/// Helper: Process reasoning content in streaming mode
/// Returns (modified_delta, optional_reasoning_chunk)
fn process_reasoning_stream(
&self,
delta: &str,
index: u32,
reasoning_parsers: &mut HashMap<
u32,
Arc<std::sync::Mutex<Box<dyn crate::reasoning_parser::ReasoningParser>>>,
>,
request_id: &str,
model: &str,
created: u64,
) -> (String, Option<ChatCompletionStreamResponse>, bool) {
// Get or create parser for this index
reasoning_parsers.entry(index).or_insert_with(|| {
utils::get_reasoning_parser(
&self.reasoning_parser_factory,
self.configured_reasoning_parser.as_ref(),
model,
)
});
if let Some(pooled_parser) = reasoning_parsers.get(&index) {
let (parse_result, in_reasoning) = {
let mut parser = pooled_parser.lock().unwrap();
let result = parser.parse_reasoning_streaming_incremental(delta);
let in_reasoning = parser.is_in_reasoning();
(result, in_reasoning)
};
match parse_result {
Ok(ParserResult {
reasoning_text,
normal_text,
}) => {
let chunk = if !reasoning_text.is_empty() {
Some(ChatCompletionStreamResponse {
id: request_id.to_string(),
object: "chat.completion.chunk".to_string(),
created,
model: model.to_string(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: None,
reasoning_content: Some(reasoning_text),
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
})
} else {
None
};
return (normal_text, chunk, in_reasoning);
}
Err(e) => {
warn!("Reasoning parsing error: {}", e);
}
}
}
(delta.to_string(), None, false)
}
/// Helper: Process tool calls in streaming mode
/// Returns (should_skip_content, chunks_to_emit)
#[allow(clippy::too_many_arguments)]
async fn process_tool_calls_stream(
&self,
delta: &str,
index: u32,
tool_parsers: &mut HashMap<
u32,
Arc<tokio::sync::Mutex<Box<dyn crate::tool_parser::ToolParser>>>,
>,
has_tool_calls: &mut HashMap<u32, bool>,
tools: &[crate::protocols::spec::Tool],
request_id: &str,
model: &str,
created: u64,
history_tool_calls_count: usize,
) -> (bool, Vec<ChatCompletionStreamResponse>) {
let mut chunks = Vec::new();
// Get or create parser for this index
tool_parsers.entry(index).or_insert_with(|| {
utils::get_tool_parser(
&self.tool_parser_factory,
self.configured_tool_parser.as_ref(),
model,
)
});
if let Some(pooled_parser) = tool_parsers.get(&index) {
let mut parser = pooled_parser.lock().await;
match parser.parse_incremental(delta, tools).await {
Ok(StreamingParseResult { normal_text, calls }) => {
// Emit normal text if present
if !normal_text.is_empty() {
chunks.push(ChatCompletionStreamResponse {
id: request_id.to_string(),
object: "chat.completion.chunk".to_string(),
created,
model: model.to_string(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: Some(normal_text),
tool_calls: None,
reasoning_content: None,
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
});
}
// Emit tool call chunks
for tool_call_item in calls {
has_tool_calls.insert(index, true);
let tool_call_id = if let Some(ref name) = tool_call_item.name {
Some(Self::generate_tool_call_id(
model,
name,
tool_call_item.tool_index,
history_tool_calls_count,
))
} else {
None
};
let tool_call_delta = ToolCallDelta {
index: tool_call_item.tool_index as u32,
id: tool_call_id,
tool_type: if tool_call_item.name.is_some() {
Some("function".to_string())
} else {
None
},
function: Some(FunctionCallDelta {
name: tool_call_item.name,
arguments: if !tool_call_item.parameters.is_empty() {
Some(tool_call_item.parameters)
} else {
None
},
}),
};
chunks.push(ChatCompletionStreamResponse {
id: request_id.to_string(),
object: "chat.completion.chunk".to_string(),
created,
model: model.to_string(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: Some(vec![tool_call_delta]),
reasoning_content: None,
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
});
}
// If we emitted chunks, skip regular content
return (!chunks.is_empty(), chunks);
}
Err(e) => {
warn!("Tool call parsing error: {}", e);
}
}
}
(false, chunks)
}
/// Helper: Create content chunk
fn create_content_chunk(
content: String,
index: u32,
request_id: &str,
model: &str,
created: u64,
logprobs: Option<crate::protocols::spec::ChatLogProbs>,
) -> ChatCompletionStreamResponse {
ChatCompletionStreamResponse {
id: request_id.to_string(),
object: "chat.completion.chunk".to_string(),
created,
model: model.to_string(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: Some(content),
tool_calls: None,
reasoning_content: None,
},
logprobs,
finish_reason: None,
matched_stop: None,
}],
usage: None,
}
}
/// Helper: Format response as SSE chunk
fn format_sse_chunk(response: &ChatCompletionStreamResponse) -> String {
format!(
"data: {}\n\n",
serde_json::to_string(response).unwrap_or_default()
)
}
/// Submit request and handle streaming response for chat completions route
async fn handle_streaming_chat(
&self,
mut client: SglangSchedulerClient,
request: proto::GenerateRequest,
original_request: &ChatCompletionRequest,
) -> Response {
let request_id = request.request_id.clone();
let model = original_request.model.clone();
// Create channel for SSE streaming
let (tx, rx) = mpsc::unbounded_channel::<Result<Bytes, io::Error>>();
// Start the gRPC stream
let mut grpc_stream = match client.generate(request).await {
Ok(stream) => stream,
Err(e) => {
return utils::internal_error_message(format!("Generation failed: {}", e));
}
};
let stop_params = (
original_request.stop.clone(),
original_request.stop_token_ids.clone(),
original_request.skip_special_tokens,
original_request.no_stop_trim,
);
// Spawn processing task
let self_clone = self.clone();
let original_request_clone = original_request.clone();
tokio::spawn(async move {
let result = Self::process_streaming_chunks(
&self_clone,
&mut grpc_stream,
request_id,
model,
stop_params,
original_request_clone,
&tx,
)
.await;
if let Err(e) = result {
let error_chunk = format!(
"data: {}\n\n",
json!({
"error": {
"message": e,
"type": "internal_error"
}
})
);
let _ = tx.send(Ok(Bytes::from(error_chunk)));
}
// Send DONE marker
let _ = tx.send(Ok(Bytes::from("data: [DONE]\n\n")));
});
// Create response with SSE headers
let stream = UnboundedReceiverStream::new(rx);
let mut response = Response::new(Body::from_stream(stream));
*response.status_mut() = StatusCode::OK;
response
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
response
.headers_mut()
.insert("Cache-Control", HeaderValue::from_static("no-cache"));
response
.headers_mut()
.insert("Connection", HeaderValue::from_static("keep-alive"));
response
}
/// Process streaming chunks and send SSE events
async fn process_streaming_chunks(
router: &GrpcRouter,
grpc_stream: &mut (impl tokio_stream::Stream<Item = Result<proto::GenerateResponse, tonic::Status>>
+ Unpin),
request_id: String,
model: String,
stop_params: (Option<StringOrArray>, Option<Vec<u32>>, bool, bool),
original_request: ChatCompletionRequest,
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
) -> Result<(), String> {
// Extract request parameters
let separate_reasoning = original_request.separate_reasoning;
let tool_choice = &original_request.tool_choice;
let tools = &original_request.tools;
let history_tool_calls_count = Self::get_history_tool_calls_count(&original_request);
let stream_options = &original_request.stream_options;
// Phase 1: Initialize state tracking (per-index for n>1 support)
let mut is_firsts: HashMap<u32, bool> = HashMap::new();
let mut stream_buffers: HashMap<u32, String> = HashMap::new();
let mut finish_reasons: HashMap<u32, String> = HashMap::new();
let mut matched_stops: HashMap<u32, Option<Value>> = HashMap::new();
let mut prompt_tokens: HashMap<u32, u32> = HashMap::new();
let mut completion_tokens: HashMap<u32, u32> = HashMap::new();
let mut cached_tokens: HashMap<u32, u32> = HashMap::new();
// Parser state (lazy initialization per index)
type PooledReasoningParser =
Arc<std::sync::Mutex<Box<dyn crate::reasoning_parser::ReasoningParser>>>;
let mut reasoning_parsers: HashMap<u32, PooledReasoningParser> = HashMap::new();
type PooledToolParser = Arc<tokio::sync::Mutex<Box<dyn crate::tool_parser::ToolParser>>>;
let mut tool_parsers: HashMap<u32, PooledToolParser> = HashMap::new();
let mut has_tool_calls: HashMap<u32, bool> = HashMap::new();
// Create stop decoder
let (stop, stop_token_ids, skip_special_tokens, no_stop_trim) = stop_params;
let mut stop_decoder = utils::create_stop_decoder(
&router.tokenizer,
stop.as_ref(),
stop_token_ids.as_ref(),
skip_special_tokens,
no_stop_trim,
);
let created = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
// Phase 2: Main streaming loop
while let Some(response) = grpc_stream.next().await {
let gen_response = response.map_err(|e| format!("Stream error: {}", e))?;
match gen_response.response {
Some(Chunk(chunk)) => {
let index = chunk.index;
// Process tokens through stop decoder
let (chunk_text, _should_stop) =
Self::process_chunk_tokens(&mut stop_decoder, &chunk.token_ids);
if chunk_text.is_empty() {
continue;
}
// Process logprobs if present
let choice_logprobs = if let Some(ref proto_logprobs) = chunk.output_logprobs {
match router.convert_proto_to_openai_logprobs(proto_logprobs) {
Ok(logprobs) => Some(logprobs),
Err(e) => {
warn!("Failed to process logprobs: {}", e);
None
}
}
} else {
None
};
// Initialize stream buffer if first time
let stream_buffer = stream_buffers.entry(index).or_default();
// Send first chunk with role
if is_firsts.get(&index).copied().unwrap_or(true) {
let first_chunk = ChatCompletionStreamResponse {
id: request_id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model.clone(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: None,
reasoning_content: None,
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
};
tx.send(Ok(Bytes::from(Self::format_sse_chunk(&first_chunk))))
.map_err(|_| "Failed to send first chunk".to_string())?;
is_firsts.insert(index, false);
}
// Calculate delta
let mut delta = chunk_text;
stream_buffer.push_str(&delta);
// Reasoning content handling
let in_reasoning = if separate_reasoning {
let (normal_text, reasoning_chunk, in_reasoning) = router
.process_reasoning_stream(
&delta,
index,
&mut reasoning_parsers,
&request_id,
&model,
created,
);
if let Some(chunk) = reasoning_chunk {
tx.send(Ok(Bytes::from(Self::format_sse_chunk(&chunk))))
.map_err(|_| "Failed to send reasoning chunk".to_string())?;
}
delta = normal_text;
in_reasoning
} else {
false
};
// Tool call handling
let tool_choice_enabled =
!matches!(tool_choice, Some(ToolChoice::Value(ToolChoiceValue::None)));
if !in_reasoning && tool_choice_enabled && tools.is_some() {
let (should_skip, tool_chunks) = router
.process_tool_calls_stream(
&delta,
index,
&mut tool_parsers,
&mut has_tool_calls,
tools.as_ref().unwrap(),
&request_id,
&model,
created,
history_tool_calls_count,
)
.await;
for chunk in tool_chunks {
tx.send(Ok(Bytes::from(Self::format_sse_chunk(&chunk))))
.map_err(|_| "Failed to send tool call chunk".to_string())?;
}
if should_skip {
continue;
}
}
// Regular content emission
if !delta.is_empty() {
let content_chunk = Self::create_content_chunk(
delta,
index,
&request_id,
&model,
created,
choice_logprobs,
);
tx.send(Ok(Bytes::from(Self::format_sse_chunk(&content_chunk))))
.map_err(|_| "Failed to send content chunk".to_string())?;
}
}
Some(Complete(complete)) => {
// Flush any remaining text
if let SequenceDecoderOutput::Text(text) = stop_decoder.flush() {
if !text.is_empty() {
let index = complete.index;
let stream_buffer = stream_buffers.entry(index).or_default();
stream_buffer.push_str(&text);
let content_chunk = ChatCompletionStreamResponse {
id: request_id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model.clone(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: Some(text),
tool_calls: None,
reasoning_content: None,
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
};
let sse_chunk = serde_json::to_string(&content_chunk)
.map_err(|e| format!("Failed to serialize content chunk: {}", e))?;
tx.send(Ok(Bytes::from(format!("data: {}\n\n", sse_chunk))))
.map_err(|_| "Failed to send flushed content".to_string())?;
}
}
// Store metadata
let index = complete.index;
prompt_tokens.insert(index, complete.prompt_tokens as u32);
completion_tokens.insert(index, complete.completion_tokens as u32);
cached_tokens.insert(index, complete.cached_tokens as u32);
finish_reasons.insert(index, complete.finish_reason.clone());
// Extract matched_stop
let matched_stop_value = match &complete.matched_stop {
Some(proto::generate_complete::MatchedStop::MatchedTokenId(token_id)) => {
Some(Value::Number(serde_json::Number::from(*token_id)))
}
Some(proto::generate_complete::MatchedStop::MatchedStopStr(stop_str)) => {
Some(Value::String(stop_str.clone()))
}
None => None,
};
matched_stops.insert(index, matched_stop_value);
break;
}
Some(Error(error)) => {
return Err(error.message);
}
None => continue,
}
}
// Phase 3: Check unstreamed tool args
// Check if parsers have any remaining arguments that haven't been streamed yet
for (index, parser) in &tool_parsers {
let parser_guard = parser.lock().await;
if let Some(unstreamed_items) = parser_guard.get_unstreamed_tool_args() {
for tool_call_item in unstreamed_items {
let tool_call_delta = ToolCallDelta {
index: tool_call_item.tool_index as u32,
id: None,
tool_type: None, // No type for argument deltas
function: Some(FunctionCallDelta {
name: None, // No name for argument deltas
arguments: if !tool_call_item.parameters.is_empty() {
Some(tool_call_item.parameters)
} else {
None
},
}),
};
let tool_chunk = ChatCompletionStreamResponse {
id: request_id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model.clone(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index: *index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: Some(vec![tool_call_delta]),
reasoning_content: None,
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
};
let sse_chunk = serde_json::to_string(&tool_chunk)
.map_err(|e| format!("Failed to serialize tool chunk: {}", e))?;
tx.send(Ok(Bytes::from(format!("data: {}\n\n", sse_chunk))))
.map_err(|_| "Failed to send unstreamed tool args".to_string())?;
}
}
}
// Phase 4: Finish reason chunks
for (index, finish_reason) in finish_reasons.iter() {
let final_finish_reason =
if has_tool_calls.get(index).copied().unwrap_or(false) && finish_reason == "stop" {
"tool_calls".to_string()
} else {
finish_reason.clone()
};
let matched_stop_value = matched_stops.get(index).and_then(|v| v.clone());
let finish_chunk = ChatCompletionStreamResponse {
id: request_id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model.clone(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index: *index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: None,
reasoning_content: None,
},
logprobs: None,
finish_reason: Some(final_finish_reason),
matched_stop: matched_stop_value,
}],
usage: None,
};
let sse_chunk = serde_json::to_string(&finish_chunk)
.map_err(|e| format!("Failed to serialize finish chunk: {}", e))?;
tx.send(Ok(Bytes::from(format!("data: {}\n\n", sse_chunk))))
.map_err(|_| "Failed to send finish chunk".to_string())?;
}
// Phase 5: Usage chunk
if let Some(stream_opts) = stream_options {
if stream_opts.include_usage.unwrap_or(false) {
let total_prompt: u32 = prompt_tokens.values().sum();
let total_completion: u32 = completion_tokens.values().sum();
let usage_chunk = ChatCompletionStreamResponse {
id: request_id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model.clone(),
system_fingerprint: None,
choices: vec![],
usage: Some(Usage {
prompt_tokens: total_prompt,
completion_tokens: total_completion,
total_tokens: total_prompt + total_completion,
completion_tokens_details: None,
}),
};
let sse_chunk = serde_json::to_string(&usage_chunk)
.map_err(|e| format!("Failed to serialize usage chunk: {}", e))?;
tx.send(Ok(Bytes::from(format!("data: {}\n\n", sse_chunk))))
.map_err(|_| "Failed to send usage chunk".to_string())?;
}
}
Ok(())
}
/// Submit request and handle non-streaming response for chat completions route
async fn handle_non_streaming_chat(
&self,
mut client: SglangSchedulerClient,
request: proto::GenerateRequest,
original_request: &ChatCompletionRequest,
) -> Response {
let mut stop_decoder = utils::create_stop_decoder(
&self.tokenizer,
original_request.stop.as_ref(),
original_request.stop_token_ids.as_ref(),
original_request.skip_special_tokens,
original_request.no_stop_trim,
);
// Start generation
let stream = match client.generate(request).await {
Ok(s) => s,
Err(e) => {
return utils::internal_error_message(format!("Failed to start generation: {}", e))
}
};
let all_responses = match utils::collect_stream_responses(stream, "Regular").await {
Ok(responses) => responses,
Err(err_response) => return err_response,
};
if all_responses.is_empty() {
return utils::internal_error_static("No responses from server");
}
// Process each response into a ChatChoice
let history_tool_calls_count = Self::get_history_tool_calls_count(original_request);
let mut choices = Vec::new();
for (index, complete) in all_responses.iter().enumerate() {
match self
.process_single_choice(
complete,
index,
original_request,
&mut stop_decoder,
history_tool_calls_count,
)
.await
{
Ok(choice) => choices.push(choice),
Err(e) => {
return utils::internal_error_message(format!(
"Failed to process choice {}: {}",
index, e
));
}
}
}
// Aggregate usage information from all responses
let total_prompt_tokens: u32 = all_responses.iter().map(|r| r.prompt_tokens as u32).sum();
let total_completion_tokens: u32 = all_responses
.iter()
.map(|r| r.completion_tokens as u32)
.sum();
let usage = Usage {
prompt_tokens: total_prompt_tokens,
completion_tokens: total_completion_tokens,
total_tokens: total_prompt_tokens + total_completion_tokens,
completion_tokens_details: None,
};
// Build final ChatCompletionResponse
let response = ChatCompletionResponse {
id: format!("chatcmpl-{}", Uuid::new_v4()),
object: "chat.completion".to_string(),
created: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
model: original_request.model.clone(),
choices,
usage: Some(usage),
system_fingerprint: None,
};
// Serialize and return JSON response
Json(response).into_response()
}
/// Submit request and handle non-streaming response for the `/generate` endpoint /// Submit request and handle non-streaming response for the `/generate` endpoint
async fn handle_non_streaming_generate( async fn handle_non_streaming_generate(
&self, &self,
...@@ -1498,234 +564,6 @@ impl GrpcRouter { ...@@ -1498,234 +564,6 @@ impl GrpcRouter {
Ok(()) Ok(())
} }
/// Convert proto LogProbs to OpenAI ChatLogProbs format
/// Note: Always decodes with skip_special_tokens=false to show actual tokens generated
fn convert_proto_to_openai_logprobs(
&self,
proto_logprobs: &proto::OutputLogProbs,
) -> Result<crate::protocols::spec::ChatLogProbs, String> {
let mut content_items = Vec::new();
// Decode token IDs to text (always with skip_special_tokens=false for logprobs)
let token_texts: Vec<String> = proto_logprobs
.token_ids
.iter()
.map(|&token_id| {
self.tokenizer
.decode(&[token_id as u32], false)
.unwrap_or_else(|_| format!("<token_{}>", token_id))
})
.collect();
// Build ChatLogProbsContent for each token (consume iterator to avoid clones)
for (i, (&logprob, token_text)) in proto_logprobs
.token_logprobs
.iter()
.zip(token_texts.into_iter())
.enumerate()
{
let bytes = Some(token_text.as_bytes().to_vec());
// Build top_logprobs for this position
let mut top_logprobs = Vec::new();
if let Some(top_logprobs_entry) = proto_logprobs.top_logprobs.get(i) {
// Decode top token IDs (always with skip_special_tokens=false)
let top_token_texts: Vec<String> = top_logprobs_entry
.token_ids
.iter()
.map(|&tid| {
self.tokenizer
.decode(&[tid as u32], false)
.unwrap_or_else(|_| format!("<token_{}>", tid))
})
.collect();
for (j, (&top_logprob, &_top_token_id)) in top_logprobs_entry
.values
.iter()
.zip(top_logprobs_entry.token_ids.iter())
.enumerate()
{
if let Some(top_token_text) = top_token_texts.get(j) {
top_logprobs.push(crate::protocols::spec::TopLogProb {
token: top_token_text.clone(),
logprob: top_logprob,
bytes: Some(top_token_text.as_bytes().to_vec()),
});
}
}
}
content_items.push(crate::protocols::spec::ChatLogProbsContent {
token: token_text,
logprob,
bytes,
top_logprobs,
});
}
Ok(crate::protocols::spec::ChatLogProbs::Detailed {
content: (!content_items.is_empty()).then_some(content_items),
})
}
/// Process a single GenerateComplete response into a ChatChoice
async fn process_single_choice(
&self,
complete: &proto::GenerateComplete,
index: usize,
original_request: &ChatCompletionRequest,
stop_decoder: &mut StopSequenceDecoder,
history_tool_calls_count: usize,
) -> Result<ChatChoice, String> {
stop_decoder.reset();
// Decode tokens
let outputs = stop_decoder
.process_tokens(&complete.output_ids)
.map_err(|e| format!("Failed to process tokens: {}", e))?;
// Accumulate text with early breaks
let mut final_text = String::new();
for output in outputs {
match output {
SequenceDecoderOutput::Text(t) => final_text.push_str(&t),
SequenceDecoderOutput::StoppedWithText(t) => {
final_text.push_str(&t);
break;
}
SequenceDecoderOutput::Stopped => break,
SequenceDecoderOutput::Held => {}
}
}
// Flush remaining text
if let SequenceDecoderOutput::Text(t) = stop_decoder.flush() {
final_text.push_str(&t);
}
// Step 1: Handle reasoning content parsing
let mut reasoning_text: Option<String> = None;
let mut processed_text = final_text;
// Check if reasoning parsing is enabled and separate_reasoning is requested
if original_request.separate_reasoning {
let pooled_parser = utils::get_reasoning_parser(
&self.reasoning_parser_factory,
self.configured_reasoning_parser.as_ref(),
&original_request.model,
);
let mut parser = pooled_parser
.lock()
.map_err(|e| format!("Failed to acquire reasoning parser lock: {}", e))?;
match parser.detect_and_parse_reasoning(&processed_text) {
Ok(result) => {
if !result.reasoning_text.is_empty() {
reasoning_text = Some(result.reasoning_text);
}
processed_text = result.normal_text;
}
Err(e) => {
return Err(format!("Reasoning parsing error: {}", e));
}
}
}
// Step 2: Handle tool call parsing
let mut tool_calls: Option<Vec<ToolCall>> = None;
// Check if tool calls should be processed
let tool_choice_enabled = !matches!(
&original_request.tool_choice,
Some(ToolChoice::Value(ToolChoiceValue::None))
);
if tool_choice_enabled && original_request.tools.is_some() {
// Check if JSON schema constraint was used (specific function or required mode)
let used_json_schema = match &original_request.tool_choice {
Some(ToolChoice::Function { .. }) => true,
Some(ToolChoice::Value(ToolChoiceValue::Required)) => true,
Some(ToolChoice::AllowedTools { mode, .. }) => mode == "required",
_ => false,
};
if used_json_schema {
(tool_calls, processed_text) = utils::parse_json_schema_response(
&processed_text,
&original_request.tool_choice,
);
} else {
(tool_calls, processed_text) = self
.parse_tool_calls(
&processed_text,
&original_request.model,
history_tool_calls_count,
)
.await;
}
}
// Step 3: Use finish reason directly from proto (already OpenAI-compatible string)
let finish_reason_str = &complete.finish_reason;
// Override finish reason if we have tool calls
let final_finish_reason_str = if tool_calls.is_some() {
"tool_calls"
} else {
finish_reason_str
};
// Extract matched_stop information from proto
let matched_stop = match &complete.matched_stop {
Some(proto::generate_complete::MatchedStop::MatchedTokenId(token_id)) => {
Some(Value::Number(serde_json::Number::from(*token_id)))
}
Some(proto::generate_complete::MatchedStop::MatchedStopStr(stop_str)) => {
Some(Value::String(stop_str.clone()))
}
None => None,
};
// Step 4: Convert output logprobs if present
// Note: complete.input_logprobs exists in proto but is not used for chat completions
// (input logprobs are only used in /v1/completions endpoint with echo=true)
let logprobs = if let Some(proto_logprobs) = &complete.output_logprobs {
match self.convert_proto_to_openai_logprobs(proto_logprobs) {
Ok(logprobs) => Some(logprobs),
Err(e) => {
error!("Failed to convert logprobs: {}", e);
None
}
}
} else {
None
};
// Step 5: Build ChatCompletionMessage (proper response message type)
let chat_message = ChatCompletionMessage {
role: "assistant".to_string(),
content: if processed_text.is_empty() {
None
} else {
Some(processed_text)
},
tool_calls,
reasoning_content: reasoning_text,
};
// Step 6: Build ChatChoice
let choice = ChatChoice {
index: index as u32,
message: chat_message,
logprobs,
finish_reason: Some(final_finish_reason_str.to_string()),
matched_stop,
hidden_states: None,
};
Ok(choice)
}
} }
impl std::fmt::Debug for GrpcRouter { impl std::fmt::Debug for GrpcRouter {
......
//! Streaming response processor for gRPC routers
//!
//! This module contains shared streaming logic for both Regular and PD routers,
//! eliminating ~600 lines of duplication.
use axum::response::Response;
use axum::{body::Body, http::StatusCode};
use bytes::Bytes;
use http::header::{HeaderValue, CONTENT_TYPE};
use serde_json::{json, Value};
use std::collections::HashMap;
use std::io;
use std::sync::Arc;
use tokio::sync::mpsc::UnboundedSender;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::StreamExt;
use tonic::codec::Streaming;
use tracing::{debug, error, warn};
use crate::grpc_client::proto;
use crate::protocols::spec::*;
use crate::reasoning_parser::ReasoningParser;
use crate::tokenizer::stop::{SequenceDecoderOutput, StopSequenceDecoder};
use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ToolParser;
use super::context;
use super::utils;
/// Shared streaming processor for both single and dual dispatch modes
#[derive(Clone)]
pub struct StreamingProcessor {
tokenizer: Arc<dyn Tokenizer>,
tool_parser_factory: crate::tool_parser::ToolParserFactory,
reasoning_parser_factory: crate::reasoning_parser::ReasoningParserFactory,
configured_tool_parser: Option<String>,
configured_reasoning_parser: Option<String>,
}
impl StreamingProcessor {
pub fn new(
tokenizer: Arc<dyn Tokenizer>,
tool_parser_factory: crate::tool_parser::ToolParserFactory,
reasoning_parser_factory: crate::reasoning_parser::ReasoningParserFactory,
configured_tool_parser: Option<String>,
configured_reasoning_parser: Option<String>,
) -> Self {
Self {
tokenizer,
tool_parser_factory,
reasoning_parser_factory,
configured_tool_parser,
configured_reasoning_parser,
}
}
/// Process streaming chat response and return SSE response
///
/// This is the high-level entry point for streaming responses, handling:
/// - Channel creation
/// - Background task spawning
/// - SSE response building
pub fn process_streaming_response(
self: Arc<Self>,
execution_result: context::ExecutionResult,
chat_request: ChatCompletionRequest,
dispatch: context::DispatchMetadata,
) -> axum::response::Response {
use bytes::Bytes;
use tokio::sync::mpsc;
let stop_params = (
chat_request.stop.clone(),
chat_request.stop_token_ids.clone(),
chat_request.skip_special_tokens,
chat_request.no_stop_trim,
);
// Create SSE channel
let (tx, rx) = mpsc::unbounded_channel::<Result<Bytes, io::Error>>();
// Spawn background task based on execution mode
match execution_result {
context::ExecutionResult::Single { stream } => {
let processor = self.clone();
let dispatch_clone = dispatch.clone();
tokio::spawn(async move {
let result = processor
.process_streaming_chunks(
stream,
dispatch_clone,
stop_params,
chat_request,
&tx,
)
.await;
if let Err(e) = result {
let error_chunk = format!(
"data: {}\n\n",
json!({
"error": {
"message": e,
"type": "internal_error"
}
})
);
let _ = tx.send(Ok(Bytes::from(error_chunk)));
}
let _ = tx.send(Ok(Bytes::from("data: [DONE]\n\n")));
});
}
context::ExecutionResult::Dual { prefill, decode } => {
let processor = self.clone();
tokio::spawn(async move {
let result = processor
.process_dual_streaming_chunks(
prefill,
*decode,
dispatch,
stop_params,
chat_request,
&tx,
)
.await;
if let Err(e) = result {
let error_chunk = format!(
"data: {}\n\n",
json!({
"error": {
"message": e,
"type": "internal_error"
}
})
);
let _ = tx.send(Ok(Bytes::from(error_chunk)));
}
let _ = tx.send(Ok(Bytes::from("data: [DONE]\n\n")));
});
}
}
// Return SSE response
build_sse_response(rx)
}
/// Process streaming chunks from a single stream (Regular mode)
pub async fn process_streaming_chunks(
&self,
mut grpc_stream: Streaming<proto::GenerateResponse>,
dispatch: context::DispatchMetadata,
stop_params: (Option<StringOrArray>, Option<Vec<u32>>, bool, bool),
original_request: ChatCompletionRequest,
tx: &UnboundedSender<Result<Bytes, io::Error>>,
) -> Result<(), String> {
// Extract request parameters
let separate_reasoning = original_request.separate_reasoning;
let tool_choice = &original_request.tool_choice;
let tools = &original_request.tools;
let history_tool_calls_count = utils::get_history_tool_calls_count(&original_request);
let stream_options = &original_request.stream_options;
// Phase 1: Initialize state tracking (per-index for n>1 support)
let mut is_firsts: HashMap<u32, bool> = HashMap::new();
let mut stream_buffers: HashMap<u32, String> = HashMap::new();
let mut finish_reasons: HashMap<u32, String> = HashMap::new();
let mut matched_stops: HashMap<u32, Option<Value>> = HashMap::new();
let mut prompt_tokens: HashMap<u32, u32> = HashMap::new();
let mut completion_tokens: HashMap<u32, u32> = HashMap::new();
let mut cached_tokens: HashMap<u32, u32> = HashMap::new();
// Parser state (lazy initialization per index)
type PooledReasoningParser = Arc<std::sync::Mutex<Box<dyn ReasoningParser>>>;
let mut reasoning_parsers: HashMap<u32, PooledReasoningParser> = HashMap::new();
type PooledToolParser = Arc<tokio::sync::Mutex<Box<dyn ToolParser>>>;
let mut tool_parsers: HashMap<u32, PooledToolParser> = HashMap::new();
let mut has_tool_calls: HashMap<u32, bool> = HashMap::new();
// Per-index stop decoders (each index needs its own state for n>1 support)
let mut stop_decoders: HashMap<u32, StopSequenceDecoder> = HashMap::new();
// Use dispatch metadata for consistent response fields
let request_id = &dispatch.request_id;
let model = &dispatch.model;
let created = dispatch.created;
let system_fingerprint = dispatch.weight_version.as_deref();
// Phase 2: Main streaming loop
while let Some(response) = grpc_stream.next().await {
let gen_response = response.map_err(|e| format!("Stream error: {}", e))?;
match gen_response.response {
Some(proto::generate_response::Response::Chunk(chunk)) => {
let index = chunk.index;
// Get or create stop decoder for this index
let stop_decoder = stop_decoders.entry(index).or_insert_with(|| {
let (ref stop, ref stop_token_ids, skip_special_tokens, no_stop_trim) =
stop_params;
utils::create_stop_decoder(
&self.tokenizer,
stop.as_ref(),
stop_token_ids.as_ref(),
skip_special_tokens,
no_stop_trim,
)
});
// Process tokens through stop decoder
let (chunk_text, _should_stop) =
Self::process_chunk_tokens(stop_decoder, &chunk.token_ids);
if chunk_text.is_empty() {
continue;
}
// Process logprobs if present
let choice_logprobs = if let Some(ref proto_logprobs) = chunk.output_logprobs {
match utils::convert_proto_to_openai_logprobs(
proto_logprobs,
&self.tokenizer,
) {
Ok(logprobs) => Some(logprobs),
Err(e) => {
warn!("Failed to process logprobs: {}", e);
None
}
}
} else {
None
};
// Initialize stream buffer if first time
let stream_buffer = stream_buffers.entry(index).or_default();
// Send first chunk with role
if is_firsts.get(&index).copied().unwrap_or(true) {
let first_chunk = ChatCompletionStreamResponse {
id: request_id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model.clone(),
system_fingerprint: system_fingerprint.map(|s| s.to_string()),
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: None,
reasoning_content: None,
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
};
tx.send(Ok(Bytes::from(Self::format_sse_chunk(&first_chunk))))
.map_err(|_| "Failed to send first chunk".to_string())?;
is_firsts.insert(index, false);
}
// Calculate delta
let mut delta = chunk_text;
stream_buffer.push_str(&delta);
// Reasoning content handling
let in_reasoning = if separate_reasoning {
let (normal_text, reasoning_chunk, in_reasoning) = self
.process_reasoning_stream(
&delta,
index,
&mut reasoning_parsers,
request_id,
model,
created,
system_fingerprint,
);
if let Some(chunk) = reasoning_chunk {
tx.send(Ok(Bytes::from(Self::format_sse_chunk(&chunk))))
.map_err(|_| "Failed to send reasoning chunk".to_string())?;
}
delta = normal_text;
in_reasoning
} else {
false
};
// Tool call handling
let tool_choice_enabled =
!matches!(tool_choice, Some(ToolChoice::Value(ToolChoiceValue::None)));
if !in_reasoning && tool_choice_enabled && tools.is_some() {
let (should_skip, tool_chunks) = self
.process_tool_calls_stream(
&delta,
index,
&mut tool_parsers,
&mut has_tool_calls,
tools.as_ref().unwrap(),
request_id,
model,
created,
system_fingerprint,
history_tool_calls_count,
)
.await;
for chunk in tool_chunks {
tx.send(Ok(Bytes::from(Self::format_sse_chunk(&chunk))))
.map_err(|_| "Failed to send tool call chunk".to_string())?;
}
// Continue to process the next chunk as we have tool chunks
if should_skip {
continue;
}
}
// Regular content emission
if !delta.is_empty() {
let content_chunk = Self::create_content_chunk(
delta,
index,
request_id,
model,
created,
system_fingerprint,
choice_logprobs,
);
tx.send(Ok(Bytes::from(Self::format_sse_chunk(&content_chunk))))
.map_err(|_| "Failed to send content chunk".to_string())?;
}
}
Some(proto::generate_response::Response::Complete(complete)) => {
let index = complete.index;
// Flush any remaining text for this index's stop_decoder
if let Some(decoder) = stop_decoders.get_mut(&index) {
if let SequenceDecoderOutput::Text(text) = decoder.flush() {
if !text.is_empty() {
let stream_buffer = stream_buffers.entry(index).or_default();
stream_buffer.push_str(&text);
let content_chunk = ChatCompletionStreamResponse {
id: request_id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model.clone(),
system_fingerprint: system_fingerprint.map(|s| s.to_string()),
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: Some(text),
tool_calls: None,
reasoning_content: None,
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
};
let sse_chunk =
serde_json::to_string(&content_chunk).map_err(|e| {
format!("Failed to serialize content chunk: {}", e)
})?;
tx.send(Ok(Bytes::from(format!("data: {}\n\n", sse_chunk))))
.map_err(|_| "Failed to send flushed content".to_string())?;
}
}
}
// Store metadata
prompt_tokens.insert(index, complete.prompt_tokens as u32);
completion_tokens.insert(index, complete.completion_tokens as u32);
cached_tokens.insert(index, complete.cached_tokens as u32);
finish_reasons.insert(index, complete.finish_reason.clone());
// Extract matched_stop
let matched_stop_value = match &complete.matched_stop {
Some(proto::generate_complete::MatchedStop::MatchedTokenId(token_id)) => {
Some(Value::Number(serde_json::Number::from(*token_id)))
}
Some(proto::generate_complete::MatchedStop::MatchedStopStr(stop_str)) => {
Some(Value::String(stop_str.clone()))
}
None => None,
};
matched_stops.insert(index, matched_stop_value);
// Don't break - continue reading all Complete messages for n>1
}
Some(proto::generate_response::Response::Error(error)) => {
return Err(error.message);
}
None => continue,
}
}
// Phase 3: Check unstreamed tool args
for (index, parser) in &tool_parsers {
let parser_guard = parser.lock().await;
if let Some(unstreamed_items) = parser_guard.get_unstreamed_tool_args() {
for tool_call_item in unstreamed_items {
let tool_call_delta = ToolCallDelta {
index: tool_call_item.tool_index as u32,
id: None,
tool_type: None,
function: Some(FunctionCallDelta {
name: None,
arguments: if !tool_call_item.parameters.is_empty() {
Some(tool_call_item.parameters)
} else {
None
},
}),
};
let tool_chunk = ChatCompletionStreamResponse {
id: request_id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model.clone(),
system_fingerprint: system_fingerprint.map(|s| s.to_string()),
choices: vec![ChatStreamChoice {
index: *index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: Some(vec![tool_call_delta]),
reasoning_content: None,
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
};
let sse_chunk = serde_json::to_string(&tool_chunk)
.map_err(|e| format!("Failed to serialize tool chunk: {}", e))?;
tx.send(Ok(Bytes::from(format!("data: {}\n\n", sse_chunk))))
.map_err(|_| "Failed to send unstreamed tool args".to_string())?;
}
}
}
// Phase 4: Finish reason chunks
for (index, finish_reason) in finish_reasons.iter() {
let final_finish_reason =
if has_tool_calls.get(index).copied().unwrap_or(false) && finish_reason == "stop" {
"tool_calls".to_string()
} else {
finish_reason.clone()
};
let matched_stop_value = matched_stops.get(index).and_then(|v| v.clone());
let finish_chunk = ChatCompletionStreamResponse {
id: request_id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model.clone(),
system_fingerprint: system_fingerprint.map(|s| s.to_string()),
choices: vec![ChatStreamChoice {
index: *index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: None,
reasoning_content: None,
},
logprobs: None,
finish_reason: Some(final_finish_reason),
matched_stop: matched_stop_value,
}],
usage: None,
};
let sse_chunk = serde_json::to_string(&finish_chunk)
.map_err(|e| format!("Failed to serialize finish chunk: {}", e))?;
tx.send(Ok(Bytes::from(format!("data: {}\n\n", sse_chunk))))
.map_err(|_| "Failed to send finish chunk".to_string())?;
}
// Phase 5: Usage chunk
if let Some(stream_opts) = stream_options {
if stream_opts.include_usage.unwrap_or(false) {
let total_prompt: u32 = prompt_tokens.values().sum();
let total_completion: u32 = completion_tokens.values().sum();
let usage_chunk = ChatCompletionStreamResponse {
id: request_id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model.clone(),
system_fingerprint: system_fingerprint.map(|s| s.to_string()),
choices: vec![],
usage: Some(Usage {
prompt_tokens: total_prompt,
completion_tokens: total_completion,
total_tokens: total_prompt + total_completion,
completion_tokens_details: None,
}),
};
let sse_chunk = serde_json::to_string(&usage_chunk)
.map_err(|e| format!("Failed to serialize usage chunk: {}", e))?;
tx.send(Ok(Bytes::from(format!("data: {}\n\n", sse_chunk))))
.map_err(|_| "Failed to send usage chunk".to_string())?;
}
}
Ok(())
}
/// Process dual streaming chunks (prefill + decode) - PD mode
pub async fn process_dual_streaming_chunks(
&self,
mut prefill_stream: Streaming<proto::GenerateResponse>,
decode_stream: Streaming<proto::GenerateResponse>,
dispatch: context::DispatchMetadata,
stop_params: (Option<StringOrArray>, Option<Vec<u32>>, bool, bool),
original_request: ChatCompletionRequest,
tx: &UnboundedSender<Result<Bytes, io::Error>>,
) -> Result<(), String> {
// Phase 1.5: Collect input_logprobs from prefill stream if requested
if original_request.logprobs {
while let Some(response) = prefill_stream.next().await {
let gen_response = response.map_err(|e| format!("Prefill stream error: {}", e))?;
match gen_response.response {
Some(proto::generate_response::Response::Complete(_complete)) => {
// Input logprobs collected but not yet used in streaming
// (OpenAI spec doesn't require prompt logprobs in streaming responses)
break;
}
Some(proto::generate_response::Response::Error(error)) => {
return Err(format!("Prefill error: {}", error.message));
}
_ => continue,
}
}
}
// Phase 2-5: Process decode stream (same as single mode)
self.process_streaming_chunks(decode_stream, dispatch, stop_params, original_request, tx)
.await
}
// TODO(generate): Add streaming generate handler
//
// pub async fn process_streaming_generate(
// self: Arc<Self>,
// execution_result: context::ExecutionResult,
// generate_request: GenerateRequest,
// dispatch: context::DispatchMetadata,
// ) -> axum::response::Response {
// // Similar to process_streaming_response but:
// // - No tool parsing
// // - No reasoning parsing
// // - Simpler chunk format (just text + finish_reason + logprobs)
// // - Extract stop params from generate_request.sampling_params
// // - Use same per-index stop decoder logic
// // - Emit SSE chunks with format similar to chat but without delta.tool_calls
// // Reference: router.rs:422-595
// }
// ========================================================================
// Helper Methods
// ========================================================================
/// Process a chunk of tokens through the stop decoder
fn process_chunk_tokens(
stop_decoder: &mut StopSequenceDecoder,
token_ids: &[u32],
) -> (String, bool) {
let mut chunk_text = String::new();
for &token_id in token_ids {
match stop_decoder.process_token(token_id).unwrap_or_else(|e| {
debug!(
"Error processing token {}: {}. Treating as Held.",
token_id, e
);
SequenceDecoderOutput::Held
}) {
SequenceDecoderOutput::Text(text) => {
chunk_text.push_str(&text);
}
SequenceDecoderOutput::StoppedWithText(text) => {
chunk_text.push_str(&text);
return (chunk_text, true);
}
SequenceDecoderOutput::Stopped => {
return (chunk_text, true);
}
SequenceDecoderOutput::Held => {}
}
}
(chunk_text, false)
}
/// Helper: Process reasoning content in streaming mode
#[allow(clippy::too_many_arguments)]
fn process_reasoning_stream(
&self,
delta: &str,
index: u32,
reasoning_parsers: &mut HashMap<u32, Arc<std::sync::Mutex<Box<dyn ReasoningParser>>>>,
request_id: &str,
model: &str,
created: u64,
system_fingerprint: Option<&str>,
) -> (String, Option<ChatCompletionStreamResponse>, bool) {
// Get or create parser for this index
reasoning_parsers.entry(index).or_insert_with(|| {
utils::get_reasoning_parser(
&self.reasoning_parser_factory,
self.configured_reasoning_parser.as_ref(),
model,
)
});
if let Some(pooled_parser) = reasoning_parsers.get(&index) {
let (parse_result, in_reasoning) = {
let mut parser = pooled_parser.lock().unwrap();
let result = parser.parse_reasoning_streaming_incremental(delta);
let in_reasoning = parser.is_in_reasoning();
(result, in_reasoning)
};
match parse_result {
Ok(crate::reasoning_parser::ParserResult {
reasoning_text,
normal_text,
}) => {
let chunk = if !reasoning_text.is_empty() {
Some(ChatCompletionStreamResponse {
id: request_id.to_string(),
object: "chat.completion.chunk".to_string(),
created,
model: model.to_string(),
system_fingerprint: system_fingerprint.map(|s| s.to_string()),
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: None,
reasoning_content: Some(reasoning_text),
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
})
} else {
None
};
return (normal_text, chunk, in_reasoning);
}
Err(e) => {
warn!("Reasoning parsing error: {}", e);
}
}
}
(delta.to_string(), None, false)
}
/// Helper: Process tool calls in streaming mode
#[allow(clippy::too_many_arguments)]
async fn process_tool_calls_stream(
&self,
delta: &str,
index: u32,
tool_parsers: &mut HashMap<u32, Arc<tokio::sync::Mutex<Box<dyn ToolParser>>>>,
has_tool_calls: &mut HashMap<u32, bool>,
tools: &[Tool],
request_id: &str,
model: &str,
created: u64,
system_fingerprint: Option<&str>,
history_tool_calls_count: usize,
) -> (bool, Vec<ChatCompletionStreamResponse>) {
let mut chunks = Vec::new();
// Get or create parser for this index
tool_parsers.entry(index).or_insert_with(|| {
utils::get_tool_parser(
&self.tool_parser_factory,
self.configured_tool_parser.as_ref(),
model,
)
});
if let Some(pooled_parser) = tool_parsers.get(&index) {
let mut parser = pooled_parser.lock().await;
match parser.parse_incremental(delta, tools).await {
Ok(crate::tool_parser::StreamingParseResult { normal_text, calls }) => {
// Emit normal text if present
if !normal_text.is_empty() {
chunks.push(ChatCompletionStreamResponse {
id: request_id.to_string(),
object: "chat.completion.chunk".to_string(),
created,
model: model.to_string(),
system_fingerprint: system_fingerprint.map(|s| s.to_string()),
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: Some(normal_text),
tool_calls: None,
reasoning_content: None,
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
});
}
// Emit tool call chunks
for tool_call_item in calls {
has_tool_calls.insert(index, true);
let tool_call_id = if let Some(ref name) = tool_call_item.name {
Some(utils::generate_tool_call_id(
model,
name,
tool_call_item.tool_index,
history_tool_calls_count,
))
} else {
None
};
let tool_call_delta = ToolCallDelta {
index: tool_call_item.tool_index as u32,
id: tool_call_id,
tool_type: if tool_call_item.name.is_some() {
Some("function".to_string())
} else {
None
},
function: Some(FunctionCallDelta {
name: tool_call_item.name,
arguments: if !tool_call_item.parameters.is_empty() {
Some(tool_call_item.parameters)
} else {
None
},
}),
};
chunks.push(ChatCompletionStreamResponse {
id: request_id.to_string(),
object: "chat.completion.chunk".to_string(),
created,
model: model.to_string(),
system_fingerprint: system_fingerprint.map(|s| s.to_string()),
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: Some(vec![tool_call_delta]),
reasoning_content: None,
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
});
}
// If we emitted chunks, skip regular content
return (!chunks.is_empty(), chunks);
}
Err(e) => {
error!("Tool call parsing error: {}", e);
}
}
}
(false, chunks)
}
/// Format a response as SSE chunk
fn format_sse_chunk(chunk: &ChatCompletionStreamResponse) -> String {
match serde_json::to_string(chunk) {
Ok(json) => format!("data: {}\n\n", json),
Err(e) => {
error!("Failed to serialize SSE chunk: {}", e);
format!("data: {}\n\n", json!({"error": "serialization_failed"}))
}
}
}
/// Create a content chunk response
fn create_content_chunk(
content: String,
index: u32,
request_id: &str,
model: &str,
created: u64,
system_fingerprint: Option<&str>,
logprobs: Option<ChatLogProbs>,
) -> ChatCompletionStreamResponse {
ChatCompletionStreamResponse {
id: request_id.to_string(),
object: "chat.completion.chunk".to_string(),
created,
model: model.to_string(),
system_fingerprint: system_fingerprint.map(|s| s.to_string()),
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: Some(content),
tool_calls: None,
reasoning_content: None,
},
logprobs,
finish_reason: None,
matched_stop: None,
}],
usage: None,
}
}
}
/// Build SSE response with proper headers
pub fn build_sse_response(
rx: tokio::sync::mpsc::UnboundedReceiver<Result<Bytes, io::Error>>,
) -> Response {
let stream = UnboundedReceiverStream::new(rx);
let mut response = Response::new(Body::from_stream(stream));
*response.status_mut() = StatusCode::OK;
response
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
response
.headers_mut()
.insert("Cache-Control", HeaderValue::from_static("no-cache"));
response
.headers_mut()
.insert("Connection", HeaderValue::from_static("keep-alive"));
response
}
...@@ -4,8 +4,8 @@ use super::ProcessedMessages; ...@@ -4,8 +4,8 @@ use super::ProcessedMessages;
use crate::core::Worker; use crate::core::Worker;
use crate::grpc_client::{proto, SglangSchedulerClient}; use crate::grpc_client::{proto, SglangSchedulerClient};
use crate::protocols::spec::{ use crate::protocols::spec::{
ChatCompletionRequest, ChatMessage, FunctionCallResponse, StringOrArray, Tool, ToolCall, ChatCompletionRequest, ChatLogProbs, ChatLogProbsContent, ChatMessage, FunctionCallResponse,
ToolChoice, ToolChoiceValue, StringOrArray, Tool, ToolCall, ToolChoice, ToolChoiceValue, TopLogProb,
}; };
use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams}; use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams};
use crate::tokenizer::traits::Tokenizer; use crate::tokenizer::traits::Tokenizer;
...@@ -736,6 +736,79 @@ pub fn get_tool_parser( ...@@ -736,6 +736,79 @@ pub fn get_tool_parser(
} }
} }
/// Convert proto::OutputLogProbs to OpenAI ChatLogProbs format
///
/// This function decodes token IDs using the tokenizer and builds the logprobs structure
/// expected by the OpenAI API format.
pub fn convert_proto_to_openai_logprobs(
proto_logprobs: &proto::OutputLogProbs,
tokenizer: &Arc<dyn Tokenizer>,
) -> Result<ChatLogProbs, String> {
let mut content_items = Vec::new();
// Decode token IDs to text (always with skip_special_tokens=false for logprobs)
let token_texts: Vec<String> = proto_logprobs
.token_ids
.iter()
.map(|&token_id| {
tokenizer
.decode(&[token_id as u32], false)
.unwrap_or_else(|_| format!("<token_{}>", token_id))
})
.collect();
// Build ChatLogProbsContent for each token (consume iterator to avoid clones)
for (i, (&logprob, token_text)) in proto_logprobs
.token_logprobs
.iter()
.zip(token_texts.into_iter())
.enumerate()
{
let bytes = Some(token_text.as_bytes().to_vec());
// Build top_logprobs for this position
let mut top_logprobs = Vec::new();
if let Some(top_logprobs_entry) = proto_logprobs.top_logprobs.get(i) {
// Decode top token IDs (always with skip_special_tokens=false)
let top_token_texts: Vec<String> = top_logprobs_entry
.token_ids
.iter()
.map(|&tid| {
tokenizer
.decode(&[tid as u32], false)
.unwrap_or_else(|_| format!("<token_{}>", tid))
})
.collect();
for (j, (&top_logprob, &_top_token_id)) in top_logprobs_entry
.values
.iter()
.zip(top_logprobs_entry.token_ids.iter())
.enumerate()
{
if let Some(top_token_text) = top_token_texts.get(j) {
top_logprobs.push(TopLogProb {
token: top_token_text.clone(),
logprob: top_logprob,
bytes: Some(top_token_text.as_bytes().to_vec()),
});
}
}
}
content_items.push(ChatLogProbsContent {
token: token_text,
logprob,
bytes,
top_logprobs,
});
}
Ok(ChatLogProbs::Detailed {
content: (!content_items.is_empty()).then_some(content_items),
})
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment