"vllm/vscode:/vscode.git/clone" did not exist on "aaec845f8ed7f445c66ba0d28c84bec9d184f5ed"
Unverified Commit 937398cf authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: Flash Indexer (#5785)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Signed-off-by: default avatarjthomson04 <jothomson@nvidia.com>
Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
Signed-off-by: default avatarJanelle Cai <jcai18@mit.edu>
Co-authored-by: default avatarjthomson04 <jwillthomson19@gmail.com>
Co-authored-by: default avatarJanelle Cai <jcai18@mit.edu>
parent de27efe6
......@@ -153,7 +153,11 @@ impl RadixTree {
/// An `OverlapScores` representing the match scores.
pub fn find_matches(&self, sequence: Vec<LocalBlockHash>, early_exit: bool) -> OverlapScores {
let mut scores = OverlapScores::new();
let mut current = self.root.clone();
if sequence.is_empty() {
return scores;
}
let now = Instant::now();
tracing::trace!(
......@@ -161,46 +165,142 @@ impl RadixTree {
sequence.iter().map(|h| h.0).collect::<Vec<_>>()
);
for (idx, block_hash) in sequence.iter().enumerate() {
// Get first child from root.
let first_child = {
let current_borrow = self.root.borrow();
current_borrow.children.get(&sequence[0]).cloned()
};
let Some(first_child) = first_child else {
return scores;
};
// Initialize active worker set from first child.
let (mut active, mut active_count) = {
let borrow = first_child.borrow();
(borrow.workers.clone(), borrow.workers.len())
};
// Frequency tracking for first child.
if let Some(expiration_duration) = self.expiration_duration {
let mut block_mut = first_child.borrow_mut();
while let Some(access_time) = block_mut.recent_uses.front() {
if now.duration_since(*access_time) > expiration_duration {
block_mut.recent_uses.pop_front();
} else {
break;
}
}
scores.add_frequency(block_mut.recent_uses.len());
block_mut.recent_uses.push_back(now);
}
if active.is_empty() {
return scores;
}
if early_exit && active_count == 1 {
for worker in &active {
scores.scores.insert(*worker, 1);
}
for worker in scores.scores.keys() {
let tree_size = self
.lookup
.get(worker)
.expect("worker in scores must exist in lookup table")
.len();
scores.tree_sizes.insert(*worker, tree_size);
}
return scores;
}
let mut current = first_child;
let mut matched_depth = 1u32;
// Traverse remaining levels. In a clean tree, workers at a child node
// are always a subset of the parent (along the same path), so:
// - workers can only drop out, never join, as we descend
// - if child.workers.len() == active_count, the sets are identical
//
// However, because apply_event(Removed) does NOT cascade to descendants,
// a child may transiently have MORE workers than its parent (stale
// entries from an ancestor remove whose descendant remove events
// haven't arrived yet). We detect this via child_count > active_count
// and fall back to a full membership check.
for (idx, item) in sequence.iter().enumerate().skip(1) {
let next_block = {
let current_borrow = current.borrow();
current_borrow.children.get(block_hash).cloned()
current_borrow.children.get(item).cloned()
};
if let Some(block) = next_block {
scores.update_scores(block.borrow().workers.iter());
if let Some(expiration_duration) = self.expiration_duration {
let mut block_mut = block.borrow_mut();
let Some(block) = next_block else {
break;
};
while let Some(access_time) = block_mut.recent_uses.front() {
if now.duration_since(*access_time) > expiration_duration {
block_mut.recent_uses.pop_front();
} else {
break;
{
let borrow = block.borrow();
let child_count = borrow.workers.len();
if child_count < active_count {
// Workers dropped out. Record scores for those that left.
// Score = matched_depth (number of nodes they were present at).
for worker in &active {
if !borrow.workers.contains(worker) {
scores.scores.insert(*worker, matched_depth);
}
}
scores.add_frequency(block_mut.recent_uses.len());
block_mut.recent_uses.push_back(now);
active.clone_from(&borrow.workers);
active_count = child_count;
} else if child_count > active_count {
// Stale entries: child retains workers already removed from
// an ancestor. Fall back to full membership check.
active.retain(|w| {
if borrow.workers.contains(w) {
true
} else {
scores.scores.insert(*w, matched_depth);
false
}
});
active_count = active.len();
}
}
if early_exit && block.borrow().workers.len() == 1 {
break;
// Frequency tracking (always runs when enabled, independent of dropout).
if let Some(expiration_duration) = self.expiration_duration {
let mut block_mut = block.borrow_mut();
while let Some(access_time) = block_mut.recent_uses.front() {
if now.duration_since(*access_time) > expiration_duration {
block_mut.recent_uses.pop_front();
} else {
break;
}
}
scores.add_frequency(block_mut.recent_uses.len());
block_mut.recent_uses.push_back(now);
}
current = block;
} else {
tracing::trace!(
"RadixTree::find_matches: block not found at index {} for hash {}",
idx,
block_hash.0
);
if active_count == 0 {
break;
}
if early_exit && active_count == 1 {
matched_depth = (idx + 1) as u32;
break;
}
current = block;
matched_depth = (idx + 1) as u32;
}
// Record scores for workers that survived through the deepest matched level.
for worker in &active {
scores.scores.insert(*worker, matched_depth);
}
tracing::trace!("RadixTree::find_matches: final scores={:?}", scores.scores);
// Populate tree sizes for all workers that have scores
// Populate tree sizes for all workers that have scores.
for worker in scores.scores.keys() {
let tree_size = self
.lookup
......@@ -250,8 +350,19 @@ impl RadixTree {
None => self.root.clone(),
};
let mut needs_worker_insert = false;
// In each iteration we lock the parent and insert the worker
// deferred from the previous iteration, avoiding a second
// borrow on the same block.
for block_data in op.blocks {
let mut parent_mut = current.borrow_mut();
if needs_worker_insert {
parent_mut.workers.insert(worker);
}
needs_worker_insert = true;
let child = match parent_mut.children.get(&block_data.tokens_hash) {
Some(block) => {
// Verify our simplifying assumption: block_hash is uniform across workers
......@@ -265,7 +376,6 @@ impl RadixTree {
block.clone()
}
None => {
// create new block or reuse existing from worker's lookup
let new_block = worker_lookup
.get(&block_data.block_hash)
.cloned()
......@@ -275,7 +385,6 @@ impl RadixTree {
)))
});
// insert into radix tree
parent_mut
.children
.insert(block_data.tokens_hash, new_block.clone());
......@@ -284,36 +393,30 @@ impl RadixTree {
}
};
// Update child and check for self referential blocks
{
// Try to borrow the child mutably - if it fails, it's already borrowed
// which means a self referencing block.
let mut child_mut = match child.try_borrow_mut() {
Ok(b) => b,
Err(_) => {
tracing::warn!(
worker_id = worker.worker_id.to_string(),
dp_rank = worker.dp_rank,
id,
block_hash = ?block_data.block_hash,
"Detected self referencing block in store event; rejecting sequence"
);
return Err(KvCacheEventError::InvalidBlockSequence);
}
};
// add our worker to the block
child_mut.workers.insert(worker);
// Self-reference check: try_borrow_mut will fail if child
// is the same Rc as current (parent_mut holds a mutable borrow).
if child.try_borrow_mut().is_err() {
tracing::warn!(
worker_id = worker.worker_id.to_string(),
dp_rank = worker.dp_rank,
id,
block_hash = ?block_data.block_hash,
"Detected self referencing block in store event; rejecting sequence"
);
return Err(KvCacheEventError::InvalidBlockSequence);
}
// add the block to the worker's lookup table
worker_lookup.insert(block_data.block_hash, child.clone());
// drop child so we can shift current to this block
drop(parent_mut);
current = child;
}
// Insert worker into the last child.
if needs_worker_insert {
current.borrow_mut().workers.insert(worker);
}
Ok(())
}
KvCacheEventData::Removed(remove) => {
......@@ -474,64 +577,8 @@ impl RadixTree {
#[cfg(test)]
mod tests {
use super::*;
use crate::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData,
KvCacheStoreData, KvCacheStoredBlockData, LocalBlockHash, WorkerId,
};
/// Creates blocks with artificial hash mapping (hash * 100) for testing RadixTree internals.
fn make_blocks(hashes: Vec<u64>) -> Vec<KvCacheStoredBlockData> {
hashes
.iter()
.map(|i| KvCacheStoredBlockData {
tokens_hash: LocalBlockHash(*i),
block_hash: ExternalSequenceBlockHash(*i * 100),
mm_extra_info: None,
})
.collect()
}
fn add_blocks(
hashes: Vec<u64>,
parent_hash: Option<ExternalSequenceBlockHash>,
) -> KvCacheEventData {
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks: make_blocks(hashes),
})
}
fn create_store_event(
worker_id: WorkerId,
event_id: u64,
hashes: Vec<u64>,
parent: Option<ExternalSequenceBlockHash>,
) -> RouterEvent {
RouterEvent {
worker_id,
event: KvCacheEvent {
event_id,
data: add_blocks(hashes, parent),
dp_rank: 0,
},
}
}
fn create_remove_event(worker_id: WorkerId, event_id: u64, hashes: Vec<u64>) -> RouterEvent {
RouterEvent {
worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: hashes
.iter()
.map(|i| ExternalSequenceBlockHash(*i * 100))
.collect(),
}),
dp_rank: 0,
},
}
}
use crate::protocols::{ExternalSequenceBlockHash, LocalBlockHash};
use crate::test_utils::{create_remove_event, create_store_event};
#[test]
fn test_radix_tree() {
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::HashSet;
use std::fmt;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
......@@ -577,6 +578,26 @@ fn convert_event(
block_mm_infos,
..
} => {
// Reject self-referencing blocks: all block hashes (including parent) must be unique.
{
let mut seen = HashSet::with_capacity(block_hashes.len() + 1);
if let Some(parent) = parent_block_hash {
seen.insert(parent.into_u64());
}
let has_duplicate = block_hashes.iter().any(|h| !seen.insert(h.into_u64()));
if has_duplicate {
tracing::warn!(
event_id,
"Self-referencing block detected: duplicate hash in store event; dropping"
);
return KvCacheEvent {
event_id,
data: KvCacheEventData::Cleared,
dp_rank,
};
}
}
let num_block_tokens = vec![block_size as u64; block_hashes.len()];
let block_hashes_u64: Vec<u64> = block_hashes
.into_iter()
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment