registry.rs 5.45 KB
Newer Older
1
use crate::tool_parser::json_parser::JsonParser;
2
use crate::tool_parser::mistral_parser::MistralParser;
3
use crate::tool_parser::pythonic_parser::PythonicParser;
4
use crate::tool_parser::qwen_parser::QwenParser;
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
use crate::tool_parser::traits::ToolParser;
use std::collections::HashMap;
use std::sync::Arc;

/// Registry for tool parsers and model mappings
pub struct ParserRegistry {
    /// Map of parser name to parser instance
    parsers: HashMap<String, Arc<dyn ToolParser>>,
    /// Map of model name/pattern to parser name
    model_mapping: HashMap<String, String>,
    /// Default parser to use when no match found
    default_parser: String,
}

impl ParserRegistry {
    /// Create a new parser registry with default mappings
    pub fn new() -> Self {
        let mut registry = Self {
            parsers: HashMap::new(),
            model_mapping: HashMap::new(),
            default_parser: "json".to_string(),
        };

28
29
30
        // Register default parsers
        registry.register_default_parsers();

31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
        // Register default model mappings
        registry.register_default_mappings();

        registry
    }

    /// Register a parser
    pub fn register_parser(&mut self, name: impl Into<String>, parser: Arc<dyn ToolParser>) {
        self.parsers.insert(name.into(), parser);
    }

    /// Map a model name/pattern to a parser
    pub fn map_model(&mut self, model: impl Into<String>, parser: impl Into<String>) {
        self.model_mapping.insert(model.into(), parser.into());
    }

    /// Get parser for a specific model
    pub fn get_parser(&self, model: &str) -> Option<Arc<dyn ToolParser>> {
        // Try exact match first
        if let Some(parser_name) = self.model_mapping.get(model) {
            if let Some(parser) = self.parsers.get(parser_name) {
                return Some(parser.clone());
            }
        }

56
57
58
59
60
61
62
63
64
65
66
        // Try prefix matching with more specific patterns first
        // Collect all matching patterns and sort by specificity (longer = more specific)
        let mut matches: Vec<(&String, &String)> = self
            .model_mapping
            .iter()
            .filter(|(pattern, _)| {
                if pattern.ends_with('*') {
                    let prefix = &pattern[..pattern.len() - 1];
                    model.starts_with(prefix)
                } else {
                    false
67
                }
68
69
70
71
72
73
74
75
76
77
            })
            .collect();

        // Sort by pattern length in descending order (longer patterns are more specific)
        matches.sort_by_key(|(pattern, _)| std::cmp::Reverse(pattern.len()));

        // Return the first matching parser
        for (_, parser_name) in matches {
            if let Some(parser) = self.parsers.get(parser_name) {
                return Some(parser.clone());
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
            }
        }

        // Fall back to default parser if it exists
        self.parsers.get(&self.default_parser).cloned()
    }

    /// List all registered parsers
    pub fn list_parsers(&self) -> Vec<&str> {
        self.parsers.keys().map(|s| s.as_str()).collect()
    }

    /// List all model mappings
    pub fn list_mappings(&self) -> Vec<(&str, &str)> {
        self.model_mapping
            .iter()
            .map(|(k, v)| (k.as_str(), v.as_str()))
            .collect()
    }

98
99
100
101
102
    /// Register default parsers
    fn register_default_parsers(&mut self) {
        // JSON parser - most common format
        self.register_parser("json", Arc::new(JsonParser::new()));

103
104
105
106
107
        // Mistral parser - [TOOL_CALLS] [...] format
        self.register_parser("mistral", Arc::new(MistralParser::new()));

        // Qwen parser - <tool_call>...</tool_call> format
        self.register_parser("qwen", Arc::new(QwenParser::new()));
108
109
110

        // Pythonic parser - [func(arg=val)] format
        self.register_parser("pythonic", Arc::new(PythonicParser::new()));
111
112
    }

113
114
115
116
117
118
119
120
121
122
    /// Register default model mappings
    fn register_default_mappings(&mut self) {
        // OpenAI models
        self.map_model("gpt-4*", "json");
        self.map_model("gpt-3.5*", "json");
        self.map_model("gpt-4o*", "json");

        // Anthropic models
        self.map_model("claude-*", "json");

123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
        // Mistral models - use Mistral parser
        self.map_model("mistral-*", "mistral");
        self.map_model("mixtral-*", "mistral");

        // Qwen models - use Qwen parser
        self.map_model("qwen*", "qwen");
        self.map_model("Qwen*", "qwen");

        // Llama models
        // Llama 4 uses pythonic format
        self.map_model("llama-4*", "pythonic");
        self.map_model("meta-llama-4*", "pythonic");
        // Llama 3.2 uses python_tag format
        self.map_model("llama-3.2*", "llama");
        self.map_model("meta-llama-3.2*", "llama");
        // Other Llama models use JSON
139
140
        self.map_model("llama-*", "json");
        self.map_model("meta-llama-*", "json");
141

142
143
144
        // DeepSeek models - DeepSeek v3 would need custom parser, v2 uses pythonic
        self.map_model("deepseek-*", "pythonic");

145
146
147
        // Other models default to JSON
        self.map_model("gemini-*", "json");
        self.map_model("palm-*", "json");
148
        self.map_model("gemma-*", "json");
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    }

    /// Set the default parser
    pub fn set_default_parser(&mut self, name: impl Into<String>) {
        self.default_parser = name.into();
    }

    /// Check if a parser is registered
    pub fn has_parser(&self, name: &str) -> bool {
        self.parsers.contains_key(name)
    }
}

impl Default for ParserRegistry {
    fn default() -> Self {
        Self::new()
    }
}