Unverified Commit 7b59b0b8 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][grpc] Further delegate non-stream processing to `processing.rs` (#11553)

parent acc2327b
...@@ -23,6 +23,9 @@ use std::sync::Arc; ...@@ -23,6 +23,9 @@ use std::sync::Arc;
use tracing::debug; use tracing::debug;
use super::context::SharedComponents;
use super::pipeline::RequestPipeline;
/// gRPC PD (Prefill-Decode) router implementation for SGLang /// gRPC PD (Prefill-Decode) router implementation for SGLang
#[derive(Clone)] #[derive(Clone)]
#[allow(dead_code)] // Fields will be used once implementation is complete #[allow(dead_code)] // Fields will be used once implementation is complete
...@@ -37,8 +40,8 @@ pub struct GrpcPDRouter { ...@@ -37,8 +40,8 @@ pub struct GrpcPDRouter {
retry_config: RetryConfig, retry_config: RetryConfig,
configured_reasoning_parser: Option<String>, configured_reasoning_parser: Option<String>,
configured_tool_parser: Option<String>, configured_tool_parser: Option<String>,
pipeline: super::pipeline::ChatCompletionPipeline, pipeline: RequestPipeline,
shared_components: Arc<super::context::SharedComponents>, shared_components: Arc<SharedComponents>,
} }
impl GrpcPDRouter { impl GrpcPDRouter {
...@@ -66,36 +69,21 @@ impl GrpcPDRouter { ...@@ -66,36 +69,21 @@ impl GrpcPDRouter {
.clone(); .clone();
// Create shared components for pipeline // Create shared components for pipeline
let shared_components = Arc::new(super::context::SharedComponents { let shared_components = Arc::new(SharedComponents {
tokenizer: tokenizer.clone(), tokenizer: tokenizer.clone(),
tool_parser_factory: tool_parser_factory.clone(), tool_parser_factory: tool_parser_factory.clone(),
reasoning_parser_factory: reasoning_parser_factory.clone(), reasoning_parser_factory: reasoning_parser_factory.clone(),
}); });
// Create response processor // Create PD pipeline
let processor = super::processing::ResponseProcessor::new( let pipeline = RequestPipeline::new_pd(
tokenizer.clone(), worker_registry.clone(),
tool_parser_factory.clone(), policy_registry.clone(),
reasoning_parser_factory.clone(),
ctx.configured_tool_parser.clone(),
ctx.configured_reasoning_parser.clone(),
);
// Create streaming processor
let streaming_processor = Arc::new(super::streaming::StreamingProcessor::new(
tokenizer.clone(), tokenizer.clone(),
tool_parser_factory.clone(), tool_parser_factory.clone(),
reasoning_parser_factory.clone(), reasoning_parser_factory.clone(),
ctx.configured_tool_parser.clone(), ctx.configured_tool_parser.clone(),
ctx.configured_reasoning_parser.clone(), ctx.configured_reasoning_parser.clone(),
));
// Create PD pipeline
let pipeline = super::pipeline::ChatCompletionPipeline::new_pd(
worker_registry.clone(),
policy_registry.clone(),
processor,
streaming_processor,
); );
Ok(GrpcPDRouter { Ok(GrpcPDRouter {
......
...@@ -14,13 +14,10 @@ use super::utils; ...@@ -14,13 +14,10 @@ use super::utils;
use crate::core::{ConnectionMode, Worker, WorkerRegistry, WorkerType}; use crate::core::{ConnectionMode, Worker, WorkerRegistry, WorkerType};
use crate::grpc_client::proto; use crate::grpc_client::proto;
use crate::policies::PolicyRegistry; use crate::policies::PolicyRegistry;
use crate::protocols::spec::{ use crate::protocols::spec::{ChatCompletionRequest, GenerateRequest, InputIds};
ChatCompletionRequest, ChatCompletionResponse, GenerateMetaInfo, GenerateRequest, use crate::reasoning_parser::ParserFactory as ReasoningParserFactory;
GenerateResponse, InputIds, Usage,
};
use crate::tokenizer::stop::SequenceDecoderOutput;
use crate::tokenizer::traits::Tokenizer; use crate::tokenizer::traits::Tokenizer;
use proto::generate_complete::MatchedStop; use crate::tool_parser::ParserFactory as ToolParserFactory;
use proto::DisaggregatedParams; use proto::DisaggregatedParams;
use rand::Rng; use rand::Rng;
use std::sync::Arc; use std::sync::Arc;
...@@ -790,114 +787,32 @@ impl ResponseProcessingStage { ...@@ -790,114 +787,32 @@ impl ResponseProcessingStage {
.take() .take()
.ok_or_else(|| utils::internal_error_static("No execution result"))?; .ok_or_else(|| utils::internal_error_static("No execution result"))?;
if is_streaming { // Get dispatch metadata (needed by both streaming and non-streaming)
// Get dispatch metadata for consistent response fields let dispatch = ctx
let dispatch = ctx .state
.state .dispatch
.dispatch .as_ref()
.as_ref() .ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?
.ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?; .clone();
if is_streaming {
// Streaming: Use StreamingProcessor and return SSE response (done) // Streaming: Use StreamingProcessor and return SSE response (done)
return Ok(Some( return Ok(Some(
self.streaming_processor.clone().process_streaming_response( self.streaming_processor.clone().process_streaming_response(
execution_result, execution_result,
ctx.chat_request_arc(), // Cheap Arc clone (8 bytes) ctx.chat_request_arc(), // Cheap Arc clone (8 bytes)
dispatch.clone(), dispatch,
), ),
)); ));
} }
// Non-streaming: Extract chat request details before mutable borrows // Non-streaming: Delegate to ResponseProcessor
let request_logprobs = match &ctx.input.request_type { let request_logprobs = match &ctx.input.request_type {
RequestType::Chat(req) => req.logprobs, RequestType::Chat(req) => req.logprobs,
_ => false, _ => false,
}; };
// Collect all responses from the execution result
let all_responses = match execution_result {
ExecutionResult::Single { mut stream } => {
let responses = utils::collect_stream_responses(&mut stream, "Single").await?;
stream.mark_completed();
responses
}
ExecutionResult::Dual {
mut prefill,
decode,
} => {
// Collect prefill for input_logprobs (don't mark completed yet)
let prefill_responses =
utils::collect_stream_responses(&mut prefill, "Prefill").await?;
// Collect decode for actual output (don't mark completed yet)
let mut decode_stream = *decode;
let mut decode_responses =
utils::collect_stream_responses(&mut decode_stream, "Decode").await?;
// Mark both streams as completed now that both succeeded
prefill.mark_completed();
decode_stream.mark_completed();
// Merge prefill input_logprobs if requested
if request_logprobs {
if let Some(prefill_input_logprobs) = prefill_responses
.first()
.and_then(|r| r.input_logprobs.clone())
{
for response in &mut decode_responses {
response.input_logprobs = Some(prefill_input_logprobs.clone());
}
}
}
decode_responses
}
};
if all_responses.is_empty() {
return Err(utils::internal_error_static("No responses from server"));
}
let chat_request = ctx.chat_request_arc(); let chat_request = ctx.chat_request_arc();
let history_tool_calls_count = utils::get_history_tool_calls_count(&chat_request);
// Check parser availability once upfront (not per choice)
let reasoning_parser_available = chat_request.separate_reasoning
&& utils::check_reasoning_parser_availability(
&self.processor.reasoning_parser_factory,
self.processor.configured_reasoning_parser.as_ref(),
&chat_request.model,
);
let tool_choice_enabled = !matches!(
&chat_request.tool_choice,
Some(crate::protocols::spec::ToolChoice::Value(
crate::protocols::spec::ToolChoiceValue::None
))
);
let tool_parser_available = tool_choice_enabled
&& chat_request.tools.is_some()
&& utils::check_tool_parser_availability(
&self.processor.tool_parser_factory,
self.processor.configured_tool_parser.as_ref(),
&chat_request.model,
);
// Log once per request (not per choice)
if chat_request.separate_reasoning && !reasoning_parser_available {
debug!(
"No reasoning parser found for model '{}', skipping reasoning parsing",
chat_request.model
);
}
if chat_request.tools.is_some() && tool_choice_enabled && !tool_parser_available {
debug!(
"No tool parser found for model '{}', skipping tool call parsing",
chat_request.model
);
}
let stop_decoder = ctx let stop_decoder = ctx
.state .state
...@@ -906,60 +821,16 @@ impl ResponseProcessingStage { ...@@ -906,60 +821,16 @@ impl ResponseProcessingStage {
.as_mut() .as_mut()
.ok_or_else(|| utils::internal_error_static("Stop decoder not initialized"))?; .ok_or_else(|| utils::internal_error_static("Stop decoder not initialized"))?;
let mut choices = Vec::new(); let response = self
for (index, complete) in all_responses.iter().enumerate() { .processor
match self .process_non_streaming_chat_response(
.processor execution_result,
.process_single_choice( chat_request,
complete, dispatch,
index, stop_decoder,
&chat_request, request_logprobs,
stop_decoder, )
history_tool_calls_count, .await?;
reasoning_parser_available,
tool_parser_available,
)
.await
{
Ok(choice) => choices.push(choice),
Err(e) => {
return Err(utils::internal_error_message(format!(
"Failed to process choice {}: {}",
index, e
)));
}
}
}
// Build usage
let total_prompt_tokens: u32 = all_responses.iter().map(|r| r.prompt_tokens as u32).sum();
let total_completion_tokens: u32 = all_responses
.iter()
.map(|r| r.completion_tokens as u32)
.sum();
let usage = Usage {
prompt_tokens: total_prompt_tokens,
completion_tokens: total_completion_tokens,
total_tokens: total_prompt_tokens + total_completion_tokens,
completion_tokens_details: None,
};
// Build final ChatCompletionResponse
let dispatch = ctx
.state
.dispatch
.as_ref()
.ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?;
let response = ChatCompletionResponse {
id: dispatch.request_id.clone(),
object: "chat.completion".to_string(),
created: dispatch.created,
model: dispatch.model.clone(),
choices,
usage: Some(usage),
system_fingerprint: dispatch.weight_version.clone(),
};
// Store the final response // Store the final response
ctx.state.response.final_response = Some(FinalResponse::Chat(response)); ctx.state.response.final_response = Some(FinalResponse::Chat(response));
...@@ -982,70 +853,29 @@ impl ResponseProcessingStage { ...@@ -982,70 +853,29 @@ impl ResponseProcessingStage {
.take() .take()
.ok_or_else(|| utils::internal_error_static("No execution result"))?; .ok_or_else(|| utils::internal_error_static("No execution result"))?;
if is_streaming { // Get dispatch metadata (needed by both streaming and non-streaming)
// Get dispatch metadata for consistent response fields let dispatch = ctx
let dispatch = ctx .state
.state .dispatch
.dispatch .as_ref()
.as_ref() .ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?
.ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?; .clone();
if is_streaming {
// Streaming: Use StreamingProcessor and return SSE response (done) // Streaming: Use StreamingProcessor and return SSE response (done)
return Ok(Some( return Ok(Some(
self.streaming_processor.clone().process_streaming_generate( self.streaming_processor.clone().process_streaming_generate(
execution_result, execution_result,
ctx.generate_request_arc(), // Cheap Arc clone (8 bytes) ctx.generate_request_arc(), // Cheap Arc clone (8 bytes)
dispatch.clone(), dispatch,
), ),
)); ));
} }
// Non-streaming: Collect all responses // Non-streaming: Delegate to ResponseProcessor
let request_logprobs = ctx.generate_request().return_logprob; let request_logprobs = ctx.generate_request().return_logprob;
let all_responses = match execution_result { let generate_request = ctx.generate_request_arc();
ExecutionResult::Single { mut stream } => {
let responses = utils::collect_stream_responses(&mut stream, "Single").await?;
stream.mark_completed();
responses
}
ExecutionResult::Dual {
mut prefill,
decode,
} => {
// Collect prefill for input_logprobs (don't mark completed yet)
let prefill_responses =
utils::collect_stream_responses(&mut prefill, "Prefill").await?;
// Collect decode for actual output (don't mark completed yet)
let mut decode_stream = *decode;
let mut decode_responses =
utils::collect_stream_responses(&mut decode_stream, "Decode").await?;
// Mark both streams as completed now that both succeeded
prefill.mark_completed();
decode_stream.mark_completed();
// Merge prefill input_logprobs if requested
if request_logprobs {
if let Some(prefill_input_logprobs) = prefill_responses
.first()
.and_then(|r| r.input_logprobs.clone())
{
for response in &mut decode_responses {
response.input_logprobs = Some(prefill_input_logprobs.clone());
}
}
}
decode_responses
}
};
if all_responses.is_empty() {
return Err(utils::internal_error_static("No responses from server"));
}
// Get stop decoder for processing
let stop_decoder = ctx let stop_decoder = ctx
.state .state
.response .response
...@@ -1053,103 +883,17 @@ impl ResponseProcessingStage { ...@@ -1053,103 +883,17 @@ impl ResponseProcessingStage {
.as_mut() .as_mut()
.ok_or_else(|| utils::internal_error_static("Stop decoder not initialized"))?; .ok_or_else(|| utils::internal_error_static("Stop decoder not initialized"))?;
// Get dispatch metadata let result_array = self
let dispatch = ctx .processor
.state .process_non_streaming_generate_response(
.dispatch execution_result,
.as_ref() generate_request,
.ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?; dispatch,
stop_decoder,
// Process each completion (similar to router.rs:336-400) request_logprobs,
let mut result_array = Vec::new(); start_time,
for mut complete in all_responses { )
stop_decoder.reset(); .await?;
// Process tokens through stop decoder
let outputs = match stop_decoder.process_tokens(&complete.output_ids) {
Ok(outputs) => outputs,
Err(e) => {
return Err(utils::internal_error_message(format!(
"Failed to process tokens: {}",
e
)))
}
};
// Accumulate text with early breaks
let mut decoded_text = String::new();
for output in outputs {
match output {
SequenceDecoderOutput::Text(t) => decoded_text.push_str(&t),
SequenceDecoderOutput::StoppedWithText(t) => {
decoded_text.push_str(&t);
break;
}
SequenceDecoderOutput::Stopped => break,
SequenceDecoderOutput::Held => {}
}
}
// Flush remaining text
if let SequenceDecoderOutput::Text(t) = stop_decoder.flush() {
decoded_text.push_str(&t);
}
let output_ids = std::mem::take(&mut complete.output_ids);
let finish_reason_str = std::mem::take(&mut complete.finish_reason);
// Parse finish_reason from string to proper type
let finish_reason =
utils::parse_finish_reason(&finish_reason_str, complete.completion_tokens);
// Handle matched_stop if present
let matched_stop = complete.matched_stop.take().map(|matched| match matched {
MatchedStop::MatchedTokenId(id) => serde_json::json!(id),
MatchedStop::MatchedStopStr(s) => serde_json::json!(s),
});
// Extract logprobs if requested (convert proto types to Generate format)
let input_token_logprobs = if request_logprobs {
complete
.input_logprobs
.as_ref()
.map(utils::convert_generate_input_logprobs)
} else {
None
};
let output_token_logprobs = if request_logprobs {
complete
.output_logprobs
.as_ref()
.map(utils::convert_generate_output_logprobs)
} else {
None
};
// Build GenerateResponse struct
let meta_info = GenerateMetaInfo {
id: dispatch.request_id.clone(),
finish_reason,
prompt_tokens: complete.prompt_tokens as u32,
weight_version: dispatch
.weight_version
.clone()
.unwrap_or_else(|| "default".to_string()),
input_token_logprobs,
output_token_logprobs,
completion_tokens: complete.completion_tokens as u32,
cached_tokens: complete.cached_tokens as u32,
e2e_latency: start_time.elapsed().as_secs_f64(),
matched_stop,
};
result_array.push(GenerateResponse {
text: decoded_text,
output_ids,
meta_info,
});
}
// Store the final response // Store the final response
ctx.state.response.final_response = Some(FinalResponse::Generate(result_array)); ctx.state.response.final_response = Some(FinalResponse::Generate(result_array));
...@@ -1162,23 +906,44 @@ impl ResponseProcessingStage { ...@@ -1162,23 +906,44 @@ impl ResponseProcessingStage {
// Pipeline Orchestrator // Pipeline Orchestrator
// ============================================================================ // ============================================================================
/// Complete chat completion pipeline /// Generic request pipeline for all request types
/// ///
/// Orchestrates all stages from request preparation to response delivery. /// Orchestrates all stages from request preparation to response delivery.
/// Configured differently for regular vs PD mode. /// Configured differently for regular vs PD mode.
#[derive(Clone)] #[derive(Clone)]
pub struct ChatCompletionPipeline { pub struct RequestPipeline {
stages: Arc<Vec<Box<dyn PipelineStage>>>, stages: Arc<Vec<Box<dyn PipelineStage>>>,
} }
impl ChatCompletionPipeline { impl RequestPipeline {
/// Create a regular (single-worker) pipeline /// Create a regular (single-worker) pipeline
pub fn new_regular( pub fn new_regular(
worker_registry: Arc<WorkerRegistry>, worker_registry: Arc<WorkerRegistry>,
policy_registry: Arc<PolicyRegistry>, policy_registry: Arc<PolicyRegistry>,
processor: processing::ResponseProcessor, tokenizer: Arc<dyn Tokenizer>,
streaming_processor: Arc<streaming::StreamingProcessor>, tool_parser_factory: ToolParserFactory,
reasoning_parser_factory: ReasoningParserFactory,
configured_tool_parser: Option<String>,
configured_reasoning_parser: Option<String>,
) -> Self { ) -> Self {
// Create response processor
let processor = processing::ResponseProcessor::new(
tokenizer.clone(),
tool_parser_factory.clone(),
reasoning_parser_factory.clone(),
configured_tool_parser.clone(),
configured_reasoning_parser.clone(),
);
// Create streaming processor
let streaming_processor = Arc::new(streaming::StreamingProcessor::new(
tokenizer,
tool_parser_factory,
reasoning_parser_factory,
configured_tool_parser,
configured_reasoning_parser,
));
let stages: Vec<Box<dyn PipelineStage>> = vec![ let stages: Vec<Box<dyn PipelineStage>> = vec![
Box::new(PreparationStage), Box::new(PreparationStage),
Box::new(WorkerSelectionStage::new( Box::new(WorkerSelectionStage::new(
...@@ -1190,10 +955,7 @@ impl ChatCompletionPipeline { ...@@ -1190,10 +955,7 @@ impl ChatCompletionPipeline {
Box::new(RequestBuildingStage::new(false)), // No PD metadata Box::new(RequestBuildingStage::new(false)), // No PD metadata
Box::new(DispatchMetadataStage), Box::new(DispatchMetadataStage),
Box::new(RequestExecutionStage::new(ExecutionMode::Single)), Box::new(RequestExecutionStage::new(ExecutionMode::Single)),
Box::new(ResponseProcessingStage::new( Box::new(ResponseProcessingStage::new(processor, streaming_processor)),
processor,
streaming_processor.clone(),
)),
]; ];
Self { Self {
...@@ -1205,9 +967,30 @@ impl ChatCompletionPipeline { ...@@ -1205,9 +967,30 @@ impl ChatCompletionPipeline {
pub fn new_pd( pub fn new_pd(
worker_registry: Arc<WorkerRegistry>, worker_registry: Arc<WorkerRegistry>,
policy_registry: Arc<PolicyRegistry>, policy_registry: Arc<PolicyRegistry>,
processor: processing::ResponseProcessor, tokenizer: Arc<dyn Tokenizer>,
streaming_processor: Arc<streaming::StreamingProcessor>, tool_parser_factory: ToolParserFactory,
reasoning_parser_factory: ReasoningParserFactory,
configured_tool_parser: Option<String>,
configured_reasoning_parser: Option<String>,
) -> Self { ) -> Self {
// Create response processor
let processor = processing::ResponseProcessor::new(
tokenizer.clone(),
tool_parser_factory.clone(),
reasoning_parser_factory.clone(),
configured_tool_parser.clone(),
configured_reasoning_parser.clone(),
);
// Create streaming processor
let streaming_processor = Arc::new(streaming::StreamingProcessor::new(
tokenizer,
tool_parser_factory,
reasoning_parser_factory,
configured_tool_parser,
configured_reasoning_parser,
));
let stages: Vec<Box<dyn PipelineStage>> = vec![ let stages: Vec<Box<dyn PipelineStage>> = vec![
Box::new(PreparationStage), Box::new(PreparationStage),
Box::new(WorkerSelectionStage::new( Box::new(WorkerSelectionStage::new(
...@@ -1219,10 +1002,7 @@ impl ChatCompletionPipeline { ...@@ -1219,10 +1002,7 @@ impl ChatCompletionPipeline {
Box::new(RequestBuildingStage::new(true)), // Inject PD metadata Box::new(RequestBuildingStage::new(true)), // Inject PD metadata
Box::new(DispatchMetadataStage), Box::new(DispatchMetadataStage),
Box::new(RequestExecutionStage::new(ExecutionMode::DualDispatch)), Box::new(RequestExecutionStage::new(ExecutionMode::DualDispatch)),
Box::new(ResponseProcessingStage::new( Box::new(ResponseProcessingStage::new(processor, streaming_processor)),
processor,
streaming_processor.clone(),
)),
]; ];
Self { Self {
......
...@@ -10,14 +10,18 @@ use tracing::error; ...@@ -10,14 +10,18 @@ use tracing::error;
use crate::grpc_client::proto; use crate::grpc_client::proto;
use crate::protocols::spec::{ use crate::protocols::spec::{
ChatChoice, ChatCompletionMessage, ChatCompletionRequest, FunctionCallResponse, ToolCall, ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse,
ToolChoice, ToolChoiceValue, FunctionCallResponse, GenerateMetaInfo, GenerateRequest, GenerateResponse, ToolCall,
ToolChoice, ToolChoiceValue, Usage,
}; };
use crate::reasoning_parser::ParserFactory as ReasoningParserFactory; use crate::reasoning_parser::ParserFactory as ReasoningParserFactory;
use crate::tokenizer::stop::{SequenceDecoderOutput, StopSequenceDecoder}; use crate::tokenizer::stop::{SequenceDecoderOutput, StopSequenceDecoder};
use crate::tokenizer::traits::Tokenizer; use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ParserFactory as ToolParserFactory; use crate::tool_parser::ParserFactory as ToolParserFactory;
use proto::generate_complete::MatchedStop;
use std::time::Instant;
use super::context::{DispatchMetadata, ExecutionResult};
use super::utils; use super::utils;
// ============================================================================ // ============================================================================
...@@ -51,6 +55,57 @@ impl ResponseProcessor { ...@@ -51,6 +55,57 @@ impl ResponseProcessor {
} }
} }
/// Helper to collect responses from execution result and merge logprobs if needed
async fn collect_and_merge_responses(
execution_result: ExecutionResult,
request_logprobs: bool,
) -> Result<Vec<proto::GenerateComplete>, axum::response::Response> {
let all_responses = match execution_result {
ExecutionResult::Single { mut stream } => {
let responses = utils::collect_stream_responses(&mut stream, "Single").await?;
stream.mark_completed();
responses
}
ExecutionResult::Dual {
mut prefill,
decode,
} => {
// Collect prefill for input_logprobs (don't mark completed yet)
let prefill_responses =
utils::collect_stream_responses(&mut prefill, "Prefill").await?;
// Collect decode for actual output (don't mark completed yet)
let mut decode_stream = *decode;
let mut decode_responses =
utils::collect_stream_responses(&mut decode_stream, "Decode").await?;
// Mark both streams as completed now that both succeeded
prefill.mark_completed();
decode_stream.mark_completed();
// Merge prefill input_logprobs if requested
if request_logprobs {
if let Some(prefill_input_logprobs) = prefill_responses
.first()
.and_then(|r| r.input_logprobs.clone())
{
for response in &mut decode_responses {
response.input_logprobs = Some(prefill_input_logprobs.clone());
}
}
}
decode_responses
}
};
if all_responses.is_empty() {
return Err(utils::internal_error_static("No responses from server"));
}
Ok(all_responses)
}
/// Process a single choice from GenerateComplete response (EXACT COPY from router.rs:1573-1725) /// Process a single choice from GenerateComplete response (EXACT COPY from router.rs:1573-1725)
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub async fn process_single_choice( pub async fn process_single_choice(
...@@ -158,12 +213,10 @@ impl ResponseProcessor { ...@@ -158,12 +213,10 @@ impl ResponseProcessor {
// Extract matched_stop information from proto // Extract matched_stop information from proto
let matched_stop = match &complete.matched_stop { let matched_stop = match &complete.matched_stop {
Some(proto::generate_complete::MatchedStop::MatchedTokenId(token_id)) => { Some(MatchedStop::MatchedTokenId(token_id)) => {
Some(Value::Number(serde_json::Number::from(*token_id))) Some(Value::Number(serde_json::Number::from(*token_id)))
} }
Some(proto::generate_complete::MatchedStop::MatchedStopStr(stop_str)) => { Some(MatchedStop::MatchedStopStr(stop_str)) => Some(Value::String(stop_str.clone())),
Some(Value::String(stop_str.clone()))
}
None => None, None => None,
}; };
...@@ -205,6 +258,109 @@ impl ResponseProcessor { ...@@ -205,6 +258,109 @@ impl ResponseProcessor {
Ok(choice) Ok(choice)
} }
/// Process non-streaming chat response (collects all responses and builds final response)
pub async fn process_non_streaming_chat_response(
&self,
execution_result: ExecutionResult,
chat_request: Arc<ChatCompletionRequest>,
dispatch: DispatchMetadata,
stop_decoder: &mut StopSequenceDecoder,
request_logprobs: bool,
) -> Result<ChatCompletionResponse, axum::response::Response> {
// Collect all responses from the execution result
let all_responses =
Self::collect_and_merge_responses(execution_result, request_logprobs).await?;
let history_tool_calls_count = utils::get_history_tool_calls_count(&chat_request);
// Check parser availability once upfront (not per choice)
let reasoning_parser_available = chat_request.separate_reasoning
&& utils::check_reasoning_parser_availability(
&self.reasoning_parser_factory,
self.configured_reasoning_parser.as_ref(),
&chat_request.model,
);
let tool_choice_enabled = !matches!(
&chat_request.tool_choice,
Some(ToolChoice::Value(ToolChoiceValue::None))
);
let tool_parser_available = tool_choice_enabled
&& chat_request.tools.is_some()
&& utils::check_tool_parser_availability(
&self.tool_parser_factory,
self.configured_tool_parser.as_ref(),
&chat_request.model,
);
// Log once per request (not per choice)
if chat_request.separate_reasoning && !reasoning_parser_available {
tracing::debug!(
"No reasoning parser found for model '{}', skipping reasoning parsing",
chat_request.model
);
}
if chat_request.tools.is_some() && tool_choice_enabled && !tool_parser_available {
tracing::debug!(
"No tool parser found for model '{}', skipping tool call parsing",
chat_request.model
);
}
// Process all choices
let mut choices = Vec::new();
for (index, complete) in all_responses.iter().enumerate() {
match self
.process_single_choice(
complete,
index,
&chat_request,
stop_decoder,
history_tool_calls_count,
reasoning_parser_available,
tool_parser_available,
)
.await
{
Ok(choice) => choices.push(choice),
Err(e) => {
return Err(utils::internal_error_message(format!(
"Failed to process choice {}: {}",
index, e
)));
}
}
}
// Build usage
let total_prompt_tokens: u32 = all_responses.iter().map(|r| r.prompt_tokens as u32).sum();
let total_completion_tokens: u32 = all_responses
.iter()
.map(|r| r.completion_tokens as u32)
.sum();
let usage = Usage {
prompt_tokens: total_prompt_tokens,
completion_tokens: total_completion_tokens,
total_tokens: total_prompt_tokens + total_completion_tokens,
completion_tokens_details: None,
};
// Build final ChatCompletionResponse
let response = ChatCompletionResponse {
id: dispatch.request_id.clone(),
object: "chat.completion".to_string(),
created: dispatch.created,
model: dispatch.model.clone(),
choices,
usage: Some(usage),
system_fingerprint: dispatch.weight_version.clone(),
};
Ok(response)
}
/// Parse tool calls using model-specific parser (EXACT COPY from router.rs:296-361) /// Parse tool calls using model-specific parser (EXACT COPY from router.rs:296-361)
pub async fn parse_tool_calls( pub async fn parse_tool_calls(
&self, &self,
...@@ -264,4 +420,112 @@ impl ResponseProcessor { ...@@ -264,4 +420,112 @@ impl ResponseProcessor {
} }
} }
} }
/// Process non-streaming generate response (collects all responses and builds final response array)
pub async fn process_non_streaming_generate_response(
&self,
execution_result: ExecutionResult,
_generate_request: Arc<GenerateRequest>,
dispatch: DispatchMetadata,
stop_decoder: &mut StopSequenceDecoder,
request_logprobs: bool,
start_time: Instant,
) -> Result<Vec<GenerateResponse>, axum::response::Response> {
// Collect all responses from the execution result
let all_responses =
Self::collect_and_merge_responses(execution_result, request_logprobs).await?;
// Process each completion
let mut result_array = Vec::new();
for mut complete in all_responses {
stop_decoder.reset();
// Process tokens through stop decoder
let outputs = match stop_decoder.process_tokens(&complete.output_ids) {
Ok(outputs) => outputs,
Err(e) => {
return Err(utils::internal_error_message(format!(
"Failed to process tokens: {}",
e
)))
}
};
// Accumulate text with early breaks
let mut decoded_text = String::new();
for output in outputs {
match output {
SequenceDecoderOutput::Text(t) => decoded_text.push_str(&t),
SequenceDecoderOutput::StoppedWithText(t) => {
decoded_text.push_str(&t);
break;
}
SequenceDecoderOutput::Stopped => break,
SequenceDecoderOutput::Held => {}
}
}
// Flush remaining text
if let SequenceDecoderOutput::Text(t) = stop_decoder.flush() {
decoded_text.push_str(&t);
}
let output_ids = std::mem::take(&mut complete.output_ids);
let finish_reason_str = std::mem::take(&mut complete.finish_reason);
// Parse finish_reason from string to proper type
let finish_reason =
utils::parse_finish_reason(&finish_reason_str, complete.completion_tokens);
// Handle matched_stop if present
let matched_stop = complete.matched_stop.take().map(|matched| match matched {
MatchedStop::MatchedTokenId(id) => serde_json::json!(id),
MatchedStop::MatchedStopStr(s) => serde_json::json!(s),
});
// Extract logprobs if requested (convert proto types to Generate format)
let input_token_logprobs = if request_logprobs {
complete
.input_logprobs
.as_ref()
.map(utils::convert_generate_input_logprobs)
} else {
None
};
let output_token_logprobs = if request_logprobs {
complete
.output_logprobs
.as_ref()
.map(utils::convert_generate_output_logprobs)
} else {
None
};
// Build GenerateResponse struct
let meta_info = GenerateMetaInfo {
id: dispatch.request_id.clone(),
finish_reason,
prompt_tokens: complete.prompt_tokens as u32,
weight_version: dispatch
.weight_version
.clone()
.unwrap_or_else(|| "default".to_string()),
input_token_logprobs,
output_token_logprobs,
completion_tokens: complete.completion_tokens as u32,
cached_tokens: complete.cached_tokens as u32,
e2e_latency: start_time.elapsed().as_secs_f64(),
matched_stop,
};
result_array.push(GenerateResponse {
text: decoded_text,
output_ids,
meta_info,
});
}
Ok(result_array)
}
} }
...@@ -24,6 +24,9 @@ use crate::server::AppContext; ...@@ -24,6 +24,9 @@ use crate::server::AppContext;
use crate::tokenizer::traits::Tokenizer; use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ParserFactory as ToolParserFactory; use crate::tool_parser::ParserFactory as ToolParserFactory;
use super::context::SharedComponents;
use super::pipeline::RequestPipeline;
/// gRPC router implementation for SGLang /// gRPC router implementation for SGLang
#[derive(Clone)] #[derive(Clone)]
#[allow(dead_code)] #[allow(dead_code)]
...@@ -38,8 +41,8 @@ pub struct GrpcRouter { ...@@ -38,8 +41,8 @@ pub struct GrpcRouter {
retry_config: RetryConfig, retry_config: RetryConfig,
configured_reasoning_parser: Option<String>, configured_reasoning_parser: Option<String>,
configured_tool_parser: Option<String>, configured_tool_parser: Option<String>,
pipeline: super::pipeline::ChatCompletionPipeline, pipeline: RequestPipeline,
shared_components: Arc<super::context::SharedComponents>, shared_components: Arc<SharedComponents>,
} }
impl GrpcRouter { impl GrpcRouter {
...@@ -66,36 +69,21 @@ impl GrpcRouter { ...@@ -66,36 +69,21 @@ impl GrpcRouter {
let policy_registry = ctx.policy_registry.clone(); let policy_registry = ctx.policy_registry.clone();
// Create shared components for pipeline // Create shared components for pipeline
let shared_components = Arc::new(super::context::SharedComponents { let shared_components = Arc::new(SharedComponents {
tokenizer: tokenizer.clone(), tokenizer: tokenizer.clone(),
tool_parser_factory: tool_parser_factory.clone(), tool_parser_factory: tool_parser_factory.clone(),
reasoning_parser_factory: reasoning_parser_factory.clone(), reasoning_parser_factory: reasoning_parser_factory.clone(),
}); });
// Create response processor // Create pipeline
let processor = super::processing::ResponseProcessor::new( let pipeline = RequestPipeline::new_regular(
tokenizer.clone(), worker_registry.clone(),
tool_parser_factory.clone(), policy_registry.clone(),
reasoning_parser_factory.clone(),
ctx.configured_tool_parser.clone(),
ctx.configured_reasoning_parser.clone(),
);
// Create streaming processor
let streaming_processor = Arc::new(super::streaming::StreamingProcessor::new(
tokenizer.clone(), tokenizer.clone(),
tool_parser_factory.clone(), tool_parser_factory.clone(),
reasoning_parser_factory.clone(), reasoning_parser_factory.clone(),
ctx.configured_tool_parser.clone(), ctx.configured_tool_parser.clone(),
ctx.configured_reasoning_parser.clone(), ctx.configured_reasoning_parser.clone(),
));
// Create pipeline
let pipeline = super::pipeline::ChatCompletionPipeline::new_regular(
worker_registry.clone(),
policy_registry.clone(),
processor,
streaming_processor,
); );
Ok(GrpcRouter { Ok(GrpcRouter {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment