tool_parser_registry.rs 6.11 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
//! Parser Registry Integration Tests
//!
//! Tests for model-to-parser mappings and registry functionality

use sglang_router_rs::tool_parser::ParserRegistry;

#[tokio::test]
async fn test_registry_has_all_parsers() {
    let registry = ParserRegistry::new();
    let parsers = registry.list_parsers();

    assert!(parsers.contains(&"json"));
    assert!(parsers.contains(&"mistral"));
    assert!(parsers.contains(&"qwen"));
    assert!(parsers.contains(&"pythonic"));
    assert!(parsers.contains(&"llama"));
}

#[tokio::test]
async fn test_openai_models_use_json() {
    let registry = ParserRegistry::new();

    let models = vec!["gpt-4", "gpt-4-turbo", "gpt-3.5-turbo", "gpt-4o"];
    for model in models {
        let parser = registry.get_parser(model).unwrap();
        let test_input = r#"{"name": "test", "arguments": {}}"#;
        let result = parser.parse_complete(test_input).await.unwrap();
        assert_eq!(result.len(), 1);
        assert_eq!(result[0].function.name, "test");
    }
}

#[tokio::test]
async fn test_anthropic_models_use_json() {
    let registry = ParserRegistry::new();

    let models = vec!["claude-3-opus", "claude-3-sonnet", "claude-2.1"];
    for model in models {
        let parser = registry.get_parser(model).unwrap();
        let test_input = r#"{"name": "test", "arguments": {}}"#;
        let result = parser.parse_complete(test_input).await.unwrap();
        assert_eq!(result.len(), 1);
    }
}

#[tokio::test]
async fn test_mistral_models() {
    let registry = ParserRegistry::new();

    let models = vec!["mistral-large", "mistral-medium", "mixtral-8x7b"];
    for model in models {
        let parser = registry.get_parser(model).unwrap();
        let test_input = r#"[TOOL_CALLS] [{"name": "test", "arguments": {}}]"#;
        let result = parser.parse_complete(test_input).await.unwrap();
        assert_eq!(result.len(), 1);
        assert_eq!(result[0].function.name, "test");
    }
}

#[tokio::test]
async fn test_qwen_models() {
    let registry = ParserRegistry::new();

    let models = vec!["qwen2.5-72b", "Qwen2-7B", "qwen-max"];
    for model in models {
        let parser = registry.get_parser(model).unwrap();
        let test_input = r#"<tool_call>
{"name": "test", "arguments": {}}
</tool_call>"#;
        let result = parser.parse_complete(test_input).await.unwrap();
        assert_eq!(result.len(), 1);
        assert_eq!(result[0].function.name, "test");
    }
}

#[tokio::test]
async fn test_llama_model_variants() {
    let registry = ParserRegistry::new();

    // Llama 4 uses pythonic
    let parser = registry.get_parser("llama-4-70b").unwrap();
    let test_input = r#"[get_weather(city="NYC")]"#;
    let result = parser.parse_complete(test_input).await.unwrap();
    assert_eq!(result.len(), 1);
    assert_eq!(result[0].function.name, "get_weather");

    // Llama 3.2 uses python_tag
    let parser = registry.get_parser("llama-3.2-8b").unwrap();
    let test_input = r#"<|python_tag|>{"name": "test", "arguments": {}}"#;
    let result = parser.parse_complete(test_input).await.unwrap();
    assert_eq!(result.len(), 1);
    assert_eq!(result[0].function.name, "test");

    // Other Llama models use JSON
    let parser = registry.get_parser("llama-2-70b").unwrap();
    let test_input = r#"{"name": "test", "arguments": {}}"#;
    let result = parser.parse_complete(test_input).await.unwrap();
    assert_eq!(result.len(), 1);
}

#[tokio::test]
async fn test_deepseek_models() {
    let registry = ParserRegistry::new();

    // DeepSeek uses pythonic format (simplified, v3 would need custom parser)
    let parser = registry.get_parser("deepseek-coder").unwrap();
    let test_input = r#"[function(arg="value")]"#;
    let result = parser.parse_complete(test_input).await.unwrap();
    assert_eq!(result.len(), 1);
    assert_eq!(result[0].function.name, "function");
}

#[tokio::test]
async fn test_unknown_model_fallback() {
    let registry = ParserRegistry::new();

    // Unknown models should fall back to JSON parser
    let parser = registry.get_parser("unknown-model-xyz").unwrap();
    let test_input = r#"{"name": "fallback", "arguments": {}}"#;
    let result = parser.parse_complete(test_input).await.unwrap();
    assert_eq!(result.len(), 1);
    assert_eq!(result[0].function.name, "fallback");
}

#[tokio::test]
async fn test_pattern_specificity() {
    let registry = ParserRegistry::new();

    // Test that more specific patterns take precedence
    // llama-4* should match before llama-*
    let parser = registry.get_parser("llama-4-70b").unwrap();
    assert!(parser.detect_format(r#"[test_function(x=1)]"#)); // Pythonic format

    let parser = registry.get_parser("llama-3-70b").unwrap();
    assert!(parser.detect_format(r#"{"name": "test", "arguments": {}}"#)); // JSON format
}

#[tokio::test]
async fn test_real_world_model_outputs() {
    let registry = ParserRegistry::new();

    // Test with realistic outputs from different models
    let test_cases = vec![
        (
            "gpt-4",
            r#"I'll help you with that.

{"name": "search_web", "arguments": {"query": "latest AI news", "max_results": 5}}

Let me search for that information."#,
            "search_web",
        ),
        (
            "mistral-large",
            r#"Let me search for information about Rust.

[TOOL_CALLS] [
    {"name": "search", "arguments": {"query": "Rust programming"}},
    {"name": "get_weather", "arguments": {"city": "San Francisco"}}
]

I've initiated the search."#,
            "search",
        ),
        (
            "qwen2.5",
            r#"I'll check the weather for you.

<tool_call>
{
    "name": "get_weather",
    "arguments": {
        "location": "Tokyo",
        "units": "celsius"
    }
}
</tool_call>

The weather information has been requested."#,
            "get_weather",
        ),
    ];

    for (model, output, expected_name) in test_cases {
        let parser = registry.get_parser(model).unwrap();
        let result = parser.parse_complete(output).await.unwrap();
        assert!(!result.is_empty(), "No tools parsed for model {}", model);
        assert_eq!(
            result[0].function.name, expected_name,
            "Wrong function name for model {}",
            model
        );
    }
}