tool_parser_llama.rs 12.5 KB
Newer Older
1
2
3
4
5
6
//! Llama Parser Integration Tests
//!
//! Tests for the Llama parser which handles <|python_tag|> format and plain JSON

use sglang_router_rs::tool_parser::{LlamaParser, ToolParser};

7
8
9
mod common;
use common::create_test_tools;

10
11
12
#[tokio::test]
async fn test_llama_python_tag_format() {
    let parser = LlamaParser::new();
13
    let input = r#"Here are some results: <|python_tag|>{"name": "search", "parameters": {"query": "weather"}}"#;
14

15
    let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
16
17
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "search");
18
    assert_eq!(normal_text, "Here are some results: ");
19

20
    let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
21
22
23
    assert_eq!(args["query"], "weather");
}

24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
#[tokio::test]
async fn test_llama_with_semicolon_separation() {
    let parser = LlamaParser::new();

    let input = r#"<|python_tag|>{"name": "tool1", "parameters": {}};{"name": "tool2", "parameters": {"y": 2}}"#;

    let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 2);
    assert_eq!(tools[0].function.name, "tool1");
    assert_eq!(tools[1].function.name, "tool2");
    assert_eq!(normal_text, "");
}

#[tokio::test]
async fn test_llama_no_tool_calls() {
    let parser = LlamaParser::new();

    let input = "This is just plain text with no tool calls";
    let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 0);
    assert_eq!(normal_text, input);
}

47
48
49
#[tokio::test]
async fn test_llama_plain_json_fallback() {
    let parser = LlamaParser::new();
50
    let input = r#"{"name": "calculate", "parameters": {"x": 5, "y": 10}}"#;
51

52
53
54
    let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "calculate");
55

56
    let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
57
58
59
60
61
62
63
    assert_eq!(args["x"], 5);
    assert_eq!(args["y"], 10);
}

#[tokio::test]
async fn test_llama_with_text_before() {
    let parser = LlamaParser::new();
64
    let input = r#"Let me help you with that. <|python_tag|>{"name": "get_time", "parameters": {"timezone": "UTC"}}"#;
65

66
    let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
67
    assert_eq!(tools.len(), 1);
68
    assert_eq!(normal_text, "Let me help you with that. ");
69
    assert_eq!(tools[0].function.name, "get_time");
70

71
    let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
72
73
74
75
76
77
78
79
    assert_eq!(args["timezone"], "UTC");
}

#[tokio::test]
async fn test_llama_with_nested_json() {
    let parser = LlamaParser::new();
    let input = r#"<|python_tag|>{
        "name": "update_settings",
80
        "parameters": {
81
82
83
84
85
86
87
88
            "preferences": {
                "theme": "dark",
                "language": "en"
            },
            "notifications": true
        }
    }"#;

89
90
91
    let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "update_settings");
92

93
    let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
94
95
96
97
98
99
100
101
102
    assert_eq!(args["preferences"]["theme"], "dark");
    assert_eq!(args["notifications"], true);
}

#[tokio::test]
async fn test_llama_empty_arguments() {
    let parser = LlamaParser::new();

    // With python_tag
103
    let input = r#"<|python_tag|>{"name": "ping", "parameters": {}}"#;
104
105
106
    let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "ping");
107
108

    // Plain JSON
109
    let input = r#"{"name": "ping", "parameters": {}}"#;
110
111
112
    let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "ping");
113
114
115
116
117
118
}

#[tokio::test]
async fn test_llama_format_detection() {
    let parser = LlamaParser::new();

119
120
121
    assert!(parser.has_tool_markers(r#"<|python_tag|>{"name": "test"}"#));
    assert!(parser.has_tool_markers(r#"{"name": "test", "parameters": {}}"#));
    assert!(!parser.has_tool_markers("plain text"));
122
123
124
125
126
127
128
}

#[tokio::test]
async fn test_llama_invalid_json_after_tag() {
    let parser = LlamaParser::new();

    let input = r#"<|python_tag|>{"name": invalid}"#;
129
    let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
130
    assert_eq!(tools.len(), 0);
131
    assert_eq!(normal_text, "<|python_tag|>{\"name\": invalid}");
132
133
134
135
136
137
138
139
140
}

#[tokio::test]
async fn test_llama_real_world_output() {
    let parser = LlamaParser::new();

    // Actual output from Llama 3.2 model - simplified for testing
    let input = r#"I'll search for that information for you.

141
<|python_tag|>{"name": "web_search", "parameters": {"query": "Llama 3.2 model capabilities", "num_results": 5, "search_type": "recent"}}"#;
142

143
144
145
    let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "web_search");
146
147
148

    let formatted_input = r#"<|python_tag|>{
    "name": "get_current_time",
149
    "parameters": {
150
151
152
153
154
        "timezone": "America/New_York",
        "format": "ISO8601"
    }
}"#;

155
156
157
    let (_normal_text, tools2) = parser.parse_complete(formatted_input).await.unwrap();
    assert_eq!(tools2.len(), 1);
    assert_eq!(tools2[0].function.name, "get_current_time");
158
159
}

160
161
162
#[tokio::test]
async fn test_single_json() {
    let parser = LlamaParser::new();
163
    let text = r#"{"name": "get_weather", "parameters": {"city": "Paris"}}"#;
164

165
166
167
    let (_normal_text, tools) = parser.parse_complete(text).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "get_weather");
168

169
    let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
170
171
172
173
174
175
    assert_eq!(args["city"], "Paris");
}

#[tokio::test]
async fn test_multiple_json_with_separator() {
    let parser = LlamaParser::new();
176
    let text = r#"<|python_tag|>{"name": "get_weather", "parameters": {"city": "Paris"}};{"name": "get_tourist_attractions", "parameters": {"city": "Paris"}}"#;
177

178
    let (_normal_text, tools) = parser.parse_complete(text).await.unwrap();
179
    // Note: Current implementation may only parse the first one due to semicolon handling
180
181
    assert!(!tools.is_empty());
    assert_eq!(tools[0].function.name, "get_weather");
182
183
184
185
186
}

#[tokio::test]
async fn test_json_with_trailing_text() {
    let parser = LlamaParser::new();
187
188
    // Valid JSON with trailing text - LlamaParser doesn't support this mixed format
    let text = r#"{"name": "get_weather", "parameters": {}} Some follow-up text"#;
189

190
191
192
193
194
    let (normal_text, tools) = parser.parse_complete(text).await.unwrap();
    // LlamaParser expects pure JSON or <|python_tag|> format, not JSON with trailing text
    // So this returns as normal text
    assert_eq!(tools.len(), 0);
    assert_eq!(normal_text, text);
195
196
197
198
199
}

#[tokio::test]
async fn test_invalid_then_valid_json() {
    let parser = LlamaParser::new();
200
201
    let text =
        r#"{"name": "get_weather", "parameters": {{"name": "get_weather", "parameters": {}}"#;
202

203
    let (_normal_text, tools) = parser.parse_complete(text).await.unwrap();
204
    // Should parse at least one valid JSON
205
206
    if !tools.is_empty() {
        assert_eq!(tools[0].function.name, "get_weather");
207
208
209
210
211
212
213
214
    }
}

#[tokio::test]
async fn test_plain_text_only() {
    let parser = LlamaParser::new();
    let text = "This is just plain explanation text.";

215
216
    let (_normal_text, tools) = parser.parse_complete(text).await.unwrap();
    assert_eq!(tools.len(), 0);
217
218
219
220
221
}

#[tokio::test]
async fn test_with_python_tag_prefix() {
    let parser = LlamaParser::new();
222
    let text = r#"Some intro. <|python_tag|>{"name": "get_weather", "parameters": {}}"#;
223

224
225
226
    let (_normal_text, tools) = parser.parse_complete(text).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "get_weather");
227
228
229
230
231
232
}

// STREAMING TESTS

#[tokio::test]
async fn test_llama_streaming_simple() {
233
234
235
    let tools = create_test_tools();

    let mut parser = LlamaParser::new();
236
237

    // Send complete JSON at once
238
    let full_json = r#"<|python_tag|>{"name": "search", "parameters": {"query": "weather"}}"#;
239

240
    let result = parser.parse_incremental(full_json, &tools).await.unwrap();
241

242
243
244
245
246
    assert!(
        !result.calls.is_empty(),
        "Expected tool call for complete JSON input"
    );
    assert_eq!(result.calls[0].name.as_ref().unwrap(), "search");
247
248
249
250
}

#[tokio::test]
async fn test_llama_streaming_partial() {
251
252
253
    let tools = create_test_tools();

    let mut parser = LlamaParser::new();
254
255
256
257
258
259

    // Stream in chunks
    let chunks = vec![
        r#"<|python"#,
        r#"_tag|>{"name": "#,
        r#""calculate", "#,
260
        r#""parameters": {"x": 10}"#,
261
262
263
264
265
266
        r#"}"#,
    ];

    let mut got_complete = false;

    for chunk in chunks {
267
268
269
270
271
272
        let result = parser.parse_incremental(chunk, &tools).await.unwrap();
        if !result.calls.is_empty() {
            if let Some(name) = &result.calls[0].name {
                assert_eq!(name, "calculate");
                got_complete = true;
            }
273
274
275
276
277
278
279
280
        }
    }

    assert!(got_complete, "Should have completed parsing");
}

#[tokio::test]
async fn test_llama_streaming_plain_json() {
281
282
283
    let tools = create_test_tools();

    let mut parser = LlamaParser::new();
284
285
286
287
288

    // Stream plain JSON without python_tag
    let chunks = vec![
        r#"{"name": "#,
        r#""search", "#,
289
        r#""parameters": "#,
290
291
292
293
294
295
296
        r#"{"query": "#,
        r#""test"}}"#,
    ];

    let mut got_complete = false;

    for chunk in chunks {
297
298
299
300
301
302
        let result = parser.parse_incremental(chunk, &tools).await.unwrap();
        if !result.calls.is_empty() {
            if let Some(name) = &result.calls[0].name {
                assert_eq!(name, "search");
                got_complete = true;
            }
303
304
305
306
307
308
309
310
        }
    }

    assert!(got_complete, "Should have completed parsing");
}

#[tokio::test]
async fn test_llama_streaming_with_text_before() {
311
312
313
    let tools = create_test_tools();

    let mut parser = LlamaParser::new();
314
315
316
317
318

    let chunks = vec![
        r#"Let me help you. "#,
        r#"<|python_tag|>"#,
        r#"{"name": "get_time","#,
319
        r#" "parameters": {"#,
320
321
322
323
324
325
        r#""timezone": "UTC"}}"#,
    ];

    let mut got_complete = false;

    for chunk in chunks {
326
327
328
329
330
331
        let result = parser.parse_incremental(chunk, &tools).await.unwrap();
        if !result.calls.is_empty() {
            if let Some(name) = &result.calls[0].name {
                assert_eq!(name, "get_time");
                got_complete = true;
            }
332
333
334
335
336
337
338
339
        }
    }

    assert!(got_complete, "Should have completed parsing");
}

#[tokio::test]
async fn test_llama_streaming_multiple_tools() {
340
341
342
    let tools = create_test_tools();

    let mut parser = LlamaParser::new();
343
344

    let text =
345
        r#"<|python_tag|>{"name": "func1", "parameters": {}};{"name": "func2", "parameters": {}}"#;
346

347
    let result = parser.parse_incremental(text, &tools).await.unwrap();
348

349
    // Should get first tool complete
350
351
352
353
354
355
    assert!(
        !result.calls.is_empty(),
        "Expected first tool to be complete"
    );
    if let Some(name) = &result.calls[0].name {
        assert_eq!(name, "func1");
356
357
358
    }

    // Process remaining buffer to get second tool
359
360
361
362
    let result2 = parser.parse_incremental("", &tools).await.unwrap();
    if !result2.calls.is_empty() {
        if let Some(name) = &result2.calls[0].name {
            assert_eq!(name, "func2");
363
364
365
366
367
368
        }
    }
}

#[tokio::test]
async fn test_llama_streaming_multiple_tools_chunked() {
369
370
371
    let mut parser = LlamaParser::new();

    let tools = create_test_tools();
372
373

    // First chunk - incomplete first JSON
374
    let chunk1 = r#"<|python_tag|>{"name": "get_weather", "parameters""#;
375
376
377
378
    let result1 = parser.parse_incremental(chunk1, &tools).await.unwrap();
    if !result1.calls.is_empty() {
        if let Some(name) = &result1.calls[0].name {
            assert_eq!(name, "get_weather");
379
380
381
382
383
        }
    }

    // Second chunk - complete first JSON and separator
    let chunk2 = r#": {"city": "Paris"}};{"name": "#;
384
    let result2 = parser.parse_incremental(chunk2, &tools).await.unwrap();
385

386
387
388
389
    // Should get parameters for first tool (name already sent in result1)
    if !result2.calls.is_empty() {
        let args: serde_json::Value = serde_json::from_str(&result2.calls[0].parameters).unwrap();
        assert_eq!(args["city"], "Paris");
390
391
    }

392
    let chunk3 = r#""get_time", "parameters": {"timezone": "UTC"}}"#;
393
394
395
396
    let result3 = parser.parse_incremental(chunk3, &tools).await.unwrap();
    if !result3.calls.is_empty() {
        if let Some(name) = &result3.calls[0].name {
            assert_eq!(name, "get_time");
397
398
399
        }
    }
}