Unverified Commit 677aa0e2 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] improve reasoning parser lock and reduce req cloning (#11336)

parent 01c9ee1a
...@@ -2,7 +2,9 @@ ...@@ -2,7 +2,9 @@
// Now with parser pooling support for efficient reuse across requests. // Now with parser pooling support for efficient reuse across requests.
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, Mutex, RwLock}; use std::sync::{Arc, RwLock};
use tokio::sync::Mutex;
use crate::reasoning_parser::parsers::{ use crate::reasoning_parser::parsers::{
BaseReasoningParser, DeepSeekR1Parser, Glm45Parser, KimiParser, Qwen3Parser, BaseReasoningParser, DeepSeekR1Parser, Glm45Parser, KimiParser, Qwen3Parser,
...@@ -11,6 +13,7 @@ use crate::reasoning_parser::parsers::{ ...@@ -11,6 +13,7 @@ use crate::reasoning_parser::parsers::{
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ReasoningParser}; use crate::reasoning_parser::traits::{ParseError, ParserConfig, ReasoningParser};
/// Type alias for pooled parser instances. /// Type alias for pooled parser instances.
/// Uses tokio::Mutex to avoid blocking the async executor.
pub type PooledParser = Arc<Mutex<Box<dyn ReasoningParser>>>; pub type PooledParser = Arc<Mutex<Box<dyn ReasoningParser>>>;
/// Type alias for parser creator functions. /// Type alias for parser creator functions.
...@@ -301,8 +304,8 @@ mod tests { ...@@ -301,8 +304,8 @@ mod tests {
assert_eq!(glm45.model_type(), "glm45"); assert_eq!(glm45.model_type(), "glm45");
} }
#[test] #[tokio::test]
fn test_pooled_parser_reuse() { async fn test_pooled_parser_reuse() {
let factory = ReasoningParserFactory::new(); let factory = ReasoningParserFactory::new();
// Get the same parser twice - should be the same instance // Get the same parser twice - should be the same instance
...@@ -317,20 +320,18 @@ mod tests { ...@@ -317,20 +320,18 @@ mod tests {
assert!(!Arc::ptr_eq(&parser1, &parser3)); assert!(!Arc::ptr_eq(&parser1, &parser3));
} }
#[test] #[tokio::test]
fn test_pooled_parser_concurrent_access() { async fn test_pooled_parser_concurrent_access() {
use std::thread;
let factory = ReasoningParserFactory::new(); let factory = ReasoningParserFactory::new();
let parser = factory.get_pooled("deepseek-r1"); let parser = factory.get_pooled("deepseek-r1");
// Spawn multiple threads that use the same parser // Spawn multiple async tasks that use the same parser
let mut handles = vec![]; let mut handles = vec![];
for i in 0..3 { for i in 0..3 {
let parser_clone = Arc::clone(&parser); let parser_clone = Arc::clone(&parser);
let handle = thread::spawn(move || { let handle = tokio::spawn(async move {
let mut parser = parser_clone.lock().unwrap(); let mut parser = parser_clone.lock().await;
let input = format!("thread {} reasoning</think>answer", i); let input = format!("thread {} reasoning</think>answer", i);
let result = parser.detect_and_parse_reasoning(&input).unwrap(); let result = parser.detect_and_parse_reasoning(&input).unwrap();
assert_eq!(result.normal_text, "answer"); assert_eq!(result.normal_text, "answer");
...@@ -339,14 +340,14 @@ mod tests { ...@@ -339,14 +340,14 @@ mod tests {
handles.push(handle); handles.push(handle);
} }
// Wait for all threads to complete // Wait for all tasks to complete
for handle in handles { for handle in handles {
handle.join().unwrap(); handle.await.unwrap();
} }
} }
#[test] #[tokio::test]
fn test_pool_clearing() { async fn test_pool_clearing() {
let factory = ReasoningParserFactory::new(); let factory = ReasoningParserFactory::new();
// Get a pooled parser // Get a pooled parser
...@@ -362,8 +363,8 @@ mod tests { ...@@ -362,8 +363,8 @@ mod tests {
assert!(!Arc::ptr_eq(&parser1, &parser2)); assert!(!Arc::ptr_eq(&parser1, &parser2));
} }
#[test] #[tokio::test]
fn test_passthrough_parser_pooling() { async fn test_passthrough_parser_pooling() {
let factory = ReasoningParserFactory::new(); let factory = ReasoningParserFactory::new();
// Unknown models should get passthrough parser // Unknown models should get passthrough parser
...@@ -373,19 +374,18 @@ mod tests { ...@@ -373,19 +374,18 @@ mod tests {
// Both should use the same passthrough parser instance // Both should use the same passthrough parser instance
assert!(Arc::ptr_eq(&parser1, &parser2)); assert!(Arc::ptr_eq(&parser1, &parser2));
let parser = parser1.lock().unwrap(); let parser = parser1.lock().await;
assert_eq!(parser.model_type(), "passthrough"); assert_eq!(parser.model_type(), "passthrough");
} }
#[test] #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
fn test_high_concurrency_parser_access() { async fn test_high_concurrency_parser_access() {
use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::atomic::{AtomicUsize, Ordering};
use std::thread;
use std::time::Instant; use std::time::Instant;
let factory = ReasoningParserFactory::new(); let factory = ReasoningParserFactory::new();
let num_threads = 100; let num_tasks = 100;
let requests_per_thread = 50; let requests_per_task = 50;
let models = vec!["deepseek-r1", "qwen3", "kimi", "qwen3-thinking"]; let models = vec!["deepseek-r1", "qwen3", "kimi", "qwen3-thinking"];
// Track successful operations // Track successful operations
...@@ -395,36 +395,25 @@ mod tests { ...@@ -395,36 +395,25 @@ mod tests {
let start = Instant::now(); let start = Instant::now();
let mut handles = vec![]; let mut handles = vec![];
for thread_id in 0..num_threads { for task_id in 0..num_tasks {
let factory = factory.clone(); let factory = factory.clone();
let models = models.clone(); let models = models.clone();
let success_count = Arc::clone(&success_count); let success_count = Arc::clone(&success_count);
let error_count = Arc::clone(&error_count); let error_count = Arc::clone(&error_count);
let handle = thread::spawn(move || { let handle = tokio::spawn(async move {
for request_id in 0..requests_per_thread { for request_id in 0..requests_per_task {
// Rotate through different models // Rotate through different models
let model = &models[(thread_id + request_id) % models.len()]; let model = &models[(task_id + request_id) % models.len()];
let parser = factory.get_pooled(model); let parser = factory.get_pooled(model);
// Use blocking lock - this is the realistic scenario // Use async lock - tokio::Mutex doesn't poison
// In production, requests would wait for the parser to be available let mut p = parser.lock().await;
// Handle poisoned locks gracefully
let mut p = match parser.lock() {
Ok(guard) => guard,
Err(_poisoned) => {
// Lock was poisoned by a panicking thread
// In production, we might want to recreate the parser
// For testing, we'll just skip this iteration
error_count.fetch_add(1, Ordering::Relaxed);
continue;
}
};
// Simulate realistic parsing work with substantial text // Simulate realistic parsing work with substantial text
// Typical reasoning can be 500-5000 tokens // Typical reasoning can be 500-5000 tokens
let reasoning_text = format!( let reasoning_text = format!(
"Thread {} is processing request {}. Let me think through this step by step. \ "Task {} is processing request {}. Let me think through this step by step. \
First, I need to understand the problem. The problem involves analyzing data \ First, I need to understand the problem. The problem involves analyzing data \
and making calculations. Let me break this down: \n\ and making calculations. Let me break this down: \n\
1. Initial analysis shows that we have multiple variables to consider. \ 1. Initial analysis shows that we have multiple variables to consider. \
...@@ -436,19 +425,19 @@ mod tests { ...@@ -436,19 +425,19 @@ mod tests {
7. Validating against known constraints... \ 7. Validating against known constraints... \
8. The conclusion follows logically from premises A, B, and C. \ 8. The conclusion follows logically from premises A, B, and C. \
This reasoning chain demonstrates the validity of our approach.", This reasoning chain demonstrates the validity of our approach.",
thread_id, request_id, thread_id, request_id, thread_id * request_id task_id, request_id, task_id, request_id, task_id * request_id
); );
let answer_text = format!( let answer_text = format!(
"Based on my analysis, the answer for thread {} request {} is: \ "Based on my analysis, the answer for task {} request {} is: \
The solution involves multiple steps as outlined in the reasoning. \ The solution involves multiple steps as outlined in the reasoning. \
The final result is {} with confidence level high. \ The final result is {} with confidence level high. \
This conclusion is supported by rigorous mathematical analysis \ This conclusion is supported by rigorous mathematical analysis \
and has been validated against multiple test cases. \ and has been validated against multiple test cases. \
The implementation should handle edge cases appropriately.", The implementation should handle edge cases appropriately.",
thread_id, task_id,
request_id, request_id,
thread_id * request_id task_id * request_id
); );
let input = format!("<think>{}</think>{}", reasoning_text, answer_text); let input = format!("<think>{}</think>{}", reasoning_text, answer_text);
...@@ -456,16 +445,14 @@ mod tests { ...@@ -456,16 +445,14 @@ mod tests {
match p.detect_and_parse_reasoning(&input) { match p.detect_and_parse_reasoning(&input) {
Ok(result) => { Ok(result) => {
// Note: Some parsers with stream_reasoning=true won't accumulate reasoning text // Note: Some parsers with stream_reasoning=true won't accumulate reasoning text
assert!(result assert!(result.normal_text.contains(&format!("task {}", task_id)));
.normal_text
.contains(&format!("thread {}", thread_id)));
// For parsers that accumulate reasoning (stream_reasoning=false) // For parsers that accumulate reasoning (stream_reasoning=false)
// the reasoning_text should be populated // the reasoning_text should be populated
if !result.reasoning_text.is_empty() { if !result.reasoning_text.is_empty() {
assert!(result assert!(result
.reasoning_text .reasoning_text
.contains(&format!("Thread {}", thread_id))); .contains(&format!("Task {}", task_id)));
assert!(result.reasoning_text.len() > 500); // Ensure substantial reasoning assert!(result.reasoning_text.len() > 500); // Ensure substantial reasoning
} }
...@@ -486,20 +473,20 @@ mod tests { ...@@ -486,20 +473,20 @@ mod tests {
handles.push(handle); handles.push(handle);
} }
// Wait for all threads // Wait for all tasks
for handle in handles { for handle in handles {
handle.join().unwrap(); handle.await.unwrap();
} }
let duration = start.elapsed(); let duration = start.elapsed();
let total_requests = num_threads * requests_per_thread; let total_requests = num_tasks * requests_per_task;
let successes = success_count.load(Ordering::Relaxed); let successes = success_count.load(Ordering::Relaxed);
let errors = error_count.load(Ordering::Relaxed); let errors = error_count.load(Ordering::Relaxed);
// Print stats for debugging // Print stats for debugging
println!( println!(
"High concurrency test: {} threads, {} requests each", "High concurrency test: {} tasks, {} requests each",
num_threads, requests_per_thread num_tasks, requests_per_task
); );
println!( println!(
"Completed in {:?}, {} successes, {} errors", "Completed in {:?}, {} successes, {} errors",
...@@ -523,42 +510,40 @@ mod tests { ...@@ -523,42 +510,40 @@ mod tests {
); );
} }
#[test] #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
fn test_concurrent_pool_modifications() { async fn test_concurrent_pool_modifications() {
use std::thread;
let factory = ReasoningParserFactory::new(); let factory = ReasoningParserFactory::new();
let mut handles = vec![]; let mut handles = vec![];
// Thread 1: Continuously get parsers // Task 1: Continuously get parsers
let factory1 = factory.clone(); let factory1 = factory.clone();
handles.push(thread::spawn(move || { handles.push(tokio::spawn(async move {
for _ in 0..100 { for _ in 0..100 {
let _parser = factory1.get_pooled("deepseek-r1"); let _parser = factory1.get_pooled("deepseek-r1");
} }
})); }));
// Thread 2: Continuously clear pool // Task 2: Continuously clear pool
let factory2 = factory.clone(); let factory2 = factory.clone();
handles.push(thread::spawn(move || { handles.push(tokio::spawn(async move {
for _ in 0..10 { for _ in 0..10 {
factory2.clear_pool(); factory2.clear_pool();
thread::sleep(std::time::Duration::from_micros(100)); tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
} }
})); }));
// Thread 3: Get different parsers // Task 3: Get different parsers
let factory3 = factory.clone(); let factory3 = factory.clone();
handles.push(thread::spawn(move || { handles.push(tokio::spawn(async move {
for i in 0..100 { for i in 0..100 {
let models = ["qwen3", "kimi", "unknown"]; let models = ["qwen3", "kimi", "unknown"];
let _parser = factory3.get_pooled(models[i % 3]); let _parser = factory3.get_pooled(models[i % 3]);
} }
})); }));
// Wait for all threads - should not deadlock or panic // Wait for all tasks - should not deadlock or panic
for handle in handles { for handle in handles {
handle.join().unwrap(); handle.await.unwrap();
} }
} }
} }
...@@ -48,9 +48,10 @@ pub struct RequestInput { ...@@ -48,9 +48,10 @@ pub struct RequestInput {
} }
/// Request type variants /// Request type variants
/// Using Arc instead of Box to enable cheap cloning for background tasks
pub enum RequestType { pub enum RequestType {
Chat(Box<ChatCompletionRequest>), Chat(Arc<ChatCompletionRequest>),
Generate(Box<GenerateRequest>), Generate(Arc<GenerateRequest>),
} }
/// Shared components (injected once at creation) /// Shared components (injected once at creation)
...@@ -181,14 +182,14 @@ pub struct StreamingState { ...@@ -181,14 +182,14 @@ pub struct StreamingState {
impl RequestContext { impl RequestContext {
/// Create context for chat completion request /// Create context for chat completion request
pub fn for_chat( pub fn for_chat(
request: ChatCompletionRequest, request: Arc<ChatCompletionRequest>,
headers: Option<HeaderMap>, headers: Option<HeaderMap>,
model_id: Option<String>, model_id: Option<String>,
components: Arc<SharedComponents>, components: Arc<SharedComponents>,
) -> Self { ) -> Self {
Self { Self {
input: RequestInput { input: RequestInput {
request_type: RequestType::Chat(Box::new(request)), request_type: RequestType::Chat(request),
headers, headers,
model_id, model_id,
}, },
...@@ -199,14 +200,14 @@ impl RequestContext { ...@@ -199,14 +200,14 @@ impl RequestContext {
/// Create context for generate request /// Create context for generate request
pub fn for_generate( pub fn for_generate(
request: GenerateRequest, request: Arc<GenerateRequest>,
headers: Option<HeaderMap>, headers: Option<HeaderMap>,
model_id: Option<String>, model_id: Option<String>,
components: Arc<SharedComponents>, components: Arc<SharedComponents>,
) -> Self { ) -> Self {
Self { Self {
input: RequestInput { input: RequestInput {
request_type: RequestType::Generate(Box::new(request)), request_type: RequestType::Generate(request),
headers, headers,
model_id, model_id,
}, },
...@@ -228,6 +229,14 @@ impl RequestContext { ...@@ -228,6 +229,14 @@ impl RequestContext {
} }
} }
/// Get Arc clone of chat request (panics if not chat)
pub fn chat_request_arc(&self) -> Arc<ChatCompletionRequest> {
match &self.input.request_type {
RequestType::Chat(req) => Arc::clone(req),
_ => panic!("Expected chat request"),
}
}
/// Get generate request (panics if not generate) /// Get generate request (panics if not generate)
pub fn generate_request(&self) -> &GenerateRequest { pub fn generate_request(&self) -> &GenerateRequest {
match &self.input.request_type { match &self.input.request_type {
...@@ -236,6 +245,14 @@ impl RequestContext { ...@@ -236,6 +245,14 @@ impl RequestContext {
} }
} }
/// Get Arc clone of generate request (panics if not generate)
pub fn generate_request_arc(&self) -> Arc<GenerateRequest> {
match &self.input.request_type {
RequestType::Generate(req) => Arc::clone(req),
_ => panic!("Expected generate request"),
}
}
/// Check if request is streaming /// Check if request is streaming
pub fn is_streaming(&self) -> bool { pub fn is_streaming(&self) -> bool {
match &self.input.request_type { match &self.input.request_type {
......
...@@ -129,7 +129,7 @@ impl GrpcPDRouter { ...@@ -129,7 +129,7 @@ impl GrpcPDRouter {
// Use pipeline for ALL requests (streaming and non-streaming) // Use pipeline for ALL requests (streaming and non-streaming)
self.pipeline self.pipeline
.execute_generate( .execute_generate(
body.clone(), Arc::new(body.clone()),
headers.cloned(), headers.cloned(),
model_id.map(|s| s.to_string()), model_id.map(|s| s.to_string()),
self.shared_components.clone(), self.shared_components.clone(),
...@@ -152,7 +152,7 @@ impl GrpcPDRouter { ...@@ -152,7 +152,7 @@ impl GrpcPDRouter {
// Use pipeline for ALL requests (streaming and non-streaming) // Use pipeline for ALL requests (streaming and non-streaming)
self.pipeline self.pipeline
.execute_chat( .execute_chat(
body.clone(), Arc::new(body.clone()),
headers.cloned(), headers.cloned(),
model_id.map(|s| s.to_string()), model_id.map(|s| s.to_string()),
self.shared_components.clone(), self.shared_components.clone(),
......
...@@ -58,16 +58,17 @@ impl PipelineStage for PreparationStage { ...@@ -58,16 +58,17 @@ impl PipelineStage for PreparationStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> { async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
debug!("Stage {}: Processing request", self.name()); debug!("Stage {}: Processing request", self.name());
// Clone the request to avoid borrowing issues // Clone Arc before match to avoid borrow checker issues
match &ctx.input.request_type { // (matching borrows ctx, but prepare_* methods need mutable borrow)
RequestType::Chat(request) => { // Arc clone is cheap (8 bytes) - avoids full request clone (15KB-200KB)
let request_clone = request.clone(); let is_chat = matches!(&ctx.input.request_type, RequestType::Chat(_));
self.prepare_chat(ctx, &request_clone).await?;
} if is_chat {
RequestType::Generate(request) => { let request_arc = ctx.chat_request_arc();
let request_clone = request.clone(); self.prepare_chat(ctx, &request_arc).await?;
self.prepare_generate(ctx, &request_clone).await?; } else {
} let request_arc = ctx.generate_request_arc();
self.prepare_generate(ctx, &request_arc).await?;
} }
Ok(None) Ok(None)
...@@ -820,7 +821,7 @@ impl ResponseProcessingStage { ...@@ -820,7 +821,7 @@ impl ResponseProcessingStage {
return Ok(Some( return Ok(Some(
self.streaming_processor.clone().process_streaming_response( self.streaming_processor.clone().process_streaming_response(
execution_result, execution_result,
ctx.chat_request().clone(), ctx.chat_request_arc(), // Cheap Arc clone (8 bytes)
dispatch.clone(), dispatch.clone(),
), ),
)); ));
...@@ -865,9 +866,7 @@ impl ResponseProcessingStage { ...@@ -865,9 +866,7 @@ impl ResponseProcessingStage {
return Err(utils::internal_error_static("No responses from server")); return Err(utils::internal_error_static("No responses from server"));
} }
// Clone chat_request to avoid borrow checker conflict let chat_request = ctx.chat_request_arc();
// (ctx.chat_request() borrows ctx, preventing mutable borrow of ctx.state.response.stop_decoder)
let chat_request = ctx.chat_request().clone();
let history_tool_calls_count = utils::get_history_tool_calls_count(&chat_request); let history_tool_calls_count = utils::get_history_tool_calls_count(&chat_request);
let stop_decoder = ctx let stop_decoder = ctx
...@@ -959,13 +958,11 @@ impl ResponseProcessingStage { ...@@ -959,13 +958,11 @@ impl ResponseProcessingStage {
.as_ref() .as_ref()
.ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?; .ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?;
let generate_request = ctx.generate_request().clone();
// Streaming: Use StreamingProcessor and return SSE response (done) // Streaming: Use StreamingProcessor and return SSE response (done)
return Ok(Some( return Ok(Some(
self.streaming_processor.clone().process_streaming_generate( self.streaming_processor.clone().process_streaming_generate(
execution_result, execution_result,
generate_request, ctx.generate_request_arc(), // Cheap Arc clone (8 bytes)
dispatch.clone(), dispatch.clone(),
), ),
)); ));
...@@ -1193,8 +1190,8 @@ impl ChatCompletionPipeline { ...@@ -1193,8 +1190,8 @@ impl ChatCompletionPipeline {
/// Execute the complete pipeline for a chat request /// Execute the complete pipeline for a chat request
pub async fn execute_chat( pub async fn execute_chat(
&self, &self,
request: ChatCompletionRequest, request: Arc<ChatCompletionRequest>,
headers: Option<axum::http::HeaderMap>, headers: Option<http::HeaderMap>,
model_id: Option<String>, model_id: Option<String>,
components: Arc<SharedComponents>, components: Arc<SharedComponents>,
) -> Response { ) -> Response {
...@@ -1243,8 +1240,8 @@ impl ChatCompletionPipeline { ...@@ -1243,8 +1240,8 @@ impl ChatCompletionPipeline {
/// Execute the complete pipeline for a generate request /// Execute the complete pipeline for a generate request
pub async fn execute_generate( pub async fn execute_generate(
&self, &self,
request: GenerateRequest, request: Arc<GenerateRequest>,
headers: Option<axum::http::HeaderMap>, headers: Option<http::HeaderMap>,
model_id: Option<String>, model_id: Option<String>,
components: Arc<SharedComponents>, components: Arc<SharedComponents>,
) -> Response { ) -> Response {
......
...@@ -97,9 +97,7 @@ impl ResponseProcessor { ...@@ -97,9 +97,7 @@ impl ResponseProcessor {
&original_request.model, &original_request.model,
); );
let mut parser = pooled_parser let mut parser = pooled_parser.lock().await;
.lock()
.map_err(|e| format!("Failed to acquire reasoning parser lock: {}", e))?;
match parser.detect_and_parse_reasoning(&processed_text) { match parser.detect_and_parse_reasoning(&processed_text) {
Ok(result) => { Ok(result) => {
if !result.reasoning_text.is_empty() { if !result.reasoning_text.is_empty() {
......
...@@ -129,7 +129,7 @@ impl GrpcRouter { ...@@ -129,7 +129,7 @@ impl GrpcRouter {
// Use pipeline for ALL requests (streaming and non-streaming) // Use pipeline for ALL requests (streaming and non-streaming)
self.pipeline self.pipeline
.execute_chat( .execute_chat(
body.clone(), Arc::new(body.clone()),
headers.cloned(), headers.cloned(),
model_id.map(|s| s.to_string()), model_id.map(|s| s.to_string()),
self.shared_components.clone(), self.shared_components.clone(),
...@@ -149,7 +149,7 @@ impl GrpcRouter { ...@@ -149,7 +149,7 @@ impl GrpcRouter {
// Use pipeline for ALL requests (streaming and non-streaming) // Use pipeline for ALL requests (streaming and non-streaming)
self.pipeline self.pipeline
.execute_generate( .execute_generate(
body.clone(), Arc::new(body.clone()),
headers.cloned(), headers.cloned(),
model_id.map(|s| s.to_string()), model_id.map(|s| s.to_string()),
self.shared_components.clone(), self.shared_components.clone(),
......
...@@ -66,7 +66,7 @@ impl StreamingProcessor { ...@@ -66,7 +66,7 @@ impl StreamingProcessor {
pub fn process_streaming_response( pub fn process_streaming_response(
self: Arc<Self>, self: Arc<Self>,
execution_result: context::ExecutionResult, execution_result: context::ExecutionResult,
chat_request: ChatCompletionRequest, chat_request: Arc<ChatCompletionRequest>,
dispatch: context::DispatchMetadata, dispatch: context::DispatchMetadata,
) -> Response { ) -> Response {
use bytes::Bytes; use bytes::Bytes;
...@@ -156,7 +156,7 @@ impl StreamingProcessor { ...@@ -156,7 +156,7 @@ impl StreamingProcessor {
mut grpc_stream: Streaming<proto::GenerateResponse>, mut grpc_stream: Streaming<proto::GenerateResponse>,
dispatch: context::DispatchMetadata, dispatch: context::DispatchMetadata,
stop_params: (Option<StringOrArray>, Option<Vec<u32>>, bool, bool), stop_params: (Option<StringOrArray>, Option<Vec<u32>>, bool, bool),
original_request: ChatCompletionRequest, original_request: Arc<ChatCompletionRequest>,
tx: &UnboundedSender<Result<Bytes, io::Error>>, tx: &UnboundedSender<Result<Bytes, io::Error>>,
) -> Result<(), String> { ) -> Result<(), String> {
// Extract request parameters // Extract request parameters
...@@ -176,7 +176,7 @@ impl StreamingProcessor { ...@@ -176,7 +176,7 @@ impl StreamingProcessor {
let mut cached_tokens: HashMap<u32, u32> = HashMap::new(); let mut cached_tokens: HashMap<u32, u32> = HashMap::new();
// Parser state (lazy initialization per index) // Parser state (lazy initialization per index)
type PooledReasoningParser = Arc<std::sync::Mutex<Box<dyn ReasoningParser>>>; type PooledReasoningParser = Arc<tokio::sync::Mutex<Box<dyn ReasoningParser>>>;
let mut reasoning_parsers: HashMap<u32, PooledReasoningParser> = HashMap::new(); let mut reasoning_parsers: HashMap<u32, PooledReasoningParser> = HashMap::new();
type PooledToolParser = Arc<tokio::sync::Mutex<Box<dyn ToolParser>>>; type PooledToolParser = Arc<tokio::sync::Mutex<Box<dyn ToolParser>>>;
...@@ -186,6 +186,9 @@ impl StreamingProcessor { ...@@ -186,6 +186,9 @@ impl StreamingProcessor {
// Per-index stop decoders (each index needs its own state for n>1 support) // Per-index stop decoders (each index needs its own state for n>1 support)
let mut stop_decoders: HashMap<u32, StopSequenceDecoder> = HashMap::new(); let mut stop_decoders: HashMap<u32, StopSequenceDecoder> = HashMap::new();
// Reusable SSE formatting buffer to avoid allocations per chunk
let mut sse_buffer = Vec::with_capacity(512);
// Use dispatch metadata for consistent response fields // Use dispatch metadata for consistent response fields
let request_id = &dispatch.request_id; let request_id = &dispatch.request_id;
let model = &dispatch.model; let model = &dispatch.model;
...@@ -262,7 +265,8 @@ impl StreamingProcessor { ...@@ -262,7 +265,8 @@ impl StreamingProcessor {
}], }],
usage: None, usage: None,
}; };
tx.send(Ok(Bytes::from(Self::format_sse_chunk(&first_chunk)))) Self::format_sse_chunk_into(&mut sse_buffer, &first_chunk);
tx.send(Ok(Bytes::from(sse_buffer.clone())))
.map_err(|_| "Failed to send first chunk".to_string())?; .map_err(|_| "Failed to send first chunk".to_string())?;
is_firsts.insert(index, false); is_firsts.insert(index, false);
} }
...@@ -282,9 +286,11 @@ impl StreamingProcessor { ...@@ -282,9 +286,11 @@ impl StreamingProcessor {
model, model,
created, created,
system_fingerprint, system_fingerprint,
); )
.await;
if let Some(chunk) = reasoning_chunk { if let Some(chunk) = reasoning_chunk {
tx.send(Ok(Bytes::from(Self::format_sse_chunk(&chunk)))) Self::format_sse_chunk_into(&mut sse_buffer, &chunk);
tx.send(Ok(Bytes::from(sse_buffer.clone())))
.map_err(|_| "Failed to send reasoning chunk".to_string())?; .map_err(|_| "Failed to send reasoning chunk".to_string())?;
} }
delta = normal_text; delta = normal_text;
...@@ -314,7 +320,8 @@ impl StreamingProcessor { ...@@ -314,7 +320,8 @@ impl StreamingProcessor {
.await; .await;
for chunk in tool_chunks { for chunk in tool_chunks {
tx.send(Ok(Bytes::from(Self::format_sse_chunk(&chunk)))) Self::format_sse_chunk_into(&mut sse_buffer, &chunk);
tx.send(Ok(Bytes::from(sse_buffer.clone())))
.map_err(|_| "Failed to send tool call chunk".to_string())?; .map_err(|_| "Failed to send tool call chunk".to_string())?;
} }
...@@ -335,7 +342,8 @@ impl StreamingProcessor { ...@@ -335,7 +342,8 @@ impl StreamingProcessor {
system_fingerprint, system_fingerprint,
choice_logprobs, choice_logprobs,
); );
tx.send(Ok(Bytes::from(Self::format_sse_chunk(&content_chunk)))) Self::format_sse_chunk_into(&mut sse_buffer, &content_chunk);
tx.send(Ok(Bytes::from(sse_buffer.clone())))
.map_err(|_| "Failed to send content chunk".to_string())?; .map_err(|_| "Failed to send content chunk".to_string())?;
} }
} }
...@@ -529,7 +537,7 @@ impl StreamingProcessor { ...@@ -529,7 +537,7 @@ impl StreamingProcessor {
decode_stream: Streaming<proto::GenerateResponse>, decode_stream: Streaming<proto::GenerateResponse>,
dispatch: context::DispatchMetadata, dispatch: context::DispatchMetadata,
stop_params: (Option<StringOrArray>, Option<Vec<u32>>, bool, bool), stop_params: (Option<StringOrArray>, Option<Vec<u32>>, bool, bool),
original_request: ChatCompletionRequest, original_request: Arc<ChatCompletionRequest>,
tx: &UnboundedSender<Result<Bytes, io::Error>>, tx: &UnboundedSender<Result<Bytes, io::Error>>,
) -> Result<(), String> { ) -> Result<(), String> {
// Phase 1.5: Collect input_logprobs from prefill stream if requested // Phase 1.5: Collect input_logprobs from prefill stream if requested
...@@ -561,7 +569,7 @@ impl StreamingProcessor { ...@@ -561,7 +569,7 @@ impl StreamingProcessor {
pub fn process_streaming_generate( pub fn process_streaming_generate(
self: Arc<Self>, self: Arc<Self>,
execution_result: context::ExecutionResult, execution_result: context::ExecutionResult,
generate_request: GenerateRequest, generate_request: Arc<GenerateRequest>,
dispatch: context::DispatchMetadata, dispatch: context::DispatchMetadata,
) -> Response { ) -> Response {
let return_logprob = generate_request.return_logprob; let return_logprob = generate_request.return_logprob;
...@@ -946,11 +954,11 @@ impl StreamingProcessor { ...@@ -946,11 +954,11 @@ impl StreamingProcessor {
/// Helper: Process reasoning content in streaming mode /// Helper: Process reasoning content in streaming mode
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn process_reasoning_stream( async fn process_reasoning_stream(
&self, &self,
delta: &str, delta: &str,
index: u32, index: u32,
reasoning_parsers: &mut HashMap<u32, Arc<std::sync::Mutex<Box<dyn ReasoningParser>>>>, reasoning_parsers: &mut HashMap<u32, Arc<tokio::sync::Mutex<Box<dyn ReasoningParser>>>>,
request_id: &str, request_id: &str,
model: &str, model: &str,
created: u64, created: u64,
...@@ -967,7 +975,7 @@ impl StreamingProcessor { ...@@ -967,7 +975,7 @@ impl StreamingProcessor {
if let Some(pooled_parser) = reasoning_parsers.get(&index) { if let Some(pooled_parser) = reasoning_parsers.get(&index) {
let (parse_result, in_reasoning) = { let (parse_result, in_reasoning) = {
let mut parser = pooled_parser.lock().unwrap(); let mut parser = pooled_parser.lock().await;
let result = parser.parse_reasoning_streaming_incremental(delta); let result = parser.parse_reasoning_streaming_incremental(delta);
let in_reasoning = parser.is_in_reasoning(); let in_reasoning = parser.is_in_reasoning();
(result, in_reasoning) (result, in_reasoning)
...@@ -1134,15 +1142,20 @@ impl StreamingProcessor { ...@@ -1134,15 +1142,20 @@ impl StreamingProcessor {
(false, chunks) (false, chunks)
} }
/// Format a response as SSE chunk /// Format a response as SSE chunk into a reusable buffer
fn format_sse_chunk(chunk: &ChatCompletionStreamResponse) -> String { /// This avoids allocations by reusing the same buffer across multiple chunks
match serde_json::to_string(chunk) { #[inline]
Ok(json) => format!("data: {}\n\n", json), fn format_sse_chunk_into(buffer: &mut Vec<u8>, chunk: &ChatCompletionStreamResponse) {
Err(e) => { buffer.clear();
buffer.extend_from_slice(b"data: ");
if let Err(e) = serde_json::to_writer(&mut *buffer, chunk) {
error!("Failed to serialize SSE chunk: {}", e); error!("Failed to serialize SSE chunk: {}", e);
format!("data: {}\n\n", json!({"error": "serialization_failed"})) buffer.clear();
} buffer.extend_from_slice(b"data: ");
let error_msg = json!({"error": "serialization_failed"}).to_string();
buffer.extend_from_slice(error_msg.as_bytes());
} }
buffer.extend_from_slice(b"\n\n");
} }
/// Create a content chunk response /// Create a content chunk response
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment