Unverified Commit cf630bf7 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

refactor: Make the Runtime and DistributedRuntime fields private (#4193)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 0e623146
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use anyhow::Context; use anyhow::{Context, Result};
use async_trait::async_trait; use async_trait::async_trait;
use futures::stream::StreamExt; use futures::stream::StreamExt;
use futures::{Stream, TryStreamExt}; use futures::{Stream, TryStreamExt};
use serde::{Deserialize, Serialize};
use super::*; use crate::component::Component;
use crate::traits::DistributedRuntimeProvider;
use crate::traits::events::{EventPublisher, EventSubscriber}; use crate::traits::events::{EventPublisher, EventSubscriber};
#[async_trait] #[async_trait]
...@@ -71,6 +72,8 @@ impl EventSubscriber for Component { ...@@ -71,6 +72,8 @@ impl EventSubscriber for Component {
#[cfg(feature = "integration")] #[cfg(feature = "integration")]
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::{DistributedRuntime, Runtime};
use super::*; use super::*;
// todo - make a distributed runtime fixture // todo - make a distributed runtime fixture
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use anyhow::Result;
pub use async_nats::service::endpoint::Stats as EndpointStats;
use derive_builder::Builder;
use derive_getters::Dissolve; use derive_getters::Dissolve;
use educe::Educe;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use crate::storage::key_value_store; use crate::{
component::{Endpoint, Instance, TransportType, service::EndpointStatsHandler},
use super::*; pipeline::network::{PushWorkHandler, ingress::push_endpoint::PushEndpoint},
storage::key_value_store,
pub use async_nats::service::endpoint::Stats as EndpointStats; traits::DistributedRuntimeProvider,
};
#[derive(Educe, Builder, Dissolve)] #[derive(Educe, Builder, Dissolve)]
#[educe(Debug)] #[educe(Debug)]
...@@ -72,21 +79,20 @@ impl EndpointConfigBuilder { ...@@ -72,21 +79,20 @@ impl EndpointConfigBuilder {
let service_name = endpoint.component.service_name(); let service_name = endpoint.component.service_name();
// acquire the registry lock
let registry = endpoint.drt().component_registry.inner.lock().await;
let metrics_labels: Option<Vec<(&str, &str)>> = metrics_labels let metrics_labels: Option<Vec<(&str, &str)>> = metrics_labels
.as_ref() .as_ref()
.map(|v| v.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect()); .map(|v| v.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect());
// Add metrics to the handler. The endpoint provides additional information to the handler. // Add metrics to the handler. The endpoint provides additional information to the handler.
handler.add_metrics(&endpoint, metrics_labels.as_deref())?; handler.add_metrics(&endpoint, metrics_labels.as_deref())?;
let registry = endpoint.drt().component_registry().inner.lock().await;
// get the group // get the group
let group = registry let group = registry
.services .services
.get(&service_name) .get(&service_name)
.map(|service| service.group(endpoint.component.service_name())) .map(|service| service.group(endpoint.component.service_name()))
.ok_or(error!("Service not found"))?; .ok_or(anyhow::anyhow!("Service not found"))?;
// get the stats handler map // get the stats handler map
let handler_map = registry let handler_map = registry
...@@ -118,7 +124,7 @@ impl EndpointConfigBuilder { ...@@ -118,7 +124,7 @@ impl EndpointConfigBuilder {
let namespace_name = endpoint.component.namespace.name.clone(); let namespace_name = endpoint.component.namespace.name.clone();
let component_name = endpoint.component.name.clone(); let component_name = endpoint.component.name.clone();
let endpoint_name = endpoint.name.clone(); let endpoint_name = endpoint.name.clone();
let system_health = endpoint.drt().system_health.clone(); let system_health = endpoint.drt().system_health();
let subject = endpoint.subject_to(connection_id); let subject = endpoint.subject_to(connection_id);
// Register health check target in SystemHealth if provided // Register health check target in SystemHealth if provided
...@@ -213,9 +219,9 @@ impl EndpointConfigBuilder { ...@@ -213,9 +219,9 @@ impl EndpointConfigBuilder {
"Unable to register service for discovery" "Unable to register service for discovery"
); );
endpoint_shutdown_token.cancel(); endpoint_shutdown_token.cancel();
return Err(error!( anyhow::bail!(
"Unable to register service for discovery. Check discovery service status" "Unable to register service for discovery. Check discovery service status"
)); );
} }
task.await??; task.await??;
......
...@@ -2,12 +2,16 @@ ...@@ -2,12 +2,16 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use anyhow::Context; use anyhow::Context;
use anyhow::Result;
use async_trait::async_trait; use async_trait::async_trait;
use futures::stream::StreamExt; use futures::stream::StreamExt;
use futures::{Stream, TryStreamExt}; use futures::{Stream, TryStreamExt};
use serde::Deserialize;
use serde::Serialize;
use super::*; use crate::component::Namespace;
use crate::metrics::{MetricsHierarchy, MetricsRegistry}; use crate::metrics::{MetricsHierarchy, MetricsRegistry};
use crate::traits::DistributedRuntimeProvider;
use crate::traits::events::{EventPublisher, EventSubscriber}; use crate::traits::events::{EventPublisher, EventSubscriber};
#[async_trait] #[async_trait]
...@@ -99,6 +103,8 @@ impl MetricsHierarchy for Namespace { ...@@ -99,6 +103,8 @@ impl MetricsHierarchy for Namespace {
#[cfg(feature = "integration")] #[cfg(feature = "integration")]
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::{DistributedRuntime, Runtime};
use super::*; use super::*;
// todo - make a distributed runtime fixture // todo - make a distributed runtime fixture
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use super::{Component, Registry, RegistryInner, Result}; use anyhow::Result;
use async_once_cell::OnceCell; use async_once_cell::OnceCell;
use std::{ use std::{
collections::HashMap, collections::HashMap,
...@@ -9,6 +9,8 @@ use std::{ ...@@ -9,6 +9,8 @@ use std::{
}; };
use tokio::sync::Mutex; use tokio::sync::Mutex;
use crate::component::{Registry, RegistryInner};
impl Default for Registry { impl Default for Registry {
fn default() -> Self { fn default() -> Self {
Self::new() Self::new()
...@@ -22,66 +24,3 @@ impl Registry { ...@@ -22,66 +24,3 @@ impl Registry {
} }
} }
} }
// impl ComponentRegistry {
// pub fn new() -> Self {
// Self {
// clients: Arc::new(Mutex::new(HashMap::new())),
// }
// }
// pub async fn get_or_create(&mut self, component: Component) -> Result<Arc<Client>> {
// // Lock the clients HashMap for thread-safe access
// let mut guard = self.clients.lock().await;
// // Check if the component already exists in the registry
// if let Some(weak) = guard.get(&component.slug()) {
// // Attempt to upgrade the Weak pointer
// if let Some(client) = weak.upgrade() {
// return Ok(client);
// }
// }
// // Fallback: Create a new Client
// let client = component.client().await?;
// // Insert a Weak reference to the new client into the map
// guard.insert(component.slug(), Arc::downgrade(&client));
// Ok(client)
// }
// }
// #[derive(Clone)]
// pub struct ServiceRegistry {
// clients: Arc<Mutex<HashMap<String, Arc<Service>>>>,
// }
// impl ServiceRegistry {
// pub fn new() -> Self {
// Self {
// clients: Arc::new(Mutex::new(HashMap::new())),
// }
// }
// pub async fn get_or_create(&mut self, component: Component) -> Result<Arc<Client>> {
// // Lock the clients HashMap for thread-safe access
// let mut guard = self.clients.lock().await;
// // Check if the component already exists in the registry
// if let Some(weak) = guard.get(&component.slug()) {
// // Attempt to upgrade the Weak pointer
// if let Some(client) = weak.upgrade() {
// return Ok(client);
// }
// }
// // Fallback: Create a new Client
// let client = component.client().await?;
// // Insert a Weak reference to the new client into the map
// guard.insert(component.slug(), Arc::downgrade(&client));
// Ok(client)
// }
// }
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use super::Result; use anyhow::Result;
use derive_builder::Builder; use derive_builder::Builder;
use figment::{ use figment::{
Figment, Figment,
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use crate::storage::key_value_store::{KeyValueStoreManager, WatchEvent};
use crate::{CancellationToken, Result};
use async_trait::async_trait;
use futures::{Stream, StreamExt};
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use anyhow::Result;
use async_trait::async_trait;
use futures::{Stream, StreamExt};
use tokio_util::sync::CancellationToken;
use super::{ use super::{
Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryQuery, DiscoverySpec, DiscoveryStream, Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryQuery, DiscoverySpec, DiscoveryStream,
}; };
use crate::storage::key_value_store::{KeyValueStoreManager, WatchEvent};
const INSTANCES_BUCKET: &str = "v1/instances"; const INSTANCES_BUCKET: &str = "v1/instances";
const MODELS_BUCKET: &str = "v1/mdc"; const MODELS_BUCKET: &str = "v1/mdc";
......
...@@ -4,9 +4,10 @@ ...@@ -4,9 +4,10 @@
use super::{ use super::{
Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryQuery, DiscoverySpec, DiscoveryStream, Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryQuery, DiscoverySpec, DiscoveryStream,
}; };
use crate::{CancellationToken, Result}; use anyhow::Result;
use async_trait::async_trait; use async_trait::async_trait;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use tokio_util::sync::CancellationToken;
/// Shared in-memory registry for mock discovery /// Shared in-memory registry for mock discovery
#[derive(Clone, Default)] #[derive(Clone, Default)]
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use crate::CancellationToken; use anyhow::Result;
use crate::Result;
use crate::component::TransportType;
use async_trait::async_trait; use async_trait::async_trait;
use futures::Stream; use futures::Stream;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::pin::Pin; use std::pin::Pin;
use tokio_util::sync::CancellationToken;
mod mock; mod mock;
pub use mock::{MockDiscovery, SharedMockRegistry}; pub use mock::{MockDiscovery, SharedMockRegistry};
mod kv_store; mod kv_store;
pub use kv_store::KVStoreDiscovery; pub use kv_store::KVStoreDiscovery;
pub mod utils; pub mod utils;
use crate::component::TransportType;
pub use utils::watch_and_extract_field; pub use utils::watch_and_extract_field;
/// Query key for prefix-based discovery queries /// Query key for prefix-based discovery queries
...@@ -85,7 +83,7 @@ impl DiscoverySpec { ...@@ -85,7 +83,7 @@ impl DiscoverySpec {
component: String, component: String,
endpoint: String, endpoint: String,
card: &T, card: &T,
) -> crate::Result<Self> ) -> Result<Self>
where where
T: Serialize, T: Serialize,
{ {
...@@ -158,14 +156,14 @@ impl DiscoveryInstance { ...@@ -158,14 +156,14 @@ impl DiscoveryInstance {
/// Deserializes the model JSON into the specified type T /// Deserializes the model JSON into the specified type T
/// Returns an error if this is not a Model instance or if deserialization fails /// Returns an error if this is not a Model instance or if deserialization fails
pub fn deserialize_model<T>(&self) -> crate::Result<T> pub fn deserialize_model<T>(&self) -> Result<T>
where where
T: for<'de> Deserialize<'de>, T: for<'de> Deserialize<'de>,
{ {
match self { match self {
Self::Model { card_json, .. } => Ok(serde_json::from_value(card_json.clone())?), Self::Model { card_json, .. } => Ok(serde_json::from_value(card_json.clone())?),
Self::Endpoint(_) => { Self::Endpoint(_) => {
crate::raise!("Cannot deserialize model from Endpoint instance") anyhow::bail!("Cannot deserialize model from Endpoint instance")
} }
} }
} }
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
pub use crate::component::Component; use crate::component::Component;
use crate::pipeline::PipelineError;
use crate::storage::key_value_store::{ use crate::storage::key_value_store::{
EtcdStore, KeyValueStore, KeyValueStoreEnum, KeyValueStoreManager, KeyValueStoreSelect, EtcdStore, KeyValueStore, KeyValueStoreEnum, KeyValueStoreManager, KeyValueStoreSelect,
MemoryStore, MemoryStore,
}; };
use crate::transports::nats::DRTNatsClientPrometheusMetrics; use crate::transports::nats::DRTNatsClientPrometheusMetrics;
use crate::{ use crate::{
ErrorContext,
component::{self, ComponentBuilder, Endpoint, InstanceSource, Namespace}, component::{self, ComponentBuilder, Endpoint, InstanceSource, Namespace},
discovery::Discovery, discovery::Discovery,
metrics::PrometheusUpdateCallback, metrics::PrometheusUpdateCallback,
...@@ -16,17 +16,59 @@ use crate::{ ...@@ -16,17 +16,59 @@ use crate::{
service::ServiceClient, service::ServiceClient,
transports::{etcd, nats, tcp}, transports::{etcd, nats, tcp},
}; };
use crate::{discovery, system_status_server, transports};
use super::utils::GracefulShutdownTracker; use super::utils::GracefulShutdownTracker;
use super::{Arc, DistributedRuntime, OK, OnceCell, Result, Runtime, SystemHealth, Weak, error}; use crate::SystemHealth;
use std::sync::OnceLock; use crate::runtime::Runtime;
use async_once_cell::OnceCell;
use std::sync::{Arc, OnceLock, Weak};
use anyhow::Result;
use derive_getters::Dissolve; use derive_getters::Dissolve;
use figment::error; use figment::error;
use std::collections::HashMap; use std::collections::HashMap;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
/// Distributed [Runtime] which provides access to shared resources across the cluster, this includes
/// communication protocols and transports.
#[derive(Clone)]
pub struct DistributedRuntime {
// local runtime
runtime: Runtime,
// we might consider a unifed transport manager here
etcd_client: Option<transports::etcd::Client>,
nats_client: Option<transports::nats::Client>,
store: KeyValueStoreManager,
tcp_server: Arc<OnceCell<Arc<transports::tcp::server::TcpStreamServer>>>,
system_status_server: Arc<OnceLock<Arc<system_status_server::SystemStatusServerInfo>>>,
// Service discovery client
discovery_client: Arc<dyn discovery::Discovery>,
// local registry for components
// the registry allows us to use share runtime resources across instances of the same component object.
// take for example two instances of a client to the same remote component. The registry allows us to use
// a single endpoint watcher for both clients, this keeps the number background tasking watching specific
// paths in etcd to a minimum.
component_registry: component::Registry,
// Will only have static components that are not discoverable via etcd, they must be know at
// startup. Will not start etcd.
is_static: bool,
instance_sources: Arc<tokio::sync::Mutex<HashMap<Endpoint, Weak<InstanceSource>>>>,
// Health Status
system_health: Arc<parking_lot::Mutex<SystemHealth>>,
// This hierarchy's own metrics registry
metrics_registry: MetricsRegistry,
}
impl MetricsHierarchy for DistributedRuntime { impl MetricsHierarchy for DistributedRuntime {
fn basename(&self) -> String { fn basename(&self) -> String {
"".to_string() // drt has no basename. Basename only begins with the Namespace. "".to_string() // drt has no basename. Basename only begins with the Namespace.
...@@ -229,6 +271,17 @@ impl DistributedRuntime { ...@@ -229,6 +271,17 @@ impl DistributedRuntime {
self.runtime.primary_token() self.runtime.primary_token()
} }
// TODO: Don't hand out pointers, instead have methods to use the registry in friendly ways
// (without being aware of async locks and so on)
pub fn component_registry(&self) -> &component::Registry {
&self.component_registry
}
// TODO: Don't hand out pointers, instead provide system health related services.
pub fn system_health(&self) -> Arc<parking_lot::Mutex<SystemHealth>> {
self.system_health.clone()
}
pub fn connection_id(&self) -> u64 { pub fn connection_id(&self) -> u64 {
self.store.connection_id() self.store.connection_id()
} }
...@@ -258,7 +311,7 @@ impl DistributedRuntime { ...@@ -258,7 +311,7 @@ impl DistributedRuntime {
.get_or_try_init(async move { .get_or_try_init(async move {
let options = tcp::server::ServerOptions::default(); let options = tcp::server::ServerOptions::default();
let server = tcp::server::TcpStreamServer::new(options).await?; let server = tcp::server::TcpStreamServer::new(options).await?;
OK(server) Ok::<_, PipelineError>(server)
}) })
.await? .await?
.clone()) .clone())
......
...@@ -2,11 +2,12 @@ ...@@ -2,11 +2,12 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use crate::component::{Client, Component, Endpoint, Instance}; use crate::component::{Client, Component, Endpoint, Instance};
use crate::config::HealthStatus;
use crate::pipeline::PushRouter; use crate::pipeline::PushRouter;
use crate::pipeline::{AsyncEngine, Context, ManyOut, SingleIn}; use crate::pipeline::{AsyncEngine, Context, ManyOut, SingleIn};
use crate::protocols::annotated::Annotated; use crate::protocols::annotated::Annotated;
use crate::protocols::maybe_error::MaybeError; use crate::protocols::maybe_error::MaybeError;
use crate::{DistributedRuntime, HealthStatus, SystemHealth}; use crate::{DistributedRuntime, SystemHealth};
use futures::StreamExt; use futures::StreamExt;
use parking_lot::Mutex; use parking_lot::Mutex;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
...@@ -99,10 +100,7 @@ impl HealthCheckManager { ...@@ -99,10 +100,7 @@ impl HealthCheckManager {
/// Start the health check manager by spawning per-endpoint monitoring tasks /// Start the health check manager by spawning per-endpoint monitoring tasks
pub async fn start(self: Arc<Self>) -> anyhow::Result<()> { pub async fn start(self: Arc<Self>) -> anyhow::Result<()> {
// Get all registered endpoints at startup // Get all registered endpoints at startup
let targets = { let targets = self.drt.system_health().lock().get_health_check_targets();
let system_health = self.drt.system_health.lock();
system_health.get_health_check_targets()
};
info!( info!(
"Starting health check tasks for {} endpoints with canary_wait_time: {:?}", "Starting health check tasks for {} endpoints with canary_wait_time: {:?}",
...@@ -131,12 +129,12 @@ impl HealthCheckManager { ...@@ -131,12 +129,12 @@ impl HealthCheckManager {
let endpoint_subject_clone = endpoint_subject.clone(); let endpoint_subject_clone = endpoint_subject.clone();
// Get the endpoint-specific notifier // Get the endpoint-specific notifier
let notifier = { let notifier = self
let system_health = self.drt.system_health.lock(); .drt
system_health .system_health()
.lock()
.get_endpoint_health_check_notifier(&endpoint_subject) .get_endpoint_health_check_notifier(&endpoint_subject)
.expect("Notifier should exist for registered endpoint") .expect("Notifier should exist for registered endpoint");
};
let task = tokio::spawn(async move { let task = tokio::spawn(async move {
let endpoint_subject = endpoint_subject_clone; let endpoint_subject = endpoint_subject_clone;
...@@ -150,10 +148,7 @@ impl HealthCheckManager { ...@@ -150,10 +148,7 @@ impl HealthCheckManager {
info!("Canary timer expired for {}, sending health check", endpoint_subject); info!("Canary timer expired for {}, sending health check", endpoint_subject);
// Get the health check payload for this endpoint // Get the health check payload for this endpoint
let target = { let target = manager.drt.system_health().lock().get_health_check_target(&endpoint_subject);
let system_health = manager.drt.system_health.lock();
system_health.get_health_check_target(&endpoint_subject)
};
if let Some(target) = target { if let Some(target) = target {
if let Err(e) = manager.send_health_check_request(&endpoint_subject, &target.payload).await { if let Err(e) = manager.send_health_check_request(&endpoint_subject, &target.payload).await {
...@@ -197,12 +192,14 @@ impl HealthCheckManager { ...@@ -197,12 +192,14 @@ impl HealthCheckManager {
let manager = self.clone(); let manager = self.clone();
// Get the receiver (can only be taken once) // Get the receiver (can only be taken once)
let mut rx = { let mut rx = manager
let system_health = manager.drt.system_health.lock(); .drt
system_health.take_new_endpoint_receiver().ok_or_else(|| { .system_health()
.lock()
.take_new_endpoint_receiver()
.ok_or_else(|| {
anyhow::anyhow!("Endpoint receiver already taken - this should only be called once") anyhow::anyhow!("Endpoint receiver already taken - this should only be called once")
})? })?;
};
tokio::spawn(async move { tokio::spawn(async move {
info!("Starting dynamic endpoint discovery monitor with channel-based notifications"); info!("Starting dynamic endpoint discovery monitor with channel-based notifications");
...@@ -246,14 +243,14 @@ impl HealthCheckManager { ...@@ -246,14 +243,14 @@ impl HealthCheckManager {
endpoint_subject: &str, endpoint_subject: &str,
payload: &serde_json::Value, payload: &serde_json::Value,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let target = { let target = self
let system_health = self.drt.system_health.lock(); .drt
system_health .system_health()
.lock()
.get_health_check_target(endpoint_subject) .get_health_check_target(endpoint_subject)
.ok_or_else(|| { .ok_or_else(|| {
anyhow::anyhow!("No health check target found for {}", endpoint_subject) anyhow::anyhow!("No health check target found for {}", endpoint_subject)
})? })?;
};
debug!( debug!(
"Sending health check to {} (instance_id: {})", "Sending health check to {} (instance_id: {})",
...@@ -274,7 +271,7 @@ impl HealthCheckManager { ...@@ -274,7 +271,7 @@ impl HealthCheckManager {
let request: SingleIn<serde_json::Value> = Context::new(payload.clone()); let request: SingleIn<serde_json::Value> = Context::new(payload.clone());
// Clone what we need for the spawned task // Clone what we need for the spawned task
let system_health = self.drt.system_health.clone(); let system_health = self.drt.system_health().clone();
let endpoint_subject_owned = endpoint_subject.to_string(); let endpoint_subject_owned = endpoint_subject.to_string();
let instance_id = target.instance.instance_id; let instance_id = target.instance.instance_id;
let timeout = self.config.request_timeout; let timeout = self.config.request_timeout;
...@@ -364,18 +361,16 @@ pub async fn get_health_check_status( ...@@ -364,18 +361,16 @@ pub async fn get_health_check_status(
drt: &DistributedRuntime, drt: &DistributedRuntime,
) -> anyhow::Result<serde_json::Value> { ) -> anyhow::Result<serde_json::Value> {
// Get endpoints list from SystemHealth // Get endpoints list from SystemHealth
let endpoint_subjects: Vec<String> = { let endpoint_subjects: Vec<String> = drt.system_health().lock().get_health_check_endpoints();
let system_health = drt.system_health.lock();
system_health.get_health_check_endpoints()
};
let mut endpoint_statuses = HashMap::new(); let mut endpoint_statuses = HashMap::new();
// Check each endpoint's health status // Check each endpoint's health status
{ {
let system_health = drt.system_health.lock(); let system_health = drt.system_health();
let system_health_lock = system_health.lock();
for endpoint_subject in &endpoint_subjects { for endpoint_subject in &endpoint_subjects {
let health_status = system_health let health_status = system_health_lock
.get_endpoint_health_status(endpoint_subject) .get_endpoint_health_status(endpoint_subject)
.unwrap_or(HealthStatus::NotReady); .unwrap_or(HealthStatus::NotReady);
...@@ -408,7 +403,6 @@ pub async fn get_health_check_status( ...@@ -408,7 +403,6 @@ pub async fn get_health_check_status(
#[cfg(all(test, feature = "integration"))] #[cfg(all(test, feature = "integration"))]
mod integration_tests { mod integration_tests {
use super::*; use super::*;
use crate::HealthStatus;
use crate::distributed::distributed_test_utils::create_test_drt_async; use crate::distributed::distributed_test_utils::create_test_drt_async;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
...@@ -429,8 +423,6 @@ mod integration_tests { ...@@ -429,8 +423,6 @@ mod integration_tests {
assert_eq!(manager.config.canary_wait_time, canary_wait_time); assert_eq!(manager.config.canary_wait_time, canary_wait_time);
assert_eq!(manager.config.request_timeout, request_timeout); assert_eq!(manager.config.request_timeout, request_timeout);
assert!(Arc::ptr_eq(&manager.drt.system_health, &drt.system_health));
} }
#[tokio::test] #[tokio::test]
...@@ -443,7 +435,7 @@ mod integration_tests { ...@@ -443,7 +435,7 @@ mod integration_tests {
"_health_check": true "_health_check": true
}); });
drt.system_health.lock().register_health_check_target( drt.system_health().lock().register_health_check_target(
endpoint, endpoint,
crate::component::Instance { crate::component::Instance {
component: "test_component".to_string(), component: "test_component".to_string(),
...@@ -456,7 +448,7 @@ mod integration_tests { ...@@ -456,7 +448,7 @@ mod integration_tests {
); );
let retrieved = drt let retrieved = drt
.system_health .system_health()
.lock() .lock()
.get_health_check_target(endpoint) .get_health_check_target(endpoint)
.map(|t| t.payload); .map(|t| t.payload);
...@@ -464,7 +456,7 @@ mod integration_tests { ...@@ -464,7 +456,7 @@ mod integration_tests {
assert_eq!(retrieved.unwrap(), payload); assert_eq!(retrieved.unwrap(), payload);
// Verify endpoint appears in the list // Verify endpoint appears in the list
let endpoints = drt.system_health.lock().get_health_check_endpoints(); let endpoints = drt.system_health().lock().get_health_check_endpoints();
assert!(endpoints.contains(&endpoint.to_string())); assert!(endpoints.contains(&endpoint.to_string()));
} }
...@@ -478,7 +470,7 @@ mod integration_tests { ...@@ -478,7 +470,7 @@ mod integration_tests {
"prompt": format!("test{}", i), "prompt": format!("test{}", i),
"_health_check": true "_health_check": true
}); });
drt.system_health.lock().register_health_check_target( drt.system_health().lock().register_health_check_target(
&endpoint, &endpoint,
crate::component::Instance { crate::component::Instance {
component: "test_component".to_string(), component: "test_component".to_string(),
...@@ -521,7 +513,7 @@ mod integration_tests { ...@@ -521,7 +513,7 @@ mod integration_tests {
}); });
// Register the endpoint // Register the endpoint
drt.system_health.lock().register_health_check_target( drt.system_health().lock().register_health_check_target(
endpoint, endpoint,
crate::component::Instance { crate::component::Instance {
component: "test_component".to_string(), component: "test_component".to_string(),
...@@ -535,7 +527,7 @@ mod integration_tests { ...@@ -535,7 +527,7 @@ mod integration_tests {
// Verify that a notifier was created for this endpoint // Verify that a notifier was created for this endpoint
let notifier = drt let notifier = drt
.system_health .system_health()
.lock() .lock()
.get_endpoint_health_check_notifier(endpoint); .get_endpoint_health_check_notifier(endpoint);
...@@ -551,7 +543,7 @@ mod integration_tests { ...@@ -551,7 +543,7 @@ mod integration_tests {
// Initially, the endpoint should be Ready (default after registration) // Initially, the endpoint should be Ready (default after registration)
let status = drt let status = drt
.system_health .system_health()
.lock() .lock()
.get_endpoint_health_status(endpoint); .get_endpoint_health_status(endpoint);
assert_eq!(status, Some(HealthStatus::NotReady)); assert_eq!(status, Some(HealthStatus::NotReady));
......
...@@ -6,17 +6,6 @@ ...@@ -6,17 +6,6 @@
#![allow(dead_code)] #![allow(dead_code)]
#![allow(unused_imports)] #![allow(unused_imports)]
use std::{
collections::HashMap,
sync::{Arc, OnceLock, Weak},
};
pub use anyhow::{
Context as ErrorContext, Error, Ok as OK, Result, anyhow as error, bail as raise,
};
use async_once_cell::OnceCell;
pub mod config; pub mod config;
pub use config::RuntimeConfig; pub use config::RuntimeConfig;
...@@ -27,6 +16,7 @@ pub mod engine; ...@@ -27,6 +16,7 @@ pub mod engine;
pub mod health_check; pub mod health_check;
pub mod system_status_server; pub mod system_status_server;
pub use system_status_server::SystemStatusServerInfo; pub use system_status_server::SystemStatusServerInfo;
pub mod distributed;
pub mod instances; pub mod instances;
pub mod logging; pub mod logging;
pub mod metrics; pub mod metrics;
...@@ -44,77 +34,10 @@ pub mod transports; ...@@ -44,77 +34,10 @@ pub mod transports;
pub mod utils; pub mod utils;
pub mod worker; pub mod worker;
pub mod distributed; pub use distributed::{DistributedRuntime, distributed_test_utils};
pub use distributed::distributed_test_utils;
pub use futures::stream; pub use futures::stream;
pub use metrics::MetricsRegistry; pub use metrics::MetricsRegistry;
pub use runtime::Runtime;
pub use system_health::{HealthCheckTarget, SystemHealth}; pub use system_health::{HealthCheckTarget, SystemHealth};
pub use tokio_util::sync::CancellationToken; pub use tokio_util::sync::CancellationToken;
pub use worker::Worker; pub use worker::Worker;
use crate::{
metrics::prometheus_names::distributed_runtime,
storage::key_value_store::{KeyValueStore, KeyValueStoreManager},
};
use component::{Endpoint, InstanceSource};
use utils::GracefulShutdownTracker;
use config::HealthStatus;
/// Types of Tokio runtimes that can be used to construct a Dynamo [Runtime].
#[derive(Clone)]
enum RuntimeType {
Shared(Arc<tokio::runtime::Runtime>),
External(tokio::runtime::Handle),
}
/// Local [Runtime] which provides access to shared resources local to the physical node/machine.
#[derive(Debug, Clone)]
pub struct Runtime {
id: Arc<String>,
primary: RuntimeType,
secondary: RuntimeType,
cancellation_token: CancellationToken,
endpoint_shutdown_token: CancellationToken,
graceful_shutdown_tracker: Arc<GracefulShutdownTracker>,
compute_pool: Option<Arc<compute::ComputePool>>,
block_in_place_permits: Option<Arc<tokio::sync::Semaphore>>,
}
/// Distributed [Runtime] which provides access to shared resources across the cluster, this includes
/// communication protocols and transports.
#[derive(Clone)]
pub struct DistributedRuntime {
// local runtime
runtime: Runtime,
// we might consider a unifed transport manager here
etcd_client: Option<transports::etcd::Client>,
nats_client: Option<transports::nats::Client>,
store: KeyValueStoreManager,
tcp_server: Arc<OnceCell<Arc<transports::tcp::server::TcpStreamServer>>>,
system_status_server: Arc<OnceLock<Arc<system_status_server::SystemStatusServerInfo>>>,
// Service discovery interface
discovery_client: Arc<dyn discovery::Discovery>,
// local registry for components
// the registry allows us to use share runtime resources across instances of the same component object.
// take for example two instances of a client to the same remote component. The registry allows us to use
// a single endpoint watcher for both clients, this keeps the number background tasking watching specific
// paths in etcd to a minimum.
component_registry: component::Registry,
// Will only have static components that are not discoverable via etcd, they must be know at
// startup. Will not start etcd.
is_static: bool,
instance_sources: Arc<tokio::sync::Mutex<HashMap<Endpoint, Weak<InstanceSource>>>>,
// Health Status
system_health: Arc<parking_lot::Mutex<SystemHealth>>,
// This hierarchy's own metrics registry
metrics_registry: MetricsRegistry,
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use async_nats::client::Client; use std::sync::Arc;
use async_nats::{HeaderMap, HeaderValue};
use tracing as log;
use super::*; use crate::engine::{AsyncEngine, AsyncEngineContextProvider, Data};
use crate::logging::DistributedTraceContext; use crate::logging::DistributedTraceContext;
use crate::logging::get_distributed_tracing_context; use crate::logging::get_distributed_tracing_context;
use crate::logging::inject_otel_context_into_nats_headers; use crate::logging::inject_otel_context_into_nats_headers;
use crate::{Result, protocols::maybe_error::MaybeError}; use crate::pipeline::network::ConnectionInfo;
use crate::pipeline::network::NetworkStreamWrapper;
use crate::pipeline::network::PendingConnections;
use crate::pipeline::network::ResponseService;
use crate::pipeline::network::STREAM_ERR_MSG;
use crate::pipeline::network::StreamOptions;
use crate::pipeline::network::TwoPartCodec;
use crate::pipeline::network::codec::TwoPartMessage;
use crate::pipeline::network::tcp;
use crate::pipeline::{ManyOut, PipelineError, ResponseStream, SingleIn};
use crate::protocols::maybe_error::MaybeError;
use anyhow::{Error, Result};
use async_nats::client::Client;
use async_nats::{HeaderMap, HeaderValue};
use serde::Deserialize;
use serde::Serialize;
use tokio_stream::{StreamExt, StreamNotifyClose, wrappers::ReceiverStream}; use tokio_stream::{StreamExt, StreamNotifyClose, wrappers::ReceiverStream};
use tracing::Instrument; use tracing::Instrument;
...@@ -70,7 +84,7 @@ impl AddressedPushRouter { ...@@ -70,7 +84,7 @@ impl AddressedPushRouter {
} }
} }
#[async_trait] #[async_trait::async_trait]
impl<T, U> AsyncEngine<SingleIn<AddressedRequest<T>>, ManyOut<U>, Error> for AddressedPushRouter impl<T, U> AsyncEngine<SingleIn<AddressedRequest<T>>, ManyOut<U>, Error> for AddressedPushRouter
where where
T: Data + Serialize, T: Data + Serialize,
...@@ -123,7 +137,7 @@ where ...@@ -123,7 +137,7 @@ where
let ctrl = serde_json::to_vec(&control_message)?; let ctrl = serde_json::to_vec(&control_message)?;
let data = serde_json::to_vec(&request)?; let data = serde_json::to_vec(&request)?;
log::trace!( tracing::trace!(
request_id, request_id,
"packaging two-part message; ctrl: {} bytes, data: {} bytes", "packaging two-part message; ctrl: {} bytes, data: {} bytes",
ctrl.len(), ctrl.len(),
...@@ -140,7 +154,7 @@ where ...@@ -140,7 +154,7 @@ where
// TRANSPORT ABSTRACT REQUIRED - END HERE // TRANSPORT ABSTRACT REQUIRED - END HERE
log::trace!(request_id, "enqueueing two-part message to nats"); tracing::trace!(request_id, "enqueueing two-part message to nats");
// Insert Trace Context into Headers // Insert Trace Context into Headers
// Enables span to be created in push_endpoint before // Enables span to be created in push_endpoint before
...@@ -168,7 +182,7 @@ where ...@@ -168,7 +182,7 @@ where
.request_with_headers(address.to_string(), headers, buffer) .request_with_headers(address.to_string(), headers, buffer)
.await?; .await?;
log::trace!(request_id, "awaiting transport handshake"); tracing::trace!(request_id, "awaiting transport handshake");
let response_stream = response_stream_provider let response_stream = response_stream_provider
.await .await
.map_err(|_| PipelineError::DetachedStreamReceiver)? .map_err(|_| PipelineError::DetachedStreamReceiver)?
...@@ -206,7 +220,7 @@ where ...@@ -206,7 +220,7 @@ where
Err(err) => { Err(err) => {
// legacy log print // legacy log print
let json_str = String::from_utf8_lossy(&res_bytes); let json_str = String::from_utf8_lossy(&res_bytes);
log::warn!(%err, %json_str, "Failed deserializing JSON to response"); tracing::warn!(%err, %json_str, "Failed deserializing JSON to response");
Some(U::from_err(Error::new(err).into())) Some(U::from_err(Error::new(err).into()))
} }
...@@ -218,11 +232,11 @@ where ...@@ -218,11 +232,11 @@ where
// Gracefully end the stream if 'stop_generating()' was called. Do NOT check for // Gracefully end the stream if 'stop_generating()' was called. Do NOT check for
// 'is_killed()' here because it implies the stream ended abnormally which should be // 'is_killed()' here because it implies the stream ended abnormally which should be
// handled by the error branch below. // handled by the error branch below.
log::debug!("Request cancelled and then trying to read a response"); tracing::debug!("Request cancelled and then trying to read a response");
None None
} else { } else {
// stream ended unexpectedly // stream ended unexpectedly
log::debug!("{STREAM_ERR_MSG}"); tracing::debug!("{STREAM_ERR_MSG}");
Some(U::from_err(Error::msg(STREAM_ERR_MSG).into())) Some(U::from_err(Error::msg(STREAM_ERR_MSG).into()))
} }
}); });
......
...@@ -19,7 +19,7 @@ use crate::pipeline::network::{ ...@@ -19,7 +19,7 @@ use crate::pipeline::network::{
codec::{TwoPartCodec, TwoPartMessage}, codec::{TwoPartCodec, TwoPartMessage},
tcp::StreamType, tcp::StreamType,
}; };
use crate::{ErrorContext, Result, error}; // Import SinkExt to use the `send` method use anyhow::{Context, Result, anyhow as error}; // Import SinkExt to use the `send` method
#[allow(dead_code)] #[allow(dead_code)]
pub struct TcpClient { pub struct TcpClient {
......
...@@ -16,25 +16,6 @@ use derive_builder::Builder; ...@@ -16,25 +16,6 @@ use derive_builder::Builder;
use futures::{SinkExt, StreamExt}; use futures::{SinkExt, StreamExt};
use local_ip_address::{Error, list_afinet_netifas, local_ip, local_ipv6}; use local_ip_address::{Error, list_afinet_netifas, local_ip, local_ipv6};
// Trait for IP address resolution - allows dependency injection for testing
pub trait IpResolver {
fn local_ip(&self) -> Result<std::net::IpAddr, Error>;
fn local_ipv6(&self) -> Result<std::net::IpAddr, Error>;
}
// Default implementation using the real local_ip_address crate
pub struct DefaultIpResolver;
impl IpResolver for DefaultIpResolver {
fn local_ip(&self) -> Result<std::net::IpAddr, Error> {
local_ip()
}
fn local_ipv6(&self) -> Result<std::net::IpAddr, Error> {
local_ipv6()
}
}
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::{ use tokio::{
io::AsyncWriteExt, io::AsyncWriteExt,
...@@ -56,7 +37,26 @@ use crate::pipeline::{ ...@@ -56,7 +37,26 @@ use crate::pipeline::{
tcp::StreamType, tcp::StreamType,
}, },
}; };
use crate::{ErrorContext, Result, error}; use anyhow::{Context, Result, anyhow as error};
// Trait for IP address resolution - allows dependency injection for testing
pub trait IpResolver {
fn local_ip(&self) -> Result<std::net::IpAddr, Error>;
fn local_ipv6(&self) -> Result<std::net::IpAddr, Error>;
}
// Default implementation using the real local_ip_address crate
pub struct DefaultIpResolver;
impl IpResolver for DefaultIpResolver {
fn local_ip(&self) -> Result<std::net::IpAddr, Error> {
local_ip()
}
fn local_ipv6(&self) -> Result<std::net::IpAddr, Error> {
local_ipv6()
}
}
#[allow(dead_code)] #[allow(dead_code)]
type ResponseType = TwoPartMessage; type ResponseType = TwoPartMessage;
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use super::*; use super::*;
use crate::Error; use anyhow::Error;
impl<Resp: PipelineIO> Default for SinkEdge<Resp> { impl<Resp: PipelineIO> Default for SinkEdge<Resp> {
fn default() -> Self { fn default() -> Self {
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use super::*; use super::*;
use crate::Error; use anyhow::Error;
impl<Req: PipelineIO, Resp: PipelineIO> ServiceBackend<Req, Resp> { impl<Req: PipelineIO, Resp: PipelineIO> ServiceBackend<Req, Resp> {
pub fn from_engine(engine: ServiceEngine<Req, Resp>) -> Arc<Self> { pub fn from_engine(engine: ServiceEngine<Req, Resp>) -> Arc<Self> {
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use super::*; use super::*;
use crate::Error; use anyhow::Error;
impl<Req: PipelineIO, Resp: PipelineIO> SegmentSink<Req, Resp> { impl<Req: PipelineIO, Resp: PipelineIO> SegmentSink<Req, Resp> {
pub fn new() -> Arc<Self> { pub fn new() -> Arc<Self> {
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use super::*; use super::maybe_error::MaybeError;
use crate::{Result, error}; use anyhow::{Result, anyhow as error};
use maybe_error::MaybeError; use serde::{Deserialize, Serialize};
pub trait AnnotationsProvider { pub trait AnnotationsProvider {
fn annotations(&self) -> Option<Vec<String>>; fn annotations(&self) -> Option<Vec<String>>;
......
...@@ -11,7 +11,7 @@ use std::{ ...@@ -11,7 +11,7 @@ use std::{
task::{Context, Poll}, task::{Context, Poll},
}; };
pub use crate::{Error, Result}; pub use anyhow::{Error, Result};
pub use async_trait::async_trait; pub use async_trait::async_trait;
pub use tokio::task::JoinHandle; pub use tokio::task::JoinHandle;
pub use tokio_util::sync::CancellationToken; pub use tokio_util::sync::CancellationToken;
......
...@@ -14,8 +14,10 @@ ...@@ -14,8 +14,10 @@
//! private; however, for now we are exposing most objects as fully public while the API is maturing. //! private; however, for now we are exposing most objects as fully public while the API is maturing.
use super::utils::GracefulShutdownTracker; use super::utils::GracefulShutdownTracker;
use super::{Result, Runtime, RuntimeType, error}; use crate::{
use crate::config::{self, RuntimeConfig}; compute,
config::{self, RuntimeConfig},
};
use futures::Future; use futures::Future;
use once_cell::sync::OnceCell; use once_cell::sync::OnceCell;
...@@ -24,8 +26,28 @@ use tokio::{signal, sync::Mutex, task::JoinHandle}; ...@@ -24,8 +26,28 @@ use tokio::{signal, sync::Mutex, task::JoinHandle};
pub use tokio_util::sync::CancellationToken; pub use tokio_util::sync::CancellationToken;
/// Types of Tokio runtimes that can be used to construct a Dynamo [Runtime].
#[derive(Clone)]
enum RuntimeType {
Shared(Arc<tokio::runtime::Runtime>),
External(tokio::runtime::Handle),
}
/// Local [Runtime] which provides access to shared resources local to the physical node/machine.
#[derive(Debug, Clone)]
pub struct Runtime {
id: Arc<String>,
primary: RuntimeType,
secondary: RuntimeType,
cancellation_token: CancellationToken,
endpoint_shutdown_token: CancellationToken,
graceful_shutdown_tracker: Arc<GracefulShutdownTracker>,
compute_pool: Option<Arc<compute::ComputePool>>,
block_in_place_permits: Option<Arc<tokio::sync::Semaphore>>,
}
impl Runtime { impl Runtime {
fn new(runtime: RuntimeType, secondary: Option<RuntimeType>) -> Result<Runtime> { fn new(runtime: RuntimeType, secondary: Option<RuntimeType>) -> anyhow::Result<Runtime> {
// worker id // worker id
let id = Arc::new(uuid::Uuid::new_v4().to_string()); let id = Arc::new(uuid::Uuid::new_v4().to_string());
...@@ -65,7 +87,7 @@ impl Runtime { ...@@ -65,7 +87,7 @@ impl Runtime {
runtime: RuntimeType, runtime: RuntimeType,
secondary: Option<RuntimeType>, secondary: Option<RuntimeType>,
config: &RuntimeConfig, config: &RuntimeConfig,
) -> Result<Runtime> { ) -> anyhow::Result<Runtime> {
let mut rt = Self::new(runtime, secondary)?; let mut rt = Self::new(runtime, secondary)?;
// Create compute pool from configuration // Create compute pool from configuration
...@@ -123,7 +145,7 @@ impl Runtime { ...@@ -123,7 +145,7 @@ impl Runtime {
/// Initialize thread-local compute context on all worker threads using a barrier /// Initialize thread-local compute context on all worker threads using a barrier
/// This ensures every worker thread has its thread-local context initialized /// This ensures every worker thread has its thread-local context initialized
pub async fn initialize_all_thread_locals(&self) -> Result<()> { pub async fn initialize_all_thread_locals(&self) -> anyhow::Result<()> {
if let (Some(pool), Some(permits)) = (&self.compute_pool, &self.block_in_place_permits) { if let (Some(pool), Some(permits)) = (&self.compute_pool, &self.block_in_place_permits) {
// First, detect how many worker threads we actually have // First, detect how many worker threads we actually have
let num_workers = self.detect_worker_thread_count().await; let num_workers = self.detect_worker_thread_count().await;
...@@ -207,11 +229,11 @@ impl Runtime { ...@@ -207,11 +229,11 @@ impl Runtime {
count count
} }
pub fn from_current() -> Result<Runtime> { pub fn from_current() -> anyhow::Result<Runtime> {
Runtime::from_handle(tokio::runtime::Handle::current()) Runtime::from_handle(tokio::runtime::Handle::current())
} }
pub fn from_handle(handle: tokio::runtime::Handle) -> Result<Runtime> { pub fn from_handle(handle: tokio::runtime::Handle) -> anyhow::Result<Runtime> {
let primary = RuntimeType::External(handle.clone()); let primary = RuntimeType::External(handle.clone());
let secondary = RuntimeType::External(handle); let secondary = RuntimeType::External(handle);
Runtime::new(primary, Some(secondary)) Runtime::new(primary, Some(secondary))
...@@ -219,7 +241,7 @@ impl Runtime { ...@@ -219,7 +241,7 @@ impl Runtime {
/// Create a [`Runtime`] instance from the settings /// Create a [`Runtime`] instance from the settings
/// See [`config::RuntimeConfig::from_settings`] /// See [`config::RuntimeConfig::from_settings`]
pub fn from_settings() -> Result<Runtime> { pub fn from_settings() -> anyhow::Result<Runtime> {
let config = config::RuntimeConfig::from_settings()?; let config = config::RuntimeConfig::from_settings()?;
let runtime = Arc::new(config.create_runtime()?); let runtime = Arc::new(config.create_runtime()?);
let primary = RuntimeType::Shared(runtime.clone()); let primary = RuntimeType::Shared(runtime.clone());
...@@ -228,7 +250,7 @@ impl Runtime { ...@@ -228,7 +250,7 @@ impl Runtime {
} }
/// Create a [`Runtime`] with two single-threaded async tokio runtime /// Create a [`Runtime`] with two single-threaded async tokio runtime
pub fn single_threaded() -> Result<Runtime> { pub fn single_threaded() -> anyhow::Result<Runtime> {
let config = config::RuntimeConfig::single_threaded(); let config = config::RuntimeConfig::single_threaded();
let owned = RuntimeType::Shared(Arc::new(config.create_runtime()?)); let owned = RuntimeType::Shared(Arc::new(config.create_runtime()?));
Runtime::new(owned, None) Runtime::new(owned, None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment