// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 use anyhow::{Context, Result, anyhow as error}; use dashmap::DashMap; use etcd_client::ConnectOptions; use futures::future::{BoxFuture, FutureExt, Shared}; use parking_lot::RwLock; use std::{sync::Arc, time::Duration}; use tokio::{sync::Mutex, time::sleep}; /// Type alias for the shared reconnection future type ReconnectFuture = Shared>>>; /// Manages ETCD client connections with reconnection support #[derive(Clone)] pub struct Client { /// The actual ETCD client, protected by RwLock for safe updates during reconnection /// WARNING: Do not recursively acquire a read lock when the current thread already holds one client: Arc>, /// Configuration for connecting to ETCD etcd_urls: Arc>, connect_options: Arc>, /// Tracks the current backoff duration and last successful connect time /// The Mutex ensures only one reconnect operation runs at a time backoff_state: Arc>, /// Shared reconnection futures for deduplication /// Only one reconnection happens at a time; concurrent callers share the future reconnect_pending: Arc>, } impl Client { /// Create a new connector with an established connection pub async fn new( etcd_urls: Vec, connect_options: Option, initial_backoff: Duration, min_backoff: Duration, max_backoff: Duration, ) -> Result { // Connect to ETCD let client = Self::connect(&etcd_urls, &connect_options).await?; Ok(Self { client: Arc::new(RwLock::new(client)), etcd_urls: Arc::new(etcd_urls), connect_options: Arc::new(connect_options), backoff_state: Arc::new(Mutex::new(BackoffState::new( initial_backoff, min_backoff, max_backoff, ))), reconnect_pending: Arc::new(DashMap::new()), }) } /// Connect to ETCD cluster async fn connect( etcd_urls: &[String], connect_options: &Option, ) -> Result { etcd_client::Client::connect(etcd_urls.to_vec(), connect_options.clone()) .await .with_context(|| { format!( "Unable to connect to etcd server at {}. Check etcd server status", etcd_urls.join(", ") ) }) } /// Get a clone of the current ETCD client pub fn get_client(&self) -> etcd_client::Client { self.client.read().clone() } /// Ensure the client is connected, triggering reconnection if needed. /// /// This method deduplicates concurrent reconnection attempts - only one /// reconnection happens at a time, with all callers sharing the same future. /// /// # Arguments /// * `deadline` - Deadline for reconnection attempts /// * `force` - If true, start reconnection even if not already in progress /// /// Returns Ok(()) if connected, Err if reconnection failed. pub async fn ensure_connected(&self, deadline: std::time::Instant, force: bool) -> Result<()> { // Check if reconnection already in progress if let Some(shared_future_ref) = self.reconnect_pending.get(&()) { let shared = shared_future_ref.clone(); drop(shared_future_ref); // Release DashMap lock before await let result = shared.await.map_err(|e| anyhow::anyhow!("{}", e)); if result.is_err() { // Clean up failed future so subsequent calls can retry self.reconnect_pending.remove(&()); } return result; } // If not forced, assume we're connected (lightweight path) if !force { return Ok(()); } // Start new reconnection (deduplicated) use dashmap::mapref::entry::Entry; let shared_future = match self.reconnect_pending.entry(()) { Entry::Occupied(entry) => { // Another thread started reconnection, use their future entry.get().clone() } Entry::Vacant(entry) => { // We're first, create the shared future let client = self.clone(); let shared = async move { client.reconnect_impl(deadline).await.map_err(Arc::new) } .boxed() .shared(); entry.insert(shared.clone()); shared } }; let result = shared_future.await.map_err(|e| anyhow::anyhow!("{}", e)); if result.is_err() { // Clean up failed future so subsequent calls can retry self.reconnect_pending.remove(&()); } result } /// Internal implementation of reconnection with retry logic. /// Respects the deadline and returns error if exceeded. /// /// Backoff behavior: /// - Starts at 0 (immediate reconnect) if this is the first reconnect or enough time has passed /// since the last reconnect /// - Increments exponentially for continuous failures /// - Resets to 0 only when: this is a new call AND current_time > last_connect_time + residual_backoff /// /// The mutex ensures only one reconnect operation runs at a time globally async fn reconnect_impl(&self, deadline: std::time::Instant) -> Result<()> { let mut backoff_state = self.backoff_state.lock().await; tracing::warn!("Reconnecting to ETCD cluster at: {:?}", self.etcd_urls); backoff_state.attempt_reset(); loop { backoff_state.apply_backoff(deadline).await; if std::time::Instant::now() >= deadline { // Clear the pending reconnection before returning error self.reconnect_pending.remove(&()); return Err(error!( "Unable to reconnect to ETCD cluster: deadline exceeded" )); } match Self::connect(&self.etcd_urls, &self.connect_options).await { Ok(new_client) => { tracing::info!("Successfully reconnected to ETCD cluster"); // Update the client behind the lock let mut client_guard = self.client.write(); *client_guard = new_client; // Clear the pending reconnection self.reconnect_pending.remove(&()); return Ok(()); } Err(e) => { tracing::warn!( "Reconnection failed (remaining time: {:?}): {}", deadline.saturating_duration_since(std::time::Instant::now()), e ); } } } } /// Get the ETCD URLs #[allow(dead_code)] pub fn etcd_urls(&self) -> &[String] { &self.etcd_urls } /// Get the connection options #[allow(dead_code)] pub fn connect_options(&self) -> &Option { &self.connect_options } } #[derive(Debug)] struct BackoffState { /// Initial backoff duration for reconnection attempts pub initial_backoff: Duration, /// Minimum backoff duration for reconnection attempts pub min_backoff: Duration, /// Maximum backoff duration for reconnection attempts pub max_backoff: Duration, /// Current backoff duration (starts at 0 for immediate reconnect) current_backoff: Duration, /// Last time a connection establishment was attempted last_connect_attempt: std::time::Instant, } impl Default for BackoffState { fn default() -> Self { Self { initial_backoff: Duration::from_millis(500), min_backoff: Duration::from_millis(50), max_backoff: Duration::from_secs(5), current_backoff: Duration::ZERO, last_connect_attempt: std::time::Instant::now(), } } } impl BackoffState { /// Create a new BackoffState with custom parameters. pub fn new(initial_backoff: Duration, min_backoff: Duration, max_backoff: Duration) -> Self { Self { initial_backoff, min_backoff, max_backoff, current_backoff: Duration::ZERO, last_connect_attempt: std::time::Instant::now(), } } /// Reset backoff to 0 if enough time has passed since the last connection pub fn attempt_reset(&mut self) { if std::time::Instant::now() > self.last_connect_attempt + self.current_backoff { tracing::debug!("Resetting backoff to 0 (first reconnect or enough time has passed)"); self.current_backoff = Duration::ZERO; } } /// Apply backoff and update backoff state for possible next connection attempt pub async fn apply_backoff(&mut self, deadline: std::time::Instant) { if self.current_backoff > Duration::ZERO { let remaining = deadline.saturating_duration_since(std::time::Instant::now()); let backoff = std::cmp::min(self.current_backoff, remaining / 2); let backoff = std::cmp::min(backoff, self.max_backoff); let backoff = std::cmp::max(backoff, self.min_backoff); self.current_backoff = backoff * 2; tracing::debug!( "Applying backoff of {:?} (remaining time: {:?})", backoff, remaining ); sleep(backoff).await; } else { self.current_backoff = self.initial_backoff; } self.last_connect_attempt = std::time::Instant::now(); } }