llama_parser.rs 6.14 KB
Newer Older
1
2
use async_trait::async_trait;

3
use super::json_parser::JsonParser;
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
use crate::tool_parser::{
    errors::ToolParserResult,
    state::ParseState,
    traits::ToolParser,
    types::{StreamResult, TokenConfig, ToolCall},
};

/// Llama 3.2 format parser for tool calls
///
/// Handles the Llama 3.2 specific format:
/// `<|python_tag|>{"name": "func", "arguments": {...}}`
///
/// Also supports plain JSON without the python_tag prefix
pub struct LlamaParser {
    /// Underlying JSON parser with Llama-specific configuration
    json_parser: JsonParser,
}

impl LlamaParser {
    /// Create a new Llama parser
    pub fn new() -> Self {
        // Configure JSON parser with Llama's python_tag token
        // Note: No end token for python_tag format
        let json_parser = JsonParser::with_config(TokenConfig {
            start_tokens: vec!["<|python_tag|>".to_string()],
            end_tokens: vec!["".to_string()], // Empty end token
            separator: ";".to_string(), // Llama uses semicolon for multiple calls (though not well supported)
        });

        Self { json_parser }
    }
}

impl Default for LlamaParser {
    fn default() -> Self {
        Self::new()
    }
}

#[async_trait]
impl ToolParser for LlamaParser {
45
    async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
46
        // First try with the configured python_tag parser
47
48
49
50
51
52
53
54
55
56
57
        let (_json_normal_text, tools) = self.json_parser.parse_complete(text).await?;

        if !tools.is_empty() {
            // Extract normal text before the python tag
            // JsonParser doesn't preserve normal text for single start tokens, so we do it manually
            let normal_text = if let Some(tag_pos) = text.find("<|python_tag|>") {
                text[..tag_pos].to_string()
            } else {
                String::new()
            };
            return Ok((normal_text, tools));
58
59
60
61
62
63
        }

        // If no results and text starts with '{', try plain JSON
        if text.trim_start().starts_with('{') {
            // Create a temporary plain JSON parser
            let plain_parser = JsonParser::new();
64
65
66
            let (_json_normal_text, tools) = plain_parser.parse_complete(text).await?;
            // For plain JSON, don't extract normal text (consistent with JsonParser behavior)
            return Ok((String::new(), tools));
67
68
        }

69
70
        // No tool calls found, return original text as normal text
        Ok((text.to_string(), vec![]))
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    }

    async fn parse_incremental(
        &self,
        chunk: &str,
        state: &mut ParseState,
    ) -> ToolParserResult<StreamResult> {
        // Try with the python_tag parser first
        let result = self.json_parser.parse_incremental(chunk, state).await?;

        // If we get Incomplete and buffer starts with '{', might be plain JSON
        if matches!(result, StreamResult::Incomplete) && state.buffer.trim_start().starts_with('{')
        {
            // Check if we have python_tag in the buffer
            if !state.buffer.contains("<|python_tag|>") {
                // Likely plain JSON, create temporary parser
                let plain_parser = JsonParser::new();
                return plain_parser.parse_incremental("", state).await;
            }
        }

        Ok(result)
    }

    fn detect_format(&self, text: &str) -> bool {
        // Llama format if contains python_tag or starts with JSON object
        text.contains("<|python_tag|>")
            || (text.trim_start().starts_with('{')
                && (text.contains(r#""name""#) || text.contains(r#""function""#)))
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_parse_with_python_tag() {
        let parser = LlamaParser::new();
        let input = r#"<|python_tag|>{"name": "search", "arguments": {"query": "weather"}}"#;

112
113
114
115
116
        let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
        assert_eq!(tool_calls.len(), 1);
        assert_eq!(tool_calls[0].function.name, "search");
        assert!(tool_calls[0].function.arguments.contains("weather"));
        assert_eq!(normal_text, ""); // Pure python_tag with JSON should have no normal text
117
118
119
120
121
122
123
    }

    #[tokio::test]
    async fn test_parse_plain_json() {
        let parser = LlamaParser::new();
        let input = r#"{"name": "calculate", "arguments": {"x": 5, "y": 10}}"#;

124
125
126
127
        let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
        assert_eq!(tool_calls.len(), 1);
        assert_eq!(tool_calls[0].function.name, "calculate");
        assert_eq!(normal_text, ""); // Pure JSON should have no normal text
128
129
130
131
132
133
134
    }

    #[tokio::test]
    async fn test_parse_with_text_before() {
        let parser = LlamaParser::new();
        let input = r#"Let me help you with that. <|python_tag|>{"name": "get_time", "arguments": {"timezone": "UTC"}}"#;

135
136
137
138
        let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
        assert_eq!(tool_calls.len(), 1);
        assert_eq!(tool_calls[0].function.name, "get_time");
        assert_eq!(normal_text, "Let me help you with that. ");
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    }

    #[test]
    fn test_detect_format() {
        let parser = LlamaParser::new();

        assert!(parser.detect_format(r#"<|python_tag|>{"name": "test"}"#));
        assert!(parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
        assert!(!parser.detect_format("plain text"));
        assert!(!parser.detect_format(r#"{"key": "value"}"#)); // No name field
    }

    #[tokio::test]
    async fn test_single_call_with_semicolon() {
        let parser = LlamaParser::new();
        // Note: Llama 3.2 doesn't handle multiple calls well
        let input = r#"<|python_tag|>{"name": "func1", "arguments": {"x": 1}};"#;

157
        let (_normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
158
159
160
161
162
163

        // We expect this to either parse the first JSON object or fail gracefully
        // Since the semicolon makes it invalid JSON, it will likely return empty
        // This is acceptable as Llama 3.2 doesn't reliably support parallel calls

        // If it parses anything, it should be func1
164
165
        if !tool_calls.is_empty() {
            assert_eq!(tool_calls[0].function.name, "func1");
166
167
168
        }
    }
}