mistral_parser.rs 13.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
use async_trait::async_trait;
use serde_json::Value;

use crate::tool_parser::{
    errors::{ToolParserError, ToolParserResult},
    partial_json::PartialJson,
    state::ParseState,
    traits::ToolParser,
    types::{FunctionCall, StreamResult, ToolCall},
};

/// Mistral format parser for tool calls
///
/// Handles the Mistral-specific format:
/// `[TOOL_CALLS] [{"name": "func", "arguments": {...}}, ...]`
///
/// Features:
/// - Bracket counting for proper JSON array extraction
/// - Support for multiple tool calls in a single array
/// - String-aware parsing to handle nested brackets in JSON
pub struct MistralParser {
    /// Parser for handling incomplete JSON during streaming
    partial_json: PartialJson,
}

impl MistralParser {
    /// Create a new Mistral parser
    pub fn new() -> Self {
        Self {
            partial_json: PartialJson::default(),
        }
    }

    /// Extract JSON array using bracket counting
    ///
    /// Handles nested brackets in JSON content by tracking:
    /// - String boundaries (quotes)
    /// - Escape sequences
    /// - Bracket depth
    fn extract_json_array<'a>(&self, text: &'a str) -> Option<&'a str> {
41
42
43
44
        self.extract_json_array_with_pos(text).map(|(_, json)| json)
    }

    fn extract_json_array_with_pos<'a>(&self, text: &'a str) -> Option<(usize, &'a str)> {
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
        const BOT_TOKEN: &str = "[TOOL_CALLS] [";

        // Find the start of the token
        let start_idx = text.find(BOT_TOKEN)?;

        // Start from the opening bracket after [TOOL_CALLS]
        // The -1 is to include the opening bracket that's part of the token
        let json_start = start_idx + BOT_TOKEN.len() - 1;

        let mut bracket_count = 0;
        let mut in_string = false;
        let mut escape_next = false;

        let bytes = text.as_bytes();

        for i in json_start..text.len() {
            let char = bytes[i];

            if escape_next {
                escape_next = false;
                continue;
            }

            if char == b'\\' {
                escape_next = true;
                continue;
            }

            if char == b'"' && !escape_next {
                in_string = !in_string;
                continue;
            }

            if !in_string {
                if char == b'[' {
                    bracket_count += 1;
                } else if char == b']' {
                    bracket_count -= 1;
                    if bracket_count == 0 {
                        // Found the matching closing bracket
85
                        return Some((start_idx, &text[json_start..=i]));
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
                    }
                }
            }
        }

        // Incomplete array (no matching closing bracket found)
        None
    }

    /// Parse tool calls from a JSON array
    fn parse_json_array(&self, json_str: &str) -> ToolParserResult<Vec<ToolCall>> {
        let value: Value = serde_json::from_str(json_str)
            .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;

        let mut tools = Vec::new();

        if let Value::Array(arr) = value {
            for (index, item) in arr.iter().enumerate() {
                if let Some(tool) = self.parse_single_object(item, index)? {
                    tools.push(tool);
                }
            }
        } else {
            // Single object case (shouldn't happen with Mistral format, but handle it)
            if let Some(tool) = self.parse_single_object(&value, 0)? {
                tools.push(tool);
            }
        }

        Ok(tools)
    }

    /// Parse a single JSON object into a ToolCall
    fn parse_single_object(&self, obj: &Value, index: usize) -> ToolParserResult<Option<ToolCall>> {
        let name = obj.get("name").and_then(|v| v.as_str());

        if let Some(name) = name {
            // Get arguments - Mistral uses "arguments" key
            let empty_obj = Value::Object(serde_json::Map::new());
            let args = obj.get("arguments").unwrap_or(&empty_obj);

            // Convert arguments to JSON string
            let arguments = serde_json::to_string(args)
                .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;

            // Generate ID with index for multiple tools
            let id = format!("mistral_call_{}", index);

            Ok(Some(ToolCall {
                id,
                r#type: "function".to_string(),
                function: FunctionCall {
                    name: name.to_string(),
                    arguments,
                },
            }))
        } else {
            Ok(None)
        }
    }

    /// Check if text contains Mistral tool markers
    fn has_tool_markers(&self, text: &str) -> bool {
        text.contains("[TOOL_CALLS]")
    }
}

impl Default for MistralParser {
    fn default() -> Self {
        Self::new()
    }
}

#[async_trait]
impl ToolParser for MistralParser {
161
    async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
162
163
        // Check if text contains Mistral format
        if !self.has_tool_markers(text) {
164
            return Ok((text.to_string(), vec![]));
165
166
        }

167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
        // Extract JSON array from Mistral format with position
        if let Some((start_idx, json_array)) = self.extract_json_array_with_pos(text) {
            // Extract normal text before BOT_TOKEN
            let normal_text_before = if start_idx > 0 {
                text[..start_idx].to_string()
            } else {
                String::new()
            };

            match self.parse_json_array(json_array) {
                Ok(tools) => Ok((normal_text_before, tools)),
                Err(_) => {
                    // If JSON parsing fails, return the original text as normal text
                    Ok((text.to_string(), vec![]))
                }
            }
183
184
        } else {
            // Markers present but no complete array found
185
            Ok((text.to_string(), vec![]))
186
187
188
189
190
191
192
193
194
195
196
197
        }
    }

    async fn parse_incremental(
        &self,
        chunk: &str,
        state: &mut ParseState,
    ) -> ToolParserResult<StreamResult> {
        state.buffer.push_str(chunk);

        // Check if we have the start marker
        if !self.has_tool_markers(&state.buffer) {
198
199
200
201
202
203
204
205
206
207
208
209
            // No tool markers detected - return all buffered content as normal text
            let normal_text = std::mem::take(&mut state.buffer);
            return Ok(StreamResult::NormalText(normal_text));
        }

        // Check for text before [TOOL_CALLS] and extract it as normal text
        if let Some(marker_pos) = state.buffer.find("[TOOL_CALLS]") {
            if marker_pos > 0 {
                // We have text before the tool marker - extract it as normal text
                let normal_text: String = state.buffer.drain(..marker_pos).collect();
                return Ok(StreamResult::NormalText(normal_text));
            }
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
        }

        // Try to extract complete JSON array
        if let Some(json_array) = self.extract_json_array(&state.buffer) {
            // Parse with partial JSON to handle incomplete content
            match self.partial_json.parse_value(json_array) {
                Ok((value, consumed)) => {
                    // Check if we have a complete JSON structure
                    if consumed == json_array.len() {
                        // Complete JSON, parse tool calls
                        let tools = if let Value::Array(arr) = value {
                            let mut result = Vec::new();
                            for (index, item) in arr.iter().enumerate() {
                                if let Some(tool) = self.parse_single_object(item, index)? {
                                    result.push(tool);
                                }
                            }
                            result
                        } else {
                            vec![]
                        };

                        if !tools.is_empty() {
                            // Clear buffer since we consumed everything
                            state.buffer.clear();

                            // Return the first tool (simplified for Phase 3)
                            // Full multi-tool streaming will be implemented later
                            if let Some(tool) = tools.into_iter().next() {
                                return Ok(StreamResult::ToolComplete(tool));
                            }
                        }
                    } else {
                        // Partial JSON - try to extract tool name for streaming
                        if let Value::Array(arr) = value {
                            if let Some(first_tool) = arr.first() {
                                if let Some(name) = first_tool.get("name").and_then(|v| v.as_str())
                                {
                                    // Check if we've already sent the name
                                    if !state.in_string {
                                        state.in_string = true; // Use as flag for "name sent"
                                        return Ok(StreamResult::ToolName {
                                            index: 0,
                                            name: name.to_string(),
                                        });
                                    }

                                    // Check for arguments
                                    if let Some(args) = first_tool.get("arguments") {
                                        if let Ok(args_str) = serde_json::to_string(args) {
                                            return Ok(StreamResult::ToolArguments {
                                                index: 0,
                                                arguments: args_str,
                                            });
                                        }
                                    }
                                }
                            }
                        }
                    }
                }
                Err(_) => {
                    // Failed to parse even as partial JSON
                    // Keep buffering
                }
            }
        }

        Ok(StreamResult::Incomplete)
    }

    fn detect_format(&self, text: &str) -> bool {
        // Check if text contains Mistral-specific markers
        if self.has_tool_markers(text) {
            // Try to extract and validate the array
            if let Some(json_array) = self.extract_json_array(text) {
                // Check if it's valid JSON
                if let Ok(value) = serde_json::from_str::<Value>(json_array) {
                    // Check if it contains tool-like structures
                    match value {
                        Value::Array(ref arr) => arr.iter().any(|v| {
                            v.as_object().is_some_and(|o| {
                                o.contains_key("name") && o.contains_key("arguments")
                            })
                        }),
                        Value::Object(ref obj) => {
                            obj.contains_key("name") && obj.contains_key("arguments")
                        }
                        _ => false,
                    }
                } else {
                    false
                }
            } else {
                // Has markers but no complete array - might be streaming
                true
            }
        } else {
            false
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_parse_mistral_format() {
        let parser = MistralParser::new();
        let input = r#"[TOOL_CALLS] [{"name": "get_weather", "arguments": {"location": "Paris", "units": "celsius"}}]"#;

322
323
324
325
        let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
        assert_eq!(tools.len(), 1);
        assert_eq!(tools[0].function.name, "get_weather");
        assert!(tools[0].function.arguments.contains("Paris"));
326
327
328
329
330
331
332
333
334
335
    }

    #[tokio::test]
    async fn test_parse_multiple_tools() {
        let parser = MistralParser::new();
        let input = r#"[TOOL_CALLS] [
            {"name": "search", "arguments": {"query": "rust programming"}},
            {"name": "calculate", "arguments": {"expression": "2 + 2"}}
        ]"#;

336
337
338
339
        let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
        assert_eq!(tools.len(), 2);
        assert_eq!(tools[0].function.name, "search");
        assert_eq!(tools[1].function.name, "calculate");
340
341
342
343
344
345
346
    }

    #[tokio::test]
    async fn test_nested_brackets_in_json() {
        let parser = MistralParser::new();
        let input = r#"[TOOL_CALLS] [{"name": "process", "arguments": {"data": [1, 2, [3, 4]], "config": {"nested": [5, 6]}}}]"#;

347
348
349
        let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
        assert_eq!(tools.len(), 1);
        assert_eq!(tools[0].function.name, "process");
350
        // JSON serialization removes spaces, so check for [3,4] without spaces
351
        assert!(tools[0].function.arguments.contains("[3,4]"));
352
353
354
355
356
357
358
    }

    #[tokio::test]
    async fn test_escaped_quotes_in_strings() {
        let parser = MistralParser::new();
        let input = r#"[TOOL_CALLS] [{"name": "echo", "arguments": {"message": "He said \"Hello [World]\""}}]"#;

359
360
361
        let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
        assert_eq!(tools.len(), 1);
        assert_eq!(tools[0].function.name, "echo");
362
363
364
365
366
367
368
369
370
371
372
373
374
375
    }

    #[test]
    fn test_detect_format() {
        let parser = MistralParser::new();

        assert!(parser.detect_format(r#"[TOOL_CALLS] [{"name": "test", "arguments": {}}]"#));
        assert!(
            parser.detect_format(r#"Some text [TOOL_CALLS] [{"name": "test", "arguments": {}}]"#)
        );
        assert!(!parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
        assert!(!parser.detect_format("plain text"));
    }
}