kimik2_parser.rs 14.1 KB
Newer Older
1
2
use async_trait::async_trait;
use regex::Regex;
3
4
use serde_json::Value;

5
6
7
8
9
10
11
12
use crate::{
    protocols::common::Tool,
    tool_parser::{
        errors::ParserResult,
        parsers::helpers,
        traits::ToolParser,
        types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
    },
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
};

/// Kimi K2 format parser for tool calls
///
/// Handles the Kimi K2 specific format:
/// `<|tool_calls_section_begin|><|tool_call_begin|>functions.{name}:{index}<|tool_call_argument_begin|>{json_args}<|tool_call_end|><|tool_calls_section_end|>`
///
/// Features:
/// - Token-based delimiters
/// - Function calls with explicit indexing
/// - JSON arguments
pub struct KimiK2Parser {
    /// Regex for extracting complete tool calls
    tool_call_extractor: Regex,
    /// Regex for extracting partial tool calls (streaming)
    stream_tool_call_extractor: Regex,
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    /// Regex pattern for removing completed tool calls from buffer
    tool_call_end_pattern: Regex,
    /// Robust parser for ids like "functions.search:0" or fallback "search:0"
    tool_call_id_regex: Regex,

    /// Buffer for accumulating incomplete patterns across chunks
    buffer: String,

    /// Stores complete tool call info (name and arguments) for each tool being parsed
    prev_tool_call_arr: Vec<Value>,

    /// Index of currently streaming tool call (-1 means no active tool)
    current_tool_id: i32,

    /// Flag for whether current tool's name has been sent to client
    current_tool_name_sent: bool,

    /// Tracks raw JSON string content streamed to client for each tool's arguments
    streamed_args_for_tool: Vec<String>,

    /// Tracks the last arguments sent for incremental diffing
    last_arguments: String,
51
52
53
54
55
56
57
58
59
60
61
62
63
}

impl KimiK2Parser {
    /// Create a new Kimi K2 parser
    pub fn new() -> Self {
        // Pattern for complete tool calls
        let tool_call_pattern = r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*?\})\s*<\|tool_call_end\|>";
        let tool_call_extractor = Regex::new(tool_call_pattern).expect("Valid regex pattern");

        // Pattern for streaming (partial) tool calls
        let stream_pattern = r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*)";
        let stream_tool_call_extractor = Regex::new(stream_pattern).expect("Valid regex pattern");

64
65
66
67
68
69
70
71
        // Pattern for removing completed tool calls
        let end_pattern = r"<\|tool_call_begin\|>.*?<\|tool_call_end\|>";
        let tool_call_end_pattern = Regex::new(end_pattern).expect("Valid regex pattern");

        // Robust parser for ids like "functions.search:0" or fallback "search:0"
        let id_pattern = r"^(?:functions\.)?(?P<name>[\w\.]+):(?P<index>\d+)$";
        let tool_call_id_regex = Regex::new(id_pattern).expect("Valid regex pattern");

72
73
74
        Self {
            tool_call_extractor,
            stream_tool_call_extractor,
75
76
77
78
79
80
81
82
            tool_call_end_pattern,
            tool_call_id_regex,
            buffer: String::new(),
            prev_tool_call_arr: Vec::new(),
            current_tool_id: -1,
            current_tool_name_sent: false,
            streamed_args_for_tool: Vec::new(),
            last_arguments: String::new(),
83
84
85
86
87
        }
    }

    /// Parse function ID to extract name and index
    fn parse_function_id(&self, id: &str) -> Option<(String, usize)> {
88
89
90
91
92
93
        if let Some(captures) = self.tool_call_id_regex.captures(id) {
            let name = captures.name("name")?.as_str().to_string();
            let index = captures.name("index")?.as_str().parse::<usize>().ok()?;
            Some((name, index))
        } else {
            None
94
95
96
97
98
99
100
101
102
103
104
105
        }
    }
}

impl Default for KimiK2Parser {
    fn default() -> Self {
        Self::new()
    }
}

#[async_trait]
impl ToolParser for KimiK2Parser {
106
    async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
107
        if !self.has_tool_markers(text) {
108
            return Ok((text.to_string(), vec![]));
109
110
        }

111
112
113
        // Find where tool calls begin
        let idx = text.find("<|tool_calls_section_begin|>").unwrap();
        let normal_text = text[..idx].to_string();
114

115
116
117
        // Try to extract tool calls
        let mut tools = Vec::new();
        for captures in self.tool_call_extractor.captures_iter(text) {
118
119
120
121
122
123
124
125
126
            if let (Some(id_match), Some(args_match)) = (
                captures.name("tool_call_id"),
                captures.name("function_arguments"),
            ) {
                let function_id = id_match.as_str();
                let function_args = args_match.as_str();

                // Parse function ID
                if let Some((func_name, _index)) = self.parse_function_id(function_id) {
127
                    // Try to parse JSON arguments
128
                    match serde_json::from_str::<Value>(function_args) {
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
                        Ok(_) => {
                            tools.push(ToolCall {
                                function: FunctionCall {
                                    name: func_name,
                                    arguments: function_args.to_string(),
                                },
                            });
                        }
                        Err(e) => {
                            tracing::warn!(
                                "Failed to parse JSON arguments for {}: {}",
                                func_name,
                                e
                            );
                            continue;
                        }
145
                    }
146
147
148
                } else {
                    tracing::warn!("Failed to parse function ID: {}", function_id);
                    continue;
149
150
151
152
                }
            }
        }

153
154
155
156
        // If no tools were successfully parsed despite having markers, return entire text as fallback
        if tools.is_empty() {
            return Ok((text.to_string(), vec![]));
        }
157
158

        Ok((normal_text, tools))
159
160
161
    }

    async fn parse_incremental(
162
        &mut self,
163
        chunk: &str,
164
        tools: &[Tool],
165
    ) -> ParserResult<StreamingParseResult> {
166
167
        self.buffer.push_str(chunk);
        let current_text = &self.buffer.clone();
168

169
        // Check if we have a tool call (either the start token or individual tool call)
170
        let has_tool_call =
171
            self.has_tool_markers(current_text) || current_text.contains("<|tool_call_begin|>");
172
173

        if !has_tool_call {
174
            // No tool markers detected - return all buffered content as normal text
175
176
177
178
179
180
181
182
183
            let mut normal_text = std::mem::take(&mut self.buffer);
            // Remove end tokens if present
            for e_token in ["<|tool_calls_section_end|>", "<|tool_call_end|>"] {
                normal_text = normal_text.replace(e_token, "");
            }
            return Ok(StreamingParseResult {
                normal_text,
                calls: vec![],
            });
184
185
        }

186
187
        // Build tool indices for validation
        let tool_indices = helpers::get_tool_indices(tools);
188

189
        let mut calls: Vec<ToolCallItem> = Vec::new();
190
191

        // Try to match streaming pattern
192
        if let Some(captures) = self.stream_tool_call_extractor.captures(current_text) {
193
194
195
196
197
            if let (Some(id_match), Some(args_match)) = (
                captures.name("tool_call_id"),
                captures.name("function_arguments"),
            ) {
                let function_id = id_match.as_str();
198
                let function_args = args_match.as_str();
199
200
201

                // Parse function ID
                if let Some((func_name, _index)) = self.parse_function_id(function_id) {
202
203
204
205
206
207
208
209
210
211
212
                    // Validate tool name
                    if !tool_indices.contains_key(&func_name) {
                        // Invalid tool name - skip this tool, preserve indexing for next tool
                        tracing::warn!("Invalid tool name '{}' - skipping", func_name);
                        helpers::reset_current_tool_state(
                            &mut self.buffer,
                            &mut self.current_tool_name_sent,
                            &mut self.streamed_args_for_tool,
                            &self.prev_tool_call_arr,
                        );
                        return Ok(StreamingParseResult::default());
213
214
                    }

215
216
217
218
219
220
                    // Initialize state if this is the first tool call
                    if self.current_tool_id == -1 {
                        self.current_tool_id = 0;
                        self.prev_tool_call_arr = Vec::new();
                        self.streamed_args_for_tool = vec![String::new()];
                    }
221

222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
                    // Ensure we have enough entries in our tracking arrays
                    helpers::ensure_capacity(
                        self.current_tool_id,
                        &mut self.prev_tool_call_arr,
                        &mut self.streamed_args_for_tool,
                    );

                    // Send tool name if not sent yet
                    if !self.current_tool_name_sent {
                        calls.push(ToolCallItem {
                            tool_index: self.current_tool_id as usize,
                            name: Some(func_name.clone()),
                            parameters: String::new(),
                        });
                        self.current_tool_name_sent = true;
237

238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
                        // Store the tool call info for serving layer completions endpoint
                        let tool_id = self.current_tool_id as usize;
                        if self.prev_tool_call_arr.len() <= tool_id {
                            self.prev_tool_call_arr
                                .resize_with(tool_id + 1, || Value::Null);
                        }
                        self.prev_tool_call_arr[tool_id] = serde_json::json!({
                            "name": func_name,
                            "arguments": {},
                        });
                    } else {
                        // Compute incremental diff
                        let argument_diff = if function_args.starts_with(&self.last_arguments) {
                            &function_args[self.last_arguments.len()..]
                        } else {
                            function_args
                        };

                        // Split by end token before sending (like Python does)
                        let parsed_args_diff =
                            if let Some(pos) = argument_diff.find("<|tool_call_end|>") {
                                &argument_diff[..pos]
                            } else {
                                argument_diff
262
263
                            };

264
265
266
267
268
269
270
271
272
273
274
                        if !parsed_args_diff.is_empty() {
                            calls.push(ToolCallItem {
                                tool_index: self.current_tool_id as usize,
                                name: None,
                                parameters: parsed_args_diff.to_string(),
                            });
                            // Note: Python adds full diff to _last_arguments, not just parsed part
                            self.last_arguments.push_str(argument_diff);
                            let tool_id = self.current_tool_id as usize;
                            if tool_id < self.streamed_args_for_tool.len() {
                                self.streamed_args_for_tool[tool_id].push_str(parsed_args_diff);
275
276
                            }
                        }
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298

                        // Check completeness - split by end token first
                        let parsed_args = if let Some(pos) = function_args.find("<|tool_call_end|>")
                        {
                            &function_args[..pos]
                        } else {
                            function_args
                        };

                        if helpers::is_complete_json(parsed_args) {
                            // Update the stored arguments
                            if let Ok(parsed_args_value) =
                                serde_json::from_str::<Value>(parsed_args)
                            {
                                let tool_id = self.current_tool_id as usize;
                                if tool_id < self.prev_tool_call_arr.len() {
                                    if let Some(obj) =
                                        self.prev_tool_call_arr[tool_id].as_object_mut()
                                    {
                                        obj.insert("arguments".to_string(), parsed_args_value);
                                    }
                                }
299
                            }
300
301
302
303
304
305
306

                            // Find the end of the current tool call and remove only that part from buffer
                            if let Some(mat) = self.tool_call_end_pattern.find(current_text) {
                                // Remove the completed tool call from buffer, keep any remaining content
                                self.buffer = current_text[mat.end()..].to_string();
                            } else {
                                self.buffer.clear();
307
                            }
308
309
310
311
312
313
314
315
316
317

                            let result = StreamingParseResult {
                                normal_text: String::new(),
                                calls,
                            };

                            self.current_tool_id += 1;
                            self.last_arguments.clear();
                            self.current_tool_name_sent = false;
                            return Ok(result);
318
319
320
321
322
323
                        }
                    }
                }
            }
        }

324
325
326
327
        Ok(StreamingParseResult {
            normal_text: String::new(),
            calls,
        })
328
329
    }

330
331
    fn has_tool_markers(&self, text: &str) -> bool {
        text.contains("<|tool_calls_section_begin|>")
332
    }
333
334
335
336

    fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
        helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
    }
337
338
339
340
341
342
343
344
345

    fn reset(&mut self) {
        self.buffer.clear();
        self.prev_tool_call_arr.clear();
        self.current_tool_id = -1;
        self.current_tool_name_sent = false;
        self.streamed_args_for_tool.clear();
        self.last_arguments.clear();
    }
346
}