Unverified Commit 3f9c3ffe authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

fix: etcd.rs - linear increasing watch with number of requests (#1081)


Signed-off-by: default avatarMichael Feil <63565275+michaelfeil@users.noreply.github.com>
Co-authored-by: default avatarMichael Feil <63565275+michaelfeil@users.noreply.github.com>
Co-authored-by: default avatarjthomson04 <jwillthomson19@gmail.com>
Co-authored-by: default avatarRyan Olson <ryanolson@users.noreply.github.com>
parent 4eae238f
......@@ -184,7 +184,7 @@ impl AsyncEngine<SingleIn<BackendInput>, ManyOut<Annotated<LLMEngineOutput>>, Er
&self,
request: SingleIn<BackendInput>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
match &self.inner.client.instances {
match self.inner.client.instance_source.as_ref() {
InstanceSource::Static => self.inner.r#static(request).await,
InstanceSource::Dynamic(_) => {
let instance_id = self.chooser.find_best_match(&request.token_ids).await?;
......
......@@ -46,7 +46,7 @@ use derive_getters::Getters;
use educe::Educe;
use serde::{Deserialize, Serialize};
use service::EndpointStatsHandler;
use std::{collections::HashMap, sync::Arc};
use std::{collections::HashMap, hash::Hash, sync::Arc};
use validator::{Validate, ValidationError};
mod client;
......@@ -123,6 +123,24 @@ pub struct Component {
is_static: bool,
}
impl Hash for Component {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.namespace.name().hash(state);
self.name.hash(state);
self.is_static.hash(state);
}
}
impl PartialEq for Component {
fn eq(&self, other: &Self) -> bool {
self.namespace.name() == other.namespace.name()
&& self.name == other.name
&& self.is_static == other.is_static
}
}
impl Eq for Component {}
impl std::fmt::Display for Component {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}.{}", self.namespace.name(), self.name)
......@@ -238,6 +256,24 @@ pub struct Endpoint {
is_static: bool,
}
impl Hash for Endpoint {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.component.hash(state);
self.name.hash(state);
self.is_static.hash(state);
}
}
impl PartialEq for Endpoint {
fn eq(&self, other: &Self) -> bool {
self.component == other.component
&& self.name == other.name
&& self.is_static == other.is_static
}
}
impl Eq for Endpoint {}
impl DistributedRuntimeProvider for Endpoint {
fn drt(&self) -> &DistributedRuntime {
self.component.drt()
......
......@@ -25,7 +25,10 @@ use std::sync::{
};
use tokio::{net::unix::pipe::Receiver, sync::Mutex};
use crate::{pipeline::async_trait, transports::etcd::WatchEvent};
use crate::{
pipeline::async_trait,
transports::etcd::{Client as EtcdClient, WatchEvent},
};
use super::*;
......@@ -53,7 +56,7 @@ pub struct Client {
// This is me
pub endpoint: Endpoint,
// These are the remotes I know about
pub instances: InstanceSource,
pub instance_source: Arc<InstanceSource>,
}
#[derive(Clone, Debug)]
......@@ -67,7 +70,7 @@ impl Client {
pub(crate) async fn new_static(endpoint: Endpoint) -> Result<Self> {
Ok(Client {
endpoint,
instances: InstanceSource::Static,
instance_source: Arc::new(InstanceSource::Static),
})
}
......@@ -77,6 +80,74 @@ impl Client {
let Some(etcd_client) = &endpoint.component.drt.etcd_client else {
anyhow::bail!("Attempt to create a dynamic client on a static endpoint");
};
let instance_source =
Self::get_or_create_dynamic_instance_source(etcd_client, &endpoint).await?;
Ok(Client {
endpoint,
instance_source,
})
}
pub fn path(&self) -> String {
self.endpoint.path()
}
/// The root etcd path we watch in etcd to discover new instances to route to.
pub fn etcd_root(&self) -> String {
self.endpoint.etcd_root()
}
pub fn instances(&self) -> Vec<Instance> {
match self.instance_source.as_ref() {
InstanceSource::Static => vec![],
InstanceSource::Dynamic(watch_rx) => watch_rx.borrow().clone(),
}
}
pub fn instance_ids(&self) -> Vec<i64> {
self.instances().into_iter().map(|ep| ep.id()).collect()
}
/// Wait for at least one Instance to be available for this Endpoint
pub async fn wait_for_instances(&self) -> Result<Vec<Instance>> {
let mut instances: Vec<Instance> = vec![];
if let InstanceSource::Dynamic(mut rx) = self.instance_source.as_ref().clone() {
// wait for there to be 1 or more endpoints
loop {
instances = rx.borrow_and_update().to_vec();
if instances.is_empty() {
rx.changed().await?;
} else {
break;
}
}
}
Ok(instances)
}
/// Is this component know at startup and not discovered via etcd?
pub fn is_static(&self) -> bool {
matches!(self.instance_source.as_ref(), InstanceSource::Static)
}
async fn get_or_create_dynamic_instance_source(
etcd_client: &EtcdClient,
endpoint: &Endpoint,
) -> Result<Arc<InstanceSource>> {
let drt = endpoint.drt();
let instance_sources = drt.instance_sources();
let mut instance_sources = instance_sources.lock().await;
if let Some(instance_source) = instance_sources.get(endpoint) {
if let Some(instance_source) = instance_source.upgrade() {
return Ok(instance_source);
} else {
instance_sources.remove(endpoint);
}
}
let prefix_watcher = etcd_client
.kv_get_and_watch_prefix(endpoint.etcd_root())
.await?;
......@@ -146,51 +217,8 @@ impl Client {
let _ = watch_tx.send(vec![]);
});
Ok(Client {
endpoint,
instances: InstanceSource::Dynamic(watch_rx),
})
}
pub fn path(&self) -> String {
self.endpoint.path()
}
/// The root etcd path we watch in etcd to discover new instances to route to.
pub fn etcd_root(&self) -> String {
self.endpoint.etcd_root()
}
pub fn instances(&self) -> Vec<Instance> {
match &self.instances {
InstanceSource::Static => vec![],
InstanceSource::Dynamic(watch_rx) => watch_rx.borrow().clone(),
}
}
pub fn instance_ids(&self) -> Vec<i64> {
self.instances().into_iter().map(|ep| ep.id()).collect()
}
/// Wait for at least one Instance to be available for this Endpoint
pub async fn wait_for_instances(&self) -> Result<Vec<Instance>> {
let mut instances: Vec<Instance> = vec![];
if let InstanceSource::Dynamic(mut rx) = self.instances.clone() {
// wait for there to be 1 or more endpoints
loop {
instances = rx.borrow_and_update().to_vec();
if instances.is_empty() {
rx.changed().await?;
} else {
break;
}
}
}
Ok(instances)
}
/// Is this component know at startup and not discovered via etcd?
pub fn is_static(&self) -> bool {
matches!(self.instances, InstanceSource::Static)
let instance_source = Arc::new(InstanceSource::Dynamic(watch_rx));
instance_sources.insert(endpoint.clone(), Arc::downgrade(&instance_source));
Ok(instance_source)
}
}
......@@ -15,17 +15,19 @@
pub use crate::component::Component;
use crate::{
component::{self, ComponentBuilder, Namespace},
component::{self, ComponentBuilder, Endpoint, InstanceSource, Namespace},
discovery::DiscoveryClient,
service::ServiceClient,
transports::{etcd, nats, tcp},
ErrorContext,
};
use super::{error, Arc, DistributedRuntime, OnceCell, Result, Runtime, OK};
use super::{error, Arc, DistributedRuntime, OnceCell, Result, Runtime, Weak, OK};
use derive_getters::Dissolve;
use figment::error;
use std::collections::HashMap;
use tokio::sync::Mutex;
use tokio_util::sync::CancellationToken;
impl DistributedRuntime {
......@@ -70,6 +72,7 @@ impl DistributedRuntime {
tcp_server: Arc::new(OnceCell::new()),
component_registry: component::Registry::new(),
is_static,
instance_sources: Arc::new(Mutex::new(HashMap::new())),
})
}
......@@ -155,6 +158,10 @@ impl DistributedRuntime {
pub fn child_token(&self) -> CancellationToken {
self.runtime.child_token()
}
pub fn instance_sources(&self) -> Arc<Mutex<HashMap<Endpoint, Weak<InstanceSource>>>> {
self.instance_sources.clone()
}
}
#[derive(Dissolve)]
......
......@@ -18,7 +18,11 @@
#![allow(dead_code)]
#![allow(unused_imports)]
use std::sync::{Arc, Mutex};
use std::{
collections::HashMap,
sync::{Arc, Weak},
};
use tokio::sync::Mutex;
pub use anyhow::{
anyhow as error, bail as raise, Context as ErrorContext, Error, Ok as OK, Result,
......@@ -50,6 +54,8 @@ pub use futures::stream;
pub use tokio_util::sync::CancellationToken;
pub use worker::Worker;
use component::{Endpoint, InstanceSource};
/// Types of Tokio runtimes that can be used to construct a Dynamo [Runtime].
#[derive(Clone)]
enum RuntimeType {
......@@ -88,4 +94,6 @@ pub struct DistributedRuntime {
// Will only have static components that are not discoverable via etcd, they must be know at
// startup. Will not start etcd.
is_static: bool,
instance_sources: Arc<Mutex<HashMap<Endpoint, Weak<InstanceSource>>>>,
}
......@@ -188,7 +188,7 @@ where
U: Data + for<'de> Deserialize<'de>,
{
async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
match &self.client.instances {
match self.client.instance_source.as_ref() {
InstanceSource::Static => self.r#static(request).await,
InstanceSource::Dynamic(_) => match self.router_mode {
RouterMode::Random => self.random(request).await,
......
......@@ -28,8 +28,8 @@ use etcd_client::{
Certificate, Compare, CompareOp, DeleteOptions, GetOptions, Identity, PutOptions, PutResponse,
TlsOptions, Txn, TxnOp, TxnOpResponse, WatchOptions, Watcher,
};
pub use etcd_client::{ConnectOptions, KeyValue, LeaseClient};
use tokio::time::{interval, Duration};
mod lease;
use lease::*;
......@@ -345,40 +345,46 @@ impl Client {
self.runtime.secondary().spawn(async move {
for kv in kvs {
if tx.send(WatchEvent::Put(kv)).await.is_err() {
// receiver is closed
break;
// receiver is already closed
return;
}
}
while let Some(Ok(response)) = watch_stream.next().await {
loop {
tokio::select! {
maybe_resp = watch_stream.next() => {
// Early return for None or Err cases
let Some(Ok(response)) = maybe_resp else {
tracing::info!("kv watch stream closed");
return;
};
// Process events
for event in response.events() {
// Extract the KeyValue if it exists
let Some(kv) = event.kv() else {
continue; // Skip events with no KV
};
// Handle based on event type
match event.event_type() {
etcd_client::EventType::Put => {
if let Some(kv) = event.kv() {
if tx.is_closed() {
// Receiver no longer interested, expected.
break;
}
if let Err(err) = tx.send(WatchEvent::Put(kv.clone())).await {
tracing::error!(
"kv watcher error forwarding WatchEvent::Put: {err}"
);
// receiver is closed
break;
}
tracing::error!("kv watcher error forwarding WatchEvent::Put: {err}");
return;
}
}
etcd_client::EventType::Delete => {
if let Some(kv) = event.kv() {
if tx.is_closed() {
break;
}
if tx.send(WatchEvent::Delete(kv.clone())).await.is_err() {
// receiver is closed
break;
return;
}
}
}
}
}
_ = tx.closed() => {
tracing::debug!("no more receivers, stopping watcher");
return;
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment