Unverified Commit aae7ead2 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] remove old/oudated/useless comments across code base (#10968)

parent a7fe6e10
...@@ -84,8 +84,6 @@ fn create_minimal_completion_request() -> CompletionRequest { ...@@ -84,8 +84,6 @@ fn create_minimal_completion_request() -> CompletionRequest {
} }
} }
// ============= Basic Unit Tests =============
/// Test basic OpenAI router creation and configuration /// Test basic OpenAI router creation and configuration
#[tokio::test] #[tokio::test]
async fn test_openai_router_creation() { async fn test_openai_router_creation() {
...@@ -575,7 +573,6 @@ async fn test_unsupported_endpoints() { ...@@ -575,7 +573,6 @@ async fn test_unsupported_endpoints() {
.await .await
.unwrap(); .unwrap();
// Test generate endpoint (SGLang-specific, should not be supported)
let generate_request = GenerateRequest { let generate_request = GenerateRequest {
prompt: None, prompt: None,
text: Some("Hello world".to_string()), text: Some("Hello world".to_string()),
...@@ -593,7 +590,6 @@ async fn test_unsupported_endpoints() { ...@@ -593,7 +590,6 @@ async fn test_unsupported_endpoints() {
let response = router.route_generate(None, &generate_request, None).await; let response = router.route_generate(None, &generate_request, None).await;
assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED); assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED);
// Test completion endpoint (should also not be supported)
let completion_request = create_minimal_completion_request(); let completion_request = create_minimal_completion_request();
let response = router let response = router
.route_completion(None, &completion_request, None) .route_completion(None, &completion_request, None)
...@@ -601,8 +597,6 @@ async fn test_unsupported_endpoints() { ...@@ -601,8 +597,6 @@ async fn test_unsupported_endpoints() {
assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED); assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED);
} }
// ============= Mock Server E2E Tests =============
/// Test chat completion with mock OpenAI server /// Test chat completion with mock OpenAI server
#[tokio::test] #[tokio::test]
async fn test_openai_router_chat_completion_with_mock() { async fn test_openai_router_chat_completion_with_mock() {
...@@ -635,7 +629,6 @@ async fn test_openai_router_chat_completion_with_mock() { ...@@ -635,7 +629,6 @@ async fn test_openai_router_chat_completion_with_mock() {
let body_str = String::from_utf8(body_bytes.to_vec()).unwrap(); let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
let chat_response: serde_json::Value = serde_json::from_str(&body_str).unwrap(); let chat_response: serde_json::Value = serde_json::from_str(&body_str).unwrap();
// Verify it's a valid chat completion response
assert_eq!(chat_response["object"], "chat.completion"); assert_eq!(chat_response["object"], "chat.completion");
assert_eq!(chat_response["model"], "gpt-3.5-turbo"); assert_eq!(chat_response["model"], "gpt-3.5-turbo");
assert!(!chat_response["choices"].as_array().unwrap().is_empty()); assert!(!chat_response["choices"].as_array().unwrap().is_empty());
...@@ -704,7 +697,6 @@ async fn test_openai_e2e_with_server() { ...@@ -704,7 +697,6 @@ async fn test_openai_e2e_with_server() {
.unwrap(); .unwrap();
let response_json: serde_json::Value = serde_json::from_slice(&body).unwrap(); let response_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
// Verify the response structure
assert_eq!(response_json["object"], "chat.completion"); assert_eq!(response_json["object"], "chat.completion");
assert_eq!(response_json["model"], "gpt-3.5-turbo"); assert_eq!(response_json["model"], "gpt-3.5-turbo");
assert!(!response_json["choices"].as_array().unwrap().is_empty()); assert!(!response_json["choices"].as_array().unwrap().is_empty());
......
...@@ -9,7 +9,6 @@ mod test_pd_routing { ...@@ -9,7 +9,6 @@ mod test_pd_routing {
use sglang_router_rs::routers::http::pd_types::PDSelectionPolicy; use sglang_router_rs::routers::http::pd_types::PDSelectionPolicy;
use sglang_router_rs::routers::RouterFactory; use sglang_router_rs::routers::RouterFactory;
// Test-only struct to help validate PD request parsing
#[derive(Debug)] #[derive(Debug)]
struct PDRequest { struct PDRequest {
pub is_stream: bool, pub is_stream: bool,
...@@ -17,14 +16,12 @@ mod test_pd_routing { ...@@ -17,14 +16,12 @@ mod test_pd_routing {
} }
impl PDRequest { impl PDRequest {
// Extract PD-relevant info from JSON for testing
pub fn from_json(json: &serde_json::Value) -> Self { pub fn from_json(json: &serde_json::Value) -> Self {
let is_stream = json let is_stream = json
.get("stream") .get("stream")
.and_then(|v| v.as_bool()) .and_then(|v| v.as_bool())
.unwrap_or(false); .unwrap_or(false);
// Detect batch size from text or input_ids
let batch_size = if let Some(text) = json.get("text") { let batch_size = if let Some(text) = json.get("text") {
text.as_array().map(|arr| arr.len()) text.as_array().map(|arr| arr.len())
} else if let Some(input_ids) = json.get("input_ids") { } else if let Some(input_ids) = json.get("input_ids") {
...@@ -40,15 +37,10 @@ mod test_pd_routing { ...@@ -40,15 +37,10 @@ mod test_pd_routing {
} }
} }
// ========================================================================
// Phase 1: Basic PD Components and Router Creation
// ========================================================================
#[test] #[test]
fn test_worker_types() { fn test_worker_types() {
use sglang_router_rs::core::{BasicWorkerBuilder, Worker, WorkerType}; use sglang_router_rs::core::{BasicWorkerBuilder, Worker, WorkerType};
// Test worker creation for prefill servers
let prefill_worker: Box<dyn Worker> = Box::new( let prefill_worker: Box<dyn Worker> = Box::new(
BasicWorkerBuilder::new("http://prefill:8080") BasicWorkerBuilder::new("http://prefill:8080")
.worker_type(WorkerType::Prefill { .worker_type(WorkerType::Prefill {
...@@ -65,7 +57,6 @@ mod test_pd_routing { ...@@ -65,7 +57,6 @@ mod test_pd_routing {
_ => panic!("Expected Prefill worker type"), _ => panic!("Expected Prefill worker type"),
} }
// Test worker creation for decode servers
let decode_worker: Box<dyn Worker> = Box::new( let decode_worker: Box<dyn Worker> = Box::new(
BasicWorkerBuilder::new("http://decode:8080") BasicWorkerBuilder::new("http://decode:8080")
.worker_type(WorkerType::Decode) .worker_type(WorkerType::Decode)
...@@ -78,7 +69,6 @@ mod test_pd_routing { ...@@ -78,7 +69,6 @@ mod test_pd_routing {
_ => panic!("Expected Decode worker type"), _ => panic!("Expected Decode worker type"),
} }
// Test regular worker creation
let regular_worker: Box<dyn Worker> = Box::new( let regular_worker: Box<dyn Worker> = Box::new(
BasicWorkerBuilder::new("http://regular:8080") BasicWorkerBuilder::new("http://regular:8080")
.worker_type(WorkerType::Regular) .worker_type(WorkerType::Regular)
...@@ -94,7 +84,6 @@ mod test_pd_routing { ...@@ -94,7 +84,6 @@ mod test_pd_routing {
#[test] #[test]
fn test_pd_selection_policies() { fn test_pd_selection_policies() {
// Test all PD selection policy variants
// Note: These policies are only used when pd_disaggregation=true // Note: These policies are only used when pd_disaggregation=true
let policies = vec![ let policies = vec![
PDSelectionPolicy::Random, PDSelectionPolicy::Random,
...@@ -107,7 +96,6 @@ mod test_pd_routing { ...@@ -107,7 +96,6 @@ mod test_pd_routing {
]; ];
for policy in policies { for policy in policies {
// Verify each policy can be created and matched
match &policy { match &policy {
PDSelectionPolicy::Random => { PDSelectionPolicy::Random => {
assert!(matches!(policy, PDSelectionPolicy::Random)); assert!(matches!(policy, PDSelectionPolicy::Random));
...@@ -126,7 +114,6 @@ mod test_pd_routing { ...@@ -126,7 +114,6 @@ mod test_pd_routing {
#[tokio::test] #[tokio::test]
async fn test_pd_router_configuration() { async fn test_pd_router_configuration() {
// Test PD router configuration with various policies
// In the new structure, RoutingMode and PolicyConfig are separate // In the new structure, RoutingMode and PolicyConfig are separate
let test_cases = vec![ let test_cases = vec![
( (
...@@ -221,7 +208,6 @@ mod test_pd_routing { ...@@ -221,7 +208,6 @@ mod test_pd_routing {
"Router creation should succeed with empty worker" "Router creation should succeed with empty worker"
); );
// Verify that no workers are registered since we didn't initialize them
let stats = app_context.worker_registry.stats(); let stats = app_context.worker_registry.stats();
assert_eq!( assert_eq!(
stats.total_workers, 0, stats.total_workers, 0,
...@@ -230,13 +216,8 @@ mod test_pd_routing { ...@@ -230,13 +216,8 @@ mod test_pd_routing {
} }
} }
// ========================================================================
// Phase 2: Bootstrap Injection and Request Handling
// ========================================================================
#[test] #[test]
fn test_pd_request_from_json() { fn test_pd_request_from_json() {
// Test PDRequest parsing from single text request
let single_json = json!({ let single_json = json!({
"text": "Hello world", "text": "Hello world",
"stream": false, "stream": false,
...@@ -248,7 +229,6 @@ mod test_pd_routing { ...@@ -248,7 +229,6 @@ mod test_pd_routing {
assert!(!pd_req.is_stream); assert!(!pd_req.is_stream);
assert_eq!(pd_req.batch_size, None); assert_eq!(pd_req.batch_size, None);
// Test PDRequest parsing from batch text request
let batch_json = json!({ let batch_json = json!({
"text": ["Hello", "World", "Test"], "text": ["Hello", "World", "Test"],
"stream": true, "stream": true,
...@@ -259,7 +239,6 @@ mod test_pd_routing { ...@@ -259,7 +239,6 @@ mod test_pd_routing {
assert!(pd_req.is_stream); assert!(pd_req.is_stream);
assert_eq!(pd_req.batch_size, Some(3)); assert_eq!(pd_req.batch_size, Some(3));
// Test PDRequest parsing from input_ids request
let ids_json = json!({ let ids_json = json!({
"input_ids": [[1, 2, 3], [4, 5, 6]], "input_ids": [[1, 2, 3], [4, 5, 6]],
"stream": false "stream": false
...@@ -269,7 +248,6 @@ mod test_pd_routing { ...@@ -269,7 +248,6 @@ mod test_pd_routing {
assert!(!pd_req.is_stream); assert!(!pd_req.is_stream);
assert_eq!(pd_req.batch_size, Some(2)); assert_eq!(pd_req.batch_size, Some(2));
// Test PDRequest parsing from chat request
let chat_json = json!({ let chat_json = json!({
"messages": [ "messages": [
{"role": "system", "content": "You are a helpful assistant"}, {"role": "system", "content": "You are a helpful assistant"},
...@@ -288,14 +266,12 @@ mod test_pd_routing { ...@@ -288,14 +266,12 @@ mod test_pd_routing {
// Since we can't test the actual inject_bootstrap_fields function here // Since we can't test the actual inject_bootstrap_fields function here
// (it's private in the router module), we'll test the expected behavior // (it's private in the router module), we'll test the expected behavior
// Simulate bootstrap injection for single request
let mut single_json = json!({ let mut single_json = json!({
"text": "Hello world", "text": "Hello world",
"stream": false, "stream": false,
"temperature": 0.7 "temperature": 0.7
}); });
// Create a prefill worker to simulate injection
let prefill_worker: Box<dyn Worker> = Box::new( let prefill_worker: Box<dyn Worker> = Box::new(
BasicWorkerBuilder::new("http://prefill1:8080") BasicWorkerBuilder::new("http://prefill1:8080")
.worker_type(WorkerType::Prefill { .worker_type(WorkerType::Prefill {
...@@ -305,24 +281,20 @@ mod test_pd_routing { ...@@ -305,24 +281,20 @@ mod test_pd_routing {
.build(), .build(),
); );
// Extract bootstrap port from worker type
let bootstrap_port = match prefill_worker.worker_type() { let bootstrap_port = match prefill_worker.worker_type() {
WorkerType::Prefill { bootstrap_port } => bootstrap_port, WorkerType::Prefill { bootstrap_port } => bootstrap_port,
_ => None, _ => None,
}; };
// Simulate what inject_bootstrap_fields would do
single_json["bootstrap_host"] = json!(get_hostname(prefill_worker.url())); single_json["bootstrap_host"] = json!(get_hostname(prefill_worker.url()));
single_json["bootstrap_port"] = json!(bootstrap_port); single_json["bootstrap_port"] = json!(bootstrap_port);
single_json["bootstrap_room"] = json!(12345u64); // Random room ID single_json["bootstrap_room"] = json!(12345u64); // Random room ID
// Verify bootstrap fields are added correctly
assert_eq!(single_json["bootstrap_host"], "prefill1"); assert_eq!(single_json["bootstrap_host"], "prefill1");
assert_eq!(single_json["bootstrap_port"], json!(Some(9000))); assert_eq!(single_json["bootstrap_port"], json!(Some(9000)));
assert!(single_json["bootstrap_room"].is_u64()); assert!(single_json["bootstrap_room"].is_u64());
assert_eq!(single_json["temperature"], 0.7); // Original field preserved assert_eq!(single_json["temperature"], 0.7); // Original field preserved
// Simulate bootstrap injection for batch request
let mut batch_json = json!({ let mut batch_json = json!({
"text": ["Hello", "World", "Test"], "text": ["Hello", "World", "Test"],
"stream": true "stream": true
...@@ -334,7 +306,6 @@ mod test_pd_routing { ...@@ -334,7 +306,6 @@ mod test_pd_routing {
batch_json["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]); batch_json["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
batch_json["bootstrap_room"] = json!(vec![111u64, 222u64, 333u64]); batch_json["bootstrap_room"] = json!(vec![111u64, 222u64, 333u64]);
// Verify batch bootstrap fields
assert!(batch_json["bootstrap_host"].is_array()); assert!(batch_json["bootstrap_host"].is_array());
assert_eq!( assert_eq!(
batch_json["bootstrap_host"].as_array().unwrap().len(), batch_json["bootstrap_host"].as_array().unwrap().len(),
...@@ -347,7 +318,6 @@ mod test_pd_routing { ...@@ -347,7 +318,6 @@ mod test_pd_routing {
#[test] #[test]
fn test_request_serialization() { fn test_request_serialization() {
// Test that requests can be properly serialized and deserialized
let request = json!({ let request = json!({
"text": "Test prompt", "text": "Test prompt",
"stream": false, "stream": false,
...@@ -360,13 +330,10 @@ mod test_pd_routing { ...@@ -360,13 +330,10 @@ mod test_pd_routing {
"bootstrap_room": 12345u64 "bootstrap_room": 12345u64
}); });
// Convert to bytes (as would happen in the router)
let bytes = serde_json::to_vec(&request).unwrap(); let bytes = serde_json::to_vec(&request).unwrap();
// Parse back from bytes
let parsed: serde_json::Value = serde_json::from_slice(&bytes).unwrap(); let parsed: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
// Verify all fields are preserved
assert_eq!(parsed["text"], "Test prompt"); assert_eq!(parsed["text"], "Test prompt");
assert_eq!(parsed["stream"], false); assert_eq!(parsed["stream"], false);
assert_eq!(parsed["temperature"], 0.7); assert_eq!(parsed["temperature"], 0.7);
...@@ -378,7 +345,6 @@ mod test_pd_routing { ...@@ -378,7 +345,6 @@ mod test_pd_routing {
#[test] #[test]
fn test_hostname_extraction() { fn test_hostname_extraction() {
// Test various URL formats
let test_cases = vec![ let test_cases = vec![
("http://localhost:8080", "localhost"), ("http://localhost:8080", "localhost"),
("http://10.0.0.1:8080", "10.0.0.1"), ("http://10.0.0.1:8080", "10.0.0.1"),
...@@ -395,13 +361,11 @@ mod test_pd_routing { ...@@ -395,13 +361,11 @@ mod test_pd_routing {
#[test] #[test]
fn test_pd_request_edge_cases() { fn test_pd_request_edge_cases() {
// Test empty request
let empty_json = json!({}); let empty_json = json!({});
let pd_req = PDRequest::from_json(&empty_json); let pd_req = PDRequest::from_json(&empty_json);
assert!(!pd_req.is_stream); assert!(!pd_req.is_stream);
assert_eq!(pd_req.batch_size, None); assert_eq!(pd_req.batch_size, None);
// Test request with only stream field
let stream_only = json!({ let stream_only = json!({
"stream": true "stream": true
}); });
...@@ -409,14 +373,12 @@ mod test_pd_routing { ...@@ -409,14 +373,12 @@ mod test_pd_routing {
assert!(pd_req.is_stream); assert!(pd_req.is_stream);
assert_eq!(pd_req.batch_size, None); assert_eq!(pd_req.batch_size, None);
// Test request with empty text array
let empty_batch = json!({ let empty_batch = json!({
"text": [] "text": []
}); });
let pd_req = PDRequest::from_json(&empty_batch); let pd_req = PDRequest::from_json(&empty_batch);
assert_eq!(pd_req.batch_size, Some(0)); assert_eq!(pd_req.batch_size, Some(0));
// Test request with non-array text (should be None)
let non_array_text = json!({ let non_array_text = json!({
"text": "single string" "text": "single string"
}); });
...@@ -424,29 +386,21 @@ mod test_pd_routing { ...@@ -424,29 +386,21 @@ mod test_pd_routing {
assert_eq!(pd_req.batch_size, None); assert_eq!(pd_req.batch_size, None);
} }
// ========================================================================
// Phase 2: Background Load Monitoring Tests
// ========================================================================
#[tokio::test] #[tokio::test]
async fn test_background_load_monitoring() { async fn test_background_load_monitoring() {
use std::collections::HashMap; use std::collections::HashMap;
use tokio::sync::watch; use tokio::sync::watch;
// Create a watch channel for testing
let (tx, rx) = watch::channel(HashMap::new()); let (tx, rx) = watch::channel(HashMap::new());
// Simulate load updates
let mut loads = HashMap::new(); let mut loads = HashMap::new();
loads.insert("http://prefill1:8080".to_string(), 10); loads.insert("http://prefill1:8080".to_string(), 10);
loads.insert("http://prefill2:8080".to_string(), 20); loads.insert("http://prefill2:8080".to_string(), 20);
loads.insert("http://decode1:8080".to_string(), 5); loads.insert("http://decode1:8080".to_string(), 5);
loads.insert("http://decode2:8080".to_string(), 15); loads.insert("http://decode2:8080".to_string(), 15);
// Send the loads
tx.send(loads.clone()).unwrap(); tx.send(loads.clone()).unwrap();
// Verify receiver gets the update
let received_loads = rx.borrow(); let received_loads = rx.borrow();
assert_eq!(received_loads.get("http://prefill1:8080"), Some(&10)); assert_eq!(received_loads.get("http://prefill1:8080"), Some(&10));
assert_eq!(received_loads.get("http://prefill2:8080"), Some(&20)); assert_eq!(received_loads.get("http://prefill2:8080"), Some(&20));
...@@ -456,7 +410,6 @@ mod test_pd_routing { ...@@ -456,7 +410,6 @@ mod test_pd_routing {
#[test] #[test]
fn test_load_monitoring_configuration() { fn test_load_monitoring_configuration() {
// Test that load monitoring is only enabled for PowerOfTwo policy
let policies = vec![ let policies = vec![
(PDSelectionPolicy::Random, false), (PDSelectionPolicy::Random, false),
(PDSelectionPolicy::PowerOfTwo, true), (PDSelectionPolicy::PowerOfTwo, true),
...@@ -483,42 +436,31 @@ mod test_pd_routing { ...@@ -483,42 +436,31 @@ mod test_pd_routing {
use std::collections::HashMap; use std::collections::HashMap;
use tokio::sync::watch; use tokio::sync::watch;
// Test watch channel's broadcast behavior
let (tx, rx1) = watch::channel(HashMap::new()); let (tx, rx1) = watch::channel(HashMap::new());
let rx2 = rx1.clone(); let rx2 = rx1.clone();
// Initial state - empty map
assert!(rx1.borrow().is_empty()); assert!(rx1.borrow().is_empty());
assert!(rx2.borrow().is_empty()); assert!(rx2.borrow().is_empty());
// Update 1
let mut loads = HashMap::new(); let mut loads = HashMap::new();
loads.insert("worker1".to_string(), 10); loads.insert("worker1".to_string(), 10);
tx.send(loads.clone()).unwrap(); tx.send(loads.clone()).unwrap();
// Both receivers see the update
assert_eq!(rx1.borrow().get("worker1"), Some(&10)); assert_eq!(rx1.borrow().get("worker1"), Some(&10));
assert_eq!(rx2.borrow().get("worker1"), Some(&10)); assert_eq!(rx2.borrow().get("worker1"), Some(&10));
// Update 2 - overwrites previous
loads.insert("worker1".to_string(), 20); loads.insert("worker1".to_string(), 20);
loads.insert("worker2".to_string(), 30); loads.insert("worker2".to_string(), 30);
tx.send(loads).unwrap(); tx.send(loads).unwrap();
// Both receivers see the latest state
assert_eq!(rx1.borrow().get("worker1"), Some(&20)); assert_eq!(rx1.borrow().get("worker1"), Some(&20));
assert_eq!(rx2.borrow().get("worker2"), Some(&30)); assert_eq!(rx2.borrow().get("worker2"), Some(&30));
} }
// ========================================================================
// Tests based on bench_one_batch_server.py patterns
// ========================================================================
#[test] #[test]
fn test_generate_request_formats() { fn test_generate_request_formats() {
// Based on bench_one_batch_server.py request patterns // Based on bench_one_batch_server.py request patterns
// Test 1: Batch request with input_ids (most common in benchmarks)
let batch_request = json!({ let batch_request = json!({
"input_ids": [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], "input_ids": [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
"sampling_params": { "sampling_params": {
...@@ -534,7 +476,6 @@ mod test_pd_routing { ...@@ -534,7 +476,6 @@ mod test_pd_routing {
assert!(pd_req.is_stream); assert!(pd_req.is_stream);
assert_eq!(pd_req.batch_size, Some(3)); assert_eq!(pd_req.batch_size, Some(3));
// Test 2: Request with return_logprob (critical for PD)
let logprob_request = json!({ let logprob_request = json!({
"input_ids": [[1, 2, 3]], "input_ids": [[1, 2, 3]],
"sampling_params": { "sampling_params": {
...@@ -548,7 +489,6 @@ mod test_pd_routing { ...@@ -548,7 +489,6 @@ mod test_pd_routing {
assert_eq!(logprob_request["return_logprob"], true); assert_eq!(logprob_request["return_logprob"], true);
assert_eq!(logprob_request["stream"], false); assert_eq!(logprob_request["stream"], false);
// Test 3: Large batch sizes from benchmark
let batch_sizes = vec![1, 16, 64]; // From bench_one_batch_server.py let batch_sizes = vec![1, 16, 64]; // From bench_one_batch_server.py
for bs in batch_sizes { for bs in batch_sizes {
let request = json!({ let request = json!({
...@@ -567,7 +507,6 @@ mod test_pd_routing { ...@@ -567,7 +507,6 @@ mod test_pd_routing {
#[test] #[test]
fn test_sampling_params_handling() { fn test_sampling_params_handling() {
// Test various sampling parameters from bench_one_batch_server.py
let sampling_params_variations = vec![ let sampling_params_variations = vec![
json!({ json!({
"temperature": 0.0, "temperature": 0.0,
...@@ -595,14 +534,12 @@ mod test_pd_routing { ...@@ -595,14 +534,12 @@ mod test_pd_routing {
"stream": false "stream": false
}); });
// Verify params are preserved
assert_eq!(request["sampling_params"], params); assert_eq!(request["sampling_params"], params);
} }
} }
#[test] #[test]
fn test_streaming_response_parsing() { fn test_streaming_response_parsing() {
// Test SSE format parsing from streaming responses
let sse_chunks = ["data: {\"text\":\"Hello\",\"meta_info\":{\"completion_tokens\":1,\"finish_reason\":null}}", let sse_chunks = ["data: {\"text\":\"Hello\",\"meta_info\":{\"completion_tokens\":1,\"finish_reason\":null}}",
"data: {\"text\":\" world\",\"meta_info\":{\"completion_tokens\":2,\"finish_reason\":null}}", "data: {\"text\":\" world\",\"meta_info\":{\"completion_tokens\":2,\"finish_reason\":null}}",
"data: {\"text\":\"!\",\"meta_info\":{\"completion_tokens\":3,\"finish_reason\":{\"type\":\"length\"}}}", "data: {\"text\":\"!\",\"meta_info\":{\"completion_tokens\":3,\"finish_reason\":{\"type\":\"length\"}}}",
...@@ -615,13 +552,11 @@ mod test_pd_routing { ...@@ -615,13 +552,11 @@ mod test_pd_routing {
assert!(parsed["meta_info"]["completion_tokens"].is_u64()); assert!(parsed["meta_info"]["completion_tokens"].is_u64());
} }
// Test [DONE] detection
assert_eq!(sse_chunks[3], "data: [DONE]"); assert_eq!(sse_chunks[3], "data: [DONE]");
} }
#[test] #[test]
fn test_ttft_calculation() { fn test_ttft_calculation() {
// Test Time To First Token calculation pattern
let first_token_response = json!({ let first_token_response = json!({
"text": "Hello", "text": "Hello",
"meta_info": { "meta_info": {
...@@ -637,7 +572,6 @@ mod test_pd_routing { ...@@ -637,7 +572,6 @@ mod test_pd_routing {
#[test] #[test]
fn test_throughput_metrics() { fn test_throughput_metrics() {
// Test throughput calculation patterns from bench_one_batch_server.py
let batch_size = 16; let batch_size = 16;
let input_len = 1024; let input_len = 1024;
let output_len = 16; let output_len = 16;
...@@ -655,7 +589,6 @@ mod test_pd_routing { ...@@ -655,7 +589,6 @@ mod test_pd_routing {
#[test] #[test]
fn test_error_response_handling() { fn test_error_response_handling() {
// Test error response format from bench_one_batch_server.py
let error_response = json!({ let error_response = json!({
"error": "Request has failed. Invalid input format." "error": "Request has failed. Invalid input format."
}); });
...@@ -666,7 +599,6 @@ mod test_pd_routing { ...@@ -666,7 +599,6 @@ mod test_pd_routing {
#[test] #[test]
fn test_structured_output_request() { fn test_structured_output_request() {
// Test structured output format (json_schema)
let structured_request = json!({ let structured_request = json!({
"text": "What is the capital of France? Answer in JSON.", "text": "What is the capital of France? Answer in JSON.",
"sampling_params": { "sampling_params": {
...@@ -687,7 +619,6 @@ mod test_pd_routing { ...@@ -687,7 +619,6 @@ mod test_pd_routing {
fn test_bootstrap_injection_with_benchmark_requests() { fn test_bootstrap_injection_with_benchmark_requests() {
use sglang_router_rs::core::{BasicWorkerBuilder, Worker, WorkerType}; use sglang_router_rs::core::{BasicWorkerBuilder, Worker, WorkerType};
// Test bootstrap injection with actual benchmark request patterns
let mut benchmark_request = json!({ let mut benchmark_request = json!({
"input_ids": vec![vec![1, 2, 3, 4]; 16], // Batch size 16 "input_ids": vec![vec![1, 2, 3, 4]; 16], // Batch size 16
"sampling_params": { "sampling_params": {
...@@ -699,7 +630,6 @@ mod test_pd_routing { ...@@ -699,7 +630,6 @@ mod test_pd_routing {
"stream": true "stream": true
}); });
// Create a prefill worker to simulate injection
let prefill_worker: Box<dyn Worker> = Box::new( let prefill_worker: Box<dyn Worker> = Box::new(
BasicWorkerBuilder::new("http://prefill:8080") BasicWorkerBuilder::new("http://prefill:8080")
.worker_type(WorkerType::Prefill { .worker_type(WorkerType::Prefill {
...@@ -709,7 +639,6 @@ mod test_pd_routing { ...@@ -709,7 +639,6 @@ mod test_pd_routing {
.build(), .build(),
); );
// Extract bootstrap port from worker type
let bootstrap_port = match prefill_worker.worker_type() { let bootstrap_port = match prefill_worker.worker_type() {
WorkerType::Prefill { bootstrap_port } => bootstrap_port, WorkerType::Prefill { bootstrap_port } => bootstrap_port,
_ => None, _ => None,
...@@ -722,7 +651,6 @@ mod test_pd_routing { ...@@ -722,7 +651,6 @@ mod test_pd_routing {
benchmark_request["bootstrap_room"] = benchmark_request["bootstrap_room"] =
json!((0..batch_size).map(|_| 12345u64).collect::<Vec<_>>()); json!((0..batch_size).map(|_| 12345u64).collect::<Vec<_>>());
// Verify bootstrap fields match batch size
assert_eq!( assert_eq!(
benchmark_request["bootstrap_host"] benchmark_request["bootstrap_host"]
.as_array() .as_array()
...@@ -745,14 +673,12 @@ mod test_pd_routing { ...@@ -745,14 +673,12 @@ mod test_pd_routing {
batch_size batch_size
); );
// Verify original fields are preserved
assert_eq!(benchmark_request["return_logprob"], true); assert_eq!(benchmark_request["return_logprob"], true);
assert_eq!(benchmark_request["stream"], true); assert_eq!(benchmark_request["stream"], true);
} }
#[test] #[test]
fn test_server_info_response_format() { fn test_server_info_response_format() {
// Test server info format expected by bench_one_batch_server.py
let server_info = json!({ let server_info = json!({
"internal_states": [{ "internal_states": [{
"avg_spec_accept_length": 3.5, "avg_spec_accept_length": 3.5,
...@@ -769,16 +695,13 @@ mod test_pd_routing { ...@@ -769,16 +695,13 @@ mod test_pd_routing {
] ]
}); });
// Verify structure matches what benchmark expects
assert!(server_info["internal_states"][0]["avg_spec_accept_length"].is_f64()); assert!(server_info["internal_states"][0]["avg_spec_accept_length"].is_f64());
assert!(server_info["internal_states"][0]["last_gen_throughput"].is_f64()); assert!(server_info["internal_states"][0]["last_gen_throughput"].is_f64());
assert!(server_info["prefill"].is_array()); assert!(server_info["prefill"].is_array());
assert!(server_info["decode"].is_array()); assert!(server_info["decode"].is_array());
} }
// ========================================================================
// Comprehensive Endpoint Coverage Test // Comprehensive Endpoint Coverage Test
// ========================================================================
#[test] #[test]
fn test_pd_endpoints_coverage() { fn test_pd_endpoints_coverage() {
...@@ -807,7 +730,6 @@ mod test_pd_routing { ...@@ -807,7 +730,6 @@ mod test_pd_routing {
assert_eq!(implemented_count, 10); assert_eq!(implemented_count, 10);
assert_eq!(total_count, 11); assert_eq!(total_count, 11);
// Document the missing endpoint
let missing: Vec<_> = implemented_endpoints let missing: Vec<_> = implemented_endpoints
.iter() .iter()
.filter(|(_, _, impl_status)| !impl_status) .filter(|(_, _, impl_status)| !impl_status)
...@@ -819,14 +741,12 @@ mod test_pd_routing { ...@@ -819,14 +741,12 @@ mod test_pd_routing {
#[test] #[test]
fn test_large_batch_bootstrap_injection() { fn test_large_batch_bootstrap_injection() {
// Test bootstrap injection performance with very large batches
// This simulates the bench_one_batch_server.py scenario // This simulates the bench_one_batch_server.py scenario
let large_batch_sizes = vec![1024, 4096, 8192]; let large_batch_sizes = vec![1024, 4096, 8192];
for batch_size in large_batch_sizes { for batch_size in large_batch_sizes {
let start = std::time::Instant::now(); let start = std::time::Instant::now();
// Simulate a large batch request
let mut large_batch_request = json!({ let mut large_batch_request = json!({
"input_ids": vec![vec![1, 2, 3, 4]; batch_size], "input_ids": vec![vec![1, 2, 3, 4]; batch_size],
"sampling_params": { "sampling_params": {
...@@ -836,7 +756,6 @@ mod test_pd_routing { ...@@ -836,7 +756,6 @@ mod test_pd_routing {
"stream": true "stream": true
}); });
// Create a prefill worker to simulate injection
let prefill_worker: Box<dyn Worker> = Box::new( let prefill_worker: Box<dyn Worker> = Box::new(
BasicWorkerBuilder::new("http://prefill:8080") BasicWorkerBuilder::new("http://prefill:8080")
.worker_type(WorkerType::Prefill { .worker_type(WorkerType::Prefill {
...@@ -846,7 +765,6 @@ mod test_pd_routing { ...@@ -846,7 +765,6 @@ mod test_pd_routing {
.build(), .build(),
); );
// Extract bootstrap port from worker type
let bootstrap_port = match prefill_worker.worker_type() { let bootstrap_port = match prefill_worker.worker_type() {
WorkerType::Prefill { bootstrap_port } => bootstrap_port, WorkerType::Prefill { bootstrap_port } => bootstrap_port,
_ => None, _ => None,
...@@ -861,7 +779,6 @@ mod test_pd_routing { ...@@ -861,7 +779,6 @@ mod test_pd_routing {
let elapsed = start.elapsed(); let elapsed = start.elapsed();
// Verify bootstrap fields are correctly sized
assert_eq!( assert_eq!(
large_batch_request["bootstrap_host"] large_batch_request["bootstrap_host"]
.as_array() .as_array()
...@@ -899,7 +816,6 @@ mod test_pd_routing { ...@@ -899,7 +816,6 @@ mod test_pd_routing {
#[test] #[test]
fn test_payload_size_calculation() { fn test_payload_size_calculation() {
// Test payload size estimation for bench_one_batch_server.py scenarios
let test_cases = vec![ let test_cases = vec![
(1, 1024, 16), // Small batch (1, 1024, 16), // Small batch
(16, 1024, 16), // Medium batch (16, 1024, 16), // Medium batch
...@@ -937,14 +853,12 @@ mod test_pd_routing { ...@@ -937,14 +853,12 @@ mod test_pd_routing {
#[test] #[test]
fn test_policy_type_to_pd_selection_policy_mapping() { fn test_policy_type_to_pd_selection_policy_mapping() {
// Test that PDSelectionPolicy doesn't include RoundRobin
let pd_policy_count = 3; // Random, PowerOfTwo, CacheAware let pd_policy_count = 3; // Random, PowerOfTwo, CacheAware
assert_eq!( assert_eq!(
pd_policy_count, 3, pd_policy_count, 3,
"PDSelectionPolicy should have exactly 3 variants" "PDSelectionPolicy should have exactly 3 variants"
); );
// Verify that each PDSelectionPolicy variant can be created
let _random = PDSelectionPolicy::Random; let _random = PDSelectionPolicy::Random;
let _po2 = PDSelectionPolicy::PowerOfTwo; let _po2 = PDSelectionPolicy::PowerOfTwo;
let _cache_aware = PDSelectionPolicy::CacheAware { let _cache_aware = PDSelectionPolicy::CacheAware {
......
...@@ -84,7 +84,6 @@ fn test_sequence_operations() { ...@@ -84,7 +84,6 @@ fn test_sequence_operations() {
for prompt in TEST_PROMPTS.iter() { for prompt in TEST_PROMPTS.iter() {
let encoding = tokenizer.encode(prompt).expect("Failed to encode prompt"); let encoding = tokenizer.encode(prompt).expect("Failed to encode prompt");
// Test Sequence with append_text
let mut sequence = Sequence::new(tokenizer.clone()); let mut sequence = Sequence::new(tokenizer.clone());
sequence.append_text(prompt).expect("Failed to append text"); sequence.append_text(prompt).expect("Failed to append text");
...@@ -95,7 +94,6 @@ fn test_sequence_operations() { ...@@ -95,7 +94,6 @@ fn test_sequence_operations() {
); );
assert_eq!(sequence.text().unwrap(), *prompt, "Sequence text mismatch"); assert_eq!(sequence.text().unwrap(), *prompt, "Sequence text mismatch");
// Test incremental decoding with append_token
let mut decoder = Sequence::new(tokenizer.clone()); let mut decoder = Sequence::new(tokenizer.clone());
let mut output = String::new(); let mut output = String::new();
...@@ -178,7 +176,6 @@ fn test_stop_sequence_decoder() { ...@@ -178,7 +176,6 @@ fn test_stop_sequence_decoder() {
.expect("Failed to load tokenizer"), .expect("Failed to load tokenizer"),
); );
// Test with various stop sequences
let test_cases = vec![ let test_cases = vec![
( (
"Hello world! Stop here. Continue after.", "Hello world! Stop here. Continue after.",
...@@ -237,7 +234,6 @@ fn test_stop_sequence_decoder() { ...@@ -237,7 +234,6 @@ fn test_stop_sequence_decoder() {
#[test] #[test]
fn test_factory_creation() { fn test_factory_creation() {
// Test factory creation method
let tokenizer_path = ensure_tokenizer_cached(); let tokenizer_path = ensure_tokenizer_cached();
let tokenizer = factory::create_tokenizer(tokenizer_path.to_str().unwrap()) let tokenizer = factory::create_tokenizer(tokenizer_path.to_str().unwrap())
.expect("Failed to create tokenizer via factory"); .expect("Failed to create tokenizer via factory");
......
...@@ -6,7 +6,6 @@ use sglang_router_rs::tool_parser::{DeepSeekParser, ParseState, StreamResult, To ...@@ -6,7 +6,6 @@ use sglang_router_rs::tool_parser::{DeepSeekParser, ParseState, StreamResult, To
async fn test_deepseek_complete_parsing() { async fn test_deepseek_complete_parsing() {
let parser = DeepSeekParser::new(); let parser = DeepSeekParser::new();
// Test single tool call
let input = r#"Let me help you with that. let input = r#"Let me help you with that.
<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather <|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather
```json ```json
...@@ -18,7 +17,6 @@ The weather in Tokyo is..."#; ...@@ -18,7 +17,6 @@ The weather in Tokyo is..."#;
assert_eq!(result.len(), 1); assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "get_weather"); assert_eq!(result[0].function.name, "get_weather");
// Verify arguments
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["location"], "Tokyo"); assert_eq!(args["location"], "Tokyo");
assert_eq!(args["units"], "celsius"); assert_eq!(args["units"], "celsius");
......
...@@ -167,8 +167,6 @@ async fn test_unicode_edge_cases() { ...@@ -167,8 +167,6 @@ async fn test_unicode_edge_cases() {
#[tokio::test] #[tokio::test]
async fn test_nested_brackets_in_strings() { async fn test_nested_brackets_in_strings() {
// Test that parsers correctly handle brackets within string literals
let mistral_parser = MistralParser::new(); let mistral_parser = MistralParser::new();
let input = r#"[TOOL_CALLS] [{"name": "echo", "arguments": {"text": "Array: [1, 2, 3]"}}]"#; let input = r#"[TOOL_CALLS] [{"name": "echo", "arguments": {"text": "Array: [1, 2, 3]"}}]"#;
let result = mistral_parser.parse_complete(input).await.unwrap(); let result = mistral_parser.parse_complete(input).await.unwrap();
...@@ -186,8 +184,6 @@ async fn test_nested_brackets_in_strings() { ...@@ -186,8 +184,6 @@ async fn test_nested_brackets_in_strings() {
#[tokio::test] #[tokio::test]
async fn test_multiple_formats_in_text() { async fn test_multiple_formats_in_text() {
// Test that parsers don't get confused by other formats in the text
let json_parser = JsonParser::new(); let json_parser = JsonParser::new();
let input = r#" let input = r#"
Here's some text with [TOOL_CALLS] that shouldn't trigger. Here's some text with [TOOL_CALLS] that shouldn't trigger.
...@@ -272,7 +268,6 @@ async fn test_partial_token_at_buffer_boundary() { ...@@ -272,7 +268,6 @@ async fn test_partial_token_at_buffer_boundary() {
let parser = QwenParser::new(); let parser = QwenParser::new();
let mut state = ParseState::new(); let mut state = ParseState::new();
// Test case that would fail with the bug:
// Send exactly "<tool" which is a 5-character prefix of "<tool_call>\n" // Send exactly "<tool" which is a 5-character prefix of "<tool_call>\n"
let result = parser.parse_incremental("<tool", &mut state).await.unwrap(); let result = parser.parse_incremental("<tool", &mut state).await.unwrap();
assert!(matches!(result, StreamResult::Incomplete)); assert!(matches!(result, StreamResult::Incomplete));
...@@ -303,7 +298,6 @@ async fn test_partial_token_at_buffer_boundary() { ...@@ -303,7 +298,6 @@ async fn test_partial_token_at_buffer_boundary() {
async fn test_exact_prefix_lengths() { async fn test_exact_prefix_lengths() {
let parser = QwenParser::new(); let parser = QwenParser::new();
// Test various exact prefix lengths that would be missed by exclusive range
let test_cases = vec![ let test_cases = vec![
("<", 1), // 1-char prefix ("<", 1), // 1-char prefix
("<t", 2), // 2-char prefix ("<t", 2), // 2-char prefix
......
...@@ -6,7 +6,6 @@ use sglang_router_rs::tool_parser::{Glm4MoeParser, ParseState, StreamResult, Too ...@@ -6,7 +6,6 @@ use sglang_router_rs::tool_parser::{Glm4MoeParser, ParseState, StreamResult, Too
async fn test_glm4_complete_parsing() { async fn test_glm4_complete_parsing() {
let parser = Glm4MoeParser::new(); let parser = Glm4MoeParser::new();
// Test single tool call
let input = r#"Let me search for that. let input = r#"Let me search for that.
<tool_call>get_weather <tool_call>get_weather
<arg_key>city</arg_key> <arg_key>city</arg_key>
...@@ -20,7 +19,6 @@ The weather will be..."#; ...@@ -20,7 +19,6 @@ The weather will be..."#;
assert_eq!(result.len(), 1); assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "get_weather"); assert_eq!(result[0].function.name, "get_weather");
// Verify arguments
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["city"], "Beijing"); assert_eq!(args["city"], "Beijing");
assert_eq!(args["date"], "2024-12-25"); assert_eq!(args["date"], "2024-12-25");
...@@ -51,7 +49,6 @@ async fn test_glm4_multiple_tools() { ...@@ -51,7 +49,6 @@ async fn test_glm4_multiple_tools() {
async fn test_glm4_type_conversion() { async fn test_glm4_type_conversion() {
let parser = Glm4MoeParser::new(); let parser = Glm4MoeParser::new();
// Test various value types
let input = r#"<tool_call>process let input = r#"<tool_call>process
<arg_key>count</arg_key> <arg_key>count</arg_key>
<arg_value>42</arg_value> <arg_value>42</arg_value>
...@@ -132,7 +129,6 @@ fn test_glm4_format_detection() { ...@@ -132,7 +129,6 @@ fn test_glm4_format_detection() {
async fn test_glm4_python_literal_values() { async fn test_glm4_python_literal_values() {
let parser = Glm4MoeParser::new(); let parser = Glm4MoeParser::new();
// Test Python-style boolean values
let input = r#"<tool_call>config let input = r#"<tool_call>config
<arg_key>debug</arg_key> <arg_key>debug</arg_key>
<arg_value>True</arg_value> <arg_value>True</arg_value>
......
...@@ -6,7 +6,6 @@ use sglang_router_rs::tool_parser::{GptOssParser, ParseState, StreamResult, Tool ...@@ -6,7 +6,6 @@ use sglang_router_rs::tool_parser::{GptOssParser, ParseState, StreamResult, Tool
async fn test_gpt_oss_complete_parsing() { async fn test_gpt_oss_complete_parsing() {
let parser = GptOssParser::new(); let parser = GptOssParser::new();
// Test single tool call
let input = r#"Let me search for that information. let input = r#"Let me search for that information.
<|channel|>commentary to=functions.search<|constrain|>json<|message|>{"query": "rust programming", "limit": 10}<|call|> <|channel|>commentary to=functions.search<|constrain|>json<|message|>{"query": "rust programming", "limit": 10}<|call|>
Here are the results..."#; Here are the results..."#;
...@@ -15,7 +14,6 @@ Here are the results..."#; ...@@ -15,7 +14,6 @@ Here are the results..."#;
assert_eq!(result.len(), 1); assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "search"); assert_eq!(result[0].function.name, "search");
// Verify arguments
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["query"], "rust programming"); assert_eq!(args["query"], "rust programming");
assert_eq!(args["limit"], 10); assert_eq!(args["limit"], 10);
...@@ -38,7 +36,6 @@ async fn test_gpt_oss_multiple_tools() { ...@@ -38,7 +36,6 @@ async fn test_gpt_oss_multiple_tools() {
async fn test_gpt_oss_with_namespace() { async fn test_gpt_oss_with_namespace() {
let parser = GptOssParser::new(); let parser = GptOssParser::new();
// Test with different namespace patterns
let input = r#"<|channel|>commentary to=api.users.create<|constrain|>json<|message|>{"name": "John", "email": "john@example.com"}<|call|> let input = r#"<|channel|>commentary to=api.users.create<|constrain|>json<|message|>{"name": "John", "email": "john@example.com"}<|call|>
<|channel|>commentary to=tools.calculator.add<|constrain|>json<|message|>{"x": 10, "y": 20}<|call|>"#; <|channel|>commentary to=tools.calculator.add<|constrain|>json<|message|>{"x": 10, "y": 20}<|call|>"#;
...@@ -52,7 +49,6 @@ async fn test_gpt_oss_with_namespace() { ...@@ -52,7 +49,6 @@ async fn test_gpt_oss_with_namespace() {
async fn test_gpt_oss_with_assistant_prefix() { async fn test_gpt_oss_with_assistant_prefix() {
let parser = GptOssParser::new(); let parser = GptOssParser::new();
// Test with <|start|>assistant prefix
let input = r#"<|start|>assistant<|channel|>commentary to=functions.test<|constrain|>json<|message|>{"key": "value"}<|call|>"#; let input = r#"<|start|>assistant<|channel|>commentary to=functions.test<|constrain|>json<|message|>{"key": "value"}<|call|>"#;
let result = parser.parse_complete(input).await.unwrap(); let result = parser.parse_complete(input).await.unwrap();
...@@ -64,7 +60,6 @@ async fn test_gpt_oss_with_assistant_prefix() { ...@@ -64,7 +60,6 @@ async fn test_gpt_oss_with_assistant_prefix() {
async fn test_gpt_oss_empty_args() { async fn test_gpt_oss_empty_args() {
let parser = GptOssParser::new(); let parser = GptOssParser::new();
// Test with empty arguments
let input = let input =
r#"<|channel|>commentary to=functions.get_time<|constrain|>json<|message|>{}<|call|>"#; r#"<|channel|>commentary to=functions.get_time<|constrain|>json<|message|>{}<|call|>"#;
...@@ -130,7 +125,6 @@ fn test_gpt_oss_format_detection() { ...@@ -130,7 +125,6 @@ fn test_gpt_oss_format_detection() {
async fn test_gpt_oss_with_whitespace() { async fn test_gpt_oss_with_whitespace() {
let parser = GptOssParser::new(); let parser = GptOssParser::new();
// Test with whitespace after function name
let input = r#"<|channel|>commentary to=functions.test <|constrain|>json<|message|>{"key": "value"}<|call|>"#; let input = r#"<|channel|>commentary to=functions.test <|constrain|>json<|message|>{"key": "value"}<|call|>"#;
let result = parser.parse_complete(input).await.unwrap(); let result = parser.parse_complete(input).await.unwrap();
...@@ -142,7 +136,6 @@ async fn test_gpt_oss_with_whitespace() { ...@@ -142,7 +136,6 @@ async fn test_gpt_oss_with_whitespace() {
async fn test_gpt_oss_complex_json() { async fn test_gpt_oss_complex_json() {
let parser = GptOssParser::new(); let parser = GptOssParser::new();
// Test with complex nested JSON
let input = r#"<|channel|>commentary to=functions.process<|constrain|>json<|message|>{ let input = r#"<|channel|>commentary to=functions.process<|constrain|>json<|message|>{
"nested": { "nested": {
"data": [1, 2, 3], "data": [1, 2, 3],
......
...@@ -6,7 +6,6 @@ use sglang_router_rs::tool_parser::{KimiK2Parser, ParseState, StreamResult, Tool ...@@ -6,7 +6,6 @@ use sglang_router_rs::tool_parser::{KimiK2Parser, ParseState, StreamResult, Tool
async fn test_kimik2_complete_parsing() { async fn test_kimik2_complete_parsing() {
let parser = KimiK2Parser::new(); let parser = KimiK2Parser::new();
// Test single tool call
let input = r#"Let me help you with that. let input = r#"Let me help you with that.
<|tool_calls_section_begin|> <|tool_calls_section_begin|>
<|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"location": "Tokyo", "units": "celsius"}<|tool_call_end|> <|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"location": "Tokyo", "units": "celsius"}<|tool_call_end|>
...@@ -17,7 +16,6 @@ The weather in Tokyo is..."#; ...@@ -17,7 +16,6 @@ The weather in Tokyo is..."#;
assert_eq!(result.len(), 1); assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "get_weather"); assert_eq!(result[0].function.name, "get_weather");
// Verify arguments
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["location"], "Tokyo"); assert_eq!(args["location"], "Tokyo");
assert_eq!(args["units"], "celsius"); assert_eq!(args["units"], "celsius");
...@@ -42,7 +40,6 @@ async fn test_kimik2_multiple_tools() { ...@@ -42,7 +40,6 @@ async fn test_kimik2_multiple_tools() {
async fn test_kimik2_with_whitespace() { async fn test_kimik2_with_whitespace() {
let parser = KimiK2Parser::new(); let parser = KimiK2Parser::new();
// Test with extra whitespace
let input = r#"<|tool_calls_section_begin|> let input = r#"<|tool_calls_section_begin|>
<|tool_call_begin|> functions.test:0 <|tool_call_argument_begin|> {"key": "value", "num": 42} <|tool_call_end|> <|tool_call_begin|> functions.test:0 <|tool_call_argument_begin|> {"key": "value", "num": 42} <|tool_call_end|>
<|tool_calls_section_end|>"#; <|tool_calls_section_end|>"#;
...@@ -114,7 +111,6 @@ fn test_kimik2_format_detection() { ...@@ -114,7 +111,6 @@ fn test_kimik2_format_detection() {
async fn test_kimik2_sequential_indices() { async fn test_kimik2_sequential_indices() {
let parser = KimiK2Parser::new(); let parser = KimiK2Parser::new();
// Test with proper sequential indexing
let input = r#"<|tool_calls_section_begin|> let input = r#"<|tool_calls_section_begin|>
<|tool_call_begin|>functions.first:0<|tool_call_argument_begin|>{"param": "a"}<|tool_call_end|> <|tool_call_begin|>functions.first:0<|tool_call_argument_begin|>{"param": "a"}<|tool_call_end|>
<|tool_call_begin|>functions.second:1<|tool_call_argument_begin|>{"param": "b"}<|tool_call_end|> <|tool_call_begin|>functions.second:1<|tool_call_argument_begin|>{"param": "b"}<|tool_call_end|>
......
...@@ -116,7 +116,6 @@ async fn test_llama_real_world_output() { ...@@ -116,7 +116,6 @@ async fn test_llama_real_world_output() {
assert_eq!(result.len(), 1); assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "web_search"); assert_eq!(result[0].function.name, "web_search");
// Test with nicely formatted JSON
let formatted_input = r#"<|python_tag|>{ let formatted_input = r#"<|python_tag|>{
"name": "get_current_time", "name": "get_current_time",
"arguments": { "arguments": {
...@@ -144,7 +143,6 @@ async fn test_llama_json_array_format() { ...@@ -144,7 +143,6 @@ async fn test_llama_json_array_format() {
#[tokio::test] #[tokio::test]
async fn test_single_json() { async fn test_single_json() {
// Test parsing plain JSON without python_tag
let parser = LlamaParser::new(); let parser = LlamaParser::new();
let text = r#"{"name": "get_weather", "arguments": {"city": "Paris"}}"#; let text = r#"{"name": "get_weather", "arguments": {"city": "Paris"}}"#;
...@@ -158,7 +156,6 @@ async fn test_single_json() { ...@@ -158,7 +156,6 @@ async fn test_single_json() {
#[tokio::test] #[tokio::test]
async fn test_multiple_json_with_separator() { async fn test_multiple_json_with_separator() {
// Test multiple JSON objects with semicolon separator
let parser = LlamaParser::new(); let parser = LlamaParser::new();
let text = r#"<|python_tag|>{"name": "get_weather", "arguments": {"city": "Paris"}};{"name": "get_tourist_attractions", "arguments": {"city": "Paris"}}"#; let text = r#"<|python_tag|>{"name": "get_weather", "arguments": {"city": "Paris"}};{"name": "get_tourist_attractions", "arguments": {"city": "Paris"}}"#;
...@@ -170,7 +167,6 @@ async fn test_multiple_json_with_separator() { ...@@ -170,7 +167,6 @@ async fn test_multiple_json_with_separator() {
#[tokio::test] #[tokio::test]
async fn test_multiple_json_with_separator_customized() { async fn test_multiple_json_with_separator_customized() {
// Test multiple JSON objects with python_tag repeated
let parser = LlamaParser::new(); let parser = LlamaParser::new();
let text = r#"<|python_tag|>{"name": "get_weather", "arguments": {}}<|python_tag|>{"name": "get_tourist_attractions", "arguments": {}}"#; let text = r#"<|python_tag|>{"name": "get_weather", "arguments": {}}<|python_tag|>{"name": "get_tourist_attractions", "arguments": {}}"#;
...@@ -182,7 +178,6 @@ async fn test_multiple_json_with_separator_customized() { ...@@ -182,7 +178,6 @@ async fn test_multiple_json_with_separator_customized() {
#[tokio::test] #[tokio::test]
async fn test_json_with_trailing_text() { async fn test_json_with_trailing_text() {
// Test JSON with trailing text after
let parser = LlamaParser::new(); let parser = LlamaParser::new();
let text = r#"{"name": "get_weather", "arguments": {}} Some follow-up text"#; let text = r#"{"name": "get_weather", "arguments": {}} Some follow-up text"#;
...@@ -193,7 +188,6 @@ async fn test_json_with_trailing_text() { ...@@ -193,7 +188,6 @@ async fn test_json_with_trailing_text() {
#[tokio::test] #[tokio::test]
async fn test_invalid_then_valid_json() { async fn test_invalid_then_valid_json() {
// Test error recovery - invalid JSON followed by valid JSON
let parser = LlamaParser::new(); let parser = LlamaParser::new();
let text = r#"{"name": "get_weather", "arguments": {{"name": "get_weather", "arguments": {}}"#; let text = r#"{"name": "get_weather", "arguments": {{"name": "get_weather", "arguments": {}}"#;
...@@ -206,7 +200,6 @@ async fn test_invalid_then_valid_json() { ...@@ -206,7 +200,6 @@ async fn test_invalid_then_valid_json() {
#[tokio::test] #[tokio::test]
async fn test_plain_text_only() { async fn test_plain_text_only() {
// Test plain text with no tool calls
let parser = LlamaParser::new(); let parser = LlamaParser::new();
let text = "This is just plain explanation text."; let text = "This is just plain explanation text.";
...@@ -216,7 +209,6 @@ async fn test_plain_text_only() { ...@@ -216,7 +209,6 @@ async fn test_plain_text_only() {
#[tokio::test] #[tokio::test]
async fn test_with_python_tag_prefix() { async fn test_with_python_tag_prefix() {
// Test text before python_tag
let parser = LlamaParser::new(); let parser = LlamaParser::new();
let text = r#"Some intro. <|python_tag|>{"name": "get_weather", "arguments": {}}"#; let text = r#"Some intro. <|python_tag|>{"name": "get_weather", "arguments": {}}"#;
...@@ -225,9 +217,7 @@ async fn test_with_python_tag_prefix() { ...@@ -225,9 +217,7 @@ async fn test_with_python_tag_prefix() {
assert_eq!(result[0].function.name, "get_weather"); assert_eq!(result[0].function.name, "get_weather");
} }
// ============================================================================
// STREAMING TESTS // STREAMING TESTS
// ============================================================================
#[tokio::test] #[tokio::test]
async fn test_llama_streaming_simple() { async fn test_llama_streaming_simple() {
...@@ -332,7 +322,6 @@ async fn test_llama_streaming_with_text_before() { ...@@ -332,7 +322,6 @@ async fn test_llama_streaming_with_text_before() {
#[tokio::test] #[tokio::test]
async fn test_llama_streaming_multiple_tools() { async fn test_llama_streaming_multiple_tools() {
// Test streaming multiple tool calls with semicolon separator
let parser = LlamaParser::new(); let parser = LlamaParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new(); let mut state = sglang_router_rs::tool_parser::ParseState::new();
...@@ -361,7 +350,6 @@ async fn test_llama_streaming_multiple_tools() { ...@@ -361,7 +350,6 @@ async fn test_llama_streaming_multiple_tools() {
#[tokio::test] #[tokio::test]
async fn test_llama_streaming_multiple_tools_chunked() { async fn test_llama_streaming_multiple_tools_chunked() {
// Test streaming multiple tool calls arriving in chunks
let parser = LlamaParser::new(); let parser = LlamaParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new(); let mut state = sglang_router_rs::tool_parser::ParseState::new();
......
...@@ -10,8 +10,6 @@ use sglang_router_rs::tool_parser::{ ...@@ -10,8 +10,6 @@ use sglang_router_rs::tool_parser::{
#[tokio::test] #[tokio::test]
async fn test_mixed_formats_in_text() { async fn test_mixed_formats_in_text() {
// Test that parsers correctly ignore other formats' markers
let json_parser = JsonParser::new(); let json_parser = JsonParser::new();
let input = r#" let input = r#"
Some text with [TOOL_CALLS] marker that shouldn't trigger. Some text with [TOOL_CALLS] marker that shouldn't trigger.
...@@ -37,8 +35,6 @@ async fn test_mixed_formats_in_text() { ...@@ -37,8 +35,6 @@ async fn test_mixed_formats_in_text() {
#[tokio::test] #[tokio::test]
async fn test_format_markers_in_string_content() { async fn test_format_markers_in_string_content() {
// Test that format markers inside string content don't interfere
let pythonic_parser = PythonicParser::new(); let pythonic_parser = PythonicParser::new();
let input = r#"[echo(text="Use [TOOL_CALLS] and <tool_call> in text")]"#; let input = r#"[echo(text="Use [TOOL_CALLS] and <tool_call> in text")]"#;
...@@ -101,7 +97,6 @@ async fn test_multiple_sequential_calls_different_formats() { ...@@ -101,7 +97,6 @@ async fn test_multiple_sequential_calls_different_formats() {
assert_eq!(result.len(), 1); assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "call1"); assert_eq!(result[0].function.name, "call1");
// Test plain JSON separately
let input2 = r#"{"name": "call2", "arguments": {"x": 1}}"#; let input2 = r#"{"name": "call2", "arguments": {"x": 1}}"#;
let result2 = llama_parser.parse_complete(input2).await.unwrap(); let result2 = llama_parser.parse_complete(input2).await.unwrap();
assert_eq!(result2.len(), 1); assert_eq!(result2.len(), 1);
...@@ -133,7 +128,6 @@ async fn test_empty_and_whitespace_variations() { ...@@ -133,7 +128,6 @@ async fn test_empty_and_whitespace_variations() {
async fn test_special_json_values() { async fn test_special_json_values() {
let json_parser = JsonParser::new(); let json_parser = JsonParser::new();
// Test various special JSON values
let input = r#"{ let input = r#"{
"name": "test_special", "name": "test_special",
"arguments": { "arguments": {
...@@ -183,8 +177,6 @@ async fn test_parser_recovery_after_invalid_input() { ...@@ -183,8 +177,6 @@ async fn test_parser_recovery_after_invalid_input() {
#[tokio::test] #[tokio::test]
async fn test_boundary_cases_for_extraction() { async fn test_boundary_cases_for_extraction() {
// Test edge cases in JSON extraction from text
let json_parser = JsonParser::new(); let json_parser = JsonParser::new();
// JSON at the very beginning // JSON at the very beginning
...@@ -259,7 +251,6 @@ async fn test_mistral_with_pretty_json() { ...@@ -259,7 +251,6 @@ async fn test_mistral_with_pretty_json() {
async fn test_qwen_with_cdata_like_content() { async fn test_qwen_with_cdata_like_content() {
let parser = QwenParser::new(); let parser = QwenParser::new();
// Test with content that looks like CDATA but isn't
// Note: QwenParser expects exactly "<tool_call>\n" with the newline // Note: QwenParser expects exactly "<tool_call>\n" with the newline
let input = r#"<tool_call> let input = r#"<tool_call>
{"name": "process", "arguments": {"xml": "<![CDATA[some data]]>"}} {"name": "process", "arguments": {"xml": "<![CDATA[some data]]>"}}
......
...@@ -180,7 +180,6 @@ These functions will provide the information you need."#; ...@@ -180,7 +180,6 @@ These functions will provide the information you need."#;
async fn test_pythonic_nested_brackets_in_lists() { async fn test_pythonic_nested_brackets_in_lists() {
let parser = PythonicParser::new(); let parser = PythonicParser::new();
// Test nested brackets within list arguments
let input = r#"[process_matrix(data=[[1, 2], [3, 4]], labels=["row[0]", "row[1]"])]"#; let input = r#"[process_matrix(data=[[1, 2], [3, 4]], labels=["row[0]", "row[1]"])]"#;
let result = parser.parse_complete(input).await.unwrap(); let result = parser.parse_complete(input).await.unwrap();
...@@ -196,7 +195,6 @@ async fn test_pythonic_nested_brackets_in_lists() { ...@@ -196,7 +195,6 @@ async fn test_pythonic_nested_brackets_in_lists() {
async fn test_pythonic_nested_brackets_in_dicts() { async fn test_pythonic_nested_brackets_in_dicts() {
let parser = PythonicParser::new(); let parser = PythonicParser::new();
// Test nested brackets within dictionary arguments
let input = let input =
r#"[analyze(config={"patterns": ["[a-z]+", "[0-9]+"], "nested": {"list": [1, [2, 3]]}})]"#; r#"[analyze(config={"patterns": ["[a-z]+", "[0-9]+"], "nested": {"list": [1, [2, 3]]}})]"#;
...@@ -213,7 +211,6 @@ async fn test_pythonic_nested_brackets_in_dicts() { ...@@ -213,7 +211,6 @@ async fn test_pythonic_nested_brackets_in_dicts() {
async fn test_pythonic_mixed_quotes() { async fn test_pythonic_mixed_quotes() {
let parser = PythonicParser::new(); let parser = PythonicParser::new();
// Test mixed quote types in arguments
let input = r#"[format_text(single='Hello', double="World", mixed="It's \"quoted\"")]"#; let input = r#"[format_text(single='Hello', double="World", mixed="It's \"quoted\"")]"#;
let result = parser.parse_complete(input).await.unwrap(); let result = parser.parse_complete(input).await.unwrap();
...@@ -230,7 +227,6 @@ async fn test_pythonic_mixed_quotes() { ...@@ -230,7 +227,6 @@ async fn test_pythonic_mixed_quotes() {
async fn test_pythonic_complex_nesting() { async fn test_pythonic_complex_nesting() {
let parser = PythonicParser::new(); let parser = PythonicParser::new();
// Test complex nested structures
let input = r#"[transform( let input = r#"[transform(
matrix=[[1, [2, 3]], [4, [5, [6, 7]]]], matrix=[[1, [2, 3]], [4, [5, [6, 7]]]],
operations=[{"type": "scale", "factor": [2, 3]}, {"type": "rotate", "angle": 90}], operations=[{"type": "scale", "factor": [2, 3]}, {"type": "rotate", "angle": 90}],
...@@ -250,7 +246,6 @@ async fn test_pythonic_complex_nesting() { ...@@ -250,7 +246,6 @@ async fn test_pythonic_complex_nesting() {
#[tokio::test] #[tokio::test]
async fn test_parse_streaming_no_brackets() { async fn test_parse_streaming_no_brackets() {
// Test parsing text with no brackets (no tool calls)
let parser = PythonicParser::new(); let parser = PythonicParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new(); let mut state = sglang_router_rs::tool_parser::ParseState::new();
...@@ -268,7 +263,6 @@ async fn test_parse_streaming_no_brackets() { ...@@ -268,7 +263,6 @@ async fn test_parse_streaming_no_brackets() {
#[tokio::test] #[tokio::test]
async fn test_parse_streaming_complete_tool_call() { async fn test_parse_streaming_complete_tool_call() {
// Test parsing a complete tool call
let parser = PythonicParser::new(); let parser = PythonicParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new(); let mut state = sglang_router_rs::tool_parser::ParseState::new();
...@@ -289,7 +283,6 @@ async fn test_parse_streaming_complete_tool_call() { ...@@ -289,7 +283,6 @@ async fn test_parse_streaming_complete_tool_call() {
#[tokio::test] #[tokio::test]
async fn test_parse_streaming_text_before_tool_call() { async fn test_parse_streaming_text_before_tool_call() {
// Test parsing text that appears before a tool call
let parser = PythonicParser::new(); let parser = PythonicParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new(); let mut state = sglang_router_rs::tool_parser::ParseState::new();
...@@ -308,7 +301,6 @@ async fn test_parse_streaming_text_before_tool_call() { ...@@ -308,7 +301,6 @@ async fn test_parse_streaming_text_before_tool_call() {
#[tokio::test] #[tokio::test]
async fn test_parse_streaming_partial_tool_call() { async fn test_parse_streaming_partial_tool_call() {
// Test parsing a partial tool call that spans multiple chunks
let parser = PythonicParser::new(); let parser = PythonicParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new(); let mut state = sglang_router_rs::tool_parser::ParseState::new();
...@@ -340,7 +332,6 @@ async fn test_parse_streaming_partial_tool_call() { ...@@ -340,7 +332,6 @@ async fn test_parse_streaming_partial_tool_call() {
#[tokio::test] #[tokio::test]
async fn test_parse_streaming_bracket_without_text_before() { async fn test_parse_streaming_bracket_without_text_before() {
// Test parsing a tool call that starts at the beginning of the text
let parser = PythonicParser::new(); let parser = PythonicParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new(); let mut state = sglang_router_rs::tool_parser::ParseState::new();
...@@ -359,7 +350,6 @@ async fn test_parse_streaming_bracket_without_text_before() { ...@@ -359,7 +350,6 @@ async fn test_parse_streaming_bracket_without_text_before() {
#[tokio::test] #[tokio::test]
async fn test_parse_streaming_text_after_tool_call() { async fn test_parse_streaming_text_after_tool_call() {
// Test parsing text that appears after a tool call
let parser = PythonicParser::new(); let parser = PythonicParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new(); let mut state = sglang_router_rs::tool_parser::ParseState::new();
...@@ -379,7 +369,6 @@ async fn test_parse_streaming_text_after_tool_call() { ...@@ -379,7 +369,6 @@ async fn test_parse_streaming_text_after_tool_call() {
#[tokio::test] #[tokio::test]
async fn test_parse_streaming_multiple_tool_calls() { async fn test_parse_streaming_multiple_tool_calls() {
// Test parsing multiple tool calls in sequence
let parser = PythonicParser::new(); let parser = PythonicParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new(); let mut state = sglang_router_rs::tool_parser::ParseState::new();
...@@ -401,7 +390,6 @@ async fn test_parse_streaming_multiple_tool_calls() { ...@@ -401,7 +390,6 @@ async fn test_parse_streaming_multiple_tool_calls() {
#[tokio::test] #[tokio::test]
async fn test_parse_streaming_opening_bracket_only() { async fn test_parse_streaming_opening_bracket_only() {
// Test parsing text with only an opening bracket but no closing bracket
let parser = PythonicParser::new(); let parser = PythonicParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new(); let mut state = sglang_router_rs::tool_parser::ParseState::new();
...@@ -418,7 +406,6 @@ async fn test_parse_streaming_opening_bracket_only() { ...@@ -418,7 +406,6 @@ async fn test_parse_streaming_opening_bracket_only() {
#[tokio::test] #[tokio::test]
async fn test_parse_streaming_nested_brackets() { async fn test_parse_streaming_nested_brackets() {
// Test parsing tool calls with nested brackets in arguments
let parser = PythonicParser::new(); let parser = PythonicParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new(); let mut state = sglang_router_rs::tool_parser::ParseState::new();
...@@ -439,7 +426,6 @@ async fn test_parse_streaming_nested_brackets() { ...@@ -439,7 +426,6 @@ async fn test_parse_streaming_nested_brackets() {
#[tokio::test] #[tokio::test]
async fn test_parse_streaming_nested_brackets_dict() { async fn test_parse_streaming_nested_brackets_dict() {
// Test parsing tool calls with nested dictionaries and lists
let parser = PythonicParser::new(); let parser = PythonicParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new(); let mut state = sglang_router_rs::tool_parser::ParseState::new();
...@@ -460,7 +446,6 @@ async fn test_parse_streaming_nested_brackets_dict() { ...@@ -460,7 +446,6 @@ async fn test_parse_streaming_nested_brackets_dict() {
#[tokio::test] #[tokio::test]
async fn test_parse_streaming_multiple_tools_with_nested_brackets() { async fn test_parse_streaming_multiple_tools_with_nested_brackets() {
// Test parsing multiple tool calls with nested brackets
let parser = PythonicParser::new(); let parser = PythonicParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new(); let mut state = sglang_router_rs::tool_parser::ParseState::new();
...@@ -480,7 +465,6 @@ async fn test_parse_streaming_multiple_tools_with_nested_brackets() { ...@@ -480,7 +465,6 @@ async fn test_parse_streaming_multiple_tools_with_nested_brackets() {
#[tokio::test] #[tokio::test]
async fn test_parse_streaming_partial_nested_brackets() { async fn test_parse_streaming_partial_nested_brackets() {
// Test parsing partial tool calls with nested brackets across chunks
let parser = PythonicParser::new(); let parser = PythonicParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new(); let mut state = sglang_router_rs::tool_parser::ParseState::new();
...@@ -514,7 +498,6 @@ async fn test_parse_streaming_partial_nested_brackets() { ...@@ -514,7 +498,6 @@ async fn test_parse_streaming_partial_nested_brackets() {
#[tokio::test] #[tokio::test]
async fn test_parse_streaming_with_python_start_and_end_token() { async fn test_parse_streaming_with_python_start_and_end_token() {
// Test parsing a message that starts with <|python_start|> and <|python_end|> across chunks
let parser = PythonicParser::new(); let parser = PythonicParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new(); let mut state = sglang_router_rs::tool_parser::ParseState::new();
...@@ -544,7 +527,6 @@ async fn test_parse_streaming_with_python_start_and_end_token() { ...@@ -544,7 +527,6 @@ async fn test_parse_streaming_with_python_start_and_end_token() {
#[tokio::test] #[tokio::test]
async fn test_detect_and_parse_with_python_start_and_end_token() { async fn test_detect_and_parse_with_python_start_and_end_token() {
// Test parsing a message that starts with <|python_start|> and contains a valid tool call
let parser = PythonicParser::new(); let parser = PythonicParser::new();
let text = "User wants to get the weather in Mars. <|python_start|>[get_weather(location='Mars', unit='celsius')]<|python_end|> In this way we will get the weather in Mars."; let text = "User wants to get the weather in Mars. <|python_start|>[get_weather(location='Mars', unit='celsius')]<|python_end|> In this way we will get the weather in Mars.";
......
...@@ -189,7 +189,6 @@ async fn test_buffer_drain_optimization() { ...@@ -189,7 +189,6 @@ async fn test_buffer_drain_optimization() {
// First chunk - incomplete tool call // First chunk - incomplete tool call
let chunk1 = "<tool_call>\n{\"name\": \"test1\", "; let chunk1 = "<tool_call>\n{\"name\": \"test1\", ";
let _result = parser.parse_incremental(chunk1, &mut state).await.unwrap(); let _result = parser.parse_incremental(chunk1, &mut state).await.unwrap();
// Phase 2 simplified streaming might not handle partial JSON correctly
// The important thing is buffer accumulation works // The important thing is buffer accumulation works
assert!(!state.buffer.is_empty()); assert!(!state.buffer.is_empty());
...@@ -197,32 +196,23 @@ async fn test_buffer_drain_optimization() { ...@@ -197,32 +196,23 @@ async fn test_buffer_drain_optimization() {
let chunk2 = "\"arguments\": {}}\n</tool_call><tool_call>\n{\"name\": \"test2\", "; let chunk2 = "\"arguments\": {}}\n</tool_call><tool_call>\n{\"name\": \"test2\", ";
let result = parser.parse_incremental(chunk2, &mut state).await.unwrap(); let result = parser.parse_incremental(chunk2, &mut state).await.unwrap();
match result { if let StreamResult::ToolComplete(tool) = result {
StreamResult::ToolComplete(tool) => { assert_eq!(tool.function.name, "test1");
assert_eq!(tool.function.name, "test1"); // After consuming the first tool, buffer should contain only the second tool start
// After consuming the first tool, buffer should contain only the second tool start assert!(state.buffer.starts_with("<tool_call>"));
assert!(state.buffer.starts_with("<tool_call>")); assert!(state.buffer.contains("test2"));
assert!(state.buffer.contains("test2")); } else {
} // The important thing is the buffer is managed correctly
_ => {
// Phase 2 simplified streaming might return Incomplete
// The important thing is the buffer is managed correctly
}
} }
// Complete the second tool // Complete the second tool
let chunk3 = "\"arguments\": {\"x\": 1}}\n</tool_call>"; let chunk3 = "\"arguments\": {\"x\": 1}}\n</tool_call>";
let result = parser.parse_incremental(chunk3, &mut state).await.unwrap(); let result = parser.parse_incremental(chunk3, &mut state).await.unwrap();
match result { if let StreamResult::ToolComplete(tool) = result {
StreamResult::ToolComplete(tool) => { assert_eq!(tool.function.name, "test2");
assert_eq!(tool.function.name, "test2"); // Buffer should be empty after consuming all tools
// Buffer should be empty after consuming all tools assert!(state.buffer.is_empty() || !state.buffer.contains("</tool_call>"));
assert!(state.buffer.is_empty() || !state.buffer.contains("</tool_call>"));
}
_ => {
// Phase 2 simplified streaming might handle this differently
}
} }
} }
...@@ -253,7 +243,4 @@ async fn test_buffer_efficiency_with_multiple_tools() { ...@@ -253,7 +243,4 @@ async fn test_buffer_efficiency_with_multiple_tools() {
// Simplified streaming might return Incomplete // Simplified streaming might return Incomplete
} }
} }
// Verify no memory issues or panics occurred with drain()
// Test passes if we reach this point without panic
} }
...@@ -126,7 +126,6 @@ async fn test_unknown_model_fallback() { ...@@ -126,7 +126,6 @@ async fn test_unknown_model_fallback() {
async fn test_pattern_specificity() { async fn test_pattern_specificity() {
let registry = ParserRegistry::new(); let registry = ParserRegistry::new();
// Test that more specific patterns take precedence
// llama-4* should match before llama-* // llama-4* should match before llama-*
let parser = registry.get_parser("llama-4-70b").unwrap(); let parser = registry.get_parser("llama-4-70b").unwrap();
assert!(parser.detect_format(r#"[test_function(x=1)]"#)); // Pythonic format assert!(parser.detect_format(r#"[test_function(x=1)]"#)); // Pythonic format
...@@ -139,7 +138,6 @@ async fn test_pattern_specificity() { ...@@ -139,7 +138,6 @@ async fn test_pattern_specificity() {
async fn test_real_world_model_outputs() { async fn test_real_world_model_outputs() {
let registry = ParserRegistry::new(); let registry = ParserRegistry::new();
// Test with realistic outputs from different models
let test_cases = vec![ let test_cases = vec![
( (
"gpt-4", "gpt-4",
......
...@@ -6,7 +6,6 @@ use sglang_router_rs::tool_parser::{ParseState, Step3Parser, StreamResult, ToolP ...@@ -6,7 +6,6 @@ use sglang_router_rs::tool_parser::{ParseState, Step3Parser, StreamResult, ToolP
async fn test_step3_complete_parsing() { async fn test_step3_complete_parsing() {
let parser = Step3Parser::new(); let parser = Step3Parser::new();
// Test single tool call
let input = r#"Let me help you. let input = r#"Let me help you.
<|tool_calls_begin|> <|tool_calls_begin|>
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="search"> <|tool_call_begin|>function<|tool_sep|><steptml:invoke name="search">
...@@ -20,7 +19,6 @@ Here are the results..."#; ...@@ -20,7 +19,6 @@ Here are the results..."#;
assert_eq!(result.len(), 1); assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "search"); assert_eq!(result[0].function.name, "search");
// Verify arguments
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["query"], "rust programming"); assert_eq!(args["query"], "rust programming");
assert_eq!(args["limit"], 10); assert_eq!(args["limit"], 10);
...@@ -127,7 +125,6 @@ fn test_step3_format_detection() { ...@@ -127,7 +125,6 @@ fn test_step3_format_detection() {
async fn test_step3_nested_steptml() { async fn test_step3_nested_steptml() {
let parser = Step3Parser::new(); let parser = Step3Parser::new();
// Test with complex parameter values
let input = r#"<|tool_calls_begin|> let input = r#"<|tool_calls_begin|>
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="config"> <|tool_call_begin|>function<|tool_sep|><steptml:invoke name="config">
<steptml:parameter name="settings">{"nested": {"key": "value"}}</steptml:parameter> <steptml:parameter name="settings">{"nested": {"key": "value"}}</steptml:parameter>
...@@ -148,7 +145,6 @@ async fn test_step3_nested_steptml() { ...@@ -148,7 +145,6 @@ async fn test_step3_nested_steptml() {
async fn test_step3_python_literals() { async fn test_step3_python_literals() {
let parser = Step3Parser::new(); let parser = Step3Parser::new();
// Test Python-style literals
let input = r#"<|tool_calls_begin|> let input = r#"<|tool_calls_begin|>
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="test"> <|tool_call_begin|>function<|tool_sep|><steptml:invoke name="test">
<steptml:parameter name="bool_true">True</steptml:parameter> <steptml:parameter name="bool_true">True</steptml:parameter>
...@@ -211,7 +207,6 @@ async fn test_json_parameter_values() { ...@@ -211,7 +207,6 @@ async fn test_json_parameter_values() {
async fn test_step3_parameter_with_angle_brackets() { async fn test_step3_parameter_with_angle_brackets() {
let parser = Step3Parser::new(); let parser = Step3Parser::new();
// Test parameter value containing < character
let input = r#"<|tool_calls_begin|> let input = r#"<|tool_calls_begin|>
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="compare"> <|tool_call_begin|>function<|tool_sep|><steptml:invoke name="compare">
<steptml:parameter name="expression">a < b && b > c</steptml:parameter> <steptml:parameter name="expression">a < b && b > c</steptml:parameter>
...@@ -223,7 +218,6 @@ async fn test_step3_parameter_with_angle_brackets() { ...@@ -223,7 +218,6 @@ async fn test_step3_parameter_with_angle_brackets() {
assert_eq!(result.len(), 1); assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "compare"); assert_eq!(result[0].function.name, "compare");
// Verify the parameter value was parsed correctly
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["expression"], "a < b && b > c"); assert_eq!(args["expression"], "a < b && b > c");
assert_eq!(args["context"], "comparison test"); assert_eq!(args["context"], "comparison test");
...@@ -233,7 +227,6 @@ async fn test_step3_parameter_with_angle_brackets() { ...@@ -233,7 +227,6 @@ async fn test_step3_parameter_with_angle_brackets() {
async fn test_step3_empty_function_name() { async fn test_step3_empty_function_name() {
let parser = Step3Parser::new(); let parser = Step3Parser::new();
// Test empty function name
let input = r#"<|tool_calls_begin|> let input = r#"<|tool_calls_begin|>
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name=""> <|tool_call_begin|>function<|tool_sep|><steptml:invoke name="">
<steptml:parameter name="param">value</steptml:parameter> <steptml:parameter name="param">value</steptml:parameter>
......
...@@ -12,8 +12,6 @@ async fn test_json_streaming_simple() { ...@@ -12,8 +12,6 @@ async fn test_json_streaming_simple() {
let parser = JsonParser::new(); let parser = JsonParser::new();
let mut state = ParseState::new(); let mut state = ParseState::new();
// Phase 2 note: This test sends the full JSON at once in the last chunk
// In real streaming, chunks would be smaller
let full_json = r#"{"name": "get_weather", "arguments": {"location": "San Francisco"}}"#; let full_json = r#"{"name": "get_weather", "arguments": {"location": "San Francisco"}}"#;
let result = parser let result = parser
...@@ -21,7 +19,6 @@ async fn test_json_streaming_simple() { ...@@ -21,7 +19,6 @@ async fn test_json_streaming_simple() {
.await .await
.unwrap(); .unwrap();
// With complete JSON sent at once, we should get ToolComplete
match result { match result {
StreamResult::ToolComplete(tool) => { StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "get_weather"); assert_eq!(tool.function.name, "get_weather");
...@@ -37,7 +34,6 @@ async fn test_json_streaming_array() { ...@@ -37,7 +34,6 @@ async fn test_json_streaming_array() {
let parser = JsonParser::new(); let parser = JsonParser::new();
let mut state = ParseState::new(); let mut state = ParseState::new();
// Stream a JSON array of tools
let chunks = vec![ let chunks = vec![
r#"["#, r#"["#,
r#"{"name": "tool1", "#, r#"{"name": "tool1", "#,
...@@ -57,7 +53,6 @@ async fn test_json_streaming_array() { ...@@ -57,7 +53,6 @@ async fn test_json_streaming_array() {
} }
// Current implementation may handle this differently // Current implementation may handle this differently
// We're mainly testing that it doesn't crash
assert!(tool_count <= 2, "Should parse at most 2 tools"); assert!(tool_count <= 2, "Should parse at most 2 tools");
} }
...@@ -95,7 +90,6 @@ async fn test_pythonic_streaming() { ...@@ -95,7 +90,6 @@ async fn test_pythonic_streaming() {
let parser = PythonicParser::new(); let parser = PythonicParser::new();
let mut state = ParseState::new(); let mut state = ParseState::new();
// Send complete pythonic format at once
let full_input = r#"[get_weather(city="London", units="celsius")]"#; let full_input = r#"[get_weather(city="London", units="celsius")]"#;
let result = parser let result = parser
...@@ -149,7 +143,6 @@ async fn test_qwen_streaming() { ...@@ -149,7 +143,6 @@ async fn test_qwen_streaming() {
let parser = QwenParser::new(); let parser = QwenParser::new();
let mut state = ParseState::new(); let mut state = ParseState::new();
// Send complete Qwen format at once (with exact format expected by parser)
// Note: Parser expects newline after both tags // Note: Parser expects newline after both tags
let full_input = "<tool_call>\n{\"name\": \"translate\", \"arguments\": {\"text\": \"hello\", \"to\": \"zh\"}}\n</tool_call>"; let full_input = "<tool_call>\n{\"name\": \"translate\", \"arguments\": {\"text\": \"hello\", \"to\": \"zh\"}}\n</tool_call>";
...@@ -176,12 +169,10 @@ async fn test_streaming_incomplete_stays_incomplete() { ...@@ -176,12 +169,10 @@ async fn test_streaming_incomplete_stays_incomplete() {
let parser = JsonParser::new(); let parser = JsonParser::new();
let mut state = ParseState::new(); let mut state = ParseState::new();
// Send truly incomplete JSON that can't be auto-completed
let chunks = vec![r#"{"na"#, r#"me": "#]; let chunks = vec![r#"{"na"#, r#"me": "#];
for chunk in chunks { for chunk in chunks {
let result = parser.parse_incremental(chunk, &mut state).await.unwrap(); let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
// Should return Incomplete for partial JSON that can't be auto-completed
assert!( assert!(
matches!(result, StreamResult::Incomplete), matches!(result, StreamResult::Incomplete),
"Should return Incomplete for partial JSON, got: {:?}", "Should return Incomplete for partial JSON, got: {:?}",
...@@ -189,7 +180,6 @@ async fn test_streaming_incomplete_stays_incomplete() { ...@@ -189,7 +180,6 @@ async fn test_streaming_incomplete_stays_incomplete() {
); );
} }
// Buffer should contain the accumulated incomplete JSON
assert!(!state.buffer.is_empty()); assert!(!state.buffer.is_empty());
} }
...@@ -198,8 +188,6 @@ async fn test_streaming_with_text_before_tool() { ...@@ -198,8 +188,6 @@ async fn test_streaming_with_text_before_tool() {
let parser = JsonParser::new(); let parser = JsonParser::new();
let mut state = ParseState::new(); let mut state = ParseState::new();
// For streaming, the parser expects clean JSON
// Mixed text extraction only works in parse_complete, not parse_incremental
let full_input = r#"{"name": "test", "arguments": {}}"#; let full_input = r#"{"name": "test", "arguments": {}}"#;
let result = parser let result = parser
...@@ -221,10 +209,8 @@ async fn test_streaming_with_text_before_tool() { ...@@ -221,10 +209,8 @@ async fn test_streaming_with_text_before_tool() {
async fn test_streaming_buffer_accumulation() { async fn test_streaming_buffer_accumulation() {
let parser = JsonParser::new(); let parser = JsonParser::new();
// Test: Complete JSON should clear buffer after parsing
let mut state = ParseState::new(); let mut state = ParseState::new();
// Send partial JSON that can't be interpreted as complete
let result1 = parser let result1 = parser
.parse_incremental(r#"{"na"#, &mut state) .parse_incremental(r#"{"na"#, &mut state)
.await .await
...@@ -236,7 +222,6 @@ async fn test_streaming_buffer_accumulation() { ...@@ -236,7 +222,6 @@ async fn test_streaming_buffer_accumulation() {
"Buffer should accumulate incomplete JSON" "Buffer should accumulate incomplete JSON"
); );
// Send rest of JSON
let result2 = parser let result2 = parser
.parse_incremental(r#"me": "test", "arguments": {}}"#, &mut state) .parse_incremental(r#"me": "test", "arguments": {}}"#, &mut state)
.await .await
...@@ -262,7 +247,6 @@ async fn test_streaming_multiple_tools_sequential() { ...@@ -262,7 +247,6 @@ async fn test_streaming_multiple_tools_sequential() {
let parser = QwenParser::new(); let parser = QwenParser::new();
let mut state = ParseState::new(); let mut state = ParseState::new();
// Send complete Qwen format with newlines
let full_input = r#"<tool_call> let full_input = r#"<tool_call>
{"name": "tool1", "arguments": {}} {"name": "tool1", "arguments": {}}
</tool_call>"#; </tool_call>"#;
...@@ -286,13 +270,11 @@ async fn test_streaming_multiple_tools_sequential() { ...@@ -286,13 +270,11 @@ async fn test_streaming_multiple_tools_sequential() {
async fn test_streaming_reset_after_error() { async fn test_streaming_reset_after_error() {
let parser = JsonParser::new(); let parser = JsonParser::new();
// First attempt with invalid JSON
let mut state1 = ParseState::new(); let mut state1 = ParseState::new();
let _ = parser let _ = parser
.parse_incremental(r#"{"name": invalid}"#, &mut state1) .parse_incremental(r#"{"name": invalid}"#, &mut state1)
.await; .await;
// Second attempt with valid JSON should work with fresh state
let mut state2 = ParseState::new(); let mut state2 = ParseState::new();
let result = parser let result = parser
.parse_incremental(r#"{"name": "test", "arguments": {}}"#, &mut state2) .parse_incremental(r#"{"name": "test", "arguments": {}}"#, &mut state2)
...@@ -309,7 +291,6 @@ async fn test_streaming_with_unicode_chunks() { ...@@ -309,7 +291,6 @@ async fn test_streaming_with_unicode_chunks() {
let parser = JsonParser::new(); let parser = JsonParser::new();
let mut state = ParseState::new(); let mut state = ParseState::new();
// Send complete JSON with unicode
let full_input = r#"{"name": "translate", "arguments": {"text": "Hello 世界 🌍"}}"#; let full_input = r#"{"name": "translate", "arguments": {"text": "Hello 世界 🌍"}}"#;
let result = parser let result = parser
...@@ -317,8 +298,6 @@ async fn test_streaming_with_unicode_chunks() { ...@@ -317,8 +298,6 @@ async fn test_streaming_with_unicode_chunks() {
.await .await
.unwrap(); .unwrap();
// Phase 2 may return partial results even with complete JSON
// The important thing is that unicode is handled without crashes
match result { match result {
StreamResult::ToolComplete(tool) => { StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "translate"); assert_eq!(tool.function.name, "translate");
...@@ -327,10 +306,8 @@ async fn test_streaming_with_unicode_chunks() { ...@@ -327,10 +306,8 @@ async fn test_streaming_with_unicode_chunks() {
} }
StreamResult::ToolName { name, .. } => { StreamResult::ToolName { name, .. } => {
assert_eq!(name, "translate"); assert_eq!(name, "translate");
// Phase 2 partial streaming behavior - acceptable
} }
StreamResult::ToolArguments { arguments, .. } => { StreamResult::ToolArguments { arguments, .. } => {
// Verify unicode was preserved
let args: serde_json::Value = serde_json::from_str(&arguments).unwrap(); let args: serde_json::Value = serde_json::from_str(&arguments).unwrap();
assert!(args["text"].as_str().unwrap().contains("世界")); assert!(args["text"].as_str().unwrap().contains("世界"));
} }
......
...@@ -25,20 +25,17 @@ async fn test_json_with_xml_style_wrapper() { ...@@ -25,20 +25,17 @@ async fn test_json_with_xml_style_wrapper() {
#[tokio::test] #[tokio::test]
async fn test_json_with_multiple_wrapper_pairs() { async fn test_json_with_multiple_wrapper_pairs() {
// Test with multiple start/end token pairs
let parser = JsonParser::with_config(TokenConfig { let parser = JsonParser::with_config(TokenConfig {
start_tokens: vec!["<tool>".to_string(), "<<TOOL>>".to_string()], start_tokens: vec!["<tool>".to_string(), "<<TOOL>>".to_string()],
end_tokens: vec!["</tool>".to_string(), "<</TOOL>>".to_string()], end_tokens: vec!["</tool>".to_string(), "<</TOOL>>".to_string()],
separator: ", ".to_string(), separator: ", ".to_string(),
}); });
// Test first pair
let input1 = r#"<tool>{"name": "tool1", "arguments": {}}</tool>"#; let input1 = r#"<tool>{"name": "tool1", "arguments": {}}</tool>"#;
let result1 = parser.parse_complete(input1).await.unwrap(); let result1 = parser.parse_complete(input1).await.unwrap();
assert_eq!(result1.len(), 1); assert_eq!(result1.len(), 1);
assert_eq!(result1[0].function.name, "tool1"); assert_eq!(result1[0].function.name, "tool1");
// Test second pair
let input2 = r#"<<TOOL>>{"name": "tool2", "arguments": {}}<</TOOL>>"#; let input2 = r#"<<TOOL>>{"name": "tool2", "arguments": {}}<</TOOL>>"#;
let result2 = parser.parse_complete(input2).await.unwrap(); let result2 = parser.parse_complete(input2).await.unwrap();
assert_eq!(result2.len(), 1); assert_eq!(result2.len(), 1);
...@@ -47,7 +44,6 @@ async fn test_json_with_multiple_wrapper_pairs() { ...@@ -47,7 +44,6 @@ async fn test_json_with_multiple_wrapper_pairs() {
#[tokio::test] #[tokio::test]
async fn test_json_with_only_start_token() { async fn test_json_with_only_start_token() {
// Test when only start token is provided (no end token)
let parser = JsonParser::with_config(TokenConfig { let parser = JsonParser::with_config(TokenConfig {
start_tokens: vec![">>>FUNCTION:".to_string()], start_tokens: vec![">>>FUNCTION:".to_string()],
end_tokens: vec!["".to_string()], // Empty end token end_tokens: vec!["".to_string()], // Empty end token
...@@ -232,7 +228,6 @@ async fn test_json_incomplete_wrapper_tokens() { ...@@ -232,7 +228,6 @@ async fn test_json_incomplete_wrapper_tokens() {
#[tokio::test] #[tokio::test]
async fn test_json_empty_wrapper_tokens() { async fn test_json_empty_wrapper_tokens() {
// Test with empty wrapper tokens (should behave like default)
let parser = JsonParser::with_config(TokenConfig { let parser = JsonParser::with_config(TokenConfig {
start_tokens: vec![], start_tokens: vec![],
end_tokens: vec![], end_tokens: vec![],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment