Unverified Commit aae7ead2 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] remove old/oudated/useless comments across code base (#10968)

parent a7fe6e10
...@@ -279,11 +279,9 @@ mod tests { ...@@ -279,11 +279,9 @@ mod tests {
#[test] #[test]
fn test_create_tiktoken_tokenizer() { fn test_create_tiktoken_tokenizer() {
// Test creating tokenizer for GPT models
let tokenizer = create_tokenizer("gpt-4").unwrap(); let tokenizer = create_tokenizer("gpt-4").unwrap();
assert!(tokenizer.vocab_size() > 0); assert!(tokenizer.vocab_size() > 0);
// Test encoding and decoding
let text = "Hello, world!"; let text = "Hello, world!";
let encoding = tokenizer.encode(text).unwrap(); let encoding = tokenizer.encode(text).unwrap();
let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap(); let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
...@@ -292,7 +290,6 @@ mod tests { ...@@ -292,7 +290,6 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_download_tokenizer_from_hf() { async fn test_download_tokenizer_from_hf() {
// Test with a small model that should have tokenizer files
// Skip this test if HF_TOKEN is not set and we're in CI // Skip this test if HF_TOKEN is not set and we're in CI
if std::env::var("CI").is_ok() && std::env::var("HF_TOKEN").is_err() { if std::env::var("CI").is_ok() && std::env::var("HF_TOKEN").is_err() {
println!("Skipping HF download test in CI without HF_TOKEN"); println!("Skipping HF download test in CI without HF_TOKEN");
......
...@@ -206,7 +206,6 @@ mod tests { ...@@ -206,7 +206,6 @@ mod tests {
// The incremental text should be " world" (with the space that the mock tokenizer adds) // The incremental text should be " world" (with the space that the mock tokenizer adds)
assert_eq!(text2, " world"); assert_eq!(text2, " world");
// Verify the full text
assert_eq!(seq.text().unwrap(), "Hello world"); assert_eq!(seq.text().unwrap(), "Hello world");
} }
......
...@@ -398,7 +398,6 @@ mod tests { ...@@ -398,7 +398,6 @@ mod tests {
// The fix ensures we only output NEW text, not accumulated text // The fix ensures we only output NEW text, not accumulated text
assert_eq!(outputs.len(), 3); assert_eq!(outputs.len(), 3);
// Verify no text is repeated
for i in 0..outputs.len() { for i in 0..outputs.len() {
for j in i + 1..outputs.len() { for j in i + 1..outputs.len() {
// No output should contain another (no accumulation) // No output should contain another (no accumulation)
......
...@@ -36,22 +36,17 @@ fn test_tokenizer_wrapper() { ...@@ -36,22 +36,17 @@ fn test_tokenizer_wrapper() {
let mock_tokenizer = Arc::new(mock::MockTokenizer::new()); let mock_tokenizer = Arc::new(mock::MockTokenizer::new());
let tokenizer = Tokenizer::from_arc(mock_tokenizer); let tokenizer = Tokenizer::from_arc(mock_tokenizer);
// Test encoding
let encoding = tokenizer.encode("Hello world").unwrap(); let encoding = tokenizer.encode("Hello world").unwrap();
assert_eq!(encoding.token_ids(), &[1, 2]); assert_eq!(encoding.token_ids(), &[1, 2]);
// Test decoding
let text = tokenizer.decode(&[1, 2], false).unwrap(); let text = tokenizer.decode(&[1, 2], false).unwrap();
assert_eq!(text, "Hello world"); assert_eq!(text, "Hello world");
// Test vocab size
assert_eq!(tokenizer.vocab_size(), 8); assert_eq!(tokenizer.vocab_size(), 8);
// Test token to ID
assert_eq!(tokenizer.token_to_id("Hello"), Some(1)); assert_eq!(tokenizer.token_to_id("Hello"), Some(1));
assert_eq!(tokenizer.token_to_id("unknown"), None); assert_eq!(tokenizer.token_to_id("unknown"), None);
// Test ID to token
assert_eq!(tokenizer.id_to_token(1), Some("Hello".to_string())); assert_eq!(tokenizer.id_to_token(1), Some("Hello".to_string()));
assert_eq!(tokenizer.id_to_token(9999), None); assert_eq!(tokenizer.id_to_token(9999), None);
} }
......
...@@ -246,7 +246,6 @@ mod tests { ...@@ -246,7 +246,6 @@ mod tests {
#[test] #[test]
fn test_unrecognized_model_name_returns_error() { fn test_unrecognized_model_name_returns_error() {
// Test that unrecognized model names return an error
let result = TiktokenTokenizer::from_model_name("distilgpt-2"); let result = TiktokenTokenizer::from_model_name("distilgpt-2");
assert!(result.is_err()); assert!(result.is_err());
if let Err(e) = result { if let Err(e) = result {
...@@ -268,7 +267,6 @@ mod tests { ...@@ -268,7 +267,6 @@ mod tests {
#[test] #[test]
fn test_recognized_model_names() { fn test_recognized_model_names() {
// Test that recognized model names work correctly
assert!(TiktokenTokenizer::from_model_name("gpt-4").is_ok()); assert!(TiktokenTokenizer::from_model_name("gpt-4").is_ok());
assert!(TiktokenTokenizer::from_model_name("gpt-3.5-turbo").is_ok()); assert!(TiktokenTokenizer::from_model_name("gpt-3.5-turbo").is_ok());
assert!(TiktokenTokenizer::from_model_name("text-davinci-003").is_ok()); assert!(TiktokenTokenizer::from_model_name("text-davinci-003").is_ok());
......
...@@ -139,7 +139,6 @@ mod tests { ...@@ -139,7 +139,6 @@ mod tests {
async fn test_single_call_with_semicolon() { async fn test_single_call_with_semicolon() {
let parser = LlamaParser::new(); let parser = LlamaParser::new();
// Note: Llama 3.2 doesn't handle multiple calls well // Note: Llama 3.2 doesn't handle multiple calls well
// Test that we can at least parse a single call followed by semicolon
let input = r#"<|python_tag|>{"name": "func1", "arguments": {"x": 1}};"#; let input = r#"<|python_tag|>{"name": "func1", "arguments": {"x": 1}};"#;
let result = parser.parse_complete(input).await.unwrap(); let result = parser.parse_complete(input).await.unwrap();
......
...@@ -102,7 +102,6 @@ impl PythonicParser { ...@@ -102,7 +102,6 @@ impl PythonicParser {
if bracket_count == 0 { if bracket_count == 0 {
// Found the matching bracket // Found the matching bracket
let extracted: String = chars[start_idx..=i].iter().collect(); let extracted: String = chars[start_idx..=i].iter().collect();
// Verify this actually contains a function call
if extracted.contains('(') && extracted.contains(')') { if extracted.contains('(') && extracted.contains(')') {
return Some(extracted); return Some(extracted);
} }
......
...@@ -21,21 +21,18 @@ fn test_parse_state_new() { ...@@ -21,21 +21,18 @@ fn test_parse_state_new() {
fn test_parse_state_process_char() { fn test_parse_state_process_char() {
let mut state = ParseState::new(); let mut state = ParseState::new();
// Test bracket tracking
state.process_char('{'); state.process_char('{');
assert_eq!(state.bracket_depth, 1); assert_eq!(state.bracket_depth, 1);
state.process_char('}'); state.process_char('}');
assert_eq!(state.bracket_depth, 0); assert_eq!(state.bracket_depth, 0);
// Test string tracking
state.process_char('"'); state.process_char('"');
assert!(state.in_string); assert!(state.in_string);
state.process_char('"'); state.process_char('"');
assert!(!state.in_string); assert!(!state.in_string);
// Test escape handling
state.process_char('"'); state.process_char('"');
state.process_char('\\'); state.process_char('\\');
assert!(state.escape_next); assert!(state.escape_next);
...@@ -63,10 +60,8 @@ fn test_token_config() { ...@@ -63,10 +60,8 @@ fn test_token_config() {
fn test_parser_registry() { fn test_parser_registry() {
let registry = ParserRegistry::new(); let registry = ParserRegistry::new();
// Test has default mappings
assert!(!registry.list_mappings().is_empty()); assert!(!registry.list_mappings().is_empty());
// Test model pattern matching
let mappings = registry.list_mappings(); let mappings = registry.list_mappings();
let has_gpt = mappings.iter().any(|(m, _)| m.starts_with("gpt")); let has_gpt = mappings.iter().any(|(m, _)| m.starts_with("gpt"));
assert!(has_gpt); assert!(has_gpt);
...@@ -76,10 +71,8 @@ fn test_parser_registry() { ...@@ -76,10 +71,8 @@ fn test_parser_registry() {
fn test_parser_registry_pattern_matching() { fn test_parser_registry_pattern_matching() {
let mut registry = ParserRegistry::new_for_testing(); let mut registry = ParserRegistry::new_for_testing();
// Test that model mappings work by checking the list
registry.map_model("test-model", "json"); registry.map_model("test-model", "json");
// Verify through list_mappings
let mappings = registry.list_mappings(); let mappings = registry.list_mappings();
let has_test = mappings let has_test = mappings
.iter() .iter()
...@@ -112,25 +105,21 @@ fn test_tool_call_serialization() { ...@@ -112,25 +105,21 @@ fn test_tool_call_serialization() {
fn test_partial_json_parser() { fn test_partial_json_parser() {
let parser = PartialJson::default(); let parser = PartialJson::default();
// Test complete JSON
let input = r#"{"name": "test", "value": 42}"#; let input = r#"{"name": "test", "value": 42}"#;
let (value, consumed) = parser.parse_value(input).unwrap(); let (value, consumed) = parser.parse_value(input).unwrap();
assert_eq!(value["name"], "test"); assert_eq!(value["name"], "test");
assert_eq!(value["value"], 42); assert_eq!(value["value"], 42);
assert_eq!(consumed, input.len()); assert_eq!(consumed, input.len());
// Test incomplete JSON object
let input = r#"{"name": "test", "value": "#; let input = r#"{"name": "test", "value": "#;
let (value, _consumed) = parser.parse_value(input).unwrap(); let (value, _consumed) = parser.parse_value(input).unwrap();
assert_eq!(value["name"], "test"); assert_eq!(value["name"], "test");
assert!(value["value"].is_null()); assert!(value["value"].is_null());
// Test incomplete string
let input = r#"{"name": "tes"#; let input = r#"{"name": "tes"#;
let (value, _consumed) = parser.parse_value(input).unwrap(); let (value, _consumed) = parser.parse_value(input).unwrap();
assert_eq!(value["name"], "tes"); assert_eq!(value["name"], "tes");
// Test incomplete array
let input = r#"[1, 2, "#; let input = r#"[1, 2, "#;
let (value, _consumed) = parser.parse_value(input).unwrap(); let (value, _consumed) = parser.parse_value(input).unwrap();
assert!(value.is_array()); assert!(value.is_array());
...@@ -193,11 +182,9 @@ fn test_compute_diff() { ...@@ -193,11 +182,9 @@ fn test_compute_diff() {
#[test] #[test]
fn test_stream_result_variants() { fn test_stream_result_variants() {
// Test Incomplete
let result = StreamResult::Incomplete; let result = StreamResult::Incomplete;
matches!(result, StreamResult::Incomplete); matches!(result, StreamResult::Incomplete);
// Test ToolName
let result = StreamResult::ToolName { let result = StreamResult::ToolName {
index: 0, index: 0,
name: "test".to_string(), name: "test".to_string(),
...@@ -209,7 +196,6 @@ fn test_stream_result_variants() { ...@@ -209,7 +196,6 @@ fn test_stream_result_variants() {
panic!("Expected ToolName variant"); panic!("Expected ToolName variant");
} }
// Test ToolComplete
let tool = ToolCall { let tool = ToolCall {
id: "123".to_string(), id: "123".to_string(),
r#type: "function".to_string(), r#type: "function".to_string(),
...@@ -255,7 +241,6 @@ fn test_partial_tool_call() { ...@@ -255,7 +241,6 @@ fn test_partial_tool_call() {
async fn test_json_parser_complete_single() { async fn test_json_parser_complete_single() {
let parser = JsonParser::new(); let parser = JsonParser::new();
// Test single tool call with arguments
let input = r#"{"name": "get_weather", "arguments": {"location": "San Francisco", "units": "celsius"}}"#; let input = r#"{"name": "get_weather", "arguments": {"location": "San Francisco", "units": "celsius"}}"#;
let result = parser.parse_complete(input).await.unwrap(); let result = parser.parse_complete(input).await.unwrap();
...@@ -269,7 +254,6 @@ async fn test_json_parser_complete_single() { ...@@ -269,7 +254,6 @@ async fn test_json_parser_complete_single() {
async fn test_json_parser_complete_array() { async fn test_json_parser_complete_array() {
let parser = JsonParser::new(); let parser = JsonParser::new();
// Test array of tool calls
let input = r#"[ let input = r#"[
{"name": "get_weather", "arguments": {"location": "SF"}}, {"name": "get_weather", "arguments": {"location": "SF"}},
{"name": "get_news", "arguments": {"query": "technology"}} {"name": "get_news", "arguments": {"query": "technology"}}
...@@ -286,7 +270,6 @@ async fn test_json_parser_complete_array() { ...@@ -286,7 +270,6 @@ async fn test_json_parser_complete_array() {
async fn test_json_parser_with_parameters() { async fn test_json_parser_with_parameters() {
let parser = JsonParser::new(); let parser = JsonParser::new();
// Test with "parameters" instead of "arguments"
let input = r#"{"name": "calculate", "parameters": {"x": 10, "y": 20, "operation": "add"}}"#; let input = r#"{"name": "calculate", "parameters": {"x": 10, "y": 20, "operation": "add"}}"#;
let result = parser.parse_complete(input).await.unwrap(); let result = parser.parse_complete(input).await.unwrap();
...@@ -299,7 +282,6 @@ async fn test_json_parser_with_parameters() { ...@@ -299,7 +282,6 @@ async fn test_json_parser_with_parameters() {
#[tokio::test] #[tokio::test]
async fn test_json_parser_with_tokens() { async fn test_json_parser_with_tokens() {
// Test with custom wrapper tokens
let parser = JsonParser::with_config(TokenConfig { let parser = JsonParser::with_config(TokenConfig {
start_tokens: vec!["[TOOL_CALLS] [".to_string()], start_tokens: vec!["[TOOL_CALLS] [".to_string()],
end_tokens: vec!["]".to_string()], end_tokens: vec!["]".to_string()],
...@@ -315,7 +297,6 @@ async fn test_json_parser_with_tokens() { ...@@ -315,7 +297,6 @@ async fn test_json_parser_with_tokens() {
#[tokio::test] #[tokio::test]
async fn test_multiline_json_with_tokens() { async fn test_multiline_json_with_tokens() {
// Test that regex with (?s) flag properly handles multi-line JSON
let parser = JsonParser::with_config(TokenConfig { let parser = JsonParser::with_config(TokenConfig {
start_tokens: vec!["<tool>".to_string()], start_tokens: vec!["<tool>".to_string()],
end_tokens: vec!["</tool>".to_string()], end_tokens: vec!["</tool>".to_string()],
...@@ -342,7 +323,6 @@ async fn test_multiline_json_with_tokens() { ...@@ -342,7 +323,6 @@ async fn test_multiline_json_with_tokens() {
#[tokio::test] #[tokio::test]
async fn test_multiline_json_array() { async fn test_multiline_json_array() {
// Test multi-line JSON array without wrapper tokens
let parser = JsonParser::new(); let parser = JsonParser::new();
let input = r#"[ let input = r#"[
...@@ -390,7 +370,6 @@ async fn test_json_parser_streaming() { ...@@ -390,7 +370,6 @@ async fn test_json_parser_streaming() {
let parser = JsonParser::new(); let parser = JsonParser::new();
let mut state = ParseState::new(); let mut state = ParseState::new();
// Test with complete JSON
let full_json = r#"{"name": "get_weather", "arguments": {"location": "San Francisco"}}"#; let full_json = r#"{"name": "get_weather", "arguments": {"location": "San Francisco"}}"#;
let result = parser let result = parser
...@@ -417,7 +396,6 @@ async fn test_registry_with_json_parser() { ...@@ -417,7 +396,6 @@ async fn test_registry_with_json_parser() {
// Should get JSON parser for OpenAI models // Should get JSON parser for OpenAI models
let parser = registry.get_parser("gpt-4-turbo").unwrap(); let parser = registry.get_parser("gpt-4-turbo").unwrap();
// Test that the parser works
let input = r#"{"name": "test", "arguments": {"x": 1}}"#; let input = r#"{"name": "test", "arguments": {"x": 1}}"#;
let result = parser.parse_complete(input).await.unwrap(); let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1); assert_eq!(result.len(), 1);
...@@ -677,7 +655,6 @@ mod edge_cases { ...@@ -677,7 +655,6 @@ mod edge_cases {
#[tokio::test] #[tokio::test]
async fn test_multiple_token_pairs_with_conflicts() { async fn test_multiple_token_pairs_with_conflicts() {
// Test with overlapping token patterns
let parser = JsonParser::with_config(TokenConfig { let parser = JsonParser::with_config(TokenConfig {
start_tokens: vec!["<<".to_string(), "<tool>".to_string()], start_tokens: vec!["<<".to_string(), "<tool>".to_string()],
end_tokens: vec![">>".to_string(), "</tool>".to_string()], end_tokens: vec![">>".to_string(), "</tool>".to_string()],
...@@ -708,7 +685,6 @@ mod edge_cases { ...@@ -708,7 +685,6 @@ mod edge_cases {
async fn test_streaming_with_partial_chunks() { async fn test_streaming_with_partial_chunks() {
let parser = JsonParser::new(); let parser = JsonParser::new();
// Test 1: Very incomplete JSON (just opening brace) should return Incomplete
let mut state1 = ParseState::new(); let mut state1 = ParseState::new();
let partial = r#"{"#; let partial = r#"{"#;
let result = parser let result = parser
...@@ -720,7 +696,6 @@ mod edge_cases { ...@@ -720,7 +696,6 @@ mod edge_cases {
"Should return Incomplete for just opening brace" "Should return Incomplete for just opening brace"
); );
// Test 2: Complete JSON should return ToolComplete
let mut state2 = ParseState::new(); let mut state2 = ParseState::new();
let complete = r#"{"name": "get_weather", "arguments": {"location": "SF"}}"#; let complete = r#"{"name": "get_weather", "arguments": {"location": "SF"}}"#;
let result = parser let result = parser
...@@ -738,7 +713,6 @@ mod edge_cases { ...@@ -738,7 +713,6 @@ mod edge_cases {
_ => panic!("Expected ToolComplete for complete JSON"), _ => panic!("Expected ToolComplete for complete JSON"),
} }
// Test 3: Partial JSON with name
// The PartialJson parser can complete partial JSON by filling in missing values // The PartialJson parser can complete partial JSON by filling in missing values
let mut state3 = ParseState::new(); let mut state3 = ParseState::new();
let partial_with_name = r#"{"name": "test", "argum"#; let partial_with_name = r#"{"name": "test", "argum"#;
...@@ -863,7 +837,6 @@ mod stress_tests { ...@@ -863,7 +837,6 @@ mod stress_tests {
#[tokio::test] #[tokio::test]
async fn test_concurrent_parser_usage() { async fn test_concurrent_parser_usage() {
// Test that parser can be used concurrently
let parser = std::sync::Arc::new(JsonParser::new()); let parser = std::sync::Arc::new(JsonParser::new());
let mut handles = vec![]; let mut handles = vec![];
......
...@@ -679,7 +679,6 @@ mod tests { ...@@ -679,7 +679,6 @@ mod tests {
fn test_get_smallest_tenant() { fn test_get_smallest_tenant() {
let tree = Tree::new(); let tree = Tree::new();
// Test empty tree
assert_eq!(tree.get_smallest_tenant(), "empty"); assert_eq!(tree.get_smallest_tenant(), "empty");
// Insert data for tenant1 - "ap" + "icot" = 6 chars // Insert data for tenant1 - "ap" + "icot" = 6 chars
...@@ -689,7 +688,6 @@ mod tests { ...@@ -689,7 +688,6 @@ mod tests {
// Insert data for tenant2 - "cat" = 3 chars // Insert data for tenant2 - "cat" = 3 chars
tree.insert("cat", "tenant2"); tree.insert("cat", "tenant2");
// Test - tenant2 should be smallest with 3 chars vs 6 chars
assert_eq!( assert_eq!(
tree.get_smallest_tenant(), tree.get_smallest_tenant(),
"tenant2", "tenant2",
...@@ -702,7 +700,6 @@ mod tests { ...@@ -702,7 +700,6 @@ mod tests {
tree.insert("do", "tenant3"); tree.insert("do", "tenant3");
tree.insert("hi", "tenant4"); tree.insert("hi", "tenant4");
// Test - should return either tenant3 or tenant4 (both have 2 chars)
let smallest = tree.get_smallest_tenant(); let smallest = tree.get_smallest_tenant();
assert!( assert!(
smallest == "tenant3" || smallest == "tenant4", smallest == "tenant3" || smallest == "tenant4",
...@@ -720,7 +717,6 @@ mod tests { ...@@ -720,7 +717,6 @@ mod tests {
"Expected tenant3 to be smallest with 2 characters" "Expected tenant3 to be smallest with 2 characters"
); );
// Test eviction
tree.evict_tenant_by_size(3); // This should evict tenants with more than 3 chars tree.evict_tenant_by_size(3); // This should evict tenants with more than 3 chars
let post_eviction_smallest = tree.get_smallest_tenant(); let post_eviction_smallest = tree.get_smallest_tenant();
...@@ -731,7 +727,6 @@ mod tests { ...@@ -731,7 +727,6 @@ mod tests {
fn test_tenant_char_count() { fn test_tenant_char_count() {
let tree = Tree::new(); let tree = Tree::new();
// Phase 1: Initial insertions
tree.insert("apple", "tenant1"); tree.insert("apple", "tenant1");
tree.insert("apricot", "tenant1"); tree.insert("apricot", "tenant1");
tree.insert("banana", "tenant1"); tree.insert("banana", "tenant1");
...@@ -755,7 +750,6 @@ mod tests { ...@@ -755,7 +750,6 @@ mod tests {
"Phase 1: Initial insertions" "Phase 1: Initial insertions"
); );
// Phase 2: Additional insertions
tree.insert("apartment", "tenant1"); tree.insert("apartment", "tenant1");
tree.insert("appetite", "tenant2"); tree.insert("appetite", "tenant2");
tree.insert("ball", "tenant1"); tree.insert("ball", "tenant1");
...@@ -778,7 +772,6 @@ mod tests { ...@@ -778,7 +772,6 @@ mod tests {
"Phase 2: Additional insertions" "Phase 2: Additional insertions"
); );
// Phase 3: Overlapping insertions
tree.insert("zebra", "tenant1"); tree.insert("zebra", "tenant1");
tree.insert("zebra", "tenant2"); tree.insert("zebra", "tenant2");
tree.insert("zero", "tenant1"); tree.insert("zero", "tenant1");
...@@ -801,7 +794,6 @@ mod tests { ...@@ -801,7 +794,6 @@ mod tests {
"Phase 3: Overlapping insertions" "Phase 3: Overlapping insertions"
); );
// Phase 4: Eviction test
tree.evict_tenant_by_size(10); tree.evict_tenant_by_size(10);
let computed_sizes = tree.get_used_size_per_tenant(); let computed_sizes = tree.get_used_size_per_tenant();
...@@ -1088,8 +1080,6 @@ mod tests { ...@@ -1088,8 +1080,6 @@ mod tests {
tree.pretty_print(); tree.pretty_print();
// Test sequentially
for (text, tenant) in TEST_PAIRS.iter() { for (text, tenant) in TEST_PAIRS.iter() {
let (matched_text, matched_tenant) = tree.prefix_match(text); let (matched_text, matched_tenant) = tree.prefix_match(text);
assert_eq!(matched_text, *text); assert_eq!(matched_text, *text);
...@@ -1162,7 +1152,6 @@ mod tests { ...@@ -1162,7 +1152,6 @@ mod tests {
tree.pretty_print(); tree.pretty_print();
// Verify initial sizes
let sizes_before = tree.get_used_size_per_tenant(); let sizes_before = tree.get_used_size_per_tenant();
assert_eq!(sizes_before.get("tenant1").unwrap(), &5); // "hello" = 5 assert_eq!(sizes_before.get("tenant1").unwrap(), &5); // "hello" = 5
assert_eq!(sizes_before.get("tenant2").unwrap(), &10); // "hello" + "world" = 10 assert_eq!(sizes_before.get("tenant2").unwrap(), &10); // "hello" + "world" = 10
...@@ -1172,12 +1161,10 @@ mod tests { ...@@ -1172,12 +1161,10 @@ mod tests {
tree.pretty_print(); tree.pretty_print();
// Verify sizes after eviction
let sizes_after = tree.get_used_size_per_tenant(); let sizes_after = tree.get_used_size_per_tenant();
assert_eq!(sizes_after.get("tenant1").unwrap(), &5); // Should be unchanged assert_eq!(sizes_after.get("tenant1").unwrap(), &5); // Should be unchanged
assert_eq!(sizes_after.get("tenant2").unwrap(), &5); // Only "world" remains assert_eq!(sizes_after.get("tenant2").unwrap(), &5); // Only "world" remains
// Verify "world" remains for tenant2
let (matched, tenant) = tree.prefix_match("world"); let (matched, tenant) = tree.prefix_match("world");
assert_eq!(matched, "world"); assert_eq!(matched, "world");
assert_eq!(tenant, "tenant2"); assert_eq!(tenant, "tenant2");
...@@ -1208,7 +1195,6 @@ mod tests { ...@@ -1208,7 +1195,6 @@ mod tests {
// Check sizes after eviction // Check sizes after eviction
let sizes_after = tree.get_used_size_per_tenant(); let sizes_after = tree.get_used_size_per_tenant();
// Verify all tenants are under their size limits
for (tenant, &size) in sizes_after.iter() { for (tenant, &size) in sizes_after.iter() {
assert!( assert!(
size <= max_size, size <= max_size,
...@@ -1287,7 +1273,6 @@ mod tests { ...@@ -1287,7 +1273,6 @@ mod tests {
let final_sizes = tree.get_used_size_per_tenant(); let final_sizes = tree.get_used_size_per_tenant();
println!("Final sizes after test completion: {:?}", final_sizes); println!("Final sizes after test completion: {:?}", final_sizes);
// Verify all tenants are under limit
for (_, &size) in final_sizes.iter() { for (_, &size) in final_sizes.iter() {
assert!( assert!(
size <= max_size, size <= max_size,
...@@ -1364,14 +1349,12 @@ mod tests { ...@@ -1364,14 +1349,12 @@ mod tests {
tree.insert("help", "tenant1"); // tenant1: hel -> p tree.insert("help", "tenant1"); // tenant1: hel -> p
tree.insert("helicopter", "tenant2"); // tenant2: hel -> icopter tree.insert("helicopter", "tenant2"); // tenant2: hel -> icopter
// Test tenant1's data
assert_eq!(tree.prefix_match_tenant("hello", "tenant1"), "hello"); // Full match for tenant1 assert_eq!(tree.prefix_match_tenant("hello", "tenant1"), "hello"); // Full match for tenant1
assert_eq!(tree.prefix_match_tenant("help", "tenant1"), "help"); // Exclusive to tenant1 assert_eq!(tree.prefix_match_tenant("help", "tenant1"), "help"); // Exclusive to tenant1
assert_eq!(tree.prefix_match_tenant("hel", "tenant1"), "hel"); // Shared prefix assert_eq!(tree.prefix_match_tenant("hel", "tenant1"), "hel"); // Shared prefix
assert_eq!(tree.prefix_match_tenant("hello world", "tenant1"), "hello"); // Should stop at tenant1's boundary assert_eq!(tree.prefix_match_tenant("hello world", "tenant1"), "hello"); // Should stop at tenant1's boundary
assert_eq!(tree.prefix_match_tenant("helicopter", "tenant1"), "hel"); // Should stop at tenant1's boundary assert_eq!(tree.prefix_match_tenant("helicopter", "tenant1"), "hel"); // Should stop at tenant1's boundary
// Test tenant2's data
assert_eq!(tree.prefix_match_tenant("hello", "tenant2"), "hello"); // Full match for tenant2 assert_eq!(tree.prefix_match_tenant("hello", "tenant2"), "hello"); // Full match for tenant2
assert_eq!( assert_eq!(
tree.prefix_match_tenant("hello world", "tenant2"), tree.prefix_match_tenant("hello world", "tenant2"),
...@@ -1384,7 +1367,6 @@ mod tests { ...@@ -1384,7 +1367,6 @@ mod tests {
assert_eq!(tree.prefix_match_tenant("hel", "tenant2"), "hel"); // Shared prefix assert_eq!(tree.prefix_match_tenant("hel", "tenant2"), "hel"); // Shared prefix
assert_eq!(tree.prefix_match_tenant("help", "tenant2"), "hel"); // Should stop at tenant2's boundary assert_eq!(tree.prefix_match_tenant("help", "tenant2"), "hel"); // Should stop at tenant2's boundary
// Test non-existent tenant
assert_eq!(tree.prefix_match_tenant("hello", "tenant3"), ""); // Non-existent tenant assert_eq!(tree.prefix_match_tenant("hello", "tenant3"), ""); // Non-existent tenant
assert_eq!(tree.prefix_match_tenant("help", "tenant3"), ""); // Non-existent tenant assert_eq!(tree.prefix_match_tenant("help", "tenant3"), ""); // Non-existent tenant
} }
...@@ -1399,7 +1381,6 @@ mod tests { ...@@ -1399,7 +1381,6 @@ mod tests {
tree.insert("hello", "tenant2"); tree.insert("hello", "tenant2");
tree.insert("help", "tenant2"); tree.insert("help", "tenant2");
// Verify initial state
let initial_sizes = tree.get_used_size_per_tenant(); let initial_sizes = tree.get_used_size_per_tenant();
assert_eq!(initial_sizes.get("tenant1").unwrap(), &10); // "hello" + "world" assert_eq!(initial_sizes.get("tenant1").unwrap(), &10); // "hello" + "world"
assert_eq!(initial_sizes.get("tenant2").unwrap(), &6); // "hello" + "p" assert_eq!(initial_sizes.get("tenant2").unwrap(), &6); // "hello" + "p"
...@@ -1407,7 +1388,6 @@ mod tests { ...@@ -1407,7 +1388,6 @@ mod tests {
// Evict tenant1 // Evict tenant1
tree.remove_tenant("tenant1"); tree.remove_tenant("tenant1");
// Verify after eviction
let final_sizes = tree.get_used_size_per_tenant(); let final_sizes = tree.get_used_size_per_tenant();
assert!( assert!(
!final_sizes.contains_key("tenant1"), !final_sizes.contains_key("tenant1"),
...@@ -1419,11 +1399,9 @@ mod tests { ...@@ -1419,11 +1399,9 @@ mod tests {
"tenant2 should be unaffected" "tenant2 should be unaffected"
); );
// Verify tenant1's data is inaccessible
assert_eq!(tree.prefix_match_tenant("hello", "tenant1"), ""); assert_eq!(tree.prefix_match_tenant("hello", "tenant1"), "");
assert_eq!(tree.prefix_match_tenant("world", "tenant1"), ""); assert_eq!(tree.prefix_match_tenant("world", "tenant1"), "");
// Verify tenant2's data is still accessible
assert_eq!(tree.prefix_match_tenant("hello", "tenant2"), "hello"); assert_eq!(tree.prefix_match_tenant("hello", "tenant2"), "hello");
assert_eq!(tree.prefix_match_tenant("help", "tenant2"), "help"); assert_eq!(tree.prefix_match_tenant("help", "tenant2"), "help");
} }
...@@ -1441,7 +1419,6 @@ mod tests { ...@@ -1441,7 +1419,6 @@ mod tests {
tree.insert("banana", "tenant2"); tree.insert("banana", "tenant2");
tree.insert("ball", "tenant2"); tree.insert("ball", "tenant2");
// Verify initial state
let initial_sizes = tree.get_used_size_per_tenant(); let initial_sizes = tree.get_used_size_per_tenant();
println!("Initial sizes: {:?}", initial_sizes); println!("Initial sizes: {:?}", initial_sizes);
tree.pretty_print(); tree.pretty_print();
...@@ -1449,29 +1426,24 @@ mod tests { ...@@ -1449,29 +1426,24 @@ mod tests {
// Evict tenant1 // Evict tenant1
tree.remove_tenant("tenant1"); tree.remove_tenant("tenant1");
// Verify final state
let final_sizes = tree.get_used_size_per_tenant(); let final_sizes = tree.get_used_size_per_tenant();
println!("Final sizes: {:?}", final_sizes); println!("Final sizes: {:?}", final_sizes);
tree.pretty_print(); tree.pretty_print();
// Verify tenant1 is completely removed
assert!( assert!(
!final_sizes.contains_key("tenant1"), !final_sizes.contains_key("tenant1"),
"tenant1 should be completely removed" "tenant1 should be completely removed"
); );
// Verify all tenant1's data is inaccessible
assert_eq!(tree.prefix_match_tenant("apple", "tenant1"), ""); assert_eq!(tree.prefix_match_tenant("apple", "tenant1"), "");
assert_eq!(tree.prefix_match_tenant("application", "tenant1"), ""); assert_eq!(tree.prefix_match_tenant("application", "tenant1"), "");
assert_eq!(tree.prefix_match_tenant("banana", "tenant1"), ""); assert_eq!(tree.prefix_match_tenant("banana", "tenant1"), "");
// Verify tenant2's data is intact
assert_eq!(tree.prefix_match_tenant("apple", "tenant2"), "apple"); assert_eq!(tree.prefix_match_tenant("apple", "tenant2"), "apple");
assert_eq!(tree.prefix_match_tenant("appetite", "tenant2"), "appetite"); assert_eq!(tree.prefix_match_tenant("appetite", "tenant2"), "appetite");
assert_eq!(tree.prefix_match_tenant("banana", "tenant2"), "banana"); assert_eq!(tree.prefix_match_tenant("banana", "tenant2"), "banana");
assert_eq!(tree.prefix_match_tenant("ball", "tenant2"), "ball"); assert_eq!(tree.prefix_match_tenant("ball", "tenant2"), "ball");
// Verify the tree structure is still valid for tenant2
let tenant2_size = final_sizes.get("tenant2").unwrap(); let tenant2_size = final_sizes.get("tenant2").unwrap();
assert_eq!(tenant2_size, &(5 + 5 + 6 + 2)); // "apple" + "etite" + "banana" + "ll" assert_eq!(tenant2_size, &(5 + 5 + 6 + 2)); // "apple" + "etite" + "banana" + "ll"
} }
......
...@@ -576,7 +576,6 @@ mod model_info_tests { ...@@ -576,7 +576,6 @@ mod model_info_tests {
let ctx = TestContext::new(vec![]).await; let ctx = TestContext::new(vec![]).await;
let app = ctx.create_app().await; let app = ctx.create_app().await;
// Test server info with no workers
let req = Request::builder() let req = Request::builder()
.method("GET") .method("GET")
.uri("/get_server_info") .uri("/get_server_info")
...@@ -593,7 +592,6 @@ mod model_info_tests { ...@@ -593,7 +592,6 @@ mod model_info_tests {
resp.status() resp.status()
); );
// Test model info with no workers
let req = Request::builder() let req = Request::builder()
.method("GET") .method("GET")
.uri("/get_model_info") .uri("/get_model_info")
...@@ -610,7 +608,6 @@ mod model_info_tests { ...@@ -610,7 +608,6 @@ mod model_info_tests {
resp.status() resp.status()
); );
// Test v1/models with no workers
let req = Request::builder() let req = Request::builder()
.method("GET") .method("GET")
.uri("/v1/models") .uri("/v1/models")
...@@ -652,7 +649,6 @@ mod model_info_tests { ...@@ -652,7 +649,6 @@ mod model_info_tests {
let app = ctx.create_app().await; let app = ctx.create_app().await;
// Test that model info is consistent across workers
for _ in 0..5 { for _ in 0..5 {
let req = Request::builder() let req = Request::builder()
.method("GET") .method("GET")
...@@ -795,7 +791,6 @@ mod worker_management_tests { ...@@ -795,7 +791,6 @@ mod worker_management_tests {
let resp = app.clone().oneshot(req).await.unwrap(); let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
// Verify it's removed
let req = Request::builder() let req = Request::builder()
.method("GET") .method("GET")
.uri("/list_workers") .uri("/list_workers")
...@@ -1302,7 +1297,6 @@ mod error_tests { ...@@ -1302,7 +1297,6 @@ mod error_tests {
let app = ctx.create_app().await; let app = ctx.create_app().await;
// Test unknown endpoint
let req = Request::builder() let req = Request::builder()
.method("GET") .method("GET")
.uri("/unknown_endpoint") .uri("/unknown_endpoint")
...@@ -1312,7 +1306,6 @@ mod error_tests { ...@@ -1312,7 +1306,6 @@ mod error_tests {
let resp = app.clone().oneshot(req).await.unwrap(); let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::NOT_FOUND); assert_eq!(resp.status(), StatusCode::NOT_FOUND);
// Test POST to unknown endpoint
let req = Request::builder() let req = Request::builder()
.method("POST") .method("POST")
.uri("/api/v2/generate") .uri("/api/v2/generate")
...@@ -1606,7 +1599,6 @@ mod cache_tests { ...@@ -1606,7 +1599,6 @@ mod cache_tests {
.unwrap(); .unwrap();
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap(); let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
// Verify the response contains load information
assert!(body_json.is_object()); assert!(body_json.is_object());
// The exact structure depends on the implementation // The exact structure depends on the implementation
// but should contain worker load information // but should contain worker load information
...@@ -1797,7 +1789,6 @@ mod request_id_tests { ...@@ -1797,7 +1789,6 @@ mod request_id_tests {
let app = ctx.create_app().await; let app = ctx.create_app().await;
// Test 1: Request without any request ID header should generate one
let payload = json!({ let payload = json!({
"text": "Test request", "text": "Test request",
"stream": false "stream": false
...@@ -1830,7 +1821,6 @@ mod request_id_tests { ...@@ -1830,7 +1821,6 @@ mod request_id_tests {
"Request ID should have content after prefix" "Request ID should have content after prefix"
); );
// Test 2: Request with custom x-request-id should preserve it
let custom_id = "custom-request-id-123"; let custom_id = "custom-request-id-123";
let req = Request::builder() let req = Request::builder()
.method("POST") .method("POST")
...@@ -1847,7 +1837,6 @@ mod request_id_tests { ...@@ -1847,7 +1837,6 @@ mod request_id_tests {
assert!(response_id.is_some()); assert!(response_id.is_some());
assert_eq!(response_id.unwrap(), custom_id); assert_eq!(response_id.unwrap(), custom_id);
// Test 3: Different endpoints should have different prefixes
let chat_payload = json!({ let chat_payload = json!({
"messages": [{"role": "user", "content": "Hello"}], "messages": [{"role": "user", "content": "Hello"}],
"model": "test-model" "model": "test-model"
...@@ -1871,7 +1860,6 @@ mod request_id_tests { ...@@ -1871,7 +1860,6 @@ mod request_id_tests {
.unwrap() .unwrap()
.starts_with("chatcmpl-")); .starts_with("chatcmpl-"));
// Test 4: Alternative request ID headers should be recognized
let req = Request::builder() let req = Request::builder()
.method("POST") .method("POST")
.uri("/generate") .uri("/generate")
...@@ -1948,7 +1936,6 @@ mod request_id_tests { ...@@ -1948,7 +1936,6 @@ mod request_id_tests {
"stream": false "stream": false
}); });
// Test custom header is recognized
let req = Request::builder() let req = Request::builder()
.method("POST") .method("POST")
.uri("/generate") .uri("/generate")
...@@ -2013,7 +2000,6 @@ mod rerank_tests { ...@@ -2013,7 +2000,6 @@ mod rerank_tests {
.unwrap(); .unwrap();
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap(); let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
// Verify response structure
assert!(body_json.get("results").is_some()); assert!(body_json.get("results").is_some());
assert!(body_json.get("model").is_some()); assert!(body_json.get("model").is_some());
assert_eq!(body_json["model"], "test-rerank-model"); assert_eq!(body_json["model"], "test-rerank-model");
...@@ -2021,7 +2007,6 @@ mod rerank_tests { ...@@ -2021,7 +2007,6 @@ mod rerank_tests {
let results = body_json["results"].as_array().unwrap(); let results = body_json["results"].as_array().unwrap();
assert_eq!(results.len(), 2); assert_eq!(results.len(), 2);
// Verify results are sorted by score (highest first)
assert!(results[0]["score"].as_f64().unwrap() >= results[1]["score"].as_f64().unwrap()); assert!(results[0]["score"].as_f64().unwrap() >= results[1]["score"].as_f64().unwrap());
ctx.shutdown().await; ctx.shutdown().await;
...@@ -2164,7 +2149,6 @@ mod rerank_tests { ...@@ -2164,7 +2149,6 @@ mod rerank_tests {
let app = ctx.create_app().await; let app = ctx.create_app().await;
// Test V1 API format (simplified input)
let payload = json!({ let payload = json!({
"query": "machine learning algorithms", "query": "machine learning algorithms",
"documents": [ "documents": [
...@@ -2189,7 +2173,6 @@ mod rerank_tests { ...@@ -2189,7 +2173,6 @@ mod rerank_tests {
.unwrap(); .unwrap();
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap(); let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
// Verify response structure
assert!(body_json.get("results").is_some()); assert!(body_json.get("results").is_some());
assert!(body_json.get("model").is_some()); assert!(body_json.get("model").is_some());
...@@ -2199,7 +2182,6 @@ mod rerank_tests { ...@@ -2199,7 +2182,6 @@ mod rerank_tests {
let results = body_json["results"].as_array().unwrap(); let results = body_json["results"].as_array().unwrap();
assert_eq!(results.len(), 3); // All documents should be returned assert_eq!(results.len(), 3); // All documents should be returned
// Verify results are sorted by score (highest first)
assert!(results[0]["score"].as_f64().unwrap() >= results[1]["score"].as_f64().unwrap()); assert!(results[0]["score"].as_f64().unwrap() >= results[1]["score"].as_f64().unwrap());
assert!(results[1]["score"].as_f64().unwrap() >= results[2]["score"].as_f64().unwrap()); assert!(results[1]["score"].as_f64().unwrap() >= results[2]["score"].as_f64().unwrap());
...@@ -2224,7 +2206,6 @@ mod rerank_tests { ...@@ -2224,7 +2206,6 @@ mod rerank_tests {
let app = ctx.create_app().await; let app = ctx.create_app().await;
// Test empty query string (validation should fail)
let payload = json!({ let payload = json!({
"query": "", "query": "",
"documents": ["Document 1", "Document 2"], "documents": ["Document 1", "Document 2"],
...@@ -2241,7 +2222,6 @@ mod rerank_tests { ...@@ -2241,7 +2222,6 @@ mod rerank_tests {
let resp = app.clone().oneshot(req).await.unwrap(); let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST); assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
// Test query with only whitespace (validation should fail)
let payload = json!({ let payload = json!({
"query": " ", "query": " ",
"documents": ["Document 1", "Document 2"], "documents": ["Document 1", "Document 2"],
...@@ -2258,7 +2238,6 @@ mod rerank_tests { ...@@ -2258,7 +2238,6 @@ mod rerank_tests {
let resp = app.clone().oneshot(req).await.unwrap(); let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST); assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
// Test empty documents list (validation should fail)
let payload = json!({ let payload = json!({
"query": "test query", "query": "test query",
"documents": [], "documents": [],
...@@ -2275,7 +2254,6 @@ mod rerank_tests { ...@@ -2275,7 +2254,6 @@ mod rerank_tests {
let resp = app.clone().oneshot(req).await.unwrap(); let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST); assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
// Test invalid top_k (validation should fail)
let payload = json!({ let payload = json!({
"query": "test query", "query": "test query",
"documents": ["Document 1", "Document 2"], "documents": ["Document 1", "Document 2"],
......
...@@ -93,19 +93,16 @@ fn test_mixed_model_ids() { ...@@ -93,19 +93,16 @@ fn test_mixed_model_ids() {
policy.add_worker(&worker3); policy.add_worker(&worker3);
policy.add_worker(&worker4); policy.add_worker(&worker4);
// Test selection with default workers only
let default_workers: Vec<Arc<dyn Worker>> = let default_workers: Vec<Arc<dyn Worker>> =
vec![Arc::new(worker1.clone()), Arc::new(worker3.clone())]; vec![Arc::new(worker1.clone()), Arc::new(worker3.clone())];
let selected = policy.select_worker(&default_workers, Some("test request")); let selected = policy.select_worker(&default_workers, Some("test request"));
assert!(selected.is_some(), "Should select from default workers"); assert!(selected.is_some(), "Should select from default workers");
// Test selection with specific model workers only
let llama_workers: Vec<Arc<dyn Worker>> = let llama_workers: Vec<Arc<dyn Worker>> =
vec![Arc::new(worker2.clone()), Arc::new(worker4.clone())]; vec![Arc::new(worker2.clone()), Arc::new(worker4.clone())];
let selected = policy.select_worker(&llama_workers, Some("test request")); let selected = policy.select_worker(&llama_workers, Some("test request"));
assert!(selected.is_some(), "Should select from llama-3 workers"); assert!(selected.is_some(), "Should select from llama-3 workers");
// Test selection with mixed workers
let all_workers: Vec<Arc<dyn Worker>> = vec![ let all_workers: Vec<Arc<dyn Worker>> = vec![
Arc::new(worker1.clone()), Arc::new(worker1.clone()),
Arc::new(worker2.clone()), Arc::new(worker2.clone()),
...@@ -144,7 +141,6 @@ fn test_remove_worker_by_url_backward_compat() { ...@@ -144,7 +141,6 @@ fn test_remove_worker_by_url_backward_compat() {
// Should remove from all trees since we don't know the model // Should remove from all trees since we don't know the model
policy.remove_worker_by_url("http://worker1:8080"); policy.remove_worker_by_url("http://worker1:8080");
// Verify removal worked
let workers: Vec<Arc<dyn Worker>> = vec![Arc::new(worker2.clone())]; let workers: Vec<Arc<dyn Worker>> = vec![Arc::new(worker2.clone())];
let selected = policy.select_worker(&workers, Some("test")); let selected = policy.select_worker(&workers, Some("test"));
assert_eq!(selected, Some(0), "Should only have worker2 left"); assert_eq!(selected, Some(0), "Should only have worker2 left");
......
...@@ -89,7 +89,6 @@ fn test_chat_template_with_tokens() { ...@@ -89,7 +89,6 @@ fn test_chat_template_with_tokens() {
#[test] #[test]
fn test_llama_style_template() { fn test_llama_style_template() {
// Test a Llama-style chat template
let template = r#" let template = r#"
{%- if messages[0]['role'] == 'system' -%} {%- if messages[0]['role'] == 'system' -%}
{%- set system_message = messages[0]['content'] -%} {%- set system_message = messages[0]['content'] -%}
...@@ -160,7 +159,6 @@ fn test_llama_style_template() { ...@@ -160,7 +159,6 @@ fn test_llama_style_template() {
#[test] #[test]
fn test_chatml_template() { fn test_chatml_template() {
// Test a ChatML-style template
let template = r#" let template = r#"
{%- for message in messages %} {%- for message in messages %}
{{- '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>\n' }} {{- '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>\n' }}
...@@ -241,13 +239,11 @@ assistant: ...@@ -241,13 +239,11 @@ assistant:
.map(|msg| serde_json::to_value(msg).unwrap()) .map(|msg| serde_json::to_value(msg).unwrap())
.collect(); .collect();
// Test without generation prompt
let result = processor let result = processor
.apply_chat_template(&json_messages, ChatTemplateParams::default()) .apply_chat_template(&json_messages, ChatTemplateParams::default())
.unwrap(); .unwrap();
assert_eq!(result.trim(), "user: Test"); assert_eq!(result.trim(), "user: Test");
// Test with generation prompt
let result_with_prompt = processor let result_with_prompt = processor
.apply_chat_template( .apply_chat_template(
&json_messages, &json_messages,
...@@ -275,7 +271,6 @@ fn test_empty_messages_template() { ...@@ -275,7 +271,6 @@ fn test_empty_messages_template() {
#[test] #[test]
fn test_content_format_detection() { fn test_content_format_detection() {
// Test string format detection
let string_template = r#" let string_template = r#"
{%- for message in messages -%} {%- for message in messages -%}
{{ message.role }}: {{ message.content }} {{ message.role }}: {{ message.content }}
...@@ -286,7 +281,6 @@ fn test_content_format_detection() { ...@@ -286,7 +281,6 @@ fn test_content_format_detection() {
ChatTemplateContentFormat::String ChatTemplateContentFormat::String
); );
// Test OpenAI format detection
let openai_template = r#" let openai_template = r#"
{%- for message in messages -%} {%- for message in messages -%}
{%- for content in message.content -%} {%- for content in message.content -%}
...@@ -302,7 +296,6 @@ fn test_content_format_detection() { ...@@ -302,7 +296,6 @@ fn test_content_format_detection() {
#[test] #[test]
fn test_template_with_multimodal_content() { fn test_template_with_multimodal_content() {
// Test that multimodal messages work correctly when serialized to JSON
let template = r#" let template = r#"
{%- for message in messages %} {%- for message in messages %}
{{ message.role }}: {{ message.role }}:
......
...@@ -57,7 +57,6 @@ mod tests { ...@@ -57,7 +57,6 @@ mod tests {
) )
.unwrap(); .unwrap();
// Test that the custom template is used
let messages = vec![ let messages = vec![
spec::ChatMessage::User { spec::ChatMessage::User {
role: "user".to_string(), role: "user".to_string(),
...@@ -89,7 +88,6 @@ mod tests { ...@@ -89,7 +88,6 @@ mod tests {
.apply_chat_template(&json_messages, params) .apply_chat_template(&json_messages, params)
.unwrap(); .unwrap();
// Verify the custom template format
assert!(result.contains("<|user|>Hello")); assert!(result.contains("<|user|>Hello"));
assert!(result.contains("<|assistant|>Hi there")); assert!(result.contains("<|assistant|>Hi there"));
assert!(result.ends_with("<|assistant|>")); assert!(result.ends_with("<|assistant|>"));
......
...@@ -148,7 +148,6 @@ mod tests { ...@@ -148,7 +148,6 @@ mod tests {
async fn test_mock_server_with_rmcp_client() { async fn test_mock_server_with_rmcp_client() {
let mut server = MockMCPServer::start().await.unwrap(); let mut server = MockMCPServer::start().await.unwrap();
// Test that we can connect with rmcp client
use rmcp::transport::StreamableHttpClientTransport; use rmcp::transport::StreamableHttpClientTransport;
use rmcp::ServiceExt; use rmcp::ServiceExt;
...@@ -158,7 +157,6 @@ mod tests { ...@@ -158,7 +157,6 @@ mod tests {
assert!(client.is_ok(), "Should be able to connect to mock server"); assert!(client.is_ok(), "Should be able to connect to mock server");
if let Ok(client) = client { if let Ok(client) = client {
// Test listing tools
let tools = client.peer().list_all_tools().await; let tools = client.peer().list_all_tools().await;
assert!(tools.is_ok(), "Should be able to list tools"); assert!(tools.is_ok(), "Should be able to list tools");
......
...@@ -71,7 +71,6 @@ pub fn ensure_tokenizer_cached() -> PathBuf { ...@@ -71,7 +71,6 @@ pub fn ensure_tokenizer_cached() -> PathBuf {
let content = response.bytes().expect("Failed to read tokenizer content"); let content = response.bytes().expect("Failed to read tokenizer content");
// Verify we got actual JSON content
if content.len() < 100 { if content.len() < 100 {
panic!("Downloaded content too small: {} bytes", content.len()); panic!("Downloaded content too small: {} bytes", content.len());
} }
......
// This test suite validates the complete MCP implementation against the // This test suite validates the complete MCP implementation against the
// functionality required for SGLang responses API integration. // functionality required for SGLang responses API integration.
// //
// Test Coverage:
// - Core MCP server functionality // - Core MCP server functionality
// - Tool session management (individual and multi-tool) // - Tool session management (individual and multi-tool)
// - Tool execution and error handling // - Tool execution and error handling
...@@ -26,7 +25,6 @@ async fn create_mock_server() -> MockMCPServer { ...@@ -26,7 +25,6 @@ async fn create_mock_server() -> MockMCPServer {
#[tokio::test] #[tokio::test]
async fn test_mcp_server_initialization() { async fn test_mcp_server_initialization() {
// Test that we can create an empty configuration
let config = McpConfig { servers: vec![] }; let config = McpConfig { servers: vec![] };
// Should fail with no servers // Should fail with no servers
...@@ -329,7 +327,6 @@ async fn test_tool_info_structure() { ...@@ -329,7 +327,6 @@ async fn test_tool_info_structure() {
#[tokio::test] #[tokio::test]
async fn test_sse_connection() { async fn test_sse_connection() {
// Test with a non-existent command using STDIO to avoid retry delays
// This tests that SSE configuration is properly handled even when connection fails // This tests that SSE configuration is properly handled even when connection fails
let config = McpConfig { let config = McpConfig {
servers: vec![McpServerConfig { servers: vec![McpServerConfig {
...@@ -351,8 +348,6 @@ async fn test_sse_connection() { ...@@ -351,8 +348,6 @@ async fn test_sse_connection() {
#[tokio::test] #[tokio::test]
async fn test_transport_types() { async fn test_transport_types() {
// Test different transport configurations
// HTTP/Streamable transport // HTTP/Streamable transport
let http_config = McpServerConfig { let http_config = McpServerConfig {
name: "http_server".to_string(), name: "http_server".to_string(),
...@@ -444,7 +439,6 @@ async fn test_complete_workflow() { ...@@ -444,7 +439,6 @@ async fn test_complete_workflow() {
// 7. Clean shutdown // 7. Clean shutdown
manager.shutdown().await; manager.shutdown().await;
// Verify all required capabilities for responses API integration
let capabilities = [ let capabilities = [
"MCP server initialization", "MCP server initialization",
"Tool server connection and discovery", "Tool server connection and discovery",
......
...@@ -20,8 +20,6 @@ async fn test_policy_registry_with_router_manager() { ...@@ -20,8 +20,6 @@ async fn test_policy_registry_with_router_manager() {
// Create RouterManager with shared registries // Create RouterManager with shared registries
let _router_manager = RouterManager::new(worker_registry.clone()); let _router_manager = RouterManager::new(worker_registry.clone());
// Test adding workers with different models and policies
// Add first worker for llama-3 with cache_aware policy hint // Add first worker for llama-3 with cache_aware policy hint
let mut labels1 = HashMap::new(); let mut labels1 = HashMap::new();
labels1.insert("policy".to_string(), "cache_aware".to_string()); labels1.insert("policy".to_string(), "cache_aware".to_string());
...@@ -44,7 +42,6 @@ async fn test_policy_registry_with_router_manager() { ...@@ -44,7 +42,6 @@ async fn test_policy_registry_with_router_manager() {
// This would normally connect to a real worker, but for testing we'll just verify the structure // This would normally connect to a real worker, but for testing we'll just verify the structure
// In a real test, we'd need to mock the worker or use a test server // In a real test, we'd need to mock the worker or use a test server
// Verify PolicyRegistry has the correct policy for llama-3
let _llama_policy = policy_registry.get_policy("llama-3"); let _llama_policy = policy_registry.get_policy("llama-3");
// After first worker is added, llama-3 should have a policy // After first worker is added, llama-3 should have a policy
...@@ -88,10 +85,8 @@ async fn test_policy_registry_with_router_manager() { ...@@ -88,10 +85,8 @@ async fn test_policy_registry_with_router_manager() {
chat_template: None, chat_template: None,
}; };
// Verify gpt-4 has random policy
let _gpt_policy = policy_registry.get_policy("gpt-4"); let _gpt_policy = policy_registry.get_policy("gpt-4");
// Test removing workers
// When we remove both llama-3 workers, the policy should be cleaned up // When we remove both llama-3 workers, the policy should be cleaned up
println!("PolicyRegistry integration test structure created"); println!("PolicyRegistry integration test structure created");
...@@ -113,7 +108,6 @@ fn test_policy_registry_cleanup() { ...@@ -113,7 +108,6 @@ fn test_policy_registry_cleanup() {
let policy2 = registry.on_worker_added("model-1", Some("random")); let policy2 = registry.on_worker_added("model-1", Some("random"));
assert_eq!(policy2.name(), "cache_aware"); // Should still be cache_aware assert_eq!(policy2.name(), "cache_aware"); // Should still be cache_aware
// Verify policy exists
assert!(registry.get_policy("model-1").is_some()); assert!(registry.get_policy("model-1").is_some());
// Remove first worker - policy should remain // Remove first worker - policy should remain
...@@ -143,7 +137,6 @@ fn test_policy_registry_multiple_models() { ...@@ -143,7 +137,6 @@ fn test_policy_registry_multiple_models() {
assert_eq!(gpt_policy.name(), "random"); assert_eq!(gpt_policy.name(), "random");
assert_eq!(mistral_policy.name(), "round_robin"); // Default assert_eq!(mistral_policy.name(), "round_robin"); // Default
// Verify all policies are stored
assert!(registry.get_policy("llama-3").is_some()); assert!(registry.get_policy("llama-3").is_some());
assert!(registry.get_policy("gpt-4").is_some()); assert!(registry.get_policy("gpt-4").is_some());
assert!(registry.get_policy("mistral").is_some()); assert!(registry.get_policy("mistral").is_some());
......
...@@ -126,7 +126,6 @@ mod request_format_tests { ...@@ -126,7 +126,6 @@ mod request_format_tests {
}]) }])
.await; .await;
// Test 1: Basic text request
let payload = json!({ let payload = json!({
"text": "Hello, world!", "text": "Hello, world!",
"stream": false "stream": false
...@@ -135,7 +134,6 @@ mod request_format_tests { ...@@ -135,7 +134,6 @@ mod request_format_tests {
let result = ctx.make_request("/generate", payload).await; let result = ctx.make_request("/generate", payload).await;
assert!(result.is_ok()); assert!(result.is_ok());
// Test 2: Request with sampling parameters
let payload = json!({ let payload = json!({
"text": "Tell me a story", "text": "Tell me a story",
"sampling_params": { "sampling_params": {
...@@ -149,7 +147,6 @@ mod request_format_tests { ...@@ -149,7 +147,6 @@ mod request_format_tests {
let result = ctx.make_request("/generate", payload).await; let result = ctx.make_request("/generate", payload).await;
assert!(result.is_ok()); assert!(result.is_ok());
// Test 3: Request with input_ids
let payload = json!({ let payload = json!({
"input_ids": [1, 2, 3, 4, 5], "input_ids": [1, 2, 3, 4, 5],
"sampling_params": { "sampling_params": {
...@@ -176,7 +173,6 @@ mod request_format_tests { ...@@ -176,7 +173,6 @@ mod request_format_tests {
}]) }])
.await; .await;
// Test 1: Basic chat completion
let payload = json!({ let payload = json!({
"model": "test-model", "model": "test-model",
"messages": [ "messages": [
...@@ -197,7 +193,6 @@ mod request_format_tests { ...@@ -197,7 +193,6 @@ mod request_format_tests {
Some("chat.completion") Some("chat.completion")
); );
// Test 2: Chat completion with parameters
let payload = json!({ let payload = json!({
"model": "test-model", "model": "test-model",
"messages": [ "messages": [
...@@ -226,7 +221,6 @@ mod request_format_tests { ...@@ -226,7 +221,6 @@ mod request_format_tests {
}]) }])
.await; .await;
// Test 1: Basic completion
let payload = json!({ let payload = json!({
"model": "test-model", "model": "test-model",
"prompt": "Once upon a time", "prompt": "Once upon a time",
...@@ -244,7 +238,6 @@ mod request_format_tests { ...@@ -244,7 +238,6 @@ mod request_format_tests {
Some("text_completion") Some("text_completion")
); );
// Test 2: Completion with array prompt
let payload = json!({ let payload = json!({
"model": "test-model", "model": "test-model",
"prompt": ["First prompt", "Second prompt"], "prompt": ["First prompt", "Second prompt"],
...@@ -255,7 +248,6 @@ mod request_format_tests { ...@@ -255,7 +248,6 @@ mod request_format_tests {
let result = ctx.make_request("/v1/completions", payload).await; let result = ctx.make_request("/v1/completions", payload).await;
assert!(result.is_ok()); assert!(result.is_ok());
// Test 3: Completion with logprobs
let payload = json!({ let payload = json!({
"model": "test-model", "model": "test-model",
"prompt": "The capital of France is", "prompt": "The capital of France is",
...@@ -281,7 +273,6 @@ mod request_format_tests { ...@@ -281,7 +273,6 @@ mod request_format_tests {
}]) }])
.await; .await;
// Test batch text generation
let payload = json!({ let payload = json!({
"text": ["First text", "Second text", "Third text"], "text": ["First text", "Second text", "Third text"],
"sampling_params": { "sampling_params": {
...@@ -294,7 +285,6 @@ mod request_format_tests { ...@@ -294,7 +285,6 @@ mod request_format_tests {
let result = ctx.make_request("/generate", payload).await; let result = ctx.make_request("/generate", payload).await;
assert!(result.is_ok()); assert!(result.is_ok());
// Test batch with input_ids
let payload = json!({ let payload = json!({
"input_ids": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], "input_ids": [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
"stream": false "stream": false
...@@ -317,7 +307,6 @@ mod request_format_tests { ...@@ -317,7 +307,6 @@ mod request_format_tests {
}]) }])
.await; .await;
// Test with return_logprob
let payload = json!({ let payload = json!({
"text": "Test", "text": "Test",
"return_logprob": true, "return_logprob": true,
...@@ -327,7 +316,6 @@ mod request_format_tests { ...@@ -327,7 +316,6 @@ mod request_format_tests {
let result = ctx.make_request("/generate", payload).await; let result = ctx.make_request("/generate", payload).await;
assert!(result.is_ok()); assert!(result.is_ok());
// Test with json_schema
let payload = json!({ let payload = json!({
"text": "Generate JSON", "text": "Generate JSON",
"sampling_params": { "sampling_params": {
...@@ -340,7 +328,6 @@ mod request_format_tests { ...@@ -340,7 +328,6 @@ mod request_format_tests {
let result = ctx.make_request("/generate", payload).await; let result = ctx.make_request("/generate", payload).await;
assert!(result.is_ok()); assert!(result.is_ok());
// Test with ignore_eos
let payload = json!({ let payload = json!({
"text": "Continue forever", "text": "Continue forever",
"sampling_params": { "sampling_params": {
...@@ -368,7 +355,6 @@ mod request_format_tests { ...@@ -368,7 +355,6 @@ mod request_format_tests {
}]) }])
.await; .await;
// Test with empty body - should still work with mock worker
let payload = json!({}); let payload = json!({});
let result = ctx.make_request("/generate", payload).await; let result = ctx.make_request("/generate", payload).await;
......
...@@ -44,7 +44,6 @@ fn test_responses_request_creation() { ...@@ -44,7 +44,6 @@ fn test_responses_request_creation() {
repetition_penalty: 1.0, repetition_penalty: 1.0,
}; };
// Test GenerationRequest trait implementation
assert!(!request.is_stream()); assert!(!request.is_stream());
assert_eq!(request.get_model(), Some("test-model")); assert_eq!(request.get_model(), Some("test-model"));
let routing_text = request.extract_text_for_routing(); let routing_text = request.extract_text_for_routing();
...@@ -139,7 +138,6 @@ fn test_usage_conversion() { ...@@ -139,7 +138,6 @@ fn test_usage_conversion() {
8 8
); );
// Test reverse conversion
let back_to_usage = response_usage.to_usage_info(); let back_to_usage = response_usage.to_usage_info();
assert_eq!(back_to_usage.prompt_tokens, 15); assert_eq!(back_to_usage.prompt_tokens, 15);
assert_eq!(back_to_usage.completion_tokens, 25); assert_eq!(back_to_usage.completion_tokens, 25);
...@@ -152,7 +150,6 @@ fn test_reasoning_param_default() { ...@@ -152,7 +150,6 @@ fn test_reasoning_param_default() {
effort: Some(ReasoningEffort::Medium), effort: Some(ReasoningEffort::Medium),
}; };
// Test JSON serialization/deserialization preserves default
let json = serde_json::to_string(&param).unwrap(); let json = serde_json::to_string(&param).unwrap();
let parsed: ResponseReasoningParam = serde_json::from_str(&json).unwrap(); let parsed: ResponseReasoningParam = serde_json::from_str(&json).unwrap();
...@@ -197,7 +194,6 @@ fn test_json_serialization() { ...@@ -197,7 +194,6 @@ fn test_json_serialization() {
repetition_penalty: 1.2, repetition_penalty: 1.2,
}; };
// Test that everything can be serialized to JSON and back
let json = serde_json::to_string(&request).expect("Serialization should work"); let json = serde_json::to_string(&request).expect("Serialization should work");
let parsed: ResponsesRequest = let parsed: ResponsesRequest =
serde_json::from_str(&json).expect("Deserialization should work"); serde_json::from_str(&json).expect("Deserialization should work");
......
...@@ -197,7 +197,6 @@ mod streaming_tests { ...@@ -197,7 +197,6 @@ mod streaming_tests {
let events = result.unwrap(); let events = result.unwrap();
assert!(events.len() >= 2); // At least one chunk + [DONE] assert!(events.len() >= 2); // At least one chunk + [DONE]
// Verify events are valid JSON (except [DONE])
for event in &events { for event in &events {
if event != "[DONE]" { if event != "[DONE]" {
let parsed: Result<serde_json::Value, _> = serde_json::from_str(event); let parsed: Result<serde_json::Value, _> = serde_json::from_str(event);
...@@ -329,7 +328,6 @@ mod streaming_tests { ...@@ -329,7 +328,6 @@ mod streaming_tests {
#[tokio::test] #[tokio::test]
async fn test_sse_format_parsing() { async fn test_sse_format_parsing() {
// Test SSE format parsing
let parse_sse_chunk = |chunk: &[u8]| -> Vec<String> { let parse_sse_chunk = |chunk: &[u8]| -> Vec<String> {
let text = String::from_utf8_lossy(chunk); let text = String::from_utf8_lossy(chunk);
text.lines() text.lines()
...@@ -347,7 +345,6 @@ mod streaming_tests { ...@@ -347,7 +345,6 @@ mod streaming_tests {
assert_eq!(events[1], "{\"text\":\" world\"}"); assert_eq!(events[1], "{\"text\":\" world\"}");
assert_eq!(events[2], "[DONE]"); assert_eq!(events[2], "[DONE]");
// Test with mixed content
let mixed = b"event: message\ndata: {\"test\":true}\n\n: comment\ndata: [DONE]\n\n"; let mixed = b"event: message\ndata: {\"test\":true}\n\n: comment\ndata: [DONE]\n\n";
let events = parse_sse_chunk(mixed); let events = parse_sse_chunk(mixed);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment