mistral.rs 9.2 KB
Newer Older
1
2
3
use async_trait::async_trait;
use serde_json::Value;

4
5
6
7
8
9
10
11
12
use crate::{
    protocols::common::Tool,
    tool_parser::{
        errors::{ParserError, ParserResult},
        parsers::helpers,
        partial_json::PartialJson,
        traits::ToolParser,
        types::{FunctionCall, StreamingParseResult, ToolCall},
    },
13
14
15
16
17
18
19
};

/// Mistral format parser for tool calls
///
/// Handles the Mistral-specific format:
/// `[TOOL_CALLS] [{"name": "func", "arguments": {...}}, ...]`
///
20
/// Reference: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3?chat_template=default
21
22
23
pub struct MistralParser {
    /// Parser for handling incomplete JSON during streaming
    partial_json: PartialJson,
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41

    /// Buffer for accumulating incomplete patterns across chunks
    buffer: String,

    /// Stores complete tool call info (name and arguments) for each tool being parsed
    prev_tool_call_arr: Vec<Value>,

    /// Index of currently streaming tool call (-1 means no active tool)
    current_tool_id: i32,

    /// Flag for whether current tool's name has been sent to client
    current_tool_name_sent: bool,

    /// Tracks raw JSON string content streamed to client for each tool's arguments
    streamed_args_for_tool: Vec<String>,

    /// Token configuration
    bot_token: &'static str,
42
    eot_token: &'static str,
43
    tool_call_separator: &'static str,
44
45
46

    /// Track whether we've already stripped the closing ] bracket
    array_closed: bool,
47
48
49
50
51
52
53
}

impl MistralParser {
    /// Create a new Mistral parser
    pub fn new() -> Self {
        Self {
            partial_json: PartialJson::default(),
54
55
56
57
58
59
            buffer: String::new(),
            prev_tool_call_arr: Vec::new(),
            current_tool_id: -1,
            current_tool_name_sent: false,
            streamed_args_for_tool: Vec::new(),
            bot_token: "[TOOL_CALLS] [",
60
            eot_token: "]",
61
            tool_call_separator: ", ",
62
            array_closed: false,
63
64
65
        }
    }

66
    fn extract_json_array_with_pos<'a>(&self, text: &'a str) -> Option<(usize, &'a str)> {
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
        const BOT_TOKEN: &str = "[TOOL_CALLS] [";

        // Find the start of the token
        let start_idx = text.find(BOT_TOKEN)?;

        // Start from the opening bracket after [TOOL_CALLS]
        // The -1 is to include the opening bracket that's part of the token
        let json_start = start_idx + BOT_TOKEN.len() - 1;

        let mut bracket_count = 0;
        let mut in_string = false;
        let mut escape_next = false;

        let bytes = text.as_bytes();

        for i in json_start..text.len() {
            let char = bytes[i];

            if escape_next {
                escape_next = false;
                continue;
            }

            if char == b'\\' {
                escape_next = true;
                continue;
            }

            if char == b'"' && !escape_next {
                in_string = !in_string;
                continue;
            }

            if !in_string {
                if char == b'[' {
                    bracket_count += 1;
                } else if char == b']' {
                    bracket_count -= 1;
                    if bracket_count == 0 {
                        // Found the matching closing bracket
107
                        return Some((start_idx, &text[json_start..=i]));
108
109
110
111
112
113
114
115
116
117
                    }
                }
            }
        }

        // Incomplete array (no matching closing bracket found)
        None
    }

    /// Parse tool calls from a JSON array
118
    fn parse_json_array(&self, json_str: &str) -> ParserResult<Vec<ToolCall>> {
119
        let value: Value = serde_json::from_str(json_str)
120
            .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
121
122
123
124

        let mut tools = Vec::new();

        if let Value::Array(arr) = value {
125
126
            for item in arr.iter() {
                if let Some(tool) = self.parse_single_object(item)? {
127
128
129
130
131
                    tools.push(tool);
                }
            }
        } else {
            // Single object case (shouldn't happen with Mistral format, but handle it)
132
            if let Some(tool) = self.parse_single_object(&value)? {
133
134
135
136
137
138
139
140
                tools.push(tool);
            }
        }

        Ok(tools)
    }

    /// Parse a single JSON object into a ToolCall
141
    fn parse_single_object(&self, obj: &Value) -> ParserResult<Option<ToolCall>> {
142
143
144
145
146
147
148
149
150
        let name = obj.get("name").and_then(|v| v.as_str());

        if let Some(name) = name {
            // Get arguments - Mistral uses "arguments" key
            let empty_obj = Value::Object(serde_json::Map::new());
            let args = obj.get("arguments").unwrap_or(&empty_obj);

            // Convert arguments to JSON string
            let arguments = serde_json::to_string(args)
151
                .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172

            Ok(Some(ToolCall {
                function: FunctionCall {
                    name: name.to_string(),
                    arguments,
                },
            }))
        } else {
            Ok(None)
        }
    }
}

impl Default for MistralParser {
    fn default() -> Self {
        Self::new()
    }
}

#[async_trait]
impl ToolParser for MistralParser {
173
    async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
174
175
        // Check if text contains Mistral format
        if !self.has_tool_markers(text) {
176
            return Ok((text.to_string(), vec![]));
177
178
        }

179
180
181
182
183
184
185
186
187
188
189
        // Extract JSON array from Mistral format with position
        if let Some((start_idx, json_array)) = self.extract_json_array_with_pos(text) {
            // Extract normal text before BOT_TOKEN
            let normal_text_before = if start_idx > 0 {
                text[..start_idx].to_string()
            } else {
                String::new()
            };

            match self.parse_json_array(json_array) {
                Ok(tools) => Ok((normal_text_before, tools)),
190
                Err(e) => {
191
                    // If JSON parsing fails, return the original text as normal text
192
                    tracing::warn!("Failed to parse tool call: {}", e);
193
194
195
                    Ok((text.to_string(), vec![]))
                }
            }
196
197
        } else {
            // Markers present but no complete array found
198
            Ok((text.to_string(), vec![]))
199
200
201
202
        }
    }

    async fn parse_incremental(
203
        &mut self,
204
        chunk: &str,
205
        tools: &[Tool],
206
    ) -> ParserResult<StreamingParseResult> {
207
208
209
210
211
212
        // Append new text to buffer
        self.buffer.push_str(chunk);
        let current_text = &self.buffer.clone();

        // Check if current_text has tool_call
        let has_tool_start = self.has_tool_markers(current_text)
213
            || (self.current_tool_id > 0 && current_text.starts_with(self.tool_call_separator));
214
215
216
217

        if !has_tool_start {
            // Only clear buffer if we're sure no tool call is starting
            if helpers::ends_with_partial_token(&self.buffer, self.bot_token).is_none() {
218
                let mut normal_text = self.buffer.clone();
219
220
                self.buffer.clear();

221
222
223
224
225
226
227
228
229
230
231
232
233
                // Strip ] only once (the closing bracket of [TOOL_CALLS] array)
                // current_tool_id > 0 means we've parsed at least one tool
                if !self.array_closed
                    && self.current_tool_id > 0
                    && normal_text.starts_with(self.eot_token)
                {
                    normal_text = normal_text
                        .strip_prefix(self.eot_token)
                        .unwrap()
                        .to_string();
                    self.array_closed = true;
                }

234
235
236
237
238
239
240
                return Ok(StreamingParseResult {
                    normal_text,
                    calls: vec![],
                });
            } else {
                // Might be partial bot_token, keep buffering
                return Ok(StreamingParseResult::default());
241
            }
242
243
        }

244
245
        // Build tool indices
        let tool_indices = helpers::get_tool_indices(tools);
246

247
248
249
        // Determine start index for JSON parsing
        let start_idx = if let Some(pos) = current_text.find(self.bot_token) {
            pos + self.bot_token.len()
250
        } else if self.current_tool_id > 0 && current_text.starts_with(self.tool_call_separator) {
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
            self.tool_call_separator.len()
        } else {
            0
        };

        helpers::handle_json_tool_streaming(
            current_text,
            start_idx,
            &mut self.partial_json,
            &tool_indices,
            &mut self.buffer,
            &mut self.current_tool_id,
            &mut self.current_tool_name_sent,
            &mut self.streamed_args_for_tool,
            &mut self.prev_tool_call_arr,
        )
267
268
    }

269
270
    fn has_tool_markers(&self, text: &str) -> bool {
        text.contains("[TOOL_CALLS]")
271
    }
272
273
274
275

    fn get_unstreamed_tool_args(&self) -> Option<Vec<crate::tool_parser::types::ToolCallItem>> {
        helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
    }
276
277
278
279
280
281
282
283
284

    fn reset(&mut self) {
        helpers::reset_parser_state(
            &mut self.buffer,
            &mut self.prev_tool_call_arr,
            &mut self.current_tool_id,
            &mut self.current_tool_name_sent,
            &mut self.streamed_args_for_tool,
        );
285
        self.array_closed = false;
286
    }
287
}