Unverified Commit 04e05f8c authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

fix(kv-router): default to compressed concurrent tree (#7874)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent a98b4d1f
......@@ -97,10 +97,10 @@ struct Node {
edge: Vec<(LocalBlockHash, ExternalSequenceBlockHash)>,
/// Reverse index: `ExternalSequenceBlockHash` → position in `edge`.
/// Provides O(1) position lookup during removal, avoiding a linear scan.
edge_index: FxHashMap<ExternalSequenceBlockHash, u16>,
edge_index: FxHashMap<ExternalSequenceBlockHash, usize>,
/// Workers with partial edge coverage. `worker_cutoffs[w] = k` means worker `w`
/// has cached `edge[0..k]`, where `0 < k < edge.len()`.
worker_cutoffs: FxHashMap<WorkerWithDpRank, u16>,
worker_cutoffs: FxHashMap<WorkerWithDpRank, usize>,
/// Workers with full edge coverage (match index == edge.len()).
full_edge_workers: FxHashSet<WorkerWithDpRank>,
/// Child nodes, keyed by the first `LocalBlockHash` of the child's edge.
......@@ -247,13 +247,13 @@ impl ConcurrentRadixTreeCompressed {
let suffix_edge = node.edge.split_off(pos);
let suffix_first_local = suffix_edge[0].0;
let prefix_len = pos as u16;
let prefix_len = pos;
// Build suffix edge_index (positions reindexed from 0).
let mut suffix_edge_index =
FxHashMap::with_capacity_and_hasher(suffix_edge.len(), FxBuildHasher);
for (i, &(_, h)) in suffix_edge.iter().enumerate() {
suffix_edge_index.insert(h, i as u16);
suffix_edge_index.insert(h, i);
}
// Remove suffix hashes from the prefix edge_index.
for &(_, h) in &suffix_edge {
......@@ -318,7 +318,7 @@ impl ConcurrentRadixTreeCompressed {
}
for (&w, &k) in &guard.worker_cutoffs {
if let Some(wl) = lookup.get_mut(&w) {
for &(_, h) in &guard.edge[..k as usize] {
for &(_, h) in &guard.edge[..k] {
wl.insert(h, split.suffix.clone());
}
}
......@@ -390,7 +390,7 @@ impl ConcurrentRadixTreeCompressed {
active = guard.full_edge_workers.clone();
active_count = active.len();
for (&w, &k) in &guard.worker_cutoffs {
let contribution = (k as usize).min(edge_match_len) as u32;
let contribution = k.min(edge_match_len) as u32;
if contribution > 0 {
scores.scores.insert(w, contribution);
}
......@@ -404,7 +404,7 @@ impl ConcurrentRadixTreeCompressed {
if guard.full_edge_workers.contains(w) {
true
} else if let Some(&k) = guard.worker_cutoffs.get(w) {
let effective = (k as usize).min(edge_match_len) as u32;
let effective = k.min(edge_match_len) as u32;
scores.scores.insert(*w, prev_depth + effective);
false
} else {
......@@ -535,18 +535,12 @@ impl ConcurrentRadixTreeCompressed {
// stale entry in the lookup map.
{
let guard = node.read();
if let Some(&pos_u16) = guard.edge_index.get(&parent_hash) {
let pos = pos_u16 as usize;
if let Some(&pos) = guard.edge_index.get(&parent_hash) {
let is_full = guard.full_edge_workers.contains(&worker);
let cutoff = if is_full {
guard.edge.len()
} else {
guard
.worker_cutoffs
.get(&worker)
.copied()
.map(|k| k as usize)
.unwrap_or(0)
guard.worker_cutoffs.get(&worker).copied().unwrap_or(0)
};
if pos >= cutoff {
tracing::warn!(
......@@ -665,7 +659,7 @@ impl ConcurrentRadixTreeCompressed {
let mut edge_index =
FxHashMap::with_capacity_and_hasher(edge.len(), FxBuildHasher);
for (i, &(_, h)) in edge.iter().enumerate() {
edge_index.insert(h, i as u16);
edge_index.insert(h, i);
}
let mut full_edge_workers =
FxHashSet::with_capacity_and_hasher(1, FxBuildHasher);
......@@ -734,7 +728,7 @@ impl ConcurrentRadixTreeCompressed {
let mut edge_index =
FxHashMap::with_capacity_and_hasher(edge.len(), FxBuildHasher);
for (i, &(_, h)) in edge.iter().enumerate() {
edge_index.insert(h, i as u16);
edge_index.insert(h, i);
}
let mut full_edge_workers =
FxHashSet::with_capacity_and_hasher(1, FxBuildHasher);
......@@ -855,21 +849,14 @@ impl ConcurrentRadixTreeCompressed {
match guard.edge_index.get(&block_hash).copied() {
None => None, // stale: hash moved to a child
Some(pos_u16) => {
let pos = pos_u16 as usize;
Some(pos) => {
// Determine the worker's current match index.
// Use 0 as sentinel for "not tracked" → pos >= 0 is always true → no-op.
let is_full = guard.full_edge_workers.contains(&worker);
let current_cutoff = if is_full {
guard.edge.len()
} else {
guard
.worker_cutoffs
.get(&worker)
.copied()
.map(|k| k as usize)
.unwrap_or(0)
guard.worker_cutoffs.get(&worker).copied().unwrap_or(0)
};
if pos >= current_cutoff {
......@@ -891,7 +878,7 @@ impl ConcurrentRadixTreeCompressed {
if is_full {
guard.full_edge_workers.remove(&worker);
}
guard.worker_cutoffs.insert(worker, new_cutoff as u16);
guard.worker_cutoffs.insert(worker, new_cutoff);
}
if !guard.has_any_workers() {
......@@ -1146,7 +1133,7 @@ impl ConcurrentRadixTreeCompressed {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks: full_blocks[..k as usize].to_vec(),
blocks: full_blocks[..k].to_vec(),
}),
dp_rank: worker.dp_rank,
},
......
......@@ -6,7 +6,7 @@ use std::sync::Arc;
use anyhow::Result;
use tokio_util::sync::CancellationToken;
use crate::ConcurrentRadixTree;
use crate::ConcurrentRadixTreeCompressed;
use crate::ThreadPoolIndexer;
use crate::indexer::{KvIndexer, KvIndexerInterface, KvIndexerMetrics};
use crate::protocols::{LocalBlockHash, OverlapScores, RouterEvent, WorkerId};
......@@ -14,7 +14,7 @@ use crate::protocols::{LocalBlockHash, OverlapScores, RouterEvent, WorkerId};
#[derive(Clone)]
pub enum Indexer {
Single(KvIndexer),
Concurrent(Arc<ThreadPoolIndexer<ConcurrentRadixTree>>),
Concurrent(Arc<ThreadPoolIndexer<ConcurrentRadixTreeCompressed>>),
}
impl Indexer {
......@@ -57,7 +57,7 @@ impl Indexer {
pub fn create_indexer(block_size: u32, num_threads: usize) -> Indexer {
if num_threads > 1 {
Indexer::Concurrent(Arc::new(ThreadPoolIndexer::new(
ConcurrentRadixTree::new(),
ConcurrentRadixTreeCompressed::new(),
num_threads,
block_size,
)))
......
......@@ -8,7 +8,7 @@ use anyhow::Result;
use futures::StreamExt;
use dynamo_kv_router::{
ConcurrentRadixTree, ThreadPoolIndexer,
ConcurrentRadixTreeCompressed, ThreadPoolIndexer,
approx::PruneConfig,
config::KvRouterConfig,
indexer::{
......@@ -74,7 +74,7 @@ impl RemoteIndexer {
#[derive(Clone)]
pub enum Indexer {
KvIndexer(KvIndexer),
Concurrent(Arc<ThreadPoolIndexer<ConcurrentRadixTree>>),
Concurrent(Arc<ThreadPoolIndexer<ConcurrentRadixTreeCompressed>>),
Remote(Arc<RemoteIndexer>),
None,
}
......@@ -124,7 +124,7 @@ impl Indexer {
if kv_router_config.router_event_threads > 1 {
return Ok(Self::Concurrent(Arc::new(ThreadPoolIndexer::new(
ConcurrentRadixTree::new(),
ConcurrentRadixTreeCompressed::new(),
kv_router_config.router_event_threads as usize,
block_size,
))));
......
......@@ -1690,30 +1690,40 @@ def _test_router_decisions(
events_by_key[key] = []
events_by_key[key].append(event)
def count_stored_blocks(events: list[Any]) -> int:
total = 0
for event in events:
stored = event.get("event", {}).get("data", {}).get("stored")
if stored is None:
continue
total += len(stored.get("blocks", []))
return total
logger.info(
f"Events by (worker_id, dp_rank): {[(key, len(evts)) for key, evts in events_by_key.items()]}"
"Stored blocks by (worker_id, dp_rank): "
f"{[(key, count_stored_blocks(evts)) for key, evts in events_by_key.items()]}"
)
# Worker a key: 5 events (A, B from req1; C, D from req2; F from req4)
# Worker a key: 5 stored blocks (A, B from req1; C, D from req2; F from req4)
worker_a_key = (worker_a_id, dp_rank_a if dp_rank_a is not None else 0)
worker_a_events = len(events_by_key.get(worker_a_key, []))
assert worker_a_events == 5, (
f"Expected worker_a {worker_a_key} to have 5 events (A,B + C,D + F), "
f"but found {worker_a_events}"
worker_a_blocks = count_stored_blocks(events_by_key.get(worker_a_key, []))
assert worker_a_blocks == 5, (
f"Expected worker_a {worker_a_key} to have 5 stored blocks (A,B + C,D + F), "
f"but found {worker_a_blocks}"
)
# Worker b key: 4 events (A, C, E from req3; G from req5)
# Worker b key: 4 stored blocks (A, C, E from req3; G from req5)
worker_b_key = (worker_b_id, dp_rank_b if dp_rank_b is not None else 0)
worker_b_events = len(events_by_key.get(worker_b_key, []))
assert worker_b_events == 4, (
f"Expected worker_b {worker_b_key} to have 4 events (A,C,E + G), "
f"but found {worker_b_events}"
worker_b_blocks = count_stored_blocks(events_by_key.get(worker_b_key, []))
assert worker_b_blocks == 4, (
f"Expected worker_b {worker_b_key} to have 4 stored blocks (A,C,E + G), "
f"but found {worker_b_blocks}"
)
logger.info(
f"Successfully verified cross-worker routing: "
f"worker_a {worker_a_key} has {worker_a_events} events, "
f"worker_b {worker_b_key} has {worker_b_events} events"
f"worker_a {worker_a_key} has {worker_a_blocks} stored blocks, "
f"worker_b {worker_b_key} has {worker_b_blocks} stored blocks"
)
# Verify standalone indexer scores via HTTP POST /query
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment