Unverified Commit 2fc65ad8 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: dump radix tree as router events (#2057)

parent 13d3cc13
...@@ -26,8 +26,8 @@ use tokio_util::sync::CancellationToken; ...@@ -26,8 +26,8 @@ use tokio_util::sync::CancellationToken;
use crate::tokens::TokenBlockSequence; use crate::tokens::TokenBlockSequence;
use crate::kv_router::indexer::{ use crate::kv_router::indexer::{
compute_block_hash_for_seq, KvIndexerInterface, KvRouterError, OverlapScores, RadixTree, compute_block_hash_for_seq, DumpRequest, KvIndexerInterface, KvRouterError, OverlapScores,
WorkerId, RadixTree, WorkerId,
}; };
use crate::kv_router::protocols::{ use crate::kv_router::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData, ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData,
...@@ -172,6 +172,8 @@ pub struct ApproxKvIndexer { ...@@ -172,6 +172,8 @@ pub struct ApproxKvIndexer {
route_tx: mpsc::Sender<RouterResult>, route_tx: mpsc::Sender<RouterResult>,
/// A sender for remove worker requests. /// A sender for remove worker requests.
remove_worker_tx: mpsc::Sender<WorkerId>, remove_worker_tx: mpsc::Sender<WorkerId>,
/// A sender for dump requests.
dump_tx: mpsc::Sender<DumpRequest>,
/// A handle to the background task managing the KV store. /// A handle to the background task managing the KV store.
task: OnceLock<std::thread::JoinHandle<()>>, task: OnceLock<std::thread::JoinHandle<()>>,
/// The size of the KV block this indexer can handle. /// The size of the KV block this indexer can handle.
...@@ -183,6 +185,7 @@ impl ApproxKvIndexer { ...@@ -183,6 +185,7 @@ impl ApproxKvIndexer {
let (match_tx, mut match_rx) = mpsc::channel::<MatchRequest>(2048); let (match_tx, mut match_rx) = mpsc::channel::<MatchRequest>(2048);
let (route_tx, mut route_rx) = mpsc::channel::<RouterResult>(2048); let (route_tx, mut route_rx) = mpsc::channel::<RouterResult>(2048);
let (remove_worker_tx, mut remove_worker_rx) = mpsc::channel::<WorkerId>(16); let (remove_worker_tx, mut remove_worker_rx) = mpsc::channel::<WorkerId>(16);
let (dump_tx, mut dump_rx) = mpsc::channel::<DumpRequest>(16);
let cancel_clone = token.clone(); let cancel_clone = token.clone();
let task = std::thread::spawn(move || { let task = std::thread::spawn(move || {
// create a new tokio runtime which will only perform work on a single thread // create a new tokio runtime which will only perform work on a single thread
...@@ -240,6 +243,10 @@ impl ApproxKvIndexer { ...@@ -240,6 +243,10 @@ impl ApproxKvIndexer {
Some(worker) = remove_worker_rx.recv() => { Some(worker) = remove_worker_rx.recv() => {
trie.remove_worker(worker); trie.remove_worker(worker);
} }
Some(dump_req) = dump_rx.recv() => {
let events = trie.dump_tree_as_events();
let _ = dump_req.resp.send(events);
}
_ = expiry_fut => { _ = expiry_fut => {
let expired = timer_manager.pop_expired(); let expired = timer_manager.pop_expired();
...@@ -278,6 +285,7 @@ impl ApproxKvIndexer { ...@@ -278,6 +285,7 @@ impl ApproxKvIndexer {
match_tx, match_tx,
route_tx, route_tx,
remove_worker_tx, remove_worker_tx,
dump_tx,
task: once, task: once,
kv_block_size, kv_block_size,
} }
...@@ -355,6 +363,20 @@ impl KvIndexerInterface for ApproxKvIndexer { ...@@ -355,6 +363,20 @@ impl KvIndexerInterface for ApproxKvIndexer {
self.remove_worker_tx.send(worker).await.unwrap(); self.remove_worker_tx.send(worker).await.unwrap();
} }
async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
let (resp_tx, resp_rx) = oneshot::channel();
let dump_req = DumpRequest { resp: resp_tx };
if let Err(e) = self.dump_tx.send(dump_req).await {
tracing::error!("Failed to send dump request: {:?}", e);
return Err(KvRouterError::IndexerOffline);
}
resp_rx
.await
.map_err(|_| KvRouterError::IndexerDroppedRequest)
}
fn shutdown(&mut self) { fn shutdown(&mut self) {
self.cancel.cancel(); self.cancel.cancel();
if let Some(task) = self.task.take() { if let Some(task) = self.task.take() {
......
...@@ -403,6 +403,81 @@ impl RadixTree { ...@@ -403,6 +403,81 @@ impl RadixTree {
} }
} }
} }
/// Dump the radix tree as a series of RouterEvents that can reconstruct the tree.
/// Uses BFS traversal to ensure that the tree reconstruction is unique,
/// though the exact event ordering will be lost.
pub fn dump_tree_as_events(&self) -> Vec<RouterEvent> {
let mut events = Vec::new();
let mut event_id = 0u64;
// BFS queue: (current_block, parent_external_hash, tokens_hash)
let mut queue = VecDeque::new();
// Process root's children first
let root_borrow = self.root.borrow();
for (tokens_hash, child_block) in &root_borrow.children {
queue.push_back((child_block.clone(), None, *tokens_hash));
}
drop(root_borrow);
while let Some((current_block, parent_external_hash, tokens_hash)) = queue.pop_front() {
let current_borrow = current_block.borrow();
// Closure to find external hash for a block in a worker's lookup
let find_external_hash = |worker_id: &WorkerId| {
self.lookup.get(worker_id).and_then(|worker_blocks| {
worker_blocks
.iter()
.find(|(_, block)| Rc::ptr_eq(block, &current_block))
.map(|(hash, _)| *hash)
})
};
// For each worker that has this block
for worker_id in &current_borrow.workers {
// Find the external hash for this block from the worker's lookup
let external_hash = find_external_hash(worker_id);
if let Some(block_hash) = external_hash {
// Create a store event for this worker
let event = RouterEvent {
worker_id: *worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: parent_external_hash,
blocks: vec![KvCacheStoredBlockData {
block_hash,
tokens_hash,
}],
}),
},
};
events.push(event);
event_id += 1;
}
}
// Add children to queue for BFS traversal
// We need to find any external hash for this block to use as parent
let any_external_hash = if !current_borrow.workers.is_empty() {
current_borrow
.workers
.iter()
.next()
.and_then(find_external_hash)
} else {
None
};
for (child_tokens_hash, child_block) in &current_borrow.children {
queue.push_back((child_block.clone(), any_external_hash, *child_tokens_hash));
}
}
events
}
} }
/// Scores representing the overlap of workers. /// Scores representing the overlap of workers.
...@@ -466,6 +541,12 @@ pub struct MatchRequest { ...@@ -466,6 +541,12 @@ pub struct MatchRequest {
resp: oneshot::Sender<OverlapScores>, resp: oneshot::Sender<OverlapScores>,
} }
/// A request to dump the tree as events
pub struct DumpRequest {
/// Channel to send the dumped events
pub resp: oneshot::Sender<Vec<RouterEvent>>,
}
#[async_trait] #[async_trait]
pub trait KvIndexerInterface { pub trait KvIndexerInterface {
/// Find matches for a given sequence of `LocalBlockHash`es. /// Find matches for a given sequence of `LocalBlockHash`es.
...@@ -512,6 +593,13 @@ pub trait KvIndexerInterface { ...@@ -512,6 +593,13 @@ pub trait KvIndexerInterface {
/// Shutdown the KV Indexer. /// Shutdown the KV Indexer.
fn shutdown(&mut self); fn shutdown(&mut self);
/// Dump the entire tree as RouterEvents.
///
/// ### Returns
///
/// A vector of RouterEvents representing the current state of the tree.
async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError>;
} }
/// The KV Indexer, managing the KV store and handling events and match requests. /// The KV Indexer, managing the KV store and handling events and match requests.
...@@ -524,6 +612,8 @@ pub struct KvIndexer { ...@@ -524,6 +612,8 @@ pub struct KvIndexer {
match_tx: mpsc::Sender<MatchRequest>, match_tx: mpsc::Sender<MatchRequest>,
/// A sender for remove worker requests. /// A sender for remove worker requests.
remove_worker_tx: mpsc::Sender<WorkerId>, remove_worker_tx: mpsc::Sender<WorkerId>,
/// A sender for dump requests.
dump_tx: mpsc::Sender<DumpRequest>,
/// A handle to the background task managing the KV store. /// A handle to the background task managing the KV store.
task: OnceLock<std::thread::JoinHandle<()>>, task: OnceLock<std::thread::JoinHandle<()>>,
/// The size of the KV block this indexer can handle. /// The size of the KV block this indexer can handle.
...@@ -549,6 +639,7 @@ impl KvIndexer { ...@@ -549,6 +639,7 @@ impl KvIndexer {
let (event_tx, event_rx) = mpsc::channel::<RouterEvent>(2048); let (event_tx, event_rx) = mpsc::channel::<RouterEvent>(2048);
let (match_tx, match_rx) = mpsc::channel::<MatchRequest>(128); let (match_tx, match_rx) = mpsc::channel::<MatchRequest>(128);
let (remove_worker_tx, remove_worker_rx) = mpsc::channel::<WorkerId>(16); let (remove_worker_tx, remove_worker_rx) = mpsc::channel::<WorkerId>(16);
let (dump_tx, dump_rx) = mpsc::channel::<DumpRequest>(16);
let cancel_clone = token.clone(); let cancel_clone = token.clone();
let task = std::thread::spawn(move || { let task = std::thread::spawn(move || {
// create a new tokio runtime which will only perform work on a single thread // create a new tokio runtime which will only perform work on a single thread
...@@ -566,6 +657,7 @@ impl KvIndexer { ...@@ -566,6 +657,7 @@ impl KvIndexer {
let mut match_rx = match_rx; let mut match_rx = match_rx;
let mut event_rx = event_rx; let mut event_rx = event_rx;
let mut remove_worker_rx = remove_worker_rx; let mut remove_worker_rx = remove_worker_rx;
let mut dump_rx = dump_rx;
let mut trie = RadixTree::new_with_frequency(expiration_duration); let mut trie = RadixTree::new_with_frequency(expiration_duration);
loop { loop {
tokio::select! { tokio::select! {
...@@ -580,6 +672,11 @@ impl KvIndexer { ...@@ -580,6 +672,11 @@ impl KvIndexer {
let _ = req.resp.send(matches); let _ = req.resp.send(matches);
} }
Some(dump_req) = dump_rx.recv() => {
let events = trie.dump_tree_as_events();
let _ = dump_req.resp.send(events);
}
_ = cancel.cancelled() => { _ = cancel.cancelled() => {
tracing::debug!("KvCacheIndexer progress loop shutting down"); tracing::debug!("KvCacheIndexer progress loop shutting down");
return; return;
...@@ -606,6 +703,7 @@ impl KvIndexer { ...@@ -606,6 +703,7 @@ impl KvIndexer {
event_tx, event_tx,
match_tx, match_tx,
remove_worker_tx, remove_worker_tx,
dump_tx,
task: once, task: once,
kv_block_size, kv_block_size,
} }
...@@ -683,6 +781,20 @@ impl KvIndexerInterface for KvIndexer { ...@@ -683,6 +781,20 @@ impl KvIndexerInterface for KvIndexer {
task.join().expect("Failed to join kv indexer task"); task.join().expect("Failed to join kv indexer task");
} }
} }
async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
let (resp_tx, resp_rx) = oneshot::channel();
let dump_req = DumpRequest { resp: resp_tx };
if let Err(e) = self.dump_tx.send(dump_req).await {
tracing::error!("Failed to send dump request: {:?}", e);
return Err(KvRouterError::IndexerOffline);
}
resp_rx
.await
.map_err(|_| KvRouterError::IndexerDroppedRequest)
}
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
...@@ -692,7 +804,20 @@ pub struct ShardedMatchRequest { ...@@ -692,7 +804,20 @@ pub struct ShardedMatchRequest {
resp: mpsc::Sender<OverlapScores>, resp: mpsc::Sender<OverlapScores>,
} }
/// The KV Indexer, managing the KV store and handling events and match requests. /// A sharded KV Indexer that partitions the RadixTree across multiple independent shards.
///
/// ## Sharding Strategy
/// - Each worker is **permanently assigned** to a single shard on first event
/// - All KV blocks from a worker exist only in that worker's assigned shard
/// - New workers are assigned to the shard with the fewest workers (load balancing)
///
/// ## Operation
/// - **Events**: Routed directly to the worker's assigned shard
/// - **Match requests**: Broadcast to all shards (scatter-gather pattern)
/// - **Threading**: Each shard runs in its own thread with a single-threaded runtime
///
/// This design ensures no cross-shard synchronization for writes while enabling
/// parallel processing and better scalability.
pub struct KvIndexerSharded { pub struct KvIndexerSharded {
/// A `CancellationToken` for managing shutdown. /// A `CancellationToken` for managing shutdown.
cancel: CancellationToken, cancel: CancellationToken,
...@@ -704,6 +829,7 @@ pub struct KvIndexerSharded { ...@@ -704,6 +829,7 @@ pub struct KvIndexerSharded {
event_tx: Vec<mpsc::Sender<RouterEvent>>, event_tx: Vec<mpsc::Sender<RouterEvent>>,
request_broadcast_tx: broadcast::Sender<ShardedMatchRequest>, request_broadcast_tx: broadcast::Sender<ShardedMatchRequest>,
remove_worker_tx: Vec<mpsc::Sender<WorkerId>>, remove_worker_tx: Vec<mpsc::Sender<WorkerId>>,
dump_tx: Vec<mpsc::Sender<DumpRequest>>,
tasks: Vec<JoinHandle<()>>, tasks: Vec<JoinHandle<()>>,
} }
...@@ -730,6 +856,7 @@ impl KvIndexerSharded { ...@@ -730,6 +856,7 @@ impl KvIndexerSharded {
let mut event_tx = Vec::new(); let mut event_tx = Vec::new();
let mut remove_worker_tx = Vec::new(); let mut remove_worker_tx = Vec::new();
let mut dump_tx = Vec::new(); // Add dump channels
let mut tasks = Vec::new(); let mut tasks = Vec::new();
let (request_broadcast_tx, _) = broadcast::channel::<ShardedMatchRequest>(1048576); let (request_broadcast_tx, _) = broadcast::channel::<ShardedMatchRequest>(1048576);
...@@ -738,11 +865,13 @@ impl KvIndexerSharded { ...@@ -738,11 +865,13 @@ impl KvIndexerSharded {
let (shard_event_tx, mut shard_event_rx) = mpsc::channel::<RouterEvent>(2048); let (shard_event_tx, mut shard_event_rx) = mpsc::channel::<RouterEvent>(2048);
let (shard_remove_worker_tx, mut shard_remove_worker_rx) = let (shard_remove_worker_tx, mut shard_remove_worker_rx) =
mpsc::channel::<WorkerId>(16); mpsc::channel::<WorkerId>(16);
let (shard_dump_tx, mut shard_dump_rx) = mpsc::channel::<DumpRequest>(16); // Add dump channel
let mut shard_broadcast_rx = request_broadcast_tx.subscribe(); let mut shard_broadcast_rx = request_broadcast_tx.subscribe();
let cancel = token.clone(); let cancel = token.clone();
event_tx.push(shard_event_tx); event_tx.push(shard_event_tx);
remove_worker_tx.push(shard_remove_worker_tx); remove_worker_tx.push(shard_remove_worker_tx);
dump_tx.push(shard_dump_tx); // Store dump sender
let runtime = tokio::runtime::Builder::new_multi_thread() let runtime = tokio::runtime::Builder::new_multi_thread()
.worker_threads(1) .worker_threads(1)
...@@ -771,6 +900,11 @@ impl KvIndexerSharded { ...@@ -771,6 +900,11 @@ impl KvIndexerSharded {
} }
} }
Some(dump_req) = shard_dump_rx.recv() => {
let events = trie.dump_tree_as_events();
let _ = dump_req.resp.send(events);
}
_ = cancel.cancelled() => { _ = cancel.cancelled() => {
tracing::trace!("KvCacheIndexer progress loop shutting down"); tracing::trace!("KvCacheIndexer progress loop shutting down");
return; return;
...@@ -798,6 +932,7 @@ impl KvIndexerSharded { ...@@ -798,6 +932,7 @@ impl KvIndexerSharded {
event_tx, event_tx,
request_broadcast_tx, request_broadcast_tx,
remove_worker_tx, remove_worker_tx,
dump_tx, // Add dump_tx field
tasks, tasks,
} }
} }
...@@ -905,6 +1040,35 @@ impl KvIndexerInterface for KvIndexerSharded { ...@@ -905,6 +1040,35 @@ impl KvIndexerInterface for KvIndexerSharded {
self.tasks.pop().unwrap().join().unwrap(); self.tasks.pop().unwrap().join().unwrap();
} }
} }
async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
let mut all_events = Vec::new();
// Create channels for each shard
let mut receivers = Vec::new();
for shard_dump_tx in &self.dump_tx {
let (resp_tx, resp_rx) = oneshot::channel();
let dump_req = DumpRequest { resp: resp_tx };
if let Err(e) = shard_dump_tx.send(dump_req).await {
tracing::error!("Failed to send dump request to shard: {:?}", e);
return Err(KvRouterError::IndexerOffline);
}
receivers.push(resp_rx);
}
// Collect results from all shards
for resp_rx in receivers {
match resp_rx.await {
Ok(events) => all_events.extend(events),
Err(_) => return Err(KvRouterError::IndexerDroppedRequest),
}
}
Ok(all_events)
}
} }
#[cfg(test)] #[cfg(test)]
...@@ -1559,4 +1723,177 @@ mod tests { ...@@ -1559,4 +1723,177 @@ mod tests {
let overlap_scores: OverlapScores = Default::default(); let overlap_scores: OverlapScores = Default::default();
assert!(overlap_scores.scores.is_empty()); assert!(overlap_scores.scores.is_empty());
} }
#[tokio::test]
async fn test_dump_tree_as_events_round_trip() {
setup();
// Configuration
let kv_block_size = 32;
let num_shards = 2;
// Build a non-trivial indexer with events
let token1 = CancellationToken::new();
let mut original_indexer = KvIndexerSharded::new(token1.clone(), num_shards, kv_block_size);
let worker_0 = 0;
let worker_1 = 1;
let worker_2 = 2;
// Apply events to the original indexer
original_indexer
.apply_event(create_store_event(worker_0, 0, vec![1, 2, 3], None))
.await;
original_indexer
.apply_event(create_store_event(worker_1, 1, vec![1, 2, 3], None))
.await;
original_indexer
.apply_event(create_store_event(
worker_1,
2,
vec![4, 5],
Some(ExternalSequenceBlockHash(100)),
))
.await;
original_indexer
.apply_event(create_store_event(worker_2, 3, vec![6, 7], None))
.await;
original_indexer
.apply_event(create_store_event(
worker_0,
4,
vec![4],
Some(ExternalSequenceBlockHash(100)),
))
.await;
// Allow some time for events to be processed
tokio::time::sleep(Duration::from_millis(50)).await;
// Dump the original indexer
let dump1 = original_indexer.dump_events().await.unwrap();
println!("Dumped {} events", dump1.len());
// Create a new indexer and apply all dumped events
let token2 = CancellationToken::new();
let mut reconstructed_indexer =
KvIndexerSharded::new(token2.clone(), num_shards, kv_block_size);
for event in &dump1 {
reconstructed_indexer.apply_event(event.clone()).await;
}
// Allow some time for events to be processed
tokio::time::sleep(Duration::from_millis(50)).await;
// Dump the reconstructed indexer
let dump2 = reconstructed_indexer.dump_events().await.unwrap();
// Sort both dumps for comparison (order might differ due to HashMap iteration and sharding)
let mut sorted_dump1 = dump1.clone();
let mut sorted_dump2 = dump2.clone();
// Sort by (worker_id, tokens_hash, parent_hash)
let sort_key = |event: &RouterEvent| {
if let KvCacheEventData::Stored(ref data) = event.event.data {
(
event.worker_id,
data.blocks.first().map(|b| b.tokens_hash.0).unwrap_or(0),
data.parent_hash.map(|h| h.0).unwrap_or(0),
)
} else {
(event.worker_id, 0, 0)
}
};
sorted_dump1.sort_by_key(sort_key);
sorted_dump2.sort_by_key(sort_key);
// Verify the dumps have the same length
assert_eq!(
sorted_dump1.len(),
sorted_dump2.len(),
"Dumps have different lengths: {} vs {}",
sorted_dump1.len(),
sorted_dump2.len()
);
// Verify each event matches
for (i, (event1, event2)) in sorted_dump1.iter().zip(sorted_dump2.iter()).enumerate() {
assert_eq!(
event1.worker_id, event2.worker_id,
"Event {} worker_id mismatch",
i
);
if let (KvCacheEventData::Stored(data1), KvCacheEventData::Stored(data2)) =
(&event1.event.data, &event2.event.data)
{
assert_eq!(
data1.parent_hash, data2.parent_hash,
"Event {} parent_hash mismatch",
i
);
assert_eq!(
data1.blocks.len(),
data2.blocks.len(),
"Event {} blocks length mismatch",
i
);
for (j, (block1, block2)) in
data1.blocks.iter().zip(data2.blocks.iter()).enumerate()
{
assert_eq!(
block1.tokens_hash, block2.tokens_hash,
"Event {} block {} tokens_hash mismatch",
i, j
);
assert_eq!(
block1.block_hash, block2.block_hash,
"Event {} block {} block_hash mismatch",
i, j
);
}
} else {
panic!("Expected Stored events in both dumps");
}
}
// Also verify that both indexers produce the same match results
for test_seq in [
vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
vec![LocalBlockHash(1), LocalBlockHash(4), LocalBlockHash(5)],
vec![LocalBlockHash(6), LocalBlockHash(7)],
vec![LocalBlockHash(1)],
] {
let scores1 = original_indexer
.find_matches(test_seq.clone())
.await
.unwrap();
let scores2 = reconstructed_indexer
.find_matches(test_seq.clone())
.await
.unwrap();
// Sort the scores to compare
let mut scores1_sorted: Vec<_> = scores1.scores.iter().collect();
let mut scores2_sorted: Vec<_> = scores2.scores.iter().collect();
scores1_sorted.sort_by_key(|(k, _)| *k);
scores2_sorted.sort_by_key(|(k, _)| *k);
assert_eq!(
scores1_sorted, scores2_sorted,
"Match scores differ for sequence {:?}",
test_seq
);
}
// Clean up
original_indexer.shutdown();
reconstructed_indexer.shutdown();
}
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment