llama_parser.rs 8.12 KB
Newer Older
1
use async_trait::async_trait;
2
use serde_json::Value;
3

4
5
6
7
8
9
10
11
12
use crate::{
    protocols::common::Tool,
    tool_parser::{
        errors::{ParserError, ParserResult},
        parsers::helpers,
        partial_json::PartialJson,
        traits::ToolParser,
        types::{FunctionCall, StreamingParseResult, ToolCall},
    },
13
14
15
16
17
};

/// Llama 3.2 format parser for tool calls
///
/// Handles the Llama 3.2 specific format:
18
/// `<|python_tag|>{"name": "func", "parameters": {...}}`
19
20
21
///
/// Also supports plain JSON without the python_tag prefix
pub struct LlamaParser {
22
23
    /// Parser for handling incomplete JSON during streaming
    partial_json: PartialJson,
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42

    /// Buffer for accumulating incomplete patterns across chunks
    buffer: String,

    /// Stores complete tool call info (name and arguments) for each tool being parsed
    prev_tool_call_arr: Vec<Value>,

    /// Index of currently streaming tool call (-1 means no active tool)
    current_tool_id: i32,

    /// Flag for whether current tool's name has been sent to client
    current_tool_name_sent: bool,

    /// Tracks raw JSON string content streamed to client for each tool's arguments
    streamed_args_for_tool: Vec<String>,

    /// Token configuration
    bot_token: &'static str,
    tool_call_separator: &'static str,
43
44
45
46
47
}

impl LlamaParser {
    /// Create a new Llama parser
    pub fn new() -> Self {
48
49
        Self {
            partial_json: PartialJson::default(),
50
51
52
53
54
55
56
            buffer: String::new(),
            prev_tool_call_arr: Vec::new(),
            current_tool_id: -1,
            current_tool_name_sent: false,
            streamed_args_for_tool: Vec::new(),
            bot_token: "<|python_tag|>",
            tool_call_separator: ";",
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
        }
    }

    /// Extract content after python_tag token
    fn extract_content_after_python_tag(&self, text: &str) -> Option<(String, String)> {
        const PYTHON_TAG: &str = "<|python_tag|>";

        if let Some(tag_pos) = text.find(PYTHON_TAG) {
            let normal_text = text[..tag_pos].to_string();
            let json_content = text[tag_pos + PYTHON_TAG.len()..].to_string();
            Some((normal_text, json_content))
        } else {
            None
        }
    }

    /// Parse a single JSON object into a ToolCall (Llama format: name + parameters)
74
    fn parse_single_object(&self, obj: &Value) -> ParserResult<Option<ToolCall>> {
75
76
77
78
79
80
81
82
83
84
        // Llama format only: {"name": "function_name", "parameters": {...}}
        let name = obj.get("name").and_then(|v| v.as_str());

        if let Some(name) = name {
            // Llama uses "parameters" key
            let empty_obj = Value::Object(serde_json::Map::new());
            let parameters = obj.get("parameters").unwrap_or(&empty_obj);

            // Convert parameters to JSON string
            let arguments = serde_json::to_string(parameters)
85
                .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
86
87
88
89
90
91
92
93
94
95
96
97
98

            Ok(Some(ToolCall {
                function: FunctionCall {
                    name: name.to_string(),
                    arguments,
                },
            }))
        } else {
            Ok(None)
        }
    }

    /// Parse semicolon-separated JSON objects
99
    fn parse_semicolon_separated(&self, content: &str) -> ParserResult<Vec<ToolCall>> {
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        let mut all_tools = Vec::new();

        // Split by semicolon and parse each JSON object
        for part in content.split(';') {
            let trimmed = part.trim();
            if trimmed.is_empty() {
                continue;
            }

            // Try to parse this part as a single JSON object
            match serde_json::from_str::<Value>(trimmed) {
                Ok(value) => {
                    if let Some(tool) = self.parse_single_object(&value)? {
                        all_tools.push(tool);
                    }
                }
                Err(e) => {
                    // Skip invalid JSON parts in semicolon-separated list
                    tracing::warn!("Failed to parse tool call: {}", e);
                }
            }
        }

        Ok(all_tools)
124
125
126
127
128
129
130
131
132
133
134
    }
}

impl Default for LlamaParser {
    fn default() -> Self {
        Self::new()
    }
}

#[async_trait]
impl ToolParser for LlamaParser {
135
    async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
136
137
138
139
140
141
        // Extract normal text and JSON content
        let (normal_text, json_content) =
            if let Some((normal, json)) = self.extract_content_after_python_tag(text) {
                (normal, json)
            } else if text.trim_start().starts_with('{') {
                (String::new(), text.to_string())
142
            } else {
143
144
                // No JSON structure found
                return Ok((text.to_string(), vec![]));
145
            };
146

147
148
149
150
151
152
        // Parse the JSON content (may contain semicolon-separated objects)
        let tools = if json_content.contains(';') {
            self.parse_semicolon_separated(&json_content)?
        } else {
            // Try single JSON object
            let parsed = serde_json::from_str::<Value>(json_content.trim())
153
                .map_err(|e| ParserError::ParsingFailed(e.to_string()))
154
155
156
157
158
159
160
161
162
163
164
165
166
167
                .and_then(|v| {
                    self.parse_single_object(&v)
                        .map(|opt| opt.map_or_else(Vec::new, |tool| vec![tool]))
                });

            parsed.unwrap_or_else(|e| {
                tracing::warn!("Failed to parse tool call: {:?}", e);
                vec![]
            })
        };

        // If we couldn't parse any tools, return the original text
        if tools.is_empty() {
            return Ok((text.to_string(), vec![]));
168
169
        }

170
        Ok((normal_text, tools))
171
172
173
    }

    async fn parse_incremental(
174
        &mut self,
175
        chunk: &str,
176
        tools: &[Tool],
177
    ) -> ParserResult<StreamingParseResult> {
178
179
180
181
182
        // Append new text to buffer
        self.buffer.push_str(chunk);
        let current_text = &self.buffer.clone();

        // Check if current_text has tool_call
183
        let has_tool_start = self.has_tool_markers(current_text)
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
            || (self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator));

        if !has_tool_start {
            // Only clear buffer if we're sure no tool call is starting
            if helpers::ends_with_partial_token(&self.buffer, self.bot_token).is_none() {
                let normal_text = self.buffer.clone();
                self.buffer.clear();

                return Ok(StreamingParseResult {
                    normal_text,
                    calls: vec![],
                });
            } else {
                // Might be partial bot_token, keep buffering
                return Ok(StreamingParseResult::default());
199
200
            }
        }
201

202
203
204
205
206
207
208
209
        // Build tool indices
        let tool_indices = helpers::get_tool_indices(tools);

        // Determine start index for JSON parsing
        let start_idx = if let Some(pos) = current_text.find(self.bot_token) {
            pos + self.bot_token.len()
        } else if self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator) {
            self.tool_call_separator.len()
210
        } else {
211
            0
212
213
        };

214
215
216
217
218
219
220
221
222
223
224
        helpers::handle_json_tool_streaming(
            current_text,
            start_idx,
            &mut self.partial_json,
            &tool_indices,
            &mut self.buffer,
            &mut self.current_tool_id,
            &mut self.current_tool_name_sent,
            &mut self.streamed_args_for_tool,
            &mut self.prev_tool_call_arr,
        )
225
226
    }

227
    fn has_tool_markers(&self, text: &str) -> bool {
228
        // Llama format if contains python_tag or starts with JSON object
229
        text.contains("<|python_tag|>") || text.trim_start().starts_with('{')
230
    }
231
232
233
234

    fn get_unstreamed_tool_args(&self) -> Option<Vec<crate::tool_parser::types::ToolCallItem>> {
        helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
    }
235
236
237
238
239
240
241
242
243
244

    fn reset(&mut self) {
        helpers::reset_parser_state(
            &mut self.buffer,
            &mut self.prev_tool_call_arr,
            &mut self.current_tool_id,
            &mut self.current_tool_name_sent,
            &mut self.streamed_args_for_tool,
        );
    }
245
}