// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. //! System health monitoring and health check management use std::{ collections::HashMap, sync::{Arc, OnceLock}, time::Instant, }; use tokio::sync::mpsc; use crate::component; use crate::config::HealthStatus; use crate::metrics::prometheus_names::distributed_runtime; /// Health check target containing instance info and payload #[derive(Clone, Debug)] pub struct HealthCheckTarget { pub instance: component::Instance, pub payload: serde_json::Value, } /// Current Health Status /// If use_endpoint_health_status is set then /// initialize the endpoint_health hashmap to the /// starting health status #[derive(Clone)] pub struct SystemHealth { system_health: HealthStatus, endpoint_health: Arc>>, /// Maps endpoint subject to health check target (instance + payload) health_check_targets: Arc>>, /// Maps endpoint subject to its specific health check notifier health_check_notifiers: Arc>>>, /// Channel for new endpoint registrations /// This solves the race condition where HealthCheckManager starts before endpoints are registered /// Using a channel ensures no registrations are lost. new_endpoint_tx: mpsc::UnboundedSender, new_endpoint_rx: Arc>>>, use_endpoint_health_status: Vec, health_path: String, live_path: String, start_time: Instant, uptime_gauge: OnceLock, } impl SystemHealth { pub fn new( starting_health_status: HealthStatus, use_endpoint_health_status: Vec, health_path: String, live_path: String, ) -> Self { let mut endpoint_health = HashMap::new(); for endpoint in &use_endpoint_health_status { endpoint_health.insert(endpoint.clone(), starting_health_status.clone()); } // Create the channel for endpoint registration notifications let (tx, rx) = mpsc::unbounded_channel(); SystemHealth { system_health: starting_health_status, endpoint_health: Arc::new(std::sync::RwLock::new(endpoint_health)), health_check_targets: Arc::new(std::sync::RwLock::new(HashMap::new())), health_check_notifiers: Arc::new(std::sync::RwLock::new(HashMap::new())), new_endpoint_tx: tx, new_endpoint_rx: Arc::new(parking_lot::Mutex::new(Some(rx))), use_endpoint_health_status, health_path, live_path, start_time: Instant::now(), uptime_gauge: OnceLock::new(), } } pub fn set_health_status(&mut self, status: HealthStatus) { self.system_health = status; } pub fn set_endpoint_health_status(&self, endpoint: &str, status: HealthStatus) { let mut endpoint_health = self.endpoint_health.write().unwrap(); endpoint_health.insert(endpoint.to_string(), status); } /// Returns the overall health status and endpoint health statuses /// System health is determined by ALL endpoints that have registered health checks pub fn get_health_status(&self) -> (bool, HashMap) { let health_check_targets = self.health_check_targets.read().unwrap(); let endpoint_health = self.endpoint_health.read().unwrap(); let mut endpoints: HashMap = HashMap::new(); for (endpoint, status) in endpoint_health.iter() { endpoints.insert( endpoint.clone(), if *status == HealthStatus::Ready { "ready".to_string() } else { "notready".to_string() }, ); } let healthy = if !self.use_endpoint_health_status.is_empty() { self.use_endpoint_health_status.iter().all(|endpoint| { endpoint_health .get(endpoint) .is_some_and(|status| *status == HealthStatus::Ready) }) } else { // If we have registered health check targets, use them to determine health if !health_check_targets.is_empty() { health_check_targets .iter() .all(|(endpoint_subject, _target)| { endpoint_health .get(endpoint_subject) .is_some_and(|status| *status == HealthStatus::Ready) }) } else { // No health check targets registered, use simple system health self.system_health == HealthStatus::Ready } }; (healthy, endpoints) } /// Register a health check target for an endpoint pub fn register_health_check_target( &self, endpoint_subject: &str, instance: component::Instance, payload: serde_json::Value, ) { let key = endpoint_subject.to_owned(); // Atomically check+insert under a single write lock to avoid races. let inserted = { let mut targets = self.health_check_targets.write().unwrap(); match targets.entry(key.clone()) { std::collections::hash_map::Entry::Occupied(_) => false, std::collections::hash_map::Entry::Vacant(v) => { v.insert(HealthCheckTarget { instance, payload }); true } } }; if !inserted { tracing::warn!( "Attempted to re-register health check for endpoint '{}'; ignoring.", key ); return; } // Create and store a unique notifier for this endpoint (idempotent). { let mut notifiers = self.health_check_notifiers.write().unwrap(); notifiers .entry(key.clone()) .or_insert_with(|| Arc::new(tokio::sync::Notify::new())); } // Initialize endpoint health status conservatively to NotReady. { let mut endpoint_health = self.endpoint_health.write().unwrap(); endpoint_health .entry(key.clone()) .or_insert(HealthStatus::NotReady); } if let Err(e) = self.new_endpoint_tx.send(key.clone()) { tracing::error!( "Failed to send endpoint '{}' registration to health check manager: {}. \ Health checks will not be performed for this endpoint.", key, e ); } } /// Get all health check targets pub fn get_health_check_targets(&self) -> Vec<(String, HealthCheckTarget)> { let targets = self.health_check_targets.read().unwrap(); targets .iter() .map(|(k, v)| (k.clone(), v.clone())) .collect() } /// Check if any health check targets are registered pub fn has_health_check_targets(&self) -> bool { let targets = self.health_check_targets.read().unwrap(); !targets.is_empty() } /// Get list of endpoints with health check targets pub fn get_health_check_endpoints(&self) -> Vec { let targets = self.health_check_targets.read().unwrap(); targets.keys().cloned().collect() } /// Get health check target for a specific endpoint pub fn get_health_check_target(&self, endpoint: &str) -> Option { let targets = self.health_check_targets.read().unwrap(); targets.get(endpoint).cloned() } /// Get the endpoint health status (Ready/NotReady) pub fn get_endpoint_health_status(&self, endpoint: &str) -> Option { let endpoint_health = self.endpoint_health.read().unwrap(); endpoint_health.get(endpoint).cloned() } /// Get the endpoint-specific health check notifier pub fn get_endpoint_health_check_notifier( &self, endpoint_subject: &str, ) -> Option> { let notifiers = self.health_check_notifiers.read().unwrap(); notifiers.get(endpoint_subject).cloned() } /// Take the receiver for new endpoint registrations (can only be called once) /// This is used by HealthCheckManager to receive notifications of new endpoints pub fn take_new_endpoint_receiver(&self) -> Option> { self.new_endpoint_rx.lock().take() } /// Initialize the uptime gauge using the provided metrics registry pub fn initialize_uptime_gauge( &self, registry: &T, ) -> anyhow::Result<()> { let gauge = registry.create_gauge( distributed_runtime::UPTIME_SECONDS, "Total uptime of the DistributedRuntime in seconds", &[], )?; self.uptime_gauge .set(gauge) .map_err(|_| anyhow::anyhow!("uptime_gauge already initialized"))?; Ok(()) } /// Get the current uptime as a Duration pub fn uptime(&self) -> std::time::Duration { self.start_time.elapsed() } /// Update the uptime gauge with the current uptime value pub fn update_uptime_gauge(&self) { if let Some(gauge) = self.uptime_gauge.get() { gauge.set(self.uptime().as_secs_f64()); } } /// Get the health check path pub fn health_path(&self) -> &str { &self.health_path } /// Get the liveness check path pub fn live_path(&self) -> &str { &self.live_path } }