llama_parser.rs 5.23 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
use async_trait::async_trait;

use crate::tool_parser::{
    errors::ToolParserResult,
    json_parser::JsonParser,
    state::ParseState,
    traits::ToolParser,
    types::{StreamResult, TokenConfig, ToolCall},
};

/// Llama 3.2 format parser for tool calls
///
/// Handles the Llama 3.2 specific format:
/// `<|python_tag|>{"name": "func", "arguments": {...}}`
///
/// Also supports plain JSON without the python_tag prefix
pub struct LlamaParser {
    /// Underlying JSON parser with Llama-specific configuration
    json_parser: JsonParser,
}

impl LlamaParser {
    /// Create a new Llama parser
    pub fn new() -> Self {
        // Configure JSON parser with Llama's python_tag token
        // Note: No end token for python_tag format
        let json_parser = JsonParser::with_config(TokenConfig {
            start_tokens: vec!["<|python_tag|>".to_string()],
            end_tokens: vec!["".to_string()], // Empty end token
            separator: ";".to_string(), // Llama uses semicolon for multiple calls (though not well supported)
        });

        Self { json_parser }
    }
}

impl Default for LlamaParser {
    fn default() -> Self {
        Self::new()
    }
}

#[async_trait]
impl ToolParser for LlamaParser {
    async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
        // First try with the configured python_tag parser
        let result = self.json_parser.parse_complete(text).await?;

        if !result.is_empty() {
            return Ok(result);
        }

        // If no results and text starts with '{', try plain JSON
        if text.trim_start().starts_with('{') {
            // Create a temporary plain JSON parser
            let plain_parser = JsonParser::new();
            return plain_parser.parse_complete(text).await;
        }

        Ok(vec![])
    }

    async fn parse_incremental(
        &self,
        chunk: &str,
        state: &mut ParseState,
    ) -> ToolParserResult<StreamResult> {
        // Try with the python_tag parser first
        let result = self.json_parser.parse_incremental(chunk, state).await?;

        // If we get Incomplete and buffer starts with '{', might be plain JSON
        if matches!(result, StreamResult::Incomplete) && state.buffer.trim_start().starts_with('{')
        {
            // Check if we have python_tag in the buffer
            if !state.buffer.contains("<|python_tag|>") {
                // Likely plain JSON, create temporary parser
                let plain_parser = JsonParser::new();
                return plain_parser.parse_incremental("", state).await;
            }
        }

        Ok(result)
    }

    fn detect_format(&self, text: &str) -> bool {
        // Llama format if contains python_tag or starts with JSON object
        text.contains("<|python_tag|>")
            || (text.trim_start().starts_with('{')
                && (text.contains(r#""name""#) || text.contains(r#""function""#)))
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_parse_with_python_tag() {
        let parser = LlamaParser::new();
        let input = r#"<|python_tag|>{"name": "search", "arguments": {"query": "weather"}}"#;

        let result = parser.parse_complete(input).await.unwrap();
        assert_eq!(result.len(), 1);
        assert_eq!(result[0].function.name, "search");
        assert!(result[0].function.arguments.contains("weather"));
    }

    #[tokio::test]
    async fn test_parse_plain_json() {
        let parser = LlamaParser::new();
        let input = r#"{"name": "calculate", "arguments": {"x": 5, "y": 10}}"#;

        let result = parser.parse_complete(input).await.unwrap();
        assert_eq!(result.len(), 1);
        assert_eq!(result[0].function.name, "calculate");
    }

    #[tokio::test]
    async fn test_parse_with_text_before() {
        let parser = LlamaParser::new();
        let input = r#"Let me help you with that. <|python_tag|>{"name": "get_time", "arguments": {"timezone": "UTC"}}"#;

        let result = parser.parse_complete(input).await.unwrap();
        assert_eq!(result.len(), 1);
        assert_eq!(result[0].function.name, "get_time");
    }

    #[test]
    fn test_detect_format() {
        let parser = LlamaParser::new();

        assert!(parser.detect_format(r#"<|python_tag|>{"name": "test"}"#));
        assert!(parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
        assert!(!parser.detect_format("plain text"));
        assert!(!parser.detect_format(r#"{"key": "value"}"#)); // No name field
    }

    #[tokio::test]
    async fn test_single_call_with_semicolon() {
        let parser = LlamaParser::new();
        // Note: Llama 3.2 doesn't handle multiple calls well
        // Test that we can at least parse a single call followed by semicolon
        let input = r#"<|python_tag|>{"name": "func1", "arguments": {"x": 1}};"#;

        let result = parser.parse_complete(input).await.unwrap();

        // We expect this to either parse the first JSON object or fail gracefully
        // Since the semicolon makes it invalid JSON, it will likely return empty
        // This is acceptable as Llama 3.2 doesn't reliably support parallel calls

        // If it parses anything, it should be func1
        if !result.is_empty() {
            assert_eq!(result[0].function.name, "func1");
        }
    }
}