client.rs 9.75 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use anyhow::{Context, Result, anyhow as error};
use dashmap::DashMap;
use etcd_client::ConnectOptions;
use futures::future::{BoxFuture, FutureExt, Shared};
use parking_lot::RwLock;
use std::{sync::Arc, time::Duration};
use tokio::{sync::Mutex, time::sleep};

/// Type alias for the shared reconnection future
type ReconnectFuture = Shared<BoxFuture<'static, Result<(), Arc<anyhow::Error>>>>;

/// Manages ETCD client connections with reconnection support
#[derive(Clone)]
pub struct Client {
    /// The actual ETCD client, protected by RwLock for safe updates during reconnection
    /// WARNING: Do not recursively acquire a read lock when the current thread already holds one
    client: Arc<RwLock<etcd_client::Client>>,
    /// Configuration for connecting to ETCD
    etcd_urls: Arc<Vec<String>>,
    connect_options: Arc<Option<ConnectOptions>>,
    /// Tracks the current backoff duration and last successful connect time
    /// The Mutex ensures only one reconnect operation runs at a time
    backoff_state: Arc<Mutex<BackoffState>>,
    /// Shared reconnection futures for deduplication
    /// Only one reconnection happens at a time; concurrent callers share the future
    reconnect_pending: Arc<DashMap<(), ReconnectFuture>>,
}

impl Client {
    /// Create a new connector with an established connection
    pub async fn new(
        etcd_urls: Vec<String>,
        connect_options: Option<ConnectOptions>,
        initial_backoff: Duration,
        min_backoff: Duration,
        max_backoff: Duration,
    ) -> Result<Self> {
        // Connect to ETCD
        let client = Self::connect(&etcd_urls, &connect_options).await?;

        Ok(Self {
            client: Arc::new(RwLock::new(client)),
            etcd_urls: Arc::new(etcd_urls),
            connect_options: Arc::new(connect_options),
            backoff_state: Arc::new(Mutex::new(BackoffState::new(
                initial_backoff,
                min_backoff,
                max_backoff,
            ))),
            reconnect_pending: Arc::new(DashMap::new()),
        })
    }

    /// Connect to ETCD cluster
    async fn connect(
        etcd_urls: &[String],
        connect_options: &Option<ConnectOptions>,
    ) -> Result<etcd_client::Client> {
        etcd_client::Client::connect(etcd_urls.to_vec(), connect_options.clone())
            .await
            .with_context(|| {
                format!(
                    "Unable to connect to etcd server at {}. Check etcd server status",
                    etcd_urls.join(", ")
                )
            })
    }

    /// Get a clone of the current ETCD client
    pub fn get_client(&self) -> etcd_client::Client {
        self.client.read().clone()
    }

    /// Ensure the client is connected, triggering reconnection if needed.
    ///
    /// This method deduplicates concurrent reconnection attempts - only one
    /// reconnection happens at a time, with all callers sharing the same future.
    ///
    /// # Arguments
    /// * `deadline` - Deadline for reconnection attempts
    /// * `force` - If true, start reconnection even if not already in progress
    ///
    /// Returns Ok(()) if connected, Err if reconnection failed.
    pub async fn ensure_connected(&self, deadline: std::time::Instant, force: bool) -> Result<()> {
        // Check if reconnection already in progress
        if let Some(shared_future_ref) = self.reconnect_pending.get(&()) {
            let shared = shared_future_ref.clone();
            drop(shared_future_ref); // Release DashMap lock before await
            let result = shared.await.map_err(|e| anyhow::anyhow!("{}", e));
            if result.is_err() {
                // Clean up failed future so subsequent calls can retry
                self.reconnect_pending.remove(&());
            }
            return result;
        }

        // If not forced, assume we're connected (lightweight path)
        if !force {
            return Ok(());
        }

        // Start new reconnection (deduplicated)
        use dashmap::mapref::entry::Entry;
        let shared_future = match self.reconnect_pending.entry(()) {
            Entry::Occupied(entry) => {
                // Another thread started reconnection, use their future
                entry.get().clone()
            }
            Entry::Vacant(entry) => {
                // We're first, create the shared future
                let client = self.clone();
                let shared = async move { client.reconnect_impl(deadline).await.map_err(Arc::new) }
                    .boxed()
                    .shared();

                entry.insert(shared.clone());
                shared
            }
        };

        let result = shared_future.await.map_err(|e| anyhow::anyhow!("{}", e));
        if result.is_err() {
            // Clean up failed future so subsequent calls can retry
            self.reconnect_pending.remove(&());
        }
        result
    }

    /// Internal implementation of reconnection with retry logic.
    /// Respects the deadline and returns error if exceeded.
    ///
    /// Backoff behavior:
    /// - Starts at 0 (immediate reconnect) if this is the first reconnect or enough time has passed
    ///   since the last reconnect
    /// - Increments exponentially for continuous failures
    /// - Resets to 0 only when: this is a new call AND current_time > last_connect_time + residual_backoff
    ///
    /// The mutex ensures only one reconnect operation runs at a time globally
    async fn reconnect_impl(&self, deadline: std::time::Instant) -> Result<()> {
        let mut backoff_state = self.backoff_state.lock().await;

        tracing::warn!("Reconnecting to ETCD cluster at: {:?}", self.etcd_urls);
        backoff_state.attempt_reset();

        loop {
            backoff_state.apply_backoff(deadline).await;
            if std::time::Instant::now() >= deadline {
                // Clear the pending reconnection before returning error
                self.reconnect_pending.remove(&());
                return Err(error!(
                    "Unable to reconnect to ETCD cluster: deadline exceeded"
                ));
            }

            match Self::connect(&self.etcd_urls, &self.connect_options).await {
                Ok(new_client) => {
                    tracing::info!("Successfully reconnected to ETCD cluster");
                    // Update the client behind the lock
                    let mut client_guard = self.client.write();
                    *client_guard = new_client;

                    // Clear the pending reconnection
                    self.reconnect_pending.remove(&());

                    return Ok(());
                }
                Err(e) => {
                    tracing::warn!(
                        "Reconnection failed (remaining time: {:?}): {}",
                        deadline.saturating_duration_since(std::time::Instant::now()),
                        e
                    );
                }
            }
        }
    }

    /// Get the ETCD URLs
    #[allow(dead_code)]
    pub fn etcd_urls(&self) -> &[String] {
        &self.etcd_urls
    }

    /// Get the connection options
    #[allow(dead_code)]
    pub fn connect_options(&self) -> &Option<ConnectOptions> {
        &self.connect_options
    }
}

#[derive(Debug)]
struct BackoffState {
    /// Initial backoff duration for reconnection attempts
    pub initial_backoff: Duration,
    /// Minimum backoff duration for reconnection attempts
    pub min_backoff: Duration,
    /// Maximum backoff duration for reconnection attempts
    pub max_backoff: Duration,
    /// Current backoff duration (starts at 0 for immediate reconnect)
    current_backoff: Duration,
    /// Last time a connection establishment was attempted
    last_connect_attempt: std::time::Instant,
}

impl Default for BackoffState {
    fn default() -> Self {
        Self {
            initial_backoff: Duration::from_millis(500),
            min_backoff: Duration::from_millis(50),
            max_backoff: Duration::from_secs(5),
            current_backoff: Duration::ZERO,
            last_connect_attempt: std::time::Instant::now(),
        }
    }
}

impl BackoffState {
    /// Create a new BackoffState with custom parameters.
    pub fn new(initial_backoff: Duration, min_backoff: Duration, max_backoff: Duration) -> Self {
        Self {
            initial_backoff,
            min_backoff,
            max_backoff,
            current_backoff: Duration::ZERO,
            last_connect_attempt: std::time::Instant::now(),
        }
    }

    /// Reset backoff to 0 if enough time has passed since the last connection
    pub fn attempt_reset(&mut self) {
        if std::time::Instant::now() > self.last_connect_attempt + self.current_backoff {
            tracing::debug!("Resetting backoff to 0 (first reconnect or enough time has passed)");
            self.current_backoff = Duration::ZERO;
        }
    }

    /// Apply backoff and update backoff state for possible next connection attempt
    pub async fn apply_backoff(&mut self, deadline: std::time::Instant) {
        if self.current_backoff > Duration::ZERO {
            let remaining = deadline.saturating_duration_since(std::time::Instant::now());
            let backoff = std::cmp::min(self.current_backoff, remaining / 2);
            let backoff = std::cmp::min(backoff, self.max_backoff);
            let backoff = std::cmp::max(backoff, self.min_backoff);
            self.current_backoff = backoff * 2;

            tracing::debug!(
                "Applying backoff of {:?} (remaining time: {:?})",
                backoff,
                remaining
            );
            sleep(backoff).await;
        } else {
            self.current_backoff = self.initial_backoff;
        }
        self.last_connect_attempt = std::time::Instant::now();
    }
}