factory.rs 20.5 KB
Newer Older
1
// Factory and registry for creating model-specific reasoning parsers.
2
// Now with parser pooling support for efficient reuse across requests.
3
4

use std::collections::HashMap;
5
6
7
use std::sync::{Arc, RwLock};

use tokio::sync::Mutex;
8

9
use crate::reasoning_parser::parsers::{
10
11
    BaseReasoningParser, DeepSeekR1Parser, Glm45Parser, KimiParser, Qwen3Parser,
    QwenThinkingParser, Step3Parser,
12
};
13
14
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ReasoningParser};

15
/// Type alias for pooled parser instances.
16
/// Uses tokio::Mutex to avoid blocking the async executor.
17
18
pub type PooledParser = Arc<Mutex<Box<dyn ReasoningParser>>>;

19
20
21
/// Type alias for parser creator functions.
type ParserCreator = Arc<dyn Fn() -> Box<dyn ReasoningParser> + Send + Sync>;

22
/// Registry for model-specific parsers with pooling support.
23
24
#[derive(Clone)]
pub struct ParserRegistry {
25
26
27
28
29
    /// Creator functions for parsers (used when pool is empty)
    creators: Arc<RwLock<HashMap<String, ParserCreator>>>,
    /// Pooled parser instances for reuse
    pool: Arc<RwLock<HashMap<String, PooledParser>>>,
    /// Model pattern to parser name mappings
30
31
32
33
34
35
36
    patterns: Arc<RwLock<Vec<(String, String)>>>, // (pattern, parser_name)
}

impl ParserRegistry {
    /// Create a new empty registry.
    pub fn new() -> Self {
        Self {
37
38
            creators: Arc::new(RwLock::new(HashMap::new())),
            pool: Arc::new(RwLock::new(HashMap::new())),
39
40
41
42
43
44
45
46
47
            patterns: Arc::new(RwLock::new(Vec::new())),
        }
    }

    /// Register a parser creator for a given parser type.
    pub fn register_parser<F>(&self, name: &str, creator: F)
    where
        F: Fn() -> Box<dyn ReasoningParser> + Send + Sync + 'static,
    {
48
49
        let mut creators = self.creators.write().unwrap();
        creators.insert(name.to_string(), Arc::new(creator));
50
51
52
53
54
55
56
57
58
    }

    /// Register a model pattern to parser mapping.
    /// Patterns are checked in order, first match wins.
    pub fn register_pattern(&self, pattern: &str, parser_name: &str) {
        let mut patterns = self.patterns.write().unwrap();
        patterns.push((pattern.to_string(), parser_name.to_string()));
    }

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
    /// Get a pooled parser by exact name.
    /// Returns a shared parser instance from the pool, creating one if needed.
    pub fn get_pooled_parser(&self, name: &str) -> Option<PooledParser> {
        // First check if we have a pooled instance
        {
            let pool = self.pool.read().unwrap();
            if let Some(parser) = pool.get(name) {
                return Some(Arc::clone(parser));
            }
        }

        // If not in pool, create one and add to pool
        let creators = self.creators.read().unwrap();
        if let Some(creator) = creators.get(name) {
            let parser = Arc::new(Mutex::new(creator()));

            // Add to pool for future use
            let mut pool = self.pool.write().unwrap();
            pool.insert(name.to_string(), Arc::clone(&parser));

            Some(parser)
        } else {
            None
        }
    }

    /// Get a parser by exact name (creates new instance, not pooled).
    /// Use this for compatibility or when you need a fresh instance.
87
    pub fn get_parser(&self, name: &str) -> Option<Box<dyn ReasoningParser>> {
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
        let creators = self.creators.read().unwrap();
        creators.get(name).map(|creator| creator())
    }

    /// Find a pooled parser for a given model ID by pattern matching.
    pub fn find_pooled_parser_for_model(&self, model_id: &str) -> Option<PooledParser> {
        let patterns = self.patterns.read().unwrap();
        let model_lower = model_id.to_lowercase();

        for (pattern, parser_name) in patterns.iter() {
            if model_lower.contains(&pattern.to_lowercase()) {
                return self.get_pooled_parser(parser_name);
            }
        }
        None
103
104
    }

105
    /// Find a parser for a given model ID by pattern matching (creates new instance).
106
107
108
109
110
111
112
113
114
115
116
    pub fn find_parser_for_model(&self, model_id: &str) -> Option<Box<dyn ReasoningParser>> {
        let patterns = self.patterns.read().unwrap();
        let model_lower = model_id.to_lowercase();

        for (pattern, parser_name) in patterns.iter() {
            if model_lower.contains(&pattern.to_lowercase()) {
                return self.get_parser(parser_name);
            }
        }
        None
    }
117
118
119
120
121
122
123

    /// Clear the parser pool, forcing new instances to be created.
    /// Useful for testing or when parsers need to be reset globally.
    pub fn clear_pool(&self) {
        let mut pool = self.pool.write().unwrap();
        pool.clear();
    }
124
125
126
127
128
129
130
131
132
}

impl Default for ParserRegistry {
    fn default() -> Self {
        Self::new()
    }
}

/// Factory for creating reasoning parsers based on model type.
133
#[derive(Clone)]
134
pub struct ReasoningParserFactory {
135
136
137
    registry: ParserRegistry,
}

138
impl ReasoningParserFactory {
139
140
141
142
143
144
145
146
147
    /// Create a new factory with default parsers registered.
    pub fn new() -> Self {
        let registry = ParserRegistry::new();

        // Register base parser
        registry.register_parser("base", || {
            Box::new(BaseReasoningParser::new(ParserConfig::default()))
        });

148
149
        // Register DeepSeek-R1 parser (starts with in_reasoning=true)
        registry.register_parser("deepseek_r1", || Box::new(DeepSeekR1Parser::new()));
150

151
152
        // Register Qwen3 parser (starts with in_reasoning=false)
        registry.register_parser("qwen3", || Box::new(Qwen3Parser::new()));
153

154
155
        // Register Qwen3-thinking parser (starts with in_reasoning=true)
        registry.register_parser("qwen3_thinking", || Box::new(QwenThinkingParser::new()));
156

157
158
        // Register Kimi parser with Unicode tokens (starts with in_reasoning=false)
        registry.register_parser("kimi", || Box::new(KimiParser::new()));
159

160
161
162
163
164
165
        // Register GLM45 parser (same format as Qwen3 but separate for debugging)
        registry.register_parser("glm45", || Box::new(Glm45Parser::new()));

        // Register Step3 parser (same format as DeepSeek-R1 but separate for debugging)
        registry.register_parser("step3", || Box::new(Step3Parser::new()));

166
167
168
169
170
171
        // Register model patterns
        registry.register_pattern("deepseek-r1", "deepseek_r1");
        registry.register_pattern("qwen3-thinking", "qwen3_thinking");
        registry.register_pattern("qwen-thinking", "qwen3_thinking");
        registry.register_pattern("qwen3", "qwen3");
        registry.register_pattern("qwen", "qwen3");
172
        registry.register_pattern("glm45", "glm45");
173
        registry.register_pattern("kimi", "kimi");
174
        registry.register_pattern("step3", "step3");
175
176
177
178

        Self { registry }
    }

179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
    /// Get a pooled parser for the given model ID.
    /// Returns a shared instance that can be used concurrently.
    /// Falls back to a passthrough parser if model is not recognized.
    pub fn get_pooled(&self, model_id: &str) -> PooledParser {
        // First try to find by pattern
        if let Some(parser) = self.registry.find_pooled_parser_for_model(model_id) {
            return parser;
        }

        // Fall back to no-op parser (get or create passthrough in pool)
        self.registry
            .get_pooled_parser("passthrough")
            .unwrap_or_else(|| {
                // Register passthrough if not already registered
                self.registry.register_parser("passthrough", || {
                    let config = ParserConfig {
                        think_start_token: "".to_string(),
                        think_end_token: "".to_string(),
                        stream_reasoning: true,
                        max_buffer_size: 65536,
                        initial_in_reasoning: false,
                    };
                    Box::new(
                        BaseReasoningParser::new(config).with_model_type("passthrough".to_string()),
                    )
                });
                self.registry.get_pooled_parser("passthrough").unwrap()
            })
    }

    /// Create a new parser instance for the given model ID.
    /// Returns a fresh instance (not pooled).
    /// Use this when you need an isolated parser instance.
212
213
214
215
216
217
218
219
220
221
222
223
    pub fn create(&self, model_id: &str) -> Result<Box<dyn ReasoningParser>, ParseError> {
        // First try to find by pattern
        if let Some(parser) = self.registry.find_parser_for_model(model_id) {
            return Ok(parser);
        }

        // Fall back to no-op parser (base parser without reasoning detection)
        let config = ParserConfig {
            think_start_token: "".to_string(),
            think_end_token: "".to_string(),
            stream_reasoning: true,
            max_buffer_size: 65536,
224
            initial_in_reasoning: false,
225
226
227
228
229
230
231
232
233
234
        };
        Ok(Box::new(
            BaseReasoningParser::new(config).with_model_type("passthrough".to_string()),
        ))
    }

    /// Get the internal registry for custom registration.
    pub fn registry(&self) -> &ParserRegistry {
        &self.registry
    }
235
236
237
238
239
240

    /// Clear the parser pool.
    /// Useful for testing or when parsers need to be reset globally.
    pub fn clear_pool(&self) {
        self.registry.clear_pool();
    }
241
242
}

243
impl Default for ReasoningParserFactory {
244
245
246
247
248
249
250
251
252
253
254
    fn default() -> Self {
        Self::new()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_factory_creates_deepseek_r1() {
255
        let factory = ReasoningParserFactory::new();
256
257
258
259
260
261
        let parser = factory.create("deepseek-r1-distill").unwrap();
        assert_eq!(parser.model_type(), "deepseek_r1");
    }

    #[test]
    fn test_factory_creates_qwen3() {
262
        let factory = ReasoningParserFactory::new();
263
264
265
266
267
268
        let parser = factory.create("qwen3-7b").unwrap();
        assert_eq!(parser.model_type(), "qwen3");
    }

    #[test]
    fn test_factory_creates_kimi() {
269
        let factory = ReasoningParserFactory::new();
270
271
272
273
274
275
        let parser = factory.create("kimi-chat").unwrap();
        assert_eq!(parser.model_type(), "kimi");
    }

    #[test]
    fn test_factory_fallback_to_passthrough() {
276
        let factory = ReasoningParserFactory::new();
277
278
279
280
281
282
        let parser = factory.create("unknown-model").unwrap();
        assert_eq!(parser.model_type(), "passthrough");
    }

    #[test]
    fn test_case_insensitive_matching() {
283
        let factory = ReasoningParserFactory::new();
284
285
286
287
288
289
290
291
292
293
        let parser1 = factory.create("DeepSeek-R1").unwrap();
        let parser2 = factory.create("QWEN3").unwrap();
        let parser3 = factory.create("Kimi").unwrap();

        assert_eq!(parser1.model_type(), "deepseek_r1");
        assert_eq!(parser2.model_type(), "qwen3");
        assert_eq!(parser3.model_type(), "kimi");
    }

    #[test]
294
    fn test_step3_model() {
295
        let factory = ReasoningParserFactory::new();
296
        let step3 = factory.create("step3-model").unwrap();
297
298
        assert_eq!(step3.model_type(), "step3");
    }
299

300
301
    #[test]
    fn test_glm45_model() {
302
        let factory = ReasoningParserFactory::new();
303
304
        let glm45 = factory.create("glm45-v2").unwrap();
        assert_eq!(glm45.model_type(), "glm45");
305
    }
306

307
308
    #[tokio::test]
    async fn test_pooled_parser_reuse() {
309
        let factory = ReasoningParserFactory::new();
310
311
312
313
314
315
316
317
318
319
320
321
322

        // Get the same parser twice - should be the same instance
        let parser1 = factory.get_pooled("deepseek-r1");
        let parser2 = factory.get_pooled("deepseek-r1");

        // Both should point to the same Arc
        assert!(Arc::ptr_eq(&parser1, &parser2));

        // Different models should get different parsers
        let parser3 = factory.get_pooled("qwen3");
        assert!(!Arc::ptr_eq(&parser1, &parser3));
    }

323
324
    #[tokio::test]
    async fn test_pooled_parser_concurrent_access() {
325
        let factory = ReasoningParserFactory::new();
326
327
        let parser = factory.get_pooled("deepseek-r1");

328
        // Spawn multiple async tasks that use the same parser
329
330
331
332
        let mut handles = vec![];

        for i in 0..3 {
            let parser_clone = Arc::clone(&parser);
333
334
            let handle = tokio::spawn(async move {
                let mut parser = parser_clone.lock().await;
335
336
337
338
339
340
341
342
                let input = format!("thread {} reasoning</think>answer", i);
                let result = parser.detect_and_parse_reasoning(&input).unwrap();
                assert_eq!(result.normal_text, "answer");
                assert!(result.reasoning_text.contains("reasoning"));
            });
            handles.push(handle);
        }

343
        // Wait for all tasks to complete
344
        for handle in handles {
345
            handle.await.unwrap();
346
347
348
        }
    }

349
350
    #[tokio::test]
    async fn test_pool_clearing() {
351
        let factory = ReasoningParserFactory::new();
352
353
354
355
356
357
358
359
360
361
362
363
364
365

        // Get a pooled parser
        let parser1 = factory.get_pooled("deepseek-r1");

        // Clear the pool
        factory.clear_pool();

        // Get another parser - should be a new instance
        let parser2 = factory.get_pooled("deepseek-r1");

        // They should be different instances (different Arc pointers)
        assert!(!Arc::ptr_eq(&parser1, &parser2));
    }

366
367
    #[tokio::test]
    async fn test_passthrough_parser_pooling() {
368
        let factory = ReasoningParserFactory::new();
369
370
371
372
373
374
375
376

        // Unknown models should get passthrough parser
        let parser1 = factory.get_pooled("unknown-model-1");
        let parser2 = factory.get_pooled("unknown-model-2");

        // Both should use the same passthrough parser instance
        assert!(Arc::ptr_eq(&parser1, &parser2));

377
        let parser = parser1.lock().await;
378
379
380
        assert_eq!(parser.model_type(), "passthrough");
    }

381
382
    #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
    async fn test_high_concurrency_parser_access() {
383
384
385
        use std::sync::atomic::{AtomicUsize, Ordering};
        use std::time::Instant;

386
        let factory = ReasoningParserFactory::new();
387
388
        let num_tasks = 100;
        let requests_per_task = 50;
389
390
391
392
393
394
395
396
397
        let models = vec!["deepseek-r1", "qwen3", "kimi", "qwen3-thinking"];

        // Track successful operations
        let success_count = Arc::new(AtomicUsize::new(0));
        let error_count = Arc::new(AtomicUsize::new(0));

        let start = Instant::now();
        let mut handles = vec![];

398
        for task_id in 0..num_tasks {
399
400
401
402
403
            let factory = factory.clone();
            let models = models.clone();
            let success_count = Arc::clone(&success_count);
            let error_count = Arc::clone(&error_count);

404
405
            let handle = tokio::spawn(async move {
                for request_id in 0..requests_per_task {
406
                    // Rotate through different models
407
                    let model = &models[(task_id + request_id) % models.len()];
408
409
                    let parser = factory.get_pooled(model);

410
411
                    // Use async lock - tokio::Mutex doesn't poison
                    let mut p = parser.lock().await;
412
413
414
415

                    // Simulate realistic parsing work with substantial text
                    // Typical reasoning can be 500-5000 tokens
                    let reasoning_text = format!(
416
                        "Task {} is processing request {}. Let me think through this step by step. \
417
418
419
420
421
422
423
424
425
426
427
                        First, I need to understand the problem. The problem involves analyzing data \
                        and making calculations. Let me break this down: \n\
                        1. Initial analysis shows that we have multiple variables to consider. \
                        2. The data suggests a pattern that needs further investigation. \
                        3. Computing the values: {} * {} = {}. \
                        4. Cross-referencing with previous results indicates consistency. \
                        5. The mathematical proof follows from the axioms... \
                        6. Considering edge cases and boundary conditions... \
                        7. Validating against known constraints... \
                        8. The conclusion follows logically from premises A, B, and C. \
                        This reasoning chain demonstrates the validity of our approach.",
428
                        task_id, request_id, task_id, request_id, task_id * request_id
429
430
431
                    );

                    let answer_text = format!(
432
                        "Based on my analysis, the answer for task {} request {} is: \
433
434
435
436
437
                        The solution involves multiple steps as outlined in the reasoning. \
                        The final result is {} with confidence level high. \
                        This conclusion is supported by rigorous mathematical analysis \
                        and has been validated against multiple test cases. \
                        The implementation should handle edge cases appropriately.",
438
                        task_id,
439
                        request_id,
440
                        task_id * request_id
441
442
443
444
445
446
447
                    );

                    let input = format!("<think>{}</think>{}", reasoning_text, answer_text);

                    match p.detect_and_parse_reasoning(&input) {
                        Ok(result) => {
                            // Note: Some parsers with stream_reasoning=true won't accumulate reasoning text
448
                            assert!(result.normal_text.contains(&format!("task {}", task_id)));
449
450
451
452
453
454

                            // For parsers that accumulate reasoning (stream_reasoning=false)
                            // the reasoning_text should be populated
                            if !result.reasoning_text.is_empty() {
                                assert!(result
                                    .reasoning_text
455
                                    .contains(&format!("Task {}", task_id)));
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
                                assert!(result.reasoning_text.len() > 500); // Ensure substantial reasoning
                            }

                            // Normal text should always be present
                            assert!(result.normal_text.len() > 100); // Ensure substantial answer
                            success_count.fetch_add(1, Ordering::Relaxed);
                        }
                        Err(e) => {
                            eprintln!("Parse error: {:?}", e);
                            error_count.fetch_add(1, Ordering::Relaxed);
                        }
                    }

                    // Explicitly drop the lock to release it quickly
                    drop(p);
                }
            });
            handles.push(handle);
        }

476
        // Wait for all tasks
477
        for handle in handles {
478
            handle.await.unwrap();
479
480
481
        }

        let duration = start.elapsed();
482
        let total_requests = num_tasks * requests_per_task;
483
484
485
486
487
        let successes = success_count.load(Ordering::Relaxed);
        let errors = error_count.load(Ordering::Relaxed);

        // Print stats for debugging
        println!(
488
489
            "High concurrency test: {} tasks, {} requests each",
            num_tasks, requests_per_task
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
        );
        println!(
            "Completed in {:?}, {} successes, {} errors",
            duration, successes, errors
        );
        println!(
            "Throughput: {:.0} requests/sec",
            (total_requests as f64) / duration.as_secs_f64()
        );

        // All requests should succeed
        assert_eq!(successes, total_requests);
        assert_eq!(errors, 0);

        // Performance check: should handle at least 1000 req/sec
        let throughput = (total_requests as f64) / duration.as_secs_f64();
        assert!(
            throughput > 1000.0,
            "Throughput too low: {:.0} req/sec",
            throughput
        );
    }

513
514
    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
    async fn test_concurrent_pool_modifications() {
515
        let factory = ReasoningParserFactory::new();
516
517
        let mut handles = vec![];

518
        // Task 1: Continuously get parsers
519
        let factory1 = factory.clone();
520
        handles.push(tokio::spawn(async move {
521
522
523
524
525
            for _ in 0..100 {
                let _parser = factory1.get_pooled("deepseek-r1");
            }
        }));

526
        // Task 2: Continuously clear pool
527
        let factory2 = factory.clone();
528
        handles.push(tokio::spawn(async move {
529
530
            for _ in 0..10 {
                factory2.clear_pool();
531
                tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
532
533
534
            }
        }));

535
        // Task 3: Get different parsers
536
        let factory3 = factory.clone();
537
        handles.push(tokio::spawn(async move {
538
539
540
541
542
543
            for i in 0..100 {
                let models = ["qwen3", "kimi", "unknown"];
                let _parser = factory3.get_pooled(models[i % 3]);
            }
        }));

544
        // Wait for all tasks - should not deadlock or panic
545
        for handle in handles {
546
            handle.await.unwrap();
547
548
        }
    }
549
}