tests.rs 20.7 KB
Newer Older
1
use super::*;
2
use crate::tool_parser::parsers::JsonParser;
3
4
5
use crate::tool_parser::partial_json::{
    compute_diff, find_common_prefix, is_complete_json, PartialJson,
};
6
use crate::tool_parser::traits::ToolParser;
7

8
9
10
#[tokio::test]
async fn test_tool_parser_factory() {
    let factory = ToolParserFactory::new();
11

12
13
14
    // Test that we can get a pooled parser
    let pooled_parser = factory.get_pooled("gpt-4");
    let parser = pooled_parser.lock().await;
15
    assert!(parser.has_tool_markers(r#"{"name": "test", "arguments": {}}"#));
16
17
}

18
19
20
#[tokio::test]
async fn test_tool_parser_factory_model_mapping() {
    let factory = ToolParserFactory::new();
21

22
23
    // Test model mapping
    factory.registry().map_model("test-model", "json");
24

25
26
27
    // Get parser for the test model
    let pooled_parser = factory.get_pooled("test-model");
    let parser = pooled_parser.lock().await;
28
    assert!(parser.has_tool_markers(r#"{"name": "test", "arguments": {}}"#));
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
}

#[test]
fn test_tool_call_serialization() {
    let tool_call = ToolCall {
        function: FunctionCall {
            name: "search".to_string(),
            arguments: r#"{"query": "rust programming"}"#.to_string(),
        },
    };

    let json = serde_json::to_string(&tool_call).unwrap();
    assert!(json.contains("search"));
    assert!(json.contains("rust programming"));

    let parsed: ToolCall = serde_json::from_str(&json).unwrap();
    assert_eq!(parsed.function.name, "search");
46
47
48
49
    assert_eq!(
        parsed.function.arguments,
        r#"{"query": "rust programming"}"#
    );
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
}

#[test]
fn test_partial_json_parser() {
    let parser = PartialJson::default();

    let input = r#"{"name": "test", "value": 42}"#;
    let (value, consumed) = parser.parse_value(input).unwrap();
    assert_eq!(value["name"], "test");
    assert_eq!(value["value"], 42);
    assert_eq!(consumed, input.len());

    let input = r#"{"name": "test", "value": "#;
    let (value, _consumed) = parser.parse_value(input).unwrap();
    assert_eq!(value["name"], "test");
    assert!(value["value"].is_null());

    let input = r#"{"name": "tes"#;
    let (value, _consumed) = parser.parse_value(input).unwrap();
    assert_eq!(value["name"], "tes");

    let input = r#"[1, 2, "#;
    let (value, _consumed) = parser.parse_value(input).unwrap();
    assert!(value.is_array());
    assert_eq!(value[0], 1);
    assert_eq!(value[1], 2);
}

#[test]
fn test_partial_json_depth_limit() {
    // max_depth of 3 allows nesting up to 3 levels
    // Set allow_incomplete to false to get errors instead of partial results
    let parser = PartialJson::new(3, false);

    // This should work (simple object)
    let input = r#"{"a": 1}"#;
    let result = parser.parse_value(input);
    assert!(result.is_ok());

    // This should work (nested to depth 3)
    let input = r#"{"a": {"b": {"c": 1}}}"#;
    let result = parser.parse_value(input);
    assert!(result.is_ok());

    // This should fail (nested to depth 4, exceeds limit)
    let input = r#"{"a": {"b": {"c": {"d": 1}}}}"#;
    let result = parser.parse_value(input);
    assert!(result.is_err());
}

#[test]
fn test_is_complete_json() {
    assert!(is_complete_json(r#"{"name": "test"}"#));
    assert!(is_complete_json(r#"[1, 2, 3]"#));
    assert!(is_complete_json(r#""string""#));
    assert!(is_complete_json("42"));
    assert!(is_complete_json("true"));
    assert!(is_complete_json("null"));

    assert!(!is_complete_json(r#"{"name": "#));
    assert!(!is_complete_json(r#"[1, 2, "#));
    assert!(!is_complete_json(r#""unclosed"#));
}

#[test]
fn test_find_common_prefix() {
    assert_eq!(find_common_prefix("hello", "hello"), 5);
    assert_eq!(find_common_prefix("hello", "help"), 3);
    assert_eq!(find_common_prefix("hello", "world"), 0);
    assert_eq!(find_common_prefix("", "hello"), 0);
    assert_eq!(find_common_prefix("hello", ""), 0);
}

#[test]
fn test_compute_diff() {
    assert_eq!(compute_diff("hello", "hello world"), " world");
    assert_eq!(compute_diff("", "hello"), "hello");
    assert_eq!(compute_diff("hello", "hello"), "");
    assert_eq!(compute_diff("test", "hello"), "hello");
}

131
// NOTE: test_stream_result_variants removed - StreamResult enum replaced by StreamingParseResult
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156

#[test]
fn test_partial_tool_call() {
    let mut partial = PartialToolCall {
        name: None,
        arguments_buffer: String::new(),
        start_position: 0,
        name_sent: false,
        streamed_args: String::new(),
    };

    // Set name
    partial.name = Some("test_function".to_string());
    assert_eq!(partial.name.as_ref().unwrap(), "test_function");

    // Append arguments
    partial.arguments_buffer.push_str(r#"{"key": "value"}"#);
    assert_eq!(partial.arguments_buffer, r#"{"key": "value"}"#);

    // Update streaming state
    partial.name_sent = true;
    partial.streamed_args = r#"{"key": "#.to_string();
    assert!(partial.name_sent);
    assert_eq!(partial.streamed_args, r#"{"key": "#);
}
157
158
159
160
161
162

#[tokio::test]
async fn test_json_parser_complete_single() {
    let parser = JsonParser::new();

    let input = r#"{"name": "get_weather", "arguments": {"location": "San Francisco", "units": "celsius"}}"#;
163
    let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
164

165
166
167
168
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "get_weather");
    assert!(tools[0].function.arguments.contains("San Francisco"));
    assert!(tools[0].function.arguments.contains("celsius"));
169
170
171
172
173
174
175
176
177
178
179
}

#[tokio::test]
async fn test_json_parser_complete_array() {
    let parser = JsonParser::new();

    let input = r#"[
        {"name": "get_weather", "arguments": {"location": "SF"}},
        {"name": "get_news", "arguments": {"query": "technology"}}
    ]"#;

180
    let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
181

182
183
184
    assert_eq!(tools.len(), 2);
    assert_eq!(tools[0].function.name, "get_weather");
    assert_eq!(tools[1].function.name, "get_news");
185
186
187
188
189
190
191
}

#[tokio::test]
async fn test_json_parser_with_parameters() {
    let parser = JsonParser::new();

    let input = r#"{"name": "calculate", "parameters": {"x": 10, "y": 20, "operation": "add"}}"#;
192
    let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
193

194
195
196
197
198
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "calculate");
    assert!(tools[0].function.arguments.contains("10"));
    assert!(tools[0].function.arguments.contains("20"));
    assert!(tools[0].function.arguments.contains("add"));
199
200
}

201
// Tests removed - TokenConfig no longer supported in JsonParser
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223

#[tokio::test]
async fn test_multiline_json_array() {
    let parser = JsonParser::new();

    let input = r#"[
    {
        "name": "function1",
        "arguments": {
            "param1": "value1",
            "param2": 42
        }
    },
    {
        "name": "function2",
        "parameters": {
            "data": [1, 2, 3],
            "flag": false
        }
    }
]"#;

224
225
226
227
228
229
    let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 2);
    assert_eq!(tools[0].function.name, "function1");
    assert_eq!(tools[1].function.name, "function2");
    assert!(tools[0].function.arguments.contains("value1"));
    assert!(tools[1].function.arguments.contains("[1,2,3]"));
230
231
232
233
234
235
236
}

#[test]
fn test_json_parser_format_detection() {
    let parser = JsonParser::new();

    // Should detect valid tool call formats
237
238
239
    assert!(parser.has_tool_markers(r#"{"name": "test", "arguments": {}}"#));
    assert!(parser.has_tool_markers(r#"{"name": "test", "parameters": {"x": 1}}"#));
    assert!(parser.has_tool_markers(r#"[{"name": "test"}]"#));
240
241

    // Should not detect non-tool formats
242
    assert!(!parser.has_tool_markers("plain text"));
243
244
245
}

#[tokio::test]
246
247
async fn test_factory_with_json_parser() {
    let factory = ToolParserFactory::new();
248
249

    // Should get JSON parser for OpenAI models
250
251
    let pooled_parser = factory.get_pooled("gpt-4-turbo");
    let parser = pooled_parser.lock().await;
252
253

    let input = r#"{"name": "test", "arguments": {"x": 1}}"#;
254
255
256
    let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "test");
257
258
259
260
261
262
263
}

#[tokio::test]
async fn test_json_parser_invalid_input() {
    let parser = JsonParser::new();

    // Invalid JSON should return empty results
264
265
266
    assert_eq!(parser.parse_complete("not json").await.unwrap().1.len(), 0);
    assert_eq!(parser.parse_complete("{invalid}").await.unwrap().1.len(), 0);
    assert_eq!(parser.parse_complete("").await.unwrap().1.len(), 0);
267
268
269
270
271
272
273
274
}

#[tokio::test]
async fn test_json_parser_empty_arguments() {
    let parser = JsonParser::new();

    // Tool call with no arguments
    let input = r#"{"name": "get_time"}"#;
275
    let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
276

277
278
279
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "get_time");
    assert_eq!(tools[0].function.arguments, "{}");
280
281
282
283
284
285
286
287
288
289
290
291
}

#[cfg(test)]
mod failure_cases {
    use super::*;

    #[tokio::test]
    async fn test_malformed_tool_missing_name() {
        let parser = JsonParser::new();

        // Missing name field
        let input = r#"{"arguments": {"x": 1}}"#;
292
293
        let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
        assert_eq!(tools.len(), 0, "Should return empty for tool without name");
294
295
296

        // Empty name
        let input = r#"{"name": "", "arguments": {"x": 1}}"#;
297
298
299
        let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
        assert_eq!(tools.len(), 1, "Should accept empty name string");
        assert_eq!(tools[0].function.name, "");
300
301
302
303
304
305
306
307
    }

    #[tokio::test]
    async fn test_invalid_arguments_json() {
        let parser = JsonParser::new();

        // Arguments is a string instead of object
        let input = r#"{"name": "test", "arguments": "not an object"}"#;
308
309
        let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
        assert_eq!(tools.len(), 1);
310
        // Should serialize the string as JSON
311
        assert!(tools[0].function.arguments.contains("not an object"));
312
313
314

        // Arguments is a number
        let input = r#"{"name": "test", "arguments": 42}"#;
315
316
317
        let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
        assert_eq!(tools.len(), 1);
        assert_eq!(tools[0].function.arguments, "42");
318
319
320

        // Arguments is null
        let input = r#"{"name": "test", "arguments": null}"#;
321
322
323
        let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
        assert_eq!(tools.len(), 1);
        assert_eq!(tools[0].function.arguments, "null");
324
325
    }

326
    // Test removed - wrapper token functionality moved to specific parsers
327
328
329
330
331
332
333

    #[tokio::test]
    async fn test_invalid_json_structures() {
        let parser = JsonParser::new();

        // Trailing comma
        let input = r#"{"name": "test", "arguments": {"x": 1,}}"#;
334
335
        let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
        assert_eq!(tools.len(), 0, "Should reject JSON with trailing comma");
336
337
338

        // Missing quotes on keys
        let input = r#"{name: "test", arguments: {}}"#;
339
340
        let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
        assert_eq!(tools.len(), 0, "Should reject invalid JSON syntax");
341
342
343

        // Unclosed object
        let input = r#"{"name": "test", "arguments": {"#;
344
345
        let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
        assert_eq!(tools.len(), 0, "Should reject incomplete JSON");
346
347
348
349
350
351
352
353
354
355
356
357
358
    }
}

#[cfg(test)]
mod edge_cases {
    use super::*;

    #[tokio::test]
    async fn test_unicode_in_names_and_arguments() {
        let parser = JsonParser::new();

        // Unicode in function name
        let input = r#"{"name": "获取天气", "arguments": {"location": "北京"}}"#;
359
360
361
362
        let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
        assert_eq!(tools.len(), 1);
        assert_eq!(tools[0].function.name, "获取天气");
        assert!(tools[0].function.arguments.contains("北京"));
363
364
365

        // Emoji in arguments
        let input = r#"{"name": "send_message", "arguments": {"text": "Hello 👋 World 🌍"}}"#;
366
367
368
369
        let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
        assert_eq!(tools.len(), 1);
        assert!(tools[0].function.arguments.contains("👋"));
        assert!(tools[0].function.arguments.contains("🌍"));
370
371
372
373
374
375
376
377
    }

    #[tokio::test]
    async fn test_escaped_characters() {
        let parser = JsonParser::new();

        // Escaped quotes in arguments
        let input = r#"{"name": "echo", "arguments": {"text": "He said \"hello\""}}"#;
378
379
380
        let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
        assert_eq!(tools.len(), 1);
        assert!(tools[0].function.arguments.contains(r#"\"hello\""#));
381
382
383

        // Escaped backslashes
        let input = r#"{"name": "path", "arguments": {"dir": "C:\\Users\\test"}}"#;
384
385
386
        let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
        assert_eq!(tools.len(), 1);
        assert!(tools[0].function.arguments.contains("\\\\"));
387
388
389

        // Newlines and tabs
        let input = r#"{"name": "format", "arguments": {"text": "line1\nline2\ttabbed"}}"#;
390
391
392
393
        let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
        assert_eq!(tools.len(), 1);
        assert!(tools[0].function.arguments.contains("\\n"));
        assert!(tools[0].function.arguments.contains("\\t"));
394
395
396
397
398
399
400
401
402
403
404
405
406
    }

    #[tokio::test]
    async fn test_very_large_payloads() {
        let parser = JsonParser::new();

        // Large arguments object
        let mut large_args = r#"{"name": "process", "arguments": {"#.to_string();
        for i in 0..1000 {
            large_args.push_str(&format!(r#""field_{}": "value_{}","#, i, i));
        }
        large_args.push_str(r#""final": "value"}}"#);

407
408
409
410
        let (_normal_text, tools) = parser.parse_complete(&large_args).await.unwrap();
        assert_eq!(tools.len(), 1);
        assert_eq!(tools[0].function.name, "process");
        assert!(tools[0].function.arguments.contains("field_999"));
411
412
413
414
415
416
417
418
419
420
421

        // Large array of tool calls
        let mut large_array = "[".to_string();
        for i in 0..100 {
            if i > 0 {
                large_array.push(',');
            }
            large_array.push_str(&format!(r#"{{"name": "func_{}", "arguments": {{}}}}"#, i));
        }
        large_array.push(']');

422
423
424
        let (_normal_text, tools) = parser.parse_complete(&large_array).await.unwrap();
        assert_eq!(tools.len(), 100);
        assert_eq!(tools[99].function.name, "func_99");
425
426
427
428
429
430
431
432
433
434
435
436
437
438
    }

    #[tokio::test]
    async fn test_mixed_array_tools_and_non_tools() {
        let parser = JsonParser::new();

        // Array with both tool calls and non-tool objects
        let input = r#"[
            {"name": "tool1", "arguments": {}},
            {"not_a_tool": "just_data"},
            {"name": "tool2", "parameters": {"x": 1}},
            {"key": "value", "another": "field"}
        ]"#;

439
440
441
442
        let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
        assert_eq!(tools.len(), 2, "Should only parse valid tool calls");
        assert_eq!(tools[0].function.name, "tool1");
        assert_eq!(tools[1].function.name, "tool2");
443
444
445
446
447
448
449
450
    }

    #[tokio::test]
    async fn test_duplicate_keys_in_json() {
        let parser = JsonParser::new();

        // JSON with duplicate keys (last one wins in most parsers)
        let input = r#"{"name": "first", "name": "second", "arguments": {"x": 1, "x": 2}}"#;
451
452
        let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
        assert_eq!(tools.len(), 1);
453
        assert_eq!(
454
            tools[0].function.name, "second",
455
456
457
            "Last duplicate key should win"
        );
        assert!(
458
            tools[0].function.arguments.contains("2"),
459
460
461
462
463
464
465
466
467
468
            "Last duplicate value should win"
        );
    }

    #[tokio::test]
    async fn test_null_values_in_arguments() {
        let parser = JsonParser::new();

        // Null values in arguments
        let input = r#"{"name": "test", "arguments": {"required": "value", "optional": null}}"#;
469
470
471
        let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
        assert_eq!(tools.len(), 1);
        assert!(tools[0].function.arguments.contains("null"));
472
473
474

        // Array with null
        let input = r#"{"name": "test", "arguments": {"items": [1, null, "three"]}}"#;
475
476
477
        let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
        assert_eq!(tools.len(), 1);
        assert!(tools[0].function.arguments.contains("null"));
478
479
480
481
482
483
484
485
    }

    #[tokio::test]
    async fn test_special_json_values() {
        let parser = JsonParser::new();

        // Boolean values
        let input = r#"{"name": "toggle", "arguments": {"enabled": true, "disabled": false}}"#;
486
487
488
489
        let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
        assert_eq!(tools.len(), 1);
        assert!(tools[0].function.arguments.contains("true"));
        assert!(tools[0].function.arguments.contains("false"));
490
491
492

        // Numbers (including float and negative)
        let input = r#"{"name": "calc", "arguments": {"int": 42, "float": 3.14, "negative": -17}}"#;
493
494
495
496
497
        let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
        assert_eq!(tools.len(), 1);
        assert!(tools[0].function.arguments.contains("42"));
        assert!(tools[0].function.arguments.contains("3.14"));
        assert!(tools[0].function.arguments.contains("-17"));
498
499
500

        // Empty arrays and objects
        let input = r#"{"name": "test", "arguments": {"empty_arr": [], "empty_obj": {}}}"#;
501
502
503
504
        let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
        assert_eq!(tools.len(), 1);
        assert!(tools[0].function.arguments.contains("[]"));
        assert!(tools[0].function.arguments.contains("{}"));
505
506
507
508
509
510
511
512
    }

    #[tokio::test]
    async fn test_function_field_alternative() {
        let parser = JsonParser::new();

        // Using "function" instead of "name"
        let input = r#"{"function": "test_func", "arguments": {"x": 1}}"#;
513
514
515
        let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
        assert_eq!(tools.len(), 1);
        assert_eq!(tools[0].function.name, "test_func");
516
517
518

        // Both "name" and "function" present (name should take precedence)
        let input = r#"{"name": "primary", "function": "secondary", "arguments": {}}"#;
519
520
521
        let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
        assert_eq!(tools.len(), 1);
        assert_eq!(tools[0].function.name, "primary");
522
523
524
525
526
527
528
529
530
531
532
533
534
    }

    #[tokio::test]
    async fn test_whitespace_handling() {
        let parser = JsonParser::new();

        // Extra whitespace everywhere
        let input = r#"  {
            "name"   :   "test"  ,
            "arguments"   :   {
                "key"   :   "value"
            }
        }  "#;
535
536
537
        let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
        assert_eq!(tools.len(), 1);
        assert_eq!(tools[0].function.name, "test");
538
539
540

        // Minified JSON (no whitespace)
        let input = r#"{"name":"compact","arguments":{"a":1,"b":2}}"#;
541
542
543
        let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
        assert_eq!(tools.len(), 1);
        assert_eq!(tools[0].function.name, "compact");
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
    }
}

#[cfg(test)]
mod stress_tests {
    use super::*;

    #[tokio::test]
    async fn test_deeply_nested_arguments() {
        let parser = JsonParser::new();

        // Deeply nested structure
        let input = r#"{
            "name": "nested",
            "arguments": {
                "level1": {
                    "level2": {
                        "level3": {
                            "level4": {
                                "level5": {
                                    "value": "deep"
                                }
                            }
                        }
                    }
                }
            }
        }"#;

573
574
575
        let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
        assert_eq!(tools.len(), 1);
        assert!(tools[0].function.arguments.contains("deep"));
576
577
578
579
580
581
582
583
584
585
586
587
    }

    #[tokio::test]
    async fn test_concurrent_parser_usage() {
        let parser = std::sync::Arc::new(JsonParser::new());

        let mut handles = vec![];

        for i in 0..10 {
            let parser_clone = parser.clone();
            let handle = tokio::spawn(async move {
                let input = format!(r#"{{"name": "func_{}", "arguments": {{}}}}"#, i);
588
589
590
                let (_normal_text, tools) = parser_clone.parse_complete(&input).await.unwrap();
                assert_eq!(tools.len(), 1);
                assert_eq!(tools[0].function.name, format!("func_{}", i));
591
592
593
594
595
596
597
598
599
            });
            handles.push(handle);
        }

        for handle in handles {
            handle.await.unwrap();
        }
    }
}