Unverified Commit 0652441f authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

fix(kv-router): prune stale compressed-tree children (#8127)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 67bc6a1f
......@@ -643,6 +643,9 @@ impl SyncIndexer for ConcurrentRadixTree {
WorkerTask::RemoveWorkerDpRank(worker_id, dp_rank) => {
self.remove_worker_dp_rank(&mut lookup, worker_id, dp_rank);
}
WorkerTask::CleanupStaleChildren => {
self.run_cleanup_task();
}
WorkerTask::DumpEvents(_sender) => {
// Handled directly via dump_events() on the shared tree.
// Should not be reached, but respond with empty to avoid blocking.
......
......@@ -59,13 +59,14 @@
//! - `new_with_frequency()` is not provided
//! - `find_matches` does not populate `OverlapScores.frequencies`
use std::sync::Arc;
use std::sync::{Arc, Weak};
use std::time::Instant;
use dashmap::DashMap;
use parking_lot::RwLock;
use rustc_hash::{FxBuildHasher, FxHashMap, FxHashSet};
use std::collections::VecDeque;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use super::{EventKind, KvIndexerMetrics, SyncIndexer, WorkerTask};
use crate::protocols::*;
......@@ -86,6 +87,8 @@ type SharedNode = Arc<RwLock<Node>>;
/// stored here, keeping the map compact and correct across concurrent splits.
type WorkerLookup = FxHashMap<ExternalSequenceBlockHash, SharedNode>;
const CLEANUP_INTERVAL_MS: u64 = 5 * 60 * 1000;
/// A node in the concurrent radix tree.
///
/// Stores a compressed edge with per-worker match indices. Workers with full coverage
......@@ -237,12 +240,71 @@ struct RemoveOutcome {
stale_hashes: Vec<ExternalSequenceBlockHash>,
}
struct CleanupEdge {
parent: Weak<RwLock<Node>>,
key: LocalBlockHash,
child: Weak<RwLock<Node>>,
}
struct CleanupState {
clock_origin: Instant,
last_cleanup_elapsed_ms: AtomicU64,
scheduled: AtomicBool,
}
impl CleanupState {
fn new() -> Self {
Self {
clock_origin: Instant::now(),
last_cleanup_elapsed_ms: AtomicU64::new(0),
scheduled: AtomicBool::new(false),
}
}
fn elapsed_ms(&self) -> u64 {
self.clock_origin.elapsed().as_millis() as u64
}
fn try_schedule(&self) -> bool {
let now_ms = self.elapsed_ms();
let last_ms = self.last_cleanup_elapsed_ms.load(Ordering::Relaxed);
if now_ms.saturating_sub(last_ms) < CLEANUP_INTERVAL_MS {
return false;
}
self.scheduled
.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
.is_ok()
}
fn cancel(&self) {
self.scheduled.store(false, Ordering::Release);
}
}
struct CleanupGuard<'a> {
state: &'a CleanupState,
completed_elapsed_ms: Option<u64>,
}
impl Drop for CleanupGuard<'_> {
fn drop(&mut self) {
if let Some(elapsed_ms) = self.completed_elapsed_ms {
self.state
.last_cleanup_elapsed_ms
.store(elapsed_ms, Ordering::Relaxed);
}
self.state.scheduled.store(false, Ordering::Release);
}
}
/// Thread-safe radix tree (compressed trie) for concurrent KV cache lookups.
pub struct ConcurrentRadixTreeCompressed {
/// The root of the radix tree. Has an empty edge and only contains children.
root: SharedNode,
tree_sizes: DashMap<WorkerWithDpRank, AtomicUsize, FxBuildHasher>,
cleanup: CleanupState,
}
impl Default for ConcurrentRadixTreeCompressed {
......@@ -274,9 +336,73 @@ impl ConcurrentRadixTreeCompressed {
Self {
root: Arc::new(RwLock::new(Node::new())),
tree_sizes: DashMap::with_hasher(FxBuildHasher),
cleanup: CleanupState::new(),
}
}
fn cleanup_stale_children(&self) {
let mut queue = VecDeque::from([self.root.clone()]);
let mut edges = Vec::new();
while let Some(parent) = queue.pop_front() {
let guard = parent.read();
for (&key, child) in &guard.children {
queue.push_back(child.clone());
edges.push(CleanupEdge {
parent: Arc::downgrade(&parent),
key,
child: Arc::downgrade(child),
});
}
}
for edge in edges.into_iter().rev() {
let (Some(parent), Some(child)) = (edge.parent.upgrade(), edge.child.upgrade()) else {
continue;
};
let mut parent_guard = parent.write();
let Some(current) = parent_guard.children.get(&edge.key) else {
continue;
};
if !Arc::ptr_eq(current, &child) {
continue;
}
let Some(child_guard) = child.try_write() else {
continue;
};
if child_guard.has_any_workers() || !child_guard.children.is_empty() {
continue;
}
if Arc::strong_count(&child) != 2 {
continue;
}
parent_guard.children.remove(&edge.key);
drop(child_guard);
}
}
#[cfg(test)]
pub(crate) fn raw_child_edge_count(&self) -> usize {
let mut queue = VecDeque::from([self.root.clone()]);
let mut count = 0usize;
while let Some(node) = queue.pop_front() {
let guard = node.read();
count += guard.children.len();
queue.extend(guard.children.values().cloned());
}
count
}
#[cfg(test)]
pub(crate) fn run_cleanup_for_test(&self) {
self.cleanup_stale_children();
}
// ------------------------------------------------------------------
// Lookup resolution helpers
// ------------------------------------------------------------------
......@@ -1223,6 +1349,9 @@ impl SyncIndexer for ConcurrentRadixTreeCompressed {
WorkerTask::RemoveWorkerDpRank(worker_id, dp_rank) => {
self.remove_worker_dp_rank(&mut lookup, worker_id, dp_rank);
}
WorkerTask::CleanupStaleChildren => {
self.run_cleanup_task();
}
WorkerTask::DumpEvents(_sender) => {
let _ = _sender.send(Ok(Vec::new()));
}
......@@ -1240,6 +1369,24 @@ impl SyncIndexer for ConcurrentRadixTreeCompressed {
self.find_matches_impl(sequence, early_exit)
}
fn try_schedule_cleanup(&self) -> bool {
self.cleanup.try_schedule()
}
fn cancel_scheduled_cleanup(&self) {
self.cleanup.cancel();
}
fn run_cleanup_task(&self) {
let mut cleanup_guard = CleanupGuard {
state: &self.cleanup,
completed_elapsed_ms: None,
};
self.cleanup_stale_children();
cleanup_guard.completed_elapsed_ms = Some(self.cleanup.elapsed_ms());
}
fn dump_events(&self) -> Option<Vec<RouterEvent>> {
Some(self.dump_tree_as_events())
}
......
......@@ -166,6 +166,9 @@ impl SyncIndexer for PositionalIndexer {
WorkerTask::RemoveWorkerDpRank(worker_id, dp_rank) => {
self.remove_worker_dp_rank_impl(&mut worker_blocks, worker_id, dp_rank);
}
WorkerTask::CleanupStaleChildren => {
self.run_cleanup_task();
}
WorkerTask::DumpEvents(sender) => {
let events = self.dump_events(&worker_blocks);
if let Err(e) = sender.send(Ok(events)) {
......
......@@ -460,6 +460,41 @@ mod interface_tests {
);
}
#[tokio::test]
async fn test_concurrent_compressed_cleanup_prunes_dead_children_under_live_prefix() {
let index = ThreadPoolIndexer::new(ConcurrentRadixTreeCompressed::new(), 1, 32);
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
index
.apply_event(make_store_event_with_parent(0, &[1, 2, 3], &[4, 5]))
.await;
index
.apply_event(make_store_event_with_parent(0, &[1, 2, 3], &[6, 7]))
.await;
flush_and_settle(&index).await;
index
.apply_event(make_remove_event_with_parent(0, &[1, 2, 3], &[4, 5]))
.await;
index
.apply_event(make_remove_event_with_parent(0, &[1, 2, 3], &[6, 7]))
.await;
flush_and_settle(&index).await;
let expected_snapshot = vec![make_store_event(0, &[1, 2, 3])];
assert_eq!(snapshot_tree(&index).await, expected_snapshot);
assert_eq!(index.backend().raw_child_edge_count(), 3);
index.backend().run_cleanup_for_test();
assert_eq!(index.backend().raw_child_edge_count(), 1);
assert_eq!(
snapshot_tree(&index).await,
vec![make_store_event(0, &[1, 2, 3])]
);
assert_score(&index, &[1, 2, 3], WorkerWithDpRank::new(0, 0), 3).await;
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_partial_match(variant: &str) {
......
......@@ -147,6 +147,23 @@ impl<T: SyncIndexer> ThreadPoolIndexer<T> {
tokio::time::sleep(Duration::from_millis(1)).await;
}
}
fn maybe_enqueue_cleanup(&self, thread_idx: usize) {
if !self.backend.try_schedule_cleanup() {
return;
}
if let Err(e) =
self.worker_event_channels[thread_idx].send(WorkerTask::CleanupStaleChildren)
{
self.backend.cancel_scheduled_cleanup();
tracing::error!(
"Failed to send cleanup task to worker thread {}: {:?}",
thread_idx,
e
);
}
}
}
impl<T: SyncIndexer> Drop for ThreadPoolIndexer<T> {
......@@ -217,7 +234,10 @@ impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> {
thread_idx,
e
);
return;
}
self.maybe_enqueue_cleanup(thread_idx);
}
async fn remove_worker(&self, worker_id: WorkerId) {
......@@ -234,13 +254,17 @@ impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> {
idx,
e
);
return;
}
self.maybe_enqueue_cleanup(idx);
}
None => {
// Worker was never assigned a thread - broadcast to all
for channel in &self.worker_event_channels {
let _ = channel.send(WorkerTask::RemoveWorker(worker_id));
}
self.maybe_enqueue_cleanup(0);
}
}
}
......@@ -251,6 +275,7 @@ impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> {
for channel in &self.worker_event_channels {
let _ = channel.send(WorkerTask::RemoveWorkerDpRank(worker_id, dp_rank));
}
self.maybe_enqueue_cleanup(0);
}
fn shutdown(&self) {
......
......@@ -118,6 +118,17 @@ pub trait SyncIndexer: Send + Sync + 'static {
/// Find matches for a sequence of block hashes.
fn find_matches(&self, sequence: &[LocalBlockHash], early_exit: bool) -> OverlapScores;
/// Returns true when a maintenance task should be enqueued.
fn try_schedule_cleanup(&self) -> bool {
false
}
/// Rolls back a scheduled cleanup when enqueueing the task fails.
fn cancel_scheduled_cleanup(&self) {}
/// Executes a maintenance task on a worker thread.
fn run_cleanup_task(&self) {}
/// Dump events directly from the shared structure, bypassing worker channels.
/// Returns `Some(events)` for backends whose tree state is fully shared (e.g.
/// ConcurrentRadixTree). Returns `None` for backends that keep per-thread
......
......@@ -297,6 +297,8 @@ pub enum WorkerTask {
RemoveWorker(WorkerId),
/// Remove a single dp_rank for a worker.
RemoveWorkerDpRank(WorkerId, DpRank),
/// Best-effort maintenance task for shared-state backends.
CleanupStaleChildren,
DumpEvents(oneshot::Sender<anyhow::Result<Vec<RouterEvent>>>),
Terminate,
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment