mistral_parser.rs 8.39 KB
Newer Older
1
2
3
use async_trait::async_trait;
use serde_json::Value;

4
5
use crate::protocols::spec::Tool;

6
7
use crate::tool_parser::{
    errors::{ToolParserError, ToolParserResult},
8
    parsers::helpers,
9
10
    partial_json::PartialJson,
    traits::ToolParser,
11
    types::{FunctionCall, StreamingParseResult, ToolCall},
12
13
14
15
16
17
18
19
20
21
22
23
24
25
};

/// Mistral format parser for tool calls
///
/// Handles the Mistral-specific format:
/// `[TOOL_CALLS] [{"name": "func", "arguments": {...}}, ...]`
///
/// Features:
/// - Bracket counting for proper JSON array extraction
/// - Support for multiple tool calls in a single array
/// - String-aware parsing to handle nested brackets in JSON
pub struct MistralParser {
    /// Parser for handling incomplete JSON during streaming
    partial_json: PartialJson,
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44

    /// Buffer for accumulating incomplete patterns across chunks
    buffer: String,

    /// Stores complete tool call info (name and arguments) for each tool being parsed
    prev_tool_call_arr: Vec<Value>,

    /// Index of currently streaming tool call (-1 means no active tool)
    current_tool_id: i32,

    /// Flag for whether current tool's name has been sent to client
    current_tool_name_sent: bool,

    /// Tracks raw JSON string content streamed to client for each tool's arguments
    streamed_args_for_tool: Vec<String>,

    /// Token configuration
    bot_token: &'static str,
    tool_call_separator: &'static str,
45
46
47
48
49
50
51
}

impl MistralParser {
    /// Create a new Mistral parser
    pub fn new() -> Self {
        Self {
            partial_json: PartialJson::default(),
52
53
54
55
56
57
58
            buffer: String::new(),
            prev_tool_call_arr: Vec::new(),
            current_tool_id: -1,
            current_tool_name_sent: false,
            streamed_args_for_tool: Vec::new(),
            bot_token: "[TOOL_CALLS] [",
            tool_call_separator: ", ",
59
60
61
        }
    }

62
    fn extract_json_array_with_pos<'a>(&self, text: &'a str) -> Option<(usize, &'a str)> {
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
        const BOT_TOKEN: &str = "[TOOL_CALLS] [";

        // Find the start of the token
        let start_idx = text.find(BOT_TOKEN)?;

        // Start from the opening bracket after [TOOL_CALLS]
        // The -1 is to include the opening bracket that's part of the token
        let json_start = start_idx + BOT_TOKEN.len() - 1;

        let mut bracket_count = 0;
        let mut in_string = false;
        let mut escape_next = false;

        let bytes = text.as_bytes();

        for i in json_start..text.len() {
            let char = bytes[i];

            if escape_next {
                escape_next = false;
                continue;
            }

            if char == b'\\' {
                escape_next = true;
                continue;
            }

            if char == b'"' && !escape_next {
                in_string = !in_string;
                continue;
            }

            if !in_string {
                if char == b'[' {
                    bracket_count += 1;
                } else if char == b']' {
                    bracket_count -= 1;
                    if bracket_count == 0 {
                        // Found the matching closing bracket
103
                        return Some((start_idx, &text[json_start..=i]));
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
                    }
                }
            }
        }

        // Incomplete array (no matching closing bracket found)
        None
    }

    /// Parse tool calls from a JSON array
    fn parse_json_array(&self, json_str: &str) -> ToolParserResult<Vec<ToolCall>> {
        let value: Value = serde_json::from_str(json_str)
            .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;

        let mut tools = Vec::new();

        if let Value::Array(arr) = value {
121
122
            for item in arr.iter() {
                if let Some(tool) = self.parse_single_object(item)? {
123
124
125
126
127
                    tools.push(tool);
                }
            }
        } else {
            // Single object case (shouldn't happen with Mistral format, but handle it)
128
            if let Some(tool) = self.parse_single_object(&value)? {
129
130
131
132
133
134
135
136
                tools.push(tool);
            }
        }

        Ok(tools)
    }

    /// Parse a single JSON object into a ToolCall
137
    fn parse_single_object(&self, obj: &Value) -> ToolParserResult<Option<ToolCall>> {
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
        let name = obj.get("name").and_then(|v| v.as_str());

        if let Some(name) = name {
            // Get arguments - Mistral uses "arguments" key
            let empty_obj = Value::Object(serde_json::Map::new());
            let args = obj.get("arguments").unwrap_or(&empty_obj);

            // Convert arguments to JSON string
            let arguments = serde_json::to_string(args)
                .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;

            Ok(Some(ToolCall {
                function: FunctionCall {
                    name: name.to_string(),
                    arguments,
                },
            }))
        } else {
            Ok(None)
        }
    }

    /// Check if text contains Mistral tool markers
    fn has_tool_markers(&self, text: &str) -> bool {
        text.contains("[TOOL_CALLS]")
    }
}

impl Default for MistralParser {
    fn default() -> Self {
        Self::new()
    }
}

#[async_trait]
impl ToolParser for MistralParser {
174
    async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
175
176
        // Check if text contains Mistral format
        if !self.has_tool_markers(text) {
177
            return Ok((text.to_string(), vec![]));
178
179
        }

180
181
182
183
184
185
186
187
188
189
190
        // Extract JSON array from Mistral format with position
        if let Some((start_idx, json_array)) = self.extract_json_array_with_pos(text) {
            // Extract normal text before BOT_TOKEN
            let normal_text_before = if start_idx > 0 {
                text[..start_idx].to_string()
            } else {
                String::new()
            };

            match self.parse_json_array(json_array) {
                Ok(tools) => Ok((normal_text_before, tools)),
191
                Err(e) => {
192
                    // If JSON parsing fails, return the original text as normal text
193
                    tracing::warn!("Failed to parse tool call: {}", e);
194
195
196
                    Ok((text.to_string(), vec![]))
                }
            }
197
198
        } else {
            // Markers present but no complete array found
199
            Ok((text.to_string(), vec![]))
200
201
202
203
        }
    }

    async fn parse_incremental(
204
        &mut self,
205
        chunk: &str,
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
        tools: &[Tool],
    ) -> ToolParserResult<StreamingParseResult> {
        // Append new text to buffer
        self.buffer.push_str(chunk);
        let current_text = &self.buffer.clone();

        // Check if current_text has tool_call
        let has_tool_start = self.has_tool_markers(current_text)
            || (self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator));

        if !has_tool_start {
            // Only clear buffer if we're sure no tool call is starting
            if helpers::ends_with_partial_token(&self.buffer, self.bot_token).is_none() {
                let normal_text = self.buffer.clone();
                self.buffer.clear();

                return Ok(StreamingParseResult {
                    normal_text,
                    calls: vec![],
                });
            } else {
                // Might be partial bot_token, keep buffering
                return Ok(StreamingParseResult::default());
229
            }
230
231
        }

232
233
        // Build tool indices
        let tool_indices = helpers::get_tool_indices(tools);
234

235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
        // Determine start index for JSON parsing
        let start_idx = if let Some(pos) = current_text.find(self.bot_token) {
            pos + self.bot_token.len()
        } else if self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator) {
            self.tool_call_separator.len()
        } else {
            0
        };

        helpers::handle_json_tool_streaming(
            current_text,
            start_idx,
            &mut self.partial_json,
            &tool_indices,
            &mut self.buffer,
            &mut self.current_tool_id,
            &mut self.current_tool_name_sent,
            &mut self.streamed_args_for_tool,
            &mut self.prev_tool_call_arr,
        )
255
256
257
    }

    fn detect_format(&self, text: &str) -> bool {
258
        self.has_tool_markers(text)
259
    }
260
261
262
263

    fn get_unstreamed_tool_args(&self) -> Option<Vec<crate::tool_parser::types::ToolCallItem>> {
        helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
    }
264
}