tool_parser_llama.rs 14.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
//! Llama Parser Integration Tests
//!
//! Tests for the Llama parser which handles <|python_tag|> format and plain JSON

use sglang_router_rs::tool_parser::{LlamaParser, ToolParser};

#[tokio::test]
async fn test_llama_python_tag_format() {
    let parser = LlamaParser::new();
    let input = r#"<|python_tag|>{"name": "search", "arguments": {"query": "weather"}}"#;

12
13
14
    let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "search");
15

16
    let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
17
18
19
20
21
22
23
24
    assert_eq!(args["query"], "weather");
}

#[tokio::test]
async fn test_llama_plain_json_fallback() {
    let parser = LlamaParser::new();
    let input = r#"{"name": "calculate", "arguments": {"x": 5, "y": 10}}"#;

25
26
27
    let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "calculate");
28

29
    let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
30
31
32
33
34
35
36
37
38
    assert_eq!(args["x"], 5);
    assert_eq!(args["y"], 10);
}

#[tokio::test]
async fn test_llama_with_text_before() {
    let parser = LlamaParser::new();
    let input = r#"Let me help you with that. <|python_tag|>{"name": "get_time", "arguments": {"timezone": "UTC"}}"#;

39
    let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
40
    assert_eq!(tools.len(), 1);
41
    assert_eq!(normal_text, "Let me help you with that. ");
42
    assert_eq!(tools[0].function.name, "get_time");
43

44
    let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    assert_eq!(args["timezone"], "UTC");
}

#[tokio::test]
async fn test_llama_with_nested_json() {
    let parser = LlamaParser::new();
    let input = r#"<|python_tag|>{
        "name": "update_settings",
        "arguments": {
            "preferences": {
                "theme": "dark",
                "language": "en"
            },
            "notifications": true
        }
    }"#;

62
63
64
    let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "update_settings");
65

66
    let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
67
68
69
70
71
72
73
74
75
76
    assert_eq!(args["preferences"]["theme"], "dark");
    assert_eq!(args["notifications"], true);
}

#[tokio::test]
async fn test_llama_empty_arguments() {
    let parser = LlamaParser::new();

    // With python_tag
    let input = r#"<|python_tag|>{"name": "ping", "arguments": {}}"#;
77
78
79
    let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "ping");
80
81
82

    // Plain JSON
    let input = r#"{"name": "ping", "arguments": {}}"#;
83
84
85
    let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "ping");
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
}

#[tokio::test]
async fn test_llama_format_detection() {
    let parser = LlamaParser::new();

    assert!(parser.detect_format(r#"<|python_tag|>{"name": "test"}"#));
    assert!(parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
    assert!(!parser.detect_format("plain text"));
    assert!(!parser.detect_format(r#"{"key": "value"}"#)); // No name field
}

#[tokio::test]
async fn test_llama_invalid_json_after_tag() {
    let parser = LlamaParser::new();

    let input = r#"<|python_tag|>{"name": invalid}"#;
103
    let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
104
    assert_eq!(tools.len(), 0);
105
    assert_eq!(normal_text, "<|python_tag|>{\"name\": invalid}");
106
107
108
109
110
111
112
113
114
115
116
}

#[tokio::test]
async fn test_llama_real_world_output() {
    let parser = LlamaParser::new();

    // Actual output from Llama 3.2 model - simplified for testing
    let input = r#"I'll search for that information for you.

<|python_tag|>{"name": "web_search", "arguments": {"query": "Llama 3.2 model capabilities", "num_results": 5, "search_type": "recent"}}"#;

117
118
119
    let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "web_search");
120
121
122
123
124
125
126
127
128

    let formatted_input = r#"<|python_tag|>{
    "name": "get_current_time",
    "arguments": {
        "timezone": "America/New_York",
        "format": "ISO8601"
    }
}"#;

129
130
131
    let (_normal_text, tools2) = parser.parse_complete(formatted_input).await.unwrap();
    assert_eq!(tools2.len(), 1);
    assert_eq!(tools2[0].function.name, "get_current_time");
132
133
134
135
136
137
138
139
140
}

#[tokio::test]
async fn test_llama_json_array_format() {
    let parser = LlamaParser::new();

    // Plain JSON array (should work as fallback)
    let input = r#"[{"name": "func1", "arguments": {}}, {"name": "func2", "arguments": {}}]"#;

141
    let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
142
    // Current implementation might handle this through JSON fallback
143
    assert!(!tools.is_empty());
144
}
145
146
147
148
149
150

#[tokio::test]
async fn test_single_json() {
    let parser = LlamaParser::new();
    let text = r#"{"name": "get_weather", "arguments": {"city": "Paris"}}"#;

151
152
153
    let (_normal_text, tools) = parser.parse_complete(text).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "get_weather");
154

155
    let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
156
157
158
159
160
161
162
163
    assert_eq!(args["city"], "Paris");
}

#[tokio::test]
async fn test_multiple_json_with_separator() {
    let parser = LlamaParser::new();
    let text = r#"<|python_tag|>{"name": "get_weather", "arguments": {"city": "Paris"}};{"name": "get_tourist_attractions", "arguments": {"city": "Paris"}}"#;

164
    let (_normal_text, tools) = parser.parse_complete(text).await.unwrap();
165
    // Note: Current implementation may only parse the first one due to semicolon handling
166
167
    assert!(!tools.is_empty());
    assert_eq!(tools[0].function.name, "get_weather");
168
169
170
171
172
173
174
}

#[tokio::test]
async fn test_multiple_json_with_separator_customized() {
    let parser = LlamaParser::new();
    let text = r#"<|python_tag|>{"name": "get_weather", "arguments": {}}<|python_tag|>{"name": "get_tourist_attractions", "arguments": {}}"#;

175
    let (_normal_text, tools) = parser.parse_complete(text).await.unwrap();
176
    // Current implementation may handle this differently
177
178
    assert!(!tools.is_empty());
    assert_eq!(tools[0].function.name, "get_weather");
179
180
181
182
183
184
185
}

#[tokio::test]
async fn test_json_with_trailing_text() {
    let parser = LlamaParser::new();
    let text = r#"{"name": "get_weather", "arguments": {}} Some follow-up text"#;

186
187
188
    let (_normal_text, tools) = parser.parse_complete(text).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "get_weather");
189
190
191
192
193
194
195
}

#[tokio::test]
async fn test_invalid_then_valid_json() {
    let parser = LlamaParser::new();
    let text = r#"{"name": "get_weather", "arguments": {{"name": "get_weather", "arguments": {}}"#;

196
    let (_normal_text, tools) = parser.parse_complete(text).await.unwrap();
197
    // Should parse at least one valid JSON
198
199
    if !tools.is_empty() {
        assert_eq!(tools[0].function.name, "get_weather");
200
201
202
203
204
205
206
207
    }
}

#[tokio::test]
async fn test_plain_text_only() {
    let parser = LlamaParser::new();
    let text = "This is just plain explanation text.";

208
209
    let (_normal_text, tools) = parser.parse_complete(text).await.unwrap();
    assert_eq!(tools.len(), 0);
210
211
212
213
214
215
216
}

#[tokio::test]
async fn test_with_python_tag_prefix() {
    let parser = LlamaParser::new();
    let text = r#"Some intro. <|python_tag|>{"name": "get_weather", "arguments": {}}"#;

217
218
219
    let (_normal_text, tools) = parser.parse_complete(text).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "get_weather");
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
}

// STREAMING TESTS

#[tokio::test]
async fn test_llama_streaming_simple() {
    let parser = LlamaParser::new();
    let mut state = sglang_router_rs::tool_parser::ParseState::new();

    // Send complete JSON at once
    let full_json = r#"<|python_tag|>{"name": "search", "arguments": {"query": "weather"}}"#;

    let result = parser
        .parse_incremental(full_json, &mut state)
        .await
        .unwrap();

    match result {
        sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
            assert_eq!(tool.function.name, "search");
        }
        _ => panic!("Expected ToolComplete for complete JSON input"),
    }
}

#[tokio::test]
async fn test_llama_streaming_partial() {
    let parser = LlamaParser::new();
    let mut state = sglang_router_rs::tool_parser::ParseState::new();

    // Stream in chunks
    let chunks = vec![
        r#"<|python"#,
        r#"_tag|>{"name": "#,
        r#""calculate", "#,
        r#""arguments": {"x": 10}"#,
        r#"}"#,
    ];

    let mut got_complete = false;

    for chunk in chunks {
        let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
        if let sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) = result {
            assert_eq!(tool.function.name, "calculate");
            got_complete = true;
        }
    }

    assert!(got_complete, "Should have completed parsing");
}

#[tokio::test]
async fn test_llama_streaming_plain_json() {
    let parser = LlamaParser::new();
    let mut state = sglang_router_rs::tool_parser::ParseState::new();

    // Stream plain JSON without python_tag
    let chunks = vec![
        r#"{"name": "#,
        r#""search", "#,
        r#""arguments": "#,
        r#"{"query": "#,
        r#""test"}}"#,
    ];

    let mut got_complete = false;

    for chunk in chunks {
        let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
        if let sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) = result {
            assert_eq!(tool.function.name, "search");
            got_complete = true;
        }
    }

    assert!(got_complete, "Should have completed parsing");
}

#[tokio::test]
async fn test_llama_streaming_with_text_before() {
    let parser = LlamaParser::new();
    let mut state = sglang_router_rs::tool_parser::ParseState::new();

    let chunks = vec![
        r#"Let me help you. "#,
        r#"<|python_tag|>"#,
        r#"{"name": "get_time","#,
        r#" "arguments": {"#,
        r#""timezone": "UTC"}}"#,
    ];

    let mut got_complete = false;

    for chunk in chunks {
        let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
        if let sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) = result {
            assert_eq!(tool.function.name, "get_time");
            got_complete = true;
        }
    }

    assert!(got_complete, "Should have completed parsing");
}

#[tokio::test]
async fn test_llama_streaming_multiple_tools() {
    let parser = LlamaParser::new();
    let mut state = sglang_router_rs::tool_parser::ParseState::new();

    let text =
        r#"<|python_tag|>{"name": "func1", "arguments": {}};{"name": "func2", "arguments": {}}"#;

    let result = parser.parse_incremental(text, &mut state).await.unwrap();

335
    // Should get first tool complete
336
337
338
339
    match result {
        sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
            assert_eq!(tool.function.name, "func1");
        }
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
        _ => panic!("Expected first tool to be complete"),
    }

    // Process remaining buffer to get second tool
    let result2 = parser.parse_incremental("", &mut state).await.unwrap();
    match result2 {
        sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
            assert_eq!(tool.function.name, "func2");
        }
        _ => panic!("Expected second tool to be complete"),
    }
}

#[tokio::test]
async fn test_llama_streaming_multiple_tools_chunked() {
    let parser = LlamaParser::new();
    let mut state = sglang_router_rs::tool_parser::ParseState::new();

    // First chunk - incomplete first JSON
    let chunk1 = r#"<|python_tag|>{"name": "get_weather", "arguments""#;
    let result1 = parser.parse_incremental(chunk1, &mut state).await.unwrap();

    // Should be incomplete or have tool name
    match result1 {
        sglang_router_rs::tool_parser::StreamResult::Incomplete
        | sglang_router_rs::tool_parser::StreamResult::ToolName { .. }
        | sglang_router_rs::tool_parser::StreamResult::ToolArguments { .. } => {
            // Expected - could get tool name or be incomplete or even partial args
        }
        _ => panic!(
            "Expected incomplete or tool name for partial JSON, got: {:?}",
            result1
        ),
    }

    // Second chunk - complete first JSON and separator
    let chunk2 = r#": {"city": "Paris"}};{"name": "#;
    let result2 = parser.parse_incremental(chunk2, &mut state).await.unwrap();

    // Should get first tool complete
    match result2 {
        sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
            assert_eq!(tool.function.name, "get_weather");
            let args: serde_json::Value = serde_json::from_str(&tool.function.arguments).unwrap();
            assert_eq!(args["city"], "Paris");
        }
        _ => panic!("Expected first tool to be complete after separator"),
    }

    // Third chunk - complete second JSON
    let chunk3 = r#""get_time", "arguments": {"timezone": "UTC"}}"#;
    let result3 = parser.parse_incremental(chunk3, &mut state).await.unwrap();

    // Should get second tool complete
    match result3 {
        sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
            assert_eq!(tool.function.name, "get_time");
            let args: serde_json::Value = serde_json::from_str(&tool.function.arguments).unwrap();
            assert_eq!(args["timezone"], "UTC");
        }
400
        _ => {
401
402
403
404
405
406
407
408
409
410
411
            // If not complete yet, try one more empty chunk
            let result4 = parser.parse_incremental("", &mut state).await.unwrap();
            match result4 {
                sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
                    assert_eq!(tool.function.name, "get_time");
                    let args: serde_json::Value =
                        serde_json::from_str(&tool.function.arguments).unwrap();
                    assert_eq!(args["timezone"], "UTC");
                }
                _ => panic!("Expected second tool to be complete"),
            }
412
413
414
        }
    }
}