qwen.rs 8.82 KB
Newer Older
1
2
3
4
use async_trait::async_trait;
use regex::Regex;
use serde_json::Value;

5
6
7
8
9
10
11
12
13
use crate::{
    protocols::common::Tool,
    tool_parser::{
        errors::{ParserError, ParserResult},
        parsers::helpers,
        partial_json::PartialJson,
        traits::ToolParser,
        types::{FunctionCall, StreamingParseResult, ToolCall},
    },
14
15
16
17
18
19
20
21
22
23
24
};

/// Qwen format parser for tool calls
///
/// Handles the Qwen 2.5/3 specific format:
/// `<tool_call>\n{"name": "func", "arguments": {...}}\n</tool_call>`
///
/// Features:
/// - XML-style tags with JSON content
/// - Support for multiple sequential tool calls
/// - Newline-aware parsing
25
/// - Buffering for partial end tokens
26
27
28
pub struct QwenParser {
    /// Parser for handling incomplete JSON during streaming
    partial_json: PartialJson,
29
30

    /// Regex for extracting tool calls in parse_complete
31
    extractor: Regex,
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54

    /// Buffer for accumulating incomplete patterns across chunks
    buffer: String,

    /// Stores complete tool call info (name and arguments) for each tool being parsed
    prev_tool_call_arr: Vec<Value>,

    /// Index of currently streaming tool call (-1 means no active tool)
    current_tool_id: i32,

    /// Flag for whether current tool's name has been sent to client
    current_tool_name_sent: bool,

    /// Tracks raw JSON string content streamed to client for each tool's arguments
    streamed_args_for_tool: Vec<String>,

    /// Buffer for normal text that might precede partial end tokens
    normal_text_buffer: String,

    /// Token configuration
    bot_token: &'static str,
    eot_token: &'static str,
    tool_call_separator: &'static str,
55
56
57
58
59
60
61
62
63
64
65
66
}

impl QwenParser {
    /// Create a new Qwen parser
    pub fn new() -> Self {
        // Use (?s) flag for DOTALL mode to handle newlines
        let pattern = r"(?s)<tool_call>\n(.*?)\n</tool_call>";
        let extractor = Regex::new(pattern).expect("Valid regex pattern");

        Self {
            partial_json: PartialJson::default(),
            extractor,
67
68
69
70
71
72
73
74
75
            buffer: String::new(),
            prev_tool_call_arr: Vec::new(),
            current_tool_id: -1,
            current_tool_name_sent: false,
            streamed_args_for_tool: Vec::new(),
            normal_text_buffer: String::new(),
            bot_token: "<tool_call>\n",
            eot_token: "\n</tool_call>",
            tool_call_separator: "\n",
76
77
78
79
        }
    }

    /// Parse a single JSON object into a ToolCall
80
    fn parse_single_object(&self, obj: &Value) -> ParserResult<Option<ToolCall>> {
81
82
83
84
85
86
87
88
89
        let name = obj.get("name").and_then(|v| v.as_str());

        if let Some(name) = name {
            // Get arguments - Qwen uses "arguments" key
            let empty_obj = Value::Object(serde_json::Map::new());
            let args = obj.get("arguments").unwrap_or(&empty_obj);

            // Convert arguments to JSON string
            let arguments = serde_json::to_string(args)
90
                .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111

            Ok(Some(ToolCall {
                function: FunctionCall {
                    name: name.to_string(),
                    arguments,
                },
            }))
        } else {
            Ok(None)
        }
    }
}

impl Default for QwenParser {
    fn default() -> Self {
        Self::new()
    }
}

#[async_trait]
impl ToolParser for QwenParser {
112
    async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
113
114
        // Check if text contains Qwen format
        if !self.has_tool_markers(text) {
115
            return Ok((text.to_string(), vec![]));
116
117
        }

118
119
120
        // Find where the first tool call begins
        let idx = text.find("<tool_call>").unwrap(); // Safe because has_tool_markers checked
        let normal_text = text[..idx].to_string();
121

122
123
        // Extract tool calls
        let mut tools = Vec::new();
124
        for captures in self.extractor.captures_iter(text) {
125
            if let Some(json_str) = captures.get(1) {
126
                let parsed = serde_json::from_str::<Value>(json_str.as_str().trim())
127
                    .map_err(|e| ParserError::ParsingFailed(e.to_string()))
128
                    .and_then(|v| self.parse_single_object(&v));
129
130
131
132

                match parsed {
                    Ok(Some(tool)) => tools.push(tool),
                    Ok(None) => continue,
133
                    Err(e) => {
134
                        tracing::warn!("Failed to parse tool call: {:?}", e);
135
                        continue;
136
137
138
139
140
                    }
                }
            }
        }

141
142
143
144
        // If no tools were successfully parsed despite having markers, return entire text as fallback
        if tools.is_empty() {
            return Ok((text.to_string(), vec![]));
        }
145
146

        Ok((normal_text, tools))
147
148
149
    }

    async fn parse_incremental(
150
        &mut self,
151
        chunk: &str,
152
        tools: &[Tool],
153
    ) -> ParserResult<StreamingParseResult> {
154
155
156
157
158
        // Append new text to buffer
        self.buffer.push_str(chunk);
        let current_text = &self.buffer.clone();

        // Check if current_text has tool_call
159
        let has_tool_start = self.has_tool_markers(current_text)
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
            || (self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator));

        if !has_tool_start {
            // Only clear buffer if we're sure no tool call is starting
            if helpers::ends_with_partial_token(&self.buffer, self.bot_token).is_none() {
                let normal_text = self.buffer.clone();
                self.buffer.clear();

                return Ok(StreamingParseResult {
                    normal_text,
                    calls: vec![],
                });
            } else {
                // Might be partial bot_token, keep buffering
                return Ok(StreamingParseResult::default());
175
            }
176
177
        }

178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        // Build tool indices
        let tool_indices = helpers::get_tool_indices(tools);

        // Determine start index for JSON parsing
        let start_idx = if let Some(pos) = current_text.find(self.bot_token) {
            pos + self.bot_token.len()
        } else if self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator) {
            self.tool_call_separator.len()
        } else {
            0
        };

        let mut result = helpers::handle_json_tool_streaming(
            current_text,
            start_idx,
            &mut self.partial_json,
            &tool_indices,
            &mut self.buffer,
            &mut self.current_tool_id,
            &mut self.current_tool_name_sent,
            &mut self.streamed_args_for_tool,
            &mut self.prev_tool_call_arr,
        )?;

        // Qwen-specific: Handle partial end tokens in normal text
        // After tool calls complete, normal text might contain partial "</tool_call>" tags
        if !result.normal_text.is_empty() {
            self.normal_text_buffer.push_str(&result.normal_text);

            // Check if buffer contains complete end token (without leading newline)
            let end_token_without_newline = &self.eot_token[1..]; // "</tool_call>"
            if self.normal_text_buffer.contains(end_token_without_newline) {
                // Complete end token found - clean it and return
                let cleaned_text = self
                    .normal_text_buffer
                    .replace(end_token_without_newline, "");
                self.normal_text_buffer.clear();
                result.normal_text = cleaned_text;
216
            } else {
217
218
219
220
221
222
223
224
225
226
227
228
229
                // Check if buffer might contain partial end token at the end
                if let Some(partial_match_len) = helpers::ends_with_partial_token(
                    &self.normal_text_buffer,
                    end_token_without_newline,
                ) {
                    // Keep potential partial match in buffer, return the rest
                    let split_point = self.normal_text_buffer.len() - partial_match_len;
                    result.normal_text = self.normal_text_buffer[..split_point].to_string();
                    self.normal_text_buffer = self.normal_text_buffer[split_point..].to_string();
                } else {
                    // No partial match, return all buffered text
                    result.normal_text = self.normal_text_buffer.clone();
                    self.normal_text_buffer.clear();
230
231
232
233
                }
            }
        }

234
        Ok(result)
235
236
    }

237
238
    fn has_tool_markers(&self, text: &str) -> bool {
        text.contains("<tool_call>")
239
    }
240
241
242
243

    fn get_unstreamed_tool_args(&self) -> Option<Vec<crate::tool_parser::types::ToolCallItem>> {
        helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
    }
244
245
246
247
248
249
250
251
252
253

    fn reset(&mut self) {
        helpers::reset_parser_state(
            &mut self.buffer,
            &mut self.prev_tool_call_arr,
            &mut self.current_tool_id,
            &mut self.current_tool_name_sent,
            &mut self.streamed_args_for_tool,
        );
    }
254
}