Unverified Commit ce3ca9b0 authored by Jeff Nettleton's avatar Jeff Nettleton Committed by GitHub
Browse files

[router] add cargo clippy in CI and fix-up linting errors (#9242)

parent 4d98e486
...@@ -27,6 +27,12 @@ jobs: ...@@ -27,6 +27,12 @@ jobs:
run: | run: |
bash scripts/ci/ci_install_rust.sh bash scripts/ci/ci_install_rust.sh
- name: Run lint
run: |
source "$HOME/.cargo/env"
cd sgl-router/
cargo clippy --all-targets --all-features -- -D warnings
- name: Run fmt - name: Run fmt
run: | run: |
source "$HOME/.cargo/env" source "$HOME/.cargo/env"
......
...@@ -22,7 +22,7 @@ fn create_test_worker() -> BasicWorker { ...@@ -22,7 +22,7 @@ fn create_test_worker() -> BasicWorker {
fn get_bootstrap_info(worker: &BasicWorker) -> (String, Option<u16>) { fn get_bootstrap_info(worker: &BasicWorker) -> (String, Option<u16>) {
let hostname = get_hostname(worker.url()); let hostname = get_hostname(worker.url());
let bootstrap_port = match worker.worker_type() { let bootstrap_port = match worker.worker_type() {
WorkerType::Prefill { bootstrap_port } => bootstrap_port.clone(), WorkerType::Prefill { bootstrap_port } => bootstrap_port,
_ => None, _ => None,
}; };
(hostname, bootstrap_port) (hostname, bootstrap_port)
......
...@@ -137,8 +137,7 @@ mod tests { ...@@ -137,8 +137,7 @@ mod tests {
fn test_worker_result_type_alias() { fn test_worker_result_type_alias() {
// Test Ok variant // Test Ok variant
let result: WorkerResult<i32> = Ok(42); let result: WorkerResult<i32> = Ok(42);
assert!(result.is_ok()); assert!(matches!(result, Ok(42)));
assert_eq!(result.unwrap(), 42);
// Test Err variant // Test Err variant
let error = WorkerError::WorkerNotFound { let error = WorkerError::WorkerNotFound {
......
...@@ -311,13 +311,7 @@ impl Worker for BasicWorker { ...@@ -311,13 +311,7 @@ impl Worker for BasicWorker {
// Use the shared client with a custom timeout for this request // Use the shared client with a custom timeout for this request
let health_result = match WORKER_CLIENT.get(&health_url).timeout(timeout).send().await { let health_result = match WORKER_CLIENT.get(&health_url).timeout(timeout).send().await {
Ok(response) => { Ok(response) => response.status().is_success(),
if response.status().is_success() {
true
} else {
false
}
}
Err(_) => false, Err(_) => false,
}; };
...@@ -571,6 +565,7 @@ impl WorkerFactory { ...@@ -571,6 +565,7 @@ impl WorkerFactory {
} }
/// Create workers from URLs with automatic type detection /// Create workers from URLs with automatic type detection
#[allow(clippy::type_complexity)]
pub fn create_from_urls( pub fn create_from_urls(
regular_urls: Vec<String>, regular_urls: Vec<String>,
prefill_urls: Vec<(String, Option<u16>)>, prefill_urls: Vec<(String, Option<u16>)>,
...@@ -1202,12 +1197,6 @@ mod tests { ...@@ -1202,12 +1197,6 @@ mod tests {
for handle in handles { for handle in handles {
handle.await.unwrap(); handle.await.unwrap();
} }
// Final state should be deterministic (last write wins)
// We can't predict the exact final state due to scheduling,
// but we can verify no data corruption occurred
let final_health = worker.is_healthy();
assert!(final_health == true || final_health == false);
} }
// Test WorkerFactory // Test WorkerFactory
......
...@@ -249,6 +249,7 @@ impl Router { ...@@ -249,6 +249,7 @@ impl Router {
health_check_interval_secs = 60, health_check_interval_secs = 60,
health_check_endpoint = String::from("/health"), health_check_endpoint = String::from("/health"),
))] ))]
#[allow(clippy::too_many_arguments)]
fn new( fn new(
worker_urls: Vec<String>, worker_urls: Vec<String>,
policy: PolicyType, policy: PolicyType,
......
...@@ -510,25 +510,9 @@ mod tests { ...@@ -510,25 +510,9 @@ mod tests {
// ============= Duration Bucket Tests ============= // ============= Duration Bucket Tests =============
#[test]
fn test_duration_bucket_values() {
let expected_buckets = vec![
0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 15.0, 30.0, 45.0,
60.0, 90.0, 120.0, 180.0, 240.0,
];
// The buckets are defined in start_prometheus function
assert_eq!(expected_buckets.len(), 20);
// Verify proper ordering
for i in 1..expected_buckets.len() {
assert!(expected_buckets[i] > expected_buckets[i - 1]);
}
}
#[test] #[test]
fn test_duration_bucket_coverage() { fn test_duration_bucket_coverage() {
let test_cases = vec![ let test_cases: [(f64, &str); 7] = [
(0.0005, "sub-millisecond"), (0.0005, "sub-millisecond"),
(0.005, "5ms"), (0.005, "5ms"),
(0.05, "50ms"), (0.05, "50ms"),
...@@ -538,7 +522,7 @@ mod tests { ...@@ -538,7 +522,7 @@ mod tests {
(240.0, "4m"), (240.0, "4m"),
]; ];
let buckets = vec![ let buckets: [f64; 20] = [
0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 15.0, 30.0, 45.0, 0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 15.0, 30.0, 45.0,
60.0, 90.0, 120.0, 180.0, 240.0, 60.0, 90.0, 120.0, 180.0, 240.0,
]; ];
...@@ -546,7 +530,7 @@ mod tests { ...@@ -546,7 +530,7 @@ mod tests {
for (duration, label) in test_cases { for (duration, label) in test_cases {
let bucket_found = buckets let bucket_found = buckets
.iter() .iter()
.any(|&b| ((b - duration) as f64).abs() < 0.0001 || b > duration); .any(|&b| (b - duration).abs() < 0.0001 || b > duration);
assert!(bucket_found, "No bucket found for {} ({})", duration, label); assert!(bucket_found, "No bucket found for {} ({})", duration, label);
} }
} }
...@@ -558,14 +542,13 @@ mod tests { ...@@ -558,14 +542,13 @@ mod tests {
let matcher = Matcher::Suffix(String::from("duration_seconds")); let matcher = Matcher::Suffix(String::from("duration_seconds"));
// Test matching behavior // Test matching behavior
let _matching_metrics = vec![ let _matching_metrics = [
"request_duration_seconds", "request_duration_seconds",
"response_duration_seconds", "response_duration_seconds",
"sgl_router_request_duration_seconds", "sgl_router_request_duration_seconds",
]; ];
let _non_matching_metrics = let _non_matching_metrics = ["duration_total", "duration_seconds_total", "other_metric"];
vec!["duration_total", "duration_seconds_total", "other_metric"];
// Note: We can't directly test Matcher matching without the internals, // Note: We can't directly test Matcher matching without the internals,
// but we can verify the matcher is created correctly // but we can verify the matcher is created correctly
...@@ -611,8 +594,8 @@ mod tests { ...@@ -611,8 +594,8 @@ mod tests {
#[test] #[test]
fn test_custom_buckets_for_different_metrics() { fn test_custom_buckets_for_different_metrics() {
// Test that we can create different bucket configurations // Test that we can create different bucket configurations
let request_buckets = vec![0.001, 0.01, 0.1, 1.0, 10.0]; let request_buckets = [0.001, 0.01, 0.1, 1.0, 10.0];
let generate_buckets = vec![0.1, 0.5, 1.0, 5.0, 30.0, 60.0]; let generate_buckets = [0.1, 0.5, 1.0, 5.0, 30.0, 60.0];
assert_eq!(request_buckets.len(), 5); assert_eq!(request_buckets.len(), 5);
assert_eq!(generate_buckets.len(), 6); assert_eq!(generate_buckets.len(), 6);
...@@ -730,9 +713,6 @@ mod tests { ...@@ -730,9 +713,6 @@ mod tests {
for handle in handles { for handle in handles {
handle.join().unwrap(); handle.join().unwrap();
} }
// If we get here without panic, concurrent access works
assert!(true);
} }
// ============= Edge Cases Tests ============= // ============= Edge Cases Tests =============
...@@ -743,9 +723,6 @@ mod tests { ...@@ -743,9 +723,6 @@ mod tests {
RouterMetrics::record_request(""); RouterMetrics::record_request("");
RouterMetrics::set_worker_health("", true); RouterMetrics::set_worker_health("", true);
RouterMetrics::record_policy_decision("", ""); RouterMetrics::record_policy_decision("", "");
// If we get here without panic, empty strings are handled
assert!(true);
} }
#[test] #[test]
...@@ -754,14 +731,11 @@ mod tests { ...@@ -754,14 +731,11 @@ mod tests {
RouterMetrics::record_request(&long_label); RouterMetrics::record_request(&long_label);
RouterMetrics::set_worker_health(&long_label, false); RouterMetrics::set_worker_health(&long_label, false);
// If we get here without panic, long labels are handled
assert!(true);
} }
#[test] #[test]
fn test_special_characters_in_labels() { fn test_special_characters_in_labels() {
let special_labels = vec![ let special_labels = [
"test/with/slashes", "test/with/slashes",
"test-with-dashes", "test-with-dashes",
"test_with_underscores", "test_with_underscores",
...@@ -773,9 +747,6 @@ mod tests { ...@@ -773,9 +747,6 @@ mod tests {
RouterMetrics::record_request(label); RouterMetrics::record_request(label);
RouterMetrics::set_worker_health(label, true); RouterMetrics::set_worker_health(label, true);
} }
// If we get here without panic, special characters are handled
assert!(true);
} }
#[test] #[test]
...@@ -788,9 +759,7 @@ mod tests { ...@@ -788,9 +759,7 @@ mod tests {
RouterMetrics::set_worker_load("worker", usize::MAX); RouterMetrics::set_worker_load("worker", usize::MAX);
RouterMetrics::record_request_duration("route", Duration::from_nanos(1)); RouterMetrics::record_request_duration("route", Duration::from_nanos(1));
RouterMetrics::record_request_duration("route", Duration::from_secs(86400)); // 24 hours // 24 hours
RouterMetrics::record_request_duration("route", Duration::from_secs(86400));
// If we get here without panic, extreme values are handled
assert!(true);
} }
} }
...@@ -141,7 +141,7 @@ mod tests { ...@@ -141,7 +141,7 @@ mod tests {
vec![Box::new(worker1), Box::new(worker2), Box::new(worker3)]; vec![Box::new(worker1), Box::new(worker2), Box::new(worker3)];
// Run multiple selections // Run multiple selections
let mut selected_counts = vec![0; 3]; let mut selected_counts = [0; 3];
for _ in 0..100 { for _ in 0..100 {
if let Some(idx) = policy.select_worker(&workers, None) { if let Some(idx) = policy.select_worker(&workers, None) {
selected_counts[idx] += 1; selected_counts[idx] += 1;
......
use axum::body::Body; use axum::body::Body;
use axum::extract::Request; use axum::extract::Request;
use axum::http::{HeaderMap, HeaderName, HeaderValue}; use axum::http::HeaderMap;
/// Copy request headers to a Vec of name-value string pairs /// Copy request headers to a Vec of name-value string pairs
/// Used for forwarding headers to backend workers /// Used for forwarding headers to backend workers
......
...@@ -363,6 +363,7 @@ impl PDRouter { ...@@ -363,6 +363,7 @@ impl PDRouter {
Ok(format!("Successfully removed decode server: {}", url)) Ok(format!("Successfully removed decode server: {}", url))
} }
#[allow(clippy::too_many_arguments)]
pub async fn new( pub async fn new(
prefill_urls: Vec<(String, Option<u16>)>, prefill_urls: Vec<(String, Option<u16>)>,
decode_urls: Vec<String>, decode_urls: Vec<String>,
...@@ -733,6 +734,7 @@ impl PDRouter { ...@@ -733,6 +734,7 @@ impl PDRouter {
} }
// Internal method that performs the actual dual dispatch (without retry logic) // Internal method that performs the actual dual dispatch (without retry logic)
#[allow(clippy::too_many_arguments)]
async fn execute_dual_dispatch_internal( async fn execute_dual_dispatch_internal(
&self, &self,
headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
...@@ -1145,7 +1147,7 @@ impl PDRouter { ...@@ -1145,7 +1147,7 @@ impl PDRouter {
*response.status_mut() = status; *response.status_mut() = status;
// Use provided headers or create new ones, then ensure content-type is set for streaming // Use provided headers or create new ones, then ensure content-type is set for streaming
let mut headers = headers.unwrap_or_else(HeaderMap::new); let mut headers = headers.unwrap_or_default();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream")); headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
*response.headers_mut() = headers; *response.headers_mut() = headers;
...@@ -1160,41 +1162,41 @@ impl PDRouter { ...@@ -1160,41 +1162,41 @@ impl PDRouter {
return_logprob: bool, return_logprob: bool,
prefill_body: Option<bytes::Bytes>, prefill_body: Option<bytes::Bytes>,
) -> Response { ) -> Response {
match res.bytes().await { let response = res.bytes().await;
Ok(decode_body) => { let decode_body = match response {
if return_logprob && prefill_body.is_some() { Ok(decode_body) => decode_body,
// Merge logprobs from prefill and decode
let prefill_body = prefill_body.as_ref().unwrap();
match (
serde_json::from_slice::<Value>(prefill_body),
serde_json::from_slice::<Value>(&decode_body),
) {
(Ok(prefill_json), Ok(mut decode_json)) => {
// Use helper to merge logprobs
Self::merge_logprobs_in_json(&prefill_json, &mut decode_json);
// Return merged response
match serde_json::to_vec(&decode_json) {
Ok(body) => (status, body).into_response(),
Err(e) => {
error!("Failed to serialize merged response: {}", e);
(status, decode_body).into_response()
}
}
}
_ => {
// If parsing fails, just return decode response
warn!("Failed to parse responses for logprob merging");
(status, decode_body).into_response()
}
}
} else {
(status, decode_body).into_response()
}
}
Err(e) => { Err(e) => {
error!("Failed to read decode response: {}", e); error!("Failed to read decode response: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR, "Failed to read response").into_response() return (StatusCode::INTERNAL_SERVER_ERROR, "Failed to read response")
.into_response();
}
};
if !return_logprob {
return (status, decode_body).into_response();
}
let Some(prefill_body) = prefill_body else {
return (status, decode_body).into_response();
};
// Merge logprobs from prefill and decode
let (Ok(prefill_json), Ok(mut decode_json)) = (
serde_json::from_slice::<Value>(&prefill_body),
serde_json::from_slice::<Value>(&decode_body),
) else {
warn!("Failed to parse responses for logprob merging");
return (status, decode_body).into_response();
};
Self::merge_logprobs_in_json(&prefill_json, &mut decode_json);
// Return merged response
match serde_json::to_vec(&decode_json) {
Ok(body) => (status, body).into_response(),
Err(e) => {
error!("Failed to serialize merged response: {}", e);
(status, decode_body).into_response()
} }
} }
} }
......
...@@ -45,6 +45,7 @@ pub struct Router { ...@@ -45,6 +45,7 @@ pub struct Router {
impl Router { impl Router {
/// Create a new router with injected policy and client /// Create a new router with injected policy and client
#[allow(clippy::too_many_arguments)]
pub async fn new( pub async fn new(
worker_urls: Vec<String>, worker_urls: Vec<String>,
policy: Arc<dyn LoadBalancingPolicy>, policy: Arc<dyn LoadBalancingPolicy>,
......
...@@ -38,6 +38,7 @@ struct EvictionEntry { ...@@ -38,6 +38,7 @@ struct EvictionEntry {
impl Eq for EvictionEntry {} impl Eq for EvictionEntry {}
#[allow(clippy::non_canonical_partial_ord_impl)]
impl PartialOrd for EvictionEntry { impl PartialOrd for EvictionEntry {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.timestamp.cmp(&other.timestamp)) Some(self.timestamp.cmp(&other.timestamp))
...@@ -862,8 +863,8 @@ mod tests { ...@@ -862,8 +863,8 @@ mod tests {
// spawn 3 threads for insert // spawn 3 threads for insert
let tree_clone = Arc::clone(&tree); let tree_clone = Arc::clone(&tree);
let texts = vec!["hello", "apple", "banana"]; let texts = ["hello", "apple", "banana"];
let tenants = vec!["tenant1", "tenant2", "tenant3"]; let tenants = ["tenant1", "tenant2", "tenant3"];
let mut handles = vec![]; let mut handles = vec![];
...@@ -916,13 +917,12 @@ mod tests { ...@@ -916,13 +917,12 @@ mod tests {
// spawn 3 threads for insert // spawn 3 threads for insert
let tree_clone = Arc::clone(&tree); let tree_clone = Arc::clone(&tree);
let texts = vec!["apple", "apabc", "acbdeds"]; static TEXTS: [&str; 3] = ["apple", "apabc", "acbdeds"];
let mut handles = vec![]; let mut handles = vec![];
for i in 0..3 { for text in TEXTS.iter() {
let tree_clone = Arc::clone(&tree_clone); let tree_clone = Arc::clone(&tree_clone);
let text = texts[i];
let tenant = "tenant0"; let tenant = "tenant0";
let handle = thread::spawn(move || { let handle = thread::spawn(move || {
...@@ -942,14 +942,13 @@ mod tests { ...@@ -942,14 +942,13 @@ mod tests {
let tree_clone = Arc::clone(&tree); let tree_clone = Arc::clone(&tree);
for i in 0..3 { for text in TEXTS.iter() {
let tree_clone = Arc::clone(&tree_clone); let tree_clone = Arc::clone(&tree_clone);
let text = texts[i];
let tenant = "tenant0"; let tenant = "tenant0";
let handle = thread::spawn(move || { let handle = thread::spawn(move || {
let (matched_text, matched_tenant) = tree_clone.prefix_match(text); let (matched_text, matched_tenant) = tree_clone.prefix_match(text);
assert_eq!(matched_text, text); assert_eq!(matched_text, *text);
assert_eq!(matched_tenant, tenant); assert_eq!(matched_tenant, tenant);
}); });
...@@ -964,13 +963,13 @@ mod tests { ...@@ -964,13 +963,13 @@ mod tests {
#[test] #[test]
fn test_group_prefix_insert_match_concurrent() { fn test_group_prefix_insert_match_concurrent() {
let prefix = vec![ static PREFIXES: [&str; 4] = [
"Clock strikes midnight, I'm still wide awake", "Clock strikes midnight, I'm still wide awake",
"Got dreams bigger than these city lights", "Got dreams bigger than these city lights",
"Time waits for no one, gotta make my move", "Time waits for no one, gotta make my move",
"Started from the bottom, that's no metaphor", "Started from the bottom, that's no metaphor",
]; ];
let suffix = vec![ let suffixes = [
"Got too much to prove, ain't got time to lose", "Got too much to prove, ain't got time to lose",
"History in the making, yeah, you can't erase this", "History in the making, yeah, you can't erase this",
]; ];
...@@ -978,10 +977,10 @@ mod tests { ...@@ -978,10 +977,10 @@ mod tests {
let mut handles = vec![]; let mut handles = vec![];
for i in 0..prefix.len() { for (i, prefix) in PREFIXES.iter().enumerate() {
for j in 0..suffix.len() { for suffix in suffixes.iter() {
let tree_clone = Arc::clone(&tree); let tree_clone = Arc::clone(&tree);
let text = format!("{} {}", prefix[i], suffix[j]); let text = format!("{} {}", prefix, suffix);
let tenant = format!("tenant{}", i); let tenant = format!("tenant{}", i);
let handle = thread::spawn(move || { let handle = thread::spawn(move || {
...@@ -1000,17 +999,15 @@ mod tests { ...@@ -1000,17 +999,15 @@ mod tests {
tree.pretty_print(); tree.pretty_print();
// check matching using multi threads // check matching using multi threads
let mut handles = vec![]; let mut handles = vec![];
for i in 0..prefix.len() { for (i, prefix) in PREFIXES.iter().enumerate() {
let tree_clone = Arc::clone(&tree); let tree_clone = Arc::clone(&tree);
let text = prefix[i];
let handle = thread::spawn(move || { let handle = thread::spawn(move || {
let (matched_text, matched_tenant) = tree_clone.prefix_match(text); let (matched_text, matched_tenant) = tree_clone.prefix_match(prefix);
let tenant = format!("tenant{}", i); let tenant = format!("tenant{}", i);
assert_eq!(matched_text, text); assert_eq!(matched_text, *prefix);
assert_eq!(matched_tenant, tenant); assert_eq!(matched_tenant, tenant);
}); });
...@@ -1027,13 +1024,13 @@ mod tests { ...@@ -1027,13 +1024,13 @@ mod tests {
fn test_mixed_concurrent_insert_match() { fn test_mixed_concurrent_insert_match() {
// ensure it does not deadlock instead of doing correctness check // ensure it does not deadlock instead of doing correctness check
let prefix = vec![ static PREFIXES: [&str; 4] = [
"Clock strikes midnight, I'm still wide awake", "Clock strikes midnight, I'm still wide awake",
"Got dreams bigger than these city lights", "Got dreams bigger than these city lights",
"Time waits for no one, gotta make my move", "Time waits for no one, gotta make my move",
"Started from the bottom, that's no metaphor", "Started from the bottom, that's no metaphor",
]; ];
let suffix = vec![ let suffixes = [
"Got too much to prove, ain't got time to lose", "Got too much to prove, ain't got time to lose",
"History in the making, yeah, you can't erase this", "History in the making, yeah, you can't erase this",
]; ];
...@@ -1041,10 +1038,10 @@ mod tests { ...@@ -1041,10 +1038,10 @@ mod tests {
let mut handles = vec![]; let mut handles = vec![];
for i in 0..prefix.len() { for (i, prefix) in PREFIXES.iter().enumerate() {
for j in 0..suffix.len() { for suffix in suffixes.iter() {
let tree_clone = Arc::clone(&tree); let tree_clone = Arc::clone(&tree);
let text = format!("{} {}", prefix[i], suffix[j]); let text = format!("{} {}", prefix, suffix);
let tenant = format!("tenant{}", i); let tenant = format!("tenant{}", i);
let handle = thread::spawn(move || { let handle = thread::spawn(move || {
...@@ -1056,13 +1053,11 @@ mod tests { ...@@ -1056,13 +1053,11 @@ mod tests {
} }
// check matching using multi threads // check matching using multi threads
for prefix in PREFIXES.iter() {
for i in 0..prefix.len() {
let tree_clone = Arc::clone(&tree); let tree_clone = Arc::clone(&tree);
let text = prefix[i];
let handle = thread::spawn(move || { let handle = thread::spawn(move || {
let (_matched_text, _matched_tenant) = tree_clone.prefix_match(text); let (_matched_text, _matched_tenant) = tree_clone.prefix_match(prefix);
}); });
handles.push(handle); handles.push(handle);
...@@ -1080,16 +1075,14 @@ mod tests { ...@@ -1080,16 +1075,14 @@ mod tests {
// use .chars() to get the iterator of the utf-8 value // use .chars() to get the iterator of the utf-8 value
let tree = Arc::new(Tree::new()); let tree = Arc::new(Tree::new());
let test_pairs = vec![ static TEST_PAIRS: [(&str, &str); 3] = [
("你好嗎", "tenant1"), ("你好嗎", "tenant1"),
("你好喔", "tenant2"), ("你好喔", "tenant2"),
("你心情好嗎", "tenant3"), ("你心情好嗎", "tenant3"),
]; ];
// Insert sequentially // Insert sequentially
for i in 0..test_pairs.len() { for (text, tenant) in TEST_PAIRS.iter() {
let text = test_pairs[i].0;
let tenant = test_pairs[i].1;
tree.insert(text, tenant); tree.insert(text, tenant);
} }
...@@ -1097,10 +1090,10 @@ mod tests { ...@@ -1097,10 +1090,10 @@ mod tests {
// Test sequentially // Test sequentially
for i in 0..test_pairs.len() { for (text, tenant) in TEST_PAIRS.iter() {
let (matched_text, matched_tenant) = tree.prefix_match(test_pairs[i].0); let (matched_text, matched_tenant) = tree.prefix_match(text);
assert_eq!(matched_text, test_pairs[i].0); assert_eq!(matched_text, *text);
assert_eq!(matched_tenant, test_pairs[i].1); assert_eq!(matched_tenant, *tenant);
} }
} }
...@@ -1108,7 +1101,7 @@ mod tests { ...@@ -1108,7 +1101,7 @@ mod tests {
fn test_utf8_split_concurrent() { fn test_utf8_split_concurrent() {
let tree = Arc::new(Tree::new()); let tree = Arc::new(Tree::new());
let test_pairs = vec![ static TEST_PAIRS: [(&str, &str); 3] = [
("你好嗎", "tenant1"), ("你好嗎", "tenant1"),
("你好喔", "tenant2"), ("你好喔", "tenant2"),
("你心情好嗎", "tenant3"), ("你心情好嗎", "tenant3"),
...@@ -1117,13 +1110,11 @@ mod tests { ...@@ -1117,13 +1110,11 @@ mod tests {
// Create multiple threads for insertion // Create multiple threads for insertion
let mut handles = vec![]; let mut handles = vec![];
for i in 0..test_pairs.len() { for (text, tenant) in TEST_PAIRS.iter() {
let tree_clone = Arc::clone(&tree); let tree_clone = Arc::clone(&tree);
let text = test_pairs[i].0.to_string();
let tenant = test_pairs[i].1.to_string();
let handle = thread::spawn(move || { let handle = thread::spawn(move || {
tree_clone.insert(&text, &tenant); tree_clone.insert(text, tenant);
}); });
handles.push(handle); handles.push(handle);
...@@ -1139,15 +1130,13 @@ mod tests { ...@@ -1139,15 +1130,13 @@ mod tests {
// Create multiple threads for matching // Create multiple threads for matching
let mut handles = vec![]; let mut handles = vec![];
for i in 0..test_pairs.len() { for (text, tenant) in TEST_PAIRS.iter() {
let tree_clone = Arc::clone(&tree); let tree_clone = Arc::clone(&tree);
let text = test_pairs[i].0.to_string();
let tenant = test_pairs[i].1.to_string();
let handle = thread::spawn(move || { let handle = thread::spawn(move || {
let (matched_text, matched_tenant) = tree_clone.prefix_match(&text); let (matched_text, matched_tenant) = tree_clone.prefix_match(text);
assert_eq!(matched_text, text); assert_eq!(matched_text, *text);
assert_eq!(matched_tenant, tenant); assert_eq!(matched_tenant, *tenant);
}); });
handles.push(handle); handles.push(handle);
...@@ -1202,7 +1191,7 @@ mod tests { ...@@ -1202,7 +1191,7 @@ mod tests {
let max_size: usize = 100; let max_size: usize = 100;
// Define prefixes // Define prefixes
let prefixes = vec!["aqwefcisdf", "iajsdfkmade", "kjnzxcvewqe", "iejksduqasd"]; let prefixes = ["aqwefcisdf", "iajsdfkmade", "kjnzxcvewqe", "iejksduqasd"];
// Insert strings with shared prefixes // Insert strings with shared prefixes
for _i in 0..100 { for _i in 0..100 {
......
...@@ -718,7 +718,7 @@ mod worker_management_tests { ...@@ -718,7 +718,7 @@ mod worker_management_tests {
// Add the worker // Add the worker
let req = Request::builder() let req = Request::builder()
.method("POST") .method("POST")
.uri(&format!("/add_worker?url={}", url)) .uri(format!("/add_worker?url={}", url))
.body(Body::empty()) .body(Body::empty())
.unwrap(); .unwrap();
...@@ -776,7 +776,7 @@ mod worker_management_tests { ...@@ -776,7 +776,7 @@ mod worker_management_tests {
// Remove the worker // Remove the worker
let req = Request::builder() let req = Request::builder()
.method("POST") .method("POST")
.uri(&format!("/remove_worker?url={}", worker_url)) .uri(format!("/remove_worker?url={}", worker_url))
.body(Body::empty()) .body(Body::empty())
.unwrap(); .unwrap();
...@@ -856,7 +856,7 @@ mod worker_management_tests { ...@@ -856,7 +856,7 @@ mod worker_management_tests {
// Add worker first time // Add worker first time
let req = Request::builder() let req = Request::builder()
.method("POST") .method("POST")
.uri(&format!("/add_worker?url={}", url)) .uri(format!("/add_worker?url={}", url))
.body(Body::empty()) .body(Body::empty())
.unwrap(); .unwrap();
let resp = app.clone().oneshot(req).await.unwrap(); let resp = app.clone().oneshot(req).await.unwrap();
...@@ -867,7 +867,7 @@ mod worker_management_tests { ...@@ -867,7 +867,7 @@ mod worker_management_tests {
// Try to add same worker again // Try to add same worker again
let req = Request::builder() let req = Request::builder()
.method("POST") .method("POST")
.uri(&format!("/add_worker?url={}", url)) .uri(format!("/add_worker?url={}", url))
.body(Body::empty()) .body(Body::empty())
.unwrap(); .unwrap();
let resp = app.oneshot(req).await.unwrap(); let resp = app.oneshot(req).await.unwrap();
...@@ -896,7 +896,7 @@ mod worker_management_tests { ...@@ -896,7 +896,7 @@ mod worker_management_tests {
// Try to add unhealthy worker // Try to add unhealthy worker
let req = Request::builder() let req = Request::builder()
.method("POST") .method("POST")
.uri(&format!("/add_worker?url={}", url)) .uri(format!("/add_worker?url={}", url))
.body(Body::empty()) .body(Body::empty())
.unwrap(); .unwrap();
let resp = app.oneshot(req).await.unwrap(); let resp = app.oneshot(req).await.unwrap();
...@@ -1412,7 +1412,7 @@ mod pd_mode_tests { ...@@ -1412,7 +1412,7 @@ mod pd_mode_tests {
// Extract port from prefill URL // Extract port from prefill URL
let prefill_port = prefill_url let prefill_port = prefill_url
.split(':') .split(':')
.last() .next_back()
.and_then(|p| p.trim_end_matches('/').parse::<u16>().ok()) .and_then(|p| p.trim_end_matches('/').parse::<u16>().ok())
.unwrap_or(9000); .unwrap_or(9000);
......
...@@ -116,6 +116,7 @@ fn default_completion_request() -> CompletionRequest { ...@@ -116,6 +116,7 @@ fn default_completion_request() -> CompletionRequest {
} }
} }
#[allow(dead_code)]
fn create_test_worker() -> BasicWorker { fn create_test_worker() -> BasicWorker {
BasicWorker::new( BasicWorker::new(
"http://test-server:8000".to_string(), "http://test-server:8000".to_string(),
......
...@@ -8,6 +8,7 @@ use sglang_router_rs::{ ...@@ -8,6 +8,7 @@ use sglang_router_rs::{
use std::sync::Arc; use std::sync::Arc;
/// Create a test Axum application using the actual server's build_app function /// Create a test Axum application using the actual server's build_app function
#[allow(dead_code)]
pub fn create_test_app( pub fn create_test_app(
router: Arc<dyn RouterTrait>, router: Arc<dyn RouterTrait>,
client: Client, client: Client,
......
...@@ -99,7 +99,7 @@ impl TestContext { ...@@ -99,7 +99,7 @@ impl TestContext {
let worker_url = &worker_urls[0]; let worker_url = &worker_urls[0];
let response = client let response = client
.post(&format!("{}{}", worker_url, endpoint)) .post(format!("{}{}", worker_url, endpoint))
.json(&body) .json(&body)
.send() .send()
.await .await
......
...@@ -100,7 +100,7 @@ impl TestContext { ...@@ -100,7 +100,7 @@ impl TestContext {
let worker_url = &worker_urls[0]; let worker_url = &worker_urls[0];
let response = client let response = client
.post(&format!("{}{}", worker_url, endpoint)) .post(format!("{}{}", worker_url, endpoint))
.json(&body) .json(&body)
.send() .send()
.await .await
...@@ -128,8 +128,8 @@ impl TestContext { ...@@ -128,8 +128,8 @@ impl TestContext {
if let Ok(bytes) = chunk { if let Ok(bytes) = chunk {
let text = String::from_utf8_lossy(&bytes); let text = String::from_utf8_lossy(&bytes);
for line in text.lines() { for line in text.lines() {
if line.starts_with("data: ") { if let Some(stripped) = line.strip_prefix("data: ") {
events.push(line[6..].to_string()); events.push(stripped.to_string());
} }
} }
} }
......
#[cfg(test)] #[cfg(test)]
mod test_pd_routing { mod test_pd_routing {
use rand::Rng;
use serde_json::json; use serde_json::json;
use sglang_router_rs::config::{ use sglang_router_rs::config::{
CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode, CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
...@@ -421,41 +420,6 @@ mod test_pd_routing { ...@@ -421,41 +420,6 @@ mod test_pd_routing {
assert_eq!(received_loads.get("http://decode2:8080"), Some(&15)); assert_eq!(received_loads.get("http://decode2:8080"), Some(&15));
} }
#[test]
fn test_power_of_two_load_selection() {
// Test the power-of-two selection logic with different load scenarios
// Scenario 1: Clear winner for both prefill and decode
let _loads = vec![
("prefill1", 100),
("prefill2", 10), // Should be selected
("decode1", 50),
("decode2", 5), // Should be selected
];
// In actual implementation, the lower load should be selected
assert!(10 < 100);
assert!(5 < 50);
// Scenario 2: Equal loads (should select first)
let _equal_loads = vec![
("prefill1", 20),
("prefill2", 20), // Either could be selected
("decode1", 30),
("decode2", 30), // Either could be selected
];
// When loads are equal, <= comparison means first is selected
assert!(20 <= 20);
assert!(30 <= 30);
// Scenario 3: Missing load data (should default to usize::MAX)
// This tests the unwrap_or(usize::MAX) behavior
let missing_load = usize::MAX;
assert!(10 < missing_load);
assert!(missing_load > 0);
}
#[test] #[test]
fn test_load_monitoring_configuration() { fn test_load_monitoring_configuration() {
// Test that load monitoring is only enabled for PowerOfTwo policy // Test that load monitoring is only enabled for PowerOfTwo policy
...@@ -605,12 +569,10 @@ mod test_pd_routing { ...@@ -605,12 +569,10 @@ mod test_pd_routing {
#[test] #[test]
fn test_streaming_response_parsing() { fn test_streaming_response_parsing() {
// Test SSE format parsing from streaming responses // Test SSE format parsing from streaming responses
let sse_chunks = vec![ let sse_chunks = ["data: {\"text\":\"Hello\",\"meta_info\":{\"completion_tokens\":1,\"finish_reason\":null}}",
"data: {\"text\":\"Hello\",\"meta_info\":{\"completion_tokens\":1,\"finish_reason\":null}}",
"data: {\"text\":\" world\",\"meta_info\":{\"completion_tokens\":2,\"finish_reason\":null}}", "data: {\"text\":\" world\",\"meta_info\":{\"completion_tokens\":2,\"finish_reason\":null}}",
"data: {\"text\":\"!\",\"meta_info\":{\"completion_tokens\":3,\"finish_reason\":{\"type\":\"length\"}}}", "data: {\"text\":\"!\",\"meta_info\":{\"completion_tokens\":3,\"finish_reason\":{\"type\":\"length\"}}}",
"data: [DONE]", "data: [DONE]"];
];
for chunk in &sse_chunks[..3] { for chunk in &sse_chunks[..3] {
assert!(chunk.starts_with("data: ")); assert!(chunk.starts_with("data: "));
...@@ -848,7 +810,7 @@ mod test_pd_routing { ...@@ -848,7 +810,7 @@ mod test_pd_routing {
large_batch_request["bootstrap_host"] = json!(vec![hostname; batch_size]); large_batch_request["bootstrap_host"] = json!(vec![hostname; batch_size]);
large_batch_request["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]); large_batch_request["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
large_batch_request["bootstrap_room"] = json!((0..batch_size) large_batch_request["bootstrap_room"] = json!((0..batch_size)
.map(|_| rand::thread_rng().gen::<u64>()) .map(|_| rand::random::<u64>())
.collect::<Vec<_>>()); .collect::<Vec<_>>());
let elapsed = start.elapsed(); let elapsed = start.elapsed();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment