client.rs 10.7 KB
Newer Older
1
2
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
Ryan Olson's avatar
Ryan Olson committed
3

4
5
6
7
8
9
use crate::{
    pipeline::{
        AddressedPushRouter, AddressedRequest, AsyncEngine, Data, ManyOut, PushRouter, RouterMode,
        SingleIn,
    },
    storage::key_value_store::{KeyValueStoreManager, WatchEvent},
Ryan Olson's avatar
Ryan Olson committed
10
};
11
use arc_swap::ArcSwap;
Ryan Olson's avatar
Ryan Olson committed
12
use std::collections::HashMap;
13
use std::sync::Arc;
14
use tokio::net::unix::pipe::Receiver;
Ryan Olson's avatar
Ryan Olson committed
15

16
use crate::{pipeline::async_trait, transports::etcd::Client as EtcdClient};
Ryan Olson's avatar
Ryan Olson committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34

use super::*;

/// Each state will be have a nonce associated with it
/// The state will be emitted in a watch channel, so we can observe the
/// critical state transitions.
enum MapState {
    /// The map is empty; value = nonce
    Empty(u64),

    /// The map is not-empty; values are (nonce, count)
    NonEmpty(u64, u64),

    /// The watcher has finished, no more events will be emitted
    Finished,
}

enum EndpointEvent {
35
    Put(String, u64),
Ryan Olson's avatar
Ryan Olson committed
36
37
38
    Delete(String),
}

39
40
41
42
#[derive(Clone, Debug)]
pub struct Client {
    // This is me
    pub endpoint: Endpoint,
43
    // These are the remotes I know about from watching etcd
44
    pub instance_source: Arc<InstanceSource>,
45
    // These are the instance source ids less those reported as down from sending rpc
46
    instance_avail: Arc<ArcSwap<Vec<u64>>>,
47
    // These are the instance source ids less those reported as busy (above threshold)
48
    instance_free: Arc<ArcSwap<Vec<u64>>>,
49
50
51
}

#[derive(Clone, Debug)]
52
pub enum InstanceSource {
53
    Static,
54
    Dynamic(tokio::sync::watch::Receiver<Vec<Instance>>),
Ryan Olson's avatar
Ryan Olson committed
55
56
}

57
impl Client {
58
59
60
61
    // Client will only talk to a single static endpoint
    pub(crate) async fn new_static(endpoint: Endpoint) -> Result<Self> {
        Ok(Client {
            endpoint,
62
            instance_source: Arc::new(InstanceSource::Static),
63
            instance_avail: Arc::new(ArcSwap::from(Arc::new(vec![]))),
64
            instance_free: Arc::new(ArcSwap::from(Arc::new(vec![]))),
65
66
        })
    }
Ryan Olson's avatar
Ryan Olson committed
67

68
    // Client with auto-discover instances using etcd
69
    pub(crate) async fn new_dynamic(endpoint: Endpoint) -> Result<Self> {
70
71
        const INSTANCE_REFRESH_PERIOD: Duration = Duration::from_secs(1);

Ryan Olson's avatar
Ryan Olson committed
72
        // create live endpoint watcher
73
        let instance_source = Self::get_or_create_dynamic_instance_source(&endpoint).await?;
74

75
        let client = Client {
76
            endpoint,
77
            instance_source: instance_source.clone(),
78
            instance_avail: Arc::new(ArcSwap::from(Arc::new(vec![]))),
79
            instance_free: Arc::new(ArcSwap::from(Arc::new(vec![]))),
80
        };
81
        client.monitor_instance_source();
82
        Ok(client)
83
84
85
86
87
88
89
90
91
92
93
    }

    pub fn path(&self) -> String {
        self.endpoint.path()
    }

    /// The root etcd path we watch in etcd to discover new instances to route to.
    pub fn etcd_root(&self) -> String {
        self.endpoint.etcd_root()
    }

94
    /// Instances available from watching etcd
95
    pub fn instances(&self) -> Vec<Instance> {
96
97
98
99
        match self.instance_source.as_ref() {
            InstanceSource::Static => vec![],
            InstanceSource::Dynamic(watch_rx) => watch_rx.borrow().clone(),
        }
100
101
    }

102
    pub fn instance_ids(&self) -> Vec<u64> {
103
104
105
        self.instances().into_iter().map(|ep| ep.id()).collect()
    }

106
    pub fn instance_ids_avail(&self) -> arc_swap::Guard<Arc<Vec<u64>>> {
107
108
109
        self.instance_avail.load()
    }

110
    pub fn instance_ids_free(&self) -> arc_swap::Guard<Arc<Vec<u64>>> {
111
112
113
        self.instance_free.load()
    }

114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    /// Wait for at least one Instance to be available for this Endpoint
    pub async fn wait_for_instances(&self) -> Result<Vec<Instance>> {
        let mut instances: Vec<Instance> = vec![];
        if let InstanceSource::Dynamic(mut rx) = self.instance_source.as_ref().clone() {
            // wait for there to be 1 or more endpoints
            loop {
                instances = rx.borrow_and_update().to_vec();
                if instances.is_empty() {
                    rx.changed().await?;
                } else {
                    break;
                }
            }
        }
        Ok(instances)
    }

131
132
133
    /// Is this component know at startup and not discovered via etcd?
    pub fn is_static(&self) -> bool {
        matches!(self.instance_source.as_ref(), InstanceSource::Static)
134
135
136
    }

    /// Mark an instance as down/unavailable
137
    pub fn report_instance_down(&self, instance_id: u64) {
138
139
140
141
142
143
        let filtered = self
            .instance_ids_avail()
            .iter()
            .filter_map(|&id| if id == instance_id { None } else { Some(id) })
            .collect::<Vec<_>>();
        self.instance_avail.store(Arc::new(filtered));
144
145
146
147

        tracing::debug!("inhibiting instance {instance_id}");
    }

148
    /// Update the set of free instances based on busy instance IDs
149
    pub fn update_free_instances(&self, busy_instance_ids: &[u64]) {
150
        let all_instance_ids = self.instance_ids();
151
        let free_ids: Vec<u64> = all_instance_ids
152
153
154
155
156
157
            .into_iter()
            .filter(|id| !busy_instance_ids.contains(id))
            .collect();
        self.instance_free.store(Arc::new(free_ids));
    }

158
159
160
161
162
163
164
165
166
167
168
169
170
    /// Monitor the ETCD instance source and update instance_avail.
    fn monitor_instance_source(&self) {
        let cancel_token = self.endpoint.drt().primary_token();
        let client = self.clone();
        tokio::task::spawn(async move {
            let mut rx = match client.instance_source.as_ref() {
                InstanceSource::Static => {
                    tracing::error!("Static instance source is not watchable");
                    return;
                }
                InstanceSource::Dynamic(rx) => rx.clone(),
            };
            while !cancel_token.is_cancelled() {
171
                let instance_ids: Vec<u64> = rx
172
173
174
175
                    .borrow_and_update()
                    .iter()
                    .map(|instance| instance.id())
                    .collect();
176
177
178
179

                // TODO: this resets both tracked available and free instances
                client.instance_avail.store(Arc::new(instance_ids.clone()));
                client.instance_free.store(Arc::new(instance_ids));
180
181
182
183
184
185
186
187
188

                tracing::debug!("instance source updated");

                if let Err(err) = rx.changed().await {
                    tracing::error!("The Sender is dropped: {}", err);
                    cancel_token.cancel();
                }
            }
        });
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
    }

    async fn get_or_create_dynamic_instance_source(
        endpoint: &Endpoint,
    ) -> Result<Arc<InstanceSource>> {
        let drt = endpoint.drt();
        let instance_sources = drt.instance_sources();
        let mut instance_sources = instance_sources.lock().await;

        if let Some(instance_source) = instance_sources.get(endpoint) {
            if let Some(instance_source) = instance_source.upgrade() {
                return Ok(instance_source);
            } else {
                instance_sources.remove(endpoint);
            }
        }

206
207
208
209
        let prefix = endpoint.etcd_root();
        let store = Arc::new(drt.store().clone());
        let (_, mut kv_event_rx) =
            store.watch(super::INSTANCE_ROOT_PATH, None, drt.primary_token());
Ryan Olson's avatar
Ryan Olson committed
210
211
212
213
214
215
216
217
        let (watch_tx, watch_rx) = tokio::sync::watch::channel(vec![]);

        let secondary = endpoint.component.drt.runtime.secondary().clone();

        // this task should be included in the registry
        // currently this is created once per client, but this object/task should only be instantiated
        // once per worker/instance
        secondary.spawn(async move {
218
            tracing::debug!("Starting endpoint watcher for prefix: {prefix}");
Ryan Olson's avatar
Ryan Olson committed
219
220
221
222
223
            let mut map = HashMap::new();

            loop {
                let kv_event = tokio::select! {
                    _ = watch_tx.closed() => {
224
                        tracing::debug!("all watchers have closed; shutting down endpoint watcher for prefix: {prefix}");
Ryan Olson's avatar
Ryan Olson committed
225
226
227
228
229
230
                        break;
                    }
                    kv_event = kv_event_rx.recv() => {
                        match kv_event {
                            Some(kv_event) => kv_event,
                            None => {
231
                                tracing::debug!("watch stream has closed; shutting down endpoint watcher for prefix: {prefix}");
Ryan Olson's avatar
Ryan Olson committed
232
233
234
235
236
237
238
239
                                break;
                            }
                        }
                    }
                };

                match kv_event {
                    WatchEvent::Put(kv) => {
240
241
242
                        let key = kv.key_str();
                        if !key.starts_with(&prefix) {
                            continue;
Ryan Olson's avatar
Ryan Olson committed
243
                        }
244
245
246
247
248
249
250
251
252
253
254
255
256
                        let Some(mut key) = key.strip_prefix(super::INSTANCE_ROOT_PATH) else {
                            tracing::error!("WatchEvent::Put Key not in INSTANCE_ROOT_PATH. Should be impossible.");
                            continue;
                        };
                        if key.starts_with("/") {
                            key = &key[1..];
                        }

                        match serde_json::from_slice::<Instance>(kv.value()) {
                            Ok(val) => map.insert(key.to_string(), val),
                            Err(err) => {
                                tracing::error!(error = %err, prefix,
                                    "Unable to parse put endpoint event; shutting down endpoint watcher");
Ryan Olson's avatar
Ryan Olson committed
257
258
                                break;
                            }
259
260
261
262
263
264
265
266
267
268
269
270
271
                        };
                    }
                    WatchEvent::Delete(key) => {
                        let key = key.as_ref();
                        if !key.starts_with(&prefix) {
                            continue;
                        }
                        let Some(mut key) = key.strip_prefix(super::INSTANCE_ROOT_PATH) else {
                            tracing::error!("WatchEvent::Delete Key not in INSTANCE_ROOT_PATH. Should be impossible.");
                            continue;
                        };
                        if key.starts_with("/") {
                            key = &key[1..];
Ryan Olson's avatar
Ryan Olson committed
272
                        }
273
                        map.remove(key);
Ryan Olson's avatar
Ryan Olson committed
274
275
276
                    }
                }

277
                let instances: Vec<Instance> = map.values().cloned().collect();
Ryan Olson's avatar
Ryan Olson committed
278

279
                if watch_tx.send(instances).is_err() {
280
                    tracing::debug!("Unable to send watch updates; shutting down endpoint watcher for prefix: {}", prefix);
Ryan Olson's avatar
Ryan Olson committed
281
282
283
284
285
                    break;
                }

            }

286
            tracing::debug!("Completed endpoint watcher for prefix: {prefix}");
Ryan Olson's avatar
Ryan Olson committed
287
288
289
            let _ = watch_tx.send(vec![]);
        });

290
291
292
        let instance_source = Arc::new(InstanceSource::Dynamic(watch_rx));
        instance_sources.insert(endpoint.clone(), Arc::downgrade(&instance_source));
        Ok(instance_source)
293
    }
Ryan Olson's avatar
Ryan Olson committed
294
}