"megatron/git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "32e1ddb0432b34f78784fd6faae5e3a0aed7d06e"
Unverified Commit b45f753c authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] adds reasoning parser pooling and thread-safe (#9360)

parent c5057262
// Factory and registry for creating model-specific reasoning parsers. // Factory and registry for creating model-specific reasoning parsers.
// Now with parser pooling support for efficient reuse across requests.
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, RwLock}; use std::sync::{Arc, Mutex, RwLock};
use crate::reasoning_parser::parsers::{ use crate::reasoning_parser::parsers::{
BaseReasoningParser, DeepSeekR1Parser, KimiParser, Qwen3Parser, QwenThinkingParser, BaseReasoningParser, DeepSeekR1Parser, KimiParser, Qwen3Parser, QwenThinkingParser,
}; };
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ReasoningParser}; use crate::reasoning_parser::traits::{ParseError, ParserConfig, ReasoningParser};
/// Type alias for pooled parser instances.
pub type PooledParser = Arc<Mutex<Box<dyn ReasoningParser>>>;
/// Type alias for parser creator functions. /// Type alias for parser creator functions.
type ParserCreator = Arc<dyn Fn() -> Box<dyn ReasoningParser> + Send + Sync>; type ParserCreator = Arc<dyn Fn() -> Box<dyn ReasoningParser> + Send + Sync>;
/// Registry for model-specific parsers. /// Registry for model-specific parsers with pooling support.
#[derive(Clone)] #[derive(Clone)]
pub struct ParserRegistry { pub struct ParserRegistry {
parsers: Arc<RwLock<HashMap<String, ParserCreator>>>, /// Creator functions for parsers (used when pool is empty)
creators: Arc<RwLock<HashMap<String, ParserCreator>>>,
/// Pooled parser instances for reuse
pool: Arc<RwLock<HashMap<String, PooledParser>>>,
/// Model pattern to parser name mappings
patterns: Arc<RwLock<Vec<(String, String)>>>, // (pattern, parser_name) patterns: Arc<RwLock<Vec<(String, String)>>>, // (pattern, parser_name)
} }
...@@ -22,7 +30,8 @@ impl ParserRegistry { ...@@ -22,7 +30,8 @@ impl ParserRegistry {
/// Create a new empty registry. /// Create a new empty registry.
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
parsers: Arc::new(RwLock::new(HashMap::new())), creators: Arc::new(RwLock::new(HashMap::new())),
pool: Arc::new(RwLock::new(HashMap::new())),
patterns: Arc::new(RwLock::new(Vec::new())), patterns: Arc::new(RwLock::new(Vec::new())),
} }
} }
...@@ -32,8 +41,8 @@ impl ParserRegistry { ...@@ -32,8 +41,8 @@ impl ParserRegistry {
where where
F: Fn() -> Box<dyn ReasoningParser> + Send + Sync + 'static, F: Fn() -> Box<dyn ReasoningParser> + Send + Sync + 'static,
{ {
let mut parsers = self.parsers.write().unwrap(); let mut creators = self.creators.write().unwrap();
parsers.insert(name.to_string(), Arc::new(creator)); creators.insert(name.to_string(), Arc::new(creator));
} }
/// Register a model pattern to parser mapping. /// Register a model pattern to parser mapping.
...@@ -43,13 +52,53 @@ impl ParserRegistry { ...@@ -43,13 +52,53 @@ impl ParserRegistry {
patterns.push((pattern.to_string(), parser_name.to_string())); patterns.push((pattern.to_string(), parser_name.to_string()));
} }
/// Get a parser by exact name. /// Get a pooled parser by exact name.
/// Returns a shared parser instance from the pool, creating one if needed.
pub fn get_pooled_parser(&self, name: &str) -> Option<PooledParser> {
// First check if we have a pooled instance
{
let pool = self.pool.read().unwrap();
if let Some(parser) = pool.get(name) {
return Some(Arc::clone(parser));
}
}
// If not in pool, create one and add to pool
let creators = self.creators.read().unwrap();
if let Some(creator) = creators.get(name) {
let parser = Arc::new(Mutex::new(creator()));
// Add to pool for future use
let mut pool = self.pool.write().unwrap();
pool.insert(name.to_string(), Arc::clone(&parser));
Some(parser)
} else {
None
}
}
/// Get a parser by exact name (creates new instance, not pooled).
/// Use this for compatibility or when you need a fresh instance.
pub fn get_parser(&self, name: &str) -> Option<Box<dyn ReasoningParser>> { pub fn get_parser(&self, name: &str) -> Option<Box<dyn ReasoningParser>> {
let parsers = self.parsers.read().unwrap(); let creators = self.creators.read().unwrap();
parsers.get(name).map(|creator| creator()) creators.get(name).map(|creator| creator())
}
/// Find a pooled parser for a given model ID by pattern matching.
pub fn find_pooled_parser_for_model(&self, model_id: &str) -> Option<PooledParser> {
let patterns = self.patterns.read().unwrap();
let model_lower = model_id.to_lowercase();
for (pattern, parser_name) in patterns.iter() {
if model_lower.contains(&pattern.to_lowercase()) {
return self.get_pooled_parser(parser_name);
}
}
None
} }
/// Find a parser for a given model ID by pattern matching. /// Find a parser for a given model ID by pattern matching (creates new instance).
pub fn find_parser_for_model(&self, model_id: &str) -> Option<Box<dyn ReasoningParser>> { pub fn find_parser_for_model(&self, model_id: &str) -> Option<Box<dyn ReasoningParser>> {
let patterns = self.patterns.read().unwrap(); let patterns = self.patterns.read().unwrap();
let model_lower = model_id.to_lowercase(); let model_lower = model_id.to_lowercase();
...@@ -61,6 +110,13 @@ impl ParserRegistry { ...@@ -61,6 +110,13 @@ impl ParserRegistry {
} }
None None
} }
/// Clear the parser pool, forcing new instances to be created.
/// Useful for testing or when parsers need to be reset globally.
pub fn clear_pool(&self) {
let mut pool = self.pool.write().unwrap();
pool.clear();
}
} }
impl Default for ParserRegistry { impl Default for ParserRegistry {
...@@ -70,6 +126,7 @@ impl Default for ParserRegistry { ...@@ -70,6 +126,7 @@ impl Default for ParserRegistry {
} }
/// Factory for creating reasoning parsers based on model type. /// Factory for creating reasoning parsers based on model type.
#[derive(Clone)]
pub struct ParserFactory { pub struct ParserFactory {
registry: ParserRegistry, registry: ParserRegistry,
} }
...@@ -109,8 +166,39 @@ impl ParserFactory { ...@@ -109,8 +166,39 @@ impl ParserFactory {
Self { registry } Self { registry }
} }
/// Create a parser for the given model ID. /// Get a pooled parser for the given model ID.
/// Returns a no-op parser if model is not recognized. /// Returns a shared instance that can be used concurrently.
/// Falls back to a passthrough parser if model is not recognized.
pub fn get_pooled(&self, model_id: &str) -> PooledParser {
// First try to find by pattern
if let Some(parser) = self.registry.find_pooled_parser_for_model(model_id) {
return parser;
}
// Fall back to no-op parser (get or create passthrough in pool)
self.registry
.get_pooled_parser("passthrough")
.unwrap_or_else(|| {
// Register passthrough if not already registered
self.registry.register_parser("passthrough", || {
let config = ParserConfig {
think_start_token: "".to_string(),
think_end_token: "".to_string(),
stream_reasoning: true,
max_buffer_size: 65536,
initial_in_reasoning: false,
};
Box::new(
BaseReasoningParser::new(config).with_model_type("passthrough".to_string()),
)
});
self.registry.get_pooled_parser("passthrough").unwrap()
})
}
/// Create a new parser instance for the given model ID.
/// Returns a fresh instance (not pooled).
/// Use this when you need an isolated parser instance.
pub fn create(&self, model_id: &str) -> Result<Box<dyn ReasoningParser>, ParseError> { pub fn create(&self, model_id: &str) -> Result<Box<dyn ReasoningParser>, ParseError> {
// First try to find by pattern // First try to find by pattern
if let Some(parser) = self.registry.find_parser_for_model(model_id) { if let Some(parser) = self.registry.find_parser_for_model(model_id) {
...@@ -134,6 +222,12 @@ impl ParserFactory { ...@@ -134,6 +222,12 @@ impl ParserFactory {
pub fn registry(&self) -> &ParserRegistry { pub fn registry(&self) -> &ParserRegistry {
&self.registry &self.registry
} }
/// Clear the parser pool.
/// Useful for testing or when parsers need to be reset globally.
pub fn clear_pool(&self) {
self.registry.clear_pool();
}
} }
impl Default for ParserFactory { impl Default for ParserFactory {
...@@ -195,4 +289,267 @@ mod tests { ...@@ -195,4 +289,267 @@ mod tests {
assert_eq!(step3.model_type(), "deepseek_r1"); assert_eq!(step3.model_type(), "deepseek_r1");
assert_eq!(glm45.model_type(), "qwen3"); assert_eq!(glm45.model_type(), "qwen3");
} }
#[test]
fn test_pooled_parser_reuse() {
let factory = ParserFactory::new();
// Get the same parser twice - should be the same instance
let parser1 = factory.get_pooled("deepseek-r1");
let parser2 = factory.get_pooled("deepseek-r1");
// Both should point to the same Arc
assert!(Arc::ptr_eq(&parser1, &parser2));
// Different models should get different parsers
let parser3 = factory.get_pooled("qwen3");
assert!(!Arc::ptr_eq(&parser1, &parser3));
}
#[test]
fn test_pooled_parser_concurrent_access() {
use std::thread;
let factory = ParserFactory::new();
let parser = factory.get_pooled("deepseek-r1");
// Spawn multiple threads that use the same parser
let mut handles = vec![];
for i in 0..3 {
let parser_clone = Arc::clone(&parser);
let handle = thread::spawn(move || {
let mut parser = parser_clone.lock().unwrap();
let input = format!("thread {} reasoning</think>answer", i);
let result = parser.detect_and_parse_reasoning(&input).unwrap();
assert_eq!(result.normal_text, "answer");
assert!(result.reasoning_text.contains("reasoning"));
});
handles.push(handle);
}
// Wait for all threads to complete
for handle in handles {
handle.join().unwrap();
}
}
#[test]
fn test_pool_clearing() {
let factory = ParserFactory::new();
// Get a pooled parser
let parser1 = factory.get_pooled("deepseek-r1");
// Clear the pool
factory.clear_pool();
// Get another parser - should be a new instance
let parser2 = factory.get_pooled("deepseek-r1");
// They should be different instances (different Arc pointers)
assert!(!Arc::ptr_eq(&parser1, &parser2));
}
#[test]
fn test_passthrough_parser_pooling() {
let factory = ParserFactory::new();
// Unknown models should get passthrough parser
let parser1 = factory.get_pooled("unknown-model-1");
let parser2 = factory.get_pooled("unknown-model-2");
// Both should use the same passthrough parser instance
assert!(Arc::ptr_eq(&parser1, &parser2));
// Verify it's actually a passthrough parser
let parser = parser1.lock().unwrap();
assert_eq!(parser.model_type(), "passthrough");
}
#[test]
fn test_high_concurrency_parser_access() {
use std::sync::atomic::{AtomicUsize, Ordering};
use std::thread;
use std::time::Instant;
let factory = ParserFactory::new();
let num_threads = 100;
let requests_per_thread = 50;
let models = vec!["deepseek-r1", "qwen3", "kimi", "qwen3-thinking"];
// Track successful operations
let success_count = Arc::new(AtomicUsize::new(0));
let error_count = Arc::new(AtomicUsize::new(0));
let start = Instant::now();
let mut handles = vec![];
for thread_id in 0..num_threads {
let factory = factory.clone();
let models = models.clone();
let success_count = Arc::clone(&success_count);
let error_count = Arc::clone(&error_count);
let handle = thread::spawn(move || {
for request_id in 0..requests_per_thread {
// Rotate through different models
let model = &models[(thread_id + request_id) % models.len()];
let parser = factory.get_pooled(model);
// Use blocking lock - this is the realistic scenario
// In production, requests would wait for the parser to be available
// Handle poisoned locks gracefully
let mut p = match parser.lock() {
Ok(guard) => guard,
Err(_poisoned) => {
// Lock was poisoned by a panicking thread
// In production, we might want to recreate the parser
// For testing, we'll just skip this iteration
error_count.fetch_add(1, Ordering::Relaxed);
continue;
}
};
// Simulate realistic parsing work with substantial text
// Typical reasoning can be 500-5000 tokens
let reasoning_text = format!(
"Thread {} is processing request {}. Let me think through this step by step. \
First, I need to understand the problem. The problem involves analyzing data \
and making calculations. Let me break this down: \n\
1. Initial analysis shows that we have multiple variables to consider. \
2. The data suggests a pattern that needs further investigation. \
3. Computing the values: {} * {} = {}. \
4. Cross-referencing with previous results indicates consistency. \
5. The mathematical proof follows from the axioms... \
6. Considering edge cases and boundary conditions... \
7. Validating against known constraints... \
8. The conclusion follows logically from premises A, B, and C. \
This reasoning chain demonstrates the validity of our approach.",
thread_id, request_id, thread_id, request_id, thread_id * request_id
);
let answer_text = format!(
"Based on my analysis, the answer for thread {} request {} is: \
The solution involves multiple steps as outlined in the reasoning. \
The final result is {} with confidence level high. \
This conclusion is supported by rigorous mathematical analysis \
and has been validated against multiple test cases. \
The implementation should handle edge cases appropriately.",
thread_id,
request_id,
thread_id * request_id
);
let input = format!("<think>{}</think>{}", reasoning_text, answer_text);
match p.detect_and_parse_reasoning(&input) {
Ok(result) => {
// Verify parsing worked correctly with substantial content
// Note: Some parsers with stream_reasoning=true won't accumulate reasoning text
assert!(result
.normal_text
.contains(&format!("thread {}", thread_id)));
// For parsers that accumulate reasoning (stream_reasoning=false)
// the reasoning_text should be populated
if !result.reasoning_text.is_empty() {
assert!(result
.reasoning_text
.contains(&format!("Thread {}", thread_id)));
assert!(result.reasoning_text.len() > 500); // Ensure substantial reasoning
}
// Normal text should always be present
assert!(result.normal_text.len() > 100); // Ensure substantial answer
success_count.fetch_add(1, Ordering::Relaxed);
}
Err(e) => {
eprintln!("Parse error: {:?}", e);
error_count.fetch_add(1, Ordering::Relaxed);
}
}
// Explicitly drop the lock to release it quickly
drop(p);
}
});
handles.push(handle);
}
// Wait for all threads
for handle in handles {
handle.join().unwrap();
}
let duration = start.elapsed();
let total_requests = num_threads * requests_per_thread;
let successes = success_count.load(Ordering::Relaxed);
let errors = error_count.load(Ordering::Relaxed);
// Print stats for debugging
println!(
"High concurrency test: {} threads, {} requests each",
num_threads, requests_per_thread
);
println!(
"Completed in {:?}, {} successes, {} errors",
duration, successes, errors
);
println!(
"Throughput: {:.0} requests/sec",
(total_requests as f64) / duration.as_secs_f64()
);
// All requests should succeed
assert_eq!(successes, total_requests);
assert_eq!(errors, 0);
// Performance check: should handle at least 1000 req/sec
let throughput = (total_requests as f64) / duration.as_secs_f64();
assert!(
throughput > 1000.0,
"Throughput too low: {:.0} req/sec",
throughput
);
}
#[test]
fn test_concurrent_pool_modifications() {
use std::thread;
let factory = ParserFactory::new();
let mut handles = vec![];
// Thread 1: Continuously get parsers
let factory1 = factory.clone();
handles.push(thread::spawn(move || {
for _ in 0..100 {
let _parser = factory1.get_pooled("deepseek-r1");
}
}));
// Thread 2: Continuously clear pool
let factory2 = factory.clone();
handles.push(thread::spawn(move || {
for _ in 0..10 {
factory2.clear_pool();
thread::sleep(std::time::Duration::from_micros(100));
}
}));
// Thread 3: Get different parsers
let factory3 = factory.clone();
handles.push(thread::spawn(move || {
for i in 0..100 {
let models = ["qwen3", "kimi", "unknown"];
let _parser = factory3.get_pooled(models[i % 3]);
}
}));
// Wait for all threads - should not deadlock or panic
for handle in handles {
handle.join().unwrap();
}
}
} }
...@@ -2,7 +2,7 @@ pub mod factory; ...@@ -2,7 +2,7 @@ pub mod factory;
pub mod parsers; pub mod parsers;
pub mod traits; pub mod traits;
pub use factory::{ParserFactory, ParserRegistry}; pub use factory::{ParserFactory, ParserRegistry, PooledParser};
pub use parsers::{ pub use parsers::{
BaseReasoningParser, DeepSeekR1Parser, KimiParser, Qwen3Parser, QwenThinkingParser, BaseReasoningParser, DeepSeekR1Parser, KimiParser, Qwen3Parser, QwenThinkingParser,
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment