tool_parser_kimik2.rs 5.55 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
//! Kimi K2 Parser Integration Tests

use sglang_router_rs::tool_parser::{KimiK2Parser, ParseState, StreamResult, ToolParser};

#[tokio::test]
async fn test_kimik2_complete_parsing() {
    let parser = KimiK2Parser::new();

    let input = r#"Let me help you with that.
<|tool_calls_section_begin|>
<|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"location": "Tokyo", "units": "celsius"}<|tool_call_end|>
<|tool_calls_section_end|>
The weather in Tokyo is..."#;

15
    let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
16
    assert_eq!(tools.len(), 1);
17
    assert_eq!(normal_text, "Let me help you with that.\n");
18
    assert_eq!(tools[0].function.name, "get_weather");
19

20
    let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
21
22
23
24
25
26
27
28
29
30
31
32
33
    assert_eq!(args["location"], "Tokyo");
    assert_eq!(args["units"], "celsius");
}

#[tokio::test]
async fn test_kimik2_multiple_tools() {
    let parser = KimiK2Parser::new();

    let input = r#"<|tool_calls_section_begin|>
<|tool_call_begin|>functions.search:0<|tool_call_argument_begin|>{"query": "rust tutorials"}<|tool_call_end|>
<|tool_call_begin|>functions.translate:1<|tool_call_argument_begin|>{"text": "Hello", "to": "ja"}<|tool_call_end|>
<|tool_calls_section_end|>"#;

34
    let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
35
    assert_eq!(tools.len(), 2);
36
    assert_eq!(normal_text, "");
37
38
    assert_eq!(tools[0].function.name, "search");
    assert_eq!(tools[1].function.name, "translate");
39
40
41
42
43
44
45
46
47
48
}

#[tokio::test]
async fn test_kimik2_with_whitespace() {
    let parser = KimiK2Parser::new();

    let input = r#"<|tool_calls_section_begin|>
<|tool_call_begin|> functions.test:0 <|tool_call_argument_begin|> {"key": "value", "num": 42} <|tool_call_end|>
<|tool_calls_section_end|>"#;

49
    let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
50
    assert_eq!(tools.len(), 1);
51
    assert_eq!(normal_text, "");
52
    assert_eq!(tools[0].function.name, "test");
53

54
    let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    assert_eq!(args["key"], "value");
    assert_eq!(args["num"], 42);
}

#[tokio::test]
async fn test_kimik2_streaming() {
    let parser = KimiK2Parser::new();
    let mut state = ParseState::new();

    // Simulate streaming chunks
    let chunks = vec![
        "<|tool_calls_section_begin|>\n",
        "<|tool_call_begin|>functions.",
        "calculate:0",
        "<|tool_call_argument_begin|>",
        r#"{"x": 10, "#,
        r#""y": 20}"#,
        "<|tool_call_end|>\n",
        "<|tool_calls_section_end|>",
    ];

    let mut found_name = false;
    let mut found_complete = false;

    for chunk in chunks {
        let result = parser.parse_incremental(chunk, &mut state).await.unwrap();

        match result {
            StreamResult::ToolName { name, .. } => {
                assert_eq!(name, "calculate");
                found_name = true;
            }
            StreamResult::ToolComplete(tool) => {
                assert_eq!(tool.function.name, "calculate");
                found_complete = true;
            }
            _ => {}
        }
    }

    assert!(found_name || found_complete);
}

#[test]
fn test_kimik2_format_detection() {
    let parser = KimiK2Parser::new();

    // Should detect Kimi K2 format
    assert!(parser.detect_format("<|tool_calls_section_begin|>"));
    assert!(parser.detect_format("<|tool_call_begin|>"));
    assert!(parser.detect_format("text with <|tool_calls_section_begin|> marker"));

    // Should not detect other formats
    assert!(!parser.detect_format("[TOOL_CALLS]"));
    assert!(!parser.detect_format("<tool_call>"));
    assert!(!parser.detect_format("plain text"));
}

#[tokio::test]
async fn test_kimik2_sequential_indices() {
    let parser = KimiK2Parser::new();

    let input = r#"<|tool_calls_section_begin|>
<|tool_call_begin|>functions.first:0<|tool_call_argument_begin|>{"param": "a"}<|tool_call_end|>
<|tool_call_begin|>functions.second:1<|tool_call_argument_begin|>{"param": "b"}<|tool_call_end|>
<|tool_call_begin|>functions.third:2<|tool_call_argument_begin|>{"param": "c"}<|tool_call_end|>
<|tool_calls_section_end|>"#;

123
    let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
124
    assert_eq!(tools.len(), 3);
125
    assert_eq!(normal_text, "");
126
127
128
    assert_eq!(tools[0].function.name, "first");
    assert_eq!(tools[1].function.name, "second");
    assert_eq!(tools[2].function.name, "third");
129
130
131
132
133
134
135
136
137
138
139
140
}

#[tokio::test]
async fn test_function_index_extraction() {
    let parser = KimiK2Parser::new();

    let input = r#"Text before tool calls.
<|tool_calls_section_begin|>
<|tool_call_begin|>functions.search:0<|tool_call_argument_begin|>{"query": "rust"}<|tool_call_end|>
<|tool_call_begin|>functions.calc:1<|tool_call_argument_begin|>{"x": 10}<|tool_call_end|>
<|tool_calls_section_end|>"#;

141
    let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
142
    assert_eq!(tools.len(), 2);
143
    assert_eq!(normal_text, "Text before tool calls.\n");
144
145
    assert_eq!(tools[0].function.name, "search");
    assert_eq!(tools[1].function.name, "calc");
146
147
148
149
150
151
152
153
154
155
156
    // TODO: Verify indices are preserved: 0 and 1
}

#[tokio::test]
async fn test_namespace_extraction() {
    let parser = KimiK2Parser::new();

    let input = r#"<|tool_calls_section_begin|>
<|tool_call_begin|>api.tools.search:0<|tool_call_argument_begin|>{"q": "test"}<|tool_call_end|>
<|tool_calls_section_end|>"#;

157
158
159
    let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
    assert_eq!(tools.len(), 1);
    assert_eq!(tools[0].function.name, "search"); // Should extract after last dot
160
}