tool_parser_llama.rs 13.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
//! Llama Parser Integration Tests
//!
//! Tests for the Llama parser which handles <|python_tag|> format and plain JSON

use sglang_router_rs::tool_parser::{LlamaParser, ToolParser};

#[tokio::test]
async fn test_llama_python_tag_format() {
    let parser = LlamaParser::new();
10
    let input = r#"Here are some results: <|python_tag|>{"name": "search", "parameters": {"query": "weather"}}"#;
11

12
    let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
13
14
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "search");
15
    assert_eq!(normal_text, "Here are some results: ");
16

17
    let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
18
19
20
    assert_eq!(args["query"], "weather");
}

21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#[tokio::test]
async fn test_llama_with_semicolon_separation() {
    let parser = LlamaParser::new();

    let input = r#"<|python_tag|>{"name": "tool1", "parameters": {}};{"name": "tool2", "parameters": {"y": 2}}"#;

    let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 2);
    assert_eq!(tools[0].function.name, "tool1");
    assert_eq!(tools[1].function.name, "tool2");
    assert_eq!(normal_text, "");
}

#[tokio::test]
async fn test_llama_no_tool_calls() {
    let parser = LlamaParser::new();

    let input = "This is just plain text with no tool calls";
    let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 0);
    assert_eq!(normal_text, input);
}

44
45
46
#[tokio::test]
async fn test_llama_plain_json_fallback() {
    let parser = LlamaParser::new();
47
    let input = r#"{"name": "calculate", "parameters": {"x": 5, "y": 10}}"#;
48

49
50
51
    let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "calculate");
52

53
    let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
54
55
56
57
58
59
60
    assert_eq!(args["x"], 5);
    assert_eq!(args["y"], 10);
}

#[tokio::test]
async fn test_llama_with_text_before() {
    let parser = LlamaParser::new();
61
    let input = r#"Let me help you with that. <|python_tag|>{"name": "get_time", "parameters": {"timezone": "UTC"}}"#;
62

63
    let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
64
    assert_eq!(tools.len(), 1);
65
    assert_eq!(normal_text, "Let me help you with that. ");
66
    assert_eq!(tools[0].function.name, "get_time");
67

68
    let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
69
70
71
72
73
74
75
76
    assert_eq!(args["timezone"], "UTC");
}

#[tokio::test]
async fn test_llama_with_nested_json() {
    let parser = LlamaParser::new();
    let input = r#"<|python_tag|>{
        "name": "update_settings",
77
        "parameters": {
78
79
80
81
82
83
84
85
            "preferences": {
                "theme": "dark",
                "language": "en"
            },
            "notifications": true
        }
    }"#;

86
87
88
    let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "update_settings");
89

90
    let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
91
92
93
94
95
96
97
98
99
    assert_eq!(args["preferences"]["theme"], "dark");
    assert_eq!(args["notifications"], true);
}

#[tokio::test]
async fn test_llama_empty_arguments() {
    let parser = LlamaParser::new();

    // With python_tag
100
    let input = r#"<|python_tag|>{"name": "ping", "parameters": {}}"#;
101
102
103
    let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "ping");
104
105

    // Plain JSON
106
    let input = r#"{"name": "ping", "parameters": {}}"#;
107
108
109
    let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "ping");
110
111
112
113
114
115
116
}

#[tokio::test]
async fn test_llama_format_detection() {
    let parser = LlamaParser::new();

    assert!(parser.detect_format(r#"<|python_tag|>{"name": "test"}"#));
117
    assert!(parser.detect_format(r#"{"name": "test", "parameters": {}}"#));
118
119
120
121
122
123
124
125
126
    assert!(!parser.detect_format("plain text"));
    assert!(!parser.detect_format(r#"{"key": "value"}"#)); // No name field
}

#[tokio::test]
async fn test_llama_invalid_json_after_tag() {
    let parser = LlamaParser::new();

    let input = r#"<|python_tag|>{"name": invalid}"#;
127
    let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
128
    assert_eq!(tools.len(), 0);
129
    assert_eq!(normal_text, "<|python_tag|>{\"name\": invalid}");
130
131
132
133
134
135
136
137
138
}

#[tokio::test]
async fn test_llama_real_world_output() {
    let parser = LlamaParser::new();

    // Actual output from Llama 3.2 model - simplified for testing
    let input = r#"I'll search for that information for you.

139
<|python_tag|>{"name": "web_search", "parameters": {"query": "Llama 3.2 model capabilities", "num_results": 5, "search_type": "recent"}}"#;
140

141
142
143
    let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "web_search");
144
145
146

    let formatted_input = r#"<|python_tag|>{
    "name": "get_current_time",
147
    "parameters": {
148
149
150
151
152
        "timezone": "America/New_York",
        "format": "ISO8601"
    }
}"#;

153
154
155
    let (_normal_text, tools2) = parser.parse_complete(formatted_input).await.unwrap();
    assert_eq!(tools2.len(), 1);
    assert_eq!(tools2[0].function.name, "get_current_time");
156
157
}

158
159
160
#[tokio::test]
async fn test_single_json() {
    let parser = LlamaParser::new();
161
    let text = r#"{"name": "get_weather", "parameters": {"city": "Paris"}}"#;
162

163
164
165
    let (_normal_text, tools) = parser.parse_complete(text).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "get_weather");
166

167
    let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
168
169
170
171
172
173
    assert_eq!(args["city"], "Paris");
}

#[tokio::test]
async fn test_multiple_json_with_separator() {
    let parser = LlamaParser::new();
174
    let text = r#"<|python_tag|>{"name": "get_weather", "parameters": {"city": "Paris"}};{"name": "get_tourist_attractions", "parameters": {"city": "Paris"}}"#;
175

176
    let (_normal_text, tools) = parser.parse_complete(text).await.unwrap();
177
    // Note: Current implementation may only parse the first one due to semicolon handling
178
179
    assert!(!tools.is_empty());
    assert_eq!(tools[0].function.name, "get_weather");
180
181
182
183
184
}

#[tokio::test]
async fn test_json_with_trailing_text() {
    let parser = LlamaParser::new();
185
186
    // Valid JSON with trailing text - LlamaParser doesn't support this mixed format
    let text = r#"{"name": "get_weather", "parameters": {}} Some follow-up text"#;
187

188
189
190
191
192
    let (normal_text, tools) = parser.parse_complete(text).await.unwrap();
    // LlamaParser expects pure JSON or <|python_tag|> format, not JSON with trailing text
    // So this returns as normal text
    assert_eq!(tools.len(), 0);
    assert_eq!(normal_text, text);
193
194
195
196
197
}

#[tokio::test]
async fn test_invalid_then_valid_json() {
    let parser = LlamaParser::new();
198
199
    let text =
        r#"{"name": "get_weather", "parameters": {{"name": "get_weather", "parameters": {}}"#;
200

201
    let (_normal_text, tools) = parser.parse_complete(text).await.unwrap();
202
    // Should parse at least one valid JSON
203
204
    if !tools.is_empty() {
        assert_eq!(tools[0].function.name, "get_weather");
205
206
207
208
209
210
211
212
    }
}

#[tokio::test]
async fn test_plain_text_only() {
    let parser = LlamaParser::new();
    let text = "This is just plain explanation text.";

213
214
    let (_normal_text, tools) = parser.parse_complete(text).await.unwrap();
    assert_eq!(tools.len(), 0);
215
216
217
218
219
}

#[tokio::test]
async fn test_with_python_tag_prefix() {
    let parser = LlamaParser::new();
220
    let text = r#"Some intro. <|python_tag|>{"name": "get_weather", "parameters": {}}"#;
221

222
223
224
    let (_normal_text, tools) = parser.parse_complete(text).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "get_weather");
225
226
227
228
229
230
231
232
233
234
}

// STREAMING TESTS

#[tokio::test]
async fn test_llama_streaming_simple() {
    let parser = LlamaParser::new();
    let mut state = sglang_router_rs::tool_parser::ParseState::new();

    // Send complete JSON at once
235
    let full_json = r#"<|python_tag|>{"name": "search", "parameters": {"query": "weather"}}"#;
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259

    let result = parser
        .parse_incremental(full_json, &mut state)
        .await
        .unwrap();

    match result {
        sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
            assert_eq!(tool.function.name, "search");
        }
        _ => panic!("Expected ToolComplete for complete JSON input"),
    }
}

#[tokio::test]
async fn test_llama_streaming_partial() {
    let parser = LlamaParser::new();
    let mut state = sglang_router_rs::tool_parser::ParseState::new();

    // Stream in chunks
    let chunks = vec![
        r#"<|python"#,
        r#"_tag|>{"name": "#,
        r#""calculate", "#,
260
        r#""parameters": {"x": 10}"#,
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
        r#"}"#,
    ];

    let mut got_complete = false;

    for chunk in chunks {
        let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
        if let sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) = result {
            assert_eq!(tool.function.name, "calculate");
            got_complete = true;
        }
    }

    assert!(got_complete, "Should have completed parsing");
}

#[tokio::test]
async fn test_llama_streaming_plain_json() {
    let parser = LlamaParser::new();
    let mut state = sglang_router_rs::tool_parser::ParseState::new();

    // Stream plain JSON without python_tag
    let chunks = vec![
        r#"{"name": "#,
        r#""search", "#,
286
        r#""parameters": "#,
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
        r#"{"query": "#,
        r#""test"}}"#,
    ];

    let mut got_complete = false;

    for chunk in chunks {
        let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
        if let sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) = result {
            assert_eq!(tool.function.name, "search");
            got_complete = true;
        }
    }

    assert!(got_complete, "Should have completed parsing");
}

#[tokio::test]
async fn test_llama_streaming_with_text_before() {
    let parser = LlamaParser::new();
    let mut state = sglang_router_rs::tool_parser::ParseState::new();

    let chunks = vec![
        r#"Let me help you. "#,
        r#"<|python_tag|>"#,
        r#"{"name": "get_time","#,
313
        r#" "parameters": {"#,
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
        r#""timezone": "UTC"}}"#,
    ];

    let mut got_complete = false;

    for chunk in chunks {
        let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
        if let sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) = result {
            assert_eq!(tool.function.name, "get_time");
            got_complete = true;
        }
    }

    assert!(got_complete, "Should have completed parsing");
}

#[tokio::test]
async fn test_llama_streaming_multiple_tools() {
    let parser = LlamaParser::new();
    let mut state = sglang_router_rs::tool_parser::ParseState::new();

    let text =
336
        r#"<|python_tag|>{"name": "func1", "parameters": {}};{"name": "func2", "parameters": {}}"#;
337
338
339

    let result = parser.parse_incremental(text, &mut state).await.unwrap();

340
    // Should get first tool complete
341
342
343
344
    match result {
        sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
            assert_eq!(tool.function.name, "func1");
        }
345
        _ => panic!("Expected first tool to be complete, got: {:?}", result),
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
    }

    // Process remaining buffer to get second tool
    let result2 = parser.parse_incremental("", &mut state).await.unwrap();
    match result2 {
        sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
            assert_eq!(tool.function.name, "func2");
        }
        _ => panic!("Expected second tool to be complete"),
    }
}

#[tokio::test]
async fn test_llama_streaming_multiple_tools_chunked() {
    let parser = LlamaParser::new();
    let mut state = sglang_router_rs::tool_parser::ParseState::new();

    // First chunk - incomplete first JSON
364
    let chunk1 = r#"<|python_tag|>{"name": "get_weather", "parameters""#;
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
    let result1 = parser.parse_incremental(chunk1, &mut state).await.unwrap();

    // Should be incomplete or have tool name
    match result1 {
        sglang_router_rs::tool_parser::StreamResult::Incomplete
        | sglang_router_rs::tool_parser::StreamResult::ToolName { .. }
        | sglang_router_rs::tool_parser::StreamResult::ToolArguments { .. } => {
            // Expected - could get tool name or be incomplete or even partial args
        }
        _ => panic!(
            "Expected incomplete or tool name for partial JSON, got: {:?}",
            result1
        ),
    }

    // Second chunk - complete first JSON and separator
    let chunk2 = r#": {"city": "Paris"}};{"name": "#;
    let result2 = parser.parse_incremental(chunk2, &mut state).await.unwrap();

    // Should get first tool complete
    match result2 {
        sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
            assert_eq!(tool.function.name, "get_weather");
            let args: serde_json::Value = serde_json::from_str(&tool.function.arguments).unwrap();
            assert_eq!(args["city"], "Paris");
        }
391
        _ => panic!("Expected first tool complete, got: {:?}", result2),
392
393
    }

394
    let chunk3 = r#""get_time", "parameters": {"timezone": "UTC"}}"#;
395
396
397
398
    let result3 = parser.parse_incremental(chunk3, &mut state).await.unwrap();
    match result3 {
        sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
            assert_eq!(tool.function.name, "get_time");
399
        }
400
        _ => panic!("Expected tool to be complete, got: {:?}", result3),
401
402
    }
}