Unverified Commit aae7ead2 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] remove old/oudated/useless comments across code base (#10968)

parent a7fe6e10
...@@ -205,7 +205,6 @@ impl RoutingMode { ...@@ -205,7 +205,6 @@ impl RoutingMode {
decode_urls, decode_urls,
.. ..
} => prefill_urls.len() + decode_urls.len(), } => prefill_urls.len() + decode_urls.len(),
// OpenAI mode represents a single upstream
RoutingMode::OpenAI { .. } => 1, RoutingMode::OpenAI { .. } => 1,
} }
} }
...@@ -515,8 +514,6 @@ impl RouterConfig { ...@@ -515,8 +514,6 @@ impl RouterConfig {
mod tests { mod tests {
use super::*; use super::*;
// ============= RouterConfig Tests =============
#[test] #[test]
fn test_router_config_default() { fn test_router_config_default() {
let config = RouterConfig::default(); let config = RouterConfig::default();
...@@ -556,7 +553,6 @@ mod tests { ...@@ -556,7 +553,6 @@ mod tests {
} }
assert!(matches!(config.policy, PolicyConfig::RoundRobin)); assert!(matches!(config.policy, PolicyConfig::RoundRobin));
// Other fields should be default
assert_eq!(config.host, "127.0.0.1"); assert_eq!(config.host, "127.0.0.1");
assert_eq!(config.port, 3001); assert_eq!(config.port, 3001);
} }
...@@ -583,13 +579,10 @@ mod tests { ...@@ -583,13 +579,10 @@ mod tests {
assert_eq!(config.max_payload_size, deserialized.max_payload_size); assert_eq!(config.max_payload_size, deserialized.max_payload_size);
assert_eq!(config.log_dir, deserialized.log_dir); assert_eq!(config.log_dir, deserialized.log_dir);
assert_eq!(config.log_level, deserialized.log_level); assert_eq!(config.log_level, deserialized.log_level);
// discovery and metrics are None in Default implementation
assert!(deserialized.discovery.is_none()); assert!(deserialized.discovery.is_none());
assert!(deserialized.metrics.is_none()); assert!(deserialized.metrics.is_none());
} }
// ============= RoutingMode Tests =============
#[test] #[test]
fn test_routing_mode_is_pd_mode() { fn test_routing_mode_is_pd_mode() {
let regular = RoutingMode::Regular { let regular = RoutingMode::Regular {
...@@ -640,7 +633,6 @@ mod tests { ...@@ -640,7 +633,6 @@ mod tests {
#[test] #[test]
fn test_routing_mode_serialization() { fn test_routing_mode_serialization() {
// Test Regular mode
let regular = RoutingMode::Regular { let regular = RoutingMode::Regular {
worker_urls: vec!["http://worker1".to_string()], worker_urls: vec!["http://worker1".to_string()],
}; };
...@@ -648,7 +640,6 @@ mod tests { ...@@ -648,7 +640,6 @@ mod tests {
assert!(json.contains("\"type\":\"regular\"")); assert!(json.contains("\"type\":\"regular\""));
assert!(json.contains("\"worker_urls\"")); assert!(json.contains("\"worker_urls\""));
// Test PrefillDecode mode
let pd = RoutingMode::PrefillDecode { let pd = RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill1".to_string(), Some(8001))], prefill_urls: vec![("http://prefill1".to_string(), Some(8001))],
decode_urls: vec!["http://decode1".to_string()], decode_urls: vec!["http://decode1".to_string()],
...@@ -661,8 +652,6 @@ mod tests { ...@@ -661,8 +652,6 @@ mod tests {
assert!(json.contains("\"decode_urls\"")); assert!(json.contains("\"decode_urls\""));
} }
// ============= PolicyConfig Tests =============
#[test] #[test]
fn test_policy_config_name() { fn test_policy_config_name() {
assert_eq!(PolicyConfig::Random.name(), "random"); assert_eq!(PolicyConfig::Random.name(), "random");
...@@ -685,12 +674,10 @@ mod tests { ...@@ -685,12 +674,10 @@ mod tests {
#[test] #[test]
fn test_policy_config_serialization() { fn test_policy_config_serialization() {
// Test Random
let random = PolicyConfig::Random; let random = PolicyConfig::Random;
let json = serde_json::to_string(&random).unwrap(); let json = serde_json::to_string(&random).unwrap();
assert_eq!(json, r#"{"type":"random"}"#); assert_eq!(json, r#"{"type":"random"}"#);
// Test CacheAware with all parameters
let cache_aware = PolicyConfig::CacheAware { let cache_aware = PolicyConfig::CacheAware {
cache_threshold: 0.8, cache_threshold: 0.8,
balance_abs_threshold: 10, balance_abs_threshold: 10,
...@@ -703,7 +690,6 @@ mod tests { ...@@ -703,7 +690,6 @@ mod tests {
assert!(json.contains("\"cache_threshold\":0.8")); assert!(json.contains("\"cache_threshold\":0.8"));
assert!(json.contains("\"balance_abs_threshold\":10")); assert!(json.contains("\"balance_abs_threshold\":10"));
// Test PowerOfTwo
let power_of_two = PolicyConfig::PowerOfTwo { let power_of_two = PolicyConfig::PowerOfTwo {
load_check_interval_secs: 60, load_check_interval_secs: 60,
}; };
...@@ -756,8 +742,6 @@ mod tests { ...@@ -756,8 +742,6 @@ mod tests {
} }
} }
// ============= DiscoveryConfig Tests =============
#[test] #[test]
fn test_discovery_config_default() { fn test_discovery_config_default() {
let config = DiscoveryConfig::default(); let config = DiscoveryConfig::default();
...@@ -798,14 +782,12 @@ mod tests { ...@@ -798,14 +782,12 @@ mod tests {
#[test] #[test]
fn test_discovery_config_namespace() { fn test_discovery_config_namespace() {
// Test None namespace (all namespaces)
let config = DiscoveryConfig { let config = DiscoveryConfig {
namespace: None, namespace: None,
..Default::default() ..Default::default()
}; };
assert!(config.namespace.is_none()); assert!(config.namespace.is_none());
// Test specific namespace
let config = DiscoveryConfig { let config = DiscoveryConfig {
namespace: Some("production".to_string()), namespace: Some("production".to_string()),
..Default::default() ..Default::default()
...@@ -813,8 +795,6 @@ mod tests { ...@@ -813,8 +795,6 @@ mod tests {
assert_eq!(config.namespace, Some("production".to_string())); assert_eq!(config.namespace, Some("production".to_string()));
} }
// ============= MetricsConfig Tests =============
#[test] #[test]
fn test_metrics_config_default() { fn test_metrics_config_default() {
let config = MetricsConfig::default(); let config = MetricsConfig::default();
...@@ -834,8 +814,6 @@ mod tests { ...@@ -834,8 +814,6 @@ mod tests {
assert_eq!(config.host, "0.0.0.0"); assert_eq!(config.host, "0.0.0.0");
} }
// ============= RouterConfig Utility Methods Tests =============
#[test] #[test]
fn test_mode_type() { fn test_mode_type() {
let config = RouterConfig { let config = RouterConfig {
...@@ -894,8 +872,6 @@ mod tests { ...@@ -894,8 +872,6 @@ mod tests {
assert!(config.has_metrics()); assert!(config.has_metrics());
} }
// ============= Edge Cases =============
#[test] #[test]
fn test_large_worker_lists() { fn test_large_worker_lists() {
let large_urls: Vec<String> = (0..1000).map(|i| format!("http://worker{}", i)).collect(); let large_urls: Vec<String> = (0..1000).map(|i| format!("http://worker{}", i)).collect();
...@@ -906,7 +882,6 @@ mod tests { ...@@ -906,7 +882,6 @@ mod tests {
assert_eq!(mode.worker_count(), 1000); assert_eq!(mode.worker_count(), 1000);
// Test serialization with large list
let config = RouterConfig { let config = RouterConfig {
mode, mode,
..Default::default() ..Default::default()
...@@ -961,8 +936,6 @@ mod tests { ...@@ -961,8 +936,6 @@ mod tests {
assert_eq!(config.log_level, Some("".to_string())); assert_eq!(config.log_level, Some("".to_string()));
} }
// ============= Complex Configuration Tests =============
#[test] #[test]
fn test_full_pd_mode_config() { fn test_full_pd_mode_config() {
let config = RouterConfig { let config = RouterConfig {
...@@ -1149,7 +1122,6 @@ mod tests { ...@@ -1149,7 +1122,6 @@ mod tests {
assert!(config.has_metrics()); assert!(config.has_metrics());
assert_eq!(config.mode_type(), "regular"); assert_eq!(config.mode_type(), "regular");
// Test round-trip serialization
let json = serde_json::to_string_pretty(&config).unwrap(); let json = serde_json::to_string_pretty(&config).unwrap();
let deserialized: RouterConfig = serde_json::from_str(&json).unwrap(); let deserialized: RouterConfig = serde_json::from_str(&json).unwrap();
...@@ -1161,11 +1133,8 @@ mod tests { ...@@ -1161,11 +1133,8 @@ mod tests {
); );
} }
// ============= Policy Fallback Tests =============
#[test] #[test]
fn test_pd_policy_fallback_both_specified() { fn test_pd_policy_fallback_both_specified() {
// When both prefill and decode policies are specified, they should be used
let pd = RoutingMode::PrefillDecode { let pd = RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill1".to_string(), None)], prefill_urls: vec![("http://prefill1".to_string(), None)],
decode_urls: vec!["http://decode1".to_string()], decode_urls: vec!["http://decode1".to_string()],
...@@ -1183,21 +1152,19 @@ mod tests { ...@@ -1183,21 +1152,19 @@ mod tests {
let main_policy = PolicyConfig::Random; let main_policy = PolicyConfig::Random;
// Both specific policies should be used
match pd.get_prefill_policy(&main_policy) { match pd.get_prefill_policy(&main_policy) {
PolicyConfig::CacheAware { .. } => {} // Success PolicyConfig::CacheAware { .. } => {}
_ => panic!("Expected CacheAware for prefill"), _ => panic!("Expected CacheAware for prefill"),
} }
match pd.get_decode_policy(&main_policy) { match pd.get_decode_policy(&main_policy) {
PolicyConfig::PowerOfTwo { .. } => {} // Success PolicyConfig::PowerOfTwo { .. } => {}
_ => panic!("Expected PowerOfTwo for decode"), _ => panic!("Expected PowerOfTwo for decode"),
} }
} }
#[test] #[test]
fn test_pd_policy_fallback_only_prefill() { fn test_pd_policy_fallback_only_prefill() {
// When only prefill policy is specified, decode should use main policy
let pd = RoutingMode::PrefillDecode { let pd = RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill1".to_string(), None)], prefill_urls: vec![("http://prefill1".to_string(), None)],
decode_urls: vec!["http://decode1".to_string()], decode_urls: vec!["http://decode1".to_string()],
...@@ -1213,22 +1180,19 @@ mod tests { ...@@ -1213,22 +1180,19 @@ mod tests {
let main_policy = PolicyConfig::RoundRobin; let main_policy = PolicyConfig::RoundRobin;
// Prefill should use specific policy
match pd.get_prefill_policy(&main_policy) { match pd.get_prefill_policy(&main_policy) {
PolicyConfig::CacheAware { .. } => {} // Success PolicyConfig::CacheAware { .. } => {}
_ => panic!("Expected CacheAware for prefill"), _ => panic!("Expected CacheAware for prefill"),
} }
// Decode should fall back to main policy
match pd.get_decode_policy(&main_policy) { match pd.get_decode_policy(&main_policy) {
PolicyConfig::RoundRobin => {} // Success PolicyConfig::RoundRobin => {}
_ => panic!("Expected RoundRobin for decode"), _ => panic!("Expected RoundRobin for decode"),
} }
} }
#[test] #[test]
fn test_pd_policy_fallback_only_decode() { fn test_pd_policy_fallback_only_decode() {
// When only decode policy is specified, prefill should use main policy
let pd = RoutingMode::PrefillDecode { let pd = RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill1".to_string(), None)], prefill_urls: vec![("http://prefill1".to_string(), None)],
decode_urls: vec!["http://decode1".to_string()], decode_urls: vec!["http://decode1".to_string()],
...@@ -1240,22 +1204,19 @@ mod tests { ...@@ -1240,22 +1204,19 @@ mod tests {
let main_policy = PolicyConfig::Random; let main_policy = PolicyConfig::Random;
// Prefill should fall back to main policy
match pd.get_prefill_policy(&main_policy) { match pd.get_prefill_policy(&main_policy) {
PolicyConfig::Random => {} // Success PolicyConfig::Random => {}
_ => panic!("Expected Random for prefill"), _ => panic!("Expected Random for prefill"),
} }
// Decode should use specific policy
match pd.get_decode_policy(&main_policy) { match pd.get_decode_policy(&main_policy) {
PolicyConfig::PowerOfTwo { .. } => {} // Success PolicyConfig::PowerOfTwo { .. } => {}
_ => panic!("Expected PowerOfTwo for decode"), _ => panic!("Expected PowerOfTwo for decode"),
} }
} }
#[test] #[test]
fn test_pd_policy_fallback_none_specified() { fn test_pd_policy_fallback_none_specified() {
// When no specific policies are specified, both should use main policy
let pd = RoutingMode::PrefillDecode { let pd = RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill1".to_string(), None)], prefill_urls: vec![("http://prefill1".to_string(), None)],
decode_urls: vec!["http://decode1".to_string()], decode_urls: vec!["http://decode1".to_string()],
...@@ -1271,7 +1232,6 @@ mod tests { ...@@ -1271,7 +1232,6 @@ mod tests {
max_tree_size: 2000, max_tree_size: 2000,
}; };
// Both should fall back to main policy
match pd.get_prefill_policy(&main_policy) { match pd.get_prefill_policy(&main_policy) {
PolicyConfig::CacheAware { PolicyConfig::CacheAware {
cache_threshold, .. cache_threshold, ..
...@@ -1293,21 +1253,19 @@ mod tests { ...@@ -1293,21 +1253,19 @@ mod tests {
#[test] #[test]
fn test_regular_mode_policy_fallback() { fn test_regular_mode_policy_fallback() {
// For regular mode, the helper methods should just return the main policy
let regular = RoutingMode::Regular { let regular = RoutingMode::Regular {
worker_urls: vec!["http://worker1".to_string()], worker_urls: vec!["http://worker1".to_string()],
}; };
let main_policy = PolicyConfig::RoundRobin; let main_policy = PolicyConfig::RoundRobin;
// Both methods should return main policy for regular mode
match regular.get_prefill_policy(&main_policy) { match regular.get_prefill_policy(&main_policy) {
PolicyConfig::RoundRobin => {} // Success PolicyConfig::RoundRobin => {}
_ => panic!("Expected RoundRobin for regular mode"), _ => panic!("Expected RoundRobin for regular mode"),
} }
match regular.get_decode_policy(&main_policy) { match regular.get_decode_policy(&main_policy) {
PolicyConfig::RoundRobin => {} // Success PolicyConfig::RoundRobin => {}
_ => panic!("Expected RoundRobin for regular mode"), _ => panic!("Expected RoundRobin for regular mode"),
} }
} }
......
...@@ -670,7 +670,6 @@ mod tests { ...@@ -670,7 +670,6 @@ mod tests {
#[test] #[test]
fn test_validate_pd_mode_with_separate_policies() { fn test_validate_pd_mode_with_separate_policies() {
// Test PD mode with different policies for prefill and decode
let config = RouterConfig::new( let config = RouterConfig::new(
RoutingMode::PrefillDecode { RoutingMode::PrefillDecode {
prefill_urls: vec![ prefill_urls: vec![
...@@ -701,7 +700,6 @@ mod tests { ...@@ -701,7 +700,6 @@ mod tests {
#[test] #[test]
fn test_validate_pd_mode_power_of_two_insufficient_workers() { fn test_validate_pd_mode_power_of_two_insufficient_workers() {
// Test that power-of-two policy requires at least 2 workers
let config = RouterConfig::new( let config = RouterConfig::new(
RoutingMode::PrefillDecode { RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill1:8000".to_string(), None)], // Only 1 prefill prefill_urls: vec![("http://prefill1:8000".to_string(), None)], // Only 1 prefill
...@@ -726,7 +724,6 @@ mod tests { ...@@ -726,7 +724,6 @@ mod tests {
#[test] #[test]
fn test_validate_grpc_requires_tokenizer() { fn test_validate_grpc_requires_tokenizer() {
// Test that gRPC connection mode requires tokenizer configuration
let mut config = RouterConfig::new( let mut config = RouterConfig::new(
RoutingMode::Regular { RoutingMode::Regular {
worker_urls: vec!["grpc://worker:50051".to_string()], worker_urls: vec!["grpc://worker:50051".to_string()],
...@@ -748,7 +745,6 @@ mod tests { ...@@ -748,7 +745,6 @@ mod tests {
#[test] #[test]
fn test_validate_grpc_with_model_path() { fn test_validate_grpc_with_model_path() {
// Test that gRPC works with model_path
let mut config = RouterConfig::new( let mut config = RouterConfig::new(
RoutingMode::Regular { RoutingMode::Regular {
worker_urls: vec!["grpc://worker:50051".to_string()], worker_urls: vec!["grpc://worker:50051".to_string()],
...@@ -765,7 +761,6 @@ mod tests { ...@@ -765,7 +761,6 @@ mod tests {
#[test] #[test]
fn test_validate_grpc_with_tokenizer_path() { fn test_validate_grpc_with_tokenizer_path() {
// Test that gRPC works with tokenizer_path
let mut config = RouterConfig::new( let mut config = RouterConfig::new(
RoutingMode::Regular { RoutingMode::Regular {
worker_urls: vec!["grpc://worker:50051".to_string()], worker_urls: vec!["grpc://worker:50051".to_string()],
......
...@@ -336,7 +336,6 @@ mod tests { ...@@ -336,7 +336,6 @@ mod tests {
}; };
let cb = CircuitBreaker::with_config(config); let cb = CircuitBreaker::with_config(config);
// Record failures up to threshold
assert_eq!(cb.state(), CircuitState::Closed); assert_eq!(cb.state(), CircuitState::Closed);
cb.record_failure(); cb.record_failure();
assert_eq!(cb.state(), CircuitState::Closed); assert_eq!(cb.state(), CircuitState::Closed);
...@@ -344,7 +343,6 @@ mod tests { ...@@ -344,7 +343,6 @@ mod tests {
assert_eq!(cb.state(), CircuitState::Closed); assert_eq!(cb.state(), CircuitState::Closed);
cb.record_failure(); cb.record_failure();
// Circuit should now be open
assert_eq!(cb.state(), CircuitState::Open); assert_eq!(cb.state(), CircuitState::Open);
assert!(!cb.can_execute()); assert!(!cb.can_execute());
assert_eq!(cb.failure_count(), 3); assert_eq!(cb.failure_count(), 3);
...@@ -359,14 +357,11 @@ mod tests { ...@@ -359,14 +357,11 @@ mod tests {
}; };
let cb = CircuitBreaker::with_config(config); let cb = CircuitBreaker::with_config(config);
// Open the circuit
cb.record_failure(); cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open); assert_eq!(cb.state(), CircuitState::Open);
// Wait for timeout
thread::sleep(Duration::from_millis(150)); thread::sleep(Duration::from_millis(150));
// Circuit should be half-open
assert_eq!(cb.state(), CircuitState::HalfOpen); assert_eq!(cb.state(), CircuitState::HalfOpen);
assert!(cb.can_execute()); assert!(cb.can_execute());
} }
...@@ -381,20 +376,16 @@ mod tests { ...@@ -381,20 +376,16 @@ mod tests {
}; };
let cb = CircuitBreaker::with_config(config); let cb = CircuitBreaker::with_config(config);
// Open the circuit
cb.record_failure(); cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open); assert_eq!(cb.state(), CircuitState::Open);
// Wait for timeout
thread::sleep(Duration::from_millis(100)); thread::sleep(Duration::from_millis(100));
assert_eq!(cb.state(), CircuitState::HalfOpen); assert_eq!(cb.state(), CircuitState::HalfOpen);
// Record successes
cb.record_success(); cb.record_success();
assert_eq!(cb.state(), CircuitState::HalfOpen); assert_eq!(cb.state(), CircuitState::HalfOpen);
cb.record_success(); cb.record_success();
// Circuit should now be closed
assert_eq!(cb.state(), CircuitState::Closed); assert_eq!(cb.state(), CircuitState::Closed);
assert!(cb.can_execute()); assert!(cb.can_execute());
} }
...@@ -408,18 +399,14 @@ mod tests { ...@@ -408,18 +399,14 @@ mod tests {
}; };
let cb = CircuitBreaker::with_config(config); let cb = CircuitBreaker::with_config(config);
// Open the circuit
cb.record_failure(); cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open); assert_eq!(cb.state(), CircuitState::Open);
// Wait for timeout
thread::sleep(Duration::from_millis(100)); thread::sleep(Duration::from_millis(100));
assert_eq!(cb.state(), CircuitState::HalfOpen); assert_eq!(cb.state(), CircuitState::HalfOpen);
// Record a failure in half-open state
cb.record_failure(); cb.record_failure();
// Circuit should reopen immediately
assert_eq!(cb.state(), CircuitState::Open); assert_eq!(cb.state(), CircuitState::Open);
assert!(!cb.can_execute()); assert!(!cb.can_execute());
} }
...@@ -432,17 +419,14 @@ mod tests { ...@@ -432,17 +419,14 @@ mod tests {
}; };
let cb = CircuitBreaker::with_config(config); let cb = CircuitBreaker::with_config(config);
// Record some failures
cb.record_failure(); cb.record_failure();
cb.record_failure(); cb.record_failure();
assert_eq!(cb.failure_count(), 2); assert_eq!(cb.failure_count(), 2);
// Success should reset failure count
cb.record_success(); cb.record_success();
assert_eq!(cb.failure_count(), 0); assert_eq!(cb.failure_count(), 0);
assert_eq!(cb.success_count(), 1); assert_eq!(cb.success_count(), 1);
// Can now record more failures without opening
cb.record_failure(); cb.record_failure();
cb.record_failure(); cb.record_failure();
assert_eq!(cb.state(), CircuitState::Closed); assert_eq!(cb.state(), CircuitState::Closed);
...@@ -456,11 +440,9 @@ mod tests { ...@@ -456,11 +440,9 @@ mod tests {
}; };
let cb = CircuitBreaker::with_config(config); let cb = CircuitBreaker::with_config(config);
// Open the circuit
cb.record_failure(); cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open); assert_eq!(cb.state(), CircuitState::Open);
// Manual reset
cb.reset(); cb.reset();
assert_eq!(cb.state(), CircuitState::Closed); assert_eq!(cb.state(), CircuitState::Closed);
assert_eq!(cb.failure_count(), 0); assert_eq!(cb.failure_count(), 0);
...@@ -505,7 +487,6 @@ mod tests { ...@@ -505,7 +487,6 @@ mod tests {
let cb2 = cb1.clone(); let cb2 = cb1.clone();
assert_eq!(cb2.failure_count(), 1); assert_eq!(cb2.failure_count(), 1);
// Changes to cb1 affect cb2 (shared state)
cb1.record_failure(); cb1.record_failure();
assert_eq!(cb2.failure_count(), 2); assert_eq!(cb2.failure_count(), 2);
} }
......
...@@ -1562,19 +1562,16 @@ mod tests { ...@@ -1562,19 +1562,16 @@ mod tests {
.worker_type(WorkerType::Regular) .worker_type(WorkerType::Regular)
.build(); .build();
// Test health status
assert!(dp_worker.is_healthy()); assert!(dp_worker.is_healthy());
dp_worker.set_healthy(false); dp_worker.set_healthy(false);
assert!(!dp_worker.is_healthy()); assert!(!dp_worker.is_healthy());
// Test load tracking
assert_eq!(dp_worker.load(), 0); assert_eq!(dp_worker.load(), 0);
dp_worker.increment_load(); dp_worker.increment_load();
assert_eq!(dp_worker.load(), 1); assert_eq!(dp_worker.load(), 1);
dp_worker.decrement_load(); dp_worker.decrement_load();
assert_eq!(dp_worker.load(), 0); assert_eq!(dp_worker.load(), 0);
// Test processed tracking
assert_eq!(dp_worker.processed_requests(), 0); assert_eq!(dp_worker.processed_requests(), 0);
dp_worker.increment_processed(); dp_worker.increment_processed();
assert_eq!(dp_worker.processed_requests(), 1); assert_eq!(dp_worker.processed_requests(), 1);
......
...@@ -1485,7 +1485,6 @@ mod tests { ...@@ -1485,7 +1485,6 @@ mod tests {
#[test] #[test]
fn test_parse_server_info_with_fallback() { fn test_parse_server_info_with_fallback() {
// Test with "model" instead of "model_id"
let json = serde_json::json!({ let json = serde_json::json!({
"model": "gpt-4", "model": "gpt-4",
"dp_size": 2 "dp_size": 2
......
...@@ -459,14 +459,12 @@ mod tests { ...@@ -459,14 +459,12 @@ mod tests {
// Register worker (WorkerFactory returns Box<dyn Worker>, convert to Arc) // Register worker (WorkerFactory returns Box<dyn Worker>, convert to Arc)
let worker_id = registry.register(Arc::from(worker)); let worker_id = registry.register(Arc::from(worker));
// Verify registration
assert!(registry.get(&worker_id).is_some()); assert!(registry.get(&worker_id).is_some());
assert!(registry.get_by_url("http://worker1:8080").is_some()); assert!(registry.get_by_url("http://worker1:8080").is_some());
assert_eq!(registry.get_by_model("llama-3-8b").len(), 1); assert_eq!(registry.get_by_model("llama-3-8b").len(), 1);
assert_eq!(registry.get_by_type(&WorkerType::Regular).len(), 1); assert_eq!(registry.get_by_type(&WorkerType::Regular).len(), 1);
assert_eq!(registry.get_by_connection(&ConnectionMode::Http).len(), 1); assert_eq!(registry.get_by_connection(&ConnectionMode::Http).len(), 1);
// Test stats
let stats = registry.stats(); let stats = registry.stats();
assert_eq!(stats.total_workers, 1); assert_eq!(stats.total_workers, 1);
assert_eq!(stats.total_models, 1); assert_eq!(stats.total_models, 1);
...@@ -519,27 +517,22 @@ mod tests { ...@@ -519,27 +517,22 @@ mod tests {
registry.register(Arc::from(worker2)); registry.register(Arc::from(worker2));
registry.register(Arc::from(worker3)); registry.register(Arc::from(worker3));
// Test get_by_model_fast for llama-3
let llama_workers = registry.get_by_model_fast("llama-3"); let llama_workers = registry.get_by_model_fast("llama-3");
assert_eq!(llama_workers.len(), 2); assert_eq!(llama_workers.len(), 2);
let urls: Vec<String> = llama_workers.iter().map(|w| w.url().to_string()).collect(); let urls: Vec<String> = llama_workers.iter().map(|w| w.url().to_string()).collect();
assert!(urls.contains(&"http://worker1:8080".to_string())); assert!(urls.contains(&"http://worker1:8080".to_string()));
assert!(urls.contains(&"http://worker2:8080".to_string())); assert!(urls.contains(&"http://worker2:8080".to_string()));
// Test get_by_model_fast for gpt-4
let gpt_workers = registry.get_by_model_fast("gpt-4"); let gpt_workers = registry.get_by_model_fast("gpt-4");
assert_eq!(gpt_workers.len(), 1); assert_eq!(gpt_workers.len(), 1);
assert_eq!(gpt_workers[0].url(), "http://worker3:8080"); assert_eq!(gpt_workers[0].url(), "http://worker3:8080");
// Test get_by_model_fast for non-existent model
let unknown_workers = registry.get_by_model_fast("unknown-model"); let unknown_workers = registry.get_by_model_fast("unknown-model");
assert_eq!(unknown_workers.len(), 0); assert_eq!(unknown_workers.len(), 0);
// Test that both get_by_model and get_by_model_fast return same results
let llama_workers_slow = registry.get_by_model("llama-3"); let llama_workers_slow = registry.get_by_model("llama-3");
assert_eq!(llama_workers.len(), llama_workers_slow.len()); assert_eq!(llama_workers.len(), llama_workers_slow.len());
// Test removal updates the model index
registry.remove_by_url("http://worker1:8080"); registry.remove_by_url("http://worker1:8080");
let llama_workers_after = registry.get_by_model_fast("llama-3"); let llama_workers_after = registry.get_by_model_fast("llama-3");
assert_eq!(llama_workers_after.len(), 1); assert_eq!(llama_workers_after.len(), 1);
......
...@@ -266,7 +266,6 @@ mod tests { ...@@ -266,7 +266,6 @@ mod tests {
assert_eq!(chain.responses[1].input, "Second"); assert_eq!(chain.responses[1].input, "Second");
assert_eq!(chain.responses[2].input, "Third"); assert_eq!(chain.responses[2].input, "Third");
// Test with max_depth
let limited_chain = store.get_response_chain(&id3, Some(2)).await.unwrap(); let limited_chain = store.get_response_chain(&id3, Some(2)).await.unwrap();
assert_eq!(limited_chain.responses.len(), 2); assert_eq!(limited_chain.responses.len(), 2);
assert_eq!(limited_chain.responses[0].input, "Second"); assert_eq!(limited_chain.responses[0].input, "Second");
...@@ -314,7 +313,6 @@ mod tests { ...@@ -314,7 +313,6 @@ mod tests {
let deleted_count = store.delete_user_responses("user1").await.unwrap(); let deleted_count = store.delete_user_responses("user1").await.unwrap();
assert_eq!(deleted_count, 2); assert_eq!(deleted_count, 2);
// Verify they're gone
let user1_responses_after = store.list_user_responses("user1", None).await.unwrap(); let user1_responses_after = store.list_user_responses("user1", None).await.unwrap();
assert_eq!(user1_responses_after.len(), 0); assert_eq!(user1_responses_after.len(), 0);
......
...@@ -223,7 +223,6 @@ mod tests { ...@@ -223,7 +223,6 @@ mod tests {
#[test] #[test]
fn test_proto_types_compilation() { fn test_proto_types_compilation() {
// Test that protobuf types can be constructed
let health_req = proto::HealthCheckRequest { let health_req = proto::HealthCheckRequest {
tokenized: Some(proto::TokenizedInput { tokenized: Some(proto::TokenizedInput {
original_text: "test".to_string(), original_text: "test".to_string(),
...@@ -320,8 +319,6 @@ mod tests { ...@@ -320,8 +319,6 @@ mod tests {
} }
// TODO: SessionParams not in current proto - skip test // TODO: SessionParams not in current proto - skip test
// #[test]
// fn test_session_params() { ... }
#[test] #[test]
fn test_embed_request() { fn test_embed_request() {
...@@ -349,7 +346,6 @@ mod tests { ...@@ -349,7 +346,6 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_client_connect_invalid_endpoint() { async fn test_client_connect_invalid_endpoint() {
// Test connecting to an invalid endpoint should return error
let result = SglangSchedulerClient::connect("invalid://endpoint").await; let result = SglangSchedulerClient::connect("invalid://endpoint").await;
assert!(result.is_err()); assert!(result.is_err());
} }
...@@ -365,7 +361,6 @@ mod tests { ...@@ -365,7 +361,6 @@ mod tests {
assert_eq!(tokenized.input_ids, vec![1, 15043, 1917, 2]); assert_eq!(tokenized.input_ids, vec![1, 15043, 1917, 2]);
} }
// Test response type construction
#[test] #[test]
fn test_generate_stream_chunk() { fn test_generate_stream_chunk() {
let chunk = proto::GenerateStreamChunk { let chunk = proto::GenerateStreamChunk {
...@@ -383,6 +378,4 @@ mod tests { ...@@ -383,6 +378,4 @@ mod tests {
} }
// TODO: ModelInfo not in current proto - skip test // TODO: ModelInfo not in current proto - skip test
// #[test]
// fn test_model_info() { ... }
} }
...@@ -288,8 +288,6 @@ impl McpClientManager { ...@@ -288,8 +288,6 @@ impl McpClientManager {
} }
} }
// ===== Helpers =====
fn client_for(&self, server_name: &str) -> McpResult<&RunningService<RoleClient, ()>> { fn client_for(&self, server_name: &str) -> McpResult<&RunningService<RoleClient, ()>> {
self.clients self.clients
.get(server_name) .get(server_name)
...@@ -317,8 +315,6 @@ impl McpClientManager { ...@@ -317,8 +315,6 @@ impl McpClientManager {
.ok_or_else(|| McpError::ResourceNotFound(uri.to_string())) .ok_or_else(|| McpError::ResourceNotFound(uri.to_string()))
} }
// ===== Tool Methods =====
/// Call a tool by name /// Call a tool by name
pub async fn call_tool( pub async fn call_tool(
&self, &self,
...@@ -380,8 +376,6 @@ impl McpClientManager { ...@@ -380,8 +376,6 @@ impl McpClientManager {
self.clients.keys().cloned().collect() self.clients.keys().cloned().collect()
} }
// ===== Prompt Methods =====
/// Get a prompt by name with arguments /// Get a prompt by name with arguments
pub async fn get_prompt( pub async fn get_prompt(
&self, &self,
...@@ -439,8 +433,6 @@ impl McpClientManager { ...@@ -439,8 +433,6 @@ impl McpClientManager {
}) })
} }
// ===== Resource Methods =====
/// Read a resource by URI /// Read a resource by URI
pub async fn read_resource(&self, uri: &str) -> McpResult<ReadResourceResult> { pub async fn read_resource(&self, uri: &str) -> McpResult<ReadResourceResult> {
let (server_name, _resource) = self.resource_entry(uri)?; let (server_name, _resource) = self.resource_entry(uri)?;
......
...@@ -598,8 +598,6 @@ mod tests { ...@@ -598,8 +598,6 @@ mod tests {
use super::*; use super::*;
use std::net::TcpListener; use std::net::TcpListener;
// ============= PrometheusConfig Tests =============
#[test] #[test]
fn test_prometheus_config_default() { fn test_prometheus_config_default() {
let config = PrometheusConfig::default(); let config = PrometheusConfig::default();
...@@ -628,8 +626,6 @@ mod tests { ...@@ -628,8 +626,6 @@ mod tests {
assert_eq!(cloned.host, config.host); assert_eq!(cloned.host, config.host);
} }
// ============= IP Address Parsing Tests =============
#[test] #[test]
fn test_valid_ipv4_parsing() { fn test_valid_ipv4_parsing() {
let test_cases = vec!["127.0.0.1", "192.168.1.1", "0.0.0.0"]; let test_cases = vec!["127.0.0.1", "192.168.1.1", "0.0.0.0"];
...@@ -679,8 +675,6 @@ mod tests { ...@@ -679,8 +675,6 @@ mod tests {
} }
} }
// ============= Socket Address Creation Tests =============
#[test] #[test]
fn test_socket_addr_creation() { fn test_socket_addr_creation() {
let test_cases = vec![("127.0.0.1", 8080), ("0.0.0.0", 29000), ("::1", 9090)]; let test_cases = vec![("127.0.0.1", 8080), ("0.0.0.0", 29000), ("::1", 9090)];
...@@ -716,8 +710,6 @@ mod tests { ...@@ -716,8 +710,6 @@ mod tests {
} }
} }
// ============= Duration Bucket Tests =============
#[test] #[test]
fn test_duration_bucket_coverage() { fn test_duration_bucket_coverage() {
let test_cases: [(f64, &str); 7] = [ let test_cases: [(f64, &str); 7] = [
...@@ -743,8 +735,6 @@ mod tests { ...@@ -743,8 +735,6 @@ mod tests {
} }
} }
// ============= Matcher Configuration Tests =============
#[test] #[test]
fn test_duration_suffix_matcher() { fn test_duration_suffix_matcher() {
let matcher = Matcher::Suffix(String::from("duration_seconds")); let matcher = Matcher::Suffix(String::from("duration_seconds"));
...@@ -763,8 +753,6 @@ mod tests { ...@@ -763,8 +753,6 @@ mod tests {
} }
} }
// ============= Builder Configuration Tests =============
#[test] #[test]
fn test_prometheus_builder_configuration() { fn test_prometheus_builder_configuration() {
let _config = PrometheusConfig::default(); let _config = PrometheusConfig::default();
...@@ -783,16 +771,12 @@ mod tests { ...@@ -783,16 +771,12 @@ mod tests {
} }
} }
// ============= Upkeep Timeout Tests =============
#[test] #[test]
fn test_upkeep_timeout_duration() { fn test_upkeep_timeout_duration() {
let timeout = Duration::from_secs(5 * 60); let timeout = Duration::from_secs(5 * 60);
assert_eq!(timeout.as_secs(), 300); assert_eq!(timeout.as_secs(), 300);
} }
// ============= Custom Bucket Tests =============
#[test] #[test]
fn test_custom_buckets_for_different_metrics() { fn test_custom_buckets_for_different_metrics() {
let request_buckets = [0.001, 0.01, 0.1, 1.0, 10.0]; let request_buckets = [0.001, 0.01, 0.1, 1.0, 10.0];
...@@ -810,8 +794,6 @@ mod tests { ...@@ -810,8 +794,6 @@ mod tests {
} }
} }
// ============= RouterMetrics Tests =============
#[test] #[test]
fn test_metrics_static_methods() { fn test_metrics_static_methods() {
RouterMetrics::record_request("/generate"); RouterMetrics::record_request("/generate");
...@@ -876,8 +858,6 @@ mod tests { ...@@ -876,8 +858,6 @@ mod tests {
TokenizerMetrics::set_vocab_size("huggingface", 50000); TokenizerMetrics::set_vocab_size("huggingface", 50000);
} }
// ============= Port Availability Tests =============
#[test] #[test]
fn test_port_already_in_use() { fn test_port_already_in_use() {
let port = 29123; let port = 29123;
...@@ -892,8 +872,6 @@ mod tests { ...@@ -892,8 +872,6 @@ mod tests {
} }
} }
// ============= Integration Test Helpers =============
#[test] #[test]
fn test_metrics_endpoint_accessibility() { fn test_metrics_endpoint_accessibility() {
let config = PrometheusConfig { let config = PrometheusConfig {
...@@ -937,8 +915,6 @@ mod tests { ...@@ -937,8 +915,6 @@ mod tests {
} }
} }
// ============= Edge Cases Tests =============
#[test] #[test]
fn test_empty_string_metrics() { fn test_empty_string_metrics() {
RouterMetrics::record_request(""); RouterMetrics::record_request("");
......
...@@ -178,8 +178,6 @@ where ...@@ -178,8 +178,6 @@ where
} }
} }
// ============= Logging Middleware =============
/// Custom span maker that includes request ID /// Custom span maker that includes request ID
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct RequestSpan; pub struct RequestSpan;
...@@ -336,8 +334,6 @@ pub fn log_request(entry: RequestLogEntry) { ...@@ -336,8 +334,6 @@ pub fn log_request(entry: RequestLogEntry) {
} }
} }
// ============ Concurrency Limiting with Queue Support ============
/// Request queue entry /// Request queue entry
pub struct QueuedRequest { pub struct QueuedRequest {
/// Time when the request was queued /// Time when the request was queued
......
...@@ -54,21 +54,17 @@ mod tests { ...@@ -54,21 +54,17 @@ mod tests {
#[test] #[test]
fn test_create_from_config() { fn test_create_from_config() {
// Test Random
let policy = PolicyFactory::create_from_config(&PolicyConfig::Random); let policy = PolicyFactory::create_from_config(&PolicyConfig::Random);
assert_eq!(policy.name(), "random"); assert_eq!(policy.name(), "random");
// Test RoundRobin
let policy = PolicyFactory::create_from_config(&PolicyConfig::RoundRobin); let policy = PolicyFactory::create_from_config(&PolicyConfig::RoundRobin);
assert_eq!(policy.name(), "round_robin"); assert_eq!(policy.name(), "round_robin");
// Test PowerOfTwo
let policy = PolicyFactory::create_from_config(&PolicyConfig::PowerOfTwo { let policy = PolicyFactory::create_from_config(&PolicyConfig::PowerOfTwo {
load_check_interval_secs: 60, load_check_interval_secs: 60,
}); });
assert_eq!(policy.name(), "power_of_two"); assert_eq!(policy.name(), "power_of_two");
// Test CacheAware
let policy = PolicyFactory::create_from_config(&PolicyConfig::CacheAware { let policy = PolicyFactory::create_from_config(&PolicyConfig::CacheAware {
cache_threshold: 0.7, cache_threshold: 0.7,
balance_abs_threshold: 10, balance_abs_threshold: 10,
......
...@@ -75,7 +75,6 @@ mod tests { ...@@ -75,7 +75,6 @@ mod tests {
), ),
]; ];
// Test multiple selections to ensure randomness
let mut counts = HashMap::new(); let mut counts = HashMap::new();
for _ in 0..100 { for _ in 0..100 {
if let Some(idx) = policy.select_worker(&workers, None) { if let Some(idx) = policy.select_worker(&workers, None) {
......
...@@ -49,12 +49,6 @@ use std::collections::HashMap; ...@@ -49,12 +49,6 @@ use std::collections::HashMap;
// - StringOrArray & LoRAPath types // - StringOrArray & LoRAPath types
// - Helper functions // - Helper functions
// ==================================================================
// = OPENAI SPEC - Chat Completions API =
// ==================================================================
// ============= Message Types =============
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)] #[serde(untagged)]
pub enum ChatMessage { pub enum ChatMessage {
...@@ -119,8 +113,6 @@ pub struct ImageUrl { ...@@ -119,8 +113,6 @@ pub struct ImageUrl {
pub detail: Option<String>, // "auto", "low", or "high" pub detail: Option<String>, // "auto", "low", or "high"
} }
// ============= Response Format Types =============
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type")] #[serde(tag = "type")]
pub enum ResponseFormat { pub enum ResponseFormat {
...@@ -140,8 +132,6 @@ pub struct JsonSchemaFormat { ...@@ -140,8 +132,6 @@ pub struct JsonSchemaFormat {
pub strict: Option<bool>, pub strict: Option<bool>,
} }
// ============= Streaming Delta Types =============
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatMessageDelta { pub struct ChatMessageDelta {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
...@@ -177,8 +167,6 @@ pub struct FunctionCallDelta { ...@@ -177,8 +167,6 @@ pub struct FunctionCallDelta {
pub arguments: Option<String>, pub arguments: Option<String>,
} }
// ============= Request =============
#[derive(Debug, Clone, Deserialize, Serialize, Default)] #[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct ChatCompletionRequest { pub struct ChatCompletionRequest {
/// A list of messages comprising the conversation so far /// A list of messages comprising the conversation so far
...@@ -299,7 +287,6 @@ pub struct ChatCompletionRequest { ...@@ -299,7 +287,6 @@ pub struct ChatCompletionRequest {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub verbosity: Option<i32>, pub verbosity: Option<i32>,
// ============= SGLang Extensions =============
/// Top-k sampling parameter (-1 to disable) /// Top-k sampling parameter (-1 to disable)
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<i32>, pub top_k: Option<i32>,
...@@ -423,8 +410,6 @@ impl GenerationRequest for ChatCompletionRequest { ...@@ -423,8 +410,6 @@ impl GenerationRequest for ChatCompletionRequest {
} }
} }
// ============= Regular Response =============
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatCompletionResponse { pub struct ChatCompletionResponse {
pub id: String, pub id: String,
...@@ -453,8 +438,6 @@ pub struct ChatChoice { ...@@ -453,8 +438,6 @@ pub struct ChatChoice {
pub hidden_states: Option<Vec<f32>>, pub hidden_states: Option<Vec<f32>>,
} }
// ============= Streaming Response =============
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatCompletionStreamResponse { pub struct ChatCompletionStreamResponse {
pub id: String, pub id: String,
...@@ -477,9 +460,6 @@ pub struct ChatStreamChoice { ...@@ -477,9 +460,6 @@ pub struct ChatStreamChoice {
pub finish_reason: Option<String>, pub finish_reason: Option<String>,
} }
// ==================================================================
// = OPENAI SPEC - Completions API =
// ==================================================================
// Completions API request types (v1/completions) - DEPRECATED but still supported // Completions API request types (v1/completions) - DEPRECATED but still supported
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
...@@ -554,7 +534,6 @@ pub struct CompletionRequest { ...@@ -554,7 +534,6 @@ pub struct CompletionRequest {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<i64>, pub seed: Option<i64>,
// ============= SGLang Extensions =============
/// Top-k sampling parameter (-1 to disable) /// Top-k sampling parameter (-1 to disable)
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<i32>, pub top_k: Option<i32>,
...@@ -599,7 +578,6 @@ pub struct CompletionRequest { ...@@ -599,7 +578,6 @@ pub struct CompletionRequest {
#[serde(default = "default_true")] #[serde(default = "default_true")]
pub skip_special_tokens: bool, pub skip_special_tokens: bool,
// ============= SGLang Extensions =============
/// Path to LoRA adapter(s) for model customization /// Path to LoRA adapter(s) for model customization
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub lora_path: Option<LoRAPath>, pub lora_path: Option<LoRAPath>,
...@@ -638,8 +616,6 @@ impl GenerationRequest for CompletionRequest { ...@@ -638,8 +616,6 @@ impl GenerationRequest for CompletionRequest {
} }
} }
// ============= Regular Response =============
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CompletionResponse { pub struct CompletionResponse {
pub id: String, pub id: String,
...@@ -668,8 +644,6 @@ pub struct CompletionChoice { ...@@ -668,8 +644,6 @@ pub struct CompletionChoice {
pub hidden_states: Option<Vec<f32>>, pub hidden_states: Option<Vec<f32>>,
} }
// ============= Streaming Response =============
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CompletionStreamResponse { pub struct CompletionStreamResponse {
pub id: String, pub id: String,
...@@ -690,12 +664,6 @@ pub struct CompletionStreamChoice { ...@@ -690,12 +664,6 @@ pub struct CompletionStreamChoice {
pub finish_reason: Option<String>, pub finish_reason: Option<String>,
} }
// ==================================================================
// = OPENAI SPEC - Responses API =
// ==================================================================
// ============= Tool Definitions =============
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ResponseTool { pub struct ResponseTool {
#[serde(rename = "type")] #[serde(rename = "type")]
...@@ -709,8 +677,6 @@ pub enum ResponseToolType { ...@@ -709,8 +677,6 @@ pub enum ResponseToolType {
CodeInterpreter, CodeInterpreter,
} }
// ============= Reasoning Configuration =============
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ResponseReasoningParam { pub struct ResponseReasoningParam {
#[serde(default = "default_reasoning_effort")] #[serde(default = "default_reasoning_effort")]
...@@ -729,8 +695,6 @@ pub enum ReasoningEffort { ...@@ -729,8 +695,6 @@ pub enum ReasoningEffort {
High, High,
} }
// ============= Input/Output Items =============
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type")] #[serde(tag = "type")]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
...@@ -790,8 +754,6 @@ pub enum ResponseReasoningContent { ...@@ -790,8 +754,6 @@ pub enum ResponseReasoningContent {
ReasoningText { text: String }, ReasoningText { text: String },
} }
// ============= Output Items for Response =============
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type")] #[serde(tag = "type")]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
...@@ -823,8 +785,6 @@ pub enum ResponseOutputItem { ...@@ -823,8 +785,6 @@ pub enum ResponseOutputItem {
}, },
} }
// ============= Service Tier =============
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum ServiceTier { pub enum ServiceTier {
...@@ -841,8 +801,6 @@ impl Default for ServiceTier { ...@@ -841,8 +801,6 @@ impl Default for ServiceTier {
} }
} }
// ============= Truncation =============
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum Truncation { pub enum Truncation {
...@@ -856,8 +814,6 @@ impl Default for Truncation { ...@@ -856,8 +814,6 @@ impl Default for Truncation {
} }
} }
// ============= Response Status =============
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum ResponseStatus { pub enum ResponseStatus {
...@@ -868,8 +824,6 @@ pub enum ResponseStatus { ...@@ -868,8 +824,6 @@ pub enum ResponseStatus {
Cancelled, Cancelled,
} }
// ============= Reasoning Info =============
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ReasoningInfo { pub struct ReasoningInfo {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
...@@ -878,8 +832,6 @@ pub struct ReasoningInfo { ...@@ -878,8 +832,6 @@ pub struct ReasoningInfo {
pub summary: Option<String>, pub summary: Option<String>,
} }
// ============= Text Format =============
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ResponseTextFormat { pub struct ResponseTextFormat {
pub format: TextFormatType, pub format: TextFormatType,
...@@ -891,8 +843,6 @@ pub struct TextFormatType { ...@@ -891,8 +843,6 @@ pub struct TextFormatType {
pub format_type: String, pub format_type: String,
} }
// ============= Include Fields =============
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum IncludeField { pub enum IncludeField {
...@@ -910,8 +860,6 @@ pub enum IncludeField { ...@@ -910,8 +860,6 @@ pub enum IncludeField {
ReasoningEncryptedContent, ReasoningEncryptedContent,
} }
// ============= Usage Info =============
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct UsageInfo { pub struct UsageInfo {
pub prompt_tokens: u32, pub prompt_tokens: u32,
...@@ -928,8 +876,6 @@ pub struct PromptTokenUsageInfo { ...@@ -928,8 +876,6 @@ pub struct PromptTokenUsageInfo {
pub cached_tokens: u32, pub cached_tokens: u32,
} }
// ============= Response Usage Format =============
/// OpenAI Responses API usage format (different from standard UsageInfo) /// OpenAI Responses API usage format (different from standard UsageInfo)
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ResponseUsage { pub struct ResponseUsage {
...@@ -1038,7 +984,6 @@ fn generate_request_id() -> String { ...@@ -1038,7 +984,6 @@ fn generate_request_id() -> String {
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ResponsesRequest { pub struct ResponsesRequest {
// ============= Core OpenAI API fields =============
/// Run the request in the background /// Run the request in the background
#[serde(default)] #[serde(default)]
pub background: bool, pub background: bool,
...@@ -1122,7 +1067,6 @@ pub struct ResponsesRequest { ...@@ -1122,7 +1067,6 @@ pub struct ResponsesRequest {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>, pub user: Option<String>,
// ============= SGLang Extensions =============
/// Request ID /// Request ID
#[serde(default = "generate_request_id")] #[serde(default = "generate_request_id")]
pub request_id: String, pub request_id: String,
...@@ -1606,8 +1550,6 @@ impl ResponsesResponse { ...@@ -1606,8 +1550,6 @@ impl ResponsesResponse {
} }
} }
// ============= Helper Functions =============
impl ResponseOutputItem { impl ResponseOutputItem {
/// Create a new message output item /// Create a new message output item
pub fn new_message( pub fn new_message(
...@@ -1708,20 +1650,12 @@ impl UsageInfo { ...@@ -1708,20 +1650,12 @@ impl UsageInfo {
} }
} }
// ==================================================================
// = OPENAI SPEC - Common =
// ==================================================================
// ============= Shared Request Components =============
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct StreamOptions { pub struct StreamOptions {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub include_usage: Option<bool>, pub include_usage: Option<bool>,
} }
// ============= Tool Choice Types =============
/// Tool choice value for simple string options /// Tool choice value for simple string options
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
...@@ -1793,8 +1727,6 @@ pub struct FunctionCallResponse { ...@@ -1793,8 +1727,6 @@ pub struct FunctionCallResponse {
pub arguments: Option<String>, // JSON string pub arguments: Option<String>, // JSON string
} }
// ============= Usage Tracking =============
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Usage { pub struct Usage {
pub prompt_tokens: u32, pub prompt_tokens: u32,
...@@ -1809,8 +1741,6 @@ pub struct CompletionTokensDetails { ...@@ -1809,8 +1741,6 @@ pub struct CompletionTokensDetails {
pub reasoning_tokens: Option<u32>, pub reasoning_tokens: Option<u32>,
} }
// ============= Logprobs Types =============
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct LogProbs { pub struct LogProbs {
pub tokens: Vec<String>, pub tokens: Vec<String>,
...@@ -1860,10 +1790,6 @@ pub struct ErrorDetail { ...@@ -1860,10 +1790,6 @@ pub struct ErrorDetail {
pub code: Option<String>, pub code: Option<String>,
} }
// ==================================================================
// = SGLANG SPEC - GENERATE API =
// ==================================================================
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)] #[serde(untagged)]
pub enum InputIds { pub enum InputIds {
...@@ -1975,7 +1901,6 @@ pub struct GenerateRequest { ...@@ -1975,7 +1901,6 @@ pub struct GenerateRequest {
#[serde(default)] #[serde(default)]
pub return_logprob: bool, pub return_logprob: bool,
// ============= SGLang Extensions =============
/// Path to LoRA adapter(s) for model customization /// Path to LoRA adapter(s) for model customization
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub lora_path: Option<LoRAPath>, pub lora_path: Option<LoRAPath>,
...@@ -2036,10 +1961,6 @@ impl GenerationRequest for GenerateRequest { ...@@ -2036,10 +1961,6 @@ impl GenerationRequest for GenerateRequest {
} }
} }
// ==================================================================
// = SGLANG SPEC - RERANK API =
// ==================================================================
// Constants for rerank API // Constants for rerank API
pub const DEFAULT_MODEL_NAME: &str = "default"; pub const DEFAULT_MODEL_NAME: &str = "default";
...@@ -2237,10 +2158,6 @@ impl RerankResponse { ...@@ -2237,10 +2158,6 @@ impl RerankResponse {
} }
} }
// ==================================================================
// = OPENAI SPEC - Embeddings API =
// ==================================================================
/// Embeddings request compatible with OpenAI API /// Embeddings request compatible with OpenAI API
/// We intentionally keep fields flexible to pass through to workers. /// We intentionally keep fields flexible to pass through to workers.
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
...@@ -2292,10 +2209,6 @@ impl GenerationRequest for EmbeddingRequest { ...@@ -2292,10 +2209,6 @@ impl GenerationRequest for EmbeddingRequest {
} }
} }
// ==================================================================
// = COMMON =
// ==================================================================
/// Helper function for serde default value /// Helper function for serde default value
pub fn default_true() -> bool { pub fn default_true() -> bool {
true true
...@@ -2359,10 +2272,6 @@ mod tests { ...@@ -2359,10 +2272,6 @@ mod tests {
use super::*; use super::*;
use serde_json::{from_str, json, to_string}; use serde_json::{from_str, json, to_string};
// ==================================================================
// = RERANK REQUEST TESTS =
// ==================================================================
#[test] #[test]
fn test_rerank_request_serialization() { fn test_rerank_request_serialization() {
let request = RerankRequest { let request = RerankRequest {
...@@ -2534,10 +2443,6 @@ mod tests { ...@@ -2534,10 +2443,6 @@ mod tests {
assert_eq!(request.effective_top_k(), 3); assert_eq!(request.effective_top_k(), 3);
} }
// ==================================================================
// = RERANK RESPONSE TESTS =
// ==================================================================
#[test] #[test]
fn test_rerank_response_creation() { fn test_rerank_response_creation() {
let results = vec![ let results = vec![
...@@ -2709,10 +2614,6 @@ mod tests { ...@@ -2709,10 +2614,6 @@ mod tests {
assert_eq!(response.results[0].document, None); assert_eq!(response.results[0].document, None);
} }
// ==================================================================
// = RERANK RESULT TESTS =
// ==================================================================
#[test] #[test]
fn test_rerank_result_serialization() { fn test_rerank_result_serialization() {
let result = RerankResult { let result = RerankResult {
...@@ -2755,10 +2656,6 @@ mod tests { ...@@ -2755,10 +2656,6 @@ mod tests {
assert_eq!(deserialized.meta_info, result.meta_info); assert_eq!(deserialized.meta_info, result.meta_info);
} }
// ==================================================================
// = V1 COMPATIBILITY TESTS =
// ==================================================================
#[test] #[test]
fn test_v1_rerank_req_input_serialization() { fn test_v1_rerank_req_input_serialization() {
let v1_input = V1RerankReqInput { let v1_input = V1RerankReqInput {
...@@ -2791,10 +2688,6 @@ mod tests { ...@@ -2791,10 +2688,6 @@ mod tests {
assert_eq!(request.user, None); assert_eq!(request.user, None);
} }
// ==================================================================
// = GENERATION REQUEST TRAIT TESTS =
// ==================================================================
#[test] #[test]
fn test_rerank_request_generation_request_trait() { fn test_rerank_request_generation_request_trait() {
let request = RerankRequest { let request = RerankRequest {
...@@ -2812,10 +2705,6 @@ mod tests { ...@@ -2812,10 +2705,6 @@ mod tests {
assert_eq!(request.extract_text_for_routing(), "test query"); assert_eq!(request.extract_text_for_routing(), "test query");
} }
// ==================================================================
// = EDGE CASES AND STRESS TESTS =
// ==================================================================
#[test] #[test]
fn test_rerank_request_very_long_query() { fn test_rerank_request_very_long_query() {
let long_query = "a".repeat(100000); let long_query = "a".repeat(100000);
...@@ -2918,10 +2807,6 @@ mod tests { ...@@ -2918,10 +2807,6 @@ mod tests {
assert_eq!(usage.total_tokens, 150); assert_eq!(usage.total_tokens, 150);
} }
// ==================================================================
// = INTEGRATION TESTS =
// ==================================================================
#[test] #[test]
fn test_full_rerank_workflow() { fn test_full_rerank_workflow() {
// Create request // Create request
...@@ -2980,7 +2865,6 @@ mod tests { ...@@ -2980,7 +2865,6 @@ mod tests {
// Apply top_k // Apply top_k
response.apply_top_k(request.effective_top_k()); response.apply_top_k(request.effective_top_k());
// Verify results
assert_eq!(response.results.len(), 2); assert_eq!(response.results.len(), 2);
assert_eq!(response.results[0].score, 0.95); assert_eq!(response.results[0].score, 0.95);
assert_eq!(response.results[0].index, 0); assert_eq!(response.results[0].index, 0);
...@@ -2995,10 +2879,6 @@ mod tests { ...@@ -2995,10 +2879,6 @@ mod tests {
assert_eq!(deserialized.model, response.model); assert_eq!(deserialized.model, response.model);
} }
// ==================================================================
// = EMBEDDINGS REQUEST TESTS =
// ==================================================================
#[test] #[test]
fn test_embedding_request_serialization_string_input() { fn test_embedding_request_serialization_string_input() {
let req = EmbeddingRequest { let req = EmbeddingRequest {
......
...@@ -537,10 +537,6 @@ pub trait ValidatableRequest: ...@@ -537,10 +537,6 @@ pub trait ValidatableRequest:
} }
} }
// ==================================================================
// = OPENAI CHAT COMPLETION VALIDATION =
// ==================================================================
impl SamplingOptionsProvider for ChatCompletionRequest { impl SamplingOptionsProvider for ChatCompletionRequest {
fn get_temperature(&self) -> Option<f32> { fn get_temperature(&self) -> Option<f32> {
self.temperature self.temperature
...@@ -909,7 +905,6 @@ mod tests { ...@@ -909,7 +905,6 @@ mod tests {
fn test_chat_cross_parameter_conflicts() { fn test_chat_cross_parameter_conflicts() {
let mut request = create_valid_chat_request(); let mut request = create_valid_chat_request();
// Test 1: max_tokens vs max_completion_tokens conflict
request.max_tokens = Some(100); request.max_tokens = Some(100);
request.max_completion_tokens = Some(200); request.max_completion_tokens = Some(200);
assert!( assert!(
...@@ -921,7 +916,6 @@ mod tests { ...@@ -921,7 +916,6 @@ mod tests {
request.max_tokens = None; request.max_tokens = None;
request.max_completion_tokens = None; request.max_completion_tokens = None;
// Test 2: tools vs functions conflict (deprecated)
request.tools = Some(vec![]); request.tools = Some(vec![]);
request.functions = Some(vec![]); request.functions = Some(vec![]);
assert!( assert!(
...@@ -929,7 +923,6 @@ mod tests { ...@@ -929,7 +923,6 @@ mod tests {
"Should reject both tools and functions" "Should reject both tools and functions"
); );
// Test 3: logprobs=true without top_logprobs should be valid
let mut request = create_valid_chat_request(); let mut request = create_valid_chat_request();
request.logprobs = true; request.logprobs = true;
request.top_logprobs = None; request.top_logprobs = None;
...@@ -938,7 +931,6 @@ mod tests { ...@@ -938,7 +931,6 @@ mod tests {
"logprobs=true without top_logprobs should be valid" "logprobs=true without top_logprobs should be valid"
); );
// Test 4: top_logprobs without logprobs=true should fail (OpenAI rule)
let mut request = create_valid_chat_request(); let mut request = create_valid_chat_request();
request.logprobs = false; request.logprobs = false;
request.top_logprobs = Some(5); request.top_logprobs = Some(5);
...@@ -967,7 +959,6 @@ mod tests { ...@@ -967,7 +959,6 @@ mod tests {
fn test_parameter_ranges() { fn test_parameter_ranges() {
let mut request = create_valid_chat_request(); let mut request = create_valid_chat_request();
// Test temperature range (0.0 to 2.0)
request.temperature = Some(1.5); request.temperature = Some(1.5);
assert!(request.validate().is_ok()); assert!(request.validate().is_ok());
request.temperature = Some(-0.1); request.temperature = Some(-0.1);
...@@ -975,7 +966,6 @@ mod tests { ...@@ -975,7 +966,6 @@ mod tests {
request.temperature = Some(3.0); request.temperature = Some(3.0);
assert!(request.validate().is_err()); assert!(request.validate().is_err());
// Test top_p range (0.0 to 1.0)
request.temperature = Some(1.0); // Reset request.temperature = Some(1.0); // Reset
request.top_p = Some(0.9); request.top_p = Some(0.9);
assert!(request.validate().is_ok()); assert!(request.validate().is_ok());
...@@ -984,7 +974,6 @@ mod tests { ...@@ -984,7 +974,6 @@ mod tests {
request.top_p = Some(1.5); request.top_p = Some(1.5);
assert!(request.validate().is_err()); assert!(request.validate().is_err());
// Test frequency_penalty range (-2.0 to 2.0)
request.top_p = Some(0.9); // Reset request.top_p = Some(0.9); // Reset
request.frequency_penalty = Some(1.5); request.frequency_penalty = Some(1.5);
assert!(request.validate().is_ok()); assert!(request.validate().is_ok());
...@@ -993,7 +982,6 @@ mod tests { ...@@ -993,7 +982,6 @@ mod tests {
request.frequency_penalty = Some(3.0); request.frequency_penalty = Some(3.0);
assert!(request.validate().is_err()); assert!(request.validate().is_err());
// Test presence_penalty range (-2.0 to 2.0)
request.frequency_penalty = Some(0.0); // Reset request.frequency_penalty = Some(0.0); // Reset
request.presence_penalty = Some(-1.5); request.presence_penalty = Some(-1.5);
assert!(request.validate().is_ok()); assert!(request.validate().is_ok());
...@@ -1002,7 +990,6 @@ mod tests { ...@@ -1002,7 +990,6 @@ mod tests {
request.presence_penalty = Some(2.5); request.presence_penalty = Some(2.5);
assert!(request.validate().is_err()); assert!(request.validate().is_err());
// Test repetition_penalty range (0.0 to 2.0)
request.presence_penalty = Some(0.0); // Reset request.presence_penalty = Some(0.0); // Reset
request.repetition_penalty = Some(1.2); request.repetition_penalty = Some(1.2);
assert!(request.validate().is_ok()); assert!(request.validate().is_ok());
...@@ -1011,7 +998,6 @@ mod tests { ...@@ -1011,7 +998,6 @@ mod tests {
request.repetition_penalty = Some(2.1); request.repetition_penalty = Some(2.1);
assert!(request.validate().is_err()); assert!(request.validate().is_err());
// Test min_p range (0.0 to 1.0)
request.repetition_penalty = Some(1.0); // Reset request.repetition_penalty = Some(1.0); // Reset
request.min_p = Some(0.5); request.min_p = Some(0.5);
assert!(request.validate().is_ok()); assert!(request.validate().is_ok());
......
...@@ -373,7 +373,6 @@ mod tests { ...@@ -373,7 +373,6 @@ mod tests {
// Both should use the same passthrough parser instance // Both should use the same passthrough parser instance
assert!(Arc::ptr_eq(&parser1, &parser2)); assert!(Arc::ptr_eq(&parser1, &parser2));
// Verify it's actually a passthrough parser
let parser = parser1.lock().unwrap(); let parser = parser1.lock().unwrap();
assert_eq!(parser.model_type(), "passthrough"); assert_eq!(parser.model_type(), "passthrough");
} }
...@@ -456,7 +455,6 @@ mod tests { ...@@ -456,7 +455,6 @@ mod tests {
match p.detect_and_parse_reasoning(&input) { match p.detect_and_parse_reasoning(&input) {
Ok(result) => { Ok(result) => {
// Verify parsing worked correctly with substantial content
// Note: Some parsers with stream_reasoning=true won't accumulate reasoning text // Note: Some parsers with stream_reasoning=true won't accumulate reasoning text
assert!(result assert!(result
.normal_text .normal_text
......
...@@ -88,7 +88,6 @@ mod tests { ...@@ -88,7 +88,6 @@ mod tests {
fn test_kimi_partial_unicode() { fn test_kimi_partial_unicode() {
let mut parser = KimiParser::new(); let mut parser = KimiParser::new();
// Test partial Unicode token buffering
let result1 = parser let result1 = parser
.parse_reasoning_streaming_incremental("◁thi") .parse_reasoning_streaming_incremental("◁thi")
.unwrap(); .unwrap();
......
...@@ -96,8 +96,6 @@ impl GrpcRouter { ...@@ -96,8 +96,6 @@ impl GrpcRouter {
}) })
} }
// ============ Chat Implementation ============
/// Main route_chat implementation /// Main route_chat implementation
async fn route_chat_impl( async fn route_chat_impl(
&self, &self,
...@@ -207,7 +205,6 @@ impl GrpcRouter { ...@@ -207,7 +205,6 @@ impl GrpcRouter {
} }
} }
// ============ Helper Methods ============
/// Select a worker for the request /// Select a worker for the request
fn select_worker_for_request( fn select_worker_for_request(
&self, &self,
...@@ -809,7 +806,6 @@ mod tests { ...@@ -809,7 +806,6 @@ mod tests {
#[test] #[test]
fn test_transform_messages_mixed_content_types() { fn test_transform_messages_mixed_content_types() {
// Test with both text and multimodal content
let messages = vec![ let messages = vec![
ChatMessage::User { ChatMessage::User {
role: "user".to_string(), role: "user".to_string(),
...@@ -833,7 +829,6 @@ mod tests { ...@@ -833,7 +829,6 @@ mod tests {
}, },
]; ];
// Test String format
let result_string = let result_string =
GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String) GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String)
.unwrap(); .unwrap();
...@@ -842,7 +837,6 @@ mod tests { ...@@ -842,7 +837,6 @@ mod tests {
assert_eq!(result_string[0]["content"].as_str().unwrap(), "Plain text"); assert_eq!(result_string[0]["content"].as_str().unwrap(), "Plain text");
assert_eq!(result_string[1]["content"].as_str().unwrap(), "With image"); assert_eq!(result_string[1]["content"].as_str().unwrap(), "With image");
// Test OpenAI format
let result_openai = let result_openai =
GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::OpenAI) GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::OpenAI)
.unwrap(); .unwrap();
......
...@@ -957,7 +957,6 @@ impl RouterTrait for PDRouter { ...@@ -957,7 +957,6 @@ impl RouterTrait for PDRouter {
} }
async fn health_generate(&self, _req: Request<Body>) -> Response { async fn health_generate(&self, _req: Request<Body>) -> Response {
// Test model generation capability by selecting a random pair and testing them
// Note: This endpoint actually causes the model to generate tokens, so we only test one pair // Note: This endpoint actually causes the model to generate tokens, so we only test one pair
// Select a random worker pair using the policy // Select a random worker pair using the policy
...@@ -972,7 +971,6 @@ impl RouterTrait for PDRouter { ...@@ -972,7 +971,6 @@ impl RouterTrait for PDRouter {
} }
}; };
// Test prefill server's health_generate
let prefill_url = format!("{}/health_generate", prefill.url()); let prefill_url = format!("{}/health_generate", prefill.url());
let (prefill_result, decode_result) = tokio::join!( let (prefill_result, decode_result) = tokio::join!(
self.client.get(&prefill_url).send(), self.client.get(&prefill_url).send(),
......
...@@ -1018,7 +1018,6 @@ mod tests { ...@@ -1018,7 +1018,6 @@ mod tests {
}; };
let port = 8080u16; let port = 8080u16;
// Test that unified handler works for regular mode
handle_pod_event( handle_pod_event(
&pod_info, &pod_info,
Arc::clone(&tracked_pods), Arc::clone(&tracked_pods),
...@@ -1045,7 +1044,6 @@ mod tests { ...@@ -1045,7 +1044,6 @@ mod tests {
}; };
let port = 8080u16; let port = 8080u16;
// Test that unified handler works for PD mode with prefill
handle_pod_event( handle_pod_event(
&pod_info, &pod_info,
Arc::clone(&tracked_pods), Arc::clone(&tracked_pods),
...@@ -1080,7 +1078,6 @@ mod tests { ...@@ -1080,7 +1078,6 @@ mod tests {
let port = 8080u16; let port = 8080u16;
// Test that unified handler works for deletion in PD mode
handle_pod_deletion( handle_pod_deletion(
&pod_info, &pod_info,
Arc::clone(&tracked_pods), Arc::clone(&tracked_pods),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment