Unverified Commit b273eda2 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: simplify RadixTree assuming consistent hashing across workers (#5605)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent feb6d272
...@@ -41,7 +41,7 @@ use prometheus::{IntCounterVec, Opts}; ...@@ -41,7 +41,7 @@ use prometheus::{IntCounterVec, Opts};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::{ use std::{
cell::RefCell, cell::RefCell,
collections::{HashMap, VecDeque}, collections::{HashMap, HashSet, VecDeque},
iter, iter,
rc::Rc, rc::Rc,
sync::{Arc, Mutex, OnceLock}, sync::{Arc, Mutex, OnceLock},
...@@ -165,23 +165,40 @@ impl MaybeError for WorkerKvQueryResponse { ...@@ -165,23 +165,40 @@ impl MaybeError for WorkerKvQueryResponse {
struct RadixBlock { struct RadixBlock {
/// A map of child blocks, keyed by their local block hash. /// A map of child blocks, keyed by their local block hash.
children: HashMap<LocalBlockHash, SharedRadixBlock>, children: HashMap<LocalBlockHash, SharedRadixBlock>,
/// A map of workers (with dp_rank) to their external sequence block hash for this block. /// The set of workers that have this block cached.
/// The external hash is preserved to speed up snapshotting. workers: HashSet<WorkerWithDpRank>,
workers: HashMap<WorkerWithDpRank, ExternalSequenceBlockHash>, /// The external sequence block hash for this block (None for root).
/// This is the same for all workers under the simplifying assumption.
block_hash: Option<ExternalSequenceBlockHash>,
/// A buffer of times that this block was last traversed /// A buffer of times that this block was last traversed
recent_uses: VecDeque<Instant>, recent_uses: VecDeque<Instant>,
} }
impl RadixBlock { impl RadixBlock {
/// Create a new `RadixBlock`. /// Create a new `RadixBlock` (used for root node).
/// ///
/// ### Returns /// ### Returns
/// ///
/// A new `RadixBlock`. /// A new `RadixBlock` with no block_hash.
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
children: HashMap::new(), children: HashMap::new(),
workers: HashMap::new(), workers: HashSet::new(),
block_hash: None,
recent_uses: VecDeque::new(),
}
}
/// Create a new `RadixBlock` with a specific block hash.
///
/// ### Returns
///
/// A new `RadixBlock` with the given block_hash.
pub fn with_hash(block_hash: ExternalSequenceBlockHash) -> Self {
Self {
children: HashMap::new(),
workers: HashSet::new(),
block_hash: Some(block_hash),
recent_uses: VecDeque::new(), recent_uses: VecDeque::new(),
} }
} }
...@@ -192,15 +209,10 @@ pub struct RadixTree { ...@@ -192,15 +209,10 @@ pub struct RadixTree {
/// This will only contain root blocks /// This will only contain root blocks
root: SharedRadixBlock, root: SharedRadixBlock,
/// This is a global lookup table for all blocks which will let you jump into /// Per-worker lookup table for O(1) block access.
/// the radix tree at any point /// Maps worker -> (block_hash -> block).
/// Lookup is best case O(1) and worst case O(N); however, even constant in-time
/// could be expensive if N is large
/// We should monitor the size of this table and consider using a proper radix tree.
/// Transitioning to a radix tree only would require a change in the messaging structure
/// as the entire prefix would need to be sent. Alternatively, we could use block_depth
/// integers to indicate how many blocks to skip and use a radix/prefix tree at each level.
lookup: HashMap<WorkerWithDpRank, HashMap<ExternalSequenceBlockHash, SharedRadixBlock>>, lookup: HashMap<WorkerWithDpRank, HashMap<ExternalSequenceBlockHash, SharedRadixBlock>>,
/// The time buffer the radix tree should check when considering frequence of block accesses /// The time buffer the radix tree should check when considering frequence of block accesses
expiration_duration: Option<Duration>, expiration_duration: Option<Duration>,
} }
...@@ -288,7 +300,7 @@ impl RadixTree { ...@@ -288,7 +300,7 @@ impl RadixTree {
current_borrow.children.get(block_hash).cloned() current_borrow.children.get(block_hash).cloned()
}; };
if let Some(block) = next_block { if let Some(block) = next_block {
scores.update_scores(block.borrow().workers.keys()); scores.update_scores(block.borrow().workers.iter());
if let Some(expiration_duration) = self.expiration_duration { if let Some(expiration_duration) = self.expiration_duration {
let mut block_mut = block.borrow_mut(); let mut block_mut = block.borrow_mut();
...@@ -352,9 +364,7 @@ impl RadixTree { ...@@ -352,9 +364,7 @@ impl RadixTree {
match op { match op {
KvCacheEventData::Stored(op) => { KvCacheEventData::Stored(op) => {
// find the parent block - if the parent exists it must be on our worker, if not, // find the parent block from this worker's lookup
// we check the radix tree's root to find it.
// this is the single most expensive lookup
let mut current = match op.parent_hash { let mut current = match op.parent_hash {
Some(parent) => match worker_lookup.get(&parent) { Some(parent) => match worker_lookup.get(&parent) {
Some(current) => current.clone(), Some(current) => current.clone(),
...@@ -376,13 +386,27 @@ impl RadixTree { ...@@ -376,13 +386,27 @@ impl RadixTree {
for block_data in op.blocks { for block_data in op.blocks {
let mut parent_mut = current.borrow_mut(); let mut parent_mut = current.borrow_mut();
let child = match parent_mut.children.get(&block_data.tokens_hash) { let child = match parent_mut.children.get(&block_data.tokens_hash) {
Some(block) => block.clone(), Some(block) => {
// Verify our simplifying assumption: block_hash is uniform across workers
if block.borrow().block_hash != Some(block_data.block_hash) {
tracing::warn!(
expected = ?block_data.block_hash,
actual = ?block.borrow().block_hash,
"block_hash mismatch: sequence hashes should be uniform across workers"
);
}
block.clone()
}
None => { None => {
// create new block - automatically added to the lookup table // create new block or reuse existing from worker's lookup
let new_block = worker_lookup let new_block = worker_lookup
.get(&block_data.block_hash) .get(&block_data.block_hash)
.cloned() .cloned()
.unwrap_or_else(|| Rc::new(RefCell::new(RadixBlock::new()))); .unwrap_or_else(|| {
Rc::new(RefCell::new(RadixBlock::with_hash(
block_data.block_hash,
)))
});
// insert into radix tree // insert into radix tree
parent_mut parent_mut
...@@ -411,11 +435,11 @@ impl RadixTree { ...@@ -411,11 +435,11 @@ impl RadixTree {
} }
}; };
// add our worker to the block with its external hash // add our worker to the block
child_mut.workers.insert(worker, block_data.block_hash); child_mut.workers.insert(worker);
} }
// add the block to the worker_id lookup table // add the block to the worker's lookup table
worker_lookup.insert(block_data.block_hash, child.clone()); worker_lookup.insert(block_data.block_hash, child.clone());
// drop child so we can shift current to this block // drop child so we can shift current to this block
...@@ -426,15 +450,9 @@ impl RadixTree { ...@@ -426,15 +450,9 @@ impl RadixTree {
Ok(()) Ok(())
} }
KvCacheEventData::Removed(remove) => { KvCacheEventData::Removed(remove) => {
// tracing::trace!(id, "KV Remove Operation: {:?}", op);
// let mut worker_lookup = self.lookup.get(&worker_id).expect("Worker not found");
let mut kv_cache_err: Option<KvCacheEventError> = None; let mut kv_cache_err: Option<KvCacheEventError> = None;
for block in remove.block_hashes { for block in remove.block_hashes {
// entry in radix tree // lookup block in worker's table
// a small optimization would be to get the next block from the reduced set of children
// in order to apply this optimization, we would need to know the list of blocks is always sorted
// by parent -> child relationship
let entry = match worker_lookup.get(&block) { let entry = match worker_lookup.get(&block) {
Some(entry) => entry.clone(), Some(entry) => entry.clone(),
None => { None => {
...@@ -461,14 +479,10 @@ impl RadixTree { ...@@ -461,14 +479,10 @@ impl RadixTree {
// if no workers are using this block, that is true for all children // if no workers are using this block, that is true for all children
guard.children.clear(); guard.children.clear();
} }
// remove the block from the lookup table // remove the block from the worker's lookup table
worker_lookup.remove(&block); worker_lookup.remove(&block);
} }
if let Some(err) = kv_cache_err { kv_cache_err.map_or(Ok(()), Err)
Err(err)
} else {
Ok(())
}
} }
KvCacheEventData::Cleared => { KvCacheEventData::Cleared => {
self.clear_all_blocks(worker.worker_id); self.clear_all_blocks(worker.worker_id);
...@@ -491,13 +505,13 @@ impl RadixTree { ...@@ -491,13 +505,13 @@ impl RadixTree {
for worker in workers { for worker in workers {
if let Some((worker_key, blocks)) = self.lookup.remove_entry(&worker) { if let Some((worker_key, blocks)) = self.lookup.remove_entry(&worker) {
blocks.iter().for_each(|(_, block)| { for (_, block) in blocks {
block.borrow_mut().workers.remove(&worker); block.borrow_mut().workers.remove(&worker);
// If no workers are using this block, that is true for all children // If no workers are using this block, that is true for all children
if block.borrow().workers.is_empty() { if block.borrow().workers.is_empty() {
block.borrow_mut().children.clear(); block.borrow_mut().children.clear();
} }
}); }
if keep_worker { if keep_worker {
// Re-insert worker with empty blocks map to keep it tracked // Re-insert worker with empty blocks map to keep it tracked
...@@ -528,15 +542,6 @@ impl RadixTree { ...@@ -528,15 +542,6 @@ impl RadixTree {
/// Uses BFS traversal to ensure that the tree reconstruction is unique, /// Uses BFS traversal to ensure that the tree reconstruction is unique,
/// though the exact event ordering will be lost. /// though the exact event ordering will be lost.
pub fn dump_tree_as_events(&self) -> Vec<RouterEvent> { pub fn dump_tree_as_events(&self) -> Vec<RouterEvent> {
// BFS queue entry: (current_block, parent_hashes_per_worker, tokens_hash)
// parent_hashes_per_worker maps WorkerWithDpRank -> ExternalSequenceBlockHash
// Using Rc to avoid cloning the HashMap for each child
type BfsQueueEntry = (
SharedRadixBlock,
Rc<HashMap<WorkerWithDpRank, ExternalSequenceBlockHash>>,
LocalBlockHash,
);
tracing::debug!( tracing::debug!(
"Dumping radix tree as events (contains information about {:?} workers)", "Dumping radix tree as events (contains information about {:?} workers)",
self.lookup.len() self.lookup.len()
...@@ -545,62 +550,49 @@ impl RadixTree { ...@@ -545,62 +550,49 @@ impl RadixTree {
let mut events = Vec::new(); let mut events = Vec::new();
let mut event_id = 0u64; let mut event_id = 0u64;
let mut queue: VecDeque<BfsQueueEntry> = VecDeque::new(); // Queue entries: (current_block, parent_hash, tokens_hash)
let mut queue = VecDeque::new();
// Process root's children first // Process root's children first
let root_borrow = self.root.borrow(); let root_borrow = self.root.borrow();
let empty_parent_hashes = Rc::new(HashMap::new());
for (tokens_hash, child_block) in &root_borrow.children { for (tokens_hash, child_block) in &root_borrow.children {
queue.push_back(( queue.push_back((child_block.clone(), None, *tokens_hash));
child_block.clone(),
empty_parent_hashes.clone(),
*tokens_hash,
));
} }
drop(root_borrow); drop(root_borrow);
while let Some((current_block, parent_hashes, tokens_hash)) = queue.pop_front() { while let Some((current_block, parent_hash, tokens_hash)) = queue.pop_front() {
let current_borrow = current_block.borrow(); let current_borrow = current_block.borrow();
// Map of this block's external hashes per worker (for children to use as parent) // Get this block's hash (same for all workers)
let mut current_external_hashes = HashMap::new(); let block_hash = current_borrow
.block_hash
.expect("non-root block must have block_hash");
// For each worker that has this block // For each worker that has this block
for (worker_id, external_hash) in &current_borrow.workers { for worker in &current_borrow.workers {
// Get the correct parent hash for this worker
let parent_hash = parent_hashes.get(worker_id).copied();
// Create a store event for this worker // Create a store event for this worker
let event = RouterEvent { let event = RouterEvent {
worker_id: worker_id.worker_id, worker_id: worker.worker_id,
event: KvCacheEvent { event: KvCacheEvent {
event_id, event_id,
data: KvCacheEventData::Stored(KvCacheStoreData { data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash, parent_hash,
blocks: vec![KvCacheStoredBlockData { blocks: vec![KvCacheStoredBlockData {
block_hash: *external_hash, block_hash,
mm_extra_info: None, mm_extra_info: None,
tokens_hash, tokens_hash,
}], }],
}), }),
dp_rank: worker_id.dp_rank, dp_rank: worker.dp_rank,
}, },
}; };
events.push(event); events.push(event);
event_id += 1; event_id += 1;
// Track this block's external hash for this worker
current_external_hashes.insert(*worker_id, *external_hash);
} }
// Enqueue children with shared parent hashes (Rc avoids cloning HashMap) // Enqueue children with this block's hash as their parent
let parent_hashes_rc = Rc::new(current_external_hashes);
for (child_tokens_hash, child_block) in &current_borrow.children { for (child_tokens_hash, child_block) in &current_borrow.children {
queue.push_back(( queue.push_back((child_block.clone(), Some(block_hash), *child_tokens_hash));
child_block.clone(),
parent_hashes_rc.clone(),
*child_tokens_hash,
));
} }
} }
...@@ -2384,17 +2376,6 @@ mod tests { ...@@ -2384,17 +2376,6 @@ mod tests {
.len(), .len(),
2 2
); );
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap()
.get(&ExternalSequenceBlockHash(200))
.unwrap()
.borrow()
.workers
.len(),
2
);
} }
#[test] #[test]
...@@ -3173,19 +3154,19 @@ mod tests { ...@@ -3173,19 +3154,19 @@ mod tests {
block_1 block_1
.borrow() .borrow()
.workers .workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0)) .contains(&WorkerWithDpRank::from_worker_id(worker_0))
); );
assert!( assert!(
block_1 block_1
.borrow() .borrow()
.workers .workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_1)) .contains(&WorkerWithDpRank::from_worker_id(worker_1))
); );
assert!( assert!(
block_1 block_1
.borrow() .borrow()
.workers .workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_2)) .contains(&WorkerWithDpRank::from_worker_id(worker_2))
); );
// Remove worker_0 // Remove worker_0
...@@ -3211,19 +3192,19 @@ mod tests { ...@@ -3211,19 +3192,19 @@ mod tests {
!block_1 !block_1
.borrow() .borrow()
.workers .workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0)) .contains(&WorkerWithDpRank::from_worker_id(worker_0))
); );
assert!( assert!(
block_1 block_1
.borrow() .borrow()
.workers .workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_1)) .contains(&WorkerWithDpRank::from_worker_id(worker_1))
); );
assert!( assert!(
block_1 block_1
.borrow() .borrow()
.workers .workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_2)) .contains(&WorkerWithDpRank::from_worker_id(worker_2))
); );
// Verify that blocks with no remaining workers have their children cleared // Verify that blocks with no remaining workers have their children cleared
...@@ -3239,7 +3220,7 @@ mod tests { ...@@ -3239,7 +3220,7 @@ mod tests {
block_2 block_2
.borrow() .borrow()
.workers .workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_1)) .contains(&WorkerWithDpRank::from_worker_id(worker_1))
); );
// Verify match results no longer include worker_0 // Verify match results no longer include worker_0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment