/* * Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES * * Licensed under the Apache License, Version 2.0 (the "License"); you may not * use this file except in compliance with the License. You may obtain a copy of * the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations under * the License. */ use crate::pipeline::{ network::egress::push::{AddressedPushRouter, AddressedRequest, PushRouter}, AsyncEngine, Data, ManyOut, SingleIn, }; use rand::Rng; use std::collections::HashMap; use std::sync::{ atomic::{AtomicU64, Ordering}, Arc, }; use tokio::{net::unix::pipe::Receiver, sync::Mutex}; use crate::{pipeline::async_trait, transports::etcd::WatchEvent, Error}; use super::*; /// Each state will be have a nonce associated with it /// The state will be emitted in a watch channel, so we can observe the /// critical state transitions. enum MapState { /// The map is empty; value = nonce Empty(u64), /// The map is not-empty; values are (nonce, count) NonEmpty(u64, u64), /// The watcher has finished, no more events will be emitted Finished, } enum EndpointEvent { Put(String, i64), Delete(String), } #[derive(Clone)] pub struct Client { endpoint: Endpoint, router: PushRouter, watch_rx: tokio::sync::watch::Receiver>, counter: Arc, } impl Client where T: Data + Serialize, U: Data + for<'de> Deserialize<'de>, { pub(crate) async fn new(endpoint: Endpoint) -> Result { let router = AddressedPushRouter::new( endpoint.component.drt.nats_client.client().clone(), endpoint.component.drt.tcp_server().await?, )?; // create live endpoint watcher let prefix_watcher = endpoint .component .drt .etcd_client .kv_get_and_watch_prefix(endpoint.etcd_path()) .await?; let (prefix, _watcher, mut kv_event_rx) = prefix_watcher.dissolve(); let (watch_tx, watch_rx) = tokio::sync::watch::channel(vec![]); let secondary = endpoint.component.drt.runtime.secondary().clone(); // this task should be included in the registry // currently this is created once per client, but this object/task should only be instantiated // once per worker/instance secondary.spawn(async move { log::debug!("Starting endpoint watcher for prefix: {}", prefix); let mut map = HashMap::new(); loop { let kv_event = tokio::select! { _ = watch_tx.closed() => { log::debug!("all watchers have closed; shutting down endpoint watcher for prefix: {}", prefix); break; } kv_event = kv_event_rx.recv() => { match kv_event { Some(kv_event) => kv_event, None => { log::debug!("watch stream has closed; shutting down endpoint watcher for prefix: {}", prefix); break; } } } }; match kv_event { WatchEvent::Put(kv) => { let key = String::from_utf8(kv.key().to_vec()); let val = serde_json::from_slice::(kv.value()); if let (Ok(key), Ok(val)) = (key, val) { map.insert(key.clone(), val.lease_id); } else { log::error!("Unable to parse put endpoint event; shutting down endpoint watcher for prefix: {}", prefix); break; } } WatchEvent::Delete(kv) => { match String::from_utf8(kv.key().to_vec()) { Ok(key) => { map.remove(&key); } Err(_) => { log::error!("Unable to parse delete endpoint event; shutting down endpoint watcher for prefix: {}", prefix); break; } } } } let endpoint_ids: Vec = map.values().cloned().collect(); if watch_tx.send(endpoint_ids).is_err() { log::debug!("Unable to send watch updates; shutting down endpoint watcher for prefix: {}", prefix); break; } } log::debug!("Completed endpoint watcher for prefix: {}", prefix); let _ = watch_tx.send(vec![]); }); Ok(Client { endpoint, router, watch_rx, counter: Arc::new(AtomicU64::new(0)), }) } pub fn endpoint_ids(&self) -> &tokio::sync::watch::Receiver> { &self.watch_rx } /// Wait for at least one [`Endpoint`] to be available pub async fn wait_for_endpoints(&self) -> Result<()> { let mut rx = self.watch_rx.clone(); // wait for there to be 1 or more endpoints loop { if rx.borrow_and_update().is_empty() { rx.changed().await?; } else { break; } } Ok(()) } /// Issue a request to the next available endpoint in a round-robin fashion pub async fn round_robin(&self, request: SingleIn) -> Result> { let counter = self.counter.fetch_add(1, Ordering::Relaxed); let endpoint_id = { let endpoints = self.watch_rx.borrow(); let count = endpoints.len(); if count == 0 { return Err(error!( "no endpoints found for endpoint {:?}", self.endpoint.etcd_path() )); } let offset = counter % count as u64; endpoints[offset as usize] }; let subject = self.endpoint.subject(endpoint_id); let request = request.map(|req| AddressedRequest::new(req, subject)); self.router.generate(request).await } /// Issue a request to a random endpoint pub async fn random(&self, request: SingleIn) -> Result> { let endpoint_id = { let endpoints = self.watch_rx.borrow(); let count = endpoints.len(); if count == 0 { return Err(error!( "no endpoints found for endpoint {:?}", self.endpoint.etcd_path() )); } let counter = rand::thread_rng().gen::(); let offset = counter % count as u64; endpoints[offset as usize] }; let subject = self.endpoint.subject(endpoint_id); let request = request.map(|req| AddressedRequest::new(req, subject)); self.router.generate(request).await } /// Issue a request to a specific endpoint pub async fn direct(&self, request: SingleIn, endpoint_id: i64) -> Result> { let found = { let endpoints = self.watch_rx.borrow(); endpoints.contains(&endpoint_id) }; if !found { return Err(error!( "endpoint_id={} not found for endpoint {:?}", endpoint_id, self.endpoint.etcd_path() )); } let subject = self.endpoint.subject(endpoint_id); let request = request.map(|req| AddressedRequest::new(req, subject)); self.router.generate(request).await } } #[async_trait] impl AsyncEngine, ManyOut, Error> for Client where T: Data + Serialize, U: Data + for<'de> Deserialize<'de>, { async fn generate(&self, request: SingleIn) -> Result, Error> { self.random(request).await } }