gpt_oss_parser.rs 9.34 KB
Newer Older
1
2
3
4
use async_trait::async_trait;
use regex::Regex;
use serde_json::Value;

5
6
use crate::protocols::spec::Tool;

7
8
use crate::tool_parser::{
    errors::{ToolParserError, ToolParserResult},
9
    parsers::helpers,
10
11
    partial_json::PartialJson,
    traits::ToolParser,
12
    types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
};

/// GPT-OSS format parser for tool calls
///
/// Handles the GPT-OSS specific channel format:
/// `<|channel|>commentary to={namespace.function}<|constrain|>json<|message|>{json_args}<|call|>`
///
/// Features:
/// - Channel-based format with commentary
/// - Namespaced function calls
/// - JSON arguments
pub struct GptOssParser {
    /// Parser for handling incomplete JSON during streaming
    partial_json: PartialJson,
    /// Regex for extracting complete function calls
    function_call_extractor: Regex,
    /// Regex for extracting streaming function calls
    streaming_extractor: Regex,
31
32
33
34
35

    /// Buffer for accumulating chunks
    buffer: String,
    /// Whether the tool name has been sent (for streaming)
    name_sent: bool,
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
}

impl GptOssParser {
    /// Create a new GPT-OSS parser
    pub fn new() -> Self {
        // Pattern for complete function calls with to= parameter
        // Handles optional <|start|>assistant prefix and whitespace after function name
        let function_call_pattern = r"(?s)(?:<\|start\|>assistant)?<\|channel\|>commentary to=([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)\s*<\|constrain\|>json<\|message\|>(.*?)<\|call\|>(?:commentary)?";
        let function_call_extractor =
            Regex::new(function_call_pattern).expect("Valid regex pattern");

        // Pattern for streaming function calls (incomplete)
        let streaming_pattern = r"(?s)(?:<\|start\|>assistant)?<\|channel\|>commentary to=([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)\s*<\|constrain\|>json<\|message\|>(.*)";
        let streaming_extractor = Regex::new(streaming_pattern).expect("Valid regex pattern");

        Self {
            partial_json: PartialJson::default(),
            function_call_extractor,
            streaming_extractor,
55
56
57

            buffer: String::new(),
            name_sent: false,
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
        }
    }

    /// Extract function name from full namespace (e.g., "functions.get_weather" -> "get_weather")
    fn extract_function_name(&self, full_name: &str) -> String {
        if let Some(dot_pos) = full_name.rfind('.') {
            full_name[dot_pos + 1..].to_string()
        } else {
            full_name.to_string()
        }
    }
}

impl Default for GptOssParser {
    fn default() -> Self {
        Self::new()
    }
}

#[async_trait]
impl ToolParser for GptOssParser {
79
    async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
80
81
        // Check if text contains GPT-OSS format
        if !self.has_tool_markers(text) {
82
            return Ok((text.to_string(), vec![]));
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        }

        let mut tools = Vec::new();
        let mut _tool_index = 0;

        // Extract all function calls
        for captures in self.function_call_extractor.captures_iter(text) {
            if let (Some(name_match), Some(args_match)) = (captures.get(1), captures.get(2)) {
                let full_function_name = name_match.as_str();
                let args_content = args_match.as_str().trim();

                // Extract actual function name
                let function_name = self.extract_function_name(full_function_name);

                // Parse JSON arguments
                let arguments = if args_content.is_empty() {
                    "{}".to_string()
                } else {
                    match serde_json::from_str::<Value>(args_content) {
                        Ok(value) => serde_json::to_string(&value)
                            .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?,
                        Err(_) => {
                            // Skip malformed JSON
                            continue;
                        }
                    }
                };

                tools.push(ToolCall {
                    function: FunctionCall {
                        name: function_name,
                        arguments,
                    },
                });

                _tool_index += 1;
            }
        }

122
        Ok((String::new(), tools)) // GPT-OSS parser returns empty normal text
123
124
125
    }

    async fn parse_incremental(
126
        &mut self,
127
        chunk: &str,
128
129
130
        tools: &[Tool],
    ) -> ToolParserResult<StreamingParseResult> {
        self.buffer.push_str(chunk);
131
132

        // Check for tool markers
133
        if !self.has_tool_markers(&self.buffer) {
134
            // No markers found, clear buffer and return
135
136
            self.buffer.clear();
            return Ok(StreamingParseResult::default());
137
138
139
        }

        // Try to match streaming pattern
140
        if let Some(captures) = self.streaming_extractor.captures(&self.buffer) {
141
142
143
144
145
146
147
148
            if let (Some(name_match), Some(args_match)) = (captures.get(1), captures.get(2)) {
                let full_function_name = name_match.as_str();
                let partial_args = args_match.as_str();

                // Extract actual function name
                let function_name = self.extract_function_name(full_function_name);

                // Send function name if not sent yet
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
                if !self.name_sent {
                    // Validate tool name
                    let tool_indices = helpers::get_tool_indices(tools);
                    if !tool_indices.contains_key(&function_name) {
                        // Invalid tool name - skip
                        tracing::warn!("Invalid tool name '{}' - skipping", function_name);
                        self.buffer.clear();
                        self.name_sent = false;
                        return Ok(StreamingParseResult::default());
                    }

                    self.name_sent = true; // Mark name as sent
                    return Ok(StreamingParseResult {
                        normal_text: String::new(),
                        calls: vec![ToolCallItem {
                            tool_index: 0,
                            name: Some(function_name.clone()),
                            parameters: String::new(),
                        }],
168
169
170
171
                    });
                }

                // Check if we have a complete function call
172
                if let Some(complete_match) = self.function_call_extractor.captures(&self.buffer) {
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
                    if let Some(args_match) = complete_match.get(2) {
                        let args_content = args_match.as_str().trim();

                        // Parse JSON arguments
                        let arguments = if args_content.is_empty() {
                            "{}".to_string()
                        } else {
                            match serde_json::from_str::<Value>(args_content) {
                                Ok(value) => serde_json::to_string(&value)
                                    .unwrap_or_else(|_| "{}".to_string()),
                                Err(_) => "{}".to_string(),
                            }
                        };

                        // Remove the processed part from buffer
                        let complete_end = complete_match.get(0).unwrap().end();
189
                        self.buffer.drain(..complete_end);
190
191

                        // Reset state for next tool
192
193
194
195
196
197
198
199
200
201
202
                        self.name_sent = false;

                        // Return final arguments
                        return Ok(StreamingParseResult {
                            normal_text: String::new(),
                            calls: vec![ToolCallItem {
                                tool_index: 0,
                                name: None,
                                parameters: arguments,
                            }],
                        });
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
                    }
                } else {
                    // Try to parse partial JSON for streaming arguments
                    if !partial_args.is_empty() {
                        // Look for the end of JSON (before <|call|>)
                        let json_part = if let Some(call_pos) = partial_args.find("<|call|>") {
                            &partial_args[..call_pos]
                        } else {
                            partial_args
                        };

                        match self.partial_json.parse_value(json_part) {
                            Ok((value, _consumed)) => {
                                let args_str = serde_json::to_string(&value)
                                    .unwrap_or_else(|_| "{}".to_string());

219
220
221
222
223
224
225
                                return Ok(StreamingParseResult {
                                    normal_text: String::new(),
                                    calls: vec![ToolCallItem {
                                        tool_index: 0,
                                        name: None,
                                        parameters: args_str,
                                    }],
226
227
228
229
230
231
232
233
234
235
236
                                });
                            }
                            Err(_) => {
                                // Can't parse yet, keep buffering
                            }
                        }
                    }
                }
            }
        }

237
        Ok(StreamingParseResult::default())
238
239
    }

240
241
    fn has_tool_markers(&self, text: &str) -> bool {
        text.contains("<|channel|>commentary")
242
243
    }
}