factory.rs 20.9 KB
Newer Older
1
// Factory and registry for creating model-specific reasoning parsers.
2
// Now with parser pooling support for efficient reuse across requests.
3
4

use std::collections::HashMap;
5
use std::sync::{Arc, Mutex, RwLock};
6

7
use crate::reasoning_parser::parsers::{
8
9
    BaseReasoningParser, DeepSeekR1Parser, Glm45Parser, KimiParser, Qwen3Parser,
    QwenThinkingParser, Step3Parser,
10
};
11
12
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ReasoningParser};

13
14
15
/// Type alias for pooled parser instances.
pub type PooledParser = Arc<Mutex<Box<dyn ReasoningParser>>>;

16
17
18
/// Type alias for parser creator functions.
type ParserCreator = Arc<dyn Fn() -> Box<dyn ReasoningParser> + Send + Sync>;

19
/// Registry for model-specific parsers with pooling support.
20
21
#[derive(Clone)]
pub struct ParserRegistry {
22
23
24
25
26
    /// Creator functions for parsers (used when pool is empty)
    creators: Arc<RwLock<HashMap<String, ParserCreator>>>,
    /// Pooled parser instances for reuse
    pool: Arc<RwLock<HashMap<String, PooledParser>>>,
    /// Model pattern to parser name mappings
27
28
29
30
31
32
33
    patterns: Arc<RwLock<Vec<(String, String)>>>, // (pattern, parser_name)
}

impl ParserRegistry {
    /// Create a new empty registry.
    pub fn new() -> Self {
        Self {
34
35
            creators: Arc::new(RwLock::new(HashMap::new())),
            pool: Arc::new(RwLock::new(HashMap::new())),
36
37
38
39
40
41
42
43
44
            patterns: Arc::new(RwLock::new(Vec::new())),
        }
    }

    /// Register a parser creator for a given parser type.
    pub fn register_parser<F>(&self, name: &str, creator: F)
    where
        F: Fn() -> Box<dyn ReasoningParser> + Send + Sync + 'static,
    {
45
46
        let mut creators = self.creators.write().unwrap();
        creators.insert(name.to_string(), Arc::new(creator));
47
48
49
50
51
52
53
54
55
    }

    /// Register a model pattern to parser mapping.
    /// Patterns are checked in order, first match wins.
    pub fn register_pattern(&self, pattern: &str, parser_name: &str) {
        let mut patterns = self.patterns.write().unwrap();
        patterns.push((pattern.to_string(), parser_name.to_string()));
    }

56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    /// Get a pooled parser by exact name.
    /// Returns a shared parser instance from the pool, creating one if needed.
    pub fn get_pooled_parser(&self, name: &str) -> Option<PooledParser> {
        // First check if we have a pooled instance
        {
            let pool = self.pool.read().unwrap();
            if let Some(parser) = pool.get(name) {
                return Some(Arc::clone(parser));
            }
        }

        // If not in pool, create one and add to pool
        let creators = self.creators.read().unwrap();
        if let Some(creator) = creators.get(name) {
            let parser = Arc::new(Mutex::new(creator()));

            // Add to pool for future use
            let mut pool = self.pool.write().unwrap();
            pool.insert(name.to_string(), Arc::clone(&parser));

            Some(parser)
        } else {
            None
        }
    }

    /// Get a parser by exact name (creates new instance, not pooled).
    /// Use this for compatibility or when you need a fresh instance.
84
    pub fn get_parser(&self, name: &str) -> Option<Box<dyn ReasoningParser>> {
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
        let creators = self.creators.read().unwrap();
        creators.get(name).map(|creator| creator())
    }

    /// Find a pooled parser for a given model ID by pattern matching.
    pub fn find_pooled_parser_for_model(&self, model_id: &str) -> Option<PooledParser> {
        let patterns = self.patterns.read().unwrap();
        let model_lower = model_id.to_lowercase();

        for (pattern, parser_name) in patterns.iter() {
            if model_lower.contains(&pattern.to_lowercase()) {
                return self.get_pooled_parser(parser_name);
            }
        }
        None
100
101
    }

102
    /// Find a parser for a given model ID by pattern matching (creates new instance).
103
104
105
106
107
108
109
110
111
112
113
    pub fn find_parser_for_model(&self, model_id: &str) -> Option<Box<dyn ReasoningParser>> {
        let patterns = self.patterns.read().unwrap();
        let model_lower = model_id.to_lowercase();

        for (pattern, parser_name) in patterns.iter() {
            if model_lower.contains(&pattern.to_lowercase()) {
                return self.get_parser(parser_name);
            }
        }
        None
    }
114
115
116
117
118
119
120

    /// Clear the parser pool, forcing new instances to be created.
    /// Useful for testing or when parsers need to be reset globally.
    pub fn clear_pool(&self) {
        let mut pool = self.pool.write().unwrap();
        pool.clear();
    }
121
122
123
124
125
126
127
128
129
}

impl Default for ParserRegistry {
    fn default() -> Self {
        Self::new()
    }
}

/// Factory for creating reasoning parsers based on model type.
130
#[derive(Clone)]
131
132
133
134
135
136
137
138
139
140
141
142
143
144
pub struct ParserFactory {
    registry: ParserRegistry,
}

impl ParserFactory {
    /// Create a new factory with default parsers registered.
    pub fn new() -> Self {
        let registry = ParserRegistry::new();

        // Register base parser
        registry.register_parser("base", || {
            Box::new(BaseReasoningParser::new(ParserConfig::default()))
        });

145
146
        // Register DeepSeek-R1 parser (starts with in_reasoning=true)
        registry.register_parser("deepseek_r1", || Box::new(DeepSeekR1Parser::new()));
147

148
149
        // Register Qwen3 parser (starts with in_reasoning=false)
        registry.register_parser("qwen3", || Box::new(Qwen3Parser::new()));
150

151
152
        // Register Qwen3-thinking parser (starts with in_reasoning=true)
        registry.register_parser("qwen3_thinking", || Box::new(QwenThinkingParser::new()));
153

154
155
        // Register Kimi parser with Unicode tokens (starts with in_reasoning=false)
        registry.register_parser("kimi", || Box::new(KimiParser::new()));
156

157
158
159
160
161
162
        // Register GLM45 parser (same format as Qwen3 but separate for debugging)
        registry.register_parser("glm45", || Box::new(Glm45Parser::new()));

        // Register Step3 parser (same format as DeepSeek-R1 but separate for debugging)
        registry.register_parser("step3", || Box::new(Step3Parser::new()));

163
164
165
166
167
168
        // Register model patterns
        registry.register_pattern("deepseek-r1", "deepseek_r1");
        registry.register_pattern("qwen3-thinking", "qwen3_thinking");
        registry.register_pattern("qwen-thinking", "qwen3_thinking");
        registry.register_pattern("qwen3", "qwen3");
        registry.register_pattern("qwen", "qwen3");
169
        registry.register_pattern("glm45", "glm45");
170
        registry.register_pattern("kimi", "kimi");
171
        registry.register_pattern("step3", "step3");
172
173
174
175

        Self { registry }
    }

176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
    /// Get a pooled parser for the given model ID.
    /// Returns a shared instance that can be used concurrently.
    /// Falls back to a passthrough parser if model is not recognized.
    pub fn get_pooled(&self, model_id: &str) -> PooledParser {
        // First try to find by pattern
        if let Some(parser) = self.registry.find_pooled_parser_for_model(model_id) {
            return parser;
        }

        // Fall back to no-op parser (get or create passthrough in pool)
        self.registry
            .get_pooled_parser("passthrough")
            .unwrap_or_else(|| {
                // Register passthrough if not already registered
                self.registry.register_parser("passthrough", || {
                    let config = ParserConfig {
                        think_start_token: "".to_string(),
                        think_end_token: "".to_string(),
                        stream_reasoning: true,
                        max_buffer_size: 65536,
                        initial_in_reasoning: false,
                    };
                    Box::new(
                        BaseReasoningParser::new(config).with_model_type("passthrough".to_string()),
                    )
                });
                self.registry.get_pooled_parser("passthrough").unwrap()
            })
    }

    /// Create a new parser instance for the given model ID.
    /// Returns a fresh instance (not pooled).
    /// Use this when you need an isolated parser instance.
209
210
211
212
213
214
215
216
217
218
219
220
    pub fn create(&self, model_id: &str) -> Result<Box<dyn ReasoningParser>, ParseError> {
        // First try to find by pattern
        if let Some(parser) = self.registry.find_parser_for_model(model_id) {
            return Ok(parser);
        }

        // Fall back to no-op parser (base parser without reasoning detection)
        let config = ParserConfig {
            think_start_token: "".to_string(),
            think_end_token: "".to_string(),
            stream_reasoning: true,
            max_buffer_size: 65536,
221
            initial_in_reasoning: false,
222
223
224
225
226
227
228
229
230
231
        };
        Ok(Box::new(
            BaseReasoningParser::new(config).with_model_type("passthrough".to_string()),
        ))
    }

    /// Get the internal registry for custom registration.
    pub fn registry(&self) -> &ParserRegistry {
        &self.registry
    }
232
233
234
235
236
237

    /// Clear the parser pool.
    /// Useful for testing or when parsers need to be reset globally.
    pub fn clear_pool(&self) {
        self.registry.clear_pool();
    }
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
}

impl Default for ParserFactory {
    fn default() -> Self {
        Self::new()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_factory_creates_deepseek_r1() {
        let factory = ParserFactory::new();
        let parser = factory.create("deepseek-r1-distill").unwrap();
        assert_eq!(parser.model_type(), "deepseek_r1");
    }

    #[test]
    fn test_factory_creates_qwen3() {
        let factory = ParserFactory::new();
        let parser = factory.create("qwen3-7b").unwrap();
        assert_eq!(parser.model_type(), "qwen3");
    }

    #[test]
    fn test_factory_creates_kimi() {
        let factory = ParserFactory::new();
        let parser = factory.create("kimi-chat").unwrap();
        assert_eq!(parser.model_type(), "kimi");
    }

    #[test]
    fn test_factory_fallback_to_passthrough() {
        let factory = ParserFactory::new();
        let parser = factory.create("unknown-model").unwrap();
        assert_eq!(parser.model_type(), "passthrough");
    }

    #[test]
    fn test_case_insensitive_matching() {
        let factory = ParserFactory::new();
        let parser1 = factory.create("DeepSeek-R1").unwrap();
        let parser2 = factory.create("QWEN3").unwrap();
        let parser3 = factory.create("Kimi").unwrap();

        assert_eq!(parser1.model_type(), "deepseek_r1");
        assert_eq!(parser2.model_type(), "qwen3");
        assert_eq!(parser3.model_type(), "kimi");
    }

    #[test]
291
    fn test_step3_model() {
292
293
        let factory = ParserFactory::new();
        let step3 = factory.create("step3-model").unwrap();
294
295
        assert_eq!(step3.model_type(), "step3");
    }
296

297
298
299
300
301
    #[test]
    fn test_glm45_model() {
        let factory = ParserFactory::new();
        let glm45 = factory.create("glm45-v2").unwrap();
        assert_eq!(glm45.model_type(), "glm45");
302
    }
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563

    #[test]
    fn test_pooled_parser_reuse() {
        let factory = ParserFactory::new();

        // Get the same parser twice - should be the same instance
        let parser1 = factory.get_pooled("deepseek-r1");
        let parser2 = factory.get_pooled("deepseek-r1");

        // Both should point to the same Arc
        assert!(Arc::ptr_eq(&parser1, &parser2));

        // Different models should get different parsers
        let parser3 = factory.get_pooled("qwen3");
        assert!(!Arc::ptr_eq(&parser1, &parser3));
    }

    #[test]
    fn test_pooled_parser_concurrent_access() {
        use std::thread;

        let factory = ParserFactory::new();
        let parser = factory.get_pooled("deepseek-r1");

        // Spawn multiple threads that use the same parser
        let mut handles = vec![];

        for i in 0..3 {
            let parser_clone = Arc::clone(&parser);
            let handle = thread::spawn(move || {
                let mut parser = parser_clone.lock().unwrap();
                let input = format!("thread {} reasoning</think>answer", i);
                let result = parser.detect_and_parse_reasoning(&input).unwrap();
                assert_eq!(result.normal_text, "answer");
                assert!(result.reasoning_text.contains("reasoning"));
            });
            handles.push(handle);
        }

        // Wait for all threads to complete
        for handle in handles {
            handle.join().unwrap();
        }
    }

    #[test]
    fn test_pool_clearing() {
        let factory = ParserFactory::new();

        // Get a pooled parser
        let parser1 = factory.get_pooled("deepseek-r1");

        // Clear the pool
        factory.clear_pool();

        // Get another parser - should be a new instance
        let parser2 = factory.get_pooled("deepseek-r1");

        // They should be different instances (different Arc pointers)
        assert!(!Arc::ptr_eq(&parser1, &parser2));
    }

    #[test]
    fn test_passthrough_parser_pooling() {
        let factory = ParserFactory::new();

        // Unknown models should get passthrough parser
        let parser1 = factory.get_pooled("unknown-model-1");
        let parser2 = factory.get_pooled("unknown-model-2");

        // Both should use the same passthrough parser instance
        assert!(Arc::ptr_eq(&parser1, &parser2));

        let parser = parser1.lock().unwrap();
        assert_eq!(parser.model_type(), "passthrough");
    }

    #[test]
    fn test_high_concurrency_parser_access() {
        use std::sync::atomic::{AtomicUsize, Ordering};
        use std::thread;
        use std::time::Instant;

        let factory = ParserFactory::new();
        let num_threads = 100;
        let requests_per_thread = 50;
        let models = vec!["deepseek-r1", "qwen3", "kimi", "qwen3-thinking"];

        // Track successful operations
        let success_count = Arc::new(AtomicUsize::new(0));
        let error_count = Arc::new(AtomicUsize::new(0));

        let start = Instant::now();
        let mut handles = vec![];

        for thread_id in 0..num_threads {
            let factory = factory.clone();
            let models = models.clone();
            let success_count = Arc::clone(&success_count);
            let error_count = Arc::clone(&error_count);

            let handle = thread::spawn(move || {
                for request_id in 0..requests_per_thread {
                    // Rotate through different models
                    let model = &models[(thread_id + request_id) % models.len()];
                    let parser = factory.get_pooled(model);

                    // Use blocking lock - this is the realistic scenario
                    // In production, requests would wait for the parser to be available
                    // Handle poisoned locks gracefully
                    let mut p = match parser.lock() {
                        Ok(guard) => guard,
                        Err(_poisoned) => {
                            // Lock was poisoned by a panicking thread
                            // In production, we might want to recreate the parser
                            // For testing, we'll just skip this iteration
                            error_count.fetch_add(1, Ordering::Relaxed);
                            continue;
                        }
                    };

                    // Simulate realistic parsing work with substantial text
                    // Typical reasoning can be 500-5000 tokens
                    let reasoning_text = format!(
                        "Thread {} is processing request {}. Let me think through this step by step. \
                        First, I need to understand the problem. The problem involves analyzing data \
                        and making calculations. Let me break this down: \n\
                        1. Initial analysis shows that we have multiple variables to consider. \
                        2. The data suggests a pattern that needs further investigation. \
                        3. Computing the values: {} * {} = {}. \
                        4. Cross-referencing with previous results indicates consistency. \
                        5. The mathematical proof follows from the axioms... \
                        6. Considering edge cases and boundary conditions... \
                        7. Validating against known constraints... \
                        8. The conclusion follows logically from premises A, B, and C. \
                        This reasoning chain demonstrates the validity of our approach.",
                        thread_id, request_id, thread_id, request_id, thread_id * request_id
                    );

                    let answer_text = format!(
                        "Based on my analysis, the answer for thread {} request {} is: \
                        The solution involves multiple steps as outlined in the reasoning. \
                        The final result is {} with confidence level high. \
                        This conclusion is supported by rigorous mathematical analysis \
                        and has been validated against multiple test cases. \
                        The implementation should handle edge cases appropriately.",
                        thread_id,
                        request_id,
                        thread_id * request_id
                    );

                    let input = format!("<think>{}</think>{}", reasoning_text, answer_text);

                    match p.detect_and_parse_reasoning(&input) {
                        Ok(result) => {
                            // Note: Some parsers with stream_reasoning=true won't accumulate reasoning text
                            assert!(result
                                .normal_text
                                .contains(&format!("thread {}", thread_id)));

                            // For parsers that accumulate reasoning (stream_reasoning=false)
                            // the reasoning_text should be populated
                            if !result.reasoning_text.is_empty() {
                                assert!(result
                                    .reasoning_text
                                    .contains(&format!("Thread {}", thread_id)));
                                assert!(result.reasoning_text.len() > 500); // Ensure substantial reasoning
                            }

                            // Normal text should always be present
                            assert!(result.normal_text.len() > 100); // Ensure substantial answer
                            success_count.fetch_add(1, Ordering::Relaxed);
                        }
                        Err(e) => {
                            eprintln!("Parse error: {:?}", e);
                            error_count.fetch_add(1, Ordering::Relaxed);
                        }
                    }

                    // Explicitly drop the lock to release it quickly
                    drop(p);
                }
            });
            handles.push(handle);
        }

        // Wait for all threads
        for handle in handles {
            handle.join().unwrap();
        }

        let duration = start.elapsed();
        let total_requests = num_threads * requests_per_thread;
        let successes = success_count.load(Ordering::Relaxed);
        let errors = error_count.load(Ordering::Relaxed);

        // Print stats for debugging
        println!(
            "High concurrency test: {} threads, {} requests each",
            num_threads, requests_per_thread
        );
        println!(
            "Completed in {:?}, {} successes, {} errors",
            duration, successes, errors
        );
        println!(
            "Throughput: {:.0} requests/sec",
            (total_requests as f64) / duration.as_secs_f64()
        );

        // All requests should succeed
        assert_eq!(successes, total_requests);
        assert_eq!(errors, 0);

        // Performance check: should handle at least 1000 req/sec
        let throughput = (total_requests as f64) / duration.as_secs_f64();
        assert!(
            throughput > 1000.0,
            "Throughput too low: {:.0} req/sec",
            throughput
        );
    }

    #[test]
    fn test_concurrent_pool_modifications() {
        use std::thread;

        let factory = ParserFactory::new();
        let mut handles = vec![];

        // Thread 1: Continuously get parsers
        let factory1 = factory.clone();
        handles.push(thread::spawn(move || {
            for _ in 0..100 {
                let _parser = factory1.get_pooled("deepseek-r1");
            }
        }));

        // Thread 2: Continuously clear pool
        let factory2 = factory.clone();
        handles.push(thread::spawn(move || {
            for _ in 0..10 {
                factory2.clear_pool();
                thread::sleep(std::time::Duration::from_micros(100));
            }
        }));

        // Thread 3: Get different parsers
        let factory3 = factory.clone();
        handles.push(thread::spawn(move || {
            for i in 0..100 {
                let models = ["qwen3", "kimi", "unknown"];
                let _parser = factory3.get_pooled(models[i % 3]);
            }
        }));

        // Wait for all threads - should not deadlock or panic
        for handle in handles {
            handle.join().unwrap();
        }
    }
564
}