Unverified Commit ed4d8068 authored by Janelle Cai's avatar Janelle Cai Committed by GitHub
Browse files

feat: radix tree implementation (#7459)

parent 585b4df7
......@@ -11,7 +11,9 @@ use dynamo_kv_router::indexer::{
KvIndexer, KvIndexerInterface, KvIndexerMetrics, KvIndexerSharded,
};
use dynamo_kv_router::protocols::{KvCacheEvent, KvCacheEventData, RouterEvent};
use dynamo_kv_router::{ConcurrentRadixTree, PositionalIndexer, ThreadPoolIndexer};
use dynamo_kv_router::{
ConcurrentRadixTree, ConcurrentRadixTreeCompressed, PositionalIndexer, ThreadPoolIndexer,
};
use serde::Serialize;
use std::sync::Arc;
use tokio::time::{Duration, Instant};
......@@ -47,6 +49,13 @@ enum IndexerArgs {
#[clap(long, default_value = "16")]
num_event_workers: usize,
},
/// Compressed concurrent radix tree indexer (compressed edges).
ConcurrentRadixTreeCompressed {
/// Number of OS threads that consume and apply KV cache events.
#[clap(long, default_value = "16")]
num_event_workers: usize,
},
}
impl IndexerArgs {
......@@ -75,6 +84,13 @@ impl IndexerArgs {
IndexerArgs::ConcurrentRadixTree { num_event_workers } => Arc::new(
ThreadPoolIndexer::new(ConcurrentRadixTree::new(), num_event_workers, block_size),
),
IndexerArgs::ConcurrentRadixTreeCompressed { num_event_workers } => {
Arc::new(ThreadPoolIndexer::new(
ConcurrentRadixTreeCompressed::new(),
num_event_workers,
block_size,
))
}
}
}
......@@ -83,7 +99,10 @@ impl IndexerArgs {
}
fn is_multi_threaded(name: &str) -> bool {
matches!(name, "nested-map" | "concurrent-radix-tree")
matches!(
name,
"nested-map" | "concurrent-radix-tree" | "concurrent-radix-tree-compressed"
)
}
/// Construct an indexer from a short name string.
......@@ -103,9 +122,12 @@ impl IndexerArgs {
"concurrent-radix-tree" => IndexerArgs::ConcurrentRadixTree {
num_event_workers: nw,
},
"concurrent-radix-tree-compressed" => IndexerArgs::ConcurrentRadixTreeCompressed {
num_event_workers: nw,
},
_ => anyhow::bail!(
"Unknown indexer '{}'. Valid names: radix-tree, radix-tree-sharded, \
nested-map, concurrent-radix-tree",
nested-map, concurrent-radix-tree, concurrent-radix-tree-compressed",
name
),
};
......@@ -125,7 +147,8 @@ struct Args {
/// Comma-separated list of indexer names to benchmark and compare on the
/// same plot. Overrides the subcommand indexer when present. Valid names:
/// radix-tree, radix-tree-sharded, nested-map, concurrent-radix-tree.
/// radix-tree, radix-tree-sharded, nested-map, concurrent-radix-tree,
/// concurrent-radix-tree-compressed.
#[clap(long, value_delimiter = ',')]
compare: Vec<String>,
......@@ -536,6 +559,7 @@ async fn main() -> anyhow::Result<()> {
IndexerArgs::RadixTreeSharded { .. } => "radix-tree-sharded",
IndexerArgs::NestedMap { .. } => "nested-map",
IndexerArgs::ConcurrentRadixTree { .. } => "concurrent-radix-tree",
IndexerArgs::ConcurrentRadixTreeCompressed { .. } => "concurrent-radix-tree-compressed",
};
vec![name.to_string()]
} else {
......
......@@ -347,8 +347,6 @@ impl ConcurrentRadixTree {
let num_blocks_added = op.blocks.len();
// In each iteration, we lock the parent block and insert the worker into it from
// the previous iteration. This avoids locking a block twice.
for block_data in op.blocks {
let child = {
let mut parent_guard = current.write();
......@@ -364,7 +362,6 @@ impl ConcurrentRadixTree {
// parent_guard is dropped at the end of this block
match parent_guard.children.get(&block_data.tokens_hash) {
Some(existing) => {
// Verify our simplifying assumption: block_hash is uniform across workers
{
let existing_guard = existing.read();
if existing_guard.block_hash != Some(block_data.block_hash) {
......@@ -410,8 +407,6 @@ impl ConcurrentRadixTree {
}
}
// Insert worker into the last child (not yet handled since there is
// no subsequent iteration to pick it up).
if needs_worker_insert {
current.write().workers.insert(worker);
}
......@@ -451,7 +446,6 @@ impl ConcurrentRadixTree {
continue;
};
// Remove the worker from this block's worker set.
let mut guard = block.write();
guard.workers.remove(&worker);
if guard.workers.is_empty() {
......@@ -569,7 +563,6 @@ impl ConcurrentRadixTree {
// Queue entries: (current_block, parent_hash, tokens_hash)
let mut queue = VecDeque::new();
// Process root's children first
{
let root_guard = self.root.read();
for (tokens_hash, child_block) in &root_guard.children {
......
This diff is collapsed.
......@@ -40,6 +40,7 @@ mod traits;
mod types;
pub mod concurrent_radix_tree;
pub mod concurrent_radix_tree_compressed;
pub mod positional;
pub mod pruning;
pub mod radix_tree;
......
......@@ -10,6 +10,7 @@ use tokio::time;
use tokio_util::sync::CancellationToken;
use super::concurrent_radix_tree::ConcurrentRadixTree;
use super::concurrent_radix_tree_compressed::ConcurrentRadixTreeCompressed;
use super::positional::PositionalIndexer;
use super::*;
use crate::protocols::*;
......@@ -204,7 +205,10 @@ fn make_clear_event_with_dp_rank(worker_id: u64, dp_rank: u32) -> RouterEvent {
#[template]
#[rstest]
fn indexer_template(#[values("single", "sharded", "flat", "concurrent")] variant: &str) {}
fn indexer_template(
#[values("single", "sharded", "flat", "concurrent", "concurrent_compressed")] variant: &str,
) {
}
fn make_indexer(variant: &str) -> Box<dyn KvIndexerInterface> {
let token = CancellationToken::new();
......@@ -224,6 +228,11 @@ fn make_indexer(variant: &str) -> Box<dyn KvIndexerInterface> {
4,
kv_block_size,
)),
"concurrent_compressed" => Box::new(ThreadPoolIndexer::new(
ConcurrentRadixTreeCompressed::new(),
4,
kv_block_size,
)),
_ => panic!("Unknown variant: {}", variant),
}
}
......
......@@ -123,6 +123,28 @@ impl<T: SyncIndexer> ThreadPoolIndexer<T> {
}
}
impl<T: SyncIndexer> Drop for ThreadPoolIndexer<T> {
fn drop(&mut self) {
// Send Terminate to all worker threads so they exit their recv loops
// and drop their Arc<T> clones. Then join the threads to ensure the
// clones are actually dropped before the compiler drops `self.backend`.
// Without this, worker threads may still be alive when `backend` drops,
// keeping the Arc refcount > 0 and preventing T::drop() from running.
for channel in self.worker_event_channels.iter() {
let _ = channel.send(WorkerTask::Terminate);
}
let handles = std::mem::take(
&mut *self
.thread_handles
.lock()
.expect("thread_handles mutex poisoned"),
);
for handle in handles {
let _ = handle.join();
}
}
}
#[async_trait]
impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> {
async fn find_matches(
......@@ -217,12 +239,10 @@ impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> {
}
async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
// Fast path: backend can dump directly from shared state (e.g. ConcurrentRadixTree).
if let Some(events) = self.backend.dump_events() {
return Ok(events);
}
// Slow path: collect from each worker thread via channel (e.g. PositionalIndexer).
// Send DumpEvents to every worker as a FIFO barrier: each worker must
// finish processing all previously queued Events before it handles
// DumpEvents, so by the time all workers respond we know the shared
// tree (if any) reflects every event that was enqueued before this call.
let mut receivers = Vec::new();
for channel in &self.worker_event_channels {
......@@ -235,9 +255,8 @@ impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> {
receivers.push(resp_rx);
}
let mut event_id_counter = 0;
let mut all_events = Vec::new();
let mut event_id_counter = 0u64;
for resp_rx in receivers {
let mut events = resp_rx
......@@ -251,6 +270,15 @@ impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> {
all_events.extend(events);
}
// Shared-state backends keep their tree in concurrent structures
// readable from any thread. Now that the barrier above guarantees
// all queued writes have landed, dump directly.
if let Some(events) = self.backend.dump_events() {
return Ok(events);
}
// Per-thread-state backends returned their events through the DumpEvents
// responses collected above.
Ok(all_events)
}
......
......@@ -15,6 +15,7 @@ pub mod zmq_wire;
// Backward-compat re-exports: old top-level module paths still work
pub use indexer::concurrent_radix_tree;
pub use indexer::concurrent_radix_tree_compressed;
pub use indexer::positional as nested_map;
pub use indexer::pruning as approx;
pub use indexer::radix_tree;
......@@ -38,6 +39,7 @@ pub use self::multi_worker_sequence::{
};
pub use self::sequence::{ActiveSequences, RequestId};
pub use concurrent_radix_tree::ConcurrentRadixTree;
pub use concurrent_radix_tree_compressed::ConcurrentRadixTreeCompressed;
pub use config::{KvRouterConfig, RouterConfigOverride, RouterQueuePolicy};
pub use event_sink::EventSink;
pub use indexer::{MaybeError, SyncIndexer, ThreadPoolIndexer};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment