Unverified Commit edd86b88 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][grpc] Refactor chat handler in grpc/ to use centralized orchestrator (#11314)


Co-authored-by: default avatarSimo Lin <linsimo.mark@gmail.com>
parent 4b4dc132
......@@ -2066,6 +2066,40 @@ impl GenerationRequest for GenerateRequest {
}
}
// TODO(generate): Define GenerateResponse and GenerateChoice structs
//
// Required for pipeline generate response processing (see grpc/pipeline.rs:931-964)
//
// #[derive(Debug, Clone, Serialize, Deserialize)]
// pub struct GenerateResponse {
// pub id: String,
// pub object: String, // "text.completion"
// pub created: u64,
// pub model: String,
// pub choices: Vec<GenerateChoice>,
// #[serde(skip_serializing_if = "Option::is_none")]
// pub usage: Option<Usage>,
// #[serde(skip_serializing_if = "Option::is_none")]
// pub system_fingerprint: Option<String>,
// }
//
// #[derive(Debug, Clone, Serialize, Deserialize)]
// pub struct GenerateChoice {
// pub index: u32,
// pub text: String,
// #[serde(skip_serializing_if = "Option::is_none")]
// pub output_ids: Option<Vec<u32>>,
// #[serde(skip_serializing_if = "Option::is_none")]
// pub finish_reason: Option<String>,
// #[serde(skip_serializing_if = "Option::is_none")]
// pub logprobs: Option<TopLogprobs>,
// #[serde(skip_serializing_if = "Option::is_none")]
// pub matched_stop: Option<Value>,
// }
//
// Note: Verify if similar structs already exist elsewhere before implementing.
// May need streaming variant (GenerateStreamResponse) as well.
// Constants for rerank API
pub const DEFAULT_MODEL_NAME: &str = "default";
......
//! Request context types for gRPC router pipeline
//!
//! This module provides the core context types that flow through the router pipeline,
//! eliminating deep parameter passing chains and providing a single source of truth
//! for request state.
use std::collections::HashMap;
use std::sync::Arc;
use axum::http::HeaderMap;
use serde_json::Value;
use crate::core::Worker;
use crate::grpc_client::{proto, SglangSchedulerClient};
use crate::protocols::spec::{ChatCompletionRequest, ChatCompletionResponse, GenerateRequest};
use crate::reasoning_parser::ReasoningParserFactory;
use crate::tokenizer::stop::StopSequenceDecoder;
use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ToolParserFactory;
// ============================================================================
// Core Context Types
// ============================================================================
/// Main request processing context
///
/// This is the single source of truth for all request state as it flows
/// through the pipeline stages. Uses Rust's type system to enforce proper
/// stage ordering at compile time.
pub struct RequestContext {
// === Input (Immutable) ===
pub input: RequestInput,
// === Shared Components (Immutable References) ===
pub components: Arc<SharedComponents>,
// === Processing State (Mutable, evolves through pipeline) ===
pub state: ProcessingState,
}
/// Immutable request input
pub struct RequestInput {
pub request_type: RequestType,
pub headers: Option<HeaderMap>,
pub model_id: Option<String>,
}
/// Request type variants
pub enum RequestType {
Chat(Box<ChatCompletionRequest>),
Generate(Box<GenerateRequest>),
}
/// Shared components (injected once at creation)
pub struct SharedComponents {
pub tokenizer: Arc<dyn Tokenizer>,
pub tool_parser_factory: ToolParserFactory,
pub reasoning_parser_factory: ReasoningParserFactory,
}
/// Mutable processing state (evolves through pipeline stages)
#[derive(Default)]
pub struct ProcessingState {
// Stage 1: Preparation outputs
pub preparation: Option<PreparationOutput>,
// Stage 2: Worker selection outputs
pub workers: Option<WorkerSelection>,
// Stage 3: Client acquisition outputs
pub clients: Option<ClientSelection>,
// Stage 4: Request building outputs
pub proto_request: Option<proto::GenerateRequest>,
// Stage 5: Dispatch metadata
pub dispatch: Option<DispatchMetadata>,
// Stage 6: Response processing state
pub response: ResponseState,
}
// ============================================================================
// Stage-Specific Output Types
// ============================================================================
/// Output from preparation stage (Step 1)
pub struct PreparationOutput {
/// Original text (for chat) or resolved text (for generate)
pub original_text: Option<String>,
/// Tokenized input
pub token_ids: Vec<u32>,
/// Processed messages (chat only)
pub processed_messages: Option<super::ProcessedMessages>,
/// Tool call constraints (if applicable)
pub tool_constraints: Option<(String, String)>,
/// Filtered request (if tools were filtered)
pub filtered_request: Option<ChatCompletionRequest>,
}
/// Worker selection (Step 2)
pub enum WorkerSelection {
Single {
worker: Arc<dyn Worker>,
},
Dual {
prefill: Arc<dyn Worker>,
decode: Arc<dyn Worker>,
},
}
/// Client selection (Step 3)
pub enum ClientSelection {
Single {
client: SglangSchedulerClient,
},
Dual {
prefill: SglangSchedulerClient,
decode: SglangSchedulerClient,
},
}
/// Dispatch metadata (Step 5)
#[derive(Clone)]
pub struct DispatchMetadata {
pub request_id: String,
pub model: String,
pub created: u64,
pub weight_version: Option<String>,
pub is_streaming: bool,
}
/// Response processing state (Step 6)
#[derive(Default)]
pub struct ResponseState {
/// Stop sequence decoder
pub stop_decoder: Option<StopSequenceDecoder>,
/// Per-index streaming state (for n>1 support)
pub streaming: StreamingState,
/// Collected responses (non-streaming)
pub collected: Option<Vec<proto::GenerateComplete>>,
/// Execution result (streams from workers)
pub execution_result: Option<ExecutionResult>,
/// Final processed response
pub final_response: Option<FinalResponse>,
}
/// Streaming state (per-choice tracking)
#[derive(Default)]
pub struct StreamingState {
pub is_firsts: HashMap<u32, bool>,
pub stream_buffers: HashMap<u32, String>,
pub finish_reasons: HashMap<u32, String>,
pub matched_stops: HashMap<u32, Option<Value>>,
pub prompt_tokens: HashMap<u32, u32>,
pub completion_tokens: HashMap<u32, u32>,
pub cached_tokens: HashMap<u32, u32>,
// Parser state (lazy initialization per index)
pub reasoning_parsers:
HashMap<u32, Arc<std::sync::Mutex<Box<dyn crate::reasoning_parser::ReasoningParser>>>>,
pub tool_parsers:
HashMap<u32, Arc<tokio::sync::Mutex<Box<dyn crate::tool_parser::ToolParser>>>>,
pub has_tool_calls: HashMap<u32, bool>,
}
// ============================================================================
// Context Builders
// ============================================================================
impl RequestContext {
/// Create context for chat completion request
pub fn for_chat(
request: ChatCompletionRequest,
headers: Option<HeaderMap>,
model_id: Option<String>,
components: Arc<SharedComponents>,
) -> Self {
Self {
input: RequestInput {
request_type: RequestType::Chat(Box::new(request)),
headers,
model_id,
},
components,
state: ProcessingState::default(),
}
}
/// Create context for generate request
pub fn for_generate(
request: GenerateRequest,
headers: Option<HeaderMap>,
model_id: Option<String>,
components: Arc<SharedComponents>,
) -> Self {
Self {
input: RequestInput {
request_type: RequestType::Generate(Box::new(request)),
headers,
model_id,
},
components,
state: ProcessingState::default(),
}
}
/// Get reference to original request (type-safe)
pub fn request(&self) -> &RequestType {
&self.input.request_type
}
/// Get chat request (panics if not chat)
pub fn chat_request(&self) -> &ChatCompletionRequest {
match &self.input.request_type {
RequestType::Chat(req) => req.as_ref(),
_ => panic!("Expected chat request"),
}
}
/// Try to get chat request
pub fn try_chat_request(&self) -> Option<&ChatCompletionRequest> {
match &self.input.request_type {
RequestType::Chat(req) => Some(req.as_ref()),
_ => None,
}
}
/// Get generate request (panics if not generate)
pub fn generate_request(&self) -> &GenerateRequest {
match &self.input.request_type {
RequestType::Generate(req) => req.as_ref(),
_ => panic!("Expected generate request"),
}
}
/// Try to get generate request
pub fn try_generate_request(&self) -> Option<&GenerateRequest> {
match &self.input.request_type {
RequestType::Generate(req) => Some(req.as_ref()),
_ => None,
}
}
/// Check if request is streaming
pub fn is_streaming(&self) -> bool {
match &self.input.request_type {
RequestType::Chat(req) => req.stream,
RequestType::Generate(req) => req.stream,
}
}
/// Check if request is chat
pub fn is_chat(&self) -> bool {
matches!(&self.input.request_type, RequestType::Chat(_))
}
/// Check if request is generate
pub fn is_generate(&self) -> bool {
matches!(&self.input.request_type, RequestType::Generate(_))
}
}
// ============================================================================
// Default Implementations
// ============================================================================
// ============================================================================
// Helper Methods
// ============================================================================
impl WorkerSelection {
pub fn is_dual(&self) -> bool {
matches!(self, Self::Dual { .. })
}
pub fn single(&self) -> Option<&Arc<dyn Worker>> {
match self {
Self::Single { worker } => Some(worker),
_ => None,
}
}
#[allow(clippy::type_complexity)]
pub fn dual(&self) -> Option<(&Arc<dyn Worker>, &Arc<dyn Worker>)> {
match self {
Self::Dual { prefill, decode } => Some((prefill, decode)),
_ => None,
}
}
pub fn prefill_worker(&self) -> Option<&Arc<dyn Worker>> {
match self {
Self::Dual { prefill, .. } => Some(prefill),
_ => None,
}
}
pub fn decode_worker(&self) -> Option<&Arc<dyn Worker>> {
match self {
Self::Dual { decode, .. } => Some(decode),
_ => None,
}
}
}
impl ClientSelection {
pub fn is_dual(&self) -> bool {
matches!(self, Self::Dual { .. })
}
pub fn single(&self) -> Option<&SglangSchedulerClient> {
match self {
Self::Single { client } => Some(client),
_ => None,
}
}
pub fn single_mut(&mut self) -> Option<&mut SglangSchedulerClient> {
match self {
Self::Single { client } => Some(client),
_ => None,
}
}
pub fn dual(&self) -> Option<(&SglangSchedulerClient, &SglangSchedulerClient)> {
match self {
Self::Dual { prefill, decode } => Some((prefill, decode)),
_ => None,
}
}
pub fn dual_mut(&mut self) -> Option<(&mut SglangSchedulerClient, &mut SglangSchedulerClient)> {
match self {
Self::Dual { prefill, decode } => Some((prefill, decode)),
_ => None,
}
}
pub fn prefill_client(&self) -> Option<&SglangSchedulerClient> {
match self {
Self::Dual { prefill, .. } => Some(prefill),
_ => None,
}
}
pub fn prefill_client_mut(&mut self) -> Option<&mut SglangSchedulerClient> {
match self {
Self::Dual { prefill, .. } => Some(prefill),
_ => None,
}
}
pub fn decode_client(&self) -> Option<&SglangSchedulerClient> {
match self {
Self::Dual { decode, .. } => Some(decode),
_ => None,
}
}
pub fn decode_client_mut(&mut self) -> Option<&mut SglangSchedulerClient> {
match self {
Self::Dual { decode, .. } => Some(decode),
_ => None,
}
}
}
// ============================================================================
// Execution and Response Types
// ============================================================================
use tonic::codec::Streaming;
/// Result of request execution (streams from workers)
pub enum ExecutionResult {
Single {
stream: Streaming<proto::GenerateResponse>,
},
Dual {
prefill: Streaming<proto::GenerateResponse>,
decode: Box<Streaming<proto::GenerateResponse>>,
},
}
/// Final processed response
pub enum FinalResponse {
Chat(ChatCompletionResponse),
Generate(Box<GenerateRequest>),
}
......@@ -3,8 +3,12 @@
use crate::grpc_client::proto;
use crate::protocols::spec::StringOrArray;
pub mod context;
pub mod pd_router;
pub mod pipeline;
pub mod processing;
pub mod router;
pub mod streaming;
pub mod utils;
/// Processed chat messages ready for gRPC generation
......
This diff is collapsed.
This diff is collapsed.
//! Shared response processing logic for gRPC routers
//!
//! This module contains response processing functions that are shared between
//! the regular router and PD router, eliminating ~1,200 lines of exact duplicates.
use std::sync::Arc;
use serde_json::Value;
use tracing::error;
use crate::grpc_client::proto;
use crate::protocols::spec::{
ChatChoice, ChatCompletionMessage, ChatCompletionRequest, FunctionCallResponse, ToolCall,
ToolChoice, ToolChoiceValue,
};
use crate::reasoning_parser::ReasoningParserFactory;
use crate::tokenizer::stop::{SequenceDecoderOutput, StopSequenceDecoder};
use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ToolParserFactory;
use super::utils;
// ============================================================================
// Response Processor - Main Entry Point
// ============================================================================
/// Unified response processor for both routers
#[derive(Clone)]
pub struct ResponseProcessor {
pub tokenizer: Arc<dyn Tokenizer>,
pub tool_parser_factory: ToolParserFactory,
pub reasoning_parser_factory: ReasoningParserFactory,
configured_tool_parser: Option<String>,
configured_reasoning_parser: Option<String>,
}
impl ResponseProcessor {
pub fn new(
tokenizer: Arc<dyn Tokenizer>,
tool_parser_factory: ToolParserFactory,
reasoning_parser_factory: ReasoningParserFactory,
configured_tool_parser: Option<String>,
configured_reasoning_parser: Option<String>,
) -> Self {
Self {
tokenizer,
tool_parser_factory,
reasoning_parser_factory,
configured_tool_parser,
configured_reasoning_parser,
}
}
/// Process a single choice from GenerateComplete response (EXACT COPY from router.rs:1573-1725)
pub async fn process_single_choice(
&self,
complete: &proto::GenerateComplete,
index: usize,
original_request: &ChatCompletionRequest,
stop_decoder: &mut StopSequenceDecoder,
history_tool_calls_count: usize,
) -> Result<ChatChoice, String> {
stop_decoder.reset();
// Decode tokens
let outputs = stop_decoder
.process_tokens(&complete.output_ids)
.map_err(|e| format!("Failed to process tokens: {}", e))?;
// Accumulate text with early breaks
let mut final_text = String::new();
for output in outputs {
match output {
SequenceDecoderOutput::Text(t) => final_text.push_str(&t),
SequenceDecoderOutput::StoppedWithText(t) => {
final_text.push_str(&t);
break;
}
SequenceDecoderOutput::Stopped => break,
SequenceDecoderOutput::Held => {}
}
}
// Flush remaining text
if let SequenceDecoderOutput::Text(t) = stop_decoder.flush() {
final_text.push_str(&t);
}
// Step 1: Handle reasoning content parsing
let mut reasoning_text: Option<String> = None;
let mut processed_text = final_text;
// Check if reasoning parsing is enabled and separate_reasoning is requested
if original_request.separate_reasoning {
let pooled_parser = utils::get_reasoning_parser(
&self.reasoning_parser_factory,
self.configured_reasoning_parser.as_ref(),
&original_request.model,
);
let mut parser = pooled_parser
.lock()
.map_err(|e| format!("Failed to acquire reasoning parser lock: {}", e))?;
match parser.detect_and_parse_reasoning(&processed_text) {
Ok(result) => {
if !result.reasoning_text.is_empty() {
reasoning_text = Some(result.reasoning_text);
}
processed_text = result.normal_text;
}
Err(e) => {
return Err(format!("Reasoning parsing error: {}", e));
}
}
}
// Step 2: Handle tool call parsing
let mut tool_calls: Option<Vec<ToolCall>> = None;
// Check if tool calls should be processed
let tool_choice_enabled = !matches!(
&original_request.tool_choice,
Some(ToolChoice::Value(ToolChoiceValue::None))
);
if tool_choice_enabled && original_request.tools.is_some() {
// Check if JSON schema constraint was used (specific function or required mode)
let used_json_schema = match &original_request.tool_choice {
Some(ToolChoice::Function { .. }) => true,
Some(ToolChoice::Value(ToolChoiceValue::Required)) => true,
Some(ToolChoice::AllowedTools { mode, .. }) => mode == "required",
_ => false,
};
if used_json_schema {
(tool_calls, processed_text) = utils::parse_json_schema_response(
&processed_text,
&original_request.tool_choice,
);
} else {
(tool_calls, processed_text) = self
.parse_tool_calls(
&processed_text,
&original_request.model,
history_tool_calls_count,
)
.await;
}
}
// Step 3: Use finish reason directly from proto (already OpenAI-compatible string)
let finish_reason_str = &complete.finish_reason;
// Override finish reason if we have tool calls
let final_finish_reason_str = if tool_calls.is_some() {
"tool_calls"
} else {
finish_reason_str
};
// Extract matched_stop information from proto
let matched_stop = match &complete.matched_stop {
Some(proto::generate_complete::MatchedStop::MatchedTokenId(token_id)) => {
Some(Value::Number(serde_json::Number::from(*token_id)))
}
Some(proto::generate_complete::MatchedStop::MatchedStopStr(stop_str)) => {
Some(Value::String(stop_str.clone()))
}
None => None,
};
// Step 4: Convert output logprobs if present
let logprobs = if let Some(proto_logprobs) = &complete.output_logprobs {
match utils::convert_proto_to_openai_logprobs(proto_logprobs, &self.tokenizer) {
Ok(logprobs) => Some(logprobs),
Err(e) => {
error!("Failed to convert logprobs: {}", e);
None
}
}
} else {
None
};
// Step 5: Build ChatCompletionMessage (proper response message type)
let chat_message = ChatCompletionMessage {
role: "assistant".to_string(),
content: if processed_text.is_empty() {
None
} else {
Some(processed_text)
},
tool_calls,
reasoning_content: reasoning_text,
};
// Step 6: Build ChatChoice
let choice = ChatChoice {
index: index as u32,
message: chat_message,
logprobs,
finish_reason: Some(final_finish_reason_str.to_string()),
matched_stop,
hidden_states: None,
};
Ok(choice)
}
/// Parse tool calls using model-specific parser (EXACT COPY from router.rs:296-361)
pub async fn parse_tool_calls(
&self,
processed_text: &str,
model: &str,
history_tool_calls_count: usize,
) -> (Option<Vec<ToolCall>>, String) {
// Get pooled parser for this model
let pooled_parser = utils::get_tool_parser(
&self.tool_parser_factory,
self.configured_tool_parser.as_ref(),
model,
);
// Try parsing directly (parser will handle detection internally)
let result = {
let parser = pooled_parser.lock().await;
parser.parse_complete(processed_text).await
// Lock is dropped here
};
match result {
Ok((normal_text, parsed_tool_calls)) => {
if parsed_tool_calls.is_empty() {
return (None, normal_text);
}
let spec_tool_calls = parsed_tool_calls
.into_iter()
.enumerate()
.map(|(index, tc)| {
// Generate ID for this tool call
let id = utils::generate_tool_call_id(
model,
&tc.function.name,
index,
history_tool_calls_count,
);
ToolCall {
id,
tool_type: "function".to_string(),
function: FunctionCallResponse {
name: tc.function.name,
arguments: Some(
serde_json::to_string(&tc.function.arguments)
.unwrap_or_else(|_| "{}".to_string()),
),
},
}
})
.collect();
(Some(spec_tool_calls), normal_text)
}
Err(e) => {
error!("Tool call parsing error: {}", e);
(None, processed_text.to_string())
}
}
}
}
This diff is collapsed.
This diff is collapsed.
......@@ -4,8 +4,8 @@ use super::ProcessedMessages;
use crate::core::Worker;
use crate::grpc_client::{proto, SglangSchedulerClient};
use crate::protocols::spec::{
ChatCompletionRequest, ChatMessage, FunctionCallResponse, StringOrArray, Tool, ToolCall,
ToolChoice, ToolChoiceValue,
ChatCompletionRequest, ChatLogProbs, ChatLogProbsContent, ChatMessage, FunctionCallResponse,
StringOrArray, Tool, ToolCall, ToolChoice, ToolChoiceValue, TopLogProb,
};
use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams};
use crate::tokenizer::traits::Tokenizer;
......@@ -736,6 +736,79 @@ pub fn get_tool_parser(
}
}
/// Convert proto::OutputLogProbs to OpenAI ChatLogProbs format
///
/// This function decodes token IDs using the tokenizer and builds the logprobs structure
/// expected by the OpenAI API format.
pub fn convert_proto_to_openai_logprobs(
proto_logprobs: &proto::OutputLogProbs,
tokenizer: &Arc<dyn Tokenizer>,
) -> Result<ChatLogProbs, String> {
let mut content_items = Vec::new();
// Decode token IDs to text (always with skip_special_tokens=false for logprobs)
let token_texts: Vec<String> = proto_logprobs
.token_ids
.iter()
.map(|&token_id| {
tokenizer
.decode(&[token_id as u32], false)
.unwrap_or_else(|_| format!("<token_{}>", token_id))
})
.collect();
// Build ChatLogProbsContent for each token (consume iterator to avoid clones)
for (i, (&logprob, token_text)) in proto_logprobs
.token_logprobs
.iter()
.zip(token_texts.into_iter())
.enumerate()
{
let bytes = Some(token_text.as_bytes().to_vec());
// Build top_logprobs for this position
let mut top_logprobs = Vec::new();
if let Some(top_logprobs_entry) = proto_logprobs.top_logprobs.get(i) {
// Decode top token IDs (always with skip_special_tokens=false)
let top_token_texts: Vec<String> = top_logprobs_entry
.token_ids
.iter()
.map(|&tid| {
tokenizer
.decode(&[tid as u32], false)
.unwrap_or_else(|_| format!("<token_{}>", tid))
})
.collect();
for (j, (&top_logprob, &_top_token_id)) in top_logprobs_entry
.values
.iter()
.zip(top_logprobs_entry.token_ids.iter())
.enumerate()
{
if let Some(top_token_text) = top_token_texts.get(j) {
top_logprobs.push(TopLogProb {
token: top_token_text.clone(),
logprob: top_logprob,
bytes: Some(top_token_text.as_bytes().to_vec()),
});
}
}
}
content_items.push(ChatLogProbsContent {
token: token_text,
logprob,
bytes,
top_logprobs,
});
}
Ok(ChatLogProbs::Detailed {
content: (!content_items.is_empty()).then_some(content_items),
})
}
#[cfg(test)]
mod tests {
use super::*;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment