kimik2_parser.rs 13.9 KB
Newer Older
1
2
use async_trait::async_trait;
use regex::Regex;
3
4
5
use serde_json::Value;

use crate::protocols::spec::Tool;
6
7
8

use crate::tool_parser::{
    errors::ToolParserResult,
9
    parsers::helpers,
10
    traits::ToolParser,
11
    types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
};

/// Kimi K2 format parser for tool calls
///
/// Handles the Kimi K2 specific format:
/// `<|tool_calls_section_begin|><|tool_call_begin|>functions.{name}:{index}<|tool_call_argument_begin|>{json_args}<|tool_call_end|><|tool_calls_section_end|>`
///
/// Features:
/// - Token-based delimiters
/// - Function calls with explicit indexing
/// - JSON arguments
pub struct KimiK2Parser {
    /// Regex for extracting complete tool calls
    tool_call_extractor: Regex,
    /// Regex for extracting partial tool calls (streaming)
    stream_tool_call_extractor: Regex,
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    /// Regex pattern for removing completed tool calls from buffer
    tool_call_end_pattern: Regex,
    /// Robust parser for ids like "functions.search:0" or fallback "search:0"
    tool_call_id_regex: Regex,

    /// Buffer for accumulating incomplete patterns across chunks
    buffer: String,

    /// Stores complete tool call info (name and arguments) for each tool being parsed
    prev_tool_call_arr: Vec<Value>,

    /// Index of currently streaming tool call (-1 means no active tool)
    current_tool_id: i32,

    /// Flag for whether current tool's name has been sent to client
    current_tool_name_sent: bool,

    /// Tracks raw JSON string content streamed to client for each tool's arguments
    streamed_args_for_tool: Vec<String>,

    /// Tracks the last arguments sent for incremental diffing
    last_arguments: String,
50
51
52
53
54
55
56
57
58
59
60
61
62
}

impl KimiK2Parser {
    /// Create a new Kimi K2 parser
    pub fn new() -> Self {
        // Pattern for complete tool calls
        let tool_call_pattern = r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*?\})\s*<\|tool_call_end\|>";
        let tool_call_extractor = Regex::new(tool_call_pattern).expect("Valid regex pattern");

        // Pattern for streaming (partial) tool calls
        let stream_pattern = r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*)";
        let stream_tool_call_extractor = Regex::new(stream_pattern).expect("Valid regex pattern");

63
64
65
66
67
68
69
70
        // Pattern for removing completed tool calls
        let end_pattern = r"<\|tool_call_begin\|>.*?<\|tool_call_end\|>";
        let tool_call_end_pattern = Regex::new(end_pattern).expect("Valid regex pattern");

        // Robust parser for ids like "functions.search:0" or fallback "search:0"
        let id_pattern = r"^(?:functions\.)?(?P<name>[\w\.]+):(?P<index>\d+)$";
        let tool_call_id_regex = Regex::new(id_pattern).expect("Valid regex pattern");

71
72
73
        Self {
            tool_call_extractor,
            stream_tool_call_extractor,
74
75
76
77
78
79
80
81
            tool_call_end_pattern,
            tool_call_id_regex,
            buffer: String::new(),
            prev_tool_call_arr: Vec::new(),
            current_tool_id: -1,
            current_tool_name_sent: false,
            streamed_args_for_tool: Vec::new(),
            last_arguments: String::new(),
82
83
84
85
86
        }
    }

    /// Parse function ID to extract name and index
    fn parse_function_id(&self, id: &str) -> Option<(String, usize)> {
87
88
89
90
91
92
        if let Some(captures) = self.tool_call_id_regex.captures(id) {
            let name = captures.name("name")?.as_str().to_string();
            let index = captures.name("index")?.as_str().parse::<usize>().ok()?;
            Some((name, index))
        } else {
            None
93
94
95
96
97
98
99
100
101
102
103
104
        }
    }
}

impl Default for KimiK2Parser {
    fn default() -> Self {
        Self::new()
    }
}

#[async_trait]
impl ToolParser for KimiK2Parser {
105
    async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
106
        if !self.has_tool_markers(text) {
107
            return Ok((text.to_string(), vec![]));
108
109
        }

110
111
112
        // Find where tool calls begin
        let idx = text.find("<|tool_calls_section_begin|>").unwrap();
        let normal_text = text[..idx].to_string();
113

114
115
116
        // Try to extract tool calls
        let mut tools = Vec::new();
        for captures in self.tool_call_extractor.captures_iter(text) {
117
118
119
120
121
122
123
124
125
            if let (Some(id_match), Some(args_match)) = (
                captures.name("tool_call_id"),
                captures.name("function_arguments"),
            ) {
                let function_id = id_match.as_str();
                let function_args = args_match.as_str();

                // Parse function ID
                if let Some((func_name, _index)) = self.parse_function_id(function_id) {
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
                    // Try to parse JSON arguments
                    match serde_json::from_str::<serde_json::Value>(function_args) {
                        Ok(_) => {
                            tools.push(ToolCall {
                                function: FunctionCall {
                                    name: func_name,
                                    arguments: function_args.to_string(),
                                },
                            });
                        }
                        Err(e) => {
                            tracing::warn!(
                                "Failed to parse JSON arguments for {}: {}",
                                func_name,
                                e
                            );
                            continue;
                        }
144
                    }
145
146
147
                } else {
                    tracing::warn!("Failed to parse function ID: {}", function_id);
                    continue;
148
149
150
151
                }
            }
        }

152
153
154
155
        // If no tools were successfully parsed despite having markers, return entire text as fallback
        if tools.is_empty() {
            return Ok((text.to_string(), vec![]));
        }
156
157

        Ok((normal_text, tools))
158
159
160
    }

    async fn parse_incremental(
161
        &mut self,
162
        chunk: &str,
163
164
165
166
        tools: &[Tool],
    ) -> ToolParserResult<StreamingParseResult> {
        self.buffer.push_str(chunk);
        let current_text = &self.buffer.clone();
167

168
        // Check if we have a tool call (either the start token or individual tool call)
169
        let has_tool_call =
170
            self.has_tool_markers(current_text) || current_text.contains("<|tool_call_begin|>");
171
172

        if !has_tool_call {
173
            // No tool markers detected - return all buffered content as normal text
174
175
176
177
178
179
180
181
182
            let mut normal_text = std::mem::take(&mut self.buffer);
            // Remove end tokens if present
            for e_token in ["<|tool_calls_section_end|>", "<|tool_call_end|>"] {
                normal_text = normal_text.replace(e_token, "");
            }
            return Ok(StreamingParseResult {
                normal_text,
                calls: vec![],
            });
183
184
        }

185
186
        // Build tool indices for validation
        let tool_indices = helpers::get_tool_indices(tools);
187

188
        let mut calls: Vec<ToolCallItem> = Vec::new();
189
190

        // Try to match streaming pattern
191
        if let Some(captures) = self.stream_tool_call_extractor.captures(current_text) {
192
193
194
195
196
            if let (Some(id_match), Some(args_match)) = (
                captures.name("tool_call_id"),
                captures.name("function_arguments"),
            ) {
                let function_id = id_match.as_str();
197
                let function_args = args_match.as_str();
198
199
200

                // Parse function ID
                if let Some((func_name, _index)) = self.parse_function_id(function_id) {
201
202
203
204
205
206
207
208
209
210
211
                    // Validate tool name
                    if !tool_indices.contains_key(&func_name) {
                        // Invalid tool name - skip this tool, preserve indexing for next tool
                        tracing::warn!("Invalid tool name '{}' - skipping", func_name);
                        helpers::reset_current_tool_state(
                            &mut self.buffer,
                            &mut self.current_tool_name_sent,
                            &mut self.streamed_args_for_tool,
                            &self.prev_tool_call_arr,
                        );
                        return Ok(StreamingParseResult::default());
212
213
                    }

214
215
216
217
218
219
                    // Initialize state if this is the first tool call
                    if self.current_tool_id == -1 {
                        self.current_tool_id = 0;
                        self.prev_tool_call_arr = Vec::new();
                        self.streamed_args_for_tool = vec![String::new()];
                    }
220

221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
                    // Ensure we have enough entries in our tracking arrays
                    helpers::ensure_capacity(
                        self.current_tool_id,
                        &mut self.prev_tool_call_arr,
                        &mut self.streamed_args_for_tool,
                    );

                    // Send tool name if not sent yet
                    if !self.current_tool_name_sent {
                        calls.push(ToolCallItem {
                            tool_index: self.current_tool_id as usize,
                            name: Some(func_name.clone()),
                            parameters: String::new(),
                        });
                        self.current_tool_name_sent = true;
236

237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
                        // Store the tool call info for serving layer completions endpoint
                        let tool_id = self.current_tool_id as usize;
                        if self.prev_tool_call_arr.len() <= tool_id {
                            self.prev_tool_call_arr
                                .resize_with(tool_id + 1, || Value::Null);
                        }
                        self.prev_tool_call_arr[tool_id] = serde_json::json!({
                            "name": func_name,
                            "arguments": {},
                        });
                    } else {
                        // Compute incremental diff
                        let argument_diff = if function_args.starts_with(&self.last_arguments) {
                            &function_args[self.last_arguments.len()..]
                        } else {
                            function_args
                        };

                        // Split by end token before sending (like Python does)
                        let parsed_args_diff =
                            if let Some(pos) = argument_diff.find("<|tool_call_end|>") {
                                &argument_diff[..pos]
                            } else {
                                argument_diff
261
262
                            };

263
264
265
266
267
268
269
270
271
272
273
                        if !parsed_args_diff.is_empty() {
                            calls.push(ToolCallItem {
                                tool_index: self.current_tool_id as usize,
                                name: None,
                                parameters: parsed_args_diff.to_string(),
                            });
                            // Note: Python adds full diff to _last_arguments, not just parsed part
                            self.last_arguments.push_str(argument_diff);
                            let tool_id = self.current_tool_id as usize;
                            if tool_id < self.streamed_args_for_tool.len() {
                                self.streamed_args_for_tool[tool_id].push_str(parsed_args_diff);
274
275
                            }
                        }
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297

                        // Check completeness - split by end token first
                        let parsed_args = if let Some(pos) = function_args.find("<|tool_call_end|>")
                        {
                            &function_args[..pos]
                        } else {
                            function_args
                        };

                        if helpers::is_complete_json(parsed_args) {
                            // Update the stored arguments
                            if let Ok(parsed_args_value) =
                                serde_json::from_str::<Value>(parsed_args)
                            {
                                let tool_id = self.current_tool_id as usize;
                                if tool_id < self.prev_tool_call_arr.len() {
                                    if let Some(obj) =
                                        self.prev_tool_call_arr[tool_id].as_object_mut()
                                    {
                                        obj.insert("arguments".to_string(), parsed_args_value);
                                    }
                                }
298
                            }
299
300
301
302
303
304
305

                            // Find the end of the current tool call and remove only that part from buffer
                            if let Some(mat) = self.tool_call_end_pattern.find(current_text) {
                                // Remove the completed tool call from buffer, keep any remaining content
                                self.buffer = current_text[mat.end()..].to_string();
                            } else {
                                self.buffer.clear();
306
                            }
307
308
309
310
311
312
313
314
315
316

                            let result = StreamingParseResult {
                                normal_text: String::new(),
                                calls,
                            };

                            self.current_tool_id += 1;
                            self.last_arguments.clear();
                            self.current_tool_name_sent = false;
                            return Ok(result);
317
318
319
320
321
322
                        }
                    }
                }
            }
        }

323
324
325
326
        Ok(StreamingParseResult {
            normal_text: String::new(),
            calls,
        })
327
328
    }

329
330
    fn has_tool_markers(&self, text: &str) -> bool {
        text.contains("<|tool_calls_section_begin|>")
331
    }
332
333
334
335

    fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
        helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
    }
336
}