qwen_parser.rs 8.8 KB
Newer Older
1
2
3
4
use async_trait::async_trait;
use regex::Regex;
use serde_json::Value;

5
6
use crate::protocols::spec::Tool;

7
8
use crate::tool_parser::{
    errors::{ToolParserError, ToolParserResult},
9
    parsers::helpers,
10
11
    partial_json::PartialJson,
    traits::ToolParser,
12
    types::{FunctionCall, StreamingParseResult, ToolCall},
13
14
15
16
17
18
19
20
21
22
23
};

/// Qwen format parser for tool calls
///
/// Handles the Qwen 2.5/3 specific format:
/// `<tool_call>\n{"name": "func", "arguments": {...}}\n</tool_call>`
///
/// Features:
/// - XML-style tags with JSON content
/// - Support for multiple sequential tool calls
/// - Newline-aware parsing
24
/// - Buffering for partial end tokens
25
26
27
pub struct QwenParser {
    /// Parser for handling incomplete JSON during streaming
    partial_json: PartialJson,
28
29

    /// Regex for extracting tool calls in parse_complete
30
    extractor: Regex,
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53

    /// Buffer for accumulating incomplete patterns across chunks
    buffer: String,

    /// Stores complete tool call info (name and arguments) for each tool being parsed
    prev_tool_call_arr: Vec<Value>,

    /// Index of currently streaming tool call (-1 means no active tool)
    current_tool_id: i32,

    /// Flag for whether current tool's name has been sent to client
    current_tool_name_sent: bool,

    /// Tracks raw JSON string content streamed to client for each tool's arguments
    streamed_args_for_tool: Vec<String>,

    /// Buffer for normal text that might precede partial end tokens
    normal_text_buffer: String,

    /// Token configuration
    bot_token: &'static str,
    eot_token: &'static str,
    tool_call_separator: &'static str,
54
55
56
57
58
59
60
61
62
63
64
65
}

impl QwenParser {
    /// Create a new Qwen parser
    pub fn new() -> Self {
        // Use (?s) flag for DOTALL mode to handle newlines
        let pattern = r"(?s)<tool_call>\n(.*?)\n</tool_call>";
        let extractor = Regex::new(pattern).expect("Valid regex pattern");

        Self {
            partial_json: PartialJson::default(),
            extractor,
66
67
68
69
70
71
72
73
74
            buffer: String::new(),
            prev_tool_call_arr: Vec::new(),
            current_tool_id: -1,
            current_tool_name_sent: false,
            streamed_args_for_tool: Vec::new(),
            normal_text_buffer: String::new(),
            bot_token: "<tool_call>\n",
            eot_token: "\n</tool_call>",
            tool_call_separator: "\n",
75
76
77
78
        }
    }

    /// Parse a single JSON object into a ToolCall
79
    fn parse_single_object(&self, obj: &Value) -> ToolParserResult<Option<ToolCall>> {
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
        let name = obj.get("name").and_then(|v| v.as_str());

        if let Some(name) = name {
            // Get arguments - Qwen uses "arguments" key
            let empty_obj = Value::Object(serde_json::Map::new());
            let args = obj.get("arguments").unwrap_or(&empty_obj);

            // Convert arguments to JSON string
            let arguments = serde_json::to_string(args)
                .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;

            Ok(Some(ToolCall {
                function: FunctionCall {
                    name: name.to_string(),
                    arguments,
                },
            }))
        } else {
            Ok(None)
        }
    }

    /// Check if text contains Qwen tool markers
    fn has_tool_markers(&self, text: &str) -> bool {
        text.contains("<tool_call>")
    }

107
108
109
    /// Check if text has tool call
    fn has_tool_call(&self, text: &str) -> bool {
        text.contains("<tool_call>")
110
111
112
113
114
115
116
117
118
119
120
    }
}

impl Default for QwenParser {
    fn default() -> Self {
        Self::new()
    }
}

#[async_trait]
impl ToolParser for QwenParser {
121
    async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
122
123
        // Check if text contains Qwen format
        if !self.has_tool_markers(text) {
124
            return Ok((text.to_string(), vec![]));
125
126
        }

127
128
129
        // Find where the first tool call begins
        let idx = text.find("<tool_call>").unwrap(); // Safe because has_tool_markers checked
        let normal_text = text[..idx].to_string();
130

131
132
        // Extract tool calls
        let mut tools = Vec::new();
133
        for captures in self.extractor.captures_iter(text) {
134
            if let Some(json_str) = captures.get(1) {
135
136
                let parsed = serde_json::from_str::<Value>(json_str.as_str().trim())
                    .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))
137
                    .and_then(|v| self.parse_single_object(&v));
138
139
140
141

                match parsed {
                    Ok(Some(tool)) => tools.push(tool),
                    Ok(None) => continue,
142
                    Err(e) => {
143
                        tracing::warn!("Failed to parse tool call: {:?}", e);
144
                        continue;
145
146
147
148
149
                    }
                }
            }
        }

150
151
152
153
        // If no tools were successfully parsed despite having markers, return entire text as fallback
        if tools.is_empty() {
            return Ok((text.to_string(), vec![]));
        }
154
155

        Ok((normal_text, tools))
156
157
158
    }

    async fn parse_incremental(
159
        &mut self,
160
        chunk: &str,
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
        tools: &[Tool],
    ) -> ToolParserResult<StreamingParseResult> {
        // Append new text to buffer
        self.buffer.push_str(chunk);
        let current_text = &self.buffer.clone();

        // Check if current_text has tool_call
        let has_tool_start = self.has_tool_call(current_text)
            || (self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator));

        if !has_tool_start {
            // Only clear buffer if we're sure no tool call is starting
            if helpers::ends_with_partial_token(&self.buffer, self.bot_token).is_none() {
                let normal_text = self.buffer.clone();
                self.buffer.clear();

                return Ok(StreamingParseResult {
                    normal_text,
                    calls: vec![],
                });
            } else {
                // Might be partial bot_token, keep buffering
                return Ok(StreamingParseResult::default());
184
            }
185
186
        }

187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
        // Build tool indices
        let tool_indices = helpers::get_tool_indices(tools);

        // Determine start index for JSON parsing
        let start_idx = if let Some(pos) = current_text.find(self.bot_token) {
            pos + self.bot_token.len()
        } else if self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator) {
            self.tool_call_separator.len()
        } else {
            0
        };

        let mut result = helpers::handle_json_tool_streaming(
            current_text,
            start_idx,
            &mut self.partial_json,
            &tool_indices,
            &mut self.buffer,
            &mut self.current_tool_id,
            &mut self.current_tool_name_sent,
            &mut self.streamed_args_for_tool,
            &mut self.prev_tool_call_arr,
        )?;

        // Qwen-specific: Handle partial end tokens in normal text
        // After tool calls complete, normal text might contain partial "</tool_call>" tags
        if !result.normal_text.is_empty() {
            self.normal_text_buffer.push_str(&result.normal_text);

            // Check if buffer contains complete end token (without leading newline)
            let end_token_without_newline = &self.eot_token[1..]; // "</tool_call>"
            if self.normal_text_buffer.contains(end_token_without_newline) {
                // Complete end token found - clean it and return
                let cleaned_text = self
                    .normal_text_buffer
                    .replace(end_token_without_newline, "");
                self.normal_text_buffer.clear();
                result.normal_text = cleaned_text;
225
            } else {
226
227
228
229
230
231
232
233
234
235
236
237
238
                // Check if buffer might contain partial end token at the end
                if let Some(partial_match_len) = helpers::ends_with_partial_token(
                    &self.normal_text_buffer,
                    end_token_without_newline,
                ) {
                    // Keep potential partial match in buffer, return the rest
                    let split_point = self.normal_text_buffer.len() - partial_match_len;
                    result.normal_text = self.normal_text_buffer[..split_point].to_string();
                    self.normal_text_buffer = self.normal_text_buffer[split_point..].to_string();
                } else {
                    // No partial match, return all buffered text
                    result.normal_text = self.normal_text_buffer.clone();
                    self.normal_text_buffer.clear();
239
240
241
242
                }
            }
        }

243
        Ok(result)
244
245
246
    }

    fn detect_format(&self, text: &str) -> bool {
247
        self.has_tool_markers(text)
248
    }
249
250
251
252

    fn get_unstreamed_tool_args(&self) -> Option<Vec<crate::tool_parser::types::ToolCallItem>> {
        helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
    }
253
}