"megatron/arguments.py" did not exist on "fb4cbdc277a0d345e0343b0f7bb63176aeb41cf9"
tool_parser_llama.rs 12.6 KB
Newer Older
1
2
3
4
5
6
//! Llama Parser Integration Tests
//!
//! Tests for the Llama parser which handles <|python_tag|> format and plain JSON

use sglang_router_rs::tool_parser::{LlamaParser, ToolParser};

7
8
9
mod common;
use common::create_test_tools;

10
11
12
#[tokio::test]
async fn test_llama_python_tag_format() {
    let parser = LlamaParser::new();
13
    let input = r#"Here are some results: <|python_tag|>{"name": "search", "parameters": {"query": "weather"}}"#;
14

15
    let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
16
17
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "search");
18
    assert_eq!(normal_text, "Here are some results: ");
19

20
    let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
21
22
23
    assert_eq!(args["query"], "weather");
}

24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
#[tokio::test]
async fn test_llama_with_semicolon_separation() {
    let parser = LlamaParser::new();

    let input = r#"<|python_tag|>{"name": "tool1", "parameters": {}};{"name": "tool2", "parameters": {"y": 2}}"#;

    let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 2);
    assert_eq!(tools[0].function.name, "tool1");
    assert_eq!(tools[1].function.name, "tool2");
    assert_eq!(normal_text, "");
}

#[tokio::test]
async fn test_llama_no_tool_calls() {
    let parser = LlamaParser::new();

    let input = "This is just plain text with no tool calls";
    let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 0);
    assert_eq!(normal_text, input);
}

47
48
49
#[tokio::test]
async fn test_llama_plain_json_fallback() {
    let parser = LlamaParser::new();
50
    let input = r#"{"name": "calculate", "parameters": {"x": 5, "y": 10}}"#;
51

52
53
54
    let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "calculate");
55

56
    let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
57
58
59
60
61
62
63
    assert_eq!(args["x"], 5);
    assert_eq!(args["y"], 10);
}

#[tokio::test]
async fn test_llama_with_text_before() {
    let parser = LlamaParser::new();
64
    let input = r#"Let me help you with that. <|python_tag|>{"name": "get_time", "parameters": {"timezone": "UTC"}}"#;
65

66
    let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
67
    assert_eq!(tools.len(), 1);
68
    assert_eq!(normal_text, "Let me help you with that. ");
69
    assert_eq!(tools[0].function.name, "get_time");
70

71
    let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
72
73
74
75
76
77
78
79
    assert_eq!(args["timezone"], "UTC");
}

#[tokio::test]
async fn test_llama_with_nested_json() {
    let parser = LlamaParser::new();
    let input = r#"<|python_tag|>{
        "name": "update_settings",
80
        "parameters": {
81
82
83
84
85
86
87
88
            "preferences": {
                "theme": "dark",
                "language": "en"
            },
            "notifications": true
        }
    }"#;

89
90
91
    let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "update_settings");
92

93
    let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
94
95
96
97
98
99
100
101
102
    assert_eq!(args["preferences"]["theme"], "dark");
    assert_eq!(args["notifications"], true);
}

#[tokio::test]
async fn test_llama_empty_arguments() {
    let parser = LlamaParser::new();

    // With python_tag
103
    let input = r#"<|python_tag|>{"name": "ping", "parameters": {}}"#;
104
105
106
    let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "ping");
107
108

    // Plain JSON
109
    let input = r#"{"name": "ping", "parameters": {}}"#;
110
111
112
    let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "ping");
113
114
115
116
117
118
}

#[tokio::test]
async fn test_llama_format_detection() {
    let parser = LlamaParser::new();

119
120
121
122
    assert!(parser.has_tool_markers(r#"<|python_tag|>{"name": "test"}"#));
    assert!(parser.has_tool_markers(r#"{"name": "test", "parameters": {}}"#));
    assert!(!parser.has_tool_markers("plain text"));
    assert!(!parser.has_tool_markers(r#"{"key": "value"}"#)); // No name field
123
124
125
126
127
128
129
}

#[tokio::test]
async fn test_llama_invalid_json_after_tag() {
    let parser = LlamaParser::new();

    let input = r#"<|python_tag|>{"name": invalid}"#;
130
    let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
131
    assert_eq!(tools.len(), 0);
132
    assert_eq!(normal_text, "<|python_tag|>{\"name\": invalid}");
133
134
135
136
137
138
139
140
141
}

#[tokio::test]
async fn test_llama_real_world_output() {
    let parser = LlamaParser::new();

    // Actual output from Llama 3.2 model - simplified for testing
    let input = r#"I'll search for that information for you.

142
<|python_tag|>{"name": "web_search", "parameters": {"query": "Llama 3.2 model capabilities", "num_results": 5, "search_type": "recent"}}"#;
143

144
145
146
    let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "web_search");
147
148
149

    let formatted_input = r#"<|python_tag|>{
    "name": "get_current_time",
150
    "parameters": {
151
152
153
154
155
        "timezone": "America/New_York",
        "format": "ISO8601"
    }
}"#;

156
157
158
    let (_normal_text, tools2) = parser.parse_complete(formatted_input).await.unwrap();
    assert_eq!(tools2.len(), 1);
    assert_eq!(tools2[0].function.name, "get_current_time");
159
160
}

161
162
163
#[tokio::test]
async fn test_single_json() {
    let parser = LlamaParser::new();
164
    let text = r#"{"name": "get_weather", "parameters": {"city": "Paris"}}"#;
165

166
167
168
    let (_normal_text, tools) = parser.parse_complete(text).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "get_weather");
169

170
    let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
171
172
173
174
175
176
    assert_eq!(args["city"], "Paris");
}

#[tokio::test]
async fn test_multiple_json_with_separator() {
    let parser = LlamaParser::new();
177
    let text = r#"<|python_tag|>{"name": "get_weather", "parameters": {"city": "Paris"}};{"name": "get_tourist_attractions", "parameters": {"city": "Paris"}}"#;
178

179
    let (_normal_text, tools) = parser.parse_complete(text).await.unwrap();
180
    // Note: Current implementation may only parse the first one due to semicolon handling
181
182
    assert!(!tools.is_empty());
    assert_eq!(tools[0].function.name, "get_weather");
183
184
185
186
187
}

#[tokio::test]
async fn test_json_with_trailing_text() {
    let parser = LlamaParser::new();
188
189
    // Valid JSON with trailing text - LlamaParser doesn't support this mixed format
    let text = r#"{"name": "get_weather", "parameters": {}} Some follow-up text"#;
190

191
192
193
194
195
    let (normal_text, tools) = parser.parse_complete(text).await.unwrap();
    // LlamaParser expects pure JSON or <|python_tag|> format, not JSON with trailing text
    // So this returns as normal text
    assert_eq!(tools.len(), 0);
    assert_eq!(normal_text, text);
196
197
198
199
200
}

#[tokio::test]
async fn test_invalid_then_valid_json() {
    let parser = LlamaParser::new();
201
202
    let text =
        r#"{"name": "get_weather", "parameters": {{"name": "get_weather", "parameters": {}}"#;
203

204
    let (_normal_text, tools) = parser.parse_complete(text).await.unwrap();
205
    // Should parse at least one valid JSON
206
207
    if !tools.is_empty() {
        assert_eq!(tools[0].function.name, "get_weather");
208
209
210
211
212
213
214
215
    }
}

#[tokio::test]
async fn test_plain_text_only() {
    let parser = LlamaParser::new();
    let text = "This is just plain explanation text.";

216
217
    let (_normal_text, tools) = parser.parse_complete(text).await.unwrap();
    assert_eq!(tools.len(), 0);
218
219
220
221
222
}

#[tokio::test]
async fn test_with_python_tag_prefix() {
    let parser = LlamaParser::new();
223
    let text = r#"Some intro. <|python_tag|>{"name": "get_weather", "parameters": {}}"#;
224

225
226
227
    let (_normal_text, tools) = parser.parse_complete(text).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "get_weather");
228
229
230
231
232
233
}

// STREAMING TESTS

#[tokio::test]
async fn test_llama_streaming_simple() {
234
235
236
    let tools = create_test_tools();

    let mut parser = LlamaParser::new();
237
238

    // Send complete JSON at once
239
    let full_json = r#"<|python_tag|>{"name": "search", "parameters": {"query": "weather"}}"#;
240

241
    let result = parser.parse_incremental(full_json, &tools).await.unwrap();
242

243
244
245
246
247
    assert!(
        !result.calls.is_empty(),
        "Expected tool call for complete JSON input"
    );
    assert_eq!(result.calls[0].name.as_ref().unwrap(), "search");
248
249
250
251
}

#[tokio::test]
async fn test_llama_streaming_partial() {
252
253
254
    let tools = create_test_tools();

    let mut parser = LlamaParser::new();
255
256
257
258
259
260

    // Stream in chunks
    let chunks = vec![
        r#"<|python"#,
        r#"_tag|>{"name": "#,
        r#""calculate", "#,
261
        r#""parameters": {"x": 10}"#,
262
263
264
265
266
267
        r#"}"#,
    ];

    let mut got_complete = false;

    for chunk in chunks {
268
269
270
271
272
273
        let result = parser.parse_incremental(chunk, &tools).await.unwrap();
        if !result.calls.is_empty() {
            if let Some(name) = &result.calls[0].name {
                assert_eq!(name, "calculate");
                got_complete = true;
            }
274
275
276
277
278
279
280
281
        }
    }

    assert!(got_complete, "Should have completed parsing");
}

#[tokio::test]
async fn test_llama_streaming_plain_json() {
282
283
284
    let tools = create_test_tools();

    let mut parser = LlamaParser::new();
285
286
287
288
289

    // Stream plain JSON without python_tag
    let chunks = vec![
        r#"{"name": "#,
        r#""search", "#,
290
        r#""parameters": "#,
291
292
293
294
295
296
297
        r#"{"query": "#,
        r#""test"}}"#,
    ];

    let mut got_complete = false;

    for chunk in chunks {
298
299
300
301
302
303
        let result = parser.parse_incremental(chunk, &tools).await.unwrap();
        if !result.calls.is_empty() {
            if let Some(name) = &result.calls[0].name {
                assert_eq!(name, "search");
                got_complete = true;
            }
304
305
306
307
308
309
310
311
        }
    }

    assert!(got_complete, "Should have completed parsing");
}

#[tokio::test]
async fn test_llama_streaming_with_text_before() {
312
313
314
    let tools = create_test_tools();

    let mut parser = LlamaParser::new();
315
316
317
318
319

    let chunks = vec![
        r#"Let me help you. "#,
        r#"<|python_tag|>"#,
        r#"{"name": "get_time","#,
320
        r#" "parameters": {"#,
321
322
323
324
325
326
        r#""timezone": "UTC"}}"#,
    ];

    let mut got_complete = false;

    for chunk in chunks {
327
328
329
330
331
332
        let result = parser.parse_incremental(chunk, &tools).await.unwrap();
        if !result.calls.is_empty() {
            if let Some(name) = &result.calls[0].name {
                assert_eq!(name, "get_time");
                got_complete = true;
            }
333
334
335
336
337
338
339
340
        }
    }

    assert!(got_complete, "Should have completed parsing");
}

#[tokio::test]
async fn test_llama_streaming_multiple_tools() {
341
342
343
    let tools = create_test_tools();

    let mut parser = LlamaParser::new();
344
345

    let text =
346
        r#"<|python_tag|>{"name": "func1", "parameters": {}};{"name": "func2", "parameters": {}}"#;
347

348
    let result = parser.parse_incremental(text, &tools).await.unwrap();
349

350
    // Should get first tool complete
351
352
353
354
355
356
    assert!(
        !result.calls.is_empty(),
        "Expected first tool to be complete"
    );
    if let Some(name) = &result.calls[0].name {
        assert_eq!(name, "func1");
357
358
359
    }

    // Process remaining buffer to get second tool
360
361
362
363
    let result2 = parser.parse_incremental("", &tools).await.unwrap();
    if !result2.calls.is_empty() {
        if let Some(name) = &result2.calls[0].name {
            assert_eq!(name, "func2");
364
365
366
367
368
369
        }
    }
}

#[tokio::test]
async fn test_llama_streaming_multiple_tools_chunked() {
370
371
372
    let mut parser = LlamaParser::new();

    let tools = create_test_tools();
373
374

    // First chunk - incomplete first JSON
375
    let chunk1 = r#"<|python_tag|>{"name": "get_weather", "parameters""#;
376
377
378
379
    let result1 = parser.parse_incremental(chunk1, &tools).await.unwrap();
    if !result1.calls.is_empty() {
        if let Some(name) = &result1.calls[0].name {
            assert_eq!(name, "get_weather");
380
381
382
383
384
        }
    }

    // Second chunk - complete first JSON and separator
    let chunk2 = r#": {"city": "Paris"}};{"name": "#;
385
    let result2 = parser.parse_incremental(chunk2, &tools).await.unwrap();
386

387
388
389
390
    // Should get parameters for first tool (name already sent in result1)
    if !result2.calls.is_empty() {
        let args: serde_json::Value = serde_json::from_str(&result2.calls[0].parameters).unwrap();
        assert_eq!(args["city"], "Paris");
391
392
    }

393
    let chunk3 = r#""get_time", "parameters": {"timezone": "UTC"}}"#;
394
395
396
397
    let result3 = parser.parse_incremental(chunk3, &tools).await.unwrap();
    if !result3.calls.is_empty() {
        if let Some(name) = &result3.calls[0].name {
            assert_eq!(name, "get_time");
398
399
400
        }
    }
}