Unverified Commit 11694273 authored by blarson-b10's avatar blarson-b10 Committed by GitHub
Browse files

perf: Improve performance of snapshot using a reverse lookup from block -> external hash (#3370)


Signed-off-by: default avatarBrian Larson <brian.larson@baseten.co>
Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Co-authored-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 1ddb9b0c
......@@ -41,7 +41,7 @@ use prometheus::{IntCounterVec, Opts};
use serde::{Deserialize, Serialize};
use std::{
cell::RefCell,
collections::{HashMap, HashSet, VecDeque},
collections::{HashMap, VecDeque},
iter,
rc::Rc,
sync::{Arc, OnceLock},
......@@ -200,8 +200,9 @@ impl RouterEvent {
struct RadixBlock {
/// A map of child blocks, keyed by their local block hash.
children: HashMap<LocalBlockHash, SharedRadixBlock>,
/// A set of worker IDs associated with this block.
workers: HashSet<WorkerId>,
/// A map of worker IDs to their external sequence block hash for this block.
/// The external hash is preserved to speed up snapshotting.
workers: HashMap<WorkerId, ExternalSequenceBlockHash>,
/// A buffer of times that this block was last traversed
recent_uses: VecDeque<Instant>,
}
......@@ -215,7 +216,7 @@ impl RadixBlock {
pub fn new() -> Self {
Self {
children: HashMap::new(),
workers: HashSet::new(),
workers: HashMap::new(),
recent_uses: VecDeque::new(),
}
}
......@@ -289,7 +290,7 @@ impl RadixTree {
current_borrow.children.get(block_hash).cloned()
};
if let Some(block) = next_block {
scores.update_scores(&block.borrow().workers);
scores.update_scores(block.borrow().workers.keys());
if let Some(expiration_duration) = self.expiration_duration {
let mut block_mut = block.borrow_mut();
......@@ -380,8 +381,11 @@ impl RadixTree {
}
};
// add our worker_id to the block
block.borrow_mut().workers.insert(worker_id);
// add our worker_id to the block with its external hash
block
.borrow_mut()
.workers
.insert(worker_id, block_id.block_hash);
// add the block to the worker_id lookup table
worker_lookup.insert(block_id.block_hash, block.clone());
......@@ -417,7 +421,7 @@ impl RadixTree {
let mut guard = entry.borrow_mut();
guard.workers.remove(&worker_id);
if guard.workers.is_empty() {
// if no worker are using this block, that is true for all children
// if no workers are using this block, that is true for all children
guard.children.clear();
}
// remove the block from the lookup table
......@@ -436,6 +440,10 @@ impl RadixTree {
if let Some((_, blocks)) = self.lookup.remove_entry(&worker) {
blocks.iter().for_each(|(_, block)| {
block.borrow_mut().workers.remove(&worker);
// If no workers are using this block, that is true for all children
if block.borrow().workers.is_empty() {
block.borrow_mut().children.clear();
}
});
}
}
......@@ -445,14 +453,18 @@ impl RadixTree {
if let Some(blocks) = self.lookup.get(&worker) {
let blocks_to_clear: Vec<_> = blocks.values().collect();
// Remove the worker from each block's workers set
// Remove the worker from each block's workers map
blocks_to_clear.iter().for_each(|block| {
block.borrow_mut().workers.remove(&worker);
// If no workers are using this block, that is true for all children
if block.borrow().workers.is_empty() {
block.borrow_mut().children.clear();
}
});
// Clear the worker's blocks
if let Some(worker_blocks) = self.lookup.get_mut(&worker) {
worker_blocks.clear();
if let Some(worker_lookup) = self.lookup.get_mut(&worker) {
worker_lookup.clear();
}
}
}
......@@ -461,71 +473,68 @@ impl RadixTree {
/// Uses BFS traversal to ensure that the tree reconstruction is unique,
/// though the exact event ordering will be lost.
pub fn dump_tree_as_events(&self) -> Vec<RouterEvent> {
tracing::debug!(
"Dumping radix tree as events (contains information about {:?} workers)",
self.lookup.len()
);
let mut events = Vec::new();
let mut event_id = 0u64;
// BFS queue: (current_block, parent_external_hash, tokens_hash)
let mut queue = VecDeque::new();
// BFS queue: (current_block, parent_hashes_per_worker, tokens_hash)
// parent_hashes_per_worker maps WorkerId -> ExternalSequenceBlockHash
let mut queue: VecDeque<(
SharedRadixBlock,
HashMap<WorkerId, ExternalSequenceBlockHash>,
LocalBlockHash,
)> = VecDeque::new();
// Process root's children first
let root_borrow = self.root.borrow();
for (tokens_hash, child_block) in &root_borrow.children {
queue.push_back((child_block.clone(), None, *tokens_hash));
queue.push_back((child_block.clone(), HashMap::new(), *tokens_hash));
}
drop(root_borrow);
while let Some((current_block, parent_external_hash, tokens_hash)) = queue.pop_front() {
while let Some((current_block, parent_hashes, tokens_hash)) = queue.pop_front() {
let current_borrow = current_block.borrow();
// Closure to find external hash for a block in a worker's lookup
let find_external_hash = |worker_id: &WorkerId| {
self.lookup.get(worker_id).and_then(|worker_blocks| {
worker_blocks
.iter()
.find(|(_, block)| Rc::ptr_eq(block, &current_block))
.map(|(hash, _)| *hash)
})
};
// Map of this block's external hashes per worker (for children to use as parent)
let mut current_external_hashes = HashMap::new();
// For each worker that has this block
for worker_id in &current_borrow.workers {
// Find the external hash for this block from the worker's lookup
let external_hash = find_external_hash(worker_id);
if let Some(block_hash) = external_hash {
// Create a store event for this worker
let event = RouterEvent {
worker_id: *worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: parent_external_hash,
blocks: vec![KvCacheStoredBlockData {
block_hash,
tokens_hash,
}],
}),
},
};
events.push(event);
event_id += 1;
}
}
for (worker_id, external_hash) in &current_borrow.workers {
// Get the correct parent hash for this worker
let parent_hash = parent_hashes.get(worker_id).copied();
// Create a store event for this worker
let event = RouterEvent {
worker_id: *worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks: vec![KvCacheStoredBlockData {
block_hash: *external_hash,
tokens_hash,
}],
}),
},
};
events.push(event);
event_id += 1;
// Add children to queue for BFS traversal
// We need to find any external hash for this block to use as parent
let any_external_hash = if !current_borrow.workers.is_empty() {
current_borrow
.workers
.iter()
.next()
.and_then(find_external_hash)
} else {
None
};
// Track this block's external hash for this worker
current_external_hashes.insert(*worker_id, *external_hash);
}
// Enqueue children with per-worker parent hashes
for (child_tokens_hash, child_block) in &current_borrow.children {
queue.push_back((child_block.clone(), any_external_hash, *child_tokens_hash));
queue.push_back((
child_block.clone(),
current_external_hashes.clone(),
*child_tokens_hash,
));
}
}
......@@ -657,8 +666,11 @@ impl OverlapScores {
///
/// ### Arguments
///
/// * `workers` - A reference to a `HashSet` of `WorkerId`s.
pub fn update_scores(&mut self, workers: &HashSet<WorkerId>) {
/// * `workers` - An iterator over `WorkerId` references.
pub fn update_scores<'a, I>(&mut self, workers: I)
where
I: IntoIterator<Item = &'a WorkerId>,
{
for worker in workers {
let score = self.scores.entry(*worker).or_insert(0);
*score += 1;
......@@ -2171,4 +2183,79 @@ mod tests {
1
);
}
#[test]
fn test_remove_worker_verifies_hash_removal() {
setup();
let mut trie = RadixTree::new();
let worker_0 = 0;
let worker_1 = 1;
let worker_2 = 2;
// Add blocks for multiple workers
trie.apply_event(create_store_event(worker_0, 0, vec![1, 2, 3], None))
.unwrap();
trie.apply_event(create_store_event(worker_1, 0, vec![1, 2, 3], None))
.unwrap();
trie.apply_event(create_store_event(worker_2, 0, vec![1, 4, 5], None))
.unwrap();
// Verify worker_0 has 3 blocks in lookup
assert_eq!(trie.lookup.get(&worker_0).unwrap().len(), 3);
// Verify that blocks have the correct workers
let block_1 = trie
.lookup
.get(&worker_0)
.unwrap()
.get(&ExternalSequenceBlockHash(100))
.unwrap();
assert_eq!(block_1.borrow().workers.len(), 3); // worker_0, worker_1, and worker_2 (all have hash 1)
assert!(block_1.borrow().workers.contains_key(&worker_0));
assert!(block_1.borrow().workers.contains_key(&worker_1));
assert!(block_1.borrow().workers.contains_key(&worker_2));
// Remove worker_0
trie.remove_worker(worker_0);
// Verify worker_0 is completely removed from lookup table
assert!(!trie.lookup.contains_key(&worker_0));
assert_eq!(trie.lookup.len(), 2);
// Verify that worker_0's hash is removed from the workers set
let block_1 = trie
.lookup
.get(&worker_1)
.unwrap()
.get(&ExternalSequenceBlockHash(100))
.unwrap();
assert_eq!(block_1.borrow().workers.len(), 2); // worker_1 and worker_2 remain
assert!(!block_1.borrow().workers.contains_key(&worker_0));
assert!(block_1.borrow().workers.contains_key(&worker_1));
assert!(block_1.borrow().workers.contains_key(&worker_2));
// Verify that blocks with no remaining workers have their children cleared
// This tests the optimization where empty blocks clear their children
let block_2 = trie
.lookup
.get(&worker_1)
.unwrap()
.get(&ExternalSequenceBlockHash(200))
.unwrap();
assert_eq!(block_2.borrow().workers.len(), 1); // only worker_1
assert!(block_2.borrow().workers.contains_key(&worker_1));
// Verify match results no longer include worker_0
let result = trie
.find_matches(
vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
false,
)
.scores;
assert_eq!(result.len(), 2);
assert!(!result.contains_key(&worker_0));
assert!(result.contains_key(&worker_1));
assert!(result.contains_key(&worker_2));
}
}
......@@ -365,6 +365,8 @@ async fn purge_then_snapshot(
// Purge before snapshot ensures new/warm-restarted routers won't replay already-acknowledged messages.
// Since KV events are idempotent, this ordering reduces unnecessary reprocessing while maintaining
// at-least-once delivery guarantees. The snapshot will capture the clean state after purge.
tracing::info!("Purging acknowledged messages and performing snapshot of radix tree");
let start_time = std::time::Instant::now();
// First, purge acknowledged messages from the stream
nats_queue.purge_acknowledged().await?;
......@@ -397,9 +399,10 @@ async fn purge_then_snapshot(
.map_err(|e| anyhow::anyhow!("Failed to upload snapshot: {e:?}"))?;
tracing::info!(
"Successfully uploaded radix tree snapshot with {} events to bucket {}",
"Successfully performed snapshot of radix tree with {} events to bucket {} in {}ms",
events.len(),
resources.bucket_name
resources.bucket_name,
start_time.elapsed().as_millis()
);
Ok(())
......
......@@ -257,6 +257,7 @@ impl Client {
tokio::io::copy(&mut obj_reader, &mut buffer)
.await
.map_err(|e| anyhow::anyhow!("Failed reading object data: {e}"))?;
tracing::debug!("Downloaded {} bytes from {bucket_name}/{key}", buffer.len());
// Deserialize from bincode
let data = bincode::deserialize(&buffer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment