llama_parser.rs 7.84 KB
Newer Older
1
use async_trait::async_trait;
2
use serde_json::Value;
3

4
5
use crate::protocols::spec::Tool;

6
use crate::tool_parser::{
7
    errors::{ToolParserError, ToolParserResult},
8
    parsers::helpers,
9
    partial_json::PartialJson,
10
    traits::ToolParser,
11
    types::{FunctionCall, StreamingParseResult, ToolCall},
12
13
14
15
16
};

/// Llama 3.2 format parser for tool calls
///
/// Handles the Llama 3.2 specific format:
17
/// `<|python_tag|>{"name": "func", "parameters": {...}}`
18
19
20
///
/// Also supports plain JSON without the python_tag prefix
pub struct LlamaParser {
21
22
    /// Parser for handling incomplete JSON during streaming
    partial_json: PartialJson,
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41

    /// Buffer for accumulating incomplete patterns across chunks
    buffer: String,

    /// Stores complete tool call info (name and arguments) for each tool being parsed
    prev_tool_call_arr: Vec<Value>,

    /// Index of currently streaming tool call (-1 means no active tool)
    current_tool_id: i32,

    /// Flag for whether current tool's name has been sent to client
    current_tool_name_sent: bool,

    /// Tracks raw JSON string content streamed to client for each tool's arguments
    streamed_args_for_tool: Vec<String>,

    /// Token configuration
    bot_token: &'static str,
    tool_call_separator: &'static str,
42
43
44
45
46
}

impl LlamaParser {
    /// Create a new Llama parser
    pub fn new() -> Self {
47
48
        Self {
            partial_json: PartialJson::default(),
49
50
51
52
53
54
55
            buffer: String::new(),
            prev_tool_call_arr: Vec::new(),
            current_tool_id: -1,
            current_tool_name_sent: false,
            streamed_args_for_tool: Vec::new(),
            bot_token: "<|python_tag|>",
            tool_call_separator: ";",
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
        }
    }

    /// Extract content after python_tag token
    fn extract_content_after_python_tag(&self, text: &str) -> Option<(String, String)> {
        const PYTHON_TAG: &str = "<|python_tag|>";

        if let Some(tag_pos) = text.find(PYTHON_TAG) {
            let normal_text = text[..tag_pos].to_string();
            let json_content = text[tag_pos + PYTHON_TAG.len()..].to_string();
            Some((normal_text, json_content))
        } else {
            None
        }
    }

    /// Parse a single JSON object into a ToolCall (Llama format: name + parameters)
    fn parse_single_object(&self, obj: &Value) -> ToolParserResult<Option<ToolCall>> {
        // Llama format only: {"name": "function_name", "parameters": {...}}
        let name = obj.get("name").and_then(|v| v.as_str());

        if let Some(name) = name {
            // Llama uses "parameters" key
            let empty_obj = Value::Object(serde_json::Map::new());
            let parameters = obj.get("parameters").unwrap_or(&empty_obj);

            // Convert parameters to JSON string
            let arguments = serde_json::to_string(parameters)
                .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;

            Ok(Some(ToolCall {
                function: FunctionCall {
                    name: name.to_string(),
                    arguments,
                },
            }))
        } else {
            Ok(None)
        }
    }

    /// Parse semicolon-separated JSON objects
    fn parse_semicolon_separated(&self, content: &str) -> ToolParserResult<Vec<ToolCall>> {
        let mut all_tools = Vec::new();

        // Split by semicolon and parse each JSON object
        for part in content.split(';') {
            let trimmed = part.trim();
            if trimmed.is_empty() {
                continue;
            }

            // Try to parse this part as a single JSON object
            match serde_json::from_str::<Value>(trimmed) {
                Ok(value) => {
                    if let Some(tool) = self.parse_single_object(&value)? {
                        all_tools.push(tool);
                    }
                }
                Err(e) => {
                    // Skip invalid JSON parts in semicolon-separated list
                    tracing::warn!("Failed to parse tool call: {}", e);
                }
            }
        }

        Ok(all_tools)
123
124
125
126
127
128
129
130
131
132
133
    }
}

impl Default for LlamaParser {
    fn default() -> Self {
        Self::new()
    }
}

#[async_trait]
impl ToolParser for LlamaParser {
134
    async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
135
136
137
138
139
140
        // Extract normal text and JSON content
        let (normal_text, json_content) =
            if let Some((normal, json)) = self.extract_content_after_python_tag(text) {
                (normal, json)
            } else if text.trim_start().starts_with('{') {
                (String::new(), text.to_string())
141
            } else {
142
143
                // No JSON structure found
                return Ok((text.to_string(), vec![]));
144
            };
145

146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
        // Parse the JSON content (may contain semicolon-separated objects)
        let tools = if json_content.contains(';') {
            self.parse_semicolon_separated(&json_content)?
        } else {
            // Try single JSON object
            let parsed = serde_json::from_str::<Value>(json_content.trim())
                .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))
                .and_then(|v| {
                    self.parse_single_object(&v)
                        .map(|opt| opt.map_or_else(Vec::new, |tool| vec![tool]))
                });

            parsed.unwrap_or_else(|e| {
                tracing::warn!("Failed to parse tool call: {:?}", e);
                vec![]
            })
        };

        // If we couldn't parse any tools, return the original text
        if tools.is_empty() {
            return Ok((text.to_string(), vec![]));
167
168
        }

169
        Ok((normal_text, tools))
170
171
172
    }

    async fn parse_incremental(
173
        &mut self,
174
        chunk: &str,
175
176
177
178
179
180
181
        tools: &[Tool],
    ) -> ToolParserResult<StreamingParseResult> {
        // Append new text to buffer
        self.buffer.push_str(chunk);
        let current_text = &self.buffer.clone();

        // Check if current_text has tool_call
182
        let has_tool_start = self.has_tool_markers(current_text)
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
            || (self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator));

        if !has_tool_start {
            // Only clear buffer if we're sure no tool call is starting
            if helpers::ends_with_partial_token(&self.buffer, self.bot_token).is_none() {
                let normal_text = self.buffer.clone();
                self.buffer.clear();

                return Ok(StreamingParseResult {
                    normal_text,
                    calls: vec![],
                });
            } else {
                // Might be partial bot_token, keep buffering
                return Ok(StreamingParseResult::default());
198
199
            }
        }
200

201
202
203
204
205
206
207
208
        // Build tool indices
        let tool_indices = helpers::get_tool_indices(tools);

        // Determine start index for JSON parsing
        let start_idx = if let Some(pos) = current_text.find(self.bot_token) {
            pos + self.bot_token.len()
        } else if self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator) {
            self.tool_call_separator.len()
209
        } else {
210
            0
211
212
        };

213
214
215
216
217
218
219
220
221
222
223
        helpers::handle_json_tool_streaming(
            current_text,
            start_idx,
            &mut self.partial_json,
            &tool_indices,
            &mut self.buffer,
            &mut self.current_tool_id,
            &mut self.current_tool_name_sent,
            &mut self.streamed_args_for_tool,
            &mut self.prev_tool_call_arr,
        )
224
225
    }

226
    fn has_tool_markers(&self, text: &str) -> bool {
227
        // Llama format if contains python_tag or starts with JSON object
228
        text.contains("<|python_tag|>") || text.trim_start().starts_with('{')
229
    }
230
231
232
233

    fn get_unstreamed_tool_args(&self) -> Option<Vec<crate::tool_parser::types::ToolCallItem>> {
        helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
    }
234
}