Unverified Commit 07724eb9 authored by Neelay Shah's avatar Neelay Shah Committed by GitHub
Browse files

chore: remove unused lib/discovery crate (#5266)


Co-authored-by: default avatarClaude Opus 4.5 <noreply@anthropic.com>
parent 538d3035
This diff is collapsed.
......@@ -14,7 +14,6 @@ members = [
"lib/bindings/c",
"lib/bindings/python/codegen",
"lib/engines/*",
"lib/discovery",
"lib/config",
]
# Exclude certain packages that are slow to build and we don't ship as flagship
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
[package]
name = "dynamo-discovery"
version.workspace = true
edition.workspace = true
authors.workspace = true
license.workspace = true
repository.workspace = true
[dependencies]
# Core dependencies (always present)
anyhow = "1.0"
async-trait = "0.1"
bytes = { version = "1.8", features = ["serde"] }
dashmap = "6.1"
derive_builder = { workspace = true }
figment = { version = "0.10", features = ["toml", "yaml", "env"] }
futures = { version = "0.3" }
parking_lot = { version = "0.12" }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
tokio = { version = "1.43", features = ["full"] }
tokio-util = "0.7"
thiserror = "2.0"
tracing = "0.1"
uuid = { version = "1.11", features = ["v4", "serde"] }
xxhash-rust = { version = "0.8", features = ["xxh3"] }
validator = { workspace = true }
# HTTP service dependencies (optional)
axum = { version = "0.8", optional = true }
tower = { version = "0.5", optional = true }
tower-http = { version = "0.6", features = ["trace"], optional = true }
hyper = { version = "1.5", optional = true }
reqwest = { version = "0.12", features = ["json"], optional = true }
tokio-stream = { version = "0.1", optional = true }
futures-util = { version = "0.3", optional = true }
# Etcd dependencies (optional)
etcd-client = { version = "0.17", optional = true }
tonic = { version = "0.14", optional = true }
# Libp2p dependencies (optional) - Match am-core versions
libp2p = { version = "0.56", default-features = false, features = [
"tcp",
"noise",
"yamux",
"macros",
"tokio",
"pnet",
], optional = true }
libp2p-kad = { version = "0.48", optional = true }
libp2p-mdns = { version = "0.48", features = ["tokio"], optional = true }
libp2p-swarm = { version = "0.47", optional = true }
libp2p-identity = { version = "0.2", optional = true }
blake2 = { version = "0.10", optional = true }
[dev-dependencies]
tokio = { version = "1.43", features = ["full", "test-util"] }
clap = { version = "4", features = ["derive"] }
tempfile = "3"
testcontainers = "0.25"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
[features]
default = ["p2p", "etcd"]
etcd = ["dep:etcd-client", "dep:tonic"]
p2p = ["dep:libp2p", "dep:libp2p-kad", "dep:libp2p-mdns", "dep:libp2p-swarm", "dep:libp2p-identity", "dep:blake2"]
full = ["etcd", "p2p"]
testing-etcd = [] # Enable etcd tests (default on, disable with --no-default-features)
integration-etcd = [] # Feature flag for integration tests with real etcd (requires Docker)
# http-service = ["dep:axum", "dep:tower", "dep:tower-http", "dep:hyper", "dep:reqwest", "dep:tracing-subscriber", "dep:tokio-stream", "dep:futures-util"]
# dynamo-discovery (lib/discovery)
A small, capability-driven discovery layer for the Dynamo runtime. The core idea is to separate “what the application needs to discover” from “how a particular backend provides it,” and to provide a thin manager that composes multiple backends with caching and a concise, stable API.
## Philosophy
- Discovery is application-specific. Each application defines discovery traits that describe the information it needs (e.g., peers, topics, shards, services) and the operations required to work with that information.
- Systems are concrete implementations. A system (e.g., etcd, libp2p, an HTTP microservice, S3, NATS) implements one or more of the discovery traits. Different systems have different capabilities; not every system can implement every trait or policy.
- Managers orchestrate and cache. A manager owns the logic to coordinate multiple systems that implement the same trait, deduplicate concurrent lookups, maintain a local cache, and expose a clean public API tailored for the runtime.
This division lets you grow capabilities without coupling the runtime to any one backend. Traits define the contract; systems provide the plumbing; managers keep the runtime simple and fast.
## Core Concepts
- Discovery traits
- Define what the application wants to discover and the related operations.
- Include both public-facing operations (what the runtime calls) and internal operations (used for registration, consistency checks, etc.).
- Systems
- Backend-specific code that implements one or more discovery traits (and only the parts they can support).
- Example systems: `etcd` (centralized + TTL), `libp2p` (DHT), an HTTP service client, S3, NATS, in-memory.
- A system may expose just a subset of traits based on its capability.
- Managers
- Constructed with one or more system implementations of a trait.
- Provide a concise, stable public API, while handling caching, coalescing, retries, and capability differences behind the scenes.
- Allow you to mix-and-match systems for resilience and performance (e.g., fast in-memory cache + remote etcd).
## Capability Model
- Traits describe behavior; systems opt into the parts they can implement.
- The `DiscoverySystem` abstraction can vend one or more trait implementations. If a system cannot support a trait, it simply does not provide it.
- Managers accept a set of trait implementations and will use whatever is provided, with graceful fallback rules (e.g., local cache first, then remote sources).
## Example: Peer Discovery
The peer discovery trait is used by the runtime to translate identifiers into addresses and to manage lifecycle around registration.
- Trait methods (conceptual):
- `discover_by_worker_id(worker_id) -> PeerInfo`
- `discover_by_instance_id(instance_id) -> PeerInfo`
- `register_instance(instance_id, address) -> ()`
- `unregister_instance(instance_id) -> ()`
- Manager API (public vs. internal):
- Public: discovery queries
- `discover_by_worker_id(worker_id)`
- `discover_by_instance_id(instance_id)`
- Internal: lifecycle
- Registration and unregistration are handled by the manager when it is constructed (register the local peer) and during shutdown/cleanup. These are not exposed as public manager methods.
- Why hide registration on the manager?
- Keeps the runtime call surface minimal and intentional.
- Enforces consistent lifecycle semantics (checksums, collisions, TTLs) in one place.
- Avoids leaking backend mechanics into the runtime path.
### How the Manager Works (at a glance)
- On construction, the manager registers the local peer in its local cache and in all configured remote systems that support the peer discovery trait.
- On lookup, it consults the local cache first, then queries remotes if needed. Concurrent lookups for the same key are coalesced into a shared query. Successful remote results are cached locally for future fast paths.
## Typical Wiring
- Choose your systems and build them (e.g., Etcd with TTL, Libp2p, HTTP client, or an in-memory source for tests).
- Extract the trait implementations the runtime needs (e.g., `PeerDiscovery`).
- Create a manager with the local peer and a list of trait impls:
```rust
use std::sync::Arc;
use dynamo_am_discovery::peer::{PeerInfo, WorkerAddress, InstanceId};
use dynamo_am_discovery::peer::manager::PeerDiscoveryManager;
use dynamo_am_discovery::systems::DiscoverySystem; // e.g., etcd system builds this
# async fn example(system: Arc<dyn DiscoverySystem>) -> anyhow::Result<()> {
let local_instance = InstanceId::new_v4();
let local_address = WorkerAddress::from_bytes(b"tcp://127.0.0.1:5555".as_slice());
let local_peer = PeerInfo::new(local_instance, local_address);
// Get one or more implementations of the peer discovery trait
let mut sources = Vec::new();
if let Some(peer_disc) = system.peer_discovery() {
sources.push(peer_disc);
}
// Build the manager that orchestrates cache + remotes
let manager = PeerDiscoveryManager::new(local_peer, sources).await?;
// Look up a peer by worker_id or instance_id
// (The manager will hit local cache first, then remotes as needed.)
let _maybe = manager.discover_by_worker_id(local_instance.worker_id()).await;
let _maybe2 = manager.discover_by_instance_id(local_instance).await;
# Ok(())
# }
```
Note: The manager deliberately keeps registration/unregistration internal. If your application lifecycle requires explicit registration timing, do that by constructing the manager at the appropriate point in startup, and let it handle registration with all configured systems.
## Extending the Crate
- Add a new discovery trait when the application needs to discover a new kind of thing (e.g., shard ownership). Keep the trait small and precise.
- Implement the trait in one or more systems. It’s fine if only some systems can implement it.
- Add a manager for the trait if you need composition, caching, or a slimmer public API for the runtime.
- Keep trait-level semantics strict and documented. Managers can hide backend-specific details while enforcing common policies (e.g., collision detection, address checksums, TTLs).
## Notes on Consistency and Errors
- Systems may enforce additional policies (e.g., TTL expiry in etcd, collision detection, checksum validation). Managers use these and surface simple success/not-found/backend-error semantics to the runtime.
- Local caches accelerate the common path and are populated opportunistically from successful remote lookups.
- Concurrent lookups are deduplicated to reduce load on remote systems.
## Available Systems (examples)
- Etcd-backed system (centralized, TTL-based, keep-alive, transactional collision detection).
- Libp2p-backed system (decentralized DHT).
- HTTP service client.
- In-memory (useful for tests and single-node scenarios).
Not all systems will implement every trait; use the manager’s composition to mix what you need.
## Why This Design?
- Keeps the runtime portable: swap discovery backends without changing call sites.
- Embraces partial capability: wire up the systems that can do the job, skip the rest.
- Minimizes API surface for the runtime: managers expose only the operations the runtime actually needs, while handling lifecycle internally.
- Encourages small traits and pluggable systems, so the crate can evolve without lock-in.
If you’re adding a new trait or system, keep the trait narrowly scoped, stick to clear semantics, and lean on the manager to integrate, cache, and present a clean API to the rest of the runtime.
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
pub mod peer;
pub mod systems;
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Address types for peer discovery.
//!
//! This module provides types for representing worker addresses and peer information:
//! - [`WorkerAddress`]: Opaque byte representation of a peer's network address
//! - [`PeerInfo`]: Combined instance ID and worker address for a discovered peer
//!
//! These types are intentionally transport-agnostic, storing addresses as opaque bytes.
//! The interpretation of these bytes is left to the active message runtime.
use super::{InstanceId, WorkerId};
use bytes::Bytes;
use serde::{Deserialize, Serialize};
use std::fmt;
use xxhash_rust::xxh3::xxh3_64;
/// Opaque worker address for discovery.
///
/// This is a transport-agnostic representation of a peer's network address.
/// The bytes are opaque to discovery and are interpreted by the active message runtime.
///
/// # Checksum
///
/// WorkerAddress implements a checksum via xxh3_64 for quick comparison during
/// re-registration validation.
#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(transparent)]
pub struct WorkerAddress(Bytes);
impl WorkerAddress {
/// Create a new WorkerAddress from bytes.
pub fn from_bytes(bytes: impl Into<Bytes>) -> Self {
Self(bytes.into())
}
/// Get the underlying bytes.
pub fn as_bytes(&self) -> &[u8] {
&self.0
}
/// Get the bytes as a Bytes object.
pub fn to_bytes(&self) -> Bytes {
self.0.clone()
}
/// Compute a checksum of this address for validation.
///
/// This is used to quickly check if an address has changed during re-registration.
pub fn checksum(&self) -> u64 {
xxh3_64(self.as_bytes())
}
}
impl fmt::Debug for WorkerAddress {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("WorkerAddress")
.field(&format_args!(
"len={}, xxh3_64=0x{:016x}",
self.0.len(),
self.checksum()
))
.finish()
}
}
impl fmt::Display for WorkerAddress {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "WorkerAddress(xxh3_64=0x{:016x})", self.checksum())
}
}
/// Peer information combining instance ID and worker address.
///
/// This is the primary type returned by discovery lookups. It contains everything
/// needed to connect to and identify a peer.
///
/// # Example
///
/// ```
/// use dynamo_am_discovery::{InstanceId, WorkerAddress, PeerInfo};
/// use bytes::Bytes;
///
/// let instance_id = InstanceId::new_v4();
/// let address = WorkerAddress::from_bytes(Bytes::from_static(b"tcp://127.0.0.1:5555"));
/// let peer_info = PeerInfo::new(instance_id, address);
///
/// assert_eq!(peer_info.instance_id(), instance_id);
/// assert_eq!(peer_info.worker_id(), instance_id.worker_id());
/// ```
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct PeerInfo {
/// The instance ID of the peer
pub instance_id: InstanceId,
/// The worker address for connecting to the peer
pub worker_address: WorkerAddress,
}
impl PeerInfo {
/// Create a new PeerInfo.
pub fn new(instance_id: InstanceId, worker_address: WorkerAddress) -> Self {
Self {
instance_id,
worker_address,
}
}
/// Get the instance ID.
pub fn instance_id(&self) -> InstanceId {
self.instance_id
}
/// Get the worker ID (derived from instance ID).
pub fn worker_id(&self) -> WorkerId {
self.instance_id.worker_id()
}
/// Get a reference to the worker address.
pub fn worker_address(&self) -> &WorkerAddress {
&self.worker_address
}
/// Get the worker address checksum for validation.
pub fn address_checksum(&self) -> u64 {
self.worker_address.checksum()
}
/// Consume self and return the worker address.
pub fn into_address(self) -> WorkerAddress {
self.worker_address
}
/// Decompose into instance ID and worker address.
pub fn into_parts(self) -> (InstanceId, WorkerAddress) {
(self.instance_id, self.worker_address)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_worker_address_creation() {
let bytes = Bytes::from_static(b"tcp://127.0.0.1:5555");
let address = WorkerAddress::from_bytes(bytes.clone());
assert_eq!(address.as_bytes(), bytes.as_ref());
assert_eq!(address.to_bytes(), bytes);
}
#[test]
fn test_worker_address_checksum() {
let address1 = WorkerAddress::from_bytes(Bytes::from_static(b"tcp://127.0.0.1:5555"));
let address2 = WorkerAddress::from_bytes(Bytes::from_static(b"tcp://127.0.0.1:5555"));
let address3 = WorkerAddress::from_bytes(Bytes::from_static(b"tcp://127.0.0.1:6666"));
// Same bytes = same checksum
assert_eq!(address1.checksum(), address2.checksum());
// Different bytes = (likely) different checksum
assert_ne!(address1.checksum(), address3.checksum());
}
#[test]
fn test_worker_address_equality() {
let address1 = WorkerAddress::from_bytes(Bytes::from_static(b"tcp://127.0.0.1:5555"));
let address2 = WorkerAddress::from_bytes(Bytes::from_static(b"tcp://127.0.0.1:5555"));
let address3 = WorkerAddress::from_bytes(Bytes::from_static(b"tcp://127.0.0.1:6666"));
assert_eq!(address1, address2);
assert_ne!(address1, address3);
}
#[test]
fn test_worker_address_debug() {
let address = WorkerAddress::from_bytes(Bytes::from_static(b"test"));
let debug_str = format!("{:?}", address);
assert!(debug_str.contains("WorkerAddress"));
assert!(debug_str.contains("len=4"));
assert!(debug_str.contains("xxh3_64="));
}
#[test]
fn test_peer_info_creation() {
let instance_id = InstanceId::new_v4();
let address = WorkerAddress::from_bytes(Bytes::from_static(b"tcp://127.0.0.1:5555"));
let peer_info = PeerInfo::new(instance_id, address.clone());
assert_eq!(peer_info.instance_id(), instance_id);
assert_eq!(peer_info.worker_id(), instance_id.worker_id());
assert_eq!(peer_info.worker_address(), &address);
}
#[test]
fn test_peer_info_checksum() {
let instance_id = InstanceId::new_v4();
let address = WorkerAddress::from_bytes(Bytes::from_static(b"tcp://127.0.0.1:5555"));
let peer_info = PeerInfo::new(instance_id, address.clone());
assert_eq!(peer_info.address_checksum(), address.checksum());
}
#[test]
fn test_peer_info_into_address() {
let instance_id = InstanceId::new_v4();
let address = WorkerAddress::from_bytes(Bytes::from_static(b"tcp://127.0.0.1:5555"));
let peer_info = PeerInfo::new(instance_id, address.clone());
let extracted_address = peer_info.into_address();
assert_eq!(extracted_address, address);
}
#[test]
fn test_peer_info_into_parts() {
let instance_id = InstanceId::new_v4();
let address = WorkerAddress::from_bytes(Bytes::from_static(b"tcp://127.0.0.1:5555"));
let peer_info = PeerInfo::new(instance_id, address.clone());
let (extracted_id, extracted_address) = peer_info.into_parts();
assert_eq!(extracted_id, instance_id);
assert_eq!(extracted_address, address);
}
#[test]
fn test_peer_info_serde() {
let instance_id = InstanceId::new_v4();
let address = WorkerAddress::from_bytes(Bytes::from_static(b"tcp://127.0.0.1:5555"));
let peer_info = PeerInfo::new(instance_id, address);
// Serialize to JSON
let json = serde_json::to_string(&peer_info).unwrap();
// Deserialize back
let deserialized: PeerInfo = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.instance_id(), instance_id);
assert_eq!(deserialized.worker_id(), instance_id.worker_id());
assert_eq!(
deserialized.worker_address().as_bytes(),
b"tcp://127.0.0.1:5555"
);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Identity types for the active message system.
//!
//! This module provides strongly-typed wrappers for instance and worker identifiers:
//! - [`InstanceId`]: Unique runtime instance identifier (wraps UUID)
//! - [`WorkerId`]: Deterministic 64-bit worker identifier derived from InstanceId
//!
//! # Design Principles
//!
//! 1. **Type Safety**: InstanceId cannot be confused with message IDs or other UUIDs
//! 2. **Deterministic Derivation**: WorkerId is always computed from InstanceId (xxh3_64 hash)
//! 3. **Single Source of Truth**: InstanceId is the primary identifier, WorkerId is derived
//!
//! # Example
//!
//! ```ignore
//! // InstanceIds are created internally by the runtime
//! // and obtained from ActiveMessageClient.instance_id()
//!
//! use dynamo_am::api::identity::{InstanceId, WorkerId};
//!
//! # fn get_instance_id() -> InstanceId { unimplemented!() }
//! let instance_id = get_instance_id(); // From ActiveMessageClient
//!
//! // Derive worker ID automatically
//! let worker_id: WorkerId = instance_id.worker_id();
//!
//! // WorkerId is deterministic
//! assert_eq!(worker_id, instance_id.worker_id());
//! ```
use serde::{Deserialize, Serialize};
use std::fmt;
use uuid::Uuid;
use xxhash_rust::xxh3::xxh3_64;
/// Unique identifier for a runtime instance.
///
/// This is a UUID-based identifier that uniquely identifies a running instance
/// of the active message runtime. It is used for:
/// - Transport-level addressing
/// - Discovery registration
/// - Routing table management
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(transparent)]
pub struct InstanceId(Uuid);
impl InstanceId {
/// Create a new random v4 InstanceId.
///
/// This is exposed for testing and special cases. In production, use
/// [`InstanceFactory::create()`] instead.
pub fn new_v4() -> Self {
Self(Uuid::new_v4())
}
/// Create an InstanceId from a UUID.
pub fn from_uuid(uuid: Uuid) -> Self {
Self(uuid)
}
/// Create an InstanceId from raw bytes.
pub fn from_bytes(bytes: [u8; 16]) -> Self {
Self(Uuid::from_bytes(bytes))
}
/// Derive the deterministic WorkerId from this InstanceId.
///
/// WorkerId is computed using xxh3_64 hash of the UUID bytes.
/// This ensures a 1:1 mapping between InstanceId and WorkerId.
pub fn worker_id(&self) -> WorkerId {
WorkerId::from(self)
}
/// Get a reference to the underlying UUID.
pub fn as_uuid(&self) -> &Uuid {
&self.0
}
/// Get the underlying UUID as a u128.
pub fn as_u128(&self) -> u128 {
self.0.as_u128()
}
/// Get the underlying UUID as bytes.
pub fn as_bytes(&self) -> &[u8; 16] {
self.0.as_bytes()
}
}
impl fmt::Display for InstanceId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl From<Uuid> for InstanceId {
fn from(uuid: Uuid) -> Self {
Self(uuid)
}
}
impl From<InstanceId> for Uuid {
fn from(id: InstanceId) -> Self {
id.0
}
}
impl AsRef<Uuid> for InstanceId {
fn as_ref(&self) -> &Uuid {
&self.0
}
}
/// Deterministic 64-bit worker identifier derived from InstanceId.
///
/// WorkerId is used in:
/// - [`crate::event::EventHandle`]: Embedded in the u128 event handle (64 bits)
/// - [`crate::event::EventRoutingTable`]: Maps worker_id → instance_id for event routing
/// - Discovery systems: Lookup key for peer information
///
/// WorkerId is **always derived** from InstanceId using xxh3_64 hash.
/// This ensures consistency across the system without needing to store both values.
///
/// # Example
///
/// ```ignore
/// use dynamo_am::api::identity::{InstanceId, WorkerId};
///
/// # fn get_instance_id() -> InstanceId { unimplemented!() }
/// let instance_id = get_instance_id(); // From ActiveMessageClient
/// let worker_id = instance_id.worker_id();
///
/// // WorkerId is deterministic
/// assert_eq!(worker_id, instance_id.worker_id());
/// ```
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
#[serde(transparent)]
pub struct WorkerId(u64);
impl WorkerId {
/// Create a WorkerId from a raw u64 value.
///
/// This is used when decoding WorkerIds from event handles or wire formats.
/// External users should always derive WorkerId via `instance_id.worker_id()`.
pub fn from_u64(value: u64) -> Self {
Self(value)
}
/// Get the underlying u64 value.
#[inline(always)]
pub fn as_u64(&self) -> u64 {
self.0
}
}
impl fmt::Display for WorkerId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl From<&InstanceId> for WorkerId {
/// Derive WorkerId from InstanceId using xxh3_64 hash.
///
/// This is the canonical way to compute WorkerId - it should never be
/// constructed any other way to ensure consistency.
fn from(id: &InstanceId) -> Self {
Self(xxh3_64(id.as_uuid().as_bytes()))
}
}
impl From<InstanceId> for WorkerId {
fn from(id: InstanceId) -> Self {
Self::from(&id)
}
}
impl From<WorkerId> for u64 {
fn from(id: WorkerId) -> Self {
id.0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_instance_id_creation() {
let id1 = InstanceId::new_v4();
let id2 = InstanceId::new_v4();
// Different instances have different IDs
assert_ne!(id1, id2);
// Can convert to/from UUID
let uuid: Uuid = id1.into();
let id3 = InstanceId::from(uuid);
assert_eq!(id1, id3);
}
#[test]
fn test_worker_id_deterministic() {
let instance_id = InstanceId::new_v4();
// WorkerId is deterministic
let worker_id1 = instance_id.worker_id();
let worker_id2 = instance_id.worker_id();
assert_eq!(worker_id1, worker_id2);
// Different instances have different worker IDs
let other_instance = InstanceId::new_v4();
let other_worker = other_instance.worker_id();
assert_ne!(worker_id1, other_worker);
}
#[test]
fn test_worker_id_from_conversion() {
let instance_id = InstanceId::new_v4();
// Both From implementations work
let worker_id1 = WorkerId::from(&instance_id);
let worker_id2 = WorkerId::from(instance_id);
assert_eq!(worker_id1, worker_id2);
// Matches .worker_id() method
assert_eq!(worker_id1, instance_id.worker_id());
}
#[test]
fn test_instance_id_display() {
let instance_id = InstanceId::new_v4();
let display = format!("{}", instance_id);
let uuid_display = format!("{}", instance_id.as_uuid());
assert_eq!(display, uuid_display);
}
#[test]
fn test_worker_id_display() {
let instance_id = InstanceId::new_v4();
let worker_id = instance_id.worker_id();
let display = format!("{}", worker_id);
let u64_display = format!("{}", worker_id.as_u64());
assert_eq!(display, u64_display);
}
#[test]
fn test_instance_id_serde() {
let instance_id = InstanceId::new_v4();
// Serialize as JSON
let json = serde_json::to_string(&instance_id).unwrap();
// Should be a plain UUID string
let uuid_json = serde_json::to_string(instance_id.as_uuid()).unwrap();
assert_eq!(json, uuid_json);
// Deserialize back
let deserialized: InstanceId = serde_json::from_str(&json).unwrap();
assert_eq!(instance_id, deserialized);
}
#[test]
fn test_worker_id_serde() {
let worker_id = InstanceId::new_v4().worker_id();
// Serialize as JSON
let json = serde_json::to_string(&worker_id).unwrap();
// Should be a plain u64
let u64_json = serde_json::to_string(&worker_id.as_u64()).unwrap();
assert_eq!(json, u64_json);
// Deserialize back
let deserialized: WorkerId = serde_json::from_str(&json).unwrap();
assert_eq!(worker_id, deserialized);
}
#[test]
fn test_instance_id_as_methods() {
let uuid = Uuid::new_v4();
let instance_id = InstanceId::from_uuid(uuid);
assert_eq!(instance_id.as_uuid(), &uuid);
assert_eq!(instance_id.as_u128(), uuid.as_u128());
assert_eq!(instance_id.as_bytes(), uuid.as_bytes());
}
#[test]
fn test_instance_id_from_bytes() {
let uuid = Uuid::new_v4();
let bytes = *uuid.as_bytes();
let instance_id = InstanceId::from_bytes(bytes);
assert_eq!(instance_id.as_uuid(), &uuid);
}
#[test]
fn test_worker_id_u64_conversion() {
let instance_id = InstanceId::new_v4();
let worker_id = instance_id.worker_id();
let raw_u64 = worker_id.as_u64();
let reconstructed = WorkerId::from_u64(raw_u64);
assert_eq!(worker_id, reconstructed);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
use crate::peer::{
DiscoveryError, DiscoveryQueryError, InstanceId, PeerInfo, WorkerAddress, WorkerId,
};
#[derive(Debug, Default, Clone)]
pub struct LocalPeerDiscovery {
inner: Arc<RwLock<LocalPeerDiscoveryInner>>,
}
#[derive(Debug, Default, Clone)]
struct LocalPeerDiscoveryInner {
by_worker_id: HashMap<WorkerId, InstanceId>,
by_instance_id: HashMap<InstanceId, PeerInfo>,
}
impl LocalPeerDiscovery {
pub fn discover_by_worker_id(
&self,
worker_id: WorkerId,
) -> Result<PeerInfo, DiscoveryQueryError> {
let state = self.inner.read();
let by_worker_id = state.by_worker_id.get(&worker_id);
if let Some(instance_id) = by_worker_id {
let peer_info = state.by_instance_id.get(instance_id);
if let Some(peer_info) = peer_info {
return Ok(peer_info.clone());
}
}
Err(DiscoveryQueryError::NotFound)
}
pub fn discover_by_instance_id(
&self,
instance_id: InstanceId,
) -> Result<PeerInfo, DiscoveryQueryError> {
let state = self.inner.read();
let by_instance_id = state.by_instance_id.get(&instance_id);
if let Some(peer_info) = by_instance_id {
return Ok(peer_info.clone());
}
Err(DiscoveryQueryError::NotFound)
}
pub fn register_instance(
&self,
instance_id: InstanceId,
worker_address: WorkerAddress,
) -> Result<(), DiscoveryError> {
let mut state = self.inner.write();
// Validate no worker_id collision
let worker_id = instance_id.worker_id();
if let Some(existing_instance) = state.by_worker_id.get(&worker_id)
&& *existing_instance != instance_id
{
return Err(DiscoveryError::WorkerIdCollision(
worker_id,
*existing_instance,
instance_id,
));
}
// Fail-fast for any duplicate registration attempt
if let Some(existing_peer_info) = state.by_instance_id.get(&instance_id) {
// Check if it's the same address (idempotent attempt) or different
if existing_peer_info.address_checksum() == worker_address.checksum() {
// Duplicate registration with same address - fail to detect bugs
return Err(DiscoveryError::InstanceAlreadyRegistered(instance_id));
} else {
// Re-registration with different address - fail with checksum mismatch
return Err(DiscoveryError::ChecksumMismatch(
instance_id,
existing_peer_info.address_checksum(),
worker_address.checksum(),
));
}
}
// Register peer
let peer_info = PeerInfo::new(instance_id, worker_address);
state.by_worker_id.insert(worker_id, instance_id);
state.by_instance_id.insert(instance_id, peer_info);
Ok(())
}
#[expect(dead_code)]
pub fn unregister_instance(&self, instance_id: InstanceId) -> Result<(), DiscoveryError> {
let mut state = self.inner.write();
state.by_worker_id.remove(&instance_id.worker_id());
state.by_instance_id.remove(&instance_id);
Ok(())
}
}
This diff is collapsed.
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Peer discovery for the Dynamo Active Message system.
use anyhow::Result;
use futures::future::BoxFuture;
use std::fmt;
use std::sync::Arc;
mod address;
mod identity;
mod manager;
pub use address::{PeerInfo, WorkerAddress};
pub use identity::{InstanceId, WorkerId};
pub use manager::PeerDiscoveryManager;
/// Error type for discovery operations.
#[derive(Debug, thiserror::Error)]
pub enum DiscoveryError {
/// Worker ID collision detected - same worker_id registered to different instance
#[error(
"Worker ID collision: worker_id {0} already registered to instance {1}, attempted to register to {2}"
)]
WorkerIdCollision(WorkerId, InstanceId, InstanceId),
/// Address checksum mismatch during re-registration
#[error("Address checksum mismatch for instance {0}: existing=0x{1:016x}, new=0x{2:016x}")]
ChecksumMismatch(InstanceId, u64, u64),
/// Instance already registered - duplicate registration detected
#[error("Instance {0} is already registered")]
InstanceAlreadyRegistered(InstanceId),
/// Backend-specific error
#[error("Backend error: {0}")]
Backend(#[from] anyhow::Error),
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum DiscoveryQueryError {
#[error("Not found")]
NotFound,
#[error("Backend error: {0}")]
Backend(Arc<anyhow::Error>),
}
pub type AwaitableQueryResult = BoxFuture<'static, Result<PeerInfo, DiscoveryQueryError>>;
pub type AwaitableRegisterResult = BoxFuture<'static, Result<(), DiscoveryError>>;
/// Trait for discovering [`PeerInfo`] by [`WorkerId`] or [`InstanceId`].
pub trait PeerDiscovery: Send + Sync + fmt::Debug {
/// Lookup peer by worker_id.
fn discover_by_worker_id(&self, worker_id: WorkerId) -> AwaitableQueryResult;
/// Lookup peer by instance_id.
fn discover_by_instance_id(&self, instance_id: InstanceId) -> AwaitableQueryResult;
/// Register this peer in the discovery system.
fn register_instance(
&self,
instance_id: InstanceId,
worker_address: WorkerAddress,
) -> AwaitableRegisterResult;
/// Unregister this peer from the discovery system.
fn unregister_instance(&self, instance_id: InstanceId) -> AwaitableRegisterResult;
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use anyhow::{Context, Result, anyhow as error};
use dashmap::DashMap;
use etcd_client::ConnectOptions;
use futures::future::{BoxFuture, FutureExt, Shared};
use parking_lot::RwLock;
use std::{sync::Arc, time::Duration};
use tokio::{sync::Mutex, time::sleep};
/// Type alias for the shared reconnection future
type ReconnectFuture = Shared<BoxFuture<'static, Result<(), Arc<anyhow::Error>>>>;
/// Manages ETCD client connections with reconnection support
#[derive(Clone)]
pub struct Client {
/// The actual ETCD client, protected by RwLock for safe updates during reconnection
/// WARNING: Do not recursively acquire a read lock when the current thread already holds one
client: Arc<RwLock<etcd_client::Client>>,
/// Configuration for connecting to ETCD
etcd_urls: Arc<Vec<String>>,
connect_options: Arc<Option<ConnectOptions>>,
/// Tracks the current backoff duration and last successful connect time
/// The Mutex ensures only one reconnect operation runs at a time
backoff_state: Arc<Mutex<BackoffState>>,
/// Shared reconnection futures for deduplication
/// Only one reconnection happens at a time; concurrent callers share the future
reconnect_pending: Arc<DashMap<(), ReconnectFuture>>,
}
impl Client {
/// Create a new connector with an established connection
pub async fn new(
etcd_urls: Vec<String>,
connect_options: Option<ConnectOptions>,
initial_backoff: Duration,
min_backoff: Duration,
max_backoff: Duration,
) -> Result<Self> {
// Connect to ETCD
let client = Self::connect(&etcd_urls, &connect_options).await?;
Ok(Self {
client: Arc::new(RwLock::new(client)),
etcd_urls: Arc::new(etcd_urls),
connect_options: Arc::new(connect_options),
backoff_state: Arc::new(Mutex::new(BackoffState::new(
initial_backoff,
min_backoff,
max_backoff,
))),
reconnect_pending: Arc::new(DashMap::new()),
})
}
/// Connect to ETCD cluster
async fn connect(
etcd_urls: &[String],
connect_options: &Option<ConnectOptions>,
) -> Result<etcd_client::Client> {
etcd_client::Client::connect(etcd_urls.to_vec(), connect_options.clone())
.await
.with_context(|| {
format!(
"Unable to connect to etcd server at {}. Check etcd server status",
etcd_urls.join(", ")
)
})
}
/// Get a clone of the current ETCD client
pub fn get_client(&self) -> etcd_client::Client {
self.client.read().clone()
}
/// Ensure the client is connected, triggering reconnection if needed.
///
/// This method deduplicates concurrent reconnection attempts - only one
/// reconnection happens at a time, with all callers sharing the same future.
///
/// # Arguments
/// * `deadline` - Deadline for reconnection attempts
/// * `force` - If true, start reconnection even if not already in progress
///
/// Returns Ok(()) if connected, Err if reconnection failed.
pub async fn ensure_connected(&self, deadline: std::time::Instant, force: bool) -> Result<()> {
// Check if reconnection already in progress
if let Some(shared_future_ref) = self.reconnect_pending.get(&()) {
let shared = shared_future_ref.clone();
drop(shared_future_ref); // Release DashMap lock before await
let result = shared.await.map_err(|e| anyhow::anyhow!("{}", e));
if result.is_err() {
// Clean up failed future so subsequent calls can retry
self.reconnect_pending.remove(&());
}
return result;
}
// If not forced, assume we're connected (lightweight path)
if !force {
return Ok(());
}
// Start new reconnection (deduplicated)
use dashmap::mapref::entry::Entry;
let shared_future = match self.reconnect_pending.entry(()) {
Entry::Occupied(entry) => {
// Another thread started reconnection, use their future
entry.get().clone()
}
Entry::Vacant(entry) => {
// We're first, create the shared future
let client = self.clone();
let shared = async move { client.reconnect_impl(deadline).await.map_err(Arc::new) }
.boxed()
.shared();
entry.insert(shared.clone());
shared
}
};
let result = shared_future.await.map_err(|e| anyhow::anyhow!("{}", e));
if result.is_err() {
// Clean up failed future so subsequent calls can retry
self.reconnect_pending.remove(&());
}
result
}
/// Internal implementation of reconnection with retry logic.
/// Respects the deadline and returns error if exceeded.
///
/// Backoff behavior:
/// - Starts at 0 (immediate reconnect) if this is the first reconnect or enough time has passed
/// since the last reconnect
/// - Increments exponentially for continuous failures
/// - Resets to 0 only when: this is a new call AND current_time > last_connect_time + residual_backoff
///
/// The mutex ensures only one reconnect operation runs at a time globally
async fn reconnect_impl(&self, deadline: std::time::Instant) -> Result<()> {
let mut backoff_state = self.backoff_state.lock().await;
tracing::warn!("Reconnecting to ETCD cluster at: {:?}", self.etcd_urls);
backoff_state.attempt_reset();
loop {
backoff_state.apply_backoff(deadline).await;
if std::time::Instant::now() >= deadline {
// Clear the pending reconnection before returning error
self.reconnect_pending.remove(&());
return Err(error!(
"Unable to reconnect to ETCD cluster: deadline exceeded"
));
}
match Self::connect(&self.etcd_urls, &self.connect_options).await {
Ok(new_client) => {
tracing::info!("Successfully reconnected to ETCD cluster");
// Update the client behind the lock
let mut client_guard = self.client.write();
*client_guard = new_client;
// Clear the pending reconnection
self.reconnect_pending.remove(&());
return Ok(());
}
Err(e) => {
tracing::warn!(
"Reconnection failed (remaining time: {:?}): {}",
deadline.saturating_duration_since(std::time::Instant::now()),
e
);
}
}
}
}
/// Get the ETCD URLs
#[allow(dead_code)]
pub fn etcd_urls(&self) -> &[String] {
&self.etcd_urls
}
/// Get the connection options
#[allow(dead_code)]
pub fn connect_options(&self) -> &Option<ConnectOptions> {
&self.connect_options
}
}
#[derive(Debug)]
struct BackoffState {
/// Initial backoff duration for reconnection attempts
pub initial_backoff: Duration,
/// Minimum backoff duration for reconnection attempts
pub min_backoff: Duration,
/// Maximum backoff duration for reconnection attempts
pub max_backoff: Duration,
/// Current backoff duration (starts at 0 for immediate reconnect)
current_backoff: Duration,
/// Last time a connection establishment was attempted
last_connect_attempt: std::time::Instant,
}
impl Default for BackoffState {
fn default() -> Self {
Self {
initial_backoff: Duration::from_millis(500),
min_backoff: Duration::from_millis(50),
max_backoff: Duration::from_secs(5),
current_backoff: Duration::ZERO,
last_connect_attempt: std::time::Instant::now(),
}
}
}
impl BackoffState {
/// Create a new BackoffState with custom parameters.
pub fn new(initial_backoff: Duration, min_backoff: Duration, max_backoff: Duration) -> Self {
Self {
initial_backoff,
min_backoff,
max_backoff,
current_backoff: Duration::ZERO,
last_connect_attempt: std::time::Instant::now(),
}
}
/// Reset backoff to 0 if enough time has passed since the last connection
pub fn attempt_reset(&mut self) {
if std::time::Instant::now() > self.last_connect_attempt + self.current_backoff {
tracing::debug!("Resetting backoff to 0 (first reconnect or enough time has passed)");
self.current_backoff = Duration::ZERO;
}
}
/// Apply backoff and update backoff state for possible next connection attempt
pub async fn apply_backoff(&mut self, deadline: std::time::Instant) {
if self.current_backoff > Duration::ZERO {
let remaining = deadline.saturating_duration_since(std::time::Instant::now());
let backoff = std::cmp::min(self.current_backoff, remaining / 2);
let backoff = std::cmp::min(backoff, self.max_backoff);
let backoff = std::cmp::max(backoff, self.min_backoff);
self.current_backoff = backoff * 2;
tracing::debug!(
"Applying backoff of {:?} (remaining time: {:?})",
backoff,
remaining
);
sleep(backoff).await;
} else {
self.current_backoff = self.initial_backoff;
}
self.last_connect_attempt = std::time::Instant::now();
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Error classification for etcd operations.
//!
//! Categorizes etcd errors into reconnectable, expected, or fatal conditions
//! to enable smart retry logic.
use std::fmt;
use tonic::Code;
/// Errors that indicate a connection issue requiring reconnection.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ReconnectableError {
/// Connection to etcd server was closed
ConnectionClosed,
/// Operation timed out
Timeout,
/// Service unavailable (etcd server down or unreachable)
Unavailable,
/// Lease was not found (may have expired during disconnect)
LeaseNotFound,
}
impl fmt::Display for ReconnectableError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::ConnectionClosed => write!(f, "connection closed"),
Self::Timeout => write!(f, "operation timed out"),
Self::Unavailable => write!(f, "service unavailable"),
Self::LeaseNotFound => write!(f, "lease not found"),
}
}
}
/// Classification of etcd errors for determining retry strategy.
#[derive(Debug)]
pub(crate) enum EtcdErrorClass {
/// Error should trigger reconnection and retry
Reconnectable(ReconnectableError),
/// Expected condition (key not found) - not an error
NotFound,
/// Fatal error that cannot be recovered by reconnecting
Fatal(anyhow::Error),
}
/// Classify an etcd error to determine appropriate handling.
///
/// # Classification Strategy
///
/// - **Reconnectable**: Connection/transport errors that can be fixed by reconnecting
/// - **NotFound**: Key doesn't exist (expected condition for queries)
/// - **Fatal**: All other errors (permissions, invalid request, etc.)
pub(crate) fn classify_error(err: etcd_client::Error) -> EtcdErrorClass {
// Use structured error matching instead of fragile string matching
match err {
etcd_client::Error::GRpcStatus(status) => {
// Classify based on gRPC status code
match status.code() {
Code::NotFound => {
// Check if it's a lease not found or key not found
let msg = status.message().to_lowercase();
if msg.contains("lease") {
EtcdErrorClass::Reconnectable(ReconnectableError::LeaseNotFound)
} else {
// Key not found is expected, not an error
EtcdErrorClass::NotFound
}
}
Code::Unavailable => EtcdErrorClass::Reconnectable(ReconnectableError::Unavailable),
Code::DeadlineExceeded => {
EtcdErrorClass::Reconnectable(ReconnectableError::Timeout)
}
Code::Cancelled | Code::Aborted => {
// Connection-related cancellations
EtcdErrorClass::Reconnectable(ReconnectableError::ConnectionClosed)
}
_ => {
// All other gRPC errors are fatal
EtcdErrorClass::Fatal(anyhow::anyhow!(
"gRPC error: {} (code: {:?})",
status.message(),
status.code()
))
}
}
}
etcd_client::Error::TransportError(_) => {
// Transport errors are reconnectable
EtcdErrorClass::Reconnectable(ReconnectableError::Unavailable)
}
etcd_client::Error::IoError(_) => {
// I/O errors are reconnectable
EtcdErrorClass::Reconnectable(ReconnectableError::ConnectionClosed)
}
_ => {
// All other errors (LeaseKeepAliveError, etc.) are fatal
EtcdErrorClass::Fatal(err.into())
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Resilient keep-alive task for etcd leases.
//!
//! Handles periodic keep-alive requests to prevent lease expiration,
//! with automatic reconnection and recovery on failure.
use crate::systems::etcd::client::Client;
use crate::systems::etcd::lease::LeaseState;
use anyhow::{Context, Result};
use parking_lot::RwLock;
use std::sync::Arc;
use std::time::Duration;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
/// Background task that keeps an etcd lease alive.
///
/// # Resilience Strategy
///
/// - Acquires client and starts keep-alive stream
/// - Uses stream until failure (does NOT hold client lock)
/// - On failure: triggers reconnection, reacquires client, restarts
/// - Respects shutdown signal for clean termination
pub struct KeepAliveTask {
client: Arc<Client>,
lease_state: Arc<RwLock<LeaseState>>,
ttl: Duration,
shutdown: CancellationToken,
}
impl KeepAliveTask {
/// Create a new keep-alive task.
pub fn new(
client: Arc<Client>,
lease_state: Arc<RwLock<LeaseState>>,
ttl: Duration,
shutdown: CancellationToken,
) -> Self {
Self {
client,
lease_state,
ttl,
shutdown,
}
}
/// Spawn the keep-alive task as a background tokio task.
pub fn spawn(self) -> JoinHandle<()> {
tokio::spawn(async move {
tracing::debug!("Keep-alive task starting");
loop {
// Check for shutdown signal
if self.shutdown.is_cancelled() {
tracing::debug!("Keep-alive task shutting down");
break;
}
// Run keep-alive loop with automatic recovery
if let Err(e) = self.run_keep_alive_loop().await {
tracing::error!("Keep-alive loop failed: {}", e);
// Trigger reconnection before restarting (force=true)
let deadline = std::time::Instant::now() + Duration::from_secs(30);
if let Err(e) = self.client.ensure_connected(deadline, true).await {
tracing::error!("Failed to reconnect after keep-alive failure: {}", e);
// Wait before retry to avoid tight loop
tokio::time::sleep(Duration::from_secs(5)).await;
} else {
tracing::info!("Reconnected successfully, restarting keep-alive");
}
}
}
tracing::debug!("Keep-alive task exited");
})
}
/// Run the keep-alive loop until failure or shutdown.
///
/// # Strategy
///
/// 1. Get lease ID from state
/// 2. Acquire client and start keep-alive stream (brief lock)
/// 3. Release client lock
/// 4. Use keeper/stream handles until they fail
/// 5. On failure, return error (outer loop handles reconnection)
async fn run_keep_alive_loop(&self) -> Result<()> {
// Get current lease ID
let lease_id = self
.lease_state
.read()
.lease_id()
.ok_or_else(|| anyhow::anyhow!("No lease ID available"))?;
tracing::debug!("Starting keep-alive loop for lease {}", lease_id);
// Acquire client and start keep-alive stream (brief lock acquisition)
let mut client = self.client.get_client();
let (mut keeper, mut stream) = client
.lease_keep_alive(lease_id)
.await
.context("Failed to start lease keep-alive stream")?;
// Client lock is released here - we now only use keeper/stream handles
// Calculate sleep interval (TTL / 3, with minimum of 1 second)
let sleep_interval = Duration::from_secs((self.ttl.as_secs() / 3).max(1));
loop {
// Check for messages from the stream
tokio::select! {
// Shutdown signal
_ = self.shutdown.cancelled() => {
tracing::debug!("Keep-alive loop received shutdown signal");
return Ok(());
}
// Keep-alive response from etcd
msg = stream.message() => {
match msg {
Ok(Some(_resp)) => {
tracing::trace!("Received keep-alive response for lease {}", lease_id);
// Successful keep-alive, continue
}
Ok(None) => {
tracing::warn!("Keep-alive stream closed for lease {}", lease_id);
return Err(anyhow::anyhow!("Keep-alive stream closed"));
}
Err(e) => {
tracing::warn!("Keep-alive stream error for lease {}: {}", lease_id, e);
return Err(e.into());
}
}
}
}
// Wait before sending next keep-alive
tokio::select! {
_ = self.shutdown.cancelled() => {
tracing::debug!("Keep-alive loop received shutdown signal during sleep");
return Ok(());
}
_ = tokio::time::sleep(sleep_interval) => {
// Time to send next keep-alive
}
}
// Send keep-alive request
if let Err(e) = keeper.keep_alive().await {
tracing::warn!("Failed to send keep-alive for lease {}: {}", lease_id, e);
return Err(e.into());
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_keep_alive_task_creation() {
// Test that we can create a keep-alive task
// (actual testing requires running etcd instance)
// This is a smoke test to ensure the struct compiles
let ttl = Duration::from_secs(60);
let sleep_interval = (ttl.as_secs() / 3).max(1);
assert_eq!(sleep_interval, 20);
}
#[test]
fn test_sleep_interval_calculation() {
// Test sleep interval calculation
let ttl = Duration::from_secs(60);
let interval = (ttl.as_secs() / 3).max(1);
assert_eq!(interval, 20);
let ttl = Duration::from_secs(10);
let interval = (ttl.as_secs() / 3).max(1);
assert_eq!(interval, 3);
let ttl = Duration::from_secs(2);
let interval = (ttl.as_secs() / 3).max(1);
assert_eq!(interval, 1); // Minimum of 1 second
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Lease management for etcd peer discovery.
//!
//! Handles lease creation, validation, and renewal. Attempts to reuse
//! existing leases when reconnecting to avoid unnecessary re-registration.
//!
//! # Lease Revocation Limitation
//!
//! **IMPORTANT**: If an etcd lease is revoked (either manually or due to
//! network partition), all keys associated with that lease are automatically
//! deleted by etcd. This is an **unrecoverable** state in the current
//! implementation because:
//!
//! 1. The system does not track which keys were published under a lease
//! 2. When a lease is revoked, we create a new lease but cannot republish
//! the deleted keys
//! 3. All peer registrations made with the old lease are permanently lost
//!
//! **Mitigation**: The keep-alive mechanism maintains the lease actively,
//! reducing the chance of expiration. However, extended network partitions
//! or manual lease revocation will result in lost registrations that require
//! application-level re-registration.
use anyhow::{Context, Result};
use std::time::{Duration, Instant};
use tonic::Code;
/// Result of checking lease validity.
///
/// Provides clear information about why a lease is valid or invalid.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum LeaseValidityState {
/// Lease is valid with the specified remaining TTL in seconds
Valid { remaining_ttl: i64 },
/// Lease has expired (TTL <= minimum threshold)
Expired,
/// Lease was not found on the etcd server
NotFound,
/// Failed to check lease validity (network error, etc.)
CheckFailed(String),
}
impl LeaseValidityState {
/// Returns true if the lease is valid and can be reused.
#[allow(dead_code)]
pub fn is_valid(&self) -> bool {
matches!(self, LeaseValidityState::Valid { .. })
}
}
/// State tracking for an etcd lease.
#[derive(Debug)]
pub struct LeaseState {
/// Current lease ID, if one exists
lease_id: Option<i64>,
/// When the lease was created
created_at: Option<Instant>,
/// Lease TTL duration
ttl: Duration,
}
impl LeaseState {
/// Create a new lease state with the specified TTL.
pub fn new(ttl: Duration) -> Self {
Self {
lease_id: None,
created_at: None,
ttl,
}
}
/// Get the current lease ID if one exists.
pub fn lease_id(&self) -> Option<i64> {
self.lease_id
}
/// Get the lease TTL.
#[allow(dead_code)]
pub fn ttl(&self) -> Duration {
self.ttl
}
/// Ensure a valid lease exists, reusing the current one if still valid
/// or creating a new one if expired/not found.
///
/// # Strategy
///
/// 1. If we have a lease ID, check if it's still valid (TTL > 1/3 remaining)
/// 2. If valid, return the existing lease ID
/// 3. If invalid or not found, create a new lease
///
/// This allows us to survive transient disconnections without losing
/// our registrations, while still creating a new lease if needed.
pub async fn ensure_lease(&mut self, client: &mut etcd_client::Client) -> Result<i64> {
// Try to reuse existing lease if it's still valid
if let Some(lease_id) = self.lease_id {
match self.check_lease_validity(client, lease_id).await {
LeaseValidityState::Valid { remaining_ttl } => {
tracing::debug!(
"Reusing existing lease ID: {} (remaining TTL: {}s)",
lease_id,
remaining_ttl
);
return Ok(lease_id);
}
LeaseValidityState::Expired => {
tracing::debug!("Existing lease {} expired, creating new lease", lease_id);
}
LeaseValidityState::NotFound => {
// CRITICAL: When a lease is not found (revoked), all keys associated
// with it are already deleted by etcd. Creating a new lease will NOT
// restore those keys. This is an unrecoverable state - the caller
// must re-register all instances.
tracing::warn!(
"Existing lease {} not found on server (revoked). All keys associated \
with this lease have been deleted. Creating new lease - caller must \
re-register instances.",
lease_id
);
}
LeaseValidityState::CheckFailed(err) => {
tracing::warn!(
"Failed to check lease {} validity: {}, creating new lease",
lease_id,
err
);
}
}
}
// Create new lease
self.create_new_lease(client).await
}
/// Check if a lease is still valid (has > 1/3 of TTL remaining).
///
/// Returns a `LeaseValidityState` that provides clear information about
/// the lease status.
async fn check_lease_validity(
&self,
client: &mut etcd_client::Client,
lease_id: i64,
) -> LeaseValidityState {
// Try to get lease TTL
let resp = match client.lease_time_to_live(lease_id, None).await {
Ok(resp) => resp,
Err(e) => {
// Use structured error matching instead of fragile string matching
match e {
etcd_client::Error::GRpcStatus(status) => {
match status.code() {
Code::NotFound => {
// Lease not found on the server
return LeaseValidityState::NotFound;
}
_ => {
// Other gRPC errors
return LeaseValidityState::CheckFailed(format!(
"gRPC error: {} (code: {:?})",
status.message(),
status.code()
));
}
}
}
_ => {
// Non-gRPC errors (transport, IO, etc.)
return LeaseValidityState::CheckFailed(e.to_string());
}
}
}
};
let remaining_ttl = resp.ttl();
// TTL of 0 or negative means the lease is already gone
if remaining_ttl <= 0 {
return LeaseValidityState::NotFound;
}
// Consider lease valid if it has more than 1/3 of original TTL remaining
let min_ttl = (self.ttl.as_secs() as i64) / 3;
if remaining_ttl > min_ttl {
LeaseValidityState::Valid { remaining_ttl }
} else {
LeaseValidityState::Expired
}
}
/// Create a new lease with the configured TTL.
async fn create_new_lease(&mut self, client: &mut etcd_client::Client) -> Result<i64> {
let ttl_secs = self.ttl.as_secs() as i64;
let resp = client
.lease_grant(ttl_secs, None)
.await
.context("Failed to create new lease")?;
let lease_id = resp.id();
tracing::info!("Created new lease ID: {} (TTL: {}s)", lease_id, ttl_secs);
self.lease_id = Some(lease_id);
self.created_at = Some(Instant::now());
Ok(lease_id)
}
/// Clear the current lease state (e.g., after failed reconnection).
#[allow(dead_code)]
pub fn clear(&mut self) {
self.lease_id = None;
self.created_at = None;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lease_state_creation() {
let ttl = Duration::from_secs(60);
let state = LeaseState::new(ttl);
assert_eq!(state.lease_id(), None);
assert_eq!(state.ttl(), ttl);
}
#[test]
fn test_lease_state_clear() {
let mut state = LeaseState::new(Duration::from_secs(60));
state.lease_id = Some(12345);
state.created_at = Some(Instant::now());
state.clear();
assert_eq!(state.lease_id(), None);
assert_eq!(state.created_at, None);
}
}
This diff is collapsed.
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Operation execution with automatic retry and reconnection.
//!
//! Wraps etcd operations to handle transient connection failures transparently.
use crate::peer::{DiscoveryError, DiscoveryQueryError};
use crate::systems::etcd::client::Client;
use crate::systems::etcd::error::{EtcdErrorClass, classify_error};
use anyhow::Result;
use futures::future::BoxFuture;
use std::sync::Arc;
use std::time::{Duration, Instant};
/// Executes etcd operations with automatic reconnection on transient errors.
#[derive(Clone)]
pub struct OperationExecutor {
client: Arc<Client>,
default_timeout: Duration,
max_retries: u32,
}
impl OperationExecutor {
/// Create a new operation executor.
pub fn new(client: Arc<Client>, default_timeout: Duration, max_retries: u32) -> Self {
Self {
client,
default_timeout,
max_retries,
}
}
/// Execute a query operation with automatic retry on reconnectable errors.
///
/// # Arguments
///
/// * `op` - Function that performs the etcd operation given a client
///
/// # Returns
///
/// * `Ok(T)` - Operation succeeded
/// * `Err(DiscoveryQueryError::NotFound)` - Key not found (expected)
/// * `Err(DiscoveryQueryError::Backend)` - Fatal error or timeout
///
/// # Behavior
///
/// 1. Acquire client (brief RwLock read)
/// 2. Execute operation
/// 3. On reconnectable error:
/// - Trigger reconnection via `ensure_connected()`
/// - Retry operation
/// 4. On NotFound: return DiscoveryQueryError::NotFound
/// 5. On Fatal error: return DiscoveryQueryError::Backend
pub async fn execute_query<F, T>(&self, op: F) -> Result<T, DiscoveryQueryError>
where
F: Fn(etcd_client::Client) -> BoxFuture<'static, Result<T, etcd_client::Error>>,
{
let deadline = Instant::now() + self.default_timeout;
let mut retry_count = 0;
loop {
// Check deadline
if Instant::now() >= deadline {
return Err(DiscoveryQueryError::Backend(Arc::new(anyhow::anyhow!(
"Operation timed out after {:?}",
self.default_timeout
))));
}
// Await any in-progress reconnection (lightweight check)
if let Err(e) = self.client.ensure_connected(deadline, false).await {
return Err(DiscoveryQueryError::Backend(Arc::new(e)));
}
// Acquire client (brief lock)
let client = self.client.get_client();
// Execute operation
match op(client).await {
Ok(result) => {
return Ok(result);
}
Err(err) => {
// Classify the error to determine action
match classify_error(err) {
EtcdErrorClass::Reconnectable(kind) => {
retry_count += 1;
if retry_count >= self.max_retries {
tracing::error!(
"Max retries ({}) exceeded for reconnectable error: {:?}",
self.max_retries,
kind
);
return Err(DiscoveryQueryError::Backend(Arc::new(
anyhow::anyhow!("Max retries exceeded: {}", kind),
)));
}
tracing::debug!(
"Reconnectable error (attempt {}/{}): {:?}, retrying...",
retry_count,
self.max_retries,
kind
);
// Trigger reconnection (force=true)
if let Err(e) = self.client.ensure_connected(deadline, true).await {
tracing::error!("Failed to reconnect: {}", e);
return Err(DiscoveryQueryError::Backend(Arc::new(e)));
}
// Loop will retry operation
continue;
}
EtcdErrorClass::NotFound => {
return Err(DiscoveryQueryError::NotFound);
}
EtcdErrorClass::Fatal(e) => {
return Err(DiscoveryQueryError::Backend(Arc::new(e)));
}
}
}
}
}
}
/// Execute a write operation (register/unregister) with automatic retry.
///
/// Similar to `execute_query` but returns `DiscoveryError` instead.
pub async fn execute_write<F>(&self, op: F) -> Result<(), DiscoveryError>
where
F: Fn(etcd_client::Client) -> BoxFuture<'static, Result<(), etcd_client::Error>>,
{
let deadline = Instant::now() + self.default_timeout;
let mut retry_count = 0;
loop {
// Check deadline
if Instant::now() >= deadline {
return Err(DiscoveryError::Backend(anyhow::anyhow!(
"Operation timed out after {:?}",
self.default_timeout
)));
}
// Await any in-progress reconnection (lightweight check)
if let Err(e) = self.client.ensure_connected(deadline, false).await {
return Err(DiscoveryError::Backend(e));
}
// Acquire client (brief lock)
let client = self.client.get_client();
// Execute operation
match op(client).await {
Ok(()) => {
return Ok(());
}
Err(err) => {
// Classify the error to determine action
match classify_error(err) {
EtcdErrorClass::Reconnectable(kind) => {
retry_count += 1;
if retry_count >= self.max_retries {
tracing::error!(
"Max retries ({}) exceeded for reconnectable error: {:?}",
self.max_retries,
kind
);
return Err(DiscoveryError::Backend(anyhow::anyhow!(
"Max retries exceeded: {}",
kind
)));
}
tracing::debug!(
"Reconnectable error (attempt {}/{}): {:?}, retrying...",
retry_count,
self.max_retries,
kind
);
// Trigger reconnection (force=true)
if let Err(e) = self.client.ensure_connected(deadline, true).await {
tracing::error!("Failed to reconnect: {}", e);
return Err(DiscoveryError::Backend(e));
}
// Loop will retry operation
continue;
}
EtcdErrorClass::NotFound => {
// For writes, NotFound might be valid (e.g., deleting non-existent key)
// Treat as success
tracing::debug!("Write operation: key not found (treating as success)");
return Ok(());
}
EtcdErrorClass::Fatal(e) => {
return Err(DiscoveryError::Backend(e));
}
}
}
}
}
}
/// Get the underlying client reference.
#[allow(dead_code)]
pub fn client(&self) -> &Arc<Client> {
&self.client
}
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment