llama_parser.rs 6.38 KB
Newer Older
1
2
use async_trait::async_trait;

3
use super::json_parser::JsonParser;
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
use crate::tool_parser::{
    errors::ToolParserResult,
    state::ParseState,
    traits::ToolParser,
    types::{StreamResult, TokenConfig, ToolCall},
};

/// Llama 3.2 format parser for tool calls
///
/// Handles the Llama 3.2 specific format:
/// `<|python_tag|>{"name": "func", "arguments": {...}}`
///
/// Also supports plain JSON without the python_tag prefix
pub struct LlamaParser {
    /// Underlying JSON parser with Llama-specific configuration
    json_parser: JsonParser,
}

impl LlamaParser {
    /// Create a new Llama parser
    pub fn new() -> Self {
        // Configure JSON parser with Llama's python_tag token
        // Note: No end token for python_tag format
        let json_parser = JsonParser::with_config(TokenConfig {
            start_tokens: vec!["<|python_tag|>".to_string()],
            end_tokens: vec!["".to_string()], // Empty end token
            separator: ";".to_string(), // Llama uses semicolon for multiple calls (though not well supported)
        });

        Self { json_parser }
    }
}

impl Default for LlamaParser {
    fn default() -> Self {
        Self::new()
    }
}

#[async_trait]
impl ToolParser for LlamaParser {
45
    async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
46
        // First try with the configured python_tag parser
47
48
49
50
51
52
53
54
55
56
57
        let (_json_normal_text, tools) = self.json_parser.parse_complete(text).await?;

        if !tools.is_empty() {
            // Extract normal text before the python tag
            // JsonParser doesn't preserve normal text for single start tokens, so we do it manually
            let normal_text = if let Some(tag_pos) = text.find("<|python_tag|>") {
                text[..tag_pos].to_string()
            } else {
                String::new()
            };
            return Ok((normal_text, tools));
58
59
60
61
62
63
        }

        // If no results and text starts with '{', try plain JSON
        if text.trim_start().starts_with('{') {
            // Create a temporary plain JSON parser
            let plain_parser = JsonParser::new();
64
65
66
            let (_json_normal_text, tools) = plain_parser.parse_complete(text).await?;
            // For plain JSON, don't extract normal text (consistent with JsonParser behavior)
            return Ok((String::new(), tools));
67
68
        }

69
70
        // No tool calls found, return original text as normal text
        Ok((text.to_string(), vec![]))
71
72
73
74
75
76
77
    }

    async fn parse_incremental(
        &self,
        chunk: &str,
        state: &mut ParseState,
    ) -> ToolParserResult<StreamResult> {
78
        // First, try with the configured json_parser (which handles python_tag)
79
80
        let result = self.json_parser.parse_incremental(chunk, state).await?;

81
82
83
84
85
86
        // If we get Incomplete and no python_tag in buffer, might be plain JSON
        if matches!(result, StreamResult::Incomplete) {
            let trimmed = state.buffer.trim_start();
            if trimmed.starts_with('{') && !state.buffer.contains("<|python_tag|>") {
                // Likely plain JSON, try with a plain parser
                // Note: We need to be careful not to double-add the chunk
87
                let plain_parser = JsonParser::new();
88
89
                // The chunk was already added to state.buffer by json_parser above
                // So we call with empty string to just process what's in the buffer
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
                return plain_parser.parse_incremental("", state).await;
            }
        }

        Ok(result)
    }

    fn detect_format(&self, text: &str) -> bool {
        // Llama format if contains python_tag or starts with JSON object
        text.contains("<|python_tag|>")
            || (text.trim_start().starts_with('{')
                && (text.contains(r#""name""#) || text.contains(r#""function""#)))
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_parse_with_python_tag() {
        let parser = LlamaParser::new();
        let input = r#"<|python_tag|>{"name": "search", "arguments": {"query": "weather"}}"#;

114
115
116
117
118
        let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
        assert_eq!(tool_calls.len(), 1);
        assert_eq!(tool_calls[0].function.name, "search");
        assert!(tool_calls[0].function.arguments.contains("weather"));
        assert_eq!(normal_text, ""); // Pure python_tag with JSON should have no normal text
119
120
121
122
123
124
125
    }

    #[tokio::test]
    async fn test_parse_plain_json() {
        let parser = LlamaParser::new();
        let input = r#"{"name": "calculate", "arguments": {"x": 5, "y": 10}}"#;

126
127
128
129
        let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
        assert_eq!(tool_calls.len(), 1);
        assert_eq!(tool_calls[0].function.name, "calculate");
        assert_eq!(normal_text, ""); // Pure JSON should have no normal text
130
131
132
133
134
135
136
    }

    #[tokio::test]
    async fn test_parse_with_text_before() {
        let parser = LlamaParser::new();
        let input = r#"Let me help you with that. <|python_tag|>{"name": "get_time", "arguments": {"timezone": "UTC"}}"#;

137
138
139
140
        let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
        assert_eq!(tool_calls.len(), 1);
        assert_eq!(tool_calls[0].function.name, "get_time");
        assert_eq!(normal_text, "Let me help you with that. ");
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
    }

    #[test]
    fn test_detect_format() {
        let parser = LlamaParser::new();

        assert!(parser.detect_format(r#"<|python_tag|>{"name": "test"}"#));
        assert!(parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
        assert!(!parser.detect_format("plain text"));
        assert!(!parser.detect_format(r#"{"key": "value"}"#)); // No name field
    }

    #[tokio::test]
    async fn test_single_call_with_semicolon() {
        let parser = LlamaParser::new();
        // Note: Llama 3.2 doesn't handle multiple calls well
        let input = r#"<|python_tag|>{"name": "func1", "arguments": {"x": 1}};"#;

159
        let (_normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
160
161
162
163
164
165

        // We expect this to either parse the first JSON object or fail gracefully
        // Since the semicolon makes it invalid JSON, it will likely return empty
        // This is acceptable as Llama 3.2 doesn't reliably support parallel calls

        // If it parses anything, it should be func1
166
167
        if !tool_calls.is_empty() {
            assert_eq!(tool_calls[0].function.name, "func1");
168
169
170
        }
    }
}