factory.rs 21.3 KB
Newer Older
1
// Factory and registry for creating model-specific reasoning parsers.
2
// Now with parser pooling support for efficient reuse across requests.
3
4

use std::collections::HashMap;
5
6
7
use std::sync::{Arc, RwLock};

use tokio::sync::Mutex;
8

9
use crate::reasoning_parser::parsers::{
10
11
    BaseReasoningParser, DeepSeekR1Parser, Glm45Parser, KimiParser, Qwen3Parser,
    QwenThinkingParser, Step3Parser,
12
};
13
14
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ReasoningParser};

15
/// Type alias for pooled parser instances.
16
/// Uses tokio::Mutex to avoid blocking the async executor.
17
18
pub type PooledParser = Arc<Mutex<Box<dyn ReasoningParser>>>;

19
20
21
/// Type alias for parser creator functions.
type ParserCreator = Arc<dyn Fn() -> Box<dyn ReasoningParser> + Send + Sync>;

22
/// Registry for model-specific parsers with pooling support.
23
24
#[derive(Clone)]
pub struct ParserRegistry {
25
26
27
28
29
    /// Creator functions for parsers (used when pool is empty)
    creators: Arc<RwLock<HashMap<String, ParserCreator>>>,
    /// Pooled parser instances for reuse
    pool: Arc<RwLock<HashMap<String, PooledParser>>>,
    /// Model pattern to parser name mappings
30
31
32
33
34
35
36
    patterns: Arc<RwLock<Vec<(String, String)>>>, // (pattern, parser_name)
}

impl ParserRegistry {
    /// Create a new empty registry.
    pub fn new() -> Self {
        Self {
37
38
            creators: Arc::new(RwLock::new(HashMap::new())),
            pool: Arc::new(RwLock::new(HashMap::new())),
39
40
41
42
43
44
45
46
47
            patterns: Arc::new(RwLock::new(Vec::new())),
        }
    }

    /// Register a parser creator for a given parser type.
    pub fn register_parser<F>(&self, name: &str, creator: F)
    where
        F: Fn() -> Box<dyn ReasoningParser> + Send + Sync + 'static,
    {
48
49
        let mut creators = self.creators.write().unwrap();
        creators.insert(name.to_string(), Arc::new(creator));
50
51
52
53
54
55
56
57
58
    }

    /// Register a model pattern to parser mapping.
    /// Patterns are checked in order, first match wins.
    pub fn register_pattern(&self, pattern: &str, parser_name: &str) {
        let mut patterns = self.patterns.write().unwrap();
        patterns.push((pattern.to_string(), parser_name.to_string()));
    }

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    /// Get a pooled parser by exact name.
    /// Returns a shared parser instance from the pool, creating one if needed.
    pub fn get_pooled_parser(&self, name: &str) -> Option<PooledParser> {
        // First check if we have a pooled instance
        {
            let pool = self.pool.read().unwrap();
            if let Some(parser) = pool.get(name) {
                return Some(Arc::clone(parser));
            }
        }

        // If not in pool, create one and add to pool
        let creators = self.creators.read().unwrap();
        if let Some(creator) = creators.get(name) {
            let parser = Arc::new(Mutex::new(creator()));

            // Add to pool for future use
            let mut pool = self.pool.write().unwrap();
            pool.insert(name.to_string(), Arc::clone(&parser));

            Some(parser)
        } else {
            None
        }
    }

85
86
87
88
89
90
91
92
93
    /// Check if a parser with the given name is registered.
    pub fn has_parser(&self, name: &str) -> bool {
        let creators = self.creators.read().unwrap();
        creators.contains_key(name)
    }

    /// Create a fresh parser instance by exact name (not pooled).
    /// Returns a new parser instance for each call - useful for streaming where state isolation is needed.
    pub fn create_parser(&self, name: &str) -> Option<Box<dyn ReasoningParser>> {
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
        let creators = self.creators.read().unwrap();
        creators.get(name).map(|creator| creator())
    }

    /// Find a pooled parser for a given model ID by pattern matching.
    pub fn find_pooled_parser_for_model(&self, model_id: &str) -> Option<PooledParser> {
        let patterns = self.patterns.read().unwrap();
        let model_lower = model_id.to_lowercase();

        for (pattern, parser_name) in patterns.iter() {
            if model_lower.contains(&pattern.to_lowercase()) {
                return self.get_pooled_parser(parser_name);
            }
        }
        None
109
110
    }

111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    /// Check if a parser can be created for a specific model without actually creating it.
    /// Returns true if a parser is available (registered) for this model.
    pub fn has_parser_for_model(&self, model_id: &str) -> bool {
        let patterns = self.patterns.read().unwrap();
        let model_lower = model_id.to_lowercase();

        for (pattern, parser_name) in patterns.iter() {
            if model_lower.contains(&pattern.to_lowercase()) {
                let creators = self.creators.read().unwrap();
                return creators.contains_key(parser_name);
            }
        }
        false
    }

    /// Create a fresh parser instance for a given model ID by pattern matching (not pooled).
    /// Returns a new parser instance for each call - useful for streaming where state isolation is needed.
    pub fn create_for_model(&self, model_id: &str) -> Option<Box<dyn ReasoningParser>> {
129
130
131
132
133
        let patterns = self.patterns.read().unwrap();
        let model_lower = model_id.to_lowercase();

        for (pattern, parser_name) in patterns.iter() {
            if model_lower.contains(&pattern.to_lowercase()) {
134
                return self.create_parser(parser_name);
135
136
137
138
            }
        }
        None
    }
139
140
141
142
143
144
145

    /// Clear the parser pool, forcing new instances to be created.
    /// Useful for testing or when parsers need to be reset globally.
    pub fn clear_pool(&self) {
        let mut pool = self.pool.write().unwrap();
        pool.clear();
    }
146
147
148
149
150
151
152
153
154
}

impl Default for ParserRegistry {
    fn default() -> Self {
        Self::new()
    }
}

/// Factory for creating reasoning parsers based on model type.
155
#[derive(Clone)]
156
pub struct ParserFactory {
157
158
159
    registry: ParserRegistry,
}

160
impl ParserFactory {
161
162
163
164
165
166
167
168
169
    /// Create a new factory with default parsers registered.
    pub fn new() -> Self {
        let registry = ParserRegistry::new();

        // Register base parser
        registry.register_parser("base", || {
            Box::new(BaseReasoningParser::new(ParserConfig::default()))
        });

170
171
        // Register DeepSeek-R1 parser (starts with in_reasoning=true)
        registry.register_parser("deepseek_r1", || Box::new(DeepSeekR1Parser::new()));
172

173
174
        // Register Qwen3 parser (starts with in_reasoning=false)
        registry.register_parser("qwen3", || Box::new(Qwen3Parser::new()));
175

176
177
        // Register Qwen3-thinking parser (starts with in_reasoning=true)
        registry.register_parser("qwen3_thinking", || Box::new(QwenThinkingParser::new()));
178

179
180
        // Register Kimi parser with Unicode tokens (starts with in_reasoning=false)
        registry.register_parser("kimi", || Box::new(KimiParser::new()));
181

182
183
184
185
186
187
        // Register GLM45 parser (same format as Qwen3 but separate for debugging)
        registry.register_parser("glm45", || Box::new(Glm45Parser::new()));

        // Register Step3 parser (same format as DeepSeek-R1 but separate for debugging)
        registry.register_parser("step3", || Box::new(Step3Parser::new()));

188
189
190
191
192
193
        // Register model patterns
        registry.register_pattern("deepseek-r1", "deepseek_r1");
        registry.register_pattern("qwen3-thinking", "qwen3_thinking");
        registry.register_pattern("qwen-thinking", "qwen3_thinking");
        registry.register_pattern("qwen3", "qwen3");
        registry.register_pattern("qwen", "qwen3");
194
        registry.register_pattern("glm45", "glm45");
195
        registry.register_pattern("kimi", "kimi");
196
        registry.register_pattern("step3", "step3");
197
198
199
200

        Self { registry }
    }

201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
    /// Get a pooled parser for the given model ID.
    /// Returns a shared instance that can be used concurrently.
    /// Falls back to a passthrough parser if model is not recognized.
    pub fn get_pooled(&self, model_id: &str) -> PooledParser {
        // First try to find by pattern
        if let Some(parser) = self.registry.find_pooled_parser_for_model(model_id) {
            return parser;
        }

        // Fall back to no-op parser (get or create passthrough in pool)
        self.registry
            .get_pooled_parser("passthrough")
            .unwrap_or_else(|| {
                // Register passthrough if not already registered
                self.registry.register_parser("passthrough", || {
                    let config = ParserConfig {
                        think_start_token: "".to_string(),
                        think_end_token: "".to_string(),
                        stream_reasoning: true,
                        max_buffer_size: 65536,
                        initial_in_reasoning: false,
                    };
                    Box::new(
                        BaseReasoningParser::new(config).with_model_type("passthrough".to_string()),
                    )
                });
                self.registry.get_pooled_parser("passthrough").unwrap()
            })
    }

    /// Create a new parser instance for the given model ID.
    /// Returns a fresh instance (not pooled).
    /// Use this when you need an isolated parser instance.
234
235
    pub fn create(&self, model_id: &str) -> Result<Box<dyn ReasoningParser>, ParseError> {
        // First try to find by pattern
236
        if let Some(parser) = self.registry.create_for_model(model_id) {
237
238
239
240
241
242
243
244
245
            return Ok(parser);
        }

        // Fall back to no-op parser (base parser without reasoning detection)
        let config = ParserConfig {
            think_start_token: "".to_string(),
            think_end_token: "".to_string(),
            stream_reasoning: true,
            max_buffer_size: 65536,
246
            initial_in_reasoning: false,
247
248
249
250
251
252
253
254
255
256
        };
        Ok(Box::new(
            BaseReasoningParser::new(config).with_model_type("passthrough".to_string()),
        ))
    }

    /// Get the internal registry for custom registration.
    pub fn registry(&self) -> &ParserRegistry {
        &self.registry
    }
257
258
259
260
261
262

    /// Clear the parser pool.
    /// Useful for testing or when parsers need to be reset globally.
    pub fn clear_pool(&self) {
        self.registry.clear_pool();
    }
263
264
}

265
impl Default for ParserFactory {
266
267
268
269
270
271
272
273
274
275
276
    fn default() -> Self {
        Self::new()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_factory_creates_deepseek_r1() {
277
        let factory = ParserFactory::new();
278
279
280
281
282
283
        let parser = factory.create("deepseek-r1-distill").unwrap();
        assert_eq!(parser.model_type(), "deepseek_r1");
    }

    #[test]
    fn test_factory_creates_qwen3() {
284
        let factory = ParserFactory::new();
285
286
287
288
289
290
        let parser = factory.create("qwen3-7b").unwrap();
        assert_eq!(parser.model_type(), "qwen3");
    }

    #[test]
    fn test_factory_creates_kimi() {
291
        let factory = ParserFactory::new();
292
293
294
295
296
297
        let parser = factory.create("kimi-chat").unwrap();
        assert_eq!(parser.model_type(), "kimi");
    }

    #[test]
    fn test_factory_fallback_to_passthrough() {
298
        let factory = ParserFactory::new();
299
300
301
302
303
304
        let parser = factory.create("unknown-model").unwrap();
        assert_eq!(parser.model_type(), "passthrough");
    }

    #[test]
    fn test_case_insensitive_matching() {
305
        let factory = ParserFactory::new();
306
307
308
309
310
311
312
313
314
315
        let parser1 = factory.create("DeepSeek-R1").unwrap();
        let parser2 = factory.create("QWEN3").unwrap();
        let parser3 = factory.create("Kimi").unwrap();

        assert_eq!(parser1.model_type(), "deepseek_r1");
        assert_eq!(parser2.model_type(), "qwen3");
        assert_eq!(parser3.model_type(), "kimi");
    }

    #[test]
316
    fn test_step3_model() {
317
        let factory = ParserFactory::new();
318
        let step3 = factory.create("step3-model").unwrap();
319
320
        assert_eq!(step3.model_type(), "step3");
    }
321

322
323
    #[test]
    fn test_glm45_model() {
324
        let factory = ParserFactory::new();
325
326
        let glm45 = factory.create("glm45-v2").unwrap();
        assert_eq!(glm45.model_type(), "glm45");
327
    }
328

329
330
    #[tokio::test]
    async fn test_pooled_parser_reuse() {
331
        let factory = ParserFactory::new();
332
333
334
335
336
337
338
339
340
341
342
343
344

        // Get the same parser twice - should be the same instance
        let parser1 = factory.get_pooled("deepseek-r1");
        let parser2 = factory.get_pooled("deepseek-r1");

        // Both should point to the same Arc
        assert!(Arc::ptr_eq(&parser1, &parser2));

        // Different models should get different parsers
        let parser3 = factory.get_pooled("qwen3");
        assert!(!Arc::ptr_eq(&parser1, &parser3));
    }

345
346
    #[tokio::test]
    async fn test_pooled_parser_concurrent_access() {
347
        let factory = ParserFactory::new();
348
349
        let parser = factory.get_pooled("deepseek-r1");

350
        // Spawn multiple async tasks that use the same parser
351
352
353
354
        let mut handles = vec![];

        for i in 0..3 {
            let parser_clone = Arc::clone(&parser);
355
356
            let handle = tokio::spawn(async move {
                let mut parser = parser_clone.lock().await;
357
358
359
360
361
362
363
364
                let input = format!("thread {} reasoning</think>answer", i);
                let result = parser.detect_and_parse_reasoning(&input).unwrap();
                assert_eq!(result.normal_text, "answer");
                assert!(result.reasoning_text.contains("reasoning"));
            });
            handles.push(handle);
        }

365
        // Wait for all tasks to complete
366
        for handle in handles {
367
            handle.await.unwrap();
368
369
370
        }
    }

371
372
    #[tokio::test]
    async fn test_pool_clearing() {
373
        let factory = ParserFactory::new();
374
375
376
377
378
379
380
381
382
383
384
385
386
387

        // Get a pooled parser
        let parser1 = factory.get_pooled("deepseek-r1");

        // Clear the pool
        factory.clear_pool();

        // Get another parser - should be a new instance
        let parser2 = factory.get_pooled("deepseek-r1");

        // They should be different instances (different Arc pointers)
        assert!(!Arc::ptr_eq(&parser1, &parser2));
    }

388
389
    #[tokio::test]
    async fn test_passthrough_parser_pooling() {
390
        let factory = ParserFactory::new();
391
392
393
394
395
396
397
398

        // Unknown models should get passthrough parser
        let parser1 = factory.get_pooled("unknown-model-1");
        let parser2 = factory.get_pooled("unknown-model-2");

        // Both should use the same passthrough parser instance
        assert!(Arc::ptr_eq(&parser1, &parser2));

399
        let parser = parser1.lock().await;
400
401
402
        assert_eq!(parser.model_type(), "passthrough");
    }

403
404
    #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
    async fn test_high_concurrency_parser_access() {
405
406
407
        use std::sync::atomic::{AtomicUsize, Ordering};
        use std::time::Instant;

408
        let factory = ParserFactory::new();
409
410
        let num_tasks = 100;
        let requests_per_task = 50;
411
412
413
414
415
416
417
418
419
        let models = vec!["deepseek-r1", "qwen3", "kimi", "qwen3-thinking"];

        // Track successful operations
        let success_count = Arc::new(AtomicUsize::new(0));
        let error_count = Arc::new(AtomicUsize::new(0));

        let start = Instant::now();
        let mut handles = vec![];

420
        for task_id in 0..num_tasks {
421
422
423
424
425
            let factory = factory.clone();
            let models = models.clone();
            let success_count = Arc::clone(&success_count);
            let error_count = Arc::clone(&error_count);

426
427
            let handle = tokio::spawn(async move {
                for request_id in 0..requests_per_task {
428
                    // Rotate through different models
429
                    let model = &models[(task_id + request_id) % models.len()];
430
431
                    let parser = factory.get_pooled(model);

432
433
                    // Use async lock - tokio::Mutex doesn't poison
                    let mut p = parser.lock().await;
434
435
436
437

                    // Simulate realistic parsing work with substantial text
                    // Typical reasoning can be 500-5000 tokens
                    let reasoning_text = format!(
438
                        "Task {} is processing request {}. Let me think through this step by step. \
439
440
441
442
443
444
445
446
447
448
449
                        First, I need to understand the problem. The problem involves analyzing data \
                        and making calculations. Let me break this down: \n\
                        1. Initial analysis shows that we have multiple variables to consider. \
                        2. The data suggests a pattern that needs further investigation. \
                        3. Computing the values: {} * {} = {}. \
                        4. Cross-referencing with previous results indicates consistency. \
                        5. The mathematical proof follows from the axioms... \
                        6. Considering edge cases and boundary conditions... \
                        7. Validating against known constraints... \
                        8. The conclusion follows logically from premises A, B, and C. \
                        This reasoning chain demonstrates the validity of our approach.",
450
                        task_id, request_id, task_id, request_id, task_id * request_id
451
452
453
                    );

                    let answer_text = format!(
454
                        "Based on my analysis, the answer for task {} request {} is: \
455
456
457
458
459
                        The solution involves multiple steps as outlined in the reasoning. \
                        The final result is {} with confidence level high. \
                        This conclusion is supported by rigorous mathematical analysis \
                        and has been validated against multiple test cases. \
                        The implementation should handle edge cases appropriately.",
460
                        task_id,
461
                        request_id,
462
                        task_id * request_id
463
464
465
466
467
468
469
                    );

                    let input = format!("<think>{}</think>{}", reasoning_text, answer_text);

                    match p.detect_and_parse_reasoning(&input) {
                        Ok(result) => {
                            // Note: Some parsers with stream_reasoning=true won't accumulate reasoning text
470
                            assert!(result.normal_text.contains(&format!("task {}", task_id)));
471
472
473
474
475
476

                            // For parsers that accumulate reasoning (stream_reasoning=false)
                            // the reasoning_text should be populated
                            if !result.reasoning_text.is_empty() {
                                assert!(result
                                    .reasoning_text
477
                                    .contains(&format!("Task {}", task_id)));
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
                                assert!(result.reasoning_text.len() > 500); // Ensure substantial reasoning
                            }

                            // Normal text should always be present
                            assert!(result.normal_text.len() > 100); // Ensure substantial answer
                            success_count.fetch_add(1, Ordering::Relaxed);
                        }
                        Err(e) => {
                            eprintln!("Parse error: {:?}", e);
                            error_count.fetch_add(1, Ordering::Relaxed);
                        }
                    }

                    // Explicitly drop the lock to release it quickly
                    drop(p);
                }
            });
            handles.push(handle);
        }

498
        // Wait for all tasks
499
        for handle in handles {
500
            handle.await.unwrap();
501
502
503
        }

        let duration = start.elapsed();
504
        let total_requests = num_tasks * requests_per_task;
505
506
507
508
509
        let successes = success_count.load(Ordering::Relaxed);
        let errors = error_count.load(Ordering::Relaxed);

        // Print stats for debugging
        println!(
510
511
            "High concurrency test: {} tasks, {} requests each",
            num_tasks, requests_per_task
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
        );
        println!(
            "Completed in {:?}, {} successes, {} errors",
            duration, successes, errors
        );
        println!(
            "Throughput: {:.0} requests/sec",
            (total_requests as f64) / duration.as_secs_f64()
        );

        // All requests should succeed
        assert_eq!(successes, total_requests);
        assert_eq!(errors, 0);

        // Performance check: should handle at least 1000 req/sec
        let throughput = (total_requests as f64) / duration.as_secs_f64();
        assert!(
            throughput > 1000.0,
            "Throughput too low: {:.0} req/sec",
            throughput
        );
    }

535
536
    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
    async fn test_concurrent_pool_modifications() {
537
        let factory = ParserFactory::new();
538
539
        let mut handles = vec![];

540
        // Task 1: Continuously get parsers
541
        let factory1 = factory.clone();
542
        handles.push(tokio::spawn(async move {
543
544
545
546
547
            for _ in 0..100 {
                let _parser = factory1.get_pooled("deepseek-r1");
            }
        }));

548
        // Task 2: Continuously clear pool
549
        let factory2 = factory.clone();
550
        handles.push(tokio::spawn(async move {
551
552
            for _ in 0..10 {
                factory2.clear_pool();
553
                tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
554
555
556
            }
        }));

557
        // Task 3: Get different parsers
558
        let factory3 = factory.clone();
559
        handles.push(tokio::spawn(async move {
560
561
562
563
564
565
            for i in 0..100 {
                let models = ["qwen3", "kimi", "unknown"];
                let _parser = factory3.get_pooled(models[i % 3]);
            }
        }));

566
        // Wait for all tasks - should not deadlock or panic
567
        for handle in handles {
568
            handle.await.unwrap();
569
570
        }
    }
571
}