tool_parser_mistral.rs 8.26 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
//! Mistral Parser Integration Tests
//!
//! Tests for the Mistral parser which handles [TOOL_CALLS] format

use serde_json::json;
use sglang_router_rs::tool_parser::{MistralParser, ToolParser};

#[tokio::test]
async fn test_mistral_single_tool() {
    let parser = MistralParser::new();
    let input = r#"Let me search for that.
[TOOL_CALLS] [{"name": "search_web", "arguments": {"query": "latest news", "max_results": 5}}]"#;

14
    let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
15
    assert_eq!(tools.len(), 1);
16
    assert_eq!(normal_text, "Let me search for that.\n");
17
    assert_eq!(tools[0].function.name, "search_web");
18

19
    let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
20
21
22
23
24
25
26
27
28
29
30
31
32
    assert_eq!(args["query"], "latest news");
    assert_eq!(args["max_results"], 5);
}

#[tokio::test]
async fn test_mistral_multiple_tools() {
    let parser = MistralParser::new();
    let input = r#"I'll help you with both tasks.
[TOOL_CALLS] [
    {"name": "get_weather", "arguments": {"city": "Tokyo", "units": "celsius"}},
    {"name": "search_news", "arguments": {"query": "AI developments", "limit": 10}}
]"#;

33
    let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
34
    assert_eq!(tools.len(), 2);
35
    assert_eq!(normal_text, "I'll help you with both tasks.\n");
36

37
38
    assert_eq!(tools[0].function.name, "get_weather");
    let args0: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
39
40
    assert_eq!(args0["city"], "Tokyo");

41
42
    assert_eq!(tools[1].function.name, "search_news");
    let args1: serde_json::Value = serde_json::from_str(&tools[1].function.arguments).unwrap();
43
44
45
46
47
48
49
50
51
    assert_eq!(args1["query"], "AI developments");
}

#[tokio::test]
async fn test_mistral_nested_json() {
    let parser = MistralParser::new();
    let input = r#"Processing complex data.
[TOOL_CALLS] [{"name": "process_data", "arguments": {"config": {"nested": {"value": [1, 2, 3]}}, "enabled": true}}]"#;

52
    let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
53
    assert_eq!(tools.len(), 1);
54
    assert_eq!(normal_text, "Processing complex data.\n");
55

56
    let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
57
58
59
60
61
62
63
64
    assert_eq!(args["config"]["nested"]["value"], json!([1, 2, 3]));
    assert_eq!(args["enabled"], true);
}

#[tokio::test]
async fn test_mistral_with_text_after() {
    let parser = MistralParser::new();
    let input = r#"[TOOL_CALLS] [{"name": "test", "arguments": {}}]
Stefan He's avatar
Stefan He committed
65

66
67
And here's some text after the tool call that should be ignored."#;

68
69
70
    let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "test");
71
72
73
74
75
76
77
}

#[tokio::test]
async fn test_mistral_empty_arguments() {
    let parser = MistralParser::new();
    let input = r#"[TOOL_CALLS] [{"name": "ping", "arguments": {}}]"#;

78
79
80
    let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "ping");
81
82
83
84
85
86
87
}

#[tokio::test]
async fn test_mistral_with_brackets_in_strings() {
    let parser = MistralParser::new();
    let input = r#"[TOOL_CALLS] [{"name": "echo", "arguments": {"text": "Array notation: arr[0] = value[1]"}}]"#;

88
89
    let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 1);
90

91
    let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
92
93
94
95
96
97
98
    assert_eq!(args["text"], "Array notation: arr[0] = value[1]");
}

#[tokio::test]
async fn test_mistral_format_detection() {
    let parser = MistralParser::new();

99
100
101
102
    assert!(parser.has_tool_markers("[TOOL_CALLS] ["));
    assert!(parser.has_tool_markers("Some text [TOOL_CALLS] ["));
    assert!(!parser.has_tool_markers("Just plain text"));
    assert!(!parser.has_tool_markers("[{\"name\": \"test\"}]")); // JSON array without TOOL_CALLS
103
104
105
106
107
108
109
110
}

#[tokio::test]
async fn test_mistral_malformed_json() {
    let parser = MistralParser::new();

    // Missing closing bracket
    let input = r#"[TOOL_CALLS] [{"name": "test", "arguments": {}"#;
111
112
    if let Ok((_normal_text, tools)) = parser.parse_complete(input).await {
        assert_eq!(tools.len(), 0);
113
114
115
116
117
    }
    // Error is also acceptable for malformed input

    // Invalid JSON inside
    let input = r#"[TOOL_CALLS] [{"name": invalid}]"#;
118
119
    if let Ok((_normal_text, tools)) = parser.parse_complete(input).await {
        assert_eq!(tools.len(), 0);
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
    }
    // Error is also acceptable for malformed input
}

#[tokio::test]
async fn test_mistral_real_world_output() {
    let parser = MistralParser::new();

    // Actual output from Mistral model
    let input = r#"I'll search for information about Rust programming and check the weather in San Francisco.

[TOOL_CALLS] [
    {
        "name": "web_search",
        "arguments": {
            "query": "Rust programming language features 2024",
            "max_results": 3,
            "include_snippets": true
        }
    },
    {
        "name": "get_weather",
        "arguments": {
            "location": "San Francisco, CA",
            "units": "fahrenheit",
            "include_forecast": false
        }
    }
]

Let me execute these searches for you."#;

152
    let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
153
    assert_eq!(tools.len(), 2);
154
    assert_eq!(normal_text, "I'll search for information about Rust programming and check the weather in San Francisco.\n\n");
155
156
    assert_eq!(tools[0].function.name, "web_search");
    assert_eq!(tools[1].function.name, "get_weather");
157
}
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274

#[tokio::test]
async fn test_mistral_streaming_closing_bracket() {
    use sglang_router_rs::protocols::common::Tool;

    // Test that closing ] is stripped for Mistral array format
    let mut parser = MistralParser::new();

    let tools = vec![Tool {
        tool_type: "function".to_string(),
        function: sglang_router_rs::protocols::common::Function {
            name: "get_weather".to_string(),
            description: Some("Get weather".to_string()),
            parameters: json!({}),
            strict: None,
        },
    }];

    let chunks = vec![
        "[TOOL_CALLS] ",
        "[{",
        "\"",
        "name",
        "\":",
        "\"",
        "get",
        "_weather",
        "\",",
        "\"",
        "arguments",
        "\":",
        "{",
        "\"",
        "city",
        "\":",
        "\"",
        "Paris",
        "\"",
        "}",
        "}",
        "]",
        " Here's",
        " the weather",
        " info",
    ];

    let mut all_normal_text = String::new();

    for chunk in chunks {
        let result = parser.parse_incremental(chunk, &tools).await.unwrap();
        all_normal_text.push_str(&result.normal_text);
    }

    // Should emit only the third chunk as normal text, NOT the ]
    assert_eq!(
        all_normal_text, " Here's the weather info",
        "Should not emit ] for Mistral array format, got: '{}'",
        all_normal_text
    );
}

#[tokio::test]
async fn test_mistral_streaming_bracket_in_text_after_tools() {
    use sglang_router_rs::protocols::common::Tool;

    // Test that ] in normal text AFTER tool calls is preserved
    let mut parser = MistralParser::new();

    let tools = vec![Tool {
        tool_type: "function".to_string(),
        function: sglang_router_rs::protocols::common::Function {
            name: "get_weather".to_string(),
            description: Some("Get weather".to_string()),
            parameters: json!({}),
            strict: None,
        },
    }];

    let chunks = vec![
        "[TOOL_CALLS] ",
        "[",
        "{",
        "\"name",
        "\":",
        "\"get_weather",
        "\",",
        "\"arguments",
        "\":",
        "{\"",
        "city",
        "\":",
        "\"Paris",
        "\"}",
        "}",
        "]",
        " Array",
        " notation:",
        " arr",
        "[",
        "0",
        "]",
    ];

    let mut all_normal_text = String::new();

    for chunk in chunks {
        let result = parser.parse_incremental(chunk, &tools).await.unwrap();
        all_normal_text.push_str(&result.normal_text);
    }

    // Should preserve ] in normal text after tools complete
    assert_eq!(
        all_normal_text, " Array notation: arr[0]",
        "Should preserve ] in normal text after tools, got: '{}'",
        all_normal_text
    );
}