Unverified Commit 1fc455e8 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] add ut for pd request, metrics and config (#8184)

parent 465968b2
...@@ -214,83 +214,590 @@ impl RouterConfig { ...@@ -214,83 +214,590 @@ impl RouterConfig {
pub fn has_metrics(&self) -> bool { pub fn has_metrics(&self) -> bool {
self.metrics.is_some() self.metrics.is_some()
} }
}
#[cfg(test)]
mod tests {
use super::*;
// ============= RouterConfig Tests =============
/* Commented out - no longer needed without compatibility layer #[test]
/// Convert to routing PolicyConfig for internal use fn test_router_config_default() {
pub fn to_routing_policy_config(&self) -> ConfigResult<crate::router::PolicyConfig> { let config = RouterConfig::default();
match (&self.mode, &self.policy) {
( assert!(
RoutingMode::PrefillDecode { matches!(config.mode, RoutingMode::Regular { worker_urls } if worker_urls.is_empty())
prefill_urls, );
decode_urls, assert!(matches!(config.policy, PolicyConfig::Random));
}, assert_eq!(config.host, "127.0.0.1");
policy, assert_eq!(config.port, 3001);
) => { assert_eq!(config.max_payload_size, 268_435_456);
// Map policy to PDSelectionPolicy assert_eq!(config.request_timeout_secs, 600);
let selection_policy = match policy { assert_eq!(config.worker_startup_timeout_secs, 300);
PolicyConfig::Random => crate::pd_types::PDSelectionPolicy::Random, assert_eq!(config.worker_startup_check_interval_secs, 10);
PolicyConfig::PowerOfTwo { .. } => { assert!(config.discovery.is_none());
crate::pd_types::PDSelectionPolicy::PowerOfTwo assert!(config.metrics.is_none());
} assert!(config.log_dir.is_none());
PolicyConfig::CacheAware { .. } => { assert!(config.log_level.is_none());
return Err(ConfigError::IncompatibleConfig { }
reason: "CacheAware policy is not supported in PD disaggregated mode"
.to_string(), #[test]
}); fn test_router_config_new() {
} let mode = RoutingMode::Regular {
PolicyConfig::RoundRobin => { worker_urls: vec!["http://worker1".to_string(), "http://worker2".to_string()],
return Err(ConfigError::IncompatibleConfig { };
reason: "RoundRobin policy is not supported in PD disaggregated mode" let policy = PolicyConfig::RoundRobin;
.to_string(),
}); let config = RouterConfig::new(mode, policy);
}
}; match config.mode {
RoutingMode::Regular { worker_urls } => {
Ok(crate::router::PolicyConfig::PrefillDecodeConfig { assert_eq!(worker_urls.len(), 2);
selection_policy, assert_eq!(worker_urls[0], "http://worker1");
prefill_urls: prefill_urls.clone(), assert_eq!(worker_urls[1], "http://worker2");
decode_urls: decode_urls.clone(),
timeout_secs: self.worker_startup_timeout_secs,
interval_secs: self.worker_startup_check_interval_secs,
})
} }
(RoutingMode::Regular { .. }, PolicyConfig::Random) => { _ => panic!("Expected Regular mode"),
Ok(crate::router::PolicyConfig::RandomConfig { }
timeout_secs: self.worker_startup_timeout_secs,
interval_secs: self.worker_startup_check_interval_secs, assert!(matches!(config.policy, PolicyConfig::RoundRobin));
}) // Other fields should be default
assert_eq!(config.host, "127.0.0.1");
assert_eq!(config.port, 3001);
}
#[test]
fn test_router_config_serialization() {
let config = RouterConfig {
mode: RoutingMode::Regular {
worker_urls: vec!["http://worker1".to_string()],
},
policy: PolicyConfig::Random,
host: "0.0.0.0".to_string(),
port: 8080,
max_payload_size: 1024,
request_timeout_secs: 30,
worker_startup_timeout_secs: 60,
worker_startup_check_interval_secs: 5,
discovery: Some(DiscoveryConfig::default()),
metrics: Some(MetricsConfig::default()),
log_dir: Some("/var/log".to_string()),
log_level: Some("debug".to_string()),
};
let json = serde_json::to_string(&config).unwrap();
let deserialized: RouterConfig = serde_json::from_str(&json).unwrap();
assert_eq!(config.host, deserialized.host);
assert_eq!(config.port, deserialized.port);
assert_eq!(config.max_payload_size, deserialized.max_payload_size);
assert!(deserialized.discovery.is_some());
assert!(deserialized.metrics.is_some());
}
// ============= RoutingMode Tests =============
#[test]
fn test_routing_mode_is_pd_mode() {
let regular = RoutingMode::Regular {
worker_urls: vec!["http://worker1".to_string()],
};
assert!(!regular.is_pd_mode());
let pd = RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill1".to_string(), Some(8001))],
decode_urls: vec!["http://decode1".to_string()],
};
assert!(pd.is_pd_mode());
}
#[test]
fn test_routing_mode_worker_count() {
let regular = RoutingMode::Regular {
worker_urls: vec![
"http://worker1".to_string(),
"http://worker2".to_string(),
"http://worker3".to_string(),
],
};
assert_eq!(regular.worker_count(), 3);
let pd = RoutingMode::PrefillDecode {
prefill_urls: vec![
("http://prefill1".to_string(), Some(8001)),
("http://prefill2".to_string(), None),
],
decode_urls: vec![
"http://decode1".to_string(),
"http://decode2".to_string(),
"http://decode3".to_string(),
],
};
assert_eq!(pd.worker_count(), 5);
let empty_regular = RoutingMode::Regular {
worker_urls: vec![],
};
assert_eq!(empty_regular.worker_count(), 0);
}
#[test]
fn test_routing_mode_serialization() {
// Test Regular mode
let regular = RoutingMode::Regular {
worker_urls: vec!["http://worker1".to_string()],
};
let json = serde_json::to_string(&regular).unwrap();
assert!(json.contains("\"type\":\"regular\""));
assert!(json.contains("\"worker_urls\""));
// Test PrefillDecode mode
let pd = RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill1".to_string(), Some(8001))],
decode_urls: vec!["http://decode1".to_string()],
};
let json = serde_json::to_string(&pd).unwrap();
assert!(json.contains("\"type\":\"prefill_decode\""));
assert!(json.contains("\"prefill_urls\""));
assert!(json.contains("\"decode_urls\""));
}
// ============= PolicyConfig Tests =============
#[test]
fn test_policy_config_name() {
assert_eq!(PolicyConfig::Random.name(), "random");
assert_eq!(PolicyConfig::RoundRobin.name(), "round_robin");
let cache_aware = PolicyConfig::CacheAware {
cache_threshold: 0.8,
balance_abs_threshold: 10,
balance_rel_threshold: 1.5,
eviction_interval_secs: 300,
max_tree_size: 1000,
};
assert_eq!(cache_aware.name(), "cache_aware");
let power_of_two = PolicyConfig::PowerOfTwo {
load_check_interval_secs: 60,
};
assert_eq!(power_of_two.name(), "power_of_two");
}
#[test]
fn test_policy_config_serialization() {
// Test Random
let random = PolicyConfig::Random;
let json = serde_json::to_string(&random).unwrap();
assert_eq!(json, r#"{"type":"random"}"#);
// Test CacheAware with all parameters
let cache_aware = PolicyConfig::CacheAware {
cache_threshold: 0.8,
balance_abs_threshold: 10,
balance_rel_threshold: 1.5,
eviction_interval_secs: 300,
max_tree_size: 1000,
};
let json = serde_json::to_string(&cache_aware).unwrap();
assert!(json.contains("\"type\":\"cache_aware\""));
assert!(json.contains("\"cache_threshold\":0.8"));
assert!(json.contains("\"balance_abs_threshold\":10"));
// Test PowerOfTwo
let power_of_two = PolicyConfig::PowerOfTwo {
load_check_interval_secs: 60,
};
let json = serde_json::to_string(&power_of_two).unwrap();
assert!(json.contains("\"type\":\"power_of_two\""));
assert!(json.contains("\"load_check_interval_secs\":60"));
}
#[test]
fn test_cache_aware_parameters() {
let cache_aware = PolicyConfig::CacheAware {
cache_threshold: 0.75,
balance_abs_threshold: 20,
balance_rel_threshold: 2.0,
eviction_interval_secs: 600,
max_tree_size: 5000,
};
match cache_aware {
PolicyConfig::CacheAware {
cache_threshold,
balance_abs_threshold,
balance_rel_threshold,
eviction_interval_secs,
max_tree_size,
} => {
assert!((cache_threshold - 0.75).abs() < 0.0001);
assert_eq!(balance_abs_threshold, 20);
assert!((balance_rel_threshold - 2.0).abs() < 0.0001);
assert_eq!(eviction_interval_secs, 600);
assert_eq!(max_tree_size, 5000);
} }
(RoutingMode::Regular { .. }, PolicyConfig::RoundRobin) => { _ => panic!("Expected CacheAware"),
Ok(crate::router::PolicyConfig::RoundRobinConfig { }
timeout_secs: self.worker_startup_timeout_secs, }
interval_secs: self.worker_startup_check_interval_secs,
}) #[test]
fn test_power_of_two_parameters() {
let power_of_two = PolicyConfig::PowerOfTwo {
load_check_interval_secs: 120,
};
match power_of_two {
PolicyConfig::PowerOfTwo {
load_check_interval_secs,
} => {
assert_eq!(load_check_interval_secs, 120);
} }
( _ => panic!("Expected PowerOfTwo"),
RoutingMode::Regular { .. }, }
PolicyConfig::CacheAware { }
cache_threshold,
balance_abs_threshold, // ============= DiscoveryConfig Tests =============
balance_rel_threshold,
eviction_interval_secs, #[test]
max_tree_size, fn test_discovery_config_default() {
}, let config = DiscoveryConfig::default();
) => Ok(crate::router::PolicyConfig::CacheAwareConfig {
cache_threshold: *cache_threshold, assert!(!config.enabled);
balance_abs_threshold: *balance_abs_threshold, assert!(config.namespace.is_none());
balance_rel_threshold: *balance_rel_threshold, assert_eq!(config.port, 8000);
eviction_interval_secs: *eviction_interval_secs, assert_eq!(config.check_interval_secs, 60);
max_tree_size: *max_tree_size, assert!(config.selector.is_empty());
timeout_secs: self.worker_startup_timeout_secs, assert!(config.prefill_selector.is_empty());
interval_secs: self.worker_startup_check_interval_secs, assert!(config.decode_selector.is_empty());
assert_eq!(config.bootstrap_port_annotation, "sglang.ai/bootstrap-port");
}
#[test]
fn test_discovery_config_with_selectors() {
let mut selector = HashMap::new();
selector.insert("app".to_string(), "sglang".to_string());
selector.insert("role".to_string(), "worker".to_string());
let config = DiscoveryConfig {
enabled: true,
namespace: Some("default".to_string()),
port: 9000,
check_interval_secs: 30,
selector: selector.clone(),
prefill_selector: selector.clone(),
decode_selector: selector.clone(),
bootstrap_port_annotation: "custom.io/port".to_string(),
};
assert!(config.enabled);
assert_eq!(config.namespace, Some("default".to_string()));
assert_eq!(config.port, 9000);
assert_eq!(config.selector.len(), 2);
assert_eq!(config.selector.get("app"), Some(&"sglang".to_string()));
}
#[test]
fn test_discovery_config_namespace() {
// Test None namespace (all namespaces)
let config = DiscoveryConfig {
namespace: None,
..Default::default()
};
assert!(config.namespace.is_none());
// Test specific namespace
let config = DiscoveryConfig {
namespace: Some("production".to_string()),
..Default::default()
};
assert_eq!(config.namespace, Some("production".to_string()));
}
// ============= MetricsConfig Tests =============
#[test]
fn test_metrics_config_default() {
let config = MetricsConfig::default();
assert_eq!(config.port, 29000);
assert_eq!(config.host, "127.0.0.1");
}
#[test]
fn test_metrics_config_custom() {
let config = MetricsConfig {
port: 9090,
host: "0.0.0.0".to_string(),
};
assert_eq!(config.port, 9090);
assert_eq!(config.host, "0.0.0.0");
}
// ============= RouterConfig Utility Methods Tests =============
#[test]
fn test_mode_type() {
let config = RouterConfig {
mode: RoutingMode::Regular {
worker_urls: vec![],
},
..Default::default()
};
assert_eq!(config.mode_type(), "regular");
let config = RouterConfig {
mode: RoutingMode::PrefillDecode {
prefill_urls: vec![],
decode_urls: vec![],
},
..Default::default()
};
assert_eq!(config.mode_type(), "prefill_decode");
}
#[test]
fn test_has_service_discovery() {
let config = RouterConfig::default();
assert!(!config.has_service_discovery());
let config = RouterConfig {
discovery: Some(DiscoveryConfig {
enabled: false,
..Default::default()
}), }),
(RoutingMode::Regular { .. }, PolicyConfig::PowerOfTwo { .. }) => { ..Default::default()
Err(ConfigError::IncompatibleConfig { };
reason: "PowerOfTwo policy is only supported in PD disaggregated mode" assert!(!config.has_service_discovery());
.to_string(),
}) let config = RouterConfig {
discovery: Some(DiscoveryConfig {
enabled: true,
..Default::default()
}),
..Default::default()
};
assert!(config.has_service_discovery());
}
#[test]
fn test_has_metrics() {
let config = RouterConfig::default();
assert!(!config.has_metrics());
let config = RouterConfig {
metrics: Some(MetricsConfig::default()),
..Default::default()
};
assert!(config.has_metrics());
}
// ============= Edge Cases =============
#[test]
fn test_large_worker_lists() {
let large_urls: Vec<String> = (0..1000).map(|i| format!("http://worker{}", i)).collect();
let mode = RoutingMode::Regular {
worker_urls: large_urls.clone(),
};
assert_eq!(mode.worker_count(), 1000);
// Test serialization with large list
let config = RouterConfig {
mode,
..Default::default()
};
let json = serde_json::to_string(&config).unwrap();
let deserialized: RouterConfig = serde_json::from_str(&json).unwrap();
match deserialized.mode {
RoutingMode::Regular { worker_urls } => {
assert_eq!(worker_urls.len(), 1000);
} }
_ => panic!("Expected Regular mode"),
} }
} }
*/
#[test]
fn test_unicode_in_config() {
let config = RouterConfig {
mode: RoutingMode::Regular {
worker_urls: vec!["http://работник1".to_string(), "http://工作者2".to_string()],
},
log_dir: Some("/日志/目录".to_string()),
..Default::default()
};
let json = serde_json::to_string(&config).unwrap();
let deserialized: RouterConfig = serde_json::from_str(&json).unwrap();
match deserialized.mode {
RoutingMode::Regular { worker_urls } => {
assert_eq!(worker_urls[0], "http://работник1");
assert_eq!(worker_urls[1], "http://工作者2");
}
_ => panic!("Expected Regular mode"),
}
assert_eq!(deserialized.log_dir, Some("/日志/目录".to_string()));
}
#[test]
fn test_empty_string_fields() {
let config = RouterConfig {
host: "".to_string(),
log_dir: Some("".to_string()),
log_level: Some("".to_string()),
..Default::default()
};
assert_eq!(config.host, "");
assert_eq!(config.log_dir, Some("".to_string()));
assert_eq!(config.log_level, Some("".to_string()));
}
// ============= Complex Configuration Tests =============
#[test]
fn test_full_pd_mode_config() {
let config = RouterConfig {
mode: RoutingMode::PrefillDecode {
prefill_urls: vec![
("http://prefill1:8000".to_string(), Some(8001)),
("http://prefill2:8000".to_string(), None),
],
decode_urls: vec![
"http://decode1:8000".to_string(),
"http://decode2:8000".to_string(),
],
},
policy: PolicyConfig::PowerOfTwo {
load_check_interval_secs: 30,
},
host: "0.0.0.0".to_string(),
port: 3000,
max_payload_size: 1048576,
request_timeout_secs: 120,
worker_startup_timeout_secs: 60,
worker_startup_check_interval_secs: 5,
discovery: Some(DiscoveryConfig {
enabled: true,
namespace: Some("sglang".to_string()),
..Default::default()
}),
metrics: Some(MetricsConfig {
port: 9090,
host: "0.0.0.0".to_string(),
}),
log_dir: Some("/var/log/sglang".to_string()),
log_level: Some("info".to_string()),
};
assert!(config.mode.is_pd_mode());
assert_eq!(config.mode.worker_count(), 4);
assert_eq!(config.policy.name(), "power_of_two");
assert!(config.has_service_discovery());
assert!(config.has_metrics());
}
#[test]
fn test_full_regular_mode_config() {
let mut selector = HashMap::new();
selector.insert("app".to_string(), "sglang".to_string());
let config = RouterConfig {
mode: RoutingMode::Regular {
worker_urls: vec![
"http://worker1:8000".to_string(),
"http://worker2:8000".to_string(),
"http://worker3:8000".to_string(),
],
},
policy: PolicyConfig::CacheAware {
cache_threshold: 0.9,
balance_abs_threshold: 5,
balance_rel_threshold: 1.2,
eviction_interval_secs: 600,
max_tree_size: 10000,
},
host: "0.0.0.0".to_string(),
port: 3001,
max_payload_size: 536870912,
request_timeout_secs: 300,
worker_startup_timeout_secs: 180,
worker_startup_check_interval_secs: 15,
discovery: Some(DiscoveryConfig {
enabled: true,
namespace: None,
port: 8080,
check_interval_secs: 45,
selector,
..Default::default()
}),
metrics: Some(MetricsConfig::default()),
log_dir: None,
log_level: Some("debug".to_string()),
};
assert!(!config.mode.is_pd_mode());
assert_eq!(config.mode.worker_count(), 3);
assert_eq!(config.policy.name(), "cache_aware");
assert!(config.has_service_discovery());
assert!(config.has_metrics());
}
#[test]
fn test_config_with_all_options() {
let mut selectors = HashMap::new();
selectors.insert("env".to_string(), "prod".to_string());
selectors.insert("version".to_string(), "v1".to_string());
let config = RouterConfig {
mode: RoutingMode::Regular {
worker_urls: vec!["http://worker1".to_string()],
},
policy: PolicyConfig::RoundRobin,
host: "::1".to_string(), // IPv6
port: 8888,
max_payload_size: 1024 * 1024 * 512, // 512MB
request_timeout_secs: 900,
worker_startup_timeout_secs: 600,
worker_startup_check_interval_secs: 20,
discovery: Some(DiscoveryConfig {
enabled: true,
namespace: Some("production".to_string()),
port: 8443,
check_interval_secs: 120,
selector: selectors.clone(),
prefill_selector: selectors.clone(),
decode_selector: selectors,
bootstrap_port_annotation: "mycompany.io/bootstrap".to_string(),
}),
metrics: Some(MetricsConfig {
port: 9999,
host: "::".to_string(), // IPv6 any
}),
log_dir: Some("/opt/logs/sglang".to_string()),
log_level: Some("trace".to_string()),
};
assert!(config.has_service_discovery());
assert!(config.has_metrics());
assert_eq!(config.mode_type(), "regular");
// Test round-trip serialization
let json = serde_json::to_string_pretty(&config).unwrap();
let deserialized: RouterConfig = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.host, "::1");
assert_eq!(deserialized.port, 8888);
assert_eq!(
deserialized.discovery.unwrap().namespace,
Some("production".to_string())
);
}
} }
...@@ -322,3 +322,414 @@ impl RouterMetrics { ...@@ -322,3 +322,414 @@ impl RouterMetrics {
.set(count as f64); .set(count as f64);
} }
} }
#[cfg(test)]
mod tests {
use super::*;
use std::net::TcpListener;
// ============= PrometheusConfig Tests =============
#[test]
fn test_prometheus_config_default() {
let config = PrometheusConfig::default();
assert_eq!(config.port, 29000);
assert_eq!(config.host, "0.0.0.0");
}
#[test]
fn test_prometheus_config_custom() {
let config = PrometheusConfig {
port: 8080,
host: "127.0.0.1".to_string(),
};
assert_eq!(config.port, 8080);
assert_eq!(config.host, "127.0.0.1");
}
#[test]
fn test_prometheus_config_clone() {
let config = PrometheusConfig {
port: 9090,
host: "192.168.1.1".to_string(),
};
let cloned = config.clone();
assert_eq!(cloned.port, config.port);
assert_eq!(cloned.host, config.host);
}
// ============= IP Address Parsing Tests =============
#[test]
fn test_valid_ipv4_parsing() {
let test_cases = vec!["127.0.0.1", "192.168.1.1", "0.0.0.0"];
for ip_str in test_cases {
let config = PrometheusConfig {
port: 29000,
host: ip_str.to_string(),
};
let ip_addr: IpAddr = config.host.parse().unwrap();
assert!(matches!(ip_addr, IpAddr::V4(_)));
}
}
#[test]
fn test_valid_ipv6_parsing() {
let test_cases = vec!["::1", "2001:db8::1", "::"];
for ip_str in test_cases {
let config = PrometheusConfig {
port: 29000,
host: ip_str.to_string(),
};
let ip_addr: IpAddr = config.host.parse().unwrap();
assert!(matches!(ip_addr, IpAddr::V6(_)));
}
}
#[test]
fn test_invalid_ip_parsing() {
let test_cases = vec!["invalid", "256.256.256.256", "hostname"];
for ip_str in test_cases {
let config = PrometheusConfig {
port: 29000,
host: ip_str.to_string(),
};
let ip_addr: IpAddr = config
.host
.parse()
.unwrap_or(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)));
// Should fall back to 0.0.0.0
assert_eq!(ip_addr, IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)));
}
}
// ============= Socket Address Creation Tests =============
#[test]
fn test_socket_addr_creation() {
let test_cases = vec![("127.0.0.1", 8080), ("0.0.0.0", 29000), ("::1", 9090)];
for (host, port) in test_cases {
let config = PrometheusConfig {
port,
host: host.to_string(),
};
let ip_addr: IpAddr = config.host.parse().unwrap();
let socket_addr = SocketAddr::new(ip_addr, config.port);
assert_eq!(socket_addr.port(), port);
assert_eq!(socket_addr.ip().to_string(), host);
}
}
#[test]
fn test_socket_addr_with_different_ports() {
let ports = vec![0, 80, 8080, 65535];
for port in ports {
let config = PrometheusConfig {
port,
host: "127.0.0.1".to_string(),
};
let ip_addr: IpAddr = config.host.parse().unwrap();
let socket_addr = SocketAddr::new(ip_addr, config.port);
assert_eq!(socket_addr.port(), port);
}
}
// ============= Duration Bucket Tests =============
#[test]
fn test_duration_bucket_values() {
let expected_buckets = vec![
0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 15.0, 30.0, 45.0,
60.0, 90.0, 120.0, 180.0, 240.0,
];
// The buckets are defined in start_prometheus function
assert_eq!(expected_buckets.len(), 20);
// Verify proper ordering
for i in 1..expected_buckets.len() {
assert!(expected_buckets[i] > expected_buckets[i - 1]);
}
}
#[test]
fn test_duration_bucket_coverage() {
let test_cases = vec![
(0.0005, "sub-millisecond"),
(0.005, "5ms"),
(0.05, "50ms"),
(1.0, "1s"),
(10.0, "10s"),
(60.0, "1m"),
(240.0, "4m"),
];
let buckets = vec![
0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 15.0, 30.0, 45.0,
60.0, 90.0, 120.0, 180.0, 240.0,
];
for (duration, label) in test_cases {
let bucket_found = buckets
.iter()
.any(|&b| ((b - duration) as f64).abs() < 0.0001 || b > duration);
assert!(bucket_found, "No bucket found for {} ({})", duration, label);
}
}
// ============= Matcher Configuration Tests =============
#[test]
fn test_duration_suffix_matcher() {
let matcher = Matcher::Suffix(String::from("duration_seconds"));
// Test matching behavior
let _matching_metrics = vec![
"request_duration_seconds",
"response_duration_seconds",
"sgl_router_request_duration_seconds",
];
let _non_matching_metrics =
vec!["duration_total", "duration_seconds_total", "other_metric"];
// Note: We can't directly test Matcher matching without the internals,
// but we can verify the matcher is created correctly
match matcher {
Matcher::Suffix(suffix) => assert_eq!(suffix, "duration_seconds"),
_ => panic!("Expected Suffix matcher"),
}
}
// ============= Builder Configuration Tests =============
#[test]
fn test_prometheus_builder_configuration() {
// This test verifies the builder configuration without actually starting Prometheus
let _config = PrometheusConfig::default();
let duration_matcher = Matcher::Suffix(String::from("duration_seconds"));
let duration_bucket = [
0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 15.0, 30.0, 45.0,
60.0, 90.0, 120.0, 180.0, 240.0,
];
// Verify bucket configuration
assert_eq!(duration_bucket.len(), 20);
// Verify matcher is suffix type
match duration_matcher {
Matcher::Suffix(s) => assert_eq!(s, "duration_seconds"),
_ => panic!("Expected Suffix matcher"),
}
}
// ============= Upkeep Timeout Tests =============
#[test]
fn test_upkeep_timeout_duration() {
let timeout = Duration::from_secs(5 * 60);
assert_eq!(timeout.as_secs(), 300);
}
// ============= Custom Bucket Tests =============
#[test]
fn test_custom_buckets_for_different_metrics() {
// Test that we can create different bucket configurations
let request_buckets = vec![0.001, 0.01, 0.1, 1.0, 10.0];
let generate_buckets = vec![0.1, 0.5, 1.0, 5.0, 30.0, 60.0];
assert_eq!(request_buckets.len(), 5);
assert_eq!(generate_buckets.len(), 6);
// Verify each set is sorted
for i in 1..request_buckets.len() {
assert!(request_buckets[i] > request_buckets[i - 1]);
}
for i in 1..generate_buckets.len() {
assert!(generate_buckets[i] > generate_buckets[i - 1]);
}
}
// ============= RouterMetrics Tests =============
#[test]
fn test_metrics_static_methods() {
// Test that all static methods can be called without panic
RouterMetrics::record_request("/generate");
RouterMetrics::record_request_duration("/generate", Duration::from_millis(100));
RouterMetrics::record_request_error("/generate", "timeout");
RouterMetrics::record_retry("/generate");
RouterMetrics::set_active_workers(5);
RouterMetrics::set_worker_health("http://worker1", true);
RouterMetrics::set_worker_load("http://worker1", 10);
RouterMetrics::record_processed_request("http://worker1");
RouterMetrics::record_policy_decision("random", "http://worker1");
RouterMetrics::record_cache_hit();
RouterMetrics::record_cache_miss();
RouterMetrics::set_tree_size("http://worker1", 1000);
RouterMetrics::record_load_balancing_event();
RouterMetrics::set_load_range(20, 5);
RouterMetrics::record_pd_request("/v1/chat/completions");
RouterMetrics::record_pd_request_duration("/v1/chat/completions", Duration::from_secs(1));
RouterMetrics::record_pd_prefill_request("http://prefill1");
RouterMetrics::record_pd_decode_request("http://decode1");
RouterMetrics::record_pd_error("invalid_request");
RouterMetrics::record_pd_prefill_error("http://prefill1");
RouterMetrics::record_pd_decode_error("http://decode1");
RouterMetrics::record_pd_stream_error("http://decode1");
RouterMetrics::record_discovery_update(3, 1);
RouterMetrics::record_generate_duration(Duration::from_secs(2));
RouterMetrics::set_running_requests("http://worker1", 15);
}
// ============= Port Availability Tests =============
#[test]
fn test_port_already_in_use() {
// Skip this test if we can't bind to the port
let port = 29123; // Use a different port to avoid conflicts
if let Ok(_listener) = TcpListener::bind(("127.0.0.1", port)) {
// Port is available, we can test
let config = PrometheusConfig {
port,
host: "127.0.0.1".to_string(),
};
// Just verify config is created correctly
assert_eq!(config.port, port);
}
}
// ============= Integration Test Helpers =============
#[test]
fn test_metrics_endpoint_accessibility() {
// This would be an integration test in practice
// Here we just verify the configuration
let config = PrometheusConfig {
port: 29000,
host: "127.0.0.1".to_string(),
};
let ip_addr: IpAddr = config.host.parse().unwrap();
let socket_addr = SocketAddr::new(ip_addr, config.port);
assert_eq!(socket_addr.to_string(), "127.0.0.1:29000");
}
#[test]
fn test_concurrent_metric_updates() {
// Test that metric updates can be called concurrently
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::thread;
let done = Arc::new(AtomicBool::new(false));
let mut handles = vec![];
for i in 0..3 {
let done_clone = done.clone();
let handle = thread::spawn(move || {
let worker = format!("http://worker{}", i);
while !done_clone.load(Ordering::Relaxed) {
RouterMetrics::set_worker_load(&worker, i * 10);
RouterMetrics::record_processed_request(&worker);
thread::sleep(Duration::from_millis(1));
}
});
handles.push(handle);
}
// Let threads run briefly
thread::sleep(Duration::from_millis(10));
done.store(true, Ordering::Relaxed);
// Wait for all threads
for handle in handles {
handle.join().unwrap();
}
// If we get here without panic, concurrent access works
assert!(true);
}
// ============= Edge Cases Tests =============
#[test]
fn test_empty_string_metrics() {
// Test that empty strings don't cause issues
RouterMetrics::record_request("");
RouterMetrics::set_worker_health("", true);
RouterMetrics::record_policy_decision("", "");
// If we get here without panic, empty strings are handled
assert!(true);
}
#[test]
fn test_very_long_metric_labels() {
let long_label = "a".repeat(1000);
RouterMetrics::record_request(&long_label);
RouterMetrics::set_worker_health(&long_label, false);
// If we get here without panic, long labels are handled
assert!(true);
}
#[test]
fn test_special_characters_in_labels() {
let special_labels = vec![
"test/with/slashes",
"test-with-dashes",
"test_with_underscores",
"test.with.dots",
"test:with:colons",
];
for label in special_labels {
RouterMetrics::record_request(label);
RouterMetrics::set_worker_health(label, true);
}
// If we get here without panic, special characters are handled
assert!(true);
}
#[test]
fn test_extreme_metric_values() {
// Test extreme values
RouterMetrics::set_active_workers(0);
RouterMetrics::set_active_workers(usize::MAX);
RouterMetrics::set_worker_load("worker", 0);
RouterMetrics::set_worker_load("worker", usize::MAX);
RouterMetrics::record_request_duration("route", Duration::from_nanos(1));
RouterMetrics::record_request_duration("route", Duration::from_secs(86400)); // 24 hours
// If we get here without panic, extreme values are handled
assert!(true);
}
}
...@@ -58,7 +58,7 @@ pub enum PDSelectionPolicy { ...@@ -58,7 +58,7 @@ pub enum PDSelectionPolicy {
}, },
} }
// Bootstrap types from PDLB // Bootstrap types from PDLB
#[derive(Debug, Deserialize, Serialize)] #[derive(Debug, Deserialize, Serialize, PartialEq)]
#[serde(untagged)] #[serde(untagged)]
pub enum SingleOrBatch<T> { pub enum SingleOrBatch<T> {
Single(T), Single(T),
......
...@@ -211,6 +211,7 @@ impl ToPdRequest for ChatCompletionRequest { ...@@ -211,6 +211,7 @@ impl ToPdRequest for ChatCompletionRequest {
self.temperature => "temperature", self.temperature => "temperature",
self.top_p => "top_p", self.top_p => "top_p",
self.n => "n", self.n => "n",
self.stream_options => "stream_options",
self.stop => "stop", self.stop => "stop",
self.max_tokens => "max_tokens", self.max_tokens => "max_tokens",
self.max_completion_tokens => "max_completion_tokens", self.max_completion_tokens => "max_completion_tokens",
...@@ -262,3 +263,1015 @@ pub trait RouteableRequest: GenerationRequest + serde::Serialize + Clone { ...@@ -262,3 +263,1015 @@ pub trait RouteableRequest: GenerationRequest + serde::Serialize + Clone {
impl RouteableRequest for GenerateRequest {} impl RouteableRequest for GenerateRequest {}
impl RouteableRequest for CompletionRequest {} impl RouteableRequest for CompletionRequest {}
impl RouteableRequest for ChatCompletionRequest {} impl RouteableRequest for ChatCompletionRequest {}
#[cfg(test)]
mod tests {
use super::*;
use crate::openai_api_types::*;
use serde_json::json;
use std::collections::HashMap;
// ============= GenerateRequest to_pd_request Tests =============
#[test]
fn test_generate_to_pd_request_with_text_only() {
let req = GenerateRequest {
text: Some("Hello world".to_string()),
prompt: None,
input_ids: None,
stream: false,
parameters: None,
sampling_params: None,
return_logprob: false,
};
let pd_req = req.to_pd_request();
// Check text field conversion
assert!(matches!(pd_req.text, Some(SingleOrBatch::Single(ref s)) if s == "Hello world"));
assert!(pd_req.input_ids.is_none());
// Check bootstrap fields are None
assert!(pd_req.bootstrap_host.is_none());
assert!(pd_req.bootstrap_port.is_none());
assert!(pd_req.bootstrap_room.is_none());
// Check stream flag
assert_eq!(pd_req.stream, false);
// Check other fields
let other = pd_req.other.as_object().unwrap();
assert_eq!(other.get("stream"), Some(&json!(false)));
assert_eq!(other.get("return_logprob"), Some(&json!(false)));
}
#[test]
fn test_generate_to_pd_request_with_prompt_string() {
let req = GenerateRequest {
text: None,
prompt: Some(StringOrArray::String("Test prompt".to_string())),
input_ids: None,
stream: true,
parameters: None,
sampling_params: None,
return_logprob: true,
};
let pd_req = req.to_pd_request();
assert!(matches!(pd_req.text, Some(SingleOrBatch::Single(ref s)) if s == "Test prompt"));
assert!(pd_req.input_ids.is_none());
assert_eq!(pd_req.stream, true);
let other = pd_req.other.as_object().unwrap();
assert_eq!(other.get("stream"), Some(&json!(true)));
assert_eq!(other.get("return_logprob"), Some(&json!(true)));
}
#[test]
fn test_generate_to_pd_request_with_prompt_array() {
let req = GenerateRequest {
text: None,
prompt: Some(StringOrArray::Array(vec![
"Prompt 1".to_string(),
"Prompt 2".to_string(),
"Prompt 3".to_string(),
])),
input_ids: None,
stream: false,
parameters: None,
sampling_params: None,
return_logprob: false,
};
let pd_req = req.to_pd_request();
match pd_req.text {
Some(SingleOrBatch::Batch(ref batch)) => {
assert_eq!(batch.len(), 3);
assert_eq!(batch[0], "Prompt 1");
assert_eq!(batch[1], "Prompt 2");
assert_eq!(batch[2], "Prompt 3");
}
_ => panic!("Expected batch text"),
}
}
#[test]
fn test_generate_to_pd_request_with_single_input_ids() {
let req = GenerateRequest {
text: None,
prompt: None,
input_ids: Some(InputIds::Single(vec![100, 200, 300, 400])),
stream: false,
parameters: None,
sampling_params: None,
return_logprob: false,
};
let pd_req = req.to_pd_request();
assert!(pd_req.text.is_none());
assert!(matches!(
pd_req.input_ids,
Some(SingleOrBatch::Single(ref ids)) if ids == &vec![100, 200, 300, 400]
));
}
#[test]
fn test_generate_to_pd_request_with_batch_input_ids() {
let req = GenerateRequest {
text: None,
prompt: None,
input_ids: Some(InputIds::Batch(vec![
vec![1, 2, 3],
vec![4, 5, 6, 7],
vec![8, 9],
])),
stream: false,
parameters: None,
sampling_params: None,
return_logprob: false,
};
let pd_req = req.to_pd_request();
match pd_req.input_ids {
Some(SingleOrBatch::Batch(ref batch)) => {
assert_eq!(batch.len(), 3);
assert_eq!(batch[0], vec![1, 2, 3]);
assert_eq!(batch[1], vec![4, 5, 6, 7]);
assert_eq!(batch[2], vec![8, 9]);
}
_ => panic!("Expected batch input_ids"),
}
}
#[test]
fn test_generate_to_pd_request_priority_text_over_prompt() {
let req = GenerateRequest {
text: Some("SGLang text".to_string()),
prompt: Some(StringOrArray::String("OpenAI prompt".to_string())),
input_ids: Some(InputIds::Single(vec![1, 2, 3])),
stream: false,
parameters: None,
sampling_params: None,
return_logprob: false,
};
let pd_req = req.to_pd_request();
// text should take priority
assert!(matches!(pd_req.text, Some(SingleOrBatch::Single(ref s)) if s == "SGLang text"));
assert!(pd_req.input_ids.is_none());
}
#[test]
fn test_generate_to_pd_request_priority_prompt_over_input_ids() {
let req = GenerateRequest {
text: None,
prompt: Some(StringOrArray::String("OpenAI prompt".to_string())),
input_ids: Some(InputIds::Single(vec![1, 2, 3])),
stream: false,
parameters: None,
sampling_params: None,
return_logprob: false,
};
let pd_req = req.to_pd_request();
// prompt should take priority over input_ids
assert!(matches!(pd_req.text, Some(SingleOrBatch::Single(ref s)) if s == "OpenAI prompt"));
assert!(pd_req.input_ids.is_none());
}
#[test]
fn test_generate_to_pd_request_with_parameters() {
let params = GenerateParameters {
max_new_tokens: Some(100),
temperature: Some(0.8),
top_p: Some(0.95),
seed: Some(12345),
stop: Some(vec!["END".to_string(), "STOP".to_string()]),
repetition_penalty: Some(1.1),
..Default::default()
};
let req = GenerateRequest {
text: Some("test".to_string()),
prompt: None,
input_ids: None,
stream: false,
parameters: Some(params),
sampling_params: None,
return_logprob: false,
};
let pd_req = req.to_pd_request();
let other = pd_req.other.as_object().unwrap();
// Check that max_new_tokens and temperature were extracted to top level
assert_eq!(other.get("max_new_tokens"), Some(&json!(100)));
assert!(other.get("temperature").unwrap().as_f64().unwrap() - 0.8 < 0.0001);
// Check that other parameters remain under "parameters"
let params = other.get("parameters").unwrap().as_object().unwrap();
assert!(params.get("top_p").unwrap().as_f64().unwrap() - 0.95 < 0.0001);
assert_eq!(params.get("seed"), Some(&json!(12345)));
assert_eq!(params.get("stop"), Some(&json!(vec!["END", "STOP"])));
assert!(params.get("repetition_penalty").unwrap().as_f64().unwrap() - 1.1 < 0.0001);
}
#[test]
fn test_generate_to_pd_request_with_sampling_params() {
let sampling = SamplingParams {
max_new_tokens: Some(200),
temperature: Some(0.7),
top_p: Some(0.9),
top_k: Some(50),
frequency_penalty: Some(0.1),
presence_penalty: Some(0.2),
repetition_penalty: Some(1.05),
..Default::default()
};
let req = GenerateRequest {
text: Some("test".to_string()),
prompt: None,
input_ids: None,
stream: false,
parameters: None,
sampling_params: Some(sampling),
return_logprob: false,
};
let pd_req = req.to_pd_request();
let other = pd_req.other.as_object().unwrap();
// Check extracted top-level fields
assert_eq!(other.get("max_new_tokens"), Some(&json!(200)));
assert!(other.get("temperature").unwrap().as_f64().unwrap() - 0.7 < 0.0001);
// Check full sampling_params is preserved
let sampling = other.get("sampling_params").unwrap().as_object().unwrap();
assert_eq!(sampling.get("max_new_tokens"), Some(&json!(200)));
assert!(sampling.get("temperature").unwrap().as_f64().unwrap() - 0.7 < 0.0001);
assert!(sampling.get("top_p").unwrap().as_f64().unwrap() - 0.9 < 0.0001);
assert_eq!(sampling.get("top_k"), Some(&json!(50)));
assert!(sampling.get("frequency_penalty").unwrap().as_f64().unwrap() - 0.1 < 0.0001);
assert!(sampling.get("presence_penalty").unwrap().as_f64().unwrap() - 0.2 < 0.0001);
}
#[test]
fn test_generate_to_pd_request_sampling_params_override_parameters() {
// When both parameters and sampling_params have max_new_tokens/temperature,
// sampling_params should take precedence (processed last)
let params = GenerateParameters {
max_new_tokens: Some(100),
temperature: Some(0.5),
..Default::default()
};
let sampling = SamplingParams {
max_new_tokens: Some(200),
temperature: Some(0.9),
..Default::default()
};
let req = GenerateRequest {
text: Some("test".to_string()),
prompt: None,
input_ids: None,
stream: false,
parameters: Some(params),
sampling_params: Some(sampling),
return_logprob: false,
};
let pd_req = req.to_pd_request();
let other = pd_req.other.as_object().unwrap();
// Should use values from sampling_params since they're processed last
assert_eq!(other.get("max_new_tokens"), Some(&json!(200)));
assert!(other.get("temperature").unwrap().as_f64().unwrap() - 0.9 < 0.0001);
}
#[test]
fn test_generate_to_pd_request_empty_parameters() {
let params = GenerateParameters::default();
let req = GenerateRequest {
text: Some("test".to_string()),
prompt: None,
input_ids: None,
stream: false,
parameters: Some(params),
sampling_params: None,
return_logprob: false,
};
let pd_req = req.to_pd_request();
let other = pd_req.other.as_object().unwrap();
// Should not have parameters field if all values are None/default
assert!(!other.contains_key("parameters"));
assert!(!other.contains_key("max_new_tokens"));
assert!(!other.contains_key("temperature"));
}
#[test]
fn test_generate_to_pd_request_all_fields() {
let params = GenerateParameters {
max_new_tokens: Some(150),
temperature: Some(0.6),
top_k: Some(40),
..Default::default()
};
let sampling = SamplingParams {
max_new_tokens: Some(250), // Will override parameters
temperature: Some(0.8), // Will override parameters
presence_penalty: Some(0.1),
..Default::default()
};
let req = GenerateRequest {
text: Some("Complex test".to_string()),
prompt: Some(StringOrArray::String("Ignored prompt".to_string())),
input_ids: None,
stream: true,
parameters: Some(params),
sampling_params: Some(sampling),
return_logprob: true,
};
let pd_req = req.to_pd_request();
// Verify all fields
assert!(matches!(pd_req.text, Some(SingleOrBatch::Single(ref s)) if s == "Complex test"));
assert!(pd_req.input_ids.is_none());
assert_eq!(pd_req.stream, true);
assert!(pd_req.bootstrap_host.is_none());
assert!(pd_req.bootstrap_port.is_none());
assert!(pd_req.bootstrap_room.is_none());
let other = pd_req.other.as_object().unwrap();
assert_eq!(other.get("stream"), Some(&json!(true)));
assert_eq!(other.get("return_logprob"), Some(&json!(true)));
// Sampling params override parameters
assert_eq!(other.get("max_new_tokens"), Some(&json!(250)));
assert!(other.get("temperature").unwrap().as_f64().unwrap() - 0.8 < 0.0001);
assert!(other.contains_key("parameters"));
assert!(other.contains_key("sampling_params"));
}
// ============= CompletionRequest to_pd_request Tests =============
#[test]
fn test_completion_to_pd_request_basic() {
let req = CompletionRequest {
model: "gpt-3.5-turbo".to_string(),
prompt: StringOrArray::String("Complete this sentence".to_string()),
max_tokens: None,
temperature: None,
top_p: None,
n: None,
stream: false,
stream_options: None,
logprobs: None,
echo: false,
stop: None,
presence_penalty: None,
frequency_penalty: None,
best_of: None,
logit_bias: None,
user: None,
seed: None,
suffix: None,
};
let pd_req = req.to_pd_request();
assert!(
matches!(pd_req.text, Some(SingleOrBatch::Single(ref s)) if s == "Complete this sentence")
);
assert!(pd_req.input_ids.is_none());
assert_eq!(pd_req.stream, false);
let other = pd_req.other.as_object().unwrap();
assert_eq!(other.get("model"), Some(&json!("gpt-3.5-turbo")));
assert_eq!(other.get("stream"), Some(&json!(false)));
}
#[test]
fn test_completion_to_pd_request_array_prompt() {
let req = CompletionRequest {
model: "test".to_string(),
prompt: StringOrArray::Array(vec![
"First prompt".to_string(),
"Second prompt".to_string(),
]),
max_tokens: None,
temperature: None,
top_p: None,
n: None,
stream: false,
stream_options: None,
logprobs: None,
echo: false,
stop: None,
presence_penalty: None,
frequency_penalty: None,
best_of: None,
logit_bias: None,
user: None,
seed: None,
suffix: None,
};
let pd_req = req.to_pd_request();
match pd_req.text {
Some(SingleOrBatch::Batch(ref batch)) => {
assert_eq!(batch.len(), 2);
assert_eq!(batch[0], "First prompt");
assert_eq!(batch[1], "Second prompt");
}
_ => panic!("Expected batch text"),
}
}
#[test]
fn test_completion_to_pd_request_parameter_mapping() {
let req = CompletionRequest {
model: "test".to_string(),
prompt: StringOrArray::String("test".to_string()),
max_tokens: Some(150), // -> max_new_tokens
temperature: Some(0.75),
top_p: Some(0.92),
n: Some(3), // -> best_of
stream: true,
stream_options: None,
logprobs: Some(10), // -> top_n_tokens
echo: true, // -> return_full_text
stop: Some(StringOrArray::Array(vec![
"\\n".to_string(),
"END".to_string(),
])),
presence_penalty: Some(0.5), // -> repetition_penalty = 1.5
frequency_penalty: Some(0.2),
best_of: Some(5),
logit_bias: None,
user: Some("user123".to_string()),
seed: Some(42),
suffix: Some("...".to_string()),
};
let pd_req = req.to_pd_request();
let other = pd_req.other.as_object().unwrap();
let params = other.get("parameters").unwrap().as_object().unwrap();
// Check parameter mappings
assert_eq!(params.get("max_new_tokens"), Some(&json!(150)));
assert!(params.get("temperature").unwrap().as_f64().unwrap() - 0.75 < 0.0001);
assert!(params.get("top_p").unwrap().as_f64().unwrap() - 0.92 < 0.0001);
assert_eq!(params.get("best_of"), Some(&json!(3)));
assert_eq!(params.get("top_n_tokens"), Some(&json!(10)));
assert_eq!(params.get("return_full_text"), Some(&json!(true)));
assert_eq!(params.get("stop"), Some(&json!(vec!["\\n", "END"])));
assert!(params.get("repetition_penalty").unwrap().as_f64().unwrap() - 1.5 < 0.0001);
assert_eq!(params.get("seed"), Some(&json!(42)));
// Check other fields
assert_eq!(other.get("model"), Some(&json!("test")));
assert_eq!(other.get("stream"), Some(&json!(true)));
}
#[test]
fn test_completion_to_pd_request_stop_string() {
let req = CompletionRequest {
model: "test".to_string(),
prompt: StringOrArray::String("test".to_string()),
stop: Some(StringOrArray::String("STOP".to_string())),
max_tokens: None,
temperature: None,
top_p: None,
n: None,
stream: false,
stream_options: None,
logprobs: None,
echo: false,
presence_penalty: None,
frequency_penalty: None,
best_of: None,
logit_bias: None,
user: None,
seed: None,
suffix: None,
};
let pd_req = req.to_pd_request();
let other = pd_req.other.as_object().unwrap();
let params = other.get("parameters").unwrap().as_object().unwrap();
// Single string stop should be converted to array
assert_eq!(params.get("stop"), Some(&json!(vec!["STOP"])));
}
#[test]
fn test_completion_to_pd_request_no_presence_penalty() {
let req = CompletionRequest {
model: "test".to_string(),
prompt: StringOrArray::String("test".to_string()),
presence_penalty: None,
max_tokens: None,
temperature: None,
top_p: None,
n: None,
stream: false,
stream_options: None,
logprobs: None,
echo: false,
stop: None,
frequency_penalty: None,
best_of: None,
logit_bias: None,
user: None,
seed: None,
suffix: None,
};
let pd_req = req.to_pd_request();
let other = pd_req.other.as_object().unwrap();
let params = other.get("parameters").unwrap().as_object().unwrap();
// Should not have repetition_penalty if presence_penalty is None
assert!(!params.contains_key("repetition_penalty"));
}
// ============= ChatCompletionRequest to_pd_request Tests =============
#[test]
fn test_chat_to_pd_request_basic() {
let messages = vec![
ChatMessage::System {
role: "system".to_string(),
content: "You are a helpful assistant".to_string(),
name: None,
},
ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Text("Hello!".to_string()),
name: None,
},
];
let req = ChatCompletionRequest {
messages,
model: "gpt-4".to_string(),
temperature: None,
top_p: None,
n: None,
stream: false,
stream_options: None,
stop: None,
max_tokens: None,
max_completion_tokens: None,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
logprobs: false,
top_logprobs: None,
user: None,
seed: None,
response_format: None,
tools: None,
tool_choice: None,
parallel_tool_calls: None,
functions: None,
function_call: None,
};
let pd_req = req.to_pd_request();
assert_eq!(pd_req.stream, false);
assert!(pd_req.bootstrap_host.is_none());
assert!(pd_req.bootstrap_port.is_none());
assert!(pd_req.bootstrap_room.is_none());
let other = pd_req.other.as_object().unwrap();
assert!(other.contains_key("messages"));
assert_eq!(other.get("model"), Some(&json!("gpt-4")));
assert_eq!(other.get("stream"), Some(&json!(false)));
// Check messages are preserved
let messages = other.get("messages").unwrap().as_array().unwrap();
assert_eq!(messages.len(), 2);
}
#[test]
fn test_chat_to_pd_request_with_all_optional_fields() {
let messages = vec![ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Text("Test".to_string()),
name: Some("test_user".to_string()),
}];
let mut logit_bias = HashMap::new();
logit_bias.insert("50256".to_string(), -100);
let tool = Tool {
tool_type: "function".to_string(),
function: Function {
name: "get_weather".to_string(),
description: Some("Get weather info".to_string()),
parameters: json!({"type": "object"}),
},
};
let req = ChatCompletionRequest {
messages,
model: "gpt-4".to_string(),
temperature: Some(0.8),
top_p: Some(0.95),
n: Some(2),
stream: true,
stream_options: Some(StreamOptions {
include_usage: Some(true),
}),
stop: Some(StringOrArray::String("\\n\\n".to_string())),
max_tokens: Some(200),
max_completion_tokens: Some(150),
presence_penalty: Some(0.1),
frequency_penalty: Some(0.2),
logit_bias: Some(logit_bias),
logprobs: true,
top_logprobs: Some(5),
user: Some("user456".to_string()),
seed: Some(12345),
response_format: Some(ResponseFormat::JsonObject),
tools: Some(vec![tool]),
tool_choice: Some(ToolChoice::Auto),
parallel_tool_calls: Some(false),
functions: None,
function_call: None,
};
let pd_req = req.to_pd_request();
let other = pd_req.other.as_object().unwrap();
// Check all fields are preserved
assert!(other.get("temperature").unwrap().as_f64().unwrap() - 0.8 < 0.0001);
assert!(other.get("top_p").unwrap().as_f64().unwrap() - 0.95 < 0.0001);
assert_eq!(other.get("n"), Some(&json!(2)));
assert_eq!(other.get("stream"), Some(&json!(true)));
assert!(other.contains_key("stream_options"));
assert!(other.contains_key("stop"));
assert_eq!(other.get("max_tokens"), Some(&json!(200)));
assert_eq!(other.get("max_completion_tokens"), Some(&json!(150)));
assert!(other.get("presence_penalty").unwrap().as_f64().unwrap() - 0.1 < 0.0001);
assert!(other.get("frequency_penalty").unwrap().as_f64().unwrap() - 0.2 < 0.0001);
assert!(other.contains_key("logit_bias"));
assert_eq!(other.get("logprobs"), Some(&json!(true)));
assert_eq!(other.get("top_logprobs"), Some(&json!(5)));
assert_eq!(other.get("user"), Some(&json!("user456")));
assert_eq!(other.get("seed"), Some(&json!(12345)));
assert!(other.contains_key("response_format"));
assert!(other.contains_key("tools"));
assert!(other.contains_key("tool_choice"));
assert_eq!(other.get("parallel_tool_calls"), Some(&json!(false)));
}
#[test]
fn test_chat_to_pd_request_multimodal_content() {
let messages = vec![ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Parts(vec![
ContentPart::Text {
text: "What's in this image?".to_string(),
},
ContentPart::ImageUrl {
image_url: ImageUrl {
url: "https://example.com/image.jpg".to_string(),
detail: Some("high".to_string()),
},
},
]),
name: None,
}];
let req = ChatCompletionRequest {
messages,
model: "gpt-4-vision".to_string(),
temperature: None,
top_p: None,
n: None,
stream: false,
stream_options: None,
stop: None,
max_tokens: None,
max_completion_tokens: None,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
logprobs: false,
top_logprobs: None,
user: None,
seed: None,
response_format: None,
tools: None,
tool_choice: None,
parallel_tool_calls: None,
functions: None,
function_call: None,
};
let pd_req = req.to_pd_request();
let other = pd_req.other.as_object().unwrap();
// Messages with multimodal content should be preserved
assert!(other.contains_key("messages"));
let messages = other.get("messages").unwrap().as_array().unwrap();
assert_eq!(messages.len(), 1);
// Verify the message structure is preserved
let msg = &messages[0];
assert_eq!(msg["role"], "user");
assert!(msg["content"].is_array());
}
#[test]
fn test_chat_to_pd_request_logprobs_boolean() {
let messages = vec![ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Text("Test".to_string()),
name: None,
}];
let req = ChatCompletionRequest {
messages,
model: "test".to_string(),
logprobs: true, // Boolean logprobs flag
top_logprobs: Some(3),
temperature: None,
top_p: None,
n: None,
stream: false,
stream_options: None,
stop: None,
max_tokens: None,
max_completion_tokens: None,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
user: None,
seed: None,
response_format: None,
tools: None,
tool_choice: None,
parallel_tool_calls: None,
functions: None,
function_call: None,
};
let pd_req = req.to_pd_request();
let other = pd_req.other.as_object().unwrap();
assert_eq!(other.get("logprobs"), Some(&json!(true)));
assert_eq!(other.get("top_logprobs"), Some(&json!(3)));
}
#[test]
fn test_chat_to_pd_request_minimal_fields() {
let messages = vec![ChatMessage::Assistant {
role: "assistant".to_string(),
content: Some("I can help with that.".to_string()),
name: None,
tool_calls: None,
function_call: None,
}];
let req = ChatCompletionRequest {
messages,
model: "gpt-3.5-turbo".to_string(),
temperature: None,
top_p: None,
n: None,
stream: false,
stream_options: None,
stop: None,
max_tokens: None,
max_completion_tokens: None,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
logprobs: false,
top_logprobs: None,
user: None,
seed: None,
response_format: None,
tools: None,
tool_choice: None,
parallel_tool_calls: None,
functions: None,
function_call: None,
};
let pd_req = req.to_pd_request();
let other = pd_req.other.as_object().unwrap();
// Should only have required fields
assert!(other.contains_key("messages"));
assert!(other.contains_key("model"));
assert!(other.contains_key("stream"));
// Optional fields should not be present
assert!(!other.contains_key("temperature"));
assert!(!other.contains_key("top_p"));
assert!(!other.contains_key("max_tokens"));
assert!(!other.contains_key("stop"));
}
#[test]
fn test_routeable_request_to_json() {
let req = GenerateRequest {
text: Some("test".to_string()),
prompt: None,
input_ids: None,
stream: false,
parameters: None,
sampling_params: None,
return_logprob: false,
};
let json = req.to_json().unwrap();
assert_eq!(json["text"], "test");
assert_eq!(json["stream"], false);
}
// ============= Macro Tests =============
#[test]
fn test_insert_if_some_macro() {
let mut map = serde_json::Map::new();
let some_value: Option<i32> = Some(42);
let none_value: Option<i32> = None;
insert_if_some!(map,
some_value => "present",
none_value => "absent"
);
assert_eq!(map.get("present"), Some(&json!(42)));
assert!(!map.contains_key("absent"));
}
#[test]
fn test_insert_value_macro() {
let mut map = serde_json::Map::new();
let value1 = "test";
let value2 = 42;
insert_value!(map,
value1 => "string_field",
value2 => "int_field"
);
assert_eq!(map.get("string_field"), Some(&json!("test")));
assert_eq!(map.get("int_field"), Some(&json!(42)));
}
// ============= Edge Cases and Error Handling =============
#[test]
fn test_null_value_handling() {
let params = GenerateParameters {
max_new_tokens: None,
temperature: None,
..Default::default()
};
let req = GenerateRequest {
text: Some("test".to_string()),
prompt: None,
input_ids: None,
stream: false,
parameters: Some(params),
sampling_params: None,
return_logprob: false,
};
let pd_req = req.to_pd_request();
let other = pd_req.other.as_object().unwrap();
// Should not have parameters field if all fields are None
assert!(!other.contains_key("parameters"));
}
#[test]
fn test_large_batch_conversion() {
let large_batch: Vec<String> = (0..1000).map(|i| format!("item_{}", i)).collect();
let req = GenerateRequest {
text: None,
prompt: Some(StringOrArray::Array(large_batch.clone())),
input_ids: None,
stream: false,
parameters: None,
sampling_params: None,
return_logprob: false,
};
let pd_req = req.to_pd_request();
if let Some(SingleOrBatch::Batch(batch)) = pd_req.text {
assert_eq!(batch.len(), 1000);
assert_eq!(batch[0], "item_0");
assert_eq!(batch[999], "item_999");
} else {
panic!("Expected batch text");
}
}
#[test]
fn test_unicode_string_handling() {
let unicode_text = "Hello 世界 🌍 नमस्ते мир".to_string();
let req = GenerateRequest {
text: Some(unicode_text.clone()),
prompt: None,
input_ids: None,
stream: false,
parameters: None,
sampling_params: None,
return_logprob: false,
};
let pd_req = req.to_pd_request();
if let Some(SingleOrBatch::Single(text)) = pd_req.text {
assert_eq!(text, unicode_text);
} else {
panic!("Expected single text");
}
}
#[test]
fn test_deeply_nested_parameters() {
let mut nested_params = serde_json::Map::new();
nested_params.insert(
"nested".to_string(),
json!({
"level1": {
"level2": {
"level3": "value"
}
}
}),
);
let params = GenerateParameters {
max_new_tokens: Some(100),
..Default::default()
};
let req = GenerateRequest {
text: Some("test".to_string()),
prompt: None,
input_ids: None,
stream: false,
parameters: Some(params),
sampling_params: None,
return_logprob: false,
};
let pd_req = req.to_pd_request();
let other = pd_req.other.as_object().unwrap();
// Parameters should be preserved even with nested structures
assert!(other.contains_key("max_new_tokens"));
}
// ============= Bootstrap Field Tests =============
#[test]
fn test_bootstrap_fields_none() {
let req = GenerateRequest {
text: Some("test".to_string()),
prompt: None,
input_ids: None,
stream: false,
parameters: None,
sampling_params: None,
return_logprob: false,
};
let pd_req = req.to_pd_request();
assert_eq!(pd_req.bootstrap_host, None);
assert_eq!(pd_req.bootstrap_port, None);
assert_eq!(pd_req.bootstrap_room, None);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment