glm4_moe_parser.rs 11.7 KB
Newer Older
1
2
3
4
use async_trait::async_trait;
use regex::Regex;
use serde_json::Value;

5
6
7
8
9
10
11
12
use crate::{
    protocols::common::Tool,
    tool_parser::{
        errors::{ParserError, ParserResult},
        parsers::helpers,
        traits::ToolParser,
        types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
    },
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
};

/// GLM-4 MoE format parser for tool calls
///
/// Handles the GLM-4 MoE specific format:
/// `<tool_call>{name}\n<arg_key>{key}</arg_key>\n<arg_value>{value}</arg_value>\n</tool_call>`
///
/// Features:
/// - XML-style tags for tool calls
/// - Key-value pairs for arguments
/// - Support for multiple sequential tool calls
pub struct Glm4MoeParser {
    /// Regex for extracting complete tool calls
    tool_call_extractor: Regex,
    /// Regex for extracting function details
    func_detail_extractor: Regex,
    /// Regex for extracting argument key-value pairs
    arg_extractor: Regex,
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

    /// Buffer for accumulating incomplete patterns across chunks
    buffer: String,

    /// Stores complete tool call info (name and arguments) for each tool being parsed
    prev_tool_call_arr: Vec<Value>,

    /// Index of currently streaming tool call (-1 means no active tool)
    current_tool_id: i32,

    /// Tracks raw JSON string content streamed to client for each tool's arguments
    streamed_args_for_tool: Vec<String>,

    /// Token configuration
    bot_token: &'static str,
    eot_token: &'static str,
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
}

impl Glm4MoeParser {
    /// Create a new GLM-4 MoE parser
    pub fn new() -> Self {
        // Use (?s) flag for DOTALL mode to handle newlines
        let tool_call_pattern = r"(?s)<tool_call>.*?</tool_call>";
        let tool_call_extractor = Regex::new(tool_call_pattern).expect("Valid regex pattern");

        let func_detail_pattern = r"(?s)<tool_call>([^\n]*)\n(.*)</tool_call>";
        let func_detail_extractor = Regex::new(func_detail_pattern).expect("Valid regex pattern");

        let arg_pattern = r"(?s)<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>";
        let arg_extractor = Regex::new(arg_pattern).expect("Valid regex pattern");

        Self {
            tool_call_extractor,
            func_detail_extractor,
            arg_extractor,
66
67
68
69
70
71
            buffer: String::new(),
            prev_tool_call_arr: Vec::new(),
            current_tool_id: -1,
            streamed_args_for_tool: Vec::new(),
            bot_token: "<tool_call>",
            eot_token: "</tool_call>",
72
73
74
75
        }
    }

    /// Parse arguments from key-value pairs
76
    fn parse_arguments(&self, args_text: &str) -> ParserResult<serde_json::Map<String, Value>> {
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        let mut arguments = serde_json::Map::new();

        for capture in self.arg_extractor.captures_iter(args_text) {
            let key = capture.get(1).map_or("", |m| m.as_str()).trim();
            let value_str = capture.get(2).map_or("", |m| m.as_str()).trim();

            // Try to parse the value as JSON first, fallback to string
            let value = if let Ok(json_val) = serde_json::from_str::<Value>(value_str) {
                json_val
            } else {
                // Try parsing as Python literal (similar to Python's ast.literal_eval)
                if value_str == "true" || value_str == "True" {
                    Value::Bool(true)
                } else if value_str == "false" || value_str == "False" {
                    Value::Bool(false)
                } else if value_str == "null" || value_str == "None" {
                    Value::Null
                } else if let Ok(num) = value_str.parse::<i64>() {
                    Value::Number(num.into())
                } else if let Ok(num) = value_str.parse::<f64>() {
                    if let Some(n) = serde_json::Number::from_f64(num) {
                        Value::Number(n)
                    } else {
                        Value::String(value_str.to_string())
                    }
                } else {
                    Value::String(value_str.to_string())
                }
            };

            arguments.insert(key.to_string(), value);
        }

        Ok(arguments)
    }

    /// Parse a single tool call block
114
    fn parse_tool_call(&self, block: &str) -> ParserResult<Option<ToolCall>> {
115
116
117
118
119
120
121
122
123
124
125
        if let Some(captures) = self.func_detail_extractor.captures(block) {
            // Get function name
            let func_name = captures.get(1).map_or("", |m| m.as_str()).trim();

            // Get arguments text
            let args_text = captures.get(2).map_or("", |m| m.as_str());

            // Parse arguments
            let arguments = self.parse_arguments(args_text)?;

            let arguments_str = serde_json::to_string(&arguments)
126
                .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
127
128
129
130
131
132
133
134
135
136
137

            Ok(Some(ToolCall {
                function: FunctionCall {
                    name: func_name.to_string(),
                    arguments: arguments_str,
                },
            }))
        } else {
            Ok(None)
        }
    }
138
139
140

    /// Parse and return StreamingParseResult (mirrors Python's detect_and_parse)
    /// Parse all tool calls from text (shared logic for complete and incremental parsing)
141
    fn parse_tool_calls_from_text(&self, text: &str) -> ParserResult<Vec<ToolCall>> {
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
        let mut tools = Vec::new();

        for mat in self.tool_call_extractor.find_iter(text) {
            match self.parse_tool_call(mat.as_str()) {
                Ok(Some(tool)) => tools.push(tool),
                Ok(None) => continue,
                Err(e) => {
                    tracing::warn!("Failed to parse tool call: {}", e);
                    continue;
                }
            }
        }

        Ok(tools)
    }
157
158
159
160
161
162
163
164
165
166
}

impl Default for Glm4MoeParser {
    fn default() -> Self {
        Self::new()
    }
}

#[async_trait]
impl ToolParser for Glm4MoeParser {
167
    async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
168
169
        // Check if text contains GLM-4 MoE format
        if !self.has_tool_markers(text) {
170
            return Ok((text.to_string(), vec![]));
171
172
        }

173
174
175
        // Find where tool calls begin
        let idx = text.find("<tool_call>").unwrap();
        let normal_text = text[..idx].to_string();
176

177
178
        // Parse all tool calls using shared helper
        let tools = self.parse_tool_calls_from_text(text)?;
179

180
181
182
183
        // If no tools were successfully parsed despite having markers, return entire text as fallback
        if tools.is_empty() {
            return Ok((text.to_string(), vec![]));
        }
184
185

        Ok((normal_text, tools))
186
187
188
    }

    async fn parse_incremental(
189
        &mut self,
190
        chunk: &str,
191
        tools: &[Tool],
192
    ) -> ParserResult<StreamingParseResult> {
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
        // Python logic: Wait for complete tool call, then parse it all at once
        self.buffer.push_str(chunk);
        let current_text = &self.buffer.clone();

        // Check if we have bot_token
        let start = current_text.find(self.bot_token);
        if start.is_none() {
            self.buffer.clear();
            // If we're in the middle of streaming (current_tool_id > 0), don't return text
            let normal_text = if self.current_tool_id > 0 {
                String::new()
            } else {
                current_text.clone()
            };
            return Ok(StreamingParseResult {
                normal_text,
                calls: vec![],
            });
211
212
        }

213
214
215
216
217
218
219
220
221
222
        // Check if we have eot_token (end of tool call)
        let end = current_text.find(self.eot_token);
        if let Some(end_pos) = end {
            // We have a complete tool call!

            // Initialize state if this is the first tool call
            if self.current_tool_id == -1 {
                self.current_tool_id = 0;
                self.prev_tool_call_arr = Vec::new();
                self.streamed_args_for_tool = vec![String::new()];
223
            }
224

225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
            // Ensure we have enough entries in our tracking arrays
            helpers::ensure_capacity(
                self.current_tool_id,
                &mut self.prev_tool_call_arr,
                &mut self.streamed_args_for_tool,
            );

            // Parse the complete block using shared helper
            let block_end = end_pos + self.eot_token.len();
            let parsed_tools = self.parse_tool_calls_from_text(&current_text[..block_end])?;

            // Extract normal text before tool calls
            let idx = current_text.find(self.bot_token);
            let normal_text = if let Some(pos) = idx {
                current_text[..pos].trim().to_string()
            } else {
                String::new()
            };
243

244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
            // Build tool indices for validation
            let tool_indices = helpers::get_tool_indices(tools);

            let mut calls = Vec::new();

            if !parsed_tools.is_empty() {
                // Take the first tool and convert to ToolCallItem
                let tool_call = &parsed_tools[0];
                let tool_id = self.current_tool_id as usize;

                // Validate tool name
                if !tool_indices.contains_key(&tool_call.function.name) {
                    // Invalid tool name - skip this tool, preserve indexing for next tool
                    tracing::warn!("Invalid tool name '{}' - skipping", tool_call.function.name);
                    helpers::reset_current_tool_state(
                        &mut self.buffer,
                        &mut false, // glm4_moe doesn't track name_sent per tool
                        &mut self.streamed_args_for_tool,
                        &self.prev_tool_call_arr,
                    );
                    return Ok(StreamingParseResult::default());
                }
266

267
268
269
270
271
                calls.push(ToolCallItem {
                    tool_index: tool_id,
                    name: Some(tool_call.function.name.clone()),
                    parameters: tool_call.function.arguments.clone(),
                });
272

273
274
275
276
                // Store in tracking arrays
                if self.prev_tool_call_arr.len() <= tool_id {
                    self.prev_tool_call_arr
                        .resize_with(tool_id + 1, || Value::Null);
277
278
                }

279
280
281
282
283
284
285
                // Parse parameters as JSON and store
                if let Ok(args) = serde_json::from_str::<Value>(&tool_call.function.arguments) {
                    self.prev_tool_call_arr[tool_id] = serde_json::json!({
                        "name": tool_call.function.name,
                        "arguments": args,
                    });
                }
286

287
288
289
                if self.streamed_args_for_tool.len() <= tool_id {
                    self.streamed_args_for_tool
                        .resize_with(tool_id + 1, String::new);
290
                }
291
292
293
                self.streamed_args_for_tool[tool_id] = tool_call.function.arguments.clone();

                self.current_tool_id += 1;
294
            }
295
296
297
298

            // Remove processed portion from buffer
            self.buffer = current_text[block_end..].to_string();
            return Ok(StreamingParseResult { normal_text, calls });
299
300
        }

301
302
303
304
305
306
307
308
309
        // No complete tool call yet - return normal text before start token
        let start_pos = start.unwrap();
        let normal_text = current_text[..start_pos].to_string();
        self.buffer = current_text[start_pos..].to_string();

        Ok(StreamingParseResult {
            normal_text,
            calls: vec![],
        })
310
311
    }

312
313
    fn has_tool_markers(&self, text: &str) -> bool {
        text.contains(self.bot_token)
314
    }
315
316
317
318

    fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
        helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
    }
319
320
321
322
323
324
325

    fn reset(&mut self) {
        self.buffer.clear();
        self.prev_tool_call_arr.clear();
        self.current_tool_id = -1;
        self.streamed_args_for_tool.clear();
    }
326
}