mistral_parser.rs 8.5 KB
Newer Older
1
2
3
use async_trait::async_trait;
use serde_json::Value;

4
use crate::protocols::common::Tool;
5

6
use crate::tool_parser::{
7
    errors::{ParserError, ParserResult},
8
    parsers::helpers,
9
10
    partial_json::PartialJson,
    traits::ToolParser,
11
    types::{FunctionCall, StreamingParseResult, ToolCall},
12
13
14
15
16
17
18
19
20
21
22
23
24
25
};

/// Mistral format parser for tool calls
///
/// Handles the Mistral-specific format:
/// `[TOOL_CALLS] [{"name": "func", "arguments": {...}}, ...]`
///
/// Features:
/// - Bracket counting for proper JSON array extraction
/// - Support for multiple tool calls in a single array
/// - String-aware parsing to handle nested brackets in JSON
pub struct MistralParser {
    /// Parser for handling incomplete JSON during streaming
    partial_json: PartialJson,
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44

    /// Buffer for accumulating incomplete patterns across chunks
    buffer: String,

    /// Stores complete tool call info (name and arguments) for each tool being parsed
    prev_tool_call_arr: Vec<Value>,

    /// Index of currently streaming tool call (-1 means no active tool)
    current_tool_id: i32,

    /// Flag for whether current tool's name has been sent to client
    current_tool_name_sent: bool,

    /// Tracks raw JSON string content streamed to client for each tool's arguments
    streamed_args_for_tool: Vec<String>,

    /// Token configuration
    bot_token: &'static str,
    tool_call_separator: &'static str,
45
46
47
48
49
50
51
}

impl MistralParser {
    /// Create a new Mistral parser
    pub fn new() -> Self {
        Self {
            partial_json: PartialJson::default(),
52
53
54
55
56
57
58
            buffer: String::new(),
            prev_tool_call_arr: Vec::new(),
            current_tool_id: -1,
            current_tool_name_sent: false,
            streamed_args_for_tool: Vec::new(),
            bot_token: "[TOOL_CALLS] [",
            tool_call_separator: ", ",
59
60
61
        }
    }

62
    fn extract_json_array_with_pos<'a>(&self, text: &'a str) -> Option<(usize, &'a str)> {
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
        const BOT_TOKEN: &str = "[TOOL_CALLS] [";

        // Find the start of the token
        let start_idx = text.find(BOT_TOKEN)?;

        // Start from the opening bracket after [TOOL_CALLS]
        // The -1 is to include the opening bracket that's part of the token
        let json_start = start_idx + BOT_TOKEN.len() - 1;

        let mut bracket_count = 0;
        let mut in_string = false;
        let mut escape_next = false;

        let bytes = text.as_bytes();

        for i in json_start..text.len() {
            let char = bytes[i];

            if escape_next {
                escape_next = false;
                continue;
            }

            if char == b'\\' {
                escape_next = true;
                continue;
            }

            if char == b'"' && !escape_next {
                in_string = !in_string;
                continue;
            }

            if !in_string {
                if char == b'[' {
                    bracket_count += 1;
                } else if char == b']' {
                    bracket_count -= 1;
                    if bracket_count == 0 {
                        // Found the matching closing bracket
103
                        return Some((start_idx, &text[json_start..=i]));
104
105
106
107
108
109
110
111
112
113
                    }
                }
            }
        }

        // Incomplete array (no matching closing bracket found)
        None
    }

    /// Parse tool calls from a JSON array
114
    fn parse_json_array(&self, json_str: &str) -> ParserResult<Vec<ToolCall>> {
115
        let value: Value = serde_json::from_str(json_str)
116
            .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
117
118
119
120

        let mut tools = Vec::new();

        if let Value::Array(arr) = value {
121
122
            for item in arr.iter() {
                if let Some(tool) = self.parse_single_object(item)? {
123
124
125
126
127
                    tools.push(tool);
                }
            }
        } else {
            // Single object case (shouldn't happen with Mistral format, but handle it)
128
            if let Some(tool) = self.parse_single_object(&value)? {
129
130
131
132
133
134
135
136
                tools.push(tool);
            }
        }

        Ok(tools)
    }

    /// Parse a single JSON object into a ToolCall
137
    fn parse_single_object(&self, obj: &Value) -> ParserResult<Option<ToolCall>> {
138
139
140
141
142
143
144
145
146
        let name = obj.get("name").and_then(|v| v.as_str());

        if let Some(name) = name {
            // Get arguments - Mistral uses "arguments" key
            let empty_obj = Value::Object(serde_json::Map::new());
            let args = obj.get("arguments").unwrap_or(&empty_obj);

            // Convert arguments to JSON string
            let arguments = serde_json::to_string(args)
147
                .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168

            Ok(Some(ToolCall {
                function: FunctionCall {
                    name: name.to_string(),
                    arguments,
                },
            }))
        } else {
            Ok(None)
        }
    }
}

impl Default for MistralParser {
    fn default() -> Self {
        Self::new()
    }
}

#[async_trait]
impl ToolParser for MistralParser {
169
    async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
170
171
        // Check if text contains Mistral format
        if !self.has_tool_markers(text) {
172
            return Ok((text.to_string(), vec![]));
173
174
        }

175
176
177
178
179
180
181
182
183
184
185
        // Extract JSON array from Mistral format with position
        if let Some((start_idx, json_array)) = self.extract_json_array_with_pos(text) {
            // Extract normal text before BOT_TOKEN
            let normal_text_before = if start_idx > 0 {
                text[..start_idx].to_string()
            } else {
                String::new()
            };

            match self.parse_json_array(json_array) {
                Ok(tools) => Ok((normal_text_before, tools)),
186
                Err(e) => {
187
                    // If JSON parsing fails, return the original text as normal text
188
                    tracing::warn!("Failed to parse tool call: {}", e);
189
190
191
                    Ok((text.to_string(), vec![]))
                }
            }
192
193
        } else {
            // Markers present but no complete array found
194
            Ok((text.to_string(), vec![]))
195
196
197
198
        }
    }

    async fn parse_incremental(
199
        &mut self,
200
        chunk: &str,
201
        tools: &[Tool],
202
    ) -> ParserResult<StreamingParseResult> {
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
        // Append new text to buffer
        self.buffer.push_str(chunk);
        let current_text = &self.buffer.clone();

        // Check if current_text has tool_call
        let has_tool_start = self.has_tool_markers(current_text)
            || (self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator));

        if !has_tool_start {
            // Only clear buffer if we're sure no tool call is starting
            if helpers::ends_with_partial_token(&self.buffer, self.bot_token).is_none() {
                let normal_text = self.buffer.clone();
                self.buffer.clear();

                return Ok(StreamingParseResult {
                    normal_text,
                    calls: vec![],
                });
            } else {
                // Might be partial bot_token, keep buffering
                return Ok(StreamingParseResult::default());
224
            }
225
226
        }

227
228
        // Build tool indices
        let tool_indices = helpers::get_tool_indices(tools);
229

230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
        // Determine start index for JSON parsing
        let start_idx = if let Some(pos) = current_text.find(self.bot_token) {
            pos + self.bot_token.len()
        } else if self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator) {
            self.tool_call_separator.len()
        } else {
            0
        };

        helpers::handle_json_tool_streaming(
            current_text,
            start_idx,
            &mut self.partial_json,
            &tool_indices,
            &mut self.buffer,
            &mut self.current_tool_id,
            &mut self.current_tool_name_sent,
            &mut self.streamed_args_for_tool,
            &mut self.prev_tool_call_arr,
        )
250
251
    }

252
253
    fn has_tool_markers(&self, text: &str) -> bool {
        text.contains("[TOOL_CALLS]")
254
    }
255
256
257
258

    fn get_unstreamed_tool_args(&self) -> Option<Vec<crate::tool_parser::types::ToolCallItem>> {
        helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
    }
259
260
261
262
263
264
265
266
267
268

    fn reset(&mut self) {
        helpers::reset_parser_state(
            &mut self.buffer,
            &mut self.prev_tool_call_arr,
            &mut self.current_tool_id,
            &mut self.current_tool_name_sent,
            &mut self.streamed_args_for_tool,
        );
    }
269
}