Unverified Commit 0652441f authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

fix(kv-router): prune stale compressed-tree children (#8127)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 67bc6a1f
...@@ -643,6 +643,9 @@ impl SyncIndexer for ConcurrentRadixTree { ...@@ -643,6 +643,9 @@ impl SyncIndexer for ConcurrentRadixTree {
WorkerTask::RemoveWorkerDpRank(worker_id, dp_rank) => { WorkerTask::RemoveWorkerDpRank(worker_id, dp_rank) => {
self.remove_worker_dp_rank(&mut lookup, worker_id, dp_rank); self.remove_worker_dp_rank(&mut lookup, worker_id, dp_rank);
} }
WorkerTask::CleanupStaleChildren => {
self.run_cleanup_task();
}
WorkerTask::DumpEvents(_sender) => { WorkerTask::DumpEvents(_sender) => {
// Handled directly via dump_events() on the shared tree. // Handled directly via dump_events() on the shared tree.
// Should not be reached, but respond with empty to avoid blocking. // Should not be reached, but respond with empty to avoid blocking.
......
...@@ -59,13 +59,14 @@ ...@@ -59,13 +59,14 @@
//! - `new_with_frequency()` is not provided //! - `new_with_frequency()` is not provided
//! - `find_matches` does not populate `OverlapScores.frequencies` //! - `find_matches` does not populate `OverlapScores.frequencies`
use std::sync::Arc; use std::sync::{Arc, Weak};
use std::time::Instant;
use dashmap::DashMap; use dashmap::DashMap;
use parking_lot::RwLock; use parking_lot::RwLock;
use rustc_hash::{FxBuildHasher, FxHashMap, FxHashSet}; use rustc_hash::{FxBuildHasher, FxHashMap, FxHashSet};
use std::collections::VecDeque; use std::collections::VecDeque;
use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use super::{EventKind, KvIndexerMetrics, SyncIndexer, WorkerTask}; use super::{EventKind, KvIndexerMetrics, SyncIndexer, WorkerTask};
use crate::protocols::*; use crate::protocols::*;
...@@ -86,6 +87,8 @@ type SharedNode = Arc<RwLock<Node>>; ...@@ -86,6 +87,8 @@ type SharedNode = Arc<RwLock<Node>>;
/// stored here, keeping the map compact and correct across concurrent splits. /// stored here, keeping the map compact and correct across concurrent splits.
type WorkerLookup = FxHashMap<ExternalSequenceBlockHash, SharedNode>; type WorkerLookup = FxHashMap<ExternalSequenceBlockHash, SharedNode>;
const CLEANUP_INTERVAL_MS: u64 = 5 * 60 * 1000;
/// A node in the concurrent radix tree. /// A node in the concurrent radix tree.
/// ///
/// Stores a compressed edge with per-worker match indices. Workers with full coverage /// Stores a compressed edge with per-worker match indices. Workers with full coverage
...@@ -237,12 +240,71 @@ struct RemoveOutcome { ...@@ -237,12 +240,71 @@ struct RemoveOutcome {
stale_hashes: Vec<ExternalSequenceBlockHash>, stale_hashes: Vec<ExternalSequenceBlockHash>,
} }
struct CleanupEdge {
parent: Weak<RwLock<Node>>,
key: LocalBlockHash,
child: Weak<RwLock<Node>>,
}
struct CleanupState {
clock_origin: Instant,
last_cleanup_elapsed_ms: AtomicU64,
scheduled: AtomicBool,
}
impl CleanupState {
fn new() -> Self {
Self {
clock_origin: Instant::now(),
last_cleanup_elapsed_ms: AtomicU64::new(0),
scheduled: AtomicBool::new(false),
}
}
fn elapsed_ms(&self) -> u64 {
self.clock_origin.elapsed().as_millis() as u64
}
fn try_schedule(&self) -> bool {
let now_ms = self.elapsed_ms();
let last_ms = self.last_cleanup_elapsed_ms.load(Ordering::Relaxed);
if now_ms.saturating_sub(last_ms) < CLEANUP_INTERVAL_MS {
return false;
}
self.scheduled
.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
.is_ok()
}
fn cancel(&self) {
self.scheduled.store(false, Ordering::Release);
}
}
struct CleanupGuard<'a> {
state: &'a CleanupState,
completed_elapsed_ms: Option<u64>,
}
impl Drop for CleanupGuard<'_> {
fn drop(&mut self) {
if let Some(elapsed_ms) = self.completed_elapsed_ms {
self.state
.last_cleanup_elapsed_ms
.store(elapsed_ms, Ordering::Relaxed);
}
self.state.scheduled.store(false, Ordering::Release);
}
}
/// Thread-safe radix tree (compressed trie) for concurrent KV cache lookups. /// Thread-safe radix tree (compressed trie) for concurrent KV cache lookups.
pub struct ConcurrentRadixTreeCompressed { pub struct ConcurrentRadixTreeCompressed {
/// The root of the radix tree. Has an empty edge and only contains children. /// The root of the radix tree. Has an empty edge and only contains children.
root: SharedNode, root: SharedNode,
tree_sizes: DashMap<WorkerWithDpRank, AtomicUsize, FxBuildHasher>, tree_sizes: DashMap<WorkerWithDpRank, AtomicUsize, FxBuildHasher>,
cleanup: CleanupState,
} }
impl Default for ConcurrentRadixTreeCompressed { impl Default for ConcurrentRadixTreeCompressed {
...@@ -274,9 +336,73 @@ impl ConcurrentRadixTreeCompressed { ...@@ -274,9 +336,73 @@ impl ConcurrentRadixTreeCompressed {
Self { Self {
root: Arc::new(RwLock::new(Node::new())), root: Arc::new(RwLock::new(Node::new())),
tree_sizes: DashMap::with_hasher(FxBuildHasher), tree_sizes: DashMap::with_hasher(FxBuildHasher),
cleanup: CleanupState::new(),
}
}
fn cleanup_stale_children(&self) {
let mut queue = VecDeque::from([self.root.clone()]);
let mut edges = Vec::new();
while let Some(parent) = queue.pop_front() {
let guard = parent.read();
for (&key, child) in &guard.children {
queue.push_back(child.clone());
edges.push(CleanupEdge {
parent: Arc::downgrade(&parent),
key,
child: Arc::downgrade(child),
});
}
}
for edge in edges.into_iter().rev() {
let (Some(parent), Some(child)) = (edge.parent.upgrade(), edge.child.upgrade()) else {
continue;
};
let mut parent_guard = parent.write();
let Some(current) = parent_guard.children.get(&edge.key) else {
continue;
};
if !Arc::ptr_eq(current, &child) {
continue;
}
let Some(child_guard) = child.try_write() else {
continue;
};
if child_guard.has_any_workers() || !child_guard.children.is_empty() {
continue;
}
if Arc::strong_count(&child) != 2 {
continue;
}
parent_guard.children.remove(&edge.key);
drop(child_guard);
} }
} }
#[cfg(test)]
pub(crate) fn raw_child_edge_count(&self) -> usize {
let mut queue = VecDeque::from([self.root.clone()]);
let mut count = 0usize;
while let Some(node) = queue.pop_front() {
let guard = node.read();
count += guard.children.len();
queue.extend(guard.children.values().cloned());
}
count
}
#[cfg(test)]
pub(crate) fn run_cleanup_for_test(&self) {
self.cleanup_stale_children();
}
// ------------------------------------------------------------------ // ------------------------------------------------------------------
// Lookup resolution helpers // Lookup resolution helpers
// ------------------------------------------------------------------ // ------------------------------------------------------------------
...@@ -1223,6 +1349,9 @@ impl SyncIndexer for ConcurrentRadixTreeCompressed { ...@@ -1223,6 +1349,9 @@ impl SyncIndexer for ConcurrentRadixTreeCompressed {
WorkerTask::RemoveWorkerDpRank(worker_id, dp_rank) => { WorkerTask::RemoveWorkerDpRank(worker_id, dp_rank) => {
self.remove_worker_dp_rank(&mut lookup, worker_id, dp_rank); self.remove_worker_dp_rank(&mut lookup, worker_id, dp_rank);
} }
WorkerTask::CleanupStaleChildren => {
self.run_cleanup_task();
}
WorkerTask::DumpEvents(_sender) => { WorkerTask::DumpEvents(_sender) => {
let _ = _sender.send(Ok(Vec::new())); let _ = _sender.send(Ok(Vec::new()));
} }
...@@ -1240,6 +1369,24 @@ impl SyncIndexer for ConcurrentRadixTreeCompressed { ...@@ -1240,6 +1369,24 @@ impl SyncIndexer for ConcurrentRadixTreeCompressed {
self.find_matches_impl(sequence, early_exit) self.find_matches_impl(sequence, early_exit)
} }
fn try_schedule_cleanup(&self) -> bool {
self.cleanup.try_schedule()
}
fn cancel_scheduled_cleanup(&self) {
self.cleanup.cancel();
}
fn run_cleanup_task(&self) {
let mut cleanup_guard = CleanupGuard {
state: &self.cleanup,
completed_elapsed_ms: None,
};
self.cleanup_stale_children();
cleanup_guard.completed_elapsed_ms = Some(self.cleanup.elapsed_ms());
}
fn dump_events(&self) -> Option<Vec<RouterEvent>> { fn dump_events(&self) -> Option<Vec<RouterEvent>> {
Some(self.dump_tree_as_events()) Some(self.dump_tree_as_events())
} }
......
...@@ -166,6 +166,9 @@ impl SyncIndexer for PositionalIndexer { ...@@ -166,6 +166,9 @@ impl SyncIndexer for PositionalIndexer {
WorkerTask::RemoveWorkerDpRank(worker_id, dp_rank) => { WorkerTask::RemoveWorkerDpRank(worker_id, dp_rank) => {
self.remove_worker_dp_rank_impl(&mut worker_blocks, worker_id, dp_rank); self.remove_worker_dp_rank_impl(&mut worker_blocks, worker_id, dp_rank);
} }
WorkerTask::CleanupStaleChildren => {
self.run_cleanup_task();
}
WorkerTask::DumpEvents(sender) => { WorkerTask::DumpEvents(sender) => {
let events = self.dump_events(&worker_blocks); let events = self.dump_events(&worker_blocks);
if let Err(e) = sender.send(Ok(events)) { if let Err(e) = sender.send(Ok(events)) {
......
...@@ -460,6 +460,41 @@ mod interface_tests { ...@@ -460,6 +460,41 @@ mod interface_tests {
); );
} }
#[tokio::test]
async fn test_concurrent_compressed_cleanup_prunes_dead_children_under_live_prefix() {
let index = ThreadPoolIndexer::new(ConcurrentRadixTreeCompressed::new(), 1, 32);
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
index
.apply_event(make_store_event_with_parent(0, &[1, 2, 3], &[4, 5]))
.await;
index
.apply_event(make_store_event_with_parent(0, &[1, 2, 3], &[6, 7]))
.await;
flush_and_settle(&index).await;
index
.apply_event(make_remove_event_with_parent(0, &[1, 2, 3], &[4, 5]))
.await;
index
.apply_event(make_remove_event_with_parent(0, &[1, 2, 3], &[6, 7]))
.await;
flush_and_settle(&index).await;
let expected_snapshot = vec![make_store_event(0, &[1, 2, 3])];
assert_eq!(snapshot_tree(&index).await, expected_snapshot);
assert_eq!(index.backend().raw_child_edge_count(), 3);
index.backend().run_cleanup_for_test();
assert_eq!(index.backend().raw_child_edge_count(), 1);
assert_eq!(
snapshot_tree(&index).await,
vec![make_store_event(0, &[1, 2, 3])]
);
assert_score(&index, &[1, 2, 3], WorkerWithDpRank::new(0, 0), 3).await;
}
#[tokio::test] #[tokio::test]
#[apply(indexer_template)] #[apply(indexer_template)]
async fn test_partial_match(variant: &str) { async fn test_partial_match(variant: &str) {
......
...@@ -147,6 +147,23 @@ impl<T: SyncIndexer> ThreadPoolIndexer<T> { ...@@ -147,6 +147,23 @@ impl<T: SyncIndexer> ThreadPoolIndexer<T> {
tokio::time::sleep(Duration::from_millis(1)).await; tokio::time::sleep(Duration::from_millis(1)).await;
} }
} }
fn maybe_enqueue_cleanup(&self, thread_idx: usize) {
if !self.backend.try_schedule_cleanup() {
return;
}
if let Err(e) =
self.worker_event_channels[thread_idx].send(WorkerTask::CleanupStaleChildren)
{
self.backend.cancel_scheduled_cleanup();
tracing::error!(
"Failed to send cleanup task to worker thread {}: {:?}",
thread_idx,
e
);
}
}
} }
impl<T: SyncIndexer> Drop for ThreadPoolIndexer<T> { impl<T: SyncIndexer> Drop for ThreadPoolIndexer<T> {
...@@ -217,7 +234,10 @@ impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> { ...@@ -217,7 +234,10 @@ impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> {
thread_idx, thread_idx,
e e
); );
return;
} }
self.maybe_enqueue_cleanup(thread_idx);
} }
async fn remove_worker(&self, worker_id: WorkerId) { async fn remove_worker(&self, worker_id: WorkerId) {
...@@ -234,13 +254,17 @@ impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> { ...@@ -234,13 +254,17 @@ impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> {
idx, idx,
e e
); );
return;
} }
self.maybe_enqueue_cleanup(idx);
} }
None => { None => {
// Worker was never assigned a thread - broadcast to all // Worker was never assigned a thread - broadcast to all
for channel in &self.worker_event_channels { for channel in &self.worker_event_channels {
let _ = channel.send(WorkerTask::RemoveWorker(worker_id)); let _ = channel.send(WorkerTask::RemoveWorker(worker_id));
} }
self.maybe_enqueue_cleanup(0);
} }
} }
} }
...@@ -251,6 +275,7 @@ impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> { ...@@ -251,6 +275,7 @@ impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> {
for channel in &self.worker_event_channels { for channel in &self.worker_event_channels {
let _ = channel.send(WorkerTask::RemoveWorkerDpRank(worker_id, dp_rank)); let _ = channel.send(WorkerTask::RemoveWorkerDpRank(worker_id, dp_rank));
} }
self.maybe_enqueue_cleanup(0);
} }
fn shutdown(&self) { fn shutdown(&self) {
......
...@@ -118,6 +118,17 @@ pub trait SyncIndexer: Send + Sync + 'static { ...@@ -118,6 +118,17 @@ pub trait SyncIndexer: Send + Sync + 'static {
/// Find matches for a sequence of block hashes. /// Find matches for a sequence of block hashes.
fn find_matches(&self, sequence: &[LocalBlockHash], early_exit: bool) -> OverlapScores; fn find_matches(&self, sequence: &[LocalBlockHash], early_exit: bool) -> OverlapScores;
/// Returns true when a maintenance task should be enqueued.
fn try_schedule_cleanup(&self) -> bool {
false
}
/// Rolls back a scheduled cleanup when enqueueing the task fails.
fn cancel_scheduled_cleanup(&self) {}
/// Executes a maintenance task on a worker thread.
fn run_cleanup_task(&self) {}
/// Dump events directly from the shared structure, bypassing worker channels. /// Dump events directly from the shared structure, bypassing worker channels.
/// Returns `Some(events)` for backends whose tree state is fully shared (e.g. /// Returns `Some(events)` for backends whose tree state is fully shared (e.g.
/// ConcurrentRadixTree). Returns `None` for backends that keep per-thread /// ConcurrentRadixTree). Returns `None` for backends that keep per-thread
......
...@@ -297,6 +297,8 @@ pub enum WorkerTask { ...@@ -297,6 +297,8 @@ pub enum WorkerTask {
RemoveWorker(WorkerId), RemoveWorker(WorkerId),
/// Remove a single dp_rank for a worker. /// Remove a single dp_rank for a worker.
RemoveWorkerDpRank(WorkerId, DpRank), RemoveWorkerDpRank(WorkerId, DpRank),
/// Best-effort maintenance task for shared-state backends.
CleanupStaleChildren,
DumpEvents(oneshot::Sender<anyhow::Result<Vec<RouterEvent>>>), DumpEvents(oneshot::Sender<anyhow::Result<Vec<RouterEvent>>>),
Terminate, Terminate,
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment