cleanup.rs 5.46 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

//! Shared stale-child cleanup machinery for rooted tree structures.
//!
//! Provides a throttled, one-in-flight sweep that unlinks empty child nodes
//! from their parent. It is used by [`ConcurrentRadixTree`](super::concurrent_radix_tree),
//! [`ConcurrentRadixTreeCompressed`](super::concurrent_radix_tree_compressed)
//! and the sequence-side
//! [`PromptMembershipTrie`](super::sequences::prompt_membership_trie::PromptMembershipTrie),
//! each of which embeds a [`CleanupState`] and implements [`CleanableNode`]
//! for its node type.
//!
//! # Sweep semantics
//!
//! [`sweep_stale_children`] is a reverse-BFS prune:
//! - BFS from the root under read locks, collecting `(parent_weak, key, child_weak)` edges.
//! - Iterate edges deepest-first so children are swept before parents.
//! - For each edge: upgrade weaks, take the parent write lock, verify the
//!   child pointer still matches, `try_write` the child, and unlink only when
//!   the child has no workers, no children, and `Arc::strong_count == 2`
//!   (parent map ref + our local upgrade). The strong-count gate is what
//!   prevents reclaiming a node that a concurrent `find_matches` is currently
//!   traversing — such edges are skipped and retried on the next sweep.

use std::collections::VecDeque;
use std::hash::Hash;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Weak};
use std::time::Instant;

use parking_lot::RwLock;
use rustc_hash::FxHashMap;

pub const CLEANUP_INTERVAL_MS: u64 = 5 * 60 * 1000;

/// Node type that participates in the reverse-BFS cleanup sweep.
pub trait CleanableNode: Sized + Send + Sync + 'static {
    /// Key type used in this node's children map (e.g. `LocalBlockHash`,
    /// `SequenceHash`).
    type ChildKey: Copy + Eq + Hash + Send + Sync + 'static;

    /// True if this node still carries worker state that pins it in the tree.
    fn has_any_workers(&self) -> bool;

    /// Read-only view of this node's children keyed by the first edge element.
    fn children(&self) -> &FxHashMap<Self::ChildKey, Arc<RwLock<Self>>>;

    /// Unlink a child edge.
    fn remove_child(&mut self, key: &Self::ChildKey);
}

pub struct CleanupState {
    clock_origin: Instant,
    last_cleanup_elapsed_ms: AtomicU64,
    scheduled: AtomicBool,
}

impl CleanupState {
    pub fn new() -> Self {
        Self {
            clock_origin: Instant::now(),
            last_cleanup_elapsed_ms: AtomicU64::new(0),
            scheduled: AtomicBool::new(false),
        }
    }

    pub fn elapsed_ms(&self) -> u64 {
        self.clock_origin.elapsed().as_millis() as u64
    }

    pub fn try_schedule(&self) -> bool {
        let now_ms = self.elapsed_ms();
        let last_ms = self.last_cleanup_elapsed_ms.load(Ordering::Relaxed);
        if now_ms.saturating_sub(last_ms) < CLEANUP_INTERVAL_MS {
            return false;
        }

        self.scheduled
            .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
            .is_ok()
    }

    pub fn cancel(&self) {
        self.scheduled.store(false, Ordering::Release);
    }
}

impl Default for CleanupState {
    fn default() -> Self {
        Self::new()
    }
}

pub struct CleanupGuard<'a> {
    state: &'a CleanupState,
    completed_elapsed_ms: Option<u64>,
}

impl<'a> CleanupGuard<'a> {
    pub fn new(state: &'a CleanupState) -> Self {
        Self {
            state,
            completed_elapsed_ms: None,
        }
    }

    pub fn mark_completed(&mut self) {
        self.completed_elapsed_ms = Some(self.state.elapsed_ms());
    }
}

impl Drop for CleanupGuard<'_> {
    fn drop(&mut self) {
        if let Some(elapsed_ms) = self.completed_elapsed_ms {
            self.state
                .last_cleanup_elapsed_ms
                .store(elapsed_ms, Ordering::Relaxed);
        }
        self.state.scheduled.store(false, Ordering::Release);
    }
}

struct CleanupEdge<N: CleanableNode> {
    parent: Weak<RwLock<N>>,
    key: N::ChildKey,
    child: Weak<RwLock<N>>,
}

/// Reverse-BFS sweep that unlinks empty, unreferenced leaf nodes from the tree.
pub fn sweep_stale_children<N: CleanableNode>(root: &Arc<RwLock<N>>) {
    let mut queue: VecDeque<Arc<RwLock<N>>> = VecDeque::from([root.clone()]);
    let mut edges: Vec<CleanupEdge<N>> = Vec::new();

    while let Some(parent) = queue.pop_front() {
        let guard = parent.read();
        for (&key, child) in guard.children() {
            queue.push_back(child.clone());
            edges.push(CleanupEdge {
                parent: Arc::downgrade(&parent),
                key,
                child: Arc::downgrade(child),
            });
        }
    }

    for edge in edges.into_iter().rev() {
        let (Some(parent), Some(child)) = (edge.parent.upgrade(), edge.child.upgrade()) else {
            continue;
        };

        let mut parent_guard = parent.write();
        let still_attached = parent_guard
            .children()
            .get(&edge.key)
            .is_some_and(|current| Arc::ptr_eq(current, &child));
        if !still_attached {
            continue;
        }

        let Some(child_guard) = child.try_write() else {
            continue;
        };
        if child_guard.has_any_workers() || !child_guard.children().is_empty() {
            continue;
        }
        if Arc::strong_count(&child) != 2 {
            continue;
        }

        parent_guard.remove_child(&edge.key);
        drop(child_guard);
    }
}