Unverified Commit aba60996 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

perf(router): Remove lock from router hot path (#1963)

parent b212103f
...@@ -156,6 +156,12 @@ dependencies = [ ...@@ -156,6 +156,12 @@ dependencies = [
"derive_arbitrary", "derive_arbitrary",
] ]
[[package]]
name = "arc-swap"
version = "1.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457"
[[package]] [[package]]
name = "arrayref" name = "arrayref"
version = "0.3.9" version = "0.3.9"
...@@ -1869,6 +1875,7 @@ name = "dynamo-runtime" ...@@ -1869,6 +1875,7 @@ name = "dynamo-runtime"
version = "0.3.2" version = "0.3.2"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"arc-swap",
"assert_matches", "assert_matches",
"async-nats", "async-nats",
"async-once-cell", "async-once-cell",
......
...@@ -147,6 +147,12 @@ dependencies = [ ...@@ -147,6 +147,12 @@ dependencies = [
"derive_arbitrary", "derive_arbitrary",
] ]
[[package]]
name = "arc-swap"
version = "1.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457"
[[package]] [[package]]
name = "arrayref" name = "arrayref"
version = "0.3.9" version = "0.3.9"
...@@ -1701,6 +1707,7 @@ name = "dynamo-runtime" ...@@ -1701,6 +1707,7 @@ name = "dynamo-runtime"
version = "0.3.2" version = "0.3.2"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"arc-swap",
"async-nats", "async-nats",
"async-once-cell", "async-once-cell",
"async-stream", "async-stream",
......
...@@ -62,6 +62,7 @@ url = { workspace = true } ...@@ -62,6 +62,7 @@ url = { workspace = true }
validator = { workspace = true } validator = { workspace = true }
xxhash-rust = { workspace = true } xxhash-rust = { workspace = true }
arc-swap = { version = "1" }
async-once-cell = { version = "0.5.4" } async-once-cell = { version = "0.5.4" }
educe = { version = "0.6.0" } educe = { version = "0.6.0" }
figment = { version = "0.10.19", features = ["env", "json", "toml", "test"] } figment = { version = "0.10.19", features = ["env", "json", "toml", "test"] }
......
...@@ -47,6 +47,12 @@ version = "1.0.98" ...@@ -47,6 +47,12 @@ version = "1.0.98"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487"
[[package]]
name = "arc-swap"
version = "1.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457"
[[package]] [[package]]
name = "arrayref" name = "arrayref"
version = "0.3.9" version = "0.3.9"
...@@ -680,6 +686,7 @@ name = "dynamo-runtime" ...@@ -680,6 +686,7 @@ name = "dynamo-runtime"
version = "0.3.2" version = "0.3.2"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"arc-swap",
"async-nats", "async-nats",
"async-once-cell", "async-once-cell",
"async-stream", "async-stream",
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::pipeline::{ use crate::pipeline::{
AddressedPushRouter, AddressedRequest, AsyncEngine, Data, ManyOut, PushRouter, RouterMode, AddressedPushRouter, AddressedRequest, AsyncEngine, Data, ManyOut, PushRouter, RouterMode,
SingleIn, SingleIn,
}; };
use arc_swap::ArcSwap;
use rand::Rng; use rand::Rng;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::RwLock;
use std::sync::{ use std::sync::{
atomic::{AtomicU64, Ordering}, atomic::{AtomicU64, Ordering},
Arc, Arc, Mutex,
}; };
use tokio::{net::unix::pipe::Receiver, sync::Mutex}; use std::time::Instant;
use tokio::net::unix::pipe::Receiver;
use crate::{ use crate::{
pipeline::async_trait, pipeline::async_trait,
...@@ -58,7 +49,9 @@ pub struct Client { ...@@ -58,7 +49,9 @@ pub struct Client {
// These are the remotes I know about from watching etcd // These are the remotes I know about from watching etcd
pub instance_source: Arc<InstanceSource>, pub instance_source: Arc<InstanceSource>,
// These are the instances that are reported as down from sending rpc // These are the instances that are reported as down from sending rpc
instance_inhibited: Arc<Mutex<HashMap<i64, std::time::Instant>>>, instance_inhibited: Arc<Mutex<HashMap<i64, Instant>>>,
// The current active IDs
instance_cache: Arc<ArcSwap<Vec<i64>>>,
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
...@@ -76,11 +69,14 @@ impl Client { ...@@ -76,11 +69,14 @@ impl Client {
endpoint, endpoint,
instance_source: Arc::new(InstanceSource::Static), instance_source: Arc::new(InstanceSource::Static),
instance_inhibited: Arc::new(Mutex::new(HashMap::new())), instance_inhibited: Arc::new(Mutex::new(HashMap::new())),
instance_cache: Arc::new(ArcSwap::from(Arc::new(vec![]))),
}) })
} }
// Client with auto-discover instances using etcd // Client with auto-discover instances using etcd
pub(crate) async fn new_dynamic(endpoint: Endpoint) -> Result<Self> { pub(crate) async fn new_dynamic(endpoint: Endpoint) -> Result<Self> {
const INSTANCE_REFRESH_PERIOD: Duration = Duration::from_secs(1);
// create live endpoint watcher // create live endpoint watcher
let Some(etcd_client) = &endpoint.component.drt.etcd_client else { let Some(etcd_client) = &endpoint.component.drt.etcd_client else {
anyhow::bail!("Attempt to create a dynamic client on a static endpoint"); anyhow::bail!("Attempt to create a dynamic client on a static endpoint");
...@@ -89,11 +85,27 @@ impl Client { ...@@ -89,11 +85,27 @@ impl Client {
let instance_source = let instance_source =
Self::get_or_create_dynamic_instance_source(etcd_client, &endpoint).await?; Self::get_or_create_dynamic_instance_source(etcd_client, &endpoint).await?;
Ok(Client { let cancel_token = endpoint.drt().primary_token();
let client = Client {
endpoint, endpoint,
instance_source, instance_source,
instance_inhibited: Arc::new(Mutex::new(HashMap::new())), instance_inhibited: Arc::new(Mutex::new(HashMap::new())),
}) instance_cache: Arc::new(ArcSwap::from(Arc::new(vec![]))),
};
let instance_source_c = client.instance_source.clone();
let instance_inhibited_c = Arc::clone(&client.instance_inhibited);
let instance_cache_c = Arc::clone(&client.instance_cache);
tokio::task::spawn(async move {
while !cancel_token.is_cancelled() {
refresh_instances(&instance_source_c, &instance_inhibited_c, &instance_cache_c);
tokio::select! {
_ = cancel_token.cancelled() => {}
_ = tokio::time::sleep(INSTANCE_REFRESH_PERIOD) => {}
}
}
});
Ok(client)
} }
pub fn path(&self) -> String { pub fn path(&self) -> String {
...@@ -107,10 +119,7 @@ impl Client { ...@@ -107,10 +119,7 @@ impl Client {
/// Instances available from watching etcd /// Instances available from watching etcd
pub fn instances(&self) -> Vec<Instance> { pub fn instances(&self) -> Vec<Instance> {
match self.instance_source.as_ref() { instances_inner(self.instance_source.as_ref())
InstanceSource::Static => vec![],
InstanceSource::Dynamic(watch_rx) => watch_rx.borrow().clone(),
}
} }
pub fn instance_ids(&self) -> Vec<i64> { pub fn instance_ids(&self) -> Vec<i64> {
...@@ -135,48 +144,16 @@ impl Client { ...@@ -135,48 +144,16 @@ impl Client {
} }
/// Instances available from watching etcd minus those reported as down /// Instances available from watching etcd minus those reported as down
pub async fn instances_avail(&self) -> Vec<Instance> { pub fn instance_ids_avail(&self) -> arc_swap::Guard<Arc<Vec<i64>>> {
// TODO: Can we get the remaining TTL from the lease for the instance? self.instance_cache.load()
const ETCD_LEASE_TTL: u64 = 10; // seconds
let now = std::time::Instant::now();
let instances = self.instances();
let mut inhibited = self.instance_inhibited.lock().await;
// 1. Remove inhibited instances that are no longer in `self.instances()`
// 2. Remove inhibited instances that have expired
// 3. Only return instances that are not inhibited after removals
let mut new_inhibited = HashMap::<i64, std::time::Instant>::new();
let filtered = instances
.into_iter()
.filter_map(|instance| {
let id = instance.id();
if let Some(&timestamp) = inhibited.get(&id) {
if now.duration_since(timestamp).as_secs() > ETCD_LEASE_TTL {
tracing::debug!("instance {id} stale inhibition");
Some(instance)
} else {
tracing::debug!("instance {id} is inhibited");
new_inhibited.insert(id, timestamp);
None
}
} else {
tracing::debug!("instance {id} not inhibited");
Some(instance)
}
})
.collect();
*inhibited = new_inhibited;
filtered
} }
/// Mark an instance as down/unavailable /// Mark an instance as down/unavailable
pub async fn report_instance_down(&self, instance_id: i64) { pub fn report_instance_down(&self, instance_id: i64) {
let now = std::time::Instant::now(); self.instance_inhibited
.lock()
let mut inhibited = self.instance_inhibited.lock().await; .unwrap()
inhibited.insert(instance_id, now); .insert(instance_id, Instant::now());
tracing::debug!("inhibiting instance {instance_id}"); tracing::debug!("inhibiting instance {instance_id}");
} }
...@@ -276,3 +253,49 @@ impl Client { ...@@ -276,3 +253,49 @@ impl Client {
Ok(instance_source) Ok(instance_source)
} }
} }
/// Update the instance id cache
fn refresh_instances(
instance_source: &InstanceSource,
instance_inhibited: &Arc<Mutex<HashMap<i64, Instant>>>,
instance_cache: &Arc<ArcSwap<Vec<i64>>>,
) {
const ETCD_LEASE_TTL: u64 = 10; // seconds
// TODO: Can we get the remaining TTL from the lease for the instance?
let now = Instant::now();
let instances = instances_inner(instance_source);
let mut inhibited = instance_inhibited.lock().unwrap();
// 1. Remove inhibited instances that are no longer in `self.instances()`
// 2. Remove inhibited instances that have expired
// 3. Only return instances that are not inhibited after removals
let mut new_inhibited = HashMap::<i64, Instant>::new();
let filtered: Vec<i64> = instances
.into_iter()
.filter_map(|instance| {
let id = instance.id();
if let Some(&timestamp) = inhibited.get(&id) {
if now.duration_since(timestamp).as_secs() > ETCD_LEASE_TTL {
Some(id)
} else {
new_inhibited.insert(id, timestamp);
None
}
} else {
Some(id)
}
})
.collect();
*inhibited = new_inhibited;
instance_cache.store(Arc::new(filtered));
}
fn instances_inner(instance_source: &InstanceSource) -> Vec<Instance> {
match instance_source {
InstanceSource::Static => vec![],
InstanceSource::Dynamic(watch_rx) => watch_rx.borrow().clone(),
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::{AsyncEngineContextProvider, ResponseStream}; use super::{AsyncEngineContextProvider, ResponseStream};
use crate::{ use crate::{
...@@ -111,19 +99,18 @@ where ...@@ -111,19 +99,18 @@ where
/// Issue a request to the next available instance in a round-robin fashion /// Issue a request to the next available instance in a round-robin fashion
pub async fn round_robin(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> { pub async fn round_robin(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed); let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) as usize;
let instance_id = { let instance_id = {
let instances = self.client.instances_avail().await; let instance_ids = self.client.instance_ids_avail();
let count = instances.len(); let count = instance_ids.len();
if count == 0 { if count == 0 {
return Err(anyhow::anyhow!( return Err(anyhow::anyhow!(
"no instances found for endpoint {:?}", "no instances found for endpoint {:?}",
self.client.endpoint.etcd_root() self.client.endpoint.etcd_root()
)); ));
} }
let offset = counter % count as u64; instance_ids[counter % count]
instances[offset as usize].id()
}; };
tracing::trace!("round robin router selected {instance_id}"); tracing::trace!("round robin router selected {instance_id}");
...@@ -134,17 +121,16 @@ where ...@@ -134,17 +121,16 @@ where
/// Issue a request to a random endpoint /// Issue a request to a random endpoint
pub async fn random(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> { pub async fn random(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
let instance_id = { let instance_id = {
let instances = self.client.instances_avail().await; let instance_ids = self.client.instance_ids_avail();
let count = instances.len(); let count = instance_ids.len();
if count == 0 { if count == 0 {
return Err(anyhow::anyhow!( return Err(anyhow::anyhow!(
"no instances found for endpoint {:?}", "no instances found for endpoint {:?}",
self.client.endpoint.etcd_root() self.client.endpoint.etcd_root()
)); ));
} }
let counter = rand::rng().random::<u64>(); let counter = rand::rng().random::<u64>() as usize;
let offset = counter % count as u64; instance_ids[counter % count]
instances[offset as usize].id()
}; };
tracing::trace!("random router selected {instance_id}"); tracing::trace!("random router selected {instance_id}");
...@@ -158,10 +144,7 @@ where ...@@ -158,10 +144,7 @@ where
request: SingleIn<T>, request: SingleIn<T>,
instance_id: i64, instance_id: i64,
) -> anyhow::Result<ManyOut<U>> { ) -> anyhow::Result<ManyOut<U>> {
let found = { let found = self.client.instance_ids_avail().contains(&instance_id);
let instances = self.client.instances_avail().await;
instances.iter().any(|ep| ep.id() == instance_id)
};
if !found { if !found {
return Err(anyhow::anyhow!( return Err(anyhow::anyhow!(
...@@ -205,7 +188,7 @@ where ...@@ -205,7 +188,7 @@ where
} }
async move { async move {
if let Some((client, instance_id)) = report_instance_down { if let Some((client, instance_id)) = report_instance_down {
client.report_instance_down(instance_id).await; client.report_instance_down(instance_id);
} }
res res
} }
...@@ -215,7 +198,7 @@ where ...@@ -215,7 +198,7 @@ where
Err(err) => { Err(err) => {
if let Some(req_err) = err.downcast_ref::<NatsRequestError>() { if let Some(req_err) = err.downcast_ref::<NatsRequestError>() {
if matches!(req_err.kind(), NatsNoResponders) { if matches!(req_err.kind(), NatsNoResponders) {
self.client.report_instance_down(instance_id).await; self.client.report_instance_down(instance_id);
} }
} }
Err(err) Err(err)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment