Unverified Commit ed4d8068 authored by Janelle Cai's avatar Janelle Cai Committed by GitHub
Browse files

feat: radix tree implementation (#7459)

parent 585b4df7
...@@ -11,7 +11,9 @@ use dynamo_kv_router::indexer::{ ...@@ -11,7 +11,9 @@ use dynamo_kv_router::indexer::{
KvIndexer, KvIndexerInterface, KvIndexerMetrics, KvIndexerSharded, KvIndexer, KvIndexerInterface, KvIndexerMetrics, KvIndexerSharded,
}; };
use dynamo_kv_router::protocols::{KvCacheEvent, KvCacheEventData, RouterEvent}; use dynamo_kv_router::protocols::{KvCacheEvent, KvCacheEventData, RouterEvent};
use dynamo_kv_router::{ConcurrentRadixTree, PositionalIndexer, ThreadPoolIndexer}; use dynamo_kv_router::{
ConcurrentRadixTree, ConcurrentRadixTreeCompressed, PositionalIndexer, ThreadPoolIndexer,
};
use serde::Serialize; use serde::Serialize;
use std::sync::Arc; use std::sync::Arc;
use tokio::time::{Duration, Instant}; use tokio::time::{Duration, Instant};
...@@ -47,6 +49,13 @@ enum IndexerArgs { ...@@ -47,6 +49,13 @@ enum IndexerArgs {
#[clap(long, default_value = "16")] #[clap(long, default_value = "16")]
num_event_workers: usize, num_event_workers: usize,
}, },
/// Compressed concurrent radix tree indexer (compressed edges).
ConcurrentRadixTreeCompressed {
/// Number of OS threads that consume and apply KV cache events.
#[clap(long, default_value = "16")]
num_event_workers: usize,
},
} }
impl IndexerArgs { impl IndexerArgs {
...@@ -75,6 +84,13 @@ impl IndexerArgs { ...@@ -75,6 +84,13 @@ impl IndexerArgs {
IndexerArgs::ConcurrentRadixTree { num_event_workers } => Arc::new( IndexerArgs::ConcurrentRadixTree { num_event_workers } => Arc::new(
ThreadPoolIndexer::new(ConcurrentRadixTree::new(), num_event_workers, block_size), ThreadPoolIndexer::new(ConcurrentRadixTree::new(), num_event_workers, block_size),
), ),
IndexerArgs::ConcurrentRadixTreeCompressed { num_event_workers } => {
Arc::new(ThreadPoolIndexer::new(
ConcurrentRadixTreeCompressed::new(),
num_event_workers,
block_size,
))
}
} }
} }
...@@ -83,7 +99,10 @@ impl IndexerArgs { ...@@ -83,7 +99,10 @@ impl IndexerArgs {
} }
fn is_multi_threaded(name: &str) -> bool { fn is_multi_threaded(name: &str) -> bool {
matches!(name, "nested-map" | "concurrent-radix-tree") matches!(
name,
"nested-map" | "concurrent-radix-tree" | "concurrent-radix-tree-compressed"
)
} }
/// Construct an indexer from a short name string. /// Construct an indexer from a short name string.
...@@ -103,9 +122,12 @@ impl IndexerArgs { ...@@ -103,9 +122,12 @@ impl IndexerArgs {
"concurrent-radix-tree" => IndexerArgs::ConcurrentRadixTree { "concurrent-radix-tree" => IndexerArgs::ConcurrentRadixTree {
num_event_workers: nw, num_event_workers: nw,
}, },
"concurrent-radix-tree-compressed" => IndexerArgs::ConcurrentRadixTreeCompressed {
num_event_workers: nw,
},
_ => anyhow::bail!( _ => anyhow::bail!(
"Unknown indexer '{}'. Valid names: radix-tree, radix-tree-sharded, \ "Unknown indexer '{}'. Valid names: radix-tree, radix-tree-sharded, \
nested-map, concurrent-radix-tree", nested-map, concurrent-radix-tree, concurrent-radix-tree-compressed",
name name
), ),
}; };
...@@ -125,7 +147,8 @@ struct Args { ...@@ -125,7 +147,8 @@ struct Args {
/// Comma-separated list of indexer names to benchmark and compare on the /// Comma-separated list of indexer names to benchmark and compare on the
/// same plot. Overrides the subcommand indexer when present. Valid names: /// same plot. Overrides the subcommand indexer when present. Valid names:
/// radix-tree, radix-tree-sharded, nested-map, concurrent-radix-tree. /// radix-tree, radix-tree-sharded, nested-map, concurrent-radix-tree,
/// concurrent-radix-tree-compressed.
#[clap(long, value_delimiter = ',')] #[clap(long, value_delimiter = ',')]
compare: Vec<String>, compare: Vec<String>,
...@@ -536,6 +559,7 @@ async fn main() -> anyhow::Result<()> { ...@@ -536,6 +559,7 @@ async fn main() -> anyhow::Result<()> {
IndexerArgs::RadixTreeSharded { .. } => "radix-tree-sharded", IndexerArgs::RadixTreeSharded { .. } => "radix-tree-sharded",
IndexerArgs::NestedMap { .. } => "nested-map", IndexerArgs::NestedMap { .. } => "nested-map",
IndexerArgs::ConcurrentRadixTree { .. } => "concurrent-radix-tree", IndexerArgs::ConcurrentRadixTree { .. } => "concurrent-radix-tree",
IndexerArgs::ConcurrentRadixTreeCompressed { .. } => "concurrent-radix-tree-compressed",
}; };
vec![name.to_string()] vec![name.to_string()]
} else { } else {
......
...@@ -347,8 +347,6 @@ impl ConcurrentRadixTree { ...@@ -347,8 +347,6 @@ impl ConcurrentRadixTree {
let num_blocks_added = op.blocks.len(); let num_blocks_added = op.blocks.len();
// In each iteration, we lock the parent block and insert the worker into it from
// the previous iteration. This avoids locking a block twice.
for block_data in op.blocks { for block_data in op.blocks {
let child = { let child = {
let mut parent_guard = current.write(); let mut parent_guard = current.write();
...@@ -364,7 +362,6 @@ impl ConcurrentRadixTree { ...@@ -364,7 +362,6 @@ impl ConcurrentRadixTree {
// parent_guard is dropped at the end of this block // parent_guard is dropped at the end of this block
match parent_guard.children.get(&block_data.tokens_hash) { match parent_guard.children.get(&block_data.tokens_hash) {
Some(existing) => { Some(existing) => {
// Verify our simplifying assumption: block_hash is uniform across workers
{ {
let existing_guard = existing.read(); let existing_guard = existing.read();
if existing_guard.block_hash != Some(block_data.block_hash) { if existing_guard.block_hash != Some(block_data.block_hash) {
...@@ -410,8 +407,6 @@ impl ConcurrentRadixTree { ...@@ -410,8 +407,6 @@ impl ConcurrentRadixTree {
} }
} }
// Insert worker into the last child (not yet handled since there is
// no subsequent iteration to pick it up).
if needs_worker_insert { if needs_worker_insert {
current.write().workers.insert(worker); current.write().workers.insert(worker);
} }
...@@ -451,7 +446,6 @@ impl ConcurrentRadixTree { ...@@ -451,7 +446,6 @@ impl ConcurrentRadixTree {
continue; continue;
}; };
// Remove the worker from this block's worker set.
let mut guard = block.write(); let mut guard = block.write();
guard.workers.remove(&worker); guard.workers.remove(&worker);
if guard.workers.is_empty() { if guard.workers.is_empty() {
...@@ -569,7 +563,6 @@ impl ConcurrentRadixTree { ...@@ -569,7 +563,6 @@ impl ConcurrentRadixTree {
// Queue entries: (current_block, parent_hash, tokens_hash) // Queue entries: (current_block, parent_hash, tokens_hash)
let mut queue = VecDeque::new(); let mut queue = VecDeque::new();
// Process root's children first
{ {
let root_guard = self.root.read(); let root_guard = self.root.read();
for (tokens_hash, child_block) in &root_guard.children { for (tokens_hash, child_block) in &root_guard.children {
......
This diff is collapsed.
...@@ -40,6 +40,7 @@ mod traits; ...@@ -40,6 +40,7 @@ mod traits;
mod types; mod types;
pub mod concurrent_radix_tree; pub mod concurrent_radix_tree;
pub mod concurrent_radix_tree_compressed;
pub mod positional; pub mod positional;
pub mod pruning; pub mod pruning;
pub mod radix_tree; pub mod radix_tree;
......
...@@ -10,6 +10,7 @@ use tokio::time; ...@@ -10,6 +10,7 @@ use tokio::time;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use super::concurrent_radix_tree::ConcurrentRadixTree; use super::concurrent_radix_tree::ConcurrentRadixTree;
use super::concurrent_radix_tree_compressed::ConcurrentRadixTreeCompressed;
use super::positional::PositionalIndexer; use super::positional::PositionalIndexer;
use super::*; use super::*;
use crate::protocols::*; use crate::protocols::*;
...@@ -204,7 +205,10 @@ fn make_clear_event_with_dp_rank(worker_id: u64, dp_rank: u32) -> RouterEvent { ...@@ -204,7 +205,10 @@ fn make_clear_event_with_dp_rank(worker_id: u64, dp_rank: u32) -> RouterEvent {
#[template] #[template]
#[rstest] #[rstest]
fn indexer_template(#[values("single", "sharded", "flat", "concurrent")] variant: &str) {} fn indexer_template(
#[values("single", "sharded", "flat", "concurrent", "concurrent_compressed")] variant: &str,
) {
}
fn make_indexer(variant: &str) -> Box<dyn KvIndexerInterface> { fn make_indexer(variant: &str) -> Box<dyn KvIndexerInterface> {
let token = CancellationToken::new(); let token = CancellationToken::new();
...@@ -224,6 +228,11 @@ fn make_indexer(variant: &str) -> Box<dyn KvIndexerInterface> { ...@@ -224,6 +228,11 @@ fn make_indexer(variant: &str) -> Box<dyn KvIndexerInterface> {
4, 4,
kv_block_size, kv_block_size,
)), )),
"concurrent_compressed" => Box::new(ThreadPoolIndexer::new(
ConcurrentRadixTreeCompressed::new(),
4,
kv_block_size,
)),
_ => panic!("Unknown variant: {}", variant), _ => panic!("Unknown variant: {}", variant),
} }
} }
......
...@@ -123,6 +123,28 @@ impl<T: SyncIndexer> ThreadPoolIndexer<T> { ...@@ -123,6 +123,28 @@ impl<T: SyncIndexer> ThreadPoolIndexer<T> {
} }
} }
impl<T: SyncIndexer> Drop for ThreadPoolIndexer<T> {
fn drop(&mut self) {
// Send Terminate to all worker threads so they exit their recv loops
// and drop their Arc<T> clones. Then join the threads to ensure the
// clones are actually dropped before the compiler drops `self.backend`.
// Without this, worker threads may still be alive when `backend` drops,
// keeping the Arc refcount > 0 and preventing T::drop() from running.
for channel in self.worker_event_channels.iter() {
let _ = channel.send(WorkerTask::Terminate);
}
let handles = std::mem::take(
&mut *self
.thread_handles
.lock()
.expect("thread_handles mutex poisoned"),
);
for handle in handles {
let _ = handle.join();
}
}
}
#[async_trait] #[async_trait]
impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> { impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> {
async fn find_matches( async fn find_matches(
...@@ -217,12 +239,10 @@ impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> { ...@@ -217,12 +239,10 @@ impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> {
} }
async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> { async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
// Fast path: backend can dump directly from shared state (e.g. ConcurrentRadixTree). // Send DumpEvents to every worker as a FIFO barrier: each worker must
if let Some(events) = self.backend.dump_events() { // finish processing all previously queued Events before it handles
return Ok(events); // DumpEvents, so by the time all workers respond we know the shared
} // tree (if any) reflects every event that was enqueued before this call.
// Slow path: collect from each worker thread via channel (e.g. PositionalIndexer).
let mut receivers = Vec::new(); let mut receivers = Vec::new();
for channel in &self.worker_event_channels { for channel in &self.worker_event_channels {
...@@ -235,9 +255,8 @@ impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> { ...@@ -235,9 +255,8 @@ impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> {
receivers.push(resp_rx); receivers.push(resp_rx);
} }
let mut event_id_counter = 0;
let mut all_events = Vec::new(); let mut all_events = Vec::new();
let mut event_id_counter = 0u64;
for resp_rx in receivers { for resp_rx in receivers {
let mut events = resp_rx let mut events = resp_rx
...@@ -251,6 +270,15 @@ impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> { ...@@ -251,6 +270,15 @@ impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> {
all_events.extend(events); all_events.extend(events);
} }
// Shared-state backends keep their tree in concurrent structures
// readable from any thread. Now that the barrier above guarantees
// all queued writes have landed, dump directly.
if let Some(events) = self.backend.dump_events() {
return Ok(events);
}
// Per-thread-state backends returned their events through the DumpEvents
// responses collected above.
Ok(all_events) Ok(all_events)
} }
......
...@@ -15,6 +15,7 @@ pub mod zmq_wire; ...@@ -15,6 +15,7 @@ pub mod zmq_wire;
// Backward-compat re-exports: old top-level module paths still work // Backward-compat re-exports: old top-level module paths still work
pub use indexer::concurrent_radix_tree; pub use indexer::concurrent_radix_tree;
pub use indexer::concurrent_radix_tree_compressed;
pub use indexer::positional as nested_map; pub use indexer::positional as nested_map;
pub use indexer::pruning as approx; pub use indexer::pruning as approx;
pub use indexer::radix_tree; pub use indexer::radix_tree;
...@@ -38,6 +39,7 @@ pub use self::multi_worker_sequence::{ ...@@ -38,6 +39,7 @@ pub use self::multi_worker_sequence::{
}; };
pub use self::sequence::{ActiveSequences, RequestId}; pub use self::sequence::{ActiveSequences, RequestId};
pub use concurrent_radix_tree::ConcurrentRadixTree; pub use concurrent_radix_tree::ConcurrentRadixTree;
pub use concurrent_radix_tree_compressed::ConcurrentRadixTreeCompressed;
pub use config::{KvRouterConfig, RouterConfigOverride, RouterQueuePolicy}; pub use config::{KvRouterConfig, RouterConfigOverride, RouterQueuePolicy};
pub use event_sink::EventSink; pub use event_sink::EventSink;
pub use indexer::{MaybeError, SyncIndexer, ThreadPoolIndexer}; pub use indexer::{MaybeError, SyncIndexer, ThreadPoolIndexer};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment