Unverified Commit b273eda2 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: simplify RadixTree assuming consistent hashing across workers (#5605)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent feb6d272
......@@ -41,7 +41,7 @@ use prometheus::{IntCounterVec, Opts};
use serde::{Deserialize, Serialize};
use std::{
cell::RefCell,
collections::{HashMap, VecDeque},
collections::{HashMap, HashSet, VecDeque},
iter,
rc::Rc,
sync::{Arc, Mutex, OnceLock},
......@@ -165,23 +165,40 @@ impl MaybeError for WorkerKvQueryResponse {
struct RadixBlock {
/// A map of child blocks, keyed by their local block hash.
children: HashMap<LocalBlockHash, SharedRadixBlock>,
/// A map of workers (with dp_rank) to their external sequence block hash for this block.
/// The external hash is preserved to speed up snapshotting.
workers: HashMap<WorkerWithDpRank, ExternalSequenceBlockHash>,
/// The set of workers that have this block cached.
workers: HashSet<WorkerWithDpRank>,
/// The external sequence block hash for this block (None for root).
/// This is the same for all workers under the simplifying assumption.
block_hash: Option<ExternalSequenceBlockHash>,
/// A buffer of times that this block was last traversed
recent_uses: VecDeque<Instant>,
}
impl RadixBlock {
/// Create a new `RadixBlock`.
/// Create a new `RadixBlock` (used for root node).
///
/// ### Returns
///
/// A new `RadixBlock`.
/// A new `RadixBlock` with no block_hash.
pub fn new() -> Self {
Self {
children: HashMap::new(),
workers: HashMap::new(),
workers: HashSet::new(),
block_hash: None,
recent_uses: VecDeque::new(),
}
}
/// Create a new `RadixBlock` with a specific block hash.
///
/// ### Returns
///
/// A new `RadixBlock` with the given block_hash.
pub fn with_hash(block_hash: ExternalSequenceBlockHash) -> Self {
Self {
children: HashMap::new(),
workers: HashSet::new(),
block_hash: Some(block_hash),
recent_uses: VecDeque::new(),
}
}
......@@ -192,15 +209,10 @@ pub struct RadixTree {
/// This will only contain root blocks
root: SharedRadixBlock,
/// This is a global lookup table for all blocks which will let you jump into
/// the radix tree at any point
/// Lookup is best case O(1) and worst case O(N); however, even constant in-time
/// could be expensive if N is large
/// We should monitor the size of this table and consider using a proper radix tree.
/// Transitioning to a radix tree only would require a change in the messaging structure
/// as the entire prefix would need to be sent. Alternatively, we could use block_depth
/// integers to indicate how many blocks to skip and use a radix/prefix tree at each level.
/// Per-worker lookup table for O(1) block access.
/// Maps worker -> (block_hash -> block).
lookup: HashMap<WorkerWithDpRank, HashMap<ExternalSequenceBlockHash, SharedRadixBlock>>,
/// The time buffer the radix tree should check when considering frequence of block accesses
expiration_duration: Option<Duration>,
}
......@@ -288,7 +300,7 @@ impl RadixTree {
current_borrow.children.get(block_hash).cloned()
};
if let Some(block) = next_block {
scores.update_scores(block.borrow().workers.keys());
scores.update_scores(block.borrow().workers.iter());
if let Some(expiration_duration) = self.expiration_duration {
let mut block_mut = block.borrow_mut();
......@@ -352,9 +364,7 @@ impl RadixTree {
match op {
KvCacheEventData::Stored(op) => {
// find the parent block - if the parent exists it must be on our worker, if not,
// we check the radix tree's root to find it.
// this is the single most expensive lookup
// find the parent block from this worker's lookup
let mut current = match op.parent_hash {
Some(parent) => match worker_lookup.get(&parent) {
Some(current) => current.clone(),
......@@ -376,13 +386,27 @@ impl RadixTree {
for block_data in op.blocks {
let mut parent_mut = current.borrow_mut();
let child = match parent_mut.children.get(&block_data.tokens_hash) {
Some(block) => block.clone(),
Some(block) => {
// Verify our simplifying assumption: block_hash is uniform across workers
if block.borrow().block_hash != Some(block_data.block_hash) {
tracing::warn!(
expected = ?block_data.block_hash,
actual = ?block.borrow().block_hash,
"block_hash mismatch: sequence hashes should be uniform across workers"
);
}
block.clone()
}
None => {
// create new block - automatically added to the lookup table
// create new block or reuse existing from worker's lookup
let new_block = worker_lookup
.get(&block_data.block_hash)
.cloned()
.unwrap_or_else(|| Rc::new(RefCell::new(RadixBlock::new())));
.unwrap_or_else(|| {
Rc::new(RefCell::new(RadixBlock::with_hash(
block_data.block_hash,
)))
});
// insert into radix tree
parent_mut
......@@ -411,11 +435,11 @@ impl RadixTree {
}
};
// add our worker to the block with its external hash
child_mut.workers.insert(worker, block_data.block_hash);
// add our worker to the block
child_mut.workers.insert(worker);
}
// add the block to the worker_id lookup table
// add the block to the worker's lookup table
worker_lookup.insert(block_data.block_hash, child.clone());
// drop child so we can shift current to this block
......@@ -426,15 +450,9 @@ impl RadixTree {
Ok(())
}
KvCacheEventData::Removed(remove) => {
// tracing::trace!(id, "KV Remove Operation: {:?}", op);
// let mut worker_lookup = self.lookup.get(&worker_id).expect("Worker not found");
let mut kv_cache_err: Option<KvCacheEventError> = None;
for block in remove.block_hashes {
// entry in radix tree
// a small optimization would be to get the next block from the reduced set of children
// in order to apply this optimization, we would need to know the list of blocks is always sorted
// by parent -> child relationship
// lookup block in worker's table
let entry = match worker_lookup.get(&block) {
Some(entry) => entry.clone(),
None => {
......@@ -461,14 +479,10 @@ impl RadixTree {
// if no workers are using this block, that is true for all children
guard.children.clear();
}
// remove the block from the lookup table
// remove the block from the worker's lookup table
worker_lookup.remove(&block);
}
if let Some(err) = kv_cache_err {
Err(err)
} else {
Ok(())
}
kv_cache_err.map_or(Ok(()), Err)
}
KvCacheEventData::Cleared => {
self.clear_all_blocks(worker.worker_id);
......@@ -491,13 +505,13 @@ impl RadixTree {
for worker in workers {
if let Some((worker_key, blocks)) = self.lookup.remove_entry(&worker) {
blocks.iter().for_each(|(_, block)| {
for (_, block) in blocks {
block.borrow_mut().workers.remove(&worker);
// If no workers are using this block, that is true for all children
if block.borrow().workers.is_empty() {
block.borrow_mut().children.clear();
}
});
}
if keep_worker {
// Re-insert worker with empty blocks map to keep it tracked
......@@ -528,15 +542,6 @@ impl RadixTree {
/// Uses BFS traversal to ensure that the tree reconstruction is unique,
/// though the exact event ordering will be lost.
pub fn dump_tree_as_events(&self) -> Vec<RouterEvent> {
// BFS queue entry: (current_block, parent_hashes_per_worker, tokens_hash)
// parent_hashes_per_worker maps WorkerWithDpRank -> ExternalSequenceBlockHash
// Using Rc to avoid cloning the HashMap for each child
type BfsQueueEntry = (
SharedRadixBlock,
Rc<HashMap<WorkerWithDpRank, ExternalSequenceBlockHash>>,
LocalBlockHash,
);
tracing::debug!(
"Dumping radix tree as events (contains information about {:?} workers)",
self.lookup.len()
......@@ -545,62 +550,49 @@ impl RadixTree {
let mut events = Vec::new();
let mut event_id = 0u64;
let mut queue: VecDeque<BfsQueueEntry> = VecDeque::new();
// Queue entries: (current_block, parent_hash, tokens_hash)
let mut queue = VecDeque::new();
// Process root's children first
let root_borrow = self.root.borrow();
let empty_parent_hashes = Rc::new(HashMap::new());
for (tokens_hash, child_block) in &root_borrow.children {
queue.push_back((
child_block.clone(),
empty_parent_hashes.clone(),
*tokens_hash,
));
queue.push_back((child_block.clone(), None, *tokens_hash));
}
drop(root_borrow);
while let Some((current_block, parent_hashes, tokens_hash)) = queue.pop_front() {
while let Some((current_block, parent_hash, tokens_hash)) = queue.pop_front() {
let current_borrow = current_block.borrow();
// Map of this block's external hashes per worker (for children to use as parent)
let mut current_external_hashes = HashMap::new();
// Get this block's hash (same for all workers)
let block_hash = current_borrow
.block_hash
.expect("non-root block must have block_hash");
// For each worker that has this block
for (worker_id, external_hash) in &current_borrow.workers {
// Get the correct parent hash for this worker
let parent_hash = parent_hashes.get(worker_id).copied();
for worker in &current_borrow.workers {
// Create a store event for this worker
let event = RouterEvent {
worker_id: worker_id.worker_id,
worker_id: worker.worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks: vec![KvCacheStoredBlockData {
block_hash: *external_hash,
block_hash,
mm_extra_info: None,
tokens_hash,
}],
}),
dp_rank: worker_id.dp_rank,
dp_rank: worker.dp_rank,
},
};
events.push(event);
event_id += 1;
// Track this block's external hash for this worker
current_external_hashes.insert(*worker_id, *external_hash);
}
// Enqueue children with shared parent hashes (Rc avoids cloning HashMap)
let parent_hashes_rc = Rc::new(current_external_hashes);
// Enqueue children with this block's hash as their parent
for (child_tokens_hash, child_block) in &current_borrow.children {
queue.push_back((
child_block.clone(),
parent_hashes_rc.clone(),
*child_tokens_hash,
));
queue.push_back((child_block.clone(), Some(block_hash), *child_tokens_hash));
}
}
......@@ -2384,17 +2376,6 @@ mod tests {
.len(),
2
);
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap()
.get(&ExternalSequenceBlockHash(200))
.unwrap()
.borrow()
.workers
.len(),
2
);
}
#[test]
......@@ -3173,19 +3154,19 @@ mod tests {
block_1
.borrow()
.workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
.contains(&WorkerWithDpRank::from_worker_id(worker_0))
);
assert!(
block_1
.borrow()
.workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_1))
.contains(&WorkerWithDpRank::from_worker_id(worker_1))
);
assert!(
block_1
.borrow()
.workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_2))
.contains(&WorkerWithDpRank::from_worker_id(worker_2))
);
// Remove worker_0
......@@ -3211,19 +3192,19 @@ mod tests {
!block_1
.borrow()
.workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
.contains(&WorkerWithDpRank::from_worker_id(worker_0))
);
assert!(
block_1
.borrow()
.workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_1))
.contains(&WorkerWithDpRank::from_worker_id(worker_1))
);
assert!(
block_1
.borrow()
.workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_2))
.contains(&WorkerWithDpRank::from_worker_id(worker_2))
);
// Verify that blocks with no remaining workers have their children cleared
......@@ -3239,7 +3220,7 @@ mod tests {
block_2
.borrow()
.workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_1))
.contains(&WorkerWithDpRank::from_worker_id(worker_1))
);
// Verify match results no longer include worker_0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment