tool_parser_streaming.rs 8.94 KB
Newer Older
1
2
3
4
5
//! Streaming Parser Tests
//!
//! Tests for incremental/streaming parsing capabilities across all parsers

use sglang_router_rs::tool_parser::{
6
    JsonParser, LlamaParser, MistralParser, PythonicParser, QwenParser, ToolParser,
7
8
};

9
10
11
mod common;
use common::create_test_tools;

12
13
#[tokio::test]
async fn test_json_streaming_simple() {
14
15
16
    let tools = create_test_tools();

    let mut parser = JsonParser::new();
17
18
19

    let full_json = r#"{"name": "get_weather", "arguments": {"location": "San Francisco"}}"#;

20
    let result = parser.parse_incremental(full_json, &tools).await.unwrap();
21

22
23
    assert!(!result.calls.is_empty(), "Should have parsed a tool call");
    assert_eq!(result.calls[0].name, Some("get_weather".to_string()));
24
25
26
27
}

#[tokio::test]
async fn test_json_streaming_array() {
28
29
30
    let tools = create_test_tools();

    let mut parser = JsonParser::new();
31
32
33
34
35
36
37
38
39
40
41
42
43

    let chunks = vec![
        r#"["#,
        r#"{"name": "tool1", "#,
        r#""arguments": {}}, "#,
        r#"{"name": "tool2", "#,
        r#""arguments": {"x": 1"#,
        r#"}}]"#,
    ];

    let mut tool_count = 0;

    for chunk in chunks {
44
45
46
47
48
        let result = parser.parse_incremental(chunk, &tools).await.unwrap();
        for call in result.calls {
            if call.name.is_some() {
                tool_count += 1;
            }
49
50
51
52
53
54
55
56
57
        }
    }

    // Current implementation may handle this differently
    assert!(tool_count <= 2, "Should parse at most 2 tools");
}

#[tokio::test]
async fn test_mistral_streaming() {
58
59
60
    let tools = create_test_tools();

    let mut parser = MistralParser::new();
61
62
63
64
65
66
67
68
69
70
71
72

    let chunks = vec![
        r#"Here is the result: "#,
        r#"[TOOL_CALLS] ["#,
        r#"{"name": "#,
        r#""search", "#,
        r#""arguments": "#,
        r#"{"query": "#,
        r#""rust lang""#,
        r#"}}]"#,
    ];

73
    let mut got_tool_name = false;
74
75

    for chunk in chunks {
76
77
78
79
80
81
        let result = parser.parse_incremental(chunk, &tools).await.unwrap();
        for call in result.calls {
            if let Some(name) = call.name {
                assert_eq!(name, "search");
                got_tool_name = true;
            }
82
83
84
        }
    }

85
    assert!(got_tool_name, "Should have found tool name");
86
87
88
89
}

#[tokio::test]
async fn test_pythonic_streaming() {
90
91
92
    let tools = create_test_tools();

    let mut parser = PythonicParser::new();
93
94
95

    let full_input = r#"[get_weather(city="London", units="celsius")]"#;

96
    let result = parser.parse_incremental(full_input, &tools).await.unwrap();
97

98
99
100
101
    assert!(!result.calls.is_empty(), "Should have parsed a tool call");
    assert_eq!(result.calls[0].name, Some("get_weather".to_string()));
    let args: serde_json::Value = serde_json::from_str(&result.calls[0].parameters).unwrap();
    assert_eq!(args["city"], "London");
102
103
104
105
}

#[tokio::test]
async fn test_llama_streaming_with_python_tag() {
106
107
108
    let tools = create_test_tools();

    let mut parser = LlamaParser::new();
109
110
111
112
113
114
115
116
117
118
119
120

    let chunks = vec![
        r#"Let me help. "#,
        r#"<|python"#,
        r#"_tag|>"#,
        r#"{"name": "#,
        r#""calculate", "#,
        r#""arguments": "#,
        r#"{"x": 10}"#,
        r#"}"#,
    ];

121
    let mut got_tool_name = false;
122
123

    for chunk in chunks {
124
125
126
127
128
129
        let result = parser.parse_incremental(chunk, &tools).await.unwrap();
        for call in result.calls {
            if let Some(name) = call.name {
                assert_eq!(name, "calculate");
                got_tool_name = true;
            }
130
131
132
        }
    }

133
    assert!(got_tool_name, "Should have found tool name");
134
135
136
137
}

#[tokio::test]
async fn test_qwen_streaming() {
138
139
140
    let tools = create_test_tools();

    let mut parser = QwenParser::new();
141
142
143
144

    // Note: Parser expects newline after both tags
    let full_input = "<tool_call>\n{\"name\": \"translate\", \"arguments\": {\"text\": \"hello\", \"to\": \"zh\"}}\n</tool_call>";

145
    let result = parser.parse_incremental(full_input, &tools).await.unwrap();
146

147
148
    assert!(!result.calls.is_empty(), "Should have parsed a tool call");
    assert_eq!(result.calls[0].name, Some("translate".to_string()));
149
150
151
152
}

#[tokio::test]
async fn test_streaming_incomplete_stays_incomplete() {
153
154
155
    let tools = create_test_tools();

    let mut parser = JsonParser::new();
156
157
158
159

    let chunks = vec![r#"{"na"#, r#"me": "#];

    for chunk in chunks {
160
        let result = parser.parse_incremental(chunk, &tools).await.unwrap();
161
        assert!(
162
163
            result.calls.is_empty(),
            "Should return empty calls for partial JSON, got: {:?}",
164
165
166
167
168
169
170
            result
        );
    }
}

#[tokio::test]
async fn test_streaming_buffer_accumulation() {
171
    let tools = create_test_tools();
172

173
    let mut parser = JsonParser::new();
174

175
    let result1 = parser.parse_incremental(r#"{"na"#, &tools).await.unwrap();
176

177
    assert!(result1.calls.is_empty(), "Should not parse incomplete JSON");
178
179

    let result2 = parser
180
        .parse_incremental(r#"me": "test", "arguments": {}}"#, &tools)
181
182
183
        .await
        .unwrap();

184
185
186
187
188
    assert!(
        !result2.calls.is_empty(),
        "Should parse complete JSON after buffering"
    );
    assert_eq!(result2.calls[0].name, Some("test".to_string()));
189
190
191
192
}

#[tokio::test]
async fn test_streaming_multiple_tools_sequential() {
193
194
195
    let tools = create_test_tools();

    let mut parser = QwenParser::new();
196
197
198
199
200

    let full_input = r#"<tool_call>
{"name": "tool1", "arguments": {}}
</tool_call>"#;

201
    let result = parser.parse_incremental(full_input, &tools).await.unwrap();
202

203
204
    assert!(!result.calls.is_empty(), "Should have parsed a tool call");
    assert_eq!(result.calls[0].name, Some("tool1".to_string()));
205
206
207
208
}

#[tokio::test]
async fn test_streaming_reset_after_error() {
209
    let tools = create_test_tools();
210

211
212
213
214
    let mut parser1 = JsonParser::new();

    let _ = parser1
        .parse_incremental(r#"{"name": invalid}"#, &tools)
215
216
        .await;

217
218
219
220
    // Use a new parser instance for clean state
    let mut parser2 = JsonParser::new();
    let result = parser2
        .parse_incremental(r#"{"name": "test", "arguments": {}}"#, &tools)
221
222
223
        .await
        .unwrap();

224
225
    assert!(!result.calls.is_empty(), "Should parse valid JSON");
    assert_eq!(result.calls[0].name, Some("test".to_string()));
226
227
228
229
}

#[tokio::test]
async fn test_streaming_with_unicode_chunks() {
230
231
232
    let tools = create_test_tools();

    let mut parser = JsonParser::new();
233
234
235

    let full_input = r#"{"name": "translate", "arguments": {"text": "Hello 世界 🌍"}}"#;

236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
    let result = parser.parse_incremental(full_input, &tools).await.unwrap();

    assert!(!result.calls.is_empty(), "Should have parsed a tool call");

    // Check if we got the tool name
    if let Some(name) = &result.calls[0].name {
        assert_eq!(name, "translate");
    }

    // In streaming mode, need to make another call to get parameters
    let result2 = parser.parse_incremental("", &tools).await.unwrap();

    // Parameters should be in either result.calls[1] or result2.calls[0]
    let params = if result.calls.len() > 1 {
        &result.calls[1].parameters
    } else if !result2.calls.is_empty() {
        &result2.calls[0].parameters
    } else {
        &result.calls[0].parameters
    };

    if !params.is_empty() {
        let args: serde_json::Value = serde_json::from_str(params).unwrap();
        assert!(args["text"].as_str().unwrap().contains("世界"));
    }
}

#[tokio::test]
async fn test_streaming_with_partial_chunks() {
    let mut parser = JsonParser::new();
    let tools = create_test_tools();

    let partial = r#"{"#;
    let result = parser.parse_incremental(partial, &tools).await.unwrap();
    assert!(
        result.calls.is_empty(),
        "Should return empty calls for just opening brace"
    );

    let mut parser2 = JsonParser::new();
    let complete = r#"{"name": "get_weather", "arguments": {"location": "SF"}}"#;
    let result = parser2.parse_incremental(complete, &tools).await.unwrap();

    assert!(
        !result.calls.is_empty(),
        "Expected tool call for complete JSON"
    );
    assert_eq!(result.calls[0].name.as_ref().unwrap(), "get_weather");

    // In streaming mode, need to make another call to get parameters
    let result2 = parser2.parse_incremental("", &tools).await.unwrap();

    // Parameters should be in either result.calls[1] or result2.calls[0]
    let params = if result.calls.len() > 1 {
        &result.calls[1].parameters
    } else if !result2.calls.is_empty() {
        &result2.calls[0].parameters
    } else {
        &result.calls[0].parameters
    };

    if !params.is_empty() {
        let args: serde_json::Value = serde_json::from_str(params).unwrap();
        assert_eq!(args["location"], "SF");
    }

    // The PartialJson parser can complete partial JSON by filling in missing values
    let mut parser3 = JsonParser::new();
    let partial_with_name = r#"{"name": "test", "argum"#;
    let result = parser3
        .parse_incremental(partial_with_name, &tools)
307
308
309
        .await
        .unwrap();

310
311
312
    // Parser behavior may vary - either complete with partial data or wait for more
    if !result.calls.is_empty() {
        assert_eq!(result.calls[0].name.as_ref().unwrap(), "test");
313
314
    }
}