Commit 7f85dcc3 authored by Thomas Montfort's avatar Thomas Montfort Committed by GitHub
Browse files

test: add unit tests for RuntimeConfig (#215)

parent 9e4a548d
...@@ -3179,9 +3179,9 @@ dependencies = [ ...@@ -3179,9 +3179,9 @@ dependencies = [
[[package]] [[package]]
name = "once_cell" name = "once_cell"
version = "1.20.3" version = "1.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "945462a4b81e43c4e3ba96bd7b49d834c6f61198356aa858733bc4acf3cbe62e" checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775"
[[package]] [[package]]
name = "onig" name = "onig"
......
...@@ -2551,6 +2551,15 @@ version = "0.12.16" ...@@ -2551,6 +2551,15 @@ version = "0.12.16"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
[[package]]
name = "temp-env"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96374855068f47402c3121c6eed88d29cb1de8f3ab27090e273e420bdabcf050"
dependencies = [
"parking_lot",
]
[[package]] [[package]]
name = "tempfile" name = "tempfile"
version = "3.17.1" version = "3.17.1"
...@@ -2957,6 +2966,7 @@ dependencies = [ ...@@ -2957,6 +2966,7 @@ dependencies = [
"serde", "serde",
"serde_json", "serde_json",
"socket2", "socket2",
"temp-env",
"thiserror", "thiserror",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
......
...@@ -76,3 +76,4 @@ rand = { version = "0.8"} ...@@ -76,3 +76,4 @@ rand = { version = "0.8"}
assert_matches = "1.5.0" assert_matches = "1.5.0"
env_logger = "0.11" env_logger = "0.11"
rstest = "0.23.0" rstest = "0.23.0"
temp-env = "0.3.6"
...@@ -35,7 +35,7 @@ impl WorkerConfig { ...@@ -35,7 +35,7 @@ impl WorkerConfig {
// All calls should be global and thread safe. // All calls should be global and thread safe.
Figment::new() Figment::new()
.merge(Serialized::defaults(Self::default())) .merge(Serialized::defaults(Self::default()))
.merge(Env::prefixed("TRITON_WORKER_")) .merge(Env::prefixed("TRD_WORKER_"))
.extract() .extract()
.unwrap() // safety: Called on startup, so panic is reasonable .unwrap() // safety: Called on startup, so panic is reasonable
} }
...@@ -58,12 +58,12 @@ impl Default for WorkerConfig { ...@@ -58,12 +58,12 @@ impl Default for WorkerConfig {
#[derive(Serialize, Deserialize, Validate, Debug, Builder, Clone)] #[derive(Serialize, Deserialize, Validate, Debug, Builder, Clone)]
#[builder(build_fn(private, name = "build_internal"), derive(Debug, Serialize))] #[builder(build_fn(private, name = "build_internal"), derive(Debug, Serialize))]
pub struct RuntimeConfig { pub struct RuntimeConfig {
/// Maximum number of async worker threads /// Number of async worker threads
/// If set to 1, the runtime will run in single-threaded mode /// If set to 1, the runtime will run in single-threaded mode
#[validate(range(min = 1))] #[validate(range(min = 1))]
#[builder(default = "16")] #[builder(default = "16")]
#[builder_field_attr(serde(skip_serializing_if = "Option::is_none"))] #[builder_field_attr(serde(skip_serializing_if = "Option::is_none"))]
pub max_worker_threads: usize, pub num_worker_threads: usize,
/// Maximum number of blocking threads /// Maximum number of blocking threads
/// Blocking threads are used for blocking operations, this value must be greater than 0. /// Blocking threads are used for blocking operations, this value must be greater than 0.
...@@ -83,16 +83,24 @@ impl RuntimeConfig { ...@@ -83,16 +83,24 @@ impl RuntimeConfig {
.merge(Serialized::defaults(RuntimeConfig::default())) .merge(Serialized::defaults(RuntimeConfig::default()))
.merge(Toml::file("/opt/triton/defaults/runtime.toml")) .merge(Toml::file("/opt/triton/defaults/runtime.toml"))
.merge(Toml::file("/opt/triton/etc/runtime.toml")) .merge(Toml::file("/opt/triton/etc/runtime.toml"))
.merge(Env::prefixed("TRITON_RUNTIME_")) .merge(Env::prefixed("TRD_RUNTIME_").filter_map(|k| {
let full_key = format!("TRD_RUNTIME_{}", k.as_str());
// filters out empty environment variables
match std::env::var(&full_key) {
Ok(v) if !v.is_empty() => Some(k.into()),
_ => None,
}
}))
} }
/// Load the runtime configuration from the environment and configuration files /// Load the runtime configuration from the environment and configuration files
/// Configuration is priorities in the following order, where the last has the lowest priority: /// Configuration is priorities in the following order, where the last has the lowest priority:
/// 1. Environment variables (top priority) /// 1. Environment variables (top priority)
/// TO DO: Add documentation for configuration files. Paths should be configurable.
/// 2. /opt/triton/etc/runtime.toml /// 2. /opt/triton/etc/runtime.toml
/// 3. /opt/triton/defaults/runtime.toml (lowest priority) /// 3. /opt/triton/defaults/runtime.toml (lowest priority)
/// ///
/// Environment variables are prefixed with `TRITON_RUNTIME_` /// Environment variables are prefixed with `TRD_RUNTIME_`
pub fn from_settings() -> Result<RuntimeConfig> { pub fn from_settings() -> Result<RuntimeConfig> {
let config: RuntimeConfig = Self::figment().extract()?; let config: RuntimeConfig = Self::figment().extract()?;
config.validate()?; config.validate()?;
...@@ -101,7 +109,7 @@ impl RuntimeConfig { ...@@ -101,7 +109,7 @@ impl RuntimeConfig {
pub fn single_threaded() -> Self { pub fn single_threaded() -> Self {
RuntimeConfig { RuntimeConfig {
max_worker_threads: 1, num_worker_threads: 1,
max_blocking_threads: 1, max_blocking_threads: 1,
} }
} }
...@@ -109,7 +117,7 @@ impl RuntimeConfig { ...@@ -109,7 +117,7 @@ impl RuntimeConfig {
/// Create a new default runtime configuration /// Create a new default runtime configuration
pub(crate) fn create_runtime(&self) -> Result<tokio::runtime::Runtime> { pub(crate) fn create_runtime(&self) -> Result<tokio::runtime::Runtime> {
Ok(tokio::runtime::Builder::new_multi_thread() Ok(tokio::runtime::Builder::new_multi_thread()
.worker_threads(self.max_worker_threads) .worker_threads(self.num_worker_threads)
.max_blocking_threads(self.max_blocking_threads) .max_blocking_threads(self.max_blocking_threads)
.enable_all() .enable_all()
.build()?) .build()?)
...@@ -119,7 +127,7 @@ impl RuntimeConfig { ...@@ -119,7 +127,7 @@ impl RuntimeConfig {
impl Default for RuntimeConfig { impl Default for RuntimeConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
max_worker_threads: 16, num_worker_threads: 16,
max_blocking_threads: 16, max_blocking_threads: 16,
} }
} }
...@@ -161,3 +169,68 @@ pub fn jsonl_logging_enabled() -> bool { ...@@ -161,3 +169,68 @@ pub fn jsonl_logging_enabled() -> bool {
pub fn disable_ansi_logging() -> bool { pub fn disable_ansi_logging() -> bool {
env_is_truthy("TRD_SDK_DISABLE_ANSI_LOGGING") env_is_truthy("TRD_SDK_DISABLE_ANSI_LOGGING")
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_runtime_config_with_env_vars() -> Result<()> {
temp_env::with_vars(
vec![
("TRD_RUNTIME_NUM_WORKER_THREADS", Some("24")),
("TRD_RUNTIME_MAX_BLOCKING_THREADS", Some("32")),
],
|| {
let config = RuntimeConfig::from_settings()?;
assert_eq!(config.num_worker_threads, 24);
assert_eq!(config.max_blocking_threads, 32);
Ok(())
},
)
}
#[test]
fn test_runtime_config_defaults() -> Result<()> {
temp_env::with_vars(
vec![
("TRD_RUNTIME_NUM_WORKER_THREADS", None::<&str>),
("TRD_RUNTIME_MAX_BLOCKING_THREADS", Some("")),
],
|| {
let config = RuntimeConfig::from_settings()?;
let default_config = RuntimeConfig::default();
assert_eq!(config.num_worker_threads, default_config.num_worker_threads);
assert_eq!(
config.max_blocking_threads,
default_config.max_blocking_threads
);
Ok(())
},
)
}
#[test]
fn test_runtime_config_rejects_invalid_thread_count() -> Result<()> {
temp_env::with_vars(
vec![
("TRD_RUNTIME_NUM_WORKER_THREADS", Some("0")),
("TRD_RUNTIME_MAX_BLOCKING_THREADS", Some("0")),
],
|| {
let result = RuntimeConfig::from_settings();
assert!(result.is_err());
if let Err(e) = result {
assert!(e
.to_string()
.contains("num_worker_threads: Validation error"));
assert!(e
.to_string()
.contains("max_blocking_threads: Validation error"));
}
Ok(())
},
)
}
}
...@@ -25,10 +25,10 @@ ...@@ -25,10 +25,10 @@
//! the signal handler used to trap `SIGINT` and `SIGTERM` signals and trigger a graceful shutdown. //! the signal handler used to trap `SIGINT` and `SIGTERM` signals and trigger a graceful shutdown.
//! //!
//! On termination, the user application is given a graceful shutdown period of controlled by //! On termination, the user application is given a graceful shutdown period of controlled by
//! the [TRITON_WORKER_GRACEFUL_SHUTDOWN_TIMEOUT] environment variable. If the application does not //! the [TRD_WORKER_GRACEFUL_SHUTDOWN_TIMEOUT] environment variable. If the application does not
//! shutdown in time, the worker will terminate the application with an exit code of 911. //! shutdown in time, the worker will terminate the application with an exit code of 911.
//! //!
//! The default values of `TRITON_WORKER_GRACEFUL_SHUTDOWN_TIMEOUT` differ between the development //! The default values of [TRD_WORKER_GRACEFUL_SHUTDOWN_TIMEOUT] differ between the development
//! and release builds. In development, the default is [DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT_DEBUG] and //! and release builds. In development, the default is [DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT_DEBUG] and
//! in release, the default is [DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT_RELEASE]. //! in release, the default is [DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT_RELEASE].
...@@ -45,10 +45,10 @@ static INIT: OnceCell<Mutex<Option<tokio::task::JoinHandle<Result<()>>>>> = Once ...@@ -45,10 +45,10 @@ static INIT: OnceCell<Mutex<Option<tokio::task::JoinHandle<Result<()>>>>> = Once
const SHUTDOWN_MESSAGE: &str = const SHUTDOWN_MESSAGE: &str =
"Application received shutdown signal; attempting to gracefully shutdown"; "Application received shutdown signal; attempting to gracefully shutdown";
const SHUTDOWN_TIMEOUT_MESSAGE: &str = const SHUTDOWN_TIMEOUT_MESSAGE: &str =
"Use TRITON_WORKER_GRACEFUL_SHUTDOWN_TIMEOUT to control the graceful shutdown timeout"; "Use TRD_WORKER_GRACEFUL_SHUTDOWN_TIMEOUT to control the graceful shutdown timeout";
/// Environment variable to control the graceful shutdown timeout /// Environment variable to control the graceful shutdown timeout
pub const TRITON_WORKER_GRACEFUL_SHUTDOWN_TIMEOUT: &str = "TRITON_WORKER_GRACEFUL_SHUTDOWN_TIMEOUT"; pub const TRD_WORKER_GRACEFUL_SHUTDOWN_TIMEOUT: &str = "TRD_WORKER_GRACEFUL_SHUTDOWN_TIMEOUT";
/// Default graceful shutdown timeout in seconds in debug mode /// Default graceful shutdown timeout in seconds in debug mode
pub const DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT_DEBUG: u64 = 5; pub const DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT_DEBUG: u64 = 5;
...@@ -105,7 +105,7 @@ impl Worker { ...@@ -105,7 +105,7 @@ impl Worker {
let primary = runtime.primary(); let primary = runtime.primary();
let secondary = runtime.secondary.clone(); let secondary = runtime.secondary.clone();
let timeout = std::env::var(TRITON_WORKER_GRACEFUL_SHUTDOWN_TIMEOUT) let timeout = std::env::var(TRD_WORKER_GRACEFUL_SHUTDOWN_TIMEOUT)
.ok() .ok()
.and_then(|s| s.parse::<u64>().ok()) .and_then(|s| s.parse::<u64>().ok())
.unwrap_or({ .unwrap_or({
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment