distributed.rs 17.2 KB
Newer Older
1
2
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
Ryan Olson's avatar
Ryan Olson committed
3
4

pub use crate::component::Component;
5
use crate::storage::key_value_store::{EtcdStore, KeyValueStore, MemoryStore};
6
use crate::transports::nats::DRTNatsClientPrometheusMetrics;
Ryan Olson's avatar
Ryan Olson committed
7
use crate::{
8
    ErrorContext, PrometheusUpdateCallback,
9
    component::{self, ComponentBuilder, Endpoint, InstanceSource, Namespace},
Ryan Olson's avatar
Ryan Olson committed
10
    discovery::DiscoveryClient,
11
    metrics::MetricsRegistry,
Ryan Olson's avatar
Ryan Olson committed
12
13
14
15
    service::ServiceClient,
    transports::{etcd, nats, tcp},
};

16
use super::utils::GracefulShutdownTracker;
17
use super::{Arc, DistributedRuntime, OK, OnceCell, Result, Runtime, SystemHealth, Weak, error};
18
use std::sync::OnceLock;
Ryan Olson's avatar
Ryan Olson committed
19
20
21

use derive_getters::Dissolve;
use figment::error;
22
23
use std::collections::HashMap;
use tokio::sync::Mutex;
24
use tokio_util::sync::CancellationToken;
Ryan Olson's avatar
Ryan Olson committed
25

26
27
28
29
30
31
32
33
34
35
impl MetricsRegistry for DistributedRuntime {
    fn basename(&self) -> String {
        "".to_string() // drt has no basename. Basename only begins with the Namespace.
    }

    fn parent_hierarchy(&self) -> Vec<String> {
        vec![] // drt is the root, so no parent hierarchy
    }
}

Ryan Olson's avatar
Ryan Olson committed
36
37
38
39
40
41
impl std::fmt::Debug for DistributedRuntime {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "DistributedRuntime")
    }
}

Ryan Olson's avatar
Ryan Olson committed
42
43
impl DistributedRuntime {
    pub async fn new(runtime: Runtime, config: DistributedConfig) -> Result<Self> {
44
        let (etcd_config, nats_config, is_static) = config.dissolve();
Ryan Olson's avatar
Ryan Olson committed
45
46
47

        let runtime_clone = runtime.clone();

48
49
50
        let (etcd_client, store) = if is_static {
            let store: Arc<dyn KeyValueStore> = Arc::new(MemoryStore::new());
            (None, store)
51
        } else {
52
53
54
55
            let etcd_client = etcd::Client::new(etcd_config.clone(), runtime_clone).await?;
            let store: Arc<dyn KeyValueStore> = Arc::new(EtcdStore::new(etcd_client.clone()));

            (Some(etcd_client), store)
56
        };
Ryan Olson's avatar
Ryan Olson committed
57

58
        let nats_client = Some(nats_config.clone().connect().await?);
Ryan Olson's avatar
Ryan Olson committed
59

60
        // Start system status server for health and metrics if enabled in configuration
61
62
63
64
65
66
67
68
        let config = crate::config::RuntimeConfig::from_settings().unwrap_or_default();
        // IMPORTANT: We must extract cancel_token from runtime BEFORE moving runtime into the struct below.
        // This is because after moving, runtime is no longer accessible in this scope (ownership rules).
        let cancel_token = if config.system_server_enabled() {
            Some(runtime.clone().child_token())
        } else {
            None
        };
69
70
        let starting_health_status = config.starting_health_status.clone();
        let use_endpoint_health_status = config.use_endpoint_health_status.clone();
71
72
        let health_endpoint_path = config.system_health_path.clone();
        let live_endpoint_path = config.system_live_path.clone();
73
        let system_health = Arc::new(parking_lot::Mutex::new(SystemHealth::new(
74
75
            starting_health_status,
            use_endpoint_health_status,
76
77
            health_endpoint_path,
            live_endpoint_path,
78
        )));
79

80
81
        let nats_client_for_metrics = nats_client.clone();

82
        let distributed_runtime = Self {
Ryan Olson's avatar
Ryan Olson committed
83
84
            runtime,
            etcd_client,
85
            store,
Ryan Olson's avatar
Ryan Olson committed
86
87
            nats_client,
            tcp_server: Arc::new(OnceCell::new()),
88
            system_status_server: Arc::new(OnceLock::new()),
Ryan Olson's avatar
Ryan Olson committed
89
            component_registry: component::Registry::new(),
90
            is_static,
91
            instance_sources: Arc::new(Mutex::new(HashMap::new())),
92
            hierarchy_to_metricsregistry: Arc::new(std::sync::RwLock::new(HashMap::<
93
                String,
94
                crate::MetricsRegistryEntry,
95
            >::new())),
96
            system_health,
97
98
        };

99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        if let Some(nats_client_for_metrics) = nats_client_for_metrics {
            let nats_client_metrics = DRTNatsClientPrometheusMetrics::new(
                &distributed_runtime,
                nats_client_for_metrics.client().clone(),
            )?;
            let mut drt_hierarchies = distributed_runtime.parent_hierarchy();
            drt_hierarchies.push(distributed_runtime.hierarchy());
            // Register a callback to update NATS client metrics
            let nats_client_callback = Arc::new({
                let nats_client_clone = nats_client_metrics.clone();
                move || {
                    nats_client_clone.set_from_client_stats();
                    Ok(())
                }
            });
            distributed_runtime
                .register_prometheus_update_callback(drt_hierarchies, nats_client_callback);
        }
117

118
119
120
121
122
123
        // Initialize the uptime gauge in SystemHealth
        distributed_runtime
            .system_health
            .lock()
            .initialize_uptime_gauge(&distributed_runtime)?;

124
        // Handle system status server initialization
125
        if let Some(cancel_token) = cancel_token {
126
            // System server is enabled - start both the state and HTTP server
127
128
129
            let host = config.system_host.clone();
            let port = config.system_port;

130
            // Start system status server (it creates SystemStatusState internally)
131
            match crate::system_status_server::spawn_system_status_server(
132
133
134
135
                &host,
                port,
                cancel_token,
                Arc::new(distributed_runtime.clone()),
136
137
138
            )
            .await
            {
139
                Ok((addr, handle)) => {
140
                    tracing::info!("System status server started successfully on {}", addr);
141

142
143
144
145
146
147
                    // Store system status server information
                    let system_status_server_info =
                        crate::system_status_server::SystemStatusServerInfo::new(
                            addr,
                            Some(handle),
                        );
148

149
                    // Initialize the system_status_server field
150
                    distributed_runtime
151
152
153
                        .system_status_server
                        .set(Arc::new(system_status_server_info))
                        .expect("System status server info should only be set once");
154
155
                }
                Err(e) => {
156
                    tracing::error!("System status server startup failed: {}", e);
157
                }
158
            }
159
        } else {
160
            // System server HTTP is disabled, but uptime metrics are still being tracked via SystemHealth
161
162
163
            tracing::debug!(
                "System status server HTTP endpoints disabled, but uptime metrics are being tracked"
            );
164
165
        }

166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
        // Start health check manager if enabled
        if config.health_check_enabled {
            let health_check_config = crate::health_check::HealthCheckConfig {
                canary_wait_time: std::time::Duration::from_secs(config.canary_wait_time_secs),
                request_timeout: std::time::Duration::from_secs(
                    config.health_check_request_timeout_secs,
                ),
            };

            // Start the health check manager (spawns per-endpoint monitoring tasks)
            match crate::health_check::start_health_check_manager(
                distributed_runtime.clone(),
                Some(health_check_config),
            )
            .await
            {
                Ok(()) => tracing::info!(
                    "Health check manager started (canary_wait_time: {}s, request_timeout: {}s)",
                    config.canary_wait_time_secs,
                    config.health_check_request_timeout_secs
                ),
                Err(e) => tracing::error!("Health check manager failed to start: {}", e),
            }
        }

191
        Ok(distributed_runtime)
Ryan Olson's avatar
Ryan Olson committed
192
193
194
    }

    pub async fn from_settings(runtime: Runtime) -> Result<Self> {
195
196
197
198
199
200
201
        let config = DistributedConfig::from_settings(false);
        Self::new(runtime, config).await
    }

    // Call this if you are using static workers that do not need etcd-based discovery.
    pub async fn from_settings_without_discovery(runtime: Runtime) -> Result<Self> {
        let config = DistributedConfig::from_settings(true);
Ryan Olson's avatar
Ryan Olson committed
202
203
204
205
206
207
208
        Self::new(runtime, config).await
    }

    pub fn runtime(&self) -> &Runtime {
        &self.runtime
    }

209
210
211
212
    pub fn primary_token(&self) -> CancellationToken {
        self.runtime.primary_token()
    }

213
214
215
216
    /// The etcd lease all our components will be attached to.
    /// Not available for static workers.
    pub fn primary_lease(&self) -> Option<etcd::Lease> {
        self.etcd_client.as_ref().map(|c| c.primary_lease())
Ryan Olson's avatar
Ryan Olson committed
217
218
219
220
221
222
223
224
    }

    pub fn shutdown(&self) {
        self.runtime.shutdown();
    }

    /// Create a [`Namespace`]
    pub fn namespace(&self, name: impl Into<String>) -> Result<Namespace> {
225
        Namespace::new(self.clone(), name.into(), self.is_static)
Ryan Olson's avatar
Ryan Olson committed
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
    }

    // /// Create a [`Component`]
    // pub fn component(
    //     &self,
    //     name: impl Into<String>,
    //     namespace: impl Into<String>,
    // ) -> Result<Component> {
    //     Ok(ComponentBuilder::from_runtime(self.clone())
    //         .name(name.into())
    //         .namespace(namespace.into())
    //         .build()?)
    // }

    pub(crate) fn discovery_client(&self, namespace: impl Into<String>) -> DiscoveryClient {
241
242
243
244
245
246
        DiscoveryClient::new(
            namespace.into(),
            self.etcd_client
                .clone()
                .expect("Attempt to get discovery_client on static DistributedRuntime"),
        )
Ryan Olson's avatar
Ryan Olson committed
247
248
    }

249
250
    pub(crate) fn service_client(&self) -> Option<ServiceClient> {
        self.nats_client().map(|nc| ServiceClient::new(nc.clone()))
Ryan Olson's avatar
Ryan Olson committed
251
252
    }

253
    pub async fn tcp_server(&self) -> Result<Arc<tcp::server::TcpStreamServer>> {
Ryan Olson's avatar
Ryan Olson committed
254
255
256
257
258
259
260
261
262
263
264
        Ok(self
            .tcp_server
            .get_or_try_init(async move {
                let options = tcp::server::ServerOptions::default();
                let server = tcp::server::TcpStreamServer::new(options).await?;
                OK(server)
            })
            .await?
            .clone())
    }

265
266
    pub fn nats_client(&self) -> Option<&nats::Client> {
        self.nats_client.as_ref()
Ryan Olson's avatar
Ryan Olson committed
267
268
    }

269
270
271
272
273
    /// Get system status server information if available
    pub fn system_status_server_info(
        &self,
    ) -> Option<Arc<crate::system_status_server::SystemStatusServerInfo>> {
        self.system_status_server.get().cloned()
274
275
    }

276
    // todo(ryan): deprecate this as we move to Discovery traits and Component Identifiers
277
    pub fn etcd_client(&self) -> Option<etcd::Client> {
Ryan Olson's avatar
Ryan Olson committed
278
279
        self.etcd_client.clone()
    }
280

281
282
283
284
285
286
    /// An interface to store things. Will eventually replace `etcd_client`.
    /// Currently does key-value, but will grow to include whatever we need to store.
    pub fn store(&self) -> Arc<dyn KeyValueStore> {
        self.store.clone()
    }

287
288
289
    pub fn child_token(&self) -> CancellationToken {
        self.runtime.child_token()
    }
290

291
292
293
294
    pub(crate) fn graceful_shutdown_tracker(&self) -> Arc<GracefulShutdownTracker> {
        self.runtime.graceful_shutdown_tracker()
    }

295
296
297
    pub fn instance_sources(&self) -> Arc<Mutex<HashMap<Endpoint, Weak<InstanceSource>>>> {
        self.instance_sources.clone()
    }
298

299
300
    /// Add a Prometheus metric to a specific hierarchy's registry. Note that it is possible
    /// to register the same metric name multiple times, as long as the labels are different.
301
302
303
304
305
306
307
308
    pub fn add_prometheus_metric(
        &self,
        hierarchy: &str,
        prometheus_metric: Box<dyn prometheus::core::Collector>,
    ) -> anyhow::Result<()> {
        let mut registries = self.hierarchy_to_metricsregistry.write().unwrap();
        let entry = registries.entry(hierarchy.to_string()).or_default();

309
310
311
312
313
        // Try to register the metric
        entry
            .prometheus_registry
            .register(prometheus_metric)
            .map_err(|e| e.into())
314
315
    }

316
317
318
319
320
321
322
323
    /// Add a Prometheus update callback to the given hierarchies
    /// TODO: rename this to register_callback, once we move the the MetricsRegistry trait
    ///       out of the runtime, and make it into a composed module.
    pub fn register_prometheus_update_callback(
        &self,
        hierarchies: Vec<String>,
        callback: PrometheusUpdateCallback,
    ) {
324
        let mut registries = self.hierarchy_to_metricsregistry.write().unwrap();
325
        for hierarchy in &hierarchies {
326
            registries
327
                .entry(hierarchy.clone())
328
                .or_default()
329
                .add_prometheus_update_callback(callback.clone());
330
331
332
        }
    }

333
334
    /// Execute all Prometheus update callbacks for a given hierarchy and return their results
    pub fn execute_prometheus_update_callbacks(&self, hierarchy: &str) -> Vec<anyhow::Result<()>> {
335
336
337
338
339
        // Clone callbacks while holding read lock (fast operation)
        let callbacks = {
            let registries = self.hierarchy_to_metricsregistry.read().unwrap();
            registries
                .get(hierarchy)
340
                .map(|entry| entry.prometheus_update_callbacks.clone())
341
342
343
344
345
346
347
348
349
        }; // Read lock released here

        // Execute callbacks without holding the lock
        match callbacks {
            Some(callbacks) => callbacks.iter().map(|callback| callback()).collect(),
            None => Vec::new(),
        }
    }

350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
    /// Add a Prometheus exposition text callback that returns Prometheus text for the given hierarchies
    pub fn register_prometheus_expfmt_callback(
        &self,
        hierarchies: Vec<String>,
        callback: crate::PrometheusExpositionFormatCallback,
    ) {
        let mut registries = self.hierarchy_to_metricsregistry.write().unwrap();
        for hierarchy in &hierarchies {
            registries
                .entry(hierarchy.clone())
                .or_default()
                .add_prometheus_expfmt_callback(callback.clone());
        }
    }

365
366
367
368
369
    /// Get all registered hierarchy keys. Private because it is only used for testing.
    fn get_registered_hierarchies(&self) -> Vec<String> {
        let registries = self.hierarchy_to_metricsregistry.read().unwrap();
        registries.keys().cloned().collect()
    }
Ryan Olson's avatar
Ryan Olson committed
370
371
372
373
374
375
}

#[derive(Dissolve)]
pub struct DistributedConfig {
    pub etcd_config: etcd::ClientOptions,
    pub nats_config: nats::ClientOptions,
376
    pub is_static: bool,
Ryan Olson's avatar
Ryan Olson committed
377
378
379
}

impl DistributedConfig {
380
    pub fn from_settings(is_static: bool) -> DistributedConfig {
Ryan Olson's avatar
Ryan Olson committed
381
382
383
        DistributedConfig {
            etcd_config: etcd::ClientOptions::default(),
            nats_config: nats::ClientOptions::default(),
384
            is_static,
Ryan Olson's avatar
Ryan Olson committed
385
386
        }
    }
Ryan Olson's avatar
Ryan Olson committed
387
388
389
390
391

    pub fn for_cli() -> DistributedConfig {
        let mut config = DistributedConfig {
            etcd_config: etcd::ClientOptions::default(),
            nats_config: nats::ClientOptions::default(),
392
            is_static: false,
Ryan Olson's avatar
Ryan Olson committed
393
394
395
396
397
398
        };

        config.etcd_config.attach_lease = false;

        config
    }
Ryan Olson's avatar
Ryan Olson committed
399
}
400

401
pub mod distributed_test_utils {
402
403
404
    //! Common test helper functions for DistributedRuntime tests
    // TODO: Use in-memory DistributedRuntime for tests instead of full runtime when available.

405
    /// Helper function to create a DRT instance for integration-only tests.
406
407
408
409
410
411
412
413
414
415
    /// Uses from_current to leverage existing tokio runtime
    /// Note: Settings are read from environment variables inside DistributedRuntime::from_settings_without_discovery
    #[cfg(feature = "integration")]
    pub async fn create_test_drt_async() -> crate::DistributedRuntime {
        let rt = crate::Runtime::from_current().unwrap();
        crate::DistributedRuntime::from_settings_without_discovery(rt)
            .await
            .unwrap()
    }
}
416

417
#[cfg(all(test, feature = "integration"))]
418
419
420
421
422
423
424
425
426
427
428
429
430
431
mod tests {
    use super::distributed_test_utils::create_test_drt_async;

    #[tokio::test]
    async fn test_drt_uptime_after_delay_system_disabled() {
        // Test uptime with system status server disabled
        temp_env::async_with_vars([("DYN_SYSTEM_ENABLED", Some("false"))], async {
            // Start a DRT
            let drt = create_test_drt_async().await;

            // Wait 50ms
            tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;

            // Check that uptime is 50+ ms
432
            let uptime = drt.system_health.lock().uptime();
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
            assert!(
                uptime >= std::time::Duration::from_millis(50),
                "Expected uptime to be at least 50ms, but got {:?}",
                uptime
            );

            println!(
                "✓ DRT uptime test passed (system disabled): uptime = {:?}",
                uptime
            );
        })
        .await;
    }

    #[tokio::test]
    async fn test_drt_uptime_after_delay_system_enabled() {
        // Test uptime with system status server enabled
        temp_env::async_with_vars([("DYN_SYSTEM_ENABLED", Some("true"))], async {
            // Start a DRT
            let drt = create_test_drt_async().await;

            // Wait 50ms
            tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;

            // Check that uptime is 50+ ms
458
            let uptime = drt.system_health.lock().uptime();
459
460
461
462
463
464
465
466
467
468
469
470
471
472
            assert!(
                uptime >= std::time::Duration::from_millis(50),
                "Expected uptime to be at least 50ms, but got {:?}",
                uptime
            );

            println!(
                "✓ DRT uptime test passed (system enabled): uptime = {:?}",
                uptime
            );
        })
        .await;
    }
}