qwen_parser.rs 8.92 KB
Newer Older
1
2
3
4
use async_trait::async_trait;
use regex::Regex;
use serde_json::Value;

5
6
use crate::protocols::spec::Tool;

7
8
use crate::tool_parser::{
    errors::{ToolParserError, ToolParserResult},
9
    parsers::helpers,
10
11
    partial_json::PartialJson,
    traits::ToolParser,
12
    types::{FunctionCall, StreamingParseResult, ToolCall},
13
14
15
16
17
18
19
20
21
22
23
};

/// Qwen format parser for tool calls
///
/// Handles the Qwen 2.5/3 specific format:
/// `<tool_call>\n{"name": "func", "arguments": {...}}\n</tool_call>`
///
/// Features:
/// - XML-style tags with JSON content
/// - Support for multiple sequential tool calls
/// - Newline-aware parsing
24
/// - Buffering for partial end tokens
25
26
27
pub struct QwenParser {
    /// Parser for handling incomplete JSON during streaming
    partial_json: PartialJson,
28
29

    /// Regex for extracting tool calls in parse_complete
30
    extractor: Regex,
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53

    /// Buffer for accumulating incomplete patterns across chunks
    buffer: String,

    /// Stores complete tool call info (name and arguments) for each tool being parsed
    prev_tool_call_arr: Vec<Value>,

    /// Index of currently streaming tool call (-1 means no active tool)
    current_tool_id: i32,

    /// Flag for whether current tool's name has been sent to client
    current_tool_name_sent: bool,

    /// Tracks raw JSON string content streamed to client for each tool's arguments
    streamed_args_for_tool: Vec<String>,

    /// Buffer for normal text that might precede partial end tokens
    normal_text_buffer: String,

    /// Token configuration
    bot_token: &'static str,
    eot_token: &'static str,
    tool_call_separator: &'static str,
54
55
56
57
58
59
60
61
62
63
64
65
}

impl QwenParser {
    /// Create a new Qwen parser
    pub fn new() -> Self {
        // Use (?s) flag for DOTALL mode to handle newlines
        let pattern = r"(?s)<tool_call>\n(.*?)\n</tool_call>";
        let extractor = Regex::new(pattern).expect("Valid regex pattern");

        Self {
            partial_json: PartialJson::default(),
            extractor,
66
67
68
69
70
71
72
73
74
            buffer: String::new(),
            prev_tool_call_arr: Vec::new(),
            current_tool_id: -1,
            current_tool_name_sent: false,
            streamed_args_for_tool: Vec::new(),
            normal_text_buffer: String::new(),
            bot_token: "<tool_call>\n",
            eot_token: "\n</tool_call>",
            tool_call_separator: "\n",
75
76
77
78
        }
    }

    /// Parse a single JSON object into a ToolCall
79
    fn parse_single_object(&self, obj: &Value) -> ToolParserResult<Option<ToolCall>> {
80
81
82
83
84
85
86
87
88
89
90
        let name = obj.get("name").and_then(|v| v.as_str());

        if let Some(name) = name {
            // Get arguments - Qwen uses "arguments" key
            let empty_obj = Value::Object(serde_json::Map::new());
            let args = obj.get("arguments").unwrap_or(&empty_obj);

            // Convert arguments to JSON string
            let arguments = serde_json::to_string(args)
                .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;

91
92
93
94
95
96
            // Generate unique ID
            let id = obj
                .get("id")
                .and_then(|v| v.as_str())
                .map(String::from)
                .unwrap_or_else(|| format!("qwen_call_{}", uuid::Uuid::new_v4()));
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115

            Ok(Some(ToolCall {
                id,
                r#type: "function".to_string(),
                function: FunctionCall {
                    name: name.to_string(),
                    arguments,
                },
            }))
        } else {
            Ok(None)
        }
    }

    /// Check if text contains Qwen tool markers
    fn has_tool_markers(&self, text: &str) -> bool {
        text.contains("<tool_call>")
    }

116
117
118
    /// Check if text has tool call
    fn has_tool_call(&self, text: &str) -> bool {
        text.contains("<tool_call>")
119
120
121
122
123
124
125
126
127
128
129
    }
}

impl Default for QwenParser {
    fn default() -> Self {
        Self::new()
    }
}

#[async_trait]
impl ToolParser for QwenParser {
130
    async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
131
132
        // Check if text contains Qwen format
        if !self.has_tool_markers(text) {
133
            return Ok((text.to_string(), vec![]));
134
135
        }

136
137
138
        // Find where the first tool call begins
        let idx = text.find("<tool_call>").unwrap(); // Safe because has_tool_markers checked
        let normal_text = text[..idx].to_string();
139

140
141
        // Extract tool calls
        let mut tools = Vec::new();
142
        for captures in self.extractor.captures_iter(text) {
143
            if let Some(json_str) = captures.get(1) {
144
145
                let parsed = serde_json::from_str::<Value>(json_str.as_str().trim())
                    .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))
146
                    .and_then(|v| self.parse_single_object(&v));
147
148
149
150

                match parsed {
                    Ok(Some(tool)) => tools.push(tool),
                    Ok(None) => continue,
151
                    Err(e) => {
152
                        tracing::warn!("Failed to parse tool call: {:?}", e);
153
                        continue;
154
155
156
157
158
                    }
                }
            }
        }

159
160
161
162
        // If no tools were successfully parsed despite having markers, return entire text as fallback
        if tools.is_empty() {
            return Ok((text.to_string(), vec![]));
        }
163
164

        Ok((normal_text, tools))
165
166
167
    }

    async fn parse_incremental(
168
        &mut self,
169
        chunk: &str,
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
        tools: &[Tool],
    ) -> ToolParserResult<StreamingParseResult> {
        // Append new text to buffer
        self.buffer.push_str(chunk);
        let current_text = &self.buffer.clone();

        // Check if current_text has tool_call
        let has_tool_start = self.has_tool_call(current_text)
            || (self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator));

        if !has_tool_start {
            // Only clear buffer if we're sure no tool call is starting
            if helpers::ends_with_partial_token(&self.buffer, self.bot_token).is_none() {
                let normal_text = self.buffer.clone();
                self.buffer.clear();

                return Ok(StreamingParseResult {
                    normal_text,
                    calls: vec![],
                });
            } else {
                // Might be partial bot_token, keep buffering
                return Ok(StreamingParseResult::default());
193
            }
194
195
        }

196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
        // Build tool indices
        let tool_indices = helpers::get_tool_indices(tools);

        // Determine start index for JSON parsing
        let start_idx = if let Some(pos) = current_text.find(self.bot_token) {
            pos + self.bot_token.len()
        } else if self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator) {
            self.tool_call_separator.len()
        } else {
            0
        };

        let mut result = helpers::handle_json_tool_streaming(
            current_text,
            start_idx,
            &mut self.partial_json,
            &tool_indices,
            &mut self.buffer,
            &mut self.current_tool_id,
            &mut self.current_tool_name_sent,
            &mut self.streamed_args_for_tool,
            &mut self.prev_tool_call_arr,
        )?;

        // Qwen-specific: Handle partial end tokens in normal text
        // After tool calls complete, normal text might contain partial "</tool_call>" tags
        if !result.normal_text.is_empty() {
            self.normal_text_buffer.push_str(&result.normal_text);

            // Check if buffer contains complete end token (without leading newline)
            let end_token_without_newline = &self.eot_token[1..]; // "</tool_call>"
            if self.normal_text_buffer.contains(end_token_without_newline) {
                // Complete end token found - clean it and return
                let cleaned_text = self
                    .normal_text_buffer
                    .replace(end_token_without_newline, "");
                self.normal_text_buffer.clear();
                result.normal_text = cleaned_text;
234
            } else {
235
236
237
238
239
240
241
242
243
244
245
246
247
                // Check if buffer might contain partial end token at the end
                if let Some(partial_match_len) = helpers::ends_with_partial_token(
                    &self.normal_text_buffer,
                    end_token_without_newline,
                ) {
                    // Keep potential partial match in buffer, return the rest
                    let split_point = self.normal_text_buffer.len() - partial_match_len;
                    result.normal_text = self.normal_text_buffer[..split_point].to_string();
                    self.normal_text_buffer = self.normal_text_buffer[split_point..].to_string();
                } else {
                    // No partial match, return all buffered text
                    result.normal_text = self.normal_text_buffer.clone();
                    self.normal_text_buffer.clear();
248
249
250
251
                }
            }
        }

252
        Ok(result)
253
254
255
    }

    fn detect_format(&self, text: &str) -> bool {
256
        self.has_tool_markers(text)
257
258
    }
}