tool_parser_llama.rs 14.3 KB
Newer Older
1
2
3
4
5
6
//! Llama Parser Integration Tests
//!
//! Tests for the Llama parser which handles <|python_tag|> format and plain JSON

use sglang_router_rs::tool_parser::{LlamaParser, ToolParser};

7
mod common;
8
use common::{create_test_tools, streaming_helpers::*};
9

10
11
12
#[tokio::test]
async fn test_llama_python_tag_format() {
    let parser = LlamaParser::new();
13
    let input = r#"Here are some results: <|python_tag|>{"name": "search", "parameters": {"query": "weather"}}"#;
14

15
    let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
16
17
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "search");
18
    assert_eq!(normal_text, "Here are some results: ");
19

20
    let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
21
22
23
    assert_eq!(args["query"], "weather");
}

24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
#[tokio::test]
async fn test_llama_with_semicolon_separation() {
    let parser = LlamaParser::new();

    let input = r#"<|python_tag|>{"name": "tool1", "parameters": {}};{"name": "tool2", "parameters": {"y": 2}}"#;

    let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 2);
    assert_eq!(tools[0].function.name, "tool1");
    assert_eq!(tools[1].function.name, "tool2");
    assert_eq!(normal_text, "");
}

#[tokio::test]
async fn test_llama_no_tool_calls() {
    let parser = LlamaParser::new();

    let input = "This is just plain text with no tool calls";
    let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 0);
    assert_eq!(normal_text, input);
}

47
48
49
#[tokio::test]
async fn test_llama_plain_json_fallback() {
    let parser = LlamaParser::new();
50
    let input = r#"{"name": "calculate", "parameters": {"x": 5, "y": 10}}"#;
51

52
53
54
    let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "calculate");
55

56
    let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
57
58
59
60
61
62
63
    assert_eq!(args["x"], 5);
    assert_eq!(args["y"], 10);
}

#[tokio::test]
async fn test_llama_with_text_before() {
    let parser = LlamaParser::new();
64
    let input = r#"Let me help you with that. <|python_tag|>{"name": "get_time", "parameters": {"timezone": "UTC"}}"#;
65

66
    let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
67
    assert_eq!(tools.len(), 1);
68
    assert_eq!(normal_text, "Let me help you with that. ");
69
    assert_eq!(tools[0].function.name, "get_time");
70

71
    let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
72
73
74
75
76
77
78
79
    assert_eq!(args["timezone"], "UTC");
}

#[tokio::test]
async fn test_llama_with_nested_json() {
    let parser = LlamaParser::new();
    let input = r#"<|python_tag|>{
        "name": "update_settings",
80
        "parameters": {
81
82
83
84
85
86
87
88
            "preferences": {
                "theme": "dark",
                "language": "en"
            },
            "notifications": true
        }
    }"#;

89
90
91
    let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "update_settings");
92

93
    let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
94
95
96
97
98
99
100
101
102
    assert_eq!(args["preferences"]["theme"], "dark");
    assert_eq!(args["notifications"], true);
}

#[tokio::test]
async fn test_llama_empty_arguments() {
    let parser = LlamaParser::new();

    // With python_tag
103
    let input = r#"<|python_tag|>{"name": "ping", "parameters": {}}"#;
104
105
106
    let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "ping");
107
108

    // Plain JSON
109
    let input = r#"{"name": "ping", "parameters": {}}"#;
110
111
112
    let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "ping");
113
114
115
116
117
118
}

#[tokio::test]
async fn test_llama_format_detection() {
    let parser = LlamaParser::new();

119
120
121
    assert!(parser.has_tool_markers(r#"<|python_tag|>{"name": "test"}"#));
    assert!(parser.has_tool_markers(r#"{"name": "test", "parameters": {}}"#));
    assert!(!parser.has_tool_markers("plain text"));
122
123
124
125
126
127
128
}

#[tokio::test]
async fn test_llama_invalid_json_after_tag() {
    let parser = LlamaParser::new();

    let input = r#"<|python_tag|>{"name": invalid}"#;
129
    let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
130
    assert_eq!(tools.len(), 0);
131
    assert_eq!(normal_text, "<|python_tag|>{\"name\": invalid}");
132
133
134
135
136
137
138
139
140
}

#[tokio::test]
async fn test_llama_real_world_output() {
    let parser = LlamaParser::new();

    // Actual output from Llama 3.2 model - simplified for testing
    let input = r#"I'll search for that information for you.

141
<|python_tag|>{"name": "web_search", "parameters": {"query": "Llama 3.2 model capabilities", "num_results": 5, "search_type": "recent"}}"#;
142

143
144
145
    let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "web_search");
146
147
148

    let formatted_input = r#"<|python_tag|>{
    "name": "get_current_time",
149
    "parameters": {
150
151
152
153
154
        "timezone": "America/New_York",
        "format": "ISO8601"
    }
}"#;

155
156
157
    let (_normal_text, tools2) = parser.parse_complete(formatted_input).await.unwrap();
    assert_eq!(tools2.len(), 1);
    assert_eq!(tools2[0].function.name, "get_current_time");
158
159
}

160
161
162
#[tokio::test]
async fn test_single_json() {
    let parser = LlamaParser::new();
163
    let text = r#"{"name": "get_weather", "parameters": {"city": "Paris"}}"#;
164

165
166
167
    let (_normal_text, tools) = parser.parse_complete(text).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "get_weather");
168

169
    let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
170
171
172
173
174
175
    assert_eq!(args["city"], "Paris");
}

#[tokio::test]
async fn test_multiple_json_with_separator() {
    let parser = LlamaParser::new();
176
    let text = r#"<|python_tag|>{"name": "get_weather", "parameters": {"city": "Paris"}};{"name": "get_tourist_attractions", "parameters": {"city": "Paris"}}"#;
177

178
    let (_normal_text, tools) = parser.parse_complete(text).await.unwrap();
179
    // Note: Current implementation may only parse the first one due to semicolon handling
180
181
    assert!(!tools.is_empty());
    assert_eq!(tools[0].function.name, "get_weather");
182
183
184
185
186
}

#[tokio::test]
async fn test_json_with_trailing_text() {
    let parser = LlamaParser::new();
187
188
    // Valid JSON with trailing text - LlamaParser doesn't support this mixed format
    let text = r#"{"name": "get_weather", "parameters": {}} Some follow-up text"#;
189

190
191
192
193
194
    let (normal_text, tools) = parser.parse_complete(text).await.unwrap();
    // LlamaParser expects pure JSON or <|python_tag|> format, not JSON with trailing text
    // So this returns as normal text
    assert_eq!(tools.len(), 0);
    assert_eq!(normal_text, text);
195
196
197
198
199
}

#[tokio::test]
async fn test_invalid_then_valid_json() {
    let parser = LlamaParser::new();
200
201
    let text =
        r#"{"name": "get_weather", "parameters": {{"name": "get_weather", "parameters": {}}"#;
202

203
    let (_normal_text, tools) = parser.parse_complete(text).await.unwrap();
204
    // Should parse at least one valid JSON
205
206
    if !tools.is_empty() {
        assert_eq!(tools[0].function.name, "get_weather");
207
208
209
210
211
212
213
214
    }
}

#[tokio::test]
async fn test_plain_text_only() {
    let parser = LlamaParser::new();
    let text = "This is just plain explanation text.";

215
216
    let (_normal_text, tools) = parser.parse_complete(text).await.unwrap();
    assert_eq!(tools.len(), 0);
217
218
219
220
221
}

#[tokio::test]
async fn test_with_python_tag_prefix() {
    let parser = LlamaParser::new();
222
    let text = r#"Some intro. <|python_tag|>{"name": "get_weather", "parameters": {}}"#;
223

224
225
226
    let (_normal_text, tools) = parser.parse_complete(text).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "get_weather");
227
228
229
230
231
232
}

// STREAMING TESTS

#[tokio::test]
async fn test_llama_streaming_simple() {
233
234
235
    let tools = create_test_tools();

    let mut parser = LlamaParser::new();
236
237

    // Send complete JSON at once
238
    let full_json = r#"<|python_tag|>{"name": "search", "parameters": {"query": "weather"}}"#;
239

240
    let result = parser.parse_incremental(full_json, &tools).await.unwrap();
241

242
243
244
245
246
    assert!(
        !result.calls.is_empty(),
        "Expected tool call for complete JSON input"
    );
    assert_eq!(result.calls[0].name.as_ref().unwrap(), "search");
247
248
249
250
}

#[tokio::test]
async fn test_llama_streaming_partial() {
251
252
253
    let tools = create_test_tools();

    let mut parser = LlamaParser::new();
254
255
256
257
258
259

    // Stream in chunks
    let chunks = vec![
        r#"<|python"#,
        r#"_tag|>{"name": "#,
        r#""calculate", "#,
260
        r#""parameters": {"x": 10}"#,
261
262
263
264
265
266
        r#"}"#,
    ];

    let mut got_complete = false;

    for chunk in chunks {
267
268
269
270
271
272
        let result = parser.parse_incremental(chunk, &tools).await.unwrap();
        if !result.calls.is_empty() {
            if let Some(name) = &result.calls[0].name {
                assert_eq!(name, "calculate");
                got_complete = true;
            }
273
274
275
276
277
278
279
280
        }
    }

    assert!(got_complete, "Should have completed parsing");
}

#[tokio::test]
async fn test_llama_streaming_plain_json() {
281
282
283
    let tools = create_test_tools();

    let mut parser = LlamaParser::new();
284
285
286
287
288

    // Stream plain JSON without python_tag
    let chunks = vec![
        r#"{"name": "#,
        r#""search", "#,
289
        r#""parameters": "#,
290
291
292
293
294
295
296
        r#"{"query": "#,
        r#""test"}}"#,
    ];

    let mut got_complete = false;

    for chunk in chunks {
297
298
299
300
301
302
        let result = parser.parse_incremental(chunk, &tools).await.unwrap();
        if !result.calls.is_empty() {
            if let Some(name) = &result.calls[0].name {
                assert_eq!(name, "search");
                got_complete = true;
            }
303
304
305
306
307
308
309
310
        }
    }

    assert!(got_complete, "Should have completed parsing");
}

#[tokio::test]
async fn test_llama_streaming_with_text_before() {
311
312
313
    let tools = create_test_tools();

    let mut parser = LlamaParser::new();
314
315
316
317
318

    let chunks = vec![
        r#"Let me help you. "#,
        r#"<|python_tag|>"#,
        r#"{"name": "get_time","#,
319
        r#" "parameters": {"#,
320
321
322
323
324
325
        r#""timezone": "UTC"}}"#,
    ];

    let mut got_complete = false;

    for chunk in chunks {
326
327
328
329
330
331
        let result = parser.parse_incremental(chunk, &tools).await.unwrap();
        if !result.calls.is_empty() {
            if let Some(name) = &result.calls[0].name {
                assert_eq!(name, "get_time");
                got_complete = true;
            }
332
333
334
335
336
337
338
339
        }
    }

    assert!(got_complete, "Should have completed parsing");
}

#[tokio::test]
async fn test_llama_streaming_multiple_tools() {
340
341
342
    let tools = create_test_tools();

    let mut parser = LlamaParser::new();
343
344

    let text =
345
        r#"<|python_tag|>{"name": "func1", "parameters": {}};{"name": "func2", "parameters": {}}"#;
346

347
    let result = parser.parse_incremental(text, &tools).await.unwrap();
348

349
    // Should get first tool complete
350
351
352
353
354
355
    assert!(
        !result.calls.is_empty(),
        "Expected first tool to be complete"
    );
    if let Some(name) = &result.calls[0].name {
        assert_eq!(name, "func1");
356
357
358
    }

    // Process remaining buffer to get second tool
359
360
361
362
    let result2 = parser.parse_incremental("", &tools).await.unwrap();
    if !result2.calls.is_empty() {
        if let Some(name) = &result2.calls[0].name {
            assert_eq!(name, "func2");
363
364
365
366
367
368
        }
    }
}

#[tokio::test]
async fn test_llama_streaming_multiple_tools_chunked() {
369
370
371
    let mut parser = LlamaParser::new();

    let tools = create_test_tools();
372
373

    // First chunk - incomplete first JSON
374
    let chunk1 = r#"<|python_tag|>{"name": "get_weather", "parameters""#;
375
376
377
378
    let result1 = parser.parse_incremental(chunk1, &tools).await.unwrap();
    if !result1.calls.is_empty() {
        if let Some(name) = &result1.calls[0].name {
            assert_eq!(name, "get_weather");
379
380
381
382
383
        }
    }

    // Second chunk - complete first JSON and separator
    let chunk2 = r#": {"city": "Paris"}};{"name": "#;
384
    let result2 = parser.parse_incremental(chunk2, &tools).await.unwrap();
385

386
387
388
389
    // Should get parameters for first tool (name already sent in result1)
    if !result2.calls.is_empty() {
        let args: serde_json::Value = serde_json::from_str(&result2.calls[0].parameters).unwrap();
        assert_eq!(args["city"], "Paris");
390
391
    }

392
    let chunk3 = r#""get_time", "parameters": {"timezone": "UTC"}}"#;
393
394
395
396
    let result3 = parser.parse_incremental(chunk3, &tools).await.unwrap();
    if !result3.calls.is_empty() {
        if let Some(name) = &result3.calls[0].name {
            assert_eq!(name, "get_time");
397
398
399
        }
    }
}
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455

// =============================================================================
// REALISTIC STREAMING TESTS
// =============================================================================

#[tokio::test]
async fn test_llama_realistic_chunks_with_python_tag() {
    let tools = create_test_tools();
    let mut parser = LlamaParser::new();

    let input = r#"<|python_tag|>{"name": "calculate", "parameters": {"x": 10, "y": 20}}"#;
    let chunks = create_realistic_chunks(input);

    assert!(chunks.len() > 15, "Should have many small chunks");

    let mut got_tool_name = false;

    for chunk in chunks {
        let result = parser.parse_incremental(&chunk, &tools).await.unwrap();
        for call in result.calls {
            if let Some(name) = call.name {
                assert_eq!(name, "calculate");
                got_tool_name = true;
            }
        }
    }

    assert!(got_tool_name, "Should have parsed tool name");
}

#[tokio::test]
async fn test_llama_python_tag_arrives_in_parts() {
    let tools = create_test_tools();
    let mut parser = LlamaParser::new();

    // Python tag itself arrives in small chunks
    let chunks = vec![
        "<|p", "yth", "on_", "tag", "|>{", r#"""#, "na", r#"me""#, ": ", r#"""#, "sea", "rch",
        r#"""#, ", ", r#"""#, "par", "ame", "ter", "s", r#"""#, ": {", r#"""#, "q", r#"""#, ": ",
        r#"""#, "tes", "t", r#"""#, "}}",
    ];

    let mut got_tool_name = false;

    for chunk in chunks {
        let result = parser.parse_incremental(chunk, &tools).await.unwrap();
        for call in result.calls {
            if let Some(name) = call.name {
                assert_eq!(name, "search");
                got_tool_name = true;
            }
        }
    }

    assert!(got_tool_name, "Should have parsed tool name");
}