Unverified Commit cf630bf7 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

refactor: Make the Runtime and DistributedRuntime fields private (#4193)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 0e623146
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use anyhow::Context;
use anyhow::{Context, Result};
use async_trait::async_trait;
use futures::stream::StreamExt;
use futures::{Stream, TryStreamExt};
use serde::{Deserialize, Serialize};
use super::*;
use crate::component::Component;
use crate::traits::DistributedRuntimeProvider;
use crate::traits::events::{EventPublisher, EventSubscriber};
#[async_trait]
......@@ -71,6 +72,8 @@ impl EventSubscriber for Component {
#[cfg(feature = "integration")]
#[cfg(test)]
mod tests {
use crate::{DistributedRuntime, Runtime};
use super::*;
// todo - make a distributed runtime fixture
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use anyhow::Result;
pub use async_nats::service::endpoint::Stats as EndpointStats;
use derive_builder::Builder;
use derive_getters::Dissolve;
use educe::Educe;
use tokio_util::sync::CancellationToken;
use crate::storage::key_value_store;
use super::*;
pub use async_nats::service::endpoint::Stats as EndpointStats;
use crate::{
component::{Endpoint, Instance, TransportType, service::EndpointStatsHandler},
pipeline::network::{PushWorkHandler, ingress::push_endpoint::PushEndpoint},
storage::key_value_store,
traits::DistributedRuntimeProvider,
};
#[derive(Educe, Builder, Dissolve)]
#[educe(Debug)]
......@@ -72,21 +79,20 @@ impl EndpointConfigBuilder {
let service_name = endpoint.component.service_name();
// acquire the registry lock
let registry = endpoint.drt().component_registry.inner.lock().await;
let metrics_labels: Option<Vec<(&str, &str)>> = metrics_labels
.as_ref()
.map(|v| v.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect());
// Add metrics to the handler. The endpoint provides additional information to the handler.
handler.add_metrics(&endpoint, metrics_labels.as_deref())?;
let registry = endpoint.drt().component_registry().inner.lock().await;
// get the group
let group = registry
.services
.get(&service_name)
.map(|service| service.group(endpoint.component.service_name()))
.ok_or(error!("Service not found"))?;
.ok_or(anyhow::anyhow!("Service not found"))?;
// get the stats handler map
let handler_map = registry
......@@ -118,7 +124,7 @@ impl EndpointConfigBuilder {
let namespace_name = endpoint.component.namespace.name.clone();
let component_name = endpoint.component.name.clone();
let endpoint_name = endpoint.name.clone();
let system_health = endpoint.drt().system_health.clone();
let system_health = endpoint.drt().system_health();
let subject = endpoint.subject_to(connection_id);
// Register health check target in SystemHealth if provided
......@@ -213,9 +219,9 @@ impl EndpointConfigBuilder {
"Unable to register service for discovery"
);
endpoint_shutdown_token.cancel();
return Err(error!(
anyhow::bail!(
"Unable to register service for discovery. Check discovery service status"
));
);
}
task.await??;
......
......@@ -2,12 +2,16 @@
// SPDX-License-Identifier: Apache-2.0
use anyhow::Context;
use anyhow::Result;
use async_trait::async_trait;
use futures::stream::StreamExt;
use futures::{Stream, TryStreamExt};
use serde::Deserialize;
use serde::Serialize;
use super::*;
use crate::component::Namespace;
use crate::metrics::{MetricsHierarchy, MetricsRegistry};
use crate::traits::DistributedRuntimeProvider;
use crate::traits::events::{EventPublisher, EventSubscriber};
#[async_trait]
......@@ -99,6 +103,8 @@ impl MetricsHierarchy for Namespace {
#[cfg(feature = "integration")]
#[cfg(test)]
mod tests {
use crate::{DistributedRuntime, Runtime};
use super::*;
// todo - make a distributed runtime fixture
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::{Component, Registry, RegistryInner, Result};
use anyhow::Result;
use async_once_cell::OnceCell;
use std::{
collections::HashMap,
......@@ -9,6 +9,8 @@ use std::{
};
use tokio::sync::Mutex;
use crate::component::{Registry, RegistryInner};
impl Default for Registry {
fn default() -> Self {
Self::new()
......@@ -22,66 +24,3 @@ impl Registry {
}
}
}
// impl ComponentRegistry {
// pub fn new() -> Self {
// Self {
// clients: Arc::new(Mutex::new(HashMap::new())),
// }
// }
// pub async fn get_or_create(&mut self, component: Component) -> Result<Arc<Client>> {
// // Lock the clients HashMap for thread-safe access
// let mut guard = self.clients.lock().await;
// // Check if the component already exists in the registry
// if let Some(weak) = guard.get(&component.slug()) {
// // Attempt to upgrade the Weak pointer
// if let Some(client) = weak.upgrade() {
// return Ok(client);
// }
// }
// // Fallback: Create a new Client
// let client = component.client().await?;
// // Insert a Weak reference to the new client into the map
// guard.insert(component.slug(), Arc::downgrade(&client));
// Ok(client)
// }
// }
// #[derive(Clone)]
// pub struct ServiceRegistry {
// clients: Arc<Mutex<HashMap<String, Arc<Service>>>>,
// }
// impl ServiceRegistry {
// pub fn new() -> Self {
// Self {
// clients: Arc::new(Mutex::new(HashMap::new())),
// }
// }
// pub async fn get_or_create(&mut self, component: Component) -> Result<Arc<Client>> {
// // Lock the clients HashMap for thread-safe access
// let mut guard = self.clients.lock().await;
// // Check if the component already exists in the registry
// if let Some(weak) = guard.get(&component.slug()) {
// // Attempt to upgrade the Weak pointer
// if let Some(client) = weak.upgrade() {
// return Ok(client);
// }
// }
// // Fallback: Create a new Client
// let client = component.client().await?;
// // Insert a Weak reference to the new client into the map
// guard.insert(component.slug(), Arc::downgrade(&client));
// Ok(client)
// }
// }
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::Result;
use anyhow::Result;
use derive_builder::Builder;
use figment::{
Figment,
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use crate::storage::key_value_store::{KeyValueStoreManager, WatchEvent};
use crate::{CancellationToken, Result};
use async_trait::async_trait;
use futures::{Stream, StreamExt};
use std::pin::Pin;
use std::sync::Arc;
use anyhow::Result;
use async_trait::async_trait;
use futures::{Stream, StreamExt};
use tokio_util::sync::CancellationToken;
use super::{
Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryQuery, DiscoverySpec, DiscoveryStream,
};
use crate::storage::key_value_store::{KeyValueStoreManager, WatchEvent};
const INSTANCES_BUCKET: &str = "v1/instances";
const MODELS_BUCKET: &str = "v1/mdc";
......
......@@ -4,9 +4,10 @@
use super::{
Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryQuery, DiscoverySpec, DiscoveryStream,
};
use crate::{CancellationToken, Result};
use anyhow::Result;
use async_trait::async_trait;
use std::sync::{Arc, Mutex};
use tokio_util::sync::CancellationToken;
/// Shared in-memory registry for mock discovery
#[derive(Clone, Default)]
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use crate::CancellationToken;
use crate::Result;
use crate::component::TransportType;
use anyhow::Result;
use async_trait::async_trait;
use futures::Stream;
use serde::{Deserialize, Serialize};
use std::pin::Pin;
use tokio_util::sync::CancellationToken;
mod mock;
pub use mock::{MockDiscovery, SharedMockRegistry};
mod kv_store;
pub use kv_store::KVStoreDiscovery;
pub mod utils;
use crate::component::TransportType;
pub use utils::watch_and_extract_field;
/// Query key for prefix-based discovery queries
......@@ -85,7 +83,7 @@ impl DiscoverySpec {
component: String,
endpoint: String,
card: &T,
) -> crate::Result<Self>
) -> Result<Self>
where
T: Serialize,
{
......@@ -158,14 +156,14 @@ impl DiscoveryInstance {
/// Deserializes the model JSON into the specified type T
/// Returns an error if this is not a Model instance or if deserialization fails
pub fn deserialize_model<T>(&self) -> crate::Result<T>
pub fn deserialize_model<T>(&self) -> Result<T>
where
T: for<'de> Deserialize<'de>,
{
match self {
Self::Model { card_json, .. } => Ok(serde_json::from_value(card_json.clone())?),
Self::Endpoint(_) => {
crate::raise!("Cannot deserialize model from Endpoint instance")
anyhow::bail!("Cannot deserialize model from Endpoint instance")
}
}
}
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
pub use crate::component::Component;
use crate::component::Component;
use crate::pipeline::PipelineError;
use crate::storage::key_value_store::{
EtcdStore, KeyValueStore, KeyValueStoreEnum, KeyValueStoreManager, KeyValueStoreSelect,
MemoryStore,
};
use crate::transports::nats::DRTNatsClientPrometheusMetrics;
use crate::{
ErrorContext,
component::{self, ComponentBuilder, Endpoint, InstanceSource, Namespace},
discovery::Discovery,
metrics::PrometheusUpdateCallback,
......@@ -16,17 +16,59 @@ use crate::{
service::ServiceClient,
transports::{etcd, nats, tcp},
};
use crate::{discovery, system_status_server, transports};
use super::utils::GracefulShutdownTracker;
use super::{Arc, DistributedRuntime, OK, OnceCell, Result, Runtime, SystemHealth, Weak, error};
use std::sync::OnceLock;
use crate::SystemHealth;
use crate::runtime::Runtime;
use async_once_cell::OnceCell;
use std::sync::{Arc, OnceLock, Weak};
use anyhow::Result;
use derive_getters::Dissolve;
use figment::error;
use std::collections::HashMap;
use tokio::sync::Mutex;
use tokio_util::sync::CancellationToken;
/// Distributed [Runtime] which provides access to shared resources across the cluster, this includes
/// communication protocols and transports.
#[derive(Clone)]
pub struct DistributedRuntime {
// local runtime
runtime: Runtime,
// we might consider a unifed transport manager here
etcd_client: Option<transports::etcd::Client>,
nats_client: Option<transports::nats::Client>,
store: KeyValueStoreManager,
tcp_server: Arc<OnceCell<Arc<transports::tcp::server::TcpStreamServer>>>,
system_status_server: Arc<OnceLock<Arc<system_status_server::SystemStatusServerInfo>>>,
// Service discovery client
discovery_client: Arc<dyn discovery::Discovery>,
// local registry for components
// the registry allows us to use share runtime resources across instances of the same component object.
// take for example two instances of a client to the same remote component. The registry allows us to use
// a single endpoint watcher for both clients, this keeps the number background tasking watching specific
// paths in etcd to a minimum.
component_registry: component::Registry,
// Will only have static components that are not discoverable via etcd, they must be know at
// startup. Will not start etcd.
is_static: bool,
instance_sources: Arc<tokio::sync::Mutex<HashMap<Endpoint, Weak<InstanceSource>>>>,
// Health Status
system_health: Arc<parking_lot::Mutex<SystemHealth>>,
// This hierarchy's own metrics registry
metrics_registry: MetricsRegistry,
}
impl MetricsHierarchy for DistributedRuntime {
fn basename(&self) -> String {
"".to_string() // drt has no basename. Basename only begins with the Namespace.
......@@ -229,6 +271,17 @@ impl DistributedRuntime {
self.runtime.primary_token()
}
// TODO: Don't hand out pointers, instead have methods to use the registry in friendly ways
// (without being aware of async locks and so on)
pub fn component_registry(&self) -> &component::Registry {
&self.component_registry
}
// TODO: Don't hand out pointers, instead provide system health related services.
pub fn system_health(&self) -> Arc<parking_lot::Mutex<SystemHealth>> {
self.system_health.clone()
}
pub fn connection_id(&self) -> u64 {
self.store.connection_id()
}
......@@ -258,7 +311,7 @@ impl DistributedRuntime {
.get_or_try_init(async move {
let options = tcp::server::ServerOptions::default();
let server = tcp::server::TcpStreamServer::new(options).await?;
OK(server)
Ok::<_, PipelineError>(server)
})
.await?
.clone())
......
......@@ -2,11 +2,12 @@
// SPDX-License-Identifier: Apache-2.0
use crate::component::{Client, Component, Endpoint, Instance};
use crate::config::HealthStatus;
use crate::pipeline::PushRouter;
use crate::pipeline::{AsyncEngine, Context, ManyOut, SingleIn};
use crate::protocols::annotated::Annotated;
use crate::protocols::maybe_error::MaybeError;
use crate::{DistributedRuntime, HealthStatus, SystemHealth};
use crate::{DistributedRuntime, SystemHealth};
use futures::StreamExt;
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
......@@ -99,10 +100,7 @@ impl HealthCheckManager {
/// Start the health check manager by spawning per-endpoint monitoring tasks
pub async fn start(self: Arc<Self>) -> anyhow::Result<()> {
// Get all registered endpoints at startup
let targets = {
let system_health = self.drt.system_health.lock();
system_health.get_health_check_targets()
};
let targets = self.drt.system_health().lock().get_health_check_targets();
info!(
"Starting health check tasks for {} endpoints with canary_wait_time: {:?}",
......@@ -131,12 +129,12 @@ impl HealthCheckManager {
let endpoint_subject_clone = endpoint_subject.clone();
// Get the endpoint-specific notifier
let notifier = {
let system_health = self.drt.system_health.lock();
system_health
.get_endpoint_health_check_notifier(&endpoint_subject)
.expect("Notifier should exist for registered endpoint")
};
let notifier = self
.drt
.system_health()
.lock()
.get_endpoint_health_check_notifier(&endpoint_subject)
.expect("Notifier should exist for registered endpoint");
let task = tokio::spawn(async move {
let endpoint_subject = endpoint_subject_clone;
......@@ -150,10 +148,7 @@ impl HealthCheckManager {
info!("Canary timer expired for {}, sending health check", endpoint_subject);
// Get the health check payload for this endpoint
let target = {
let system_health = manager.drt.system_health.lock();
system_health.get_health_check_target(&endpoint_subject)
};
let target = manager.drt.system_health().lock().get_health_check_target(&endpoint_subject);
if let Some(target) = target {
if let Err(e) = manager.send_health_check_request(&endpoint_subject, &target.payload).await {
......@@ -197,12 +192,14 @@ impl HealthCheckManager {
let manager = self.clone();
// Get the receiver (can only be taken once)
let mut rx = {
let system_health = manager.drt.system_health.lock();
system_health.take_new_endpoint_receiver().ok_or_else(|| {
let mut rx = manager
.drt
.system_health()
.lock()
.take_new_endpoint_receiver()
.ok_or_else(|| {
anyhow::anyhow!("Endpoint receiver already taken - this should only be called once")
})?
};
})?;
tokio::spawn(async move {
info!("Starting dynamic endpoint discovery monitor with channel-based notifications");
......@@ -246,14 +243,14 @@ impl HealthCheckManager {
endpoint_subject: &str,
payload: &serde_json::Value,
) -> anyhow::Result<()> {
let target = {
let system_health = self.drt.system_health.lock();
system_health
.get_health_check_target(endpoint_subject)
.ok_or_else(|| {
anyhow::anyhow!("No health check target found for {}", endpoint_subject)
})?
};
let target = self
.drt
.system_health()
.lock()
.get_health_check_target(endpoint_subject)
.ok_or_else(|| {
anyhow::anyhow!("No health check target found for {}", endpoint_subject)
})?;
debug!(
"Sending health check to {} (instance_id: {})",
......@@ -274,7 +271,7 @@ impl HealthCheckManager {
let request: SingleIn<serde_json::Value> = Context::new(payload.clone());
// Clone what we need for the spawned task
let system_health = self.drt.system_health.clone();
let system_health = self.drt.system_health().clone();
let endpoint_subject_owned = endpoint_subject.to_string();
let instance_id = target.instance.instance_id;
let timeout = self.config.request_timeout;
......@@ -364,18 +361,16 @@ pub async fn get_health_check_status(
drt: &DistributedRuntime,
) -> anyhow::Result<serde_json::Value> {
// Get endpoints list from SystemHealth
let endpoint_subjects: Vec<String> = {
let system_health = drt.system_health.lock();
system_health.get_health_check_endpoints()
};
let endpoint_subjects: Vec<String> = drt.system_health().lock().get_health_check_endpoints();
let mut endpoint_statuses = HashMap::new();
// Check each endpoint's health status
{
let system_health = drt.system_health.lock();
let system_health = drt.system_health();
let system_health_lock = system_health.lock();
for endpoint_subject in &endpoint_subjects {
let health_status = system_health
let health_status = system_health_lock
.get_endpoint_health_status(endpoint_subject)
.unwrap_or(HealthStatus::NotReady);
......@@ -408,7 +403,6 @@ pub async fn get_health_check_status(
#[cfg(all(test, feature = "integration"))]
mod integration_tests {
use super::*;
use crate::HealthStatus;
use crate::distributed::distributed_test_utils::create_test_drt_async;
use std::sync::Arc;
use std::time::Duration;
......@@ -429,8 +423,6 @@ mod integration_tests {
assert_eq!(manager.config.canary_wait_time, canary_wait_time);
assert_eq!(manager.config.request_timeout, request_timeout);
assert!(Arc::ptr_eq(&manager.drt.system_health, &drt.system_health));
}
#[tokio::test]
......@@ -443,7 +435,7 @@ mod integration_tests {
"_health_check": true
});
drt.system_health.lock().register_health_check_target(
drt.system_health().lock().register_health_check_target(
endpoint,
crate::component::Instance {
component: "test_component".to_string(),
......@@ -456,7 +448,7 @@ mod integration_tests {
);
let retrieved = drt
.system_health
.system_health()
.lock()
.get_health_check_target(endpoint)
.map(|t| t.payload);
......@@ -464,7 +456,7 @@ mod integration_tests {
assert_eq!(retrieved.unwrap(), payload);
// Verify endpoint appears in the list
let endpoints = drt.system_health.lock().get_health_check_endpoints();
let endpoints = drt.system_health().lock().get_health_check_endpoints();
assert!(endpoints.contains(&endpoint.to_string()));
}
......@@ -478,7 +470,7 @@ mod integration_tests {
"prompt": format!("test{}", i),
"_health_check": true
});
drt.system_health.lock().register_health_check_target(
drt.system_health().lock().register_health_check_target(
&endpoint,
crate::component::Instance {
component: "test_component".to_string(),
......@@ -521,7 +513,7 @@ mod integration_tests {
});
// Register the endpoint
drt.system_health.lock().register_health_check_target(
drt.system_health().lock().register_health_check_target(
endpoint,
crate::component::Instance {
component: "test_component".to_string(),
......@@ -535,7 +527,7 @@ mod integration_tests {
// Verify that a notifier was created for this endpoint
let notifier = drt
.system_health
.system_health()
.lock()
.get_endpoint_health_check_notifier(endpoint);
......@@ -551,7 +543,7 @@ mod integration_tests {
// Initially, the endpoint should be Ready (default after registration)
let status = drt
.system_health
.system_health()
.lock()
.get_endpoint_health_status(endpoint);
assert_eq!(status, Some(HealthStatus::NotReady));
......
......@@ -6,17 +6,6 @@
#![allow(dead_code)]
#![allow(unused_imports)]
use std::{
collections::HashMap,
sync::{Arc, OnceLock, Weak},
};
pub use anyhow::{
Context as ErrorContext, Error, Ok as OK, Result, anyhow as error, bail as raise,
};
use async_once_cell::OnceCell;
pub mod config;
pub use config::RuntimeConfig;
......@@ -27,6 +16,7 @@ pub mod engine;
pub mod health_check;
pub mod system_status_server;
pub use system_status_server::SystemStatusServerInfo;
pub mod distributed;
pub mod instances;
pub mod logging;
pub mod metrics;
......@@ -44,77 +34,10 @@ pub mod transports;
pub mod utils;
pub mod worker;
pub mod distributed;
pub use distributed::distributed_test_utils;
pub use distributed::{DistributedRuntime, distributed_test_utils};
pub use futures::stream;
pub use metrics::MetricsRegistry;
pub use runtime::Runtime;
pub use system_health::{HealthCheckTarget, SystemHealth};
pub use tokio_util::sync::CancellationToken;
pub use worker::Worker;
use crate::{
metrics::prometheus_names::distributed_runtime,
storage::key_value_store::{KeyValueStore, KeyValueStoreManager},
};
use component::{Endpoint, InstanceSource};
use utils::GracefulShutdownTracker;
use config::HealthStatus;
/// Types of Tokio runtimes that can be used to construct a Dynamo [Runtime].
#[derive(Clone)]
enum RuntimeType {
Shared(Arc<tokio::runtime::Runtime>),
External(tokio::runtime::Handle),
}
/// Local [Runtime] which provides access to shared resources local to the physical node/machine.
#[derive(Debug, Clone)]
pub struct Runtime {
id: Arc<String>,
primary: RuntimeType,
secondary: RuntimeType,
cancellation_token: CancellationToken,
endpoint_shutdown_token: CancellationToken,
graceful_shutdown_tracker: Arc<GracefulShutdownTracker>,
compute_pool: Option<Arc<compute::ComputePool>>,
block_in_place_permits: Option<Arc<tokio::sync::Semaphore>>,
}
/// Distributed [Runtime] which provides access to shared resources across the cluster, this includes
/// communication protocols and transports.
#[derive(Clone)]
pub struct DistributedRuntime {
// local runtime
runtime: Runtime,
// we might consider a unifed transport manager here
etcd_client: Option<transports::etcd::Client>,
nats_client: Option<transports::nats::Client>,
store: KeyValueStoreManager,
tcp_server: Arc<OnceCell<Arc<transports::tcp::server::TcpStreamServer>>>,
system_status_server: Arc<OnceLock<Arc<system_status_server::SystemStatusServerInfo>>>,
// Service discovery interface
discovery_client: Arc<dyn discovery::Discovery>,
// local registry for components
// the registry allows us to use share runtime resources across instances of the same component object.
// take for example two instances of a client to the same remote component. The registry allows us to use
// a single endpoint watcher for both clients, this keeps the number background tasking watching specific
// paths in etcd to a minimum.
component_registry: component::Registry,
// Will only have static components that are not discoverable via etcd, they must be know at
// startup. Will not start etcd.
is_static: bool,
instance_sources: Arc<tokio::sync::Mutex<HashMap<Endpoint, Weak<InstanceSource>>>>,
// Health Status
system_health: Arc<parking_lot::Mutex<SystemHealth>>,
// This hierarchy's own metrics registry
metrics_registry: MetricsRegistry,
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use async_nats::client::Client;
use async_nats::{HeaderMap, HeaderValue};
use tracing as log;
use std::sync::Arc;
use super::*;
use crate::engine::{AsyncEngine, AsyncEngineContextProvider, Data};
use crate::logging::DistributedTraceContext;
use crate::logging::get_distributed_tracing_context;
use crate::logging::inject_otel_context_into_nats_headers;
use crate::{Result, protocols::maybe_error::MaybeError};
use crate::pipeline::network::ConnectionInfo;
use crate::pipeline::network::NetworkStreamWrapper;
use crate::pipeline::network::PendingConnections;
use crate::pipeline::network::ResponseService;
use crate::pipeline::network::STREAM_ERR_MSG;
use crate::pipeline::network::StreamOptions;
use crate::pipeline::network::TwoPartCodec;
use crate::pipeline::network::codec::TwoPartMessage;
use crate::pipeline::network::tcp;
use crate::pipeline::{ManyOut, PipelineError, ResponseStream, SingleIn};
use crate::protocols::maybe_error::MaybeError;
use anyhow::{Error, Result};
use async_nats::client::Client;
use async_nats::{HeaderMap, HeaderValue};
use serde::Deserialize;
use serde::Serialize;
use tokio_stream::{StreamExt, StreamNotifyClose, wrappers::ReceiverStream};
use tracing::Instrument;
......@@ -70,7 +84,7 @@ impl AddressedPushRouter {
}
}
#[async_trait]
#[async_trait::async_trait]
impl<T, U> AsyncEngine<SingleIn<AddressedRequest<T>>, ManyOut<U>, Error> for AddressedPushRouter
where
T: Data + Serialize,
......@@ -123,7 +137,7 @@ where
let ctrl = serde_json::to_vec(&control_message)?;
let data = serde_json::to_vec(&request)?;
log::trace!(
tracing::trace!(
request_id,
"packaging two-part message; ctrl: {} bytes, data: {} bytes",
ctrl.len(),
......@@ -140,7 +154,7 @@ where
// TRANSPORT ABSTRACT REQUIRED - END HERE
log::trace!(request_id, "enqueueing two-part message to nats");
tracing::trace!(request_id, "enqueueing two-part message to nats");
// Insert Trace Context into Headers
// Enables span to be created in push_endpoint before
......@@ -168,7 +182,7 @@ where
.request_with_headers(address.to_string(), headers, buffer)
.await?;
log::trace!(request_id, "awaiting transport handshake");
tracing::trace!(request_id, "awaiting transport handshake");
let response_stream = response_stream_provider
.await
.map_err(|_| PipelineError::DetachedStreamReceiver)?
......@@ -206,7 +220,7 @@ where
Err(err) => {
// legacy log print
let json_str = String::from_utf8_lossy(&res_bytes);
log::warn!(%err, %json_str, "Failed deserializing JSON to response");
tracing::warn!(%err, %json_str, "Failed deserializing JSON to response");
Some(U::from_err(Error::new(err).into()))
}
......@@ -218,11 +232,11 @@ where
// Gracefully end the stream if 'stop_generating()' was called. Do NOT check for
// 'is_killed()' here because it implies the stream ended abnormally which should be
// handled by the error branch below.
log::debug!("Request cancelled and then trying to read a response");
tracing::debug!("Request cancelled and then trying to read a response");
None
} else {
// stream ended unexpectedly
log::debug!("{STREAM_ERR_MSG}");
tracing::debug!("{STREAM_ERR_MSG}");
Some(U::from_err(Error::msg(STREAM_ERR_MSG).into()))
}
});
......
......@@ -19,7 +19,7 @@ use crate::pipeline::network::{
codec::{TwoPartCodec, TwoPartMessage},
tcp::StreamType,
};
use crate::{ErrorContext, Result, error}; // Import SinkExt to use the `send` method
use anyhow::{Context, Result, anyhow as error}; // Import SinkExt to use the `send` method
#[allow(dead_code)]
pub struct TcpClient {
......
......@@ -16,25 +16,6 @@ use derive_builder::Builder;
use futures::{SinkExt, StreamExt};
use local_ip_address::{Error, list_afinet_netifas, local_ip, local_ipv6};
// Trait for IP address resolution - allows dependency injection for testing
pub trait IpResolver {
fn local_ip(&self) -> Result<std::net::IpAddr, Error>;
fn local_ipv6(&self) -> Result<std::net::IpAddr, Error>;
}
// Default implementation using the real local_ip_address crate
pub struct DefaultIpResolver;
impl IpResolver for DefaultIpResolver {
fn local_ip(&self) -> Result<std::net::IpAddr, Error> {
local_ip()
}
fn local_ipv6(&self) -> Result<std::net::IpAddr, Error> {
local_ipv6()
}
}
use serde::{Deserialize, Serialize};
use tokio::{
io::AsyncWriteExt,
......@@ -56,7 +37,26 @@ use crate::pipeline::{
tcp::StreamType,
},
};
use crate::{ErrorContext, Result, error};
use anyhow::{Context, Result, anyhow as error};
// Trait for IP address resolution - allows dependency injection for testing
pub trait IpResolver {
fn local_ip(&self) -> Result<std::net::IpAddr, Error>;
fn local_ipv6(&self) -> Result<std::net::IpAddr, Error>;
}
// Default implementation using the real local_ip_address crate
pub struct DefaultIpResolver;
impl IpResolver for DefaultIpResolver {
fn local_ip(&self) -> Result<std::net::IpAddr, Error> {
local_ip()
}
fn local_ipv6(&self) -> Result<std::net::IpAddr, Error> {
local_ipv6()
}
}
#[allow(dead_code)]
type ResponseType = TwoPartMessage;
......
......@@ -2,7 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
use super::*;
use crate::Error;
use anyhow::Error;
impl<Resp: PipelineIO> Default for SinkEdge<Resp> {
fn default() -> Self {
......
......@@ -2,7 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
use super::*;
use crate::Error;
use anyhow::Error;
impl<Req: PipelineIO, Resp: PipelineIO> ServiceBackend<Req, Resp> {
pub fn from_engine(engine: ServiceEngine<Req, Resp>) -> Arc<Self> {
......
......@@ -2,7 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
use super::*;
use crate::Error;
use anyhow::Error;
impl<Req: PipelineIO, Resp: PipelineIO> SegmentSink<Req, Resp> {
pub fn new() -> Arc<Self> {
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::*;
use crate::{Result, error};
use maybe_error::MaybeError;
use super::maybe_error::MaybeError;
use anyhow::{Result, anyhow as error};
use serde::{Deserialize, Serialize};
pub trait AnnotationsProvider {
fn annotations(&self) -> Option<Vec<String>>;
......
......@@ -11,7 +11,7 @@ use std::{
task::{Context, Poll},
};
pub use crate::{Error, Result};
pub use anyhow::{Error, Result};
pub use async_trait::async_trait;
pub use tokio::task::JoinHandle;
pub use tokio_util::sync::CancellationToken;
......
......@@ -14,8 +14,10 @@
//! private; however, for now we are exposing most objects as fully public while the API is maturing.
use super::utils::GracefulShutdownTracker;
use super::{Result, Runtime, RuntimeType, error};
use crate::config::{self, RuntimeConfig};
use crate::{
compute,
config::{self, RuntimeConfig},
};
use futures::Future;
use once_cell::sync::OnceCell;
......@@ -24,8 +26,28 @@ use tokio::{signal, sync::Mutex, task::JoinHandle};
pub use tokio_util::sync::CancellationToken;
/// Types of Tokio runtimes that can be used to construct a Dynamo [Runtime].
#[derive(Clone)]
enum RuntimeType {
Shared(Arc<tokio::runtime::Runtime>),
External(tokio::runtime::Handle),
}
/// Local [Runtime] which provides access to shared resources local to the physical node/machine.
#[derive(Debug, Clone)]
pub struct Runtime {
id: Arc<String>,
primary: RuntimeType,
secondary: RuntimeType,
cancellation_token: CancellationToken,
endpoint_shutdown_token: CancellationToken,
graceful_shutdown_tracker: Arc<GracefulShutdownTracker>,
compute_pool: Option<Arc<compute::ComputePool>>,
block_in_place_permits: Option<Arc<tokio::sync::Semaphore>>,
}
impl Runtime {
fn new(runtime: RuntimeType, secondary: Option<RuntimeType>) -> Result<Runtime> {
fn new(runtime: RuntimeType, secondary: Option<RuntimeType>) -> anyhow::Result<Runtime> {
// worker id
let id = Arc::new(uuid::Uuid::new_v4().to_string());
......@@ -65,7 +87,7 @@ impl Runtime {
runtime: RuntimeType,
secondary: Option<RuntimeType>,
config: &RuntimeConfig,
) -> Result<Runtime> {
) -> anyhow::Result<Runtime> {
let mut rt = Self::new(runtime, secondary)?;
// Create compute pool from configuration
......@@ -123,7 +145,7 @@ impl Runtime {
/// Initialize thread-local compute context on all worker threads using a barrier
/// This ensures every worker thread has its thread-local context initialized
pub async fn initialize_all_thread_locals(&self) -> Result<()> {
pub async fn initialize_all_thread_locals(&self) -> anyhow::Result<()> {
if let (Some(pool), Some(permits)) = (&self.compute_pool, &self.block_in_place_permits) {
// First, detect how many worker threads we actually have
let num_workers = self.detect_worker_thread_count().await;
......@@ -207,11 +229,11 @@ impl Runtime {
count
}
pub fn from_current() -> Result<Runtime> {
pub fn from_current() -> anyhow::Result<Runtime> {
Runtime::from_handle(tokio::runtime::Handle::current())
}
pub fn from_handle(handle: tokio::runtime::Handle) -> Result<Runtime> {
pub fn from_handle(handle: tokio::runtime::Handle) -> anyhow::Result<Runtime> {
let primary = RuntimeType::External(handle.clone());
let secondary = RuntimeType::External(handle);
Runtime::new(primary, Some(secondary))
......@@ -219,7 +241,7 @@ impl Runtime {
/// Create a [`Runtime`] instance from the settings
/// See [`config::RuntimeConfig::from_settings`]
pub fn from_settings() -> Result<Runtime> {
pub fn from_settings() -> anyhow::Result<Runtime> {
let config = config::RuntimeConfig::from_settings()?;
let runtime = Arc::new(config.create_runtime()?);
let primary = RuntimeType::Shared(runtime.clone());
......@@ -228,7 +250,7 @@ impl Runtime {
}
/// Create a [`Runtime`] with two single-threaded async tokio runtime
pub fn single_threaded() -> Result<Runtime> {
pub fn single_threaded() -> anyhow::Result<Runtime> {
let config = config::RuntimeConfig::single_threaded();
let owned = RuntimeType::Shared(Arc::new(config.create_runtime()?));
Runtime::new(owned, None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment