mistral_parser.rs 13.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
use async_trait::async_trait;
use serde_json::Value;

use crate::tool_parser::{
    errors::{ToolParserError, ToolParserResult},
    partial_json::PartialJson,
    state::ParseState,
    traits::ToolParser,
    types::{FunctionCall, StreamResult, ToolCall},
};

/// Mistral format parser for tool calls
///
/// Handles the Mistral-specific format:
/// `[TOOL_CALLS] [{"name": "func", "arguments": {...}}, ...]`
///
/// Features:
/// - Bracket counting for proper JSON array extraction
/// - Support for multiple tool calls in a single array
/// - String-aware parsing to handle nested brackets in JSON
pub struct MistralParser {
    /// Parser for handling incomplete JSON during streaming
    partial_json: PartialJson,
}

impl MistralParser {
    /// Create a new Mistral parser
    pub fn new() -> Self {
        Self {
            partial_json: PartialJson::default(),
        }
    }

    /// Extract JSON array using bracket counting
    ///
    /// Handles nested brackets in JSON content by tracking:
    /// - String boundaries (quotes)
    /// - Escape sequences
    /// - Bracket depth
    fn extract_json_array<'a>(&self, text: &'a str) -> Option<&'a str> {
41
42
43
44
        self.extract_json_array_with_pos(text).map(|(_, json)| json)
    }

    fn extract_json_array_with_pos<'a>(&self, text: &'a str) -> Option<(usize, &'a str)> {
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
        const BOT_TOKEN: &str = "[TOOL_CALLS] [";

        // Find the start of the token
        let start_idx = text.find(BOT_TOKEN)?;

        // Start from the opening bracket after [TOOL_CALLS]
        // The -1 is to include the opening bracket that's part of the token
        let json_start = start_idx + BOT_TOKEN.len() - 1;

        let mut bracket_count = 0;
        let mut in_string = false;
        let mut escape_next = false;

        let bytes = text.as_bytes();

        for i in json_start..text.len() {
            let char = bytes[i];

            if escape_next {
                escape_next = false;
                continue;
            }

            if char == b'\\' {
                escape_next = true;
                continue;
            }

            if char == b'"' && !escape_next {
                in_string = !in_string;
                continue;
            }

            if !in_string {
                if char == b'[' {
                    bracket_count += 1;
                } else if char == b']' {
                    bracket_count -= 1;
                    if bracket_count == 0 {
                        // Found the matching closing bracket
85
                        return Some((start_idx, &text[json_start..=i]));
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
                    }
                }
            }
        }

        // Incomplete array (no matching closing bracket found)
        None
    }

    /// Parse tool calls from a JSON array
    fn parse_json_array(&self, json_str: &str) -> ToolParserResult<Vec<ToolCall>> {
        let value: Value = serde_json::from_str(json_str)
            .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;

        let mut tools = Vec::new();

        if let Value::Array(arr) = value {
            for (index, item) in arr.iter().enumerate() {
                if let Some(tool) = self.parse_single_object(item, index)? {
                    tools.push(tool);
                }
            }
        } else {
            // Single object case (shouldn't happen with Mistral format, but handle it)
            if let Some(tool) = self.parse_single_object(&value, 0)? {
                tools.push(tool);
            }
        }

        Ok(tools)
    }

    /// Parse a single JSON object into a ToolCall
    fn parse_single_object(&self, obj: &Value, index: usize) -> ToolParserResult<Option<ToolCall>> {
        let name = obj.get("name").and_then(|v| v.as_str());

        if let Some(name) = name {
            // Get arguments - Mistral uses "arguments" key
            let empty_obj = Value::Object(serde_json::Map::new());
            let args = obj.get("arguments").unwrap_or(&empty_obj);

            // Convert arguments to JSON string
            let arguments = serde_json::to_string(args)
                .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;

            // Generate ID with index for multiple tools
            let id = format!("mistral_call_{}", index);

            Ok(Some(ToolCall {
                id,
                r#type: "function".to_string(),
                function: FunctionCall {
                    name: name.to_string(),
                    arguments,
                },
            }))
        } else {
            Ok(None)
        }
    }

    /// Check if text contains Mistral tool markers
    fn has_tool_markers(&self, text: &str) -> bool {
        text.contains("[TOOL_CALLS]")
    }
}

impl Default for MistralParser {
    fn default() -> Self {
        Self::new()
    }
}

#[async_trait]
impl ToolParser for MistralParser {
161
    async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
162
163
        // Check if text contains Mistral format
        if !self.has_tool_markers(text) {
164
            return Ok((text.to_string(), vec![]));
165
166
        }

167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
        // Extract JSON array from Mistral format with position
        if let Some((start_idx, json_array)) = self.extract_json_array_with_pos(text) {
            // Extract normal text before BOT_TOKEN
            let normal_text_before = if start_idx > 0 {
                text[..start_idx].to_string()
            } else {
                String::new()
            };

            match self.parse_json_array(json_array) {
                Ok(tools) => Ok((normal_text_before, tools)),
                Err(_) => {
                    // If JSON parsing fails, return the original text as normal text
                    Ok((text.to_string(), vec![]))
                }
            }
183
184
        } else {
            // Markers present but no complete array found
185
            Ok((text.to_string(), vec![]))
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
        }
    }

    async fn parse_incremental(
        &self,
        chunk: &str,
        state: &mut ParseState,
    ) -> ToolParserResult<StreamResult> {
        state.buffer.push_str(chunk);

        // Check if we have the start marker
        if !self.has_tool_markers(&state.buffer) {
            return Ok(StreamResult::Incomplete);
        }

        // Try to extract complete JSON array
        if let Some(json_array) = self.extract_json_array(&state.buffer) {
            // Parse with partial JSON to handle incomplete content
            match self.partial_json.parse_value(json_array) {
                Ok((value, consumed)) => {
                    // Check if we have a complete JSON structure
                    if consumed == json_array.len() {
                        // Complete JSON, parse tool calls
                        let tools = if let Value::Array(arr) = value {
                            let mut result = Vec::new();
                            for (index, item) in arr.iter().enumerate() {
                                if let Some(tool) = self.parse_single_object(item, index)? {
                                    result.push(tool);
                                }
                            }
                            result
                        } else {
                            vec![]
                        };

                        if !tools.is_empty() {
                            // Clear buffer since we consumed everything
                            state.buffer.clear();

                            // Return the first tool (simplified for Phase 3)
                            // Full multi-tool streaming will be implemented later
                            if let Some(tool) = tools.into_iter().next() {
                                return Ok(StreamResult::ToolComplete(tool));
                            }
                        }
                    } else {
                        // Partial JSON - try to extract tool name for streaming
                        if let Value::Array(arr) = value {
                            if let Some(first_tool) = arr.first() {
                                if let Some(name) = first_tool.get("name").and_then(|v| v.as_str())
                                {
                                    // Check if we've already sent the name
                                    if !state.in_string {
                                        state.in_string = true; // Use as flag for "name sent"
                                        return Ok(StreamResult::ToolName {
                                            index: 0,
                                            name: name.to_string(),
                                        });
                                    }

                                    // Check for arguments
                                    if let Some(args) = first_tool.get("arguments") {
                                        if let Ok(args_str) = serde_json::to_string(args) {
                                            return Ok(StreamResult::ToolArguments {
                                                index: 0,
                                                arguments: args_str,
                                            });
                                        }
                                    }
                                }
                            }
                        }
                    }
                }
                Err(_) => {
                    // Failed to parse even as partial JSON
                    // Keep buffering
                }
            }
        }

        Ok(StreamResult::Incomplete)
    }

    fn detect_format(&self, text: &str) -> bool {
        // Check if text contains Mistral-specific markers
        if self.has_tool_markers(text) {
            // Try to extract and validate the array
            if let Some(json_array) = self.extract_json_array(text) {
                // Check if it's valid JSON
                if let Ok(value) = serde_json::from_str::<Value>(json_array) {
                    // Check if it contains tool-like structures
                    match value {
                        Value::Array(ref arr) => arr.iter().any(|v| {
                            v.as_object().is_some_and(|o| {
                                o.contains_key("name") && o.contains_key("arguments")
                            })
                        }),
                        Value::Object(ref obj) => {
                            obj.contains_key("name") && obj.contains_key("arguments")
                        }
                        _ => false,
                    }
                } else {
                    false
                }
            } else {
                // Has markers but no complete array - might be streaming
                true
            }
        } else {
            false
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_parse_mistral_format() {
        let parser = MistralParser::new();
        let input = r#"[TOOL_CALLS] [{"name": "get_weather", "arguments": {"location": "Paris", "units": "celsius"}}]"#;

311
312
313
314
        let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
        assert_eq!(tools.len(), 1);
        assert_eq!(tools[0].function.name, "get_weather");
        assert!(tools[0].function.arguments.contains("Paris"));
315
316
317
318
319
320
321
322
323
324
    }

    #[tokio::test]
    async fn test_parse_multiple_tools() {
        let parser = MistralParser::new();
        let input = r#"[TOOL_CALLS] [
            {"name": "search", "arguments": {"query": "rust programming"}},
            {"name": "calculate", "arguments": {"expression": "2 + 2"}}
        ]"#;

325
326
327
328
        let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
        assert_eq!(tools.len(), 2);
        assert_eq!(tools[0].function.name, "search");
        assert_eq!(tools[1].function.name, "calculate");
329
330
331
332
333
334
335
    }

    #[tokio::test]
    async fn test_nested_brackets_in_json() {
        let parser = MistralParser::new();
        let input = r#"[TOOL_CALLS] [{"name": "process", "arguments": {"data": [1, 2, [3, 4]], "config": {"nested": [5, 6]}}}]"#;

336
337
338
        let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
        assert_eq!(tools.len(), 1);
        assert_eq!(tools[0].function.name, "process");
339
        // JSON serialization removes spaces, so check for [3,4] without spaces
340
        assert!(tools[0].function.arguments.contains("[3,4]"));
341
342
343
344
345
346
347
    }

    #[tokio::test]
    async fn test_escaped_quotes_in_strings() {
        let parser = MistralParser::new();
        let input = r#"[TOOL_CALLS] [{"name": "echo", "arguments": {"message": "He said \"Hello [World]\""}}]"#;

348
349
350
        let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
        assert_eq!(tools.len(), 1);
        assert_eq!(tools[0].function.name, "echo");
351
352
353
354
355
356
357
358
359
360
361
362
363
364
    }

    #[test]
    fn test_detect_format() {
        let parser = MistralParser::new();

        assert!(parser.detect_format(r#"[TOOL_CALLS] [{"name": "test", "arguments": {}}]"#));
        assert!(
            parser.detect_format(r#"Some text [TOOL_CALLS] [{"name": "test", "arguments": {}}]"#)
        );
        assert!(!parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
        assert!(!parser.detect_format("plain text"));
    }
}