Unverified Commit 04e05f8c authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

fix(kv-router): default to compressed concurrent tree (#7874)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent a98b4d1f
...@@ -97,10 +97,10 @@ struct Node { ...@@ -97,10 +97,10 @@ struct Node {
edge: Vec<(LocalBlockHash, ExternalSequenceBlockHash)>, edge: Vec<(LocalBlockHash, ExternalSequenceBlockHash)>,
/// Reverse index: `ExternalSequenceBlockHash` → position in `edge`. /// Reverse index: `ExternalSequenceBlockHash` → position in `edge`.
/// Provides O(1) position lookup during removal, avoiding a linear scan. /// Provides O(1) position lookup during removal, avoiding a linear scan.
edge_index: FxHashMap<ExternalSequenceBlockHash, u16>, edge_index: FxHashMap<ExternalSequenceBlockHash, usize>,
/// Workers with partial edge coverage. `worker_cutoffs[w] = k` means worker `w` /// Workers with partial edge coverage. `worker_cutoffs[w] = k` means worker `w`
/// has cached `edge[0..k]`, where `0 < k < edge.len()`. /// has cached `edge[0..k]`, where `0 < k < edge.len()`.
worker_cutoffs: FxHashMap<WorkerWithDpRank, u16>, worker_cutoffs: FxHashMap<WorkerWithDpRank, usize>,
/// Workers with full edge coverage (match index == edge.len()). /// Workers with full edge coverage (match index == edge.len()).
full_edge_workers: FxHashSet<WorkerWithDpRank>, full_edge_workers: FxHashSet<WorkerWithDpRank>,
/// Child nodes, keyed by the first `LocalBlockHash` of the child's edge. /// Child nodes, keyed by the first `LocalBlockHash` of the child's edge.
...@@ -247,13 +247,13 @@ impl ConcurrentRadixTreeCompressed { ...@@ -247,13 +247,13 @@ impl ConcurrentRadixTreeCompressed {
let suffix_edge = node.edge.split_off(pos); let suffix_edge = node.edge.split_off(pos);
let suffix_first_local = suffix_edge[0].0; let suffix_first_local = suffix_edge[0].0;
let prefix_len = pos as u16; let prefix_len = pos;
// Build suffix edge_index (positions reindexed from 0). // Build suffix edge_index (positions reindexed from 0).
let mut suffix_edge_index = let mut suffix_edge_index =
FxHashMap::with_capacity_and_hasher(suffix_edge.len(), FxBuildHasher); FxHashMap::with_capacity_and_hasher(suffix_edge.len(), FxBuildHasher);
for (i, &(_, h)) in suffix_edge.iter().enumerate() { for (i, &(_, h)) in suffix_edge.iter().enumerate() {
suffix_edge_index.insert(h, i as u16); suffix_edge_index.insert(h, i);
} }
// Remove suffix hashes from the prefix edge_index. // Remove suffix hashes from the prefix edge_index.
for &(_, h) in &suffix_edge { for &(_, h) in &suffix_edge {
...@@ -318,7 +318,7 @@ impl ConcurrentRadixTreeCompressed { ...@@ -318,7 +318,7 @@ impl ConcurrentRadixTreeCompressed {
} }
for (&w, &k) in &guard.worker_cutoffs { for (&w, &k) in &guard.worker_cutoffs {
if let Some(wl) = lookup.get_mut(&w) { if let Some(wl) = lookup.get_mut(&w) {
for &(_, h) in &guard.edge[..k as usize] { for &(_, h) in &guard.edge[..k] {
wl.insert(h, split.suffix.clone()); wl.insert(h, split.suffix.clone());
} }
} }
...@@ -390,7 +390,7 @@ impl ConcurrentRadixTreeCompressed { ...@@ -390,7 +390,7 @@ impl ConcurrentRadixTreeCompressed {
active = guard.full_edge_workers.clone(); active = guard.full_edge_workers.clone();
active_count = active.len(); active_count = active.len();
for (&w, &k) in &guard.worker_cutoffs { for (&w, &k) in &guard.worker_cutoffs {
let contribution = (k as usize).min(edge_match_len) as u32; let contribution = k.min(edge_match_len) as u32;
if contribution > 0 { if contribution > 0 {
scores.scores.insert(w, contribution); scores.scores.insert(w, contribution);
} }
...@@ -404,7 +404,7 @@ impl ConcurrentRadixTreeCompressed { ...@@ -404,7 +404,7 @@ impl ConcurrentRadixTreeCompressed {
if guard.full_edge_workers.contains(w) { if guard.full_edge_workers.contains(w) {
true true
} else if let Some(&k) = guard.worker_cutoffs.get(w) { } else if let Some(&k) = guard.worker_cutoffs.get(w) {
let effective = (k as usize).min(edge_match_len) as u32; let effective = k.min(edge_match_len) as u32;
scores.scores.insert(*w, prev_depth + effective); scores.scores.insert(*w, prev_depth + effective);
false false
} else { } else {
...@@ -535,18 +535,12 @@ impl ConcurrentRadixTreeCompressed { ...@@ -535,18 +535,12 @@ impl ConcurrentRadixTreeCompressed {
// stale entry in the lookup map. // stale entry in the lookup map.
{ {
let guard = node.read(); let guard = node.read();
if let Some(&pos_u16) = guard.edge_index.get(&parent_hash) { if let Some(&pos) = guard.edge_index.get(&parent_hash) {
let pos = pos_u16 as usize;
let is_full = guard.full_edge_workers.contains(&worker); let is_full = guard.full_edge_workers.contains(&worker);
let cutoff = if is_full { let cutoff = if is_full {
guard.edge.len() guard.edge.len()
} else { } else {
guard guard.worker_cutoffs.get(&worker).copied().unwrap_or(0)
.worker_cutoffs
.get(&worker)
.copied()
.map(|k| k as usize)
.unwrap_or(0)
}; };
if pos >= cutoff { if pos >= cutoff {
tracing::warn!( tracing::warn!(
...@@ -665,7 +659,7 @@ impl ConcurrentRadixTreeCompressed { ...@@ -665,7 +659,7 @@ impl ConcurrentRadixTreeCompressed {
let mut edge_index = let mut edge_index =
FxHashMap::with_capacity_and_hasher(edge.len(), FxBuildHasher); FxHashMap::with_capacity_and_hasher(edge.len(), FxBuildHasher);
for (i, &(_, h)) in edge.iter().enumerate() { for (i, &(_, h)) in edge.iter().enumerate() {
edge_index.insert(h, i as u16); edge_index.insert(h, i);
} }
let mut full_edge_workers = let mut full_edge_workers =
FxHashSet::with_capacity_and_hasher(1, FxBuildHasher); FxHashSet::with_capacity_and_hasher(1, FxBuildHasher);
...@@ -734,7 +728,7 @@ impl ConcurrentRadixTreeCompressed { ...@@ -734,7 +728,7 @@ impl ConcurrentRadixTreeCompressed {
let mut edge_index = let mut edge_index =
FxHashMap::with_capacity_and_hasher(edge.len(), FxBuildHasher); FxHashMap::with_capacity_and_hasher(edge.len(), FxBuildHasher);
for (i, &(_, h)) in edge.iter().enumerate() { for (i, &(_, h)) in edge.iter().enumerate() {
edge_index.insert(h, i as u16); edge_index.insert(h, i);
} }
let mut full_edge_workers = let mut full_edge_workers =
FxHashSet::with_capacity_and_hasher(1, FxBuildHasher); FxHashSet::with_capacity_and_hasher(1, FxBuildHasher);
...@@ -855,21 +849,14 @@ impl ConcurrentRadixTreeCompressed { ...@@ -855,21 +849,14 @@ impl ConcurrentRadixTreeCompressed {
match guard.edge_index.get(&block_hash).copied() { match guard.edge_index.get(&block_hash).copied() {
None => None, // stale: hash moved to a child None => None, // stale: hash moved to a child
Some(pos_u16) => { Some(pos) => {
let pos = pos_u16 as usize;
// Determine the worker's current match index. // Determine the worker's current match index.
// Use 0 as sentinel for "not tracked" → pos >= 0 is always true → no-op. // Use 0 as sentinel for "not tracked" → pos >= 0 is always true → no-op.
let is_full = guard.full_edge_workers.contains(&worker); let is_full = guard.full_edge_workers.contains(&worker);
let current_cutoff = if is_full { let current_cutoff = if is_full {
guard.edge.len() guard.edge.len()
} else { } else {
guard guard.worker_cutoffs.get(&worker).copied().unwrap_or(0)
.worker_cutoffs
.get(&worker)
.copied()
.map(|k| k as usize)
.unwrap_or(0)
}; };
if pos >= current_cutoff { if pos >= current_cutoff {
...@@ -891,7 +878,7 @@ impl ConcurrentRadixTreeCompressed { ...@@ -891,7 +878,7 @@ impl ConcurrentRadixTreeCompressed {
if is_full { if is_full {
guard.full_edge_workers.remove(&worker); guard.full_edge_workers.remove(&worker);
} }
guard.worker_cutoffs.insert(worker, new_cutoff as u16); guard.worker_cutoffs.insert(worker, new_cutoff);
} }
if !guard.has_any_workers() { if !guard.has_any_workers() {
...@@ -1146,7 +1133,7 @@ impl ConcurrentRadixTreeCompressed { ...@@ -1146,7 +1133,7 @@ impl ConcurrentRadixTreeCompressed {
event_id, event_id,
data: KvCacheEventData::Stored(KvCacheStoreData { data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash, parent_hash,
blocks: full_blocks[..k as usize].to_vec(), blocks: full_blocks[..k].to_vec(),
}), }),
dp_rank: worker.dp_rank, dp_rank: worker.dp_rank,
}, },
......
...@@ -6,7 +6,7 @@ use std::sync::Arc; ...@@ -6,7 +6,7 @@ use std::sync::Arc;
use anyhow::Result; use anyhow::Result;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use crate::ConcurrentRadixTree; use crate::ConcurrentRadixTreeCompressed;
use crate::ThreadPoolIndexer; use crate::ThreadPoolIndexer;
use crate::indexer::{KvIndexer, KvIndexerInterface, KvIndexerMetrics}; use crate::indexer::{KvIndexer, KvIndexerInterface, KvIndexerMetrics};
use crate::protocols::{LocalBlockHash, OverlapScores, RouterEvent, WorkerId}; use crate::protocols::{LocalBlockHash, OverlapScores, RouterEvent, WorkerId};
...@@ -14,7 +14,7 @@ use crate::protocols::{LocalBlockHash, OverlapScores, RouterEvent, WorkerId}; ...@@ -14,7 +14,7 @@ use crate::protocols::{LocalBlockHash, OverlapScores, RouterEvent, WorkerId};
#[derive(Clone)] #[derive(Clone)]
pub enum Indexer { pub enum Indexer {
Single(KvIndexer), Single(KvIndexer),
Concurrent(Arc<ThreadPoolIndexer<ConcurrentRadixTree>>), Concurrent(Arc<ThreadPoolIndexer<ConcurrentRadixTreeCompressed>>),
} }
impl Indexer { impl Indexer {
...@@ -57,7 +57,7 @@ impl Indexer { ...@@ -57,7 +57,7 @@ impl Indexer {
pub fn create_indexer(block_size: u32, num_threads: usize) -> Indexer { pub fn create_indexer(block_size: u32, num_threads: usize) -> Indexer {
if num_threads > 1 { if num_threads > 1 {
Indexer::Concurrent(Arc::new(ThreadPoolIndexer::new( Indexer::Concurrent(Arc::new(ThreadPoolIndexer::new(
ConcurrentRadixTree::new(), ConcurrentRadixTreeCompressed::new(),
num_threads, num_threads,
block_size, block_size,
))) )))
......
...@@ -8,7 +8,7 @@ use anyhow::Result; ...@@ -8,7 +8,7 @@ use anyhow::Result;
use futures::StreamExt; use futures::StreamExt;
use dynamo_kv_router::{ use dynamo_kv_router::{
ConcurrentRadixTree, ThreadPoolIndexer, ConcurrentRadixTreeCompressed, ThreadPoolIndexer,
approx::PruneConfig, approx::PruneConfig,
config::KvRouterConfig, config::KvRouterConfig,
indexer::{ indexer::{
...@@ -74,7 +74,7 @@ impl RemoteIndexer { ...@@ -74,7 +74,7 @@ impl RemoteIndexer {
#[derive(Clone)] #[derive(Clone)]
pub enum Indexer { pub enum Indexer {
KvIndexer(KvIndexer), KvIndexer(KvIndexer),
Concurrent(Arc<ThreadPoolIndexer<ConcurrentRadixTree>>), Concurrent(Arc<ThreadPoolIndexer<ConcurrentRadixTreeCompressed>>),
Remote(Arc<RemoteIndexer>), Remote(Arc<RemoteIndexer>),
None, None,
} }
...@@ -124,7 +124,7 @@ impl Indexer { ...@@ -124,7 +124,7 @@ impl Indexer {
if kv_router_config.router_event_threads > 1 { if kv_router_config.router_event_threads > 1 {
return Ok(Self::Concurrent(Arc::new(ThreadPoolIndexer::new( return Ok(Self::Concurrent(Arc::new(ThreadPoolIndexer::new(
ConcurrentRadixTree::new(), ConcurrentRadixTreeCompressed::new(),
kv_router_config.router_event_threads as usize, kv_router_config.router_event_threads as usize,
block_size, block_size,
)))); ))));
......
...@@ -1690,30 +1690,40 @@ def _test_router_decisions( ...@@ -1690,30 +1690,40 @@ def _test_router_decisions(
events_by_key[key] = [] events_by_key[key] = []
events_by_key[key].append(event) events_by_key[key].append(event)
def count_stored_blocks(events: list[Any]) -> int:
total = 0
for event in events:
stored = event.get("event", {}).get("data", {}).get("stored")
if stored is None:
continue
total += len(stored.get("blocks", []))
return total
logger.info( logger.info(
f"Events by (worker_id, dp_rank): {[(key, len(evts)) for key, evts in events_by_key.items()]}" "Stored blocks by (worker_id, dp_rank): "
f"{[(key, count_stored_blocks(evts)) for key, evts in events_by_key.items()]}"
) )
# Worker a key: 5 events (A, B from req1; C, D from req2; F from req4) # Worker a key: 5 stored blocks (A, B from req1; C, D from req2; F from req4)
worker_a_key = (worker_a_id, dp_rank_a if dp_rank_a is not None else 0) worker_a_key = (worker_a_id, dp_rank_a if dp_rank_a is not None else 0)
worker_a_events = len(events_by_key.get(worker_a_key, [])) worker_a_blocks = count_stored_blocks(events_by_key.get(worker_a_key, []))
assert worker_a_events == 5, ( assert worker_a_blocks == 5, (
f"Expected worker_a {worker_a_key} to have 5 events (A,B + C,D + F), " f"Expected worker_a {worker_a_key} to have 5 stored blocks (A,B + C,D + F), "
f"but found {worker_a_events}" f"but found {worker_a_blocks}"
) )
# Worker b key: 4 events (A, C, E from req3; G from req5) # Worker b key: 4 stored blocks (A, C, E from req3; G from req5)
worker_b_key = (worker_b_id, dp_rank_b if dp_rank_b is not None else 0) worker_b_key = (worker_b_id, dp_rank_b if dp_rank_b is not None else 0)
worker_b_events = len(events_by_key.get(worker_b_key, [])) worker_b_blocks = count_stored_blocks(events_by_key.get(worker_b_key, []))
assert worker_b_events == 4, ( assert worker_b_blocks == 4, (
f"Expected worker_b {worker_b_key} to have 4 events (A,C,E + G), " f"Expected worker_b {worker_b_key} to have 4 stored blocks (A,C,E + G), "
f"but found {worker_b_events}" f"but found {worker_b_blocks}"
) )
logger.info( logger.info(
f"Successfully verified cross-worker routing: " f"Successfully verified cross-worker routing: "
f"worker_a {worker_a_key} has {worker_a_events} events, " f"worker_a {worker_a_key} has {worker_a_blocks} stored blocks, "
f"worker_b {worker_b_key} has {worker_b_events} events" f"worker_b {worker_b_key} has {worker_b_blocks} stored blocks"
) )
# Verify standalone indexer scores via HTTP POST /query # Verify standalone indexer scores via HTTP POST /query
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment