factory.rs 21.3 KB
Newer Older
1
// Factory and registry for creating model-specific reasoning parsers.
2
// Now with parser pooling support for efficient reuse across requests.
3

4
5
6
7
use std::{
    collections::HashMap,
    sync::{Arc, RwLock},
};
8
9

use tokio::sync::Mutex;
10

11
12
13
14
15
16
use crate::reasoning_parser::{
    parsers::{
        BaseReasoningParser, DeepSeekR1Parser, Glm45Parser, KimiParser, Qwen3Parser,
        QwenThinkingParser, Step3Parser,
    },
    traits::{ParseError, ParserConfig, ReasoningParser},
17
};
18

19
/// Type alias for pooled parser instances.
20
/// Uses tokio::Mutex to avoid blocking the async executor.
21
22
pub type PooledParser = Arc<Mutex<Box<dyn ReasoningParser>>>;

23
24
25
/// Type alias for parser creator functions.
type ParserCreator = Arc<dyn Fn() -> Box<dyn ReasoningParser> + Send + Sync>;

26
/// Registry for model-specific parsers with pooling support.
27
28
#[derive(Clone)]
pub struct ParserRegistry {
29
30
31
32
33
    /// Creator functions for parsers (used when pool is empty)
    creators: Arc<RwLock<HashMap<String, ParserCreator>>>,
    /// Pooled parser instances for reuse
    pool: Arc<RwLock<HashMap<String, PooledParser>>>,
    /// Model pattern to parser name mappings
34
35
36
37
38
39
40
    patterns: Arc<RwLock<Vec<(String, String)>>>, // (pattern, parser_name)
}

impl ParserRegistry {
    /// Create a new empty registry.
    pub fn new() -> Self {
        Self {
41
42
            creators: Arc::new(RwLock::new(HashMap::new())),
            pool: Arc::new(RwLock::new(HashMap::new())),
43
44
45
46
47
48
49
50
51
            patterns: Arc::new(RwLock::new(Vec::new())),
        }
    }

    /// Register a parser creator for a given parser type.
    pub fn register_parser<F>(&self, name: &str, creator: F)
    where
        F: Fn() -> Box<dyn ReasoningParser> + Send + Sync + 'static,
    {
52
53
        let mut creators = self.creators.write().unwrap();
        creators.insert(name.to_string(), Arc::new(creator));
54
55
56
57
58
59
60
61
62
    }

    /// Register a model pattern to parser mapping.
    /// Patterns are checked in order, first match wins.
    pub fn register_pattern(&self, pattern: &str, parser_name: &str) {
        let mut patterns = self.patterns.write().unwrap();
        patterns.push((pattern.to_string(), parser_name.to_string()));
    }

63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    /// Get a pooled parser by exact name.
    /// Returns a shared parser instance from the pool, creating one if needed.
    pub fn get_pooled_parser(&self, name: &str) -> Option<PooledParser> {
        // First check if we have a pooled instance
        {
            let pool = self.pool.read().unwrap();
            if let Some(parser) = pool.get(name) {
                return Some(Arc::clone(parser));
            }
        }

        // If not in pool, create one and add to pool
        let creators = self.creators.read().unwrap();
        if let Some(creator) = creators.get(name) {
            let parser = Arc::new(Mutex::new(creator()));

            // Add to pool for future use
            let mut pool = self.pool.write().unwrap();
            pool.insert(name.to_string(), Arc::clone(&parser));

            Some(parser)
        } else {
            None
        }
    }

89
90
91
92
93
94
95
96
97
    /// Check if a parser with the given name is registered.
    pub fn has_parser(&self, name: &str) -> bool {
        let creators = self.creators.read().unwrap();
        creators.contains_key(name)
    }

    /// Create a fresh parser instance by exact name (not pooled).
    /// Returns a new parser instance for each call - useful for streaming where state isolation is needed.
    pub fn create_parser(&self, name: &str) -> Option<Box<dyn ReasoningParser>> {
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
        let creators = self.creators.read().unwrap();
        creators.get(name).map(|creator| creator())
    }

    /// Find a pooled parser for a given model ID by pattern matching.
    pub fn find_pooled_parser_for_model(&self, model_id: &str) -> Option<PooledParser> {
        let patterns = self.patterns.read().unwrap();
        let model_lower = model_id.to_lowercase();

        for (pattern, parser_name) in patterns.iter() {
            if model_lower.contains(&pattern.to_lowercase()) {
                return self.get_pooled_parser(parser_name);
            }
        }
        None
113
114
    }

115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    /// Check if a parser can be created for a specific model without actually creating it.
    /// Returns true if a parser is available (registered) for this model.
    pub fn has_parser_for_model(&self, model_id: &str) -> bool {
        let patterns = self.patterns.read().unwrap();
        let model_lower = model_id.to_lowercase();

        for (pattern, parser_name) in patterns.iter() {
            if model_lower.contains(&pattern.to_lowercase()) {
                let creators = self.creators.read().unwrap();
                return creators.contains_key(parser_name);
            }
        }
        false
    }

    /// Create a fresh parser instance for a given model ID by pattern matching (not pooled).
    /// Returns a new parser instance for each call - useful for streaming where state isolation is needed.
    pub fn create_for_model(&self, model_id: &str) -> Option<Box<dyn ReasoningParser>> {
133
134
135
136
137
        let patterns = self.patterns.read().unwrap();
        let model_lower = model_id.to_lowercase();

        for (pattern, parser_name) in patterns.iter() {
            if model_lower.contains(&pattern.to_lowercase()) {
138
                return self.create_parser(parser_name);
139
140
141
142
            }
        }
        None
    }
143
144
145
146
147
148
149

    /// Clear the parser pool, forcing new instances to be created.
    /// Useful for testing or when parsers need to be reset globally.
    pub fn clear_pool(&self) {
        let mut pool = self.pool.write().unwrap();
        pool.clear();
    }
150
151
152
153
154
155
156
157
158
}

impl Default for ParserRegistry {
    fn default() -> Self {
        Self::new()
    }
}

/// Factory for creating reasoning parsers based on model type.
159
#[derive(Clone)]
160
pub struct ParserFactory {
161
162
163
    registry: ParserRegistry,
}

164
impl ParserFactory {
165
166
167
168
169
170
171
172
173
    /// Create a new factory with default parsers registered.
    pub fn new() -> Self {
        let registry = ParserRegistry::new();

        // Register base parser
        registry.register_parser("base", || {
            Box::new(BaseReasoningParser::new(ParserConfig::default()))
        });

174
175
        // Register DeepSeek-R1 parser (starts with in_reasoning=true)
        registry.register_parser("deepseek_r1", || Box::new(DeepSeekR1Parser::new()));
176

177
178
        // Register Qwen3 parser (starts with in_reasoning=false)
        registry.register_parser("qwen3", || Box::new(Qwen3Parser::new()));
179

180
181
        // Register Qwen3-thinking parser (starts with in_reasoning=true)
        registry.register_parser("qwen3_thinking", || Box::new(QwenThinkingParser::new()));
182

183
184
        // Register Kimi parser with Unicode tokens (starts with in_reasoning=false)
        registry.register_parser("kimi", || Box::new(KimiParser::new()));
185

186
187
188
189
190
191
        // Register GLM45 parser (same format as Qwen3 but separate for debugging)
        registry.register_parser("glm45", || Box::new(Glm45Parser::new()));

        // Register Step3 parser (same format as DeepSeek-R1 but separate for debugging)
        registry.register_parser("step3", || Box::new(Step3Parser::new()));

192
193
194
195
196
197
        // Register model patterns
        registry.register_pattern("deepseek-r1", "deepseek_r1");
        registry.register_pattern("qwen3-thinking", "qwen3_thinking");
        registry.register_pattern("qwen-thinking", "qwen3_thinking");
        registry.register_pattern("qwen3", "qwen3");
        registry.register_pattern("qwen", "qwen3");
198
        registry.register_pattern("glm45", "glm45");
199
        registry.register_pattern("kimi", "kimi");
200
        registry.register_pattern("step3", "step3");
201
202
203
204

        Self { registry }
    }

205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
    /// Get a pooled parser for the given model ID.
    /// Returns a shared instance that can be used concurrently.
    /// Falls back to a passthrough parser if model is not recognized.
    pub fn get_pooled(&self, model_id: &str) -> PooledParser {
        // First try to find by pattern
        if let Some(parser) = self.registry.find_pooled_parser_for_model(model_id) {
            return parser;
        }

        // Fall back to no-op parser (get or create passthrough in pool)
        self.registry
            .get_pooled_parser("passthrough")
            .unwrap_or_else(|| {
                // Register passthrough if not already registered
                self.registry.register_parser("passthrough", || {
                    let config = ParserConfig {
                        think_start_token: "".to_string(),
                        think_end_token: "".to_string(),
                        stream_reasoning: true,
                        max_buffer_size: 65536,
                        initial_in_reasoning: false,
                    };
                    Box::new(
                        BaseReasoningParser::new(config).with_model_type("passthrough".to_string()),
                    )
                });
                self.registry.get_pooled_parser("passthrough").unwrap()
            })
    }

    /// Create a new parser instance for the given model ID.
    /// Returns a fresh instance (not pooled).
    /// Use this when you need an isolated parser instance.
238
239
    pub fn create(&self, model_id: &str) -> Result<Box<dyn ReasoningParser>, ParseError> {
        // First try to find by pattern
240
        if let Some(parser) = self.registry.create_for_model(model_id) {
241
242
243
244
245
246
247
248
249
            return Ok(parser);
        }

        // Fall back to no-op parser (base parser without reasoning detection)
        let config = ParserConfig {
            think_start_token: "".to_string(),
            think_end_token: "".to_string(),
            stream_reasoning: true,
            max_buffer_size: 65536,
250
            initial_in_reasoning: false,
251
252
253
254
255
256
257
258
259
260
        };
        Ok(Box::new(
            BaseReasoningParser::new(config).with_model_type("passthrough".to_string()),
        ))
    }

    /// Get the internal registry for custom registration.
    pub fn registry(&self) -> &ParserRegistry {
        &self.registry
    }
261
262
263
264
265
266

    /// Clear the parser pool.
    /// Useful for testing or when parsers need to be reset globally.
    pub fn clear_pool(&self) {
        self.registry.clear_pool();
    }
267
268
}

269
impl Default for ParserFactory {
270
271
272
273
274
275
276
277
278
279
280
    fn default() -> Self {
        Self::new()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_factory_creates_deepseek_r1() {
281
        let factory = ParserFactory::new();
282
283
284
285
286
287
        let parser = factory.create("deepseek-r1-distill").unwrap();
        assert_eq!(parser.model_type(), "deepseek_r1");
    }

    #[test]
    fn test_factory_creates_qwen3() {
288
        let factory = ParserFactory::new();
289
290
291
292
293
294
        let parser = factory.create("qwen3-7b").unwrap();
        assert_eq!(parser.model_type(), "qwen3");
    }

    #[test]
    fn test_factory_creates_kimi() {
295
        let factory = ParserFactory::new();
296
297
298
299
300
301
        let parser = factory.create("kimi-chat").unwrap();
        assert_eq!(parser.model_type(), "kimi");
    }

    #[test]
    fn test_factory_fallback_to_passthrough() {
302
        let factory = ParserFactory::new();
303
304
305
306
307
308
        let parser = factory.create("unknown-model").unwrap();
        assert_eq!(parser.model_type(), "passthrough");
    }

    #[test]
    fn test_case_insensitive_matching() {
309
        let factory = ParserFactory::new();
310
311
312
313
314
315
316
317
318
319
        let parser1 = factory.create("DeepSeek-R1").unwrap();
        let parser2 = factory.create("QWEN3").unwrap();
        let parser3 = factory.create("Kimi").unwrap();

        assert_eq!(parser1.model_type(), "deepseek_r1");
        assert_eq!(parser2.model_type(), "qwen3");
        assert_eq!(parser3.model_type(), "kimi");
    }

    #[test]
320
    fn test_step3_model() {
321
        let factory = ParserFactory::new();
322
        let step3 = factory.create("step3-model").unwrap();
323
324
        assert_eq!(step3.model_type(), "step3");
    }
325

326
327
    #[test]
    fn test_glm45_model() {
328
        let factory = ParserFactory::new();
329
330
        let glm45 = factory.create("glm45-v2").unwrap();
        assert_eq!(glm45.model_type(), "glm45");
331
    }
332

333
334
    #[tokio::test]
    async fn test_pooled_parser_reuse() {
335
        let factory = ParserFactory::new();
336
337
338
339
340
341
342
343
344
345
346
347
348

        // Get the same parser twice - should be the same instance
        let parser1 = factory.get_pooled("deepseek-r1");
        let parser2 = factory.get_pooled("deepseek-r1");

        // Both should point to the same Arc
        assert!(Arc::ptr_eq(&parser1, &parser2));

        // Different models should get different parsers
        let parser3 = factory.get_pooled("qwen3");
        assert!(!Arc::ptr_eq(&parser1, &parser3));
    }

349
350
    #[tokio::test]
    async fn test_pooled_parser_concurrent_access() {
351
        let factory = ParserFactory::new();
352
353
        let parser = factory.get_pooled("deepseek-r1");

354
        // Spawn multiple async tasks that use the same parser
355
356
357
358
        let mut handles = vec![];

        for i in 0..3 {
            let parser_clone = Arc::clone(&parser);
359
360
            let handle = tokio::spawn(async move {
                let mut parser = parser_clone.lock().await;
361
362
363
364
365
366
367
368
                let input = format!("thread {} reasoning</think>answer", i);
                let result = parser.detect_and_parse_reasoning(&input).unwrap();
                assert_eq!(result.normal_text, "answer");
                assert!(result.reasoning_text.contains("reasoning"));
            });
            handles.push(handle);
        }

369
        // Wait for all tasks to complete
370
        for handle in handles {
371
            handle.await.unwrap();
372
373
374
        }
    }

375
376
    #[tokio::test]
    async fn test_pool_clearing() {
377
        let factory = ParserFactory::new();
378
379
380
381
382
383
384
385
386
387
388
389
390
391

        // Get a pooled parser
        let parser1 = factory.get_pooled("deepseek-r1");

        // Clear the pool
        factory.clear_pool();

        // Get another parser - should be a new instance
        let parser2 = factory.get_pooled("deepseek-r1");

        // They should be different instances (different Arc pointers)
        assert!(!Arc::ptr_eq(&parser1, &parser2));
    }

392
393
    #[tokio::test]
    async fn test_passthrough_parser_pooling() {
394
        let factory = ParserFactory::new();
395
396
397
398
399
400
401
402

        // Unknown models should get passthrough parser
        let parser1 = factory.get_pooled("unknown-model-1");
        let parser2 = factory.get_pooled("unknown-model-2");

        // Both should use the same passthrough parser instance
        assert!(Arc::ptr_eq(&parser1, &parser2));

403
        let parser = parser1.lock().await;
404
405
406
        assert_eq!(parser.model_type(), "passthrough");
    }

407
408
    #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
    async fn test_high_concurrency_parser_access() {
409
410
411
412
        use std::{
            sync::atomic::{AtomicUsize, Ordering},
            time::Instant,
        };
413

414
        let factory = ParserFactory::new();
415
416
        let num_tasks = 100;
        let requests_per_task = 50;
417
418
419
420
421
422
423
424
425
        let models = vec!["deepseek-r1", "qwen3", "kimi", "qwen3-thinking"];

        // Track successful operations
        let success_count = Arc::new(AtomicUsize::new(0));
        let error_count = Arc::new(AtomicUsize::new(0));

        let start = Instant::now();
        let mut handles = vec![];

426
        for task_id in 0..num_tasks {
427
428
429
430
431
            let factory = factory.clone();
            let models = models.clone();
            let success_count = Arc::clone(&success_count);
            let error_count = Arc::clone(&error_count);

432
433
            let handle = tokio::spawn(async move {
                for request_id in 0..requests_per_task {
434
                    // Rotate through different models
435
                    let model = &models[(task_id + request_id) % models.len()];
436
437
                    let parser = factory.get_pooled(model);

438
439
                    // Use async lock - tokio::Mutex doesn't poison
                    let mut p = parser.lock().await;
440
441
442
443

                    // Simulate realistic parsing work with substantial text
                    // Typical reasoning can be 500-5000 tokens
                    let reasoning_text = format!(
444
                        "Task {} is processing request {}. Let me think through this step by step. \
445
446
447
448
449
450
451
452
453
454
455
                        First, I need to understand the problem. The problem involves analyzing data \
                        and making calculations. Let me break this down: \n\
                        1. Initial analysis shows that we have multiple variables to consider. \
                        2. The data suggests a pattern that needs further investigation. \
                        3. Computing the values: {} * {} = {}. \
                        4. Cross-referencing with previous results indicates consistency. \
                        5. The mathematical proof follows from the axioms... \
                        6. Considering edge cases and boundary conditions... \
                        7. Validating against known constraints... \
                        8. The conclusion follows logically from premises A, B, and C. \
                        This reasoning chain demonstrates the validity of our approach.",
456
                        task_id, request_id, task_id, request_id, task_id * request_id
457
458
459
                    );

                    let answer_text = format!(
460
                        "Based on my analysis, the answer for task {} request {} is: \
461
462
463
464
465
                        The solution involves multiple steps as outlined in the reasoning. \
                        The final result is {} with confidence level high. \
                        This conclusion is supported by rigorous mathematical analysis \
                        and has been validated against multiple test cases. \
                        The implementation should handle edge cases appropriately.",
466
                        task_id,
467
                        request_id,
468
                        task_id * request_id
469
470
471
472
473
474
475
                    );

                    let input = format!("<think>{}</think>{}", reasoning_text, answer_text);

                    match p.detect_and_parse_reasoning(&input) {
                        Ok(result) => {
                            // Note: Some parsers with stream_reasoning=true won't accumulate reasoning text
476
                            assert!(result.normal_text.contains(&format!("task {}", task_id)));
477
478
479
480
481
482

                            // For parsers that accumulate reasoning (stream_reasoning=false)
                            // the reasoning_text should be populated
                            if !result.reasoning_text.is_empty() {
                                assert!(result
                                    .reasoning_text
483
                                    .contains(&format!("Task {}", task_id)));
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
                                assert!(result.reasoning_text.len() > 500); // Ensure substantial reasoning
                            }

                            // Normal text should always be present
                            assert!(result.normal_text.len() > 100); // Ensure substantial answer
                            success_count.fetch_add(1, Ordering::Relaxed);
                        }
                        Err(e) => {
                            eprintln!("Parse error: {:?}", e);
                            error_count.fetch_add(1, Ordering::Relaxed);
                        }
                    }

                    // Explicitly drop the lock to release it quickly
                    drop(p);
                }
            });
            handles.push(handle);
        }

504
        // Wait for all tasks
505
        for handle in handles {
506
            handle.await.unwrap();
507
508
509
        }

        let duration = start.elapsed();
510
        let total_requests = num_tasks * requests_per_task;
511
512
513
514
515
        let successes = success_count.load(Ordering::Relaxed);
        let errors = error_count.load(Ordering::Relaxed);

        // Print stats for debugging
        println!(
516
517
            "High concurrency test: {} tasks, {} requests each",
            num_tasks, requests_per_task
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
        );
        println!(
            "Completed in {:?}, {} successes, {} errors",
            duration, successes, errors
        );
        println!(
            "Throughput: {:.0} requests/sec",
            (total_requests as f64) / duration.as_secs_f64()
        );

        // All requests should succeed
        assert_eq!(successes, total_requests);
        assert_eq!(errors, 0);

        // Performance check: should handle at least 1000 req/sec
        let throughput = (total_requests as f64) / duration.as_secs_f64();
        assert!(
            throughput > 1000.0,
            "Throughput too low: {:.0} req/sec",
            throughput
        );
    }

541
542
    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
    async fn test_concurrent_pool_modifications() {
543
        let factory = ParserFactory::new();
544
545
        let mut handles = vec![];

546
        // Task 1: Continuously get parsers
547
        let factory1 = factory.clone();
548
        handles.push(tokio::spawn(async move {
549
550
551
552
553
            for _ in 0..100 {
                let _parser = factory1.get_pooled("deepseek-r1");
            }
        }));

554
        // Task 2: Continuously clear pool
555
        let factory2 = factory.clone();
556
        handles.push(tokio::spawn(async move {
557
558
            for _ in 0..10 {
                factory2.clear_pool();
559
                tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
560
561
562
            }
        }));

563
        // Task 3: Get different parsers
564
        let factory3 = factory.clone();
565
        handles.push(tokio::spawn(async move {
566
567
568
569
570
571
            for i in 0..100 {
                let models = ["qwen3", "kimi", "unknown"];
                let _parser = factory3.get_pooled(models[i % 3]);
            }
        }));

572
        // Wait for all tasks - should not deadlock or panic
573
        for handle in handles {
574
            handle.await.unwrap();
575
576
        }
    }
577
}