Unverified Commit 937398cf authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: Flash Indexer (#5785)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Signed-off-by: default avatarjthomson04 <jothomson@nvidia.com>
Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
Signed-off-by: default avatarJanelle Cai <jcai18@mit.edu>
Co-authored-by: default avatarjthomson04 <jwillthomson19@gmail.com>
Co-authored-by: default avatarJanelle Cai <jcai18@mit.edu>
parent de27efe6
...@@ -153,7 +153,11 @@ impl RadixTree { ...@@ -153,7 +153,11 @@ impl RadixTree {
/// An `OverlapScores` representing the match scores. /// An `OverlapScores` representing the match scores.
pub fn find_matches(&self, sequence: Vec<LocalBlockHash>, early_exit: bool) -> OverlapScores { pub fn find_matches(&self, sequence: Vec<LocalBlockHash>, early_exit: bool) -> OverlapScores {
let mut scores = OverlapScores::new(); let mut scores = OverlapScores::new();
let mut current = self.root.clone();
if sequence.is_empty() {
return scores;
}
let now = Instant::now(); let now = Instant::now();
tracing::trace!( tracing::trace!(
...@@ -161,17 +165,110 @@ impl RadixTree { ...@@ -161,17 +165,110 @@ impl RadixTree {
sequence.iter().map(|h| h.0).collect::<Vec<_>>() sequence.iter().map(|h| h.0).collect::<Vec<_>>()
); );
for (idx, block_hash) in sequence.iter().enumerate() { // Get first child from root.
let first_child = {
let current_borrow = self.root.borrow();
current_borrow.children.get(&sequence[0]).cloned()
};
let Some(first_child) = first_child else {
return scores;
};
// Initialize active worker set from first child.
let (mut active, mut active_count) = {
let borrow = first_child.borrow();
(borrow.workers.clone(), borrow.workers.len())
};
// Frequency tracking for first child.
if let Some(expiration_duration) = self.expiration_duration {
let mut block_mut = first_child.borrow_mut();
while let Some(access_time) = block_mut.recent_uses.front() {
if now.duration_since(*access_time) > expiration_duration {
block_mut.recent_uses.pop_front();
} else {
break;
}
}
scores.add_frequency(block_mut.recent_uses.len());
block_mut.recent_uses.push_back(now);
}
if active.is_empty() {
return scores;
}
if early_exit && active_count == 1 {
for worker in &active {
scores.scores.insert(*worker, 1);
}
for worker in scores.scores.keys() {
let tree_size = self
.lookup
.get(worker)
.expect("worker in scores must exist in lookup table")
.len();
scores.tree_sizes.insert(*worker, tree_size);
}
return scores;
}
let mut current = first_child;
let mut matched_depth = 1u32;
// Traverse remaining levels. In a clean tree, workers at a child node
// are always a subset of the parent (along the same path), so:
// - workers can only drop out, never join, as we descend
// - if child.workers.len() == active_count, the sets are identical
//
// However, because apply_event(Removed) does NOT cascade to descendants,
// a child may transiently have MORE workers than its parent (stale
// entries from an ancestor remove whose descendant remove events
// haven't arrived yet). We detect this via child_count > active_count
// and fall back to a full membership check.
for (idx, item) in sequence.iter().enumerate().skip(1) {
let next_block = { let next_block = {
let current_borrow = current.borrow(); let current_borrow = current.borrow();
current_borrow.children.get(block_hash).cloned() current_borrow.children.get(item).cloned()
};
let Some(block) = next_block else {
break;
}; };
if let Some(block) = next_block {
scores.update_scores(block.borrow().workers.iter());
{
let borrow = block.borrow();
let child_count = borrow.workers.len();
if child_count < active_count {
// Workers dropped out. Record scores for those that left.
// Score = matched_depth (number of nodes they were present at).
for worker in &active {
if !borrow.workers.contains(worker) {
scores.scores.insert(*worker, matched_depth);
}
}
active.clone_from(&borrow.workers);
active_count = child_count;
} else if child_count > active_count {
// Stale entries: child retains workers already removed from
// an ancestor. Fall back to full membership check.
active.retain(|w| {
if borrow.workers.contains(w) {
true
} else {
scores.scores.insert(*w, matched_depth);
false
}
});
active_count = active.len();
}
}
// Frequency tracking (always runs when enabled, independent of dropout).
if let Some(expiration_duration) = self.expiration_duration { if let Some(expiration_duration) = self.expiration_duration {
let mut block_mut = block.borrow_mut(); let mut block_mut = block.borrow_mut();
while let Some(access_time) = block_mut.recent_uses.front() { while let Some(access_time) = block_mut.recent_uses.front() {
if now.duration_since(*access_time) > expiration_duration { if now.duration_since(*access_time) > expiration_duration {
block_mut.recent_uses.pop_front(); block_mut.recent_uses.pop_front();
...@@ -183,24 +280,27 @@ impl RadixTree { ...@@ -183,24 +280,27 @@ impl RadixTree {
block_mut.recent_uses.push_back(now); block_mut.recent_uses.push_back(now);
} }
if early_exit && block.borrow().workers.len() == 1 { if active_count == 0 {
break; break;
} }
current = block; if early_exit && active_count == 1 {
} else { matched_depth = (idx + 1) as u32;
tracing::trace!(
"RadixTree::find_matches: block not found at index {} for hash {}",
idx,
block_hash.0
);
break; break;
} }
current = block;
matched_depth = (idx + 1) as u32;
}
// Record scores for workers that survived through the deepest matched level.
for worker in &active {
scores.scores.insert(*worker, matched_depth);
} }
tracing::trace!("RadixTree::find_matches: final scores={:?}", scores.scores); tracing::trace!("RadixTree::find_matches: final scores={:?}", scores.scores);
// Populate tree sizes for all workers that have scores // Populate tree sizes for all workers that have scores.
for worker in scores.scores.keys() { for worker in scores.scores.keys() {
let tree_size = self let tree_size = self
.lookup .lookup
...@@ -250,8 +350,19 @@ impl RadixTree { ...@@ -250,8 +350,19 @@ impl RadixTree {
None => self.root.clone(), None => self.root.clone(),
}; };
let mut needs_worker_insert = false;
// In each iteration we lock the parent and insert the worker
// deferred from the previous iteration, avoiding a second
// borrow on the same block.
for block_data in op.blocks { for block_data in op.blocks {
let mut parent_mut = current.borrow_mut(); let mut parent_mut = current.borrow_mut();
if needs_worker_insert {
parent_mut.workers.insert(worker);
}
needs_worker_insert = true;
let child = match parent_mut.children.get(&block_data.tokens_hash) { let child = match parent_mut.children.get(&block_data.tokens_hash) {
Some(block) => { Some(block) => {
// Verify our simplifying assumption: block_hash is uniform across workers // Verify our simplifying assumption: block_hash is uniform across workers
...@@ -265,7 +376,6 @@ impl RadixTree { ...@@ -265,7 +376,6 @@ impl RadixTree {
block.clone() block.clone()
} }
None => { None => {
// create new block or reuse existing from worker's lookup
let new_block = worker_lookup let new_block = worker_lookup
.get(&block_data.block_hash) .get(&block_data.block_hash)
.cloned() .cloned()
...@@ -275,7 +385,6 @@ impl RadixTree { ...@@ -275,7 +385,6 @@ impl RadixTree {
))) )))
}); });
// insert into radix tree
parent_mut parent_mut
.children .children
.insert(block_data.tokens_hash, new_block.clone()); .insert(block_data.tokens_hash, new_block.clone());
...@@ -284,13 +393,9 @@ impl RadixTree { ...@@ -284,13 +393,9 @@ impl RadixTree {
} }
}; };
// Update child and check for self referential blocks // Self-reference check: try_borrow_mut will fail if child
{ // is the same Rc as current (parent_mut holds a mutable borrow).
// Try to borrow the child mutably - if it fails, it's already borrowed if child.try_borrow_mut().is_err() {
// which means a self referencing block.
let mut child_mut = match child.try_borrow_mut() {
Ok(b) => b,
Err(_) => {
tracing::warn!( tracing::warn!(
worker_id = worker.worker_id.to_string(), worker_id = worker.worker_id.to_string(),
dp_rank = worker.dp_rank, dp_rank = worker.dp_rank,
...@@ -300,20 +405,18 @@ impl RadixTree { ...@@ -300,20 +405,18 @@ impl RadixTree {
); );
return Err(KvCacheEventError::InvalidBlockSequence); return Err(KvCacheEventError::InvalidBlockSequence);
} }
};
// add our worker to the block
child_mut.workers.insert(worker);
}
// add the block to the worker's lookup table
worker_lookup.insert(block_data.block_hash, child.clone()); worker_lookup.insert(block_data.block_hash, child.clone());
// drop child so we can shift current to this block
drop(parent_mut); drop(parent_mut);
current = child; current = child;
} }
// Insert worker into the last child.
if needs_worker_insert {
current.borrow_mut().workers.insert(worker);
}
Ok(()) Ok(())
} }
KvCacheEventData::Removed(remove) => { KvCacheEventData::Removed(remove) => {
...@@ -474,64 +577,8 @@ impl RadixTree { ...@@ -474,64 +577,8 @@ impl RadixTree {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::protocols::{ use crate::protocols::{ExternalSequenceBlockHash, LocalBlockHash};
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, use crate::test_utils::{create_remove_event, create_store_event};
KvCacheStoreData, KvCacheStoredBlockData, LocalBlockHash, WorkerId,
};
/// Creates blocks with artificial hash mapping (hash * 100) for testing RadixTree internals.
fn make_blocks(hashes: Vec<u64>) -> Vec<KvCacheStoredBlockData> {
hashes
.iter()
.map(|i| KvCacheStoredBlockData {
tokens_hash: LocalBlockHash(*i),
block_hash: ExternalSequenceBlockHash(*i * 100),
mm_extra_info: None,
})
.collect()
}
fn add_blocks(
hashes: Vec<u64>,
parent_hash: Option<ExternalSequenceBlockHash>,
) -> KvCacheEventData {
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks: make_blocks(hashes),
})
}
fn create_store_event(
worker_id: WorkerId,
event_id: u64,
hashes: Vec<u64>,
parent: Option<ExternalSequenceBlockHash>,
) -> RouterEvent {
RouterEvent {
worker_id,
event: KvCacheEvent {
event_id,
data: add_blocks(hashes, parent),
dp_rank: 0,
},
}
}
fn create_remove_event(worker_id: WorkerId, event_id: u64, hashes: Vec<u64>) -> RouterEvent {
RouterEvent {
worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: hashes
.iter()
.map(|i| ExternalSequenceBlockHash(*i * 100))
.collect(),
}),
dp_rank: 0,
},
}
}
#[test] #[test]
fn test_radix_tree() { fn test_radix_tree() {
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Shared test utilities for radix tree tests.
use crate::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash, RouterEvent, WorkerId,
};
/// Creates blocks with artificial hash mapping (hash * 100) for testing.
pub fn make_blocks(hashes: Vec<u64>) -> Vec<KvCacheStoredBlockData> {
hashes
.iter()
.map(|i| KvCacheStoredBlockData {
tokens_hash: LocalBlockHash(*i),
block_hash: ExternalSequenceBlockHash(*i * 100),
mm_extra_info: None,
})
.collect()
}
pub fn add_blocks(
hashes: Vec<u64>,
parent_hash: Option<ExternalSequenceBlockHash>,
) -> KvCacheEventData {
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks: make_blocks(hashes),
})
}
pub fn create_store_event(
worker_id: WorkerId,
event_id: u64,
hashes: Vec<u64>,
parent: Option<ExternalSequenceBlockHash>,
) -> RouterEvent {
RouterEvent {
worker_id,
event: KvCacheEvent {
event_id,
data: add_blocks(hashes, parent),
dp_rank: 0,
},
}
}
pub fn create_remove_event(worker_id: WorkerId, event_id: u64, hashes: Vec<u64>) -> RouterEvent {
RouterEvent {
worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: hashes
.iter()
.map(|i| ExternalSequenceBlockHash(*i * 100))
.collect(),
}),
dp_rank: 0,
},
}
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment