registry.rs 5.79 KB
Newer Older
1
use crate::tool_parser::parsers::{
2
    DeepSeekParser, JsonParser, LlamaParser, MistralParser, PythonicParser, QwenParser,
3
};
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
use crate::tool_parser::traits::ToolParser;
use std::collections::HashMap;
use std::sync::Arc;

/// Registry for tool parsers and model mappings
pub struct ParserRegistry {
    /// Map of parser name to parser instance
    parsers: HashMap<String, Arc<dyn ToolParser>>,
    /// Map of model name/pattern to parser name
    model_mapping: HashMap<String, String>,
    /// Default parser to use when no match found
    default_parser: String,
}

impl ParserRegistry {
    /// Create a new parser registry with default mappings
    pub fn new() -> Self {
        let mut registry = Self {
            parsers: HashMap::new(),
            model_mapping: HashMap::new(),
            default_parser: "json".to_string(),
        };

27
28
29
        // Register default parsers
        registry.register_default_parsers();

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
        // Register default model mappings
        registry.register_default_mappings();

        registry
    }

    /// Register a parser
    pub fn register_parser(&mut self, name: impl Into<String>, parser: Arc<dyn ToolParser>) {
        self.parsers.insert(name.into(), parser);
    }

    /// Map a model name/pattern to a parser
    pub fn map_model(&mut self, model: impl Into<String>, parser: impl Into<String>) {
        self.model_mapping.insert(model.into(), parser.into());
    }

    /// Get parser for a specific model
    pub fn get_parser(&self, model: &str) -> Option<Arc<dyn ToolParser>> {
        // Try exact match first
        if let Some(parser_name) = self.model_mapping.get(model) {
            if let Some(parser) = self.parsers.get(parser_name) {
                return Some(parser.clone());
            }
        }

55
56
57
58
59
60
61
62
63
64
65
        // Try prefix matching with more specific patterns first
        // Collect all matching patterns and sort by specificity (longer = more specific)
        let mut matches: Vec<(&String, &String)> = self
            .model_mapping
            .iter()
            .filter(|(pattern, _)| {
                if pattern.ends_with('*') {
                    let prefix = &pattern[..pattern.len() - 1];
                    model.starts_with(prefix)
                } else {
                    false
66
                }
67
68
69
70
71
72
73
74
75
76
            })
            .collect();

        // Sort by pattern length in descending order (longer patterns are more specific)
        matches.sort_by_key(|(pattern, _)| std::cmp::Reverse(pattern.len()));

        // Return the first matching parser
        for (_, parser_name) in matches {
            if let Some(parser) = self.parsers.get(parser_name) {
                return Some(parser.clone());
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
            }
        }

        // Fall back to default parser if it exists
        self.parsers.get(&self.default_parser).cloned()
    }

    /// List all registered parsers
    pub fn list_parsers(&self) -> Vec<&str> {
        self.parsers.keys().map(|s| s.as_str()).collect()
    }

    /// List all model mappings
    pub fn list_mappings(&self) -> Vec<(&str, &str)> {
        self.model_mapping
            .iter()
            .map(|(k, v)| (k.as_str(), v.as_str()))
            .collect()
    }

97
98
99
100
101
    /// Register default parsers
    fn register_default_parsers(&mut self) {
        // JSON parser - most common format
        self.register_parser("json", Arc::new(JsonParser::new()));

102
103
104
105
106
        // Mistral parser - [TOOL_CALLS] [...] format
        self.register_parser("mistral", Arc::new(MistralParser::new()));

        // Qwen parser - <tool_call>...</tool_call> format
        self.register_parser("qwen", Arc::new(QwenParser::new()));
107
108
109

        // Pythonic parser - [func(arg=val)] format
        self.register_parser("pythonic", Arc::new(PythonicParser::new()));
110
111
112

        // Llama parser - <|python_tag|>{...} or plain JSON format
        self.register_parser("llama", Arc::new(LlamaParser::new()));
113
114
115

        // DeepSeek V3 parser - Unicode tokens with JSON blocks
        self.register_parser("deepseek", Arc::new(DeepSeekParser::new()));
116
117
    }

118
119
120
121
122
123
124
125
126
127
    /// Register default model mappings
    fn register_default_mappings(&mut self) {
        // OpenAI models
        self.map_model("gpt-4*", "json");
        self.map_model("gpt-3.5*", "json");
        self.map_model("gpt-4o*", "json");

        // Anthropic models
        self.map_model("claude-*", "json");

128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
        // Mistral models - use Mistral parser
        self.map_model("mistral-*", "mistral");
        self.map_model("mixtral-*", "mistral");

        // Qwen models - use Qwen parser
        self.map_model("qwen*", "qwen");
        self.map_model("Qwen*", "qwen");

        // Llama models
        // Llama 4 uses pythonic format
        self.map_model("llama-4*", "pythonic");
        self.map_model("meta-llama-4*", "pythonic");
        // Llama 3.2 uses python_tag format
        self.map_model("llama-3.2*", "llama");
        self.map_model("meta-llama-3.2*", "llama");
        // Other Llama models use JSON
144
145
        self.map_model("llama-*", "json");
        self.map_model("meta-llama-*", "json");
146

147
148
149
150
151
        // DeepSeek models
        // DeepSeek V3 uses custom Unicode token format
        self.map_model("deepseek-v3*", "deepseek");
        self.map_model("deepseek-ai/DeepSeek-V3*", "deepseek");
        // DeepSeek V2 uses pythonic format
152
153
        self.map_model("deepseek-*", "pythonic");

154
155
156
        // Other models default to JSON
        self.map_model("gemini-*", "json");
        self.map_model("palm-*", "json");
157
        self.map_model("gemma-*", "json");
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
    }

    /// Set the default parser
    pub fn set_default_parser(&mut self, name: impl Into<String>) {
        self.default_parser = name.into();
    }

    /// Check if a parser is registered
    pub fn has_parser(&self, name: &str) -> bool {
        self.parsers.contains_key(name)
    }
}

impl Default for ParserRegistry {
    fn default() -> Self {
        Self::new()
    }
}