qwen.rs 9.21 KB
Newer Older
1
2
3
4
use async_trait::async_trait;
use regex::Regex;
use serde_json::Value;

5
6
7
8
9
10
11
12
13
use crate::{
    protocols::common::Tool,
    tool_parser::{
        errors::{ParserError, ParserResult},
        parsers::helpers,
        partial_json::PartialJson,
        traits::ToolParser,
        types::{FunctionCall, StreamingParseResult, ToolCall},
    },
14
15
16
17
18
19
20
21
};

/// Qwen format parser for tool calls
///
/// Handles the Qwen 2.5/3 specific format:
/// `<tool_call>\n{"name": "func", "arguments": {...}}\n</tool_call>`
///
/// Features:
22
23
24
25
26
/// - Tool Call Tags: `<tool_call>` and `</tool_call>` wrap each individual call
/// - Each individual call is separated by `\n`
/// - Function Call Object: JSON object with "name" and "arguments" fields
///
/// Reference: https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct?chat_template=default
27
28
29
pub struct QwenParser {
    /// Parser for handling incomplete JSON during streaming
    partial_json: PartialJson,
30
31

    /// Regex for extracting tool calls in parse_complete
32
    extractor: Regex,
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52

    /// Buffer for accumulating incomplete patterns across chunks
    buffer: String,

    /// Stores complete tool call info (name and arguments) for each tool being parsed
    prev_tool_call_arr: Vec<Value>,

    /// Index of currently streaming tool call (-1 means no active tool)
    current_tool_id: i32,

    /// Flag for whether current tool's name has been sent to client
    current_tool_name_sent: bool,

    /// Tracks raw JSON string content streamed to client for each tool's arguments
    streamed_args_for_tool: Vec<String>,

    /// Buffer for normal text that might precede partial end tokens
    normal_text_buffer: String,

    /// Token configuration
53
54
55
    /// Start/end tokens for each individual tool call (not the entire sequence)
    individual_tool_start_token: &'static str,
    individual_tool_end_token: &'static str,
56
    tool_call_separator: &'static str,
57
58
59
60
61
62
63
64
65
66
67
68
}

impl QwenParser {
    /// Create a new Qwen parser
    pub fn new() -> Self {
        // Use (?s) flag for DOTALL mode to handle newlines
        let pattern = r"(?s)<tool_call>\n(.*?)\n</tool_call>";
        let extractor = Regex::new(pattern).expect("Valid regex pattern");

        Self {
            partial_json: PartialJson::default(),
            extractor,
69
70
71
72
73
74
            buffer: String::new(),
            prev_tool_call_arr: Vec::new(),
            current_tool_id: -1,
            current_tool_name_sent: false,
            streamed_args_for_tool: Vec::new(),
            normal_text_buffer: String::new(),
75
76
            individual_tool_start_token: "<tool_call>\n",
            individual_tool_end_token: "\n</tool_call>",
77
            tool_call_separator: "\n",
78
79
80
81
        }
    }

    /// Parse a single JSON object into a ToolCall
82
    fn parse_single_object(&self, obj: &Value) -> ParserResult<Option<ToolCall>> {
83
84
85
86
87
88
89
90
91
        let name = obj.get("name").and_then(|v| v.as_str());

        if let Some(name) = name {
            // Get arguments - Qwen uses "arguments" key
            let empty_obj = Value::Object(serde_json::Map::new());
            let args = obj.get("arguments").unwrap_or(&empty_obj);

            // Convert arguments to JSON string
            let arguments = serde_json::to_string(args)
92
                .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113

            Ok(Some(ToolCall {
                function: FunctionCall {
                    name: name.to_string(),
                    arguments,
                },
            }))
        } else {
            Ok(None)
        }
    }
}

impl Default for QwenParser {
    fn default() -> Self {
        Self::new()
    }
}

#[async_trait]
impl ToolParser for QwenParser {
114
    async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
115
116
        // Check if text contains Qwen format
        if !self.has_tool_markers(text) {
117
            return Ok((text.to_string(), vec![]));
118
119
        }

120
121
122
        // Find where the first tool call begins
        let idx = text.find("<tool_call>").unwrap(); // Safe because has_tool_markers checked
        let normal_text = text[..idx].to_string();
123

124
125
        // Extract tool calls
        let mut tools = Vec::new();
126
        for captures in self.extractor.captures_iter(text) {
127
            if let Some(json_str) = captures.get(1) {
128
                let parsed = serde_json::from_str::<Value>(json_str.as_str().trim())
129
                    .map_err(|e| ParserError::ParsingFailed(e.to_string()))
130
                    .and_then(|v| self.parse_single_object(&v));
131
132
133
134

                match parsed {
                    Ok(Some(tool)) => tools.push(tool),
                    Ok(None) => continue,
135
                    Err(e) => {
136
                        tracing::warn!("Failed to parse tool call: {:?}", e);
137
                        continue;
138
139
140
141
142
                    }
                }
            }
        }

143
144
145
146
        // If no tools were successfully parsed despite having markers, return entire text as fallback
        if tools.is_empty() {
            return Ok((text.to_string(), vec![]));
        }
147
148

        Ok((normal_text, tools))
149
150
151
    }

    async fn parse_incremental(
152
        &mut self,
153
        chunk: &str,
154
        tools: &[Tool],
155
    ) -> ParserResult<StreamingParseResult> {
156
157
158
159
160
        // Append new text to buffer
        self.buffer.push_str(chunk);
        let current_text = &self.buffer.clone();

        // Check if current_text has tool_call
161
        let has_tool_start = self.has_tool_markers(current_text)
162
            || (self.current_tool_id > 0 && current_text.starts_with(self.tool_call_separator));
163
164
165

        if !has_tool_start {
            // Only clear buffer if we're sure no tool call is starting
166
167
168
            if helpers::ends_with_partial_token(&self.buffer, self.individual_tool_start_token)
                .is_none()
            {
169
170
171
172
173
174
175
176
                let normal_text = self.buffer.clone();
                self.buffer.clear();

                return Ok(StreamingParseResult {
                    normal_text,
                    calls: vec![],
                });
            } else {
177
                // Might be partial individual_tool_start_token, keep buffering
178
                return Ok(StreamingParseResult::default());
179
            }
180
181
        }

182
183
184
185
        // Build tool indices
        let tool_indices = helpers::get_tool_indices(tools);

        // Determine start index for JSON parsing
186
187
188
        let start_idx = if let Some(pos) = current_text.find(self.individual_tool_start_token) {
            pos + self.individual_tool_start_token.len()
        } else if self.current_tool_id > 0 && current_text.starts_with(self.tool_call_separator) {
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
            self.tool_call_separator.len()
        } else {
            0
        };

        let mut result = helpers::handle_json_tool_streaming(
            current_text,
            start_idx,
            &mut self.partial_json,
            &tool_indices,
            &mut self.buffer,
            &mut self.current_tool_id,
            &mut self.current_tool_name_sent,
            &mut self.streamed_args_for_tool,
            &mut self.prev_tool_call_arr,
        )?;

        // Qwen-specific: Handle partial end tokens in normal text
        // After tool calls complete, normal text might contain partial "</tool_call>" tags
        if !result.normal_text.is_empty() {
            self.normal_text_buffer.push_str(&result.normal_text);

            // Check if buffer contains complete end token (without leading newline)
212
            let end_token_without_newline = &self.individual_tool_end_token[1..]; // "</tool_call>"
213
214
215
216
217
218
219
            if self.normal_text_buffer.contains(end_token_without_newline) {
                // Complete end token found - clean it and return
                let cleaned_text = self
                    .normal_text_buffer
                    .replace(end_token_without_newline, "");
                self.normal_text_buffer.clear();
                result.normal_text = cleaned_text;
220
            } else {
221
222
223
224
225
226
227
228
229
230
231
232
233
                // Check if buffer might contain partial end token at the end
                if let Some(partial_match_len) = helpers::ends_with_partial_token(
                    &self.normal_text_buffer,
                    end_token_without_newline,
                ) {
                    // Keep potential partial match in buffer, return the rest
                    let split_point = self.normal_text_buffer.len() - partial_match_len;
                    result.normal_text = self.normal_text_buffer[..split_point].to_string();
                    self.normal_text_buffer = self.normal_text_buffer[split_point..].to_string();
                } else {
                    // No partial match, return all buffered text
                    result.normal_text = self.normal_text_buffer.clone();
                    self.normal_text_buffer.clear();
234
235
236
237
                }
            }
        }

238
        Ok(result)
239
240
    }

241
242
    fn has_tool_markers(&self, text: &str) -> bool {
        text.contains("<tool_call>")
243
    }
244
245
246
247

    fn get_unstreamed_tool_args(&self) -> Option<Vec<crate::tool_parser::types::ToolCallItem>> {
        helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
    }
248
249
250
251
252
253
254
255
256
257

    fn reset(&mut self) {
        helpers::reset_parser_state(
            &mut self.buffer,
            &mut self.prev_tool_call_arr,
            &mut self.current_tool_id,
            &mut self.current_tool_name_sent,
            &mut self.streamed_args_for_tool,
        );
    }
258
}