"lib/engines/vscode:/vscode.git/clone" did not exist on "73fdfb8ab84c9f56982d7d6074ef4d2f2a214150"
Unverified Commit 937398cf authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: Flash Indexer (#5785)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Signed-off-by: default avatarjthomson04 <jothomson@nvidia.com>
Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
Signed-off-by: default avatarJanelle Cai <jcai18@mit.edu>
Co-authored-by: default avatarjthomson04 <jwillthomson19@gmail.com>
Co-authored-by: default avatarJanelle Cai <jcai18@mit.edu>
parent de27efe6
......@@ -153,7 +153,11 @@ impl RadixTree {
/// An `OverlapScores` representing the match scores.
pub fn find_matches(&self, sequence: Vec<LocalBlockHash>, early_exit: bool) -> OverlapScores {
let mut scores = OverlapScores::new();
let mut current = self.root.clone();
if sequence.is_empty() {
return scores;
}
let now = Instant::now();
tracing::trace!(
......@@ -161,46 +165,142 @@ impl RadixTree {
sequence.iter().map(|h| h.0).collect::<Vec<_>>()
);
for (idx, block_hash) in sequence.iter().enumerate() {
// Get first child from root.
let first_child = {
let current_borrow = self.root.borrow();
current_borrow.children.get(&sequence[0]).cloned()
};
let Some(first_child) = first_child else {
return scores;
};
// Initialize active worker set from first child.
let (mut active, mut active_count) = {
let borrow = first_child.borrow();
(borrow.workers.clone(), borrow.workers.len())
};
// Frequency tracking for first child.
if let Some(expiration_duration) = self.expiration_duration {
let mut block_mut = first_child.borrow_mut();
while let Some(access_time) = block_mut.recent_uses.front() {
if now.duration_since(*access_time) > expiration_duration {
block_mut.recent_uses.pop_front();
} else {
break;
}
}
scores.add_frequency(block_mut.recent_uses.len());
block_mut.recent_uses.push_back(now);
}
if active.is_empty() {
return scores;
}
if early_exit && active_count == 1 {
for worker in &active {
scores.scores.insert(*worker, 1);
}
for worker in scores.scores.keys() {
let tree_size = self
.lookup
.get(worker)
.expect("worker in scores must exist in lookup table")
.len();
scores.tree_sizes.insert(*worker, tree_size);
}
return scores;
}
let mut current = first_child;
let mut matched_depth = 1u32;
// Traverse remaining levels. In a clean tree, workers at a child node
// are always a subset of the parent (along the same path), so:
// - workers can only drop out, never join, as we descend
// - if child.workers.len() == active_count, the sets are identical
//
// However, because apply_event(Removed) does NOT cascade to descendants,
// a child may transiently have MORE workers than its parent (stale
// entries from an ancestor remove whose descendant remove events
// haven't arrived yet). We detect this via child_count > active_count
// and fall back to a full membership check.
for (idx, item) in sequence.iter().enumerate().skip(1) {
let next_block = {
let current_borrow = current.borrow();
current_borrow.children.get(block_hash).cloned()
current_borrow.children.get(item).cloned()
};
if let Some(block) = next_block {
scores.update_scores(block.borrow().workers.iter());
if let Some(expiration_duration) = self.expiration_duration {
let mut block_mut = block.borrow_mut();
let Some(block) = next_block else {
break;
};
while let Some(access_time) = block_mut.recent_uses.front() {
if now.duration_since(*access_time) > expiration_duration {
block_mut.recent_uses.pop_front();
} else {
break;
{
let borrow = block.borrow();
let child_count = borrow.workers.len();
if child_count < active_count {
// Workers dropped out. Record scores for those that left.
// Score = matched_depth (number of nodes they were present at).
for worker in &active {
if !borrow.workers.contains(worker) {
scores.scores.insert(*worker, matched_depth);
}
}
scores.add_frequency(block_mut.recent_uses.len());
block_mut.recent_uses.push_back(now);
active.clone_from(&borrow.workers);
active_count = child_count;
} else if child_count > active_count {
// Stale entries: child retains workers already removed from
// an ancestor. Fall back to full membership check.
active.retain(|w| {
if borrow.workers.contains(w) {
true
} else {
scores.scores.insert(*w, matched_depth);
false
}
});
active_count = active.len();
}
}
if early_exit && block.borrow().workers.len() == 1 {
break;
// Frequency tracking (always runs when enabled, independent of dropout).
if let Some(expiration_duration) = self.expiration_duration {
let mut block_mut = block.borrow_mut();
while let Some(access_time) = block_mut.recent_uses.front() {
if now.duration_since(*access_time) > expiration_duration {
block_mut.recent_uses.pop_front();
} else {
break;
}
}
scores.add_frequency(block_mut.recent_uses.len());
block_mut.recent_uses.push_back(now);
}
current = block;
} else {
tracing::trace!(
"RadixTree::find_matches: block not found at index {} for hash {}",
idx,
block_hash.0
);
if active_count == 0 {
break;
}
if early_exit && active_count == 1 {
matched_depth = (idx + 1) as u32;
break;
}
current = block;
matched_depth = (idx + 1) as u32;
}
// Record scores for workers that survived through the deepest matched level.
for worker in &active {
scores.scores.insert(*worker, matched_depth);
}
tracing::trace!("RadixTree::find_matches: final scores={:?}", scores.scores);
// Populate tree sizes for all workers that have scores
// Populate tree sizes for all workers that have scores.
for worker in scores.scores.keys() {
let tree_size = self
.lookup
......@@ -250,8 +350,19 @@ impl RadixTree {
None => self.root.clone(),
};
let mut needs_worker_insert = false;
// In each iteration we lock the parent and insert the worker
// deferred from the previous iteration, avoiding a second
// borrow on the same block.
for block_data in op.blocks {
let mut parent_mut = current.borrow_mut();
if needs_worker_insert {
parent_mut.workers.insert(worker);
}
needs_worker_insert = true;
let child = match parent_mut.children.get(&block_data.tokens_hash) {
Some(block) => {
// Verify our simplifying assumption: block_hash is uniform across workers
......@@ -265,7 +376,6 @@ impl RadixTree {
block.clone()
}
None => {
// create new block or reuse existing from worker's lookup
let new_block = worker_lookup
.get(&block_data.block_hash)
.cloned()
......@@ -275,7 +385,6 @@ impl RadixTree {
)))
});
// insert into radix tree
parent_mut
.children
.insert(block_data.tokens_hash, new_block.clone());
......@@ -284,36 +393,30 @@ impl RadixTree {
}
};
// Update child and check for self referential blocks
{
// Try to borrow the child mutably - if it fails, it's already borrowed
// which means a self referencing block.
let mut child_mut = match child.try_borrow_mut() {
Ok(b) => b,
Err(_) => {
tracing::warn!(
worker_id = worker.worker_id.to_string(),
dp_rank = worker.dp_rank,
id,
block_hash = ?block_data.block_hash,
"Detected self referencing block in store event; rejecting sequence"
);
return Err(KvCacheEventError::InvalidBlockSequence);
}
};
// add our worker to the block
child_mut.workers.insert(worker);
// Self-reference check: try_borrow_mut will fail if child
// is the same Rc as current (parent_mut holds a mutable borrow).
if child.try_borrow_mut().is_err() {
tracing::warn!(
worker_id = worker.worker_id.to_string(),
dp_rank = worker.dp_rank,
id,
block_hash = ?block_data.block_hash,
"Detected self referencing block in store event; rejecting sequence"
);
return Err(KvCacheEventError::InvalidBlockSequence);
}
// add the block to the worker's lookup table
worker_lookup.insert(block_data.block_hash, child.clone());
// drop child so we can shift current to this block
drop(parent_mut);
current = child;
}
// Insert worker into the last child.
if needs_worker_insert {
current.borrow_mut().workers.insert(worker);
}
Ok(())
}
KvCacheEventData::Removed(remove) => {
......@@ -474,64 +577,8 @@ impl RadixTree {
#[cfg(test)]
mod tests {
use super::*;
use crate::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData,
KvCacheStoreData, KvCacheStoredBlockData, LocalBlockHash, WorkerId,
};
/// Creates blocks with artificial hash mapping (hash * 100) for testing RadixTree internals.
fn make_blocks(hashes: Vec<u64>) -> Vec<KvCacheStoredBlockData> {
hashes
.iter()
.map(|i| KvCacheStoredBlockData {
tokens_hash: LocalBlockHash(*i),
block_hash: ExternalSequenceBlockHash(*i * 100),
mm_extra_info: None,
})
.collect()
}
fn add_blocks(
hashes: Vec<u64>,
parent_hash: Option<ExternalSequenceBlockHash>,
) -> KvCacheEventData {
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks: make_blocks(hashes),
})
}
fn create_store_event(
worker_id: WorkerId,
event_id: u64,
hashes: Vec<u64>,
parent: Option<ExternalSequenceBlockHash>,
) -> RouterEvent {
RouterEvent {
worker_id,
event: KvCacheEvent {
event_id,
data: add_blocks(hashes, parent),
dp_rank: 0,
},
}
}
fn create_remove_event(worker_id: WorkerId, event_id: u64, hashes: Vec<u64>) -> RouterEvent {
RouterEvent {
worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: hashes
.iter()
.map(|i| ExternalSequenceBlockHash(*i * 100))
.collect(),
}),
dp_rank: 0,
},
}
}
use crate::protocols::{ExternalSequenceBlockHash, LocalBlockHash};
use crate::test_utils::{create_remove_event, create_store_event};
#[test]
fn test_radix_tree() {
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Shared test utilities for radix tree tests.
use crate::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash, RouterEvent, WorkerId,
};
/// Creates blocks with artificial hash mapping (hash * 100) for testing.
pub fn make_blocks(hashes: Vec<u64>) -> Vec<KvCacheStoredBlockData> {
hashes
.iter()
.map(|i| KvCacheStoredBlockData {
tokens_hash: LocalBlockHash(*i),
block_hash: ExternalSequenceBlockHash(*i * 100),
mm_extra_info: None,
})
.collect()
}
pub fn add_blocks(
hashes: Vec<u64>,
parent_hash: Option<ExternalSequenceBlockHash>,
) -> KvCacheEventData {
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks: make_blocks(hashes),
})
}
pub fn create_store_event(
worker_id: WorkerId,
event_id: u64,
hashes: Vec<u64>,
parent: Option<ExternalSequenceBlockHash>,
) -> RouterEvent {
RouterEvent {
worker_id,
event: KvCacheEvent {
event_id,
data: add_blocks(hashes, parent),
dp_rank: 0,
},
}
}
pub fn create_remove_event(worker_id: WorkerId, event_id: u64, hashes: Vec<u64>) -> RouterEvent {
RouterEvent {
worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: hashes
.iter()
.map(|i| ExternalSequenceBlockHash(*i * 100))
.collect(),
}),
dp_rank: 0,
},
}
}
......@@ -2,9 +2,11 @@
// SPDX-License-Identifier: Apache-2.0
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use anyhow::Result;
use dynamo_kv_router::{ConcurrentRadixTree, ThreadPoolIndexer};
use dynamo_runtime::{
component::{Client, Endpoint},
discovery::DiscoveryQuery,
......@@ -17,6 +19,7 @@ use dynamo_runtime::{
traits::DistributedRuntimeProvider,
};
use futures::stream;
use tokio::sync::oneshot;
use validator::Validate;
// Re-export from dynamo-kv-router crate
......@@ -43,10 +46,11 @@ use crate::{
discovery::RuntimeConfigWatch,
kv_router::{
approx::PruneConfig,
indexer::{KvIndexer, KvIndexerInterface, KvRouterError},
indexer::{GetWorkersRequest, KvIndexer, KvIndexerInterface, KvRouterError},
protocols::{
DpRank, LocalBlockHash, OverlapScores, RouterEvent, RouterRequest, RouterResponse,
TokensWithHashes, WorkerSelectionResult, WorkerWithDpRank, compute_block_hash_for_seq,
TokensWithHashes, WorkerId, WorkerSelectionResult, WorkerWithDpRank,
compute_block_hash_for_seq,
},
scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
sequence::SequenceError,
......@@ -113,12 +117,18 @@ pub trait WorkerSelector {
) -> Result<WorkerSelectionResult, KvSchedulerError>;
}
#[derive(Clone)]
pub enum Indexer {
/// Updates itself based on KV events emitted by backend workers or routing decisions.
/// Single-threaded radix tree with channel-based event processing.
/// Supports TTL-based expiration and size-based pruning.
/// Has the ability to persist and snapshot states.
KvIndexer(KvIndexer),
/// Concurrent radix tree with a thread pool for event processing.
/// Uses sticky worker routing for per-worker event serialization.
/// Does not support TTL/pruning.
Concurrent(Arc<ThreadPoolIndexer<ConcurrentRadixTree>>),
/// Used when we do not wish to use the indexer at all (e.g., when overlap_score_weight is 0).
/// Note: This will cause KV events to accumulate in JetStream as we do not regularly purge them.
None,
......@@ -132,30 +142,37 @@ impl Indexer {
cancellation_token: tokio_util::sync::CancellationToken,
) -> Self {
if kv_router_config.overlap_score_weight == 0.0 {
// When overlap_score_weight is zero, we don't need to track prefixes
Indexer::None
} else {
let kv_indexer_metrics = indexer::KvIndexerMetrics::from_component(component);
// If use_kv_events is false, enable TTL and pruning for approximate behavior
let prune_config = if !kv_router_config.use_kv_events {
Some(PruneConfig {
ttl: Duration::from_secs_f64(kv_router_config.router_ttl_secs),
max_tree_size: kv_router_config.router_max_tree_size,
prune_target_ratio: kv_router_config.router_prune_target_ratio,
})
} else {
None
};
Indexer::KvIndexer(KvIndexer::new_with_frequency(
cancellation_token,
None, // expiration_duration for frequency tracking
return Indexer::None;
}
if kv_router_config.router_event_threads > 1 {
return Indexer::Concurrent(Arc::new(ThreadPoolIndexer::new(
ConcurrentRadixTree::new(),
kv_router_config.router_event_threads as usize,
block_size,
kv_indexer_metrics,
prune_config,
))
)));
}
let kv_indexer_metrics = indexer::KvIndexerMetrics::from_component(component);
// If use_kv_events is false, enable TTL and pruning for approximate behavior
let prune_config = if !kv_router_config.use_kv_events {
Some(PruneConfig {
ttl: Duration::from_secs_f64(kv_router_config.router_ttl_secs),
max_tree_size: kv_router_config.router_max_tree_size,
prune_target_ratio: kv_router_config.router_prune_target_ratio,
})
} else {
None
};
Indexer::KvIndexer(KvIndexer::new_with_frequency(
cancellation_token,
None, // expiration_duration for frequency tracking
block_size,
kv_indexer_metrics,
prune_config,
))
}
pub(crate) async fn find_matches(
......@@ -164,6 +181,7 @@ impl Indexer {
) -> Result<OverlapScores, KvRouterError> {
match self {
Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await,
Indexer::Concurrent(tpi) => tpi.find_matches(sequence).await,
Indexer::None => Ok(OverlapScores {
scores: HashMap::new(),
frequencies: Vec::new(),
......@@ -175,6 +193,7 @@ impl Indexer {
pub(crate) async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
match self {
Indexer::KvIndexer(indexer) => indexer.dump_events().await,
Indexer::Concurrent(tpi) => tpi.dump_events().await,
Indexer::None => {
panic!(
"Cannot dump events: indexer does not exist (is overlap_score_weight set to 0?)"
......@@ -194,9 +213,55 @@ impl Indexer {
.process_routing_decision_for_request(tokens_with_hashes, worker)
.await
}
Indexer::Concurrent(tpi) => {
tpi.process_routing_decision_for_request(tokens_with_hashes, worker)
.await
}
Indexer::None => Ok(()),
}
}
pub(crate) async fn apply_event(&self, event: RouterEvent) {
match self {
Indexer::KvIndexer(indexer) => {
if let Err(e) = indexer.event_sender().send(event).await {
tracing::warn!("Failed to send event to indexer: {e}");
}
}
Indexer::Concurrent(tpi) => tpi.apply_event(event).await,
Indexer::None => {}
}
}
pub(crate) async fn remove_worker(&self, worker_id: WorkerId) {
match self {
Indexer::KvIndexer(indexer) => {
if let Err(e) = indexer.remove_worker_sender().send(worker_id).await {
tracing::warn!("Failed to send worker removal for {worker_id}: {e}");
}
}
Indexer::Concurrent(tpi) => {
KvIndexerInterface::remove_worker(tpi.as_ref(), worker_id).await;
}
Indexer::None => {}
}
}
pub(crate) async fn get_workers(&self) -> Vec<WorkerId> {
match self {
Indexer::KvIndexer(indexer) => {
let (resp_tx, resp_rx) = oneshot::channel();
let req = GetWorkersRequest { resp: resp_tx };
if let Err(e) = indexer.get_workers_sender().send(req).await {
tracing::warn!("Failed to send get_workers request: {e}");
return Vec::new();
}
resp_rx.await.unwrap_or_default()
}
Indexer::Concurrent(tpi) => tpi.backend().get_workers(),
Indexer::None => Vec::new(),
}
}
}
/// A KvRouter only decides which worker you should use. It doesn't send you there.
......@@ -255,18 +320,11 @@ impl KvRouter {
// Start KV event subscription if needed (use_kv_events=true and overlap_score_weight>0)
if kv_router_config.should_subscribe_to_kv_events() {
// Guaranteed to be KvIndexer since overlap_score_weight > 0.0
let Indexer::KvIndexer(kv_indexer) = &indexer else {
unreachable!(
"should_subscribe_to_kv_events implies overlap_score_weight > 0 implies KvIndexer"
)
};
subscriber::start_subscriber(
component.clone(),
&kv_router_config,
router_id,
kv_indexer,
indexer.clone(),
cancellation_token.clone(),
)
.await?;
......
......@@ -4,7 +4,7 @@
use derive_builder::Builder;
use rand::Rng;
use serde::{Deserialize, Serialize};
use validator::Validate;
use validator::{Validate, ValidationError};
use crate::kv_router::protocols::{compute_block_hash_for_seq, compute_seq_hash_for_block};
......@@ -21,6 +21,7 @@ pub struct RouterConfigOverride {
/// KV Router configuration parameters
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Validate)]
#[validate(schema(function = "validate_kv_router_config"))]
pub struct KvRouterConfig {
#[validate(range(min = 0.0))]
pub overlap_score_weight: f64,
......@@ -69,6 +70,12 @@ pub struct KvRouterConfig {
/// Target size ratio after pruning (only used when use_kv_events is false, default: 0.8)
#[validate(range(min = 0.0, max = 1.0))]
pub router_prune_target_ratio: f64,
/// Number of event processing threads for the KV indexer.
/// When > 1, uses ConcurrentRadixTree with a thread pool instead of the
/// single-threaded RadixTree. Default: 1.
#[validate(range(min = 1))]
pub router_event_threads: u32,
}
impl Default for KvRouterConfig {
......@@ -87,10 +94,30 @@ impl Default for KvRouterConfig {
router_ttl_secs: 120.0,
router_max_tree_size: 2usize.pow(20), // 2^20 = 1048576, matches PruneConfig::default()
router_prune_target_ratio: 0.8,
router_event_threads: 1,
}
}
}
fn validate_kv_router_config(config: &KvRouterConfig) -> Result<(), ValidationError> {
if config.durable_kv_events && !config.use_kv_events {
return Err(ValidationError::new(
"durable_kv_events requires use_kv_events=true",
));
}
if !config.use_kv_events && config.router_event_threads > 1 {
return Err(ValidationError::new(
"router_event_threads > 1 requires use_kv_events=true",
));
}
if config.router_track_output_blocks && !config.router_track_active_blocks {
return Err(ValidationError::new(
"router_track_output_blocks requires router_track_active_blocks=true",
));
}
Ok(())
}
impl KvRouterConfig {
/// Compute sequence hashes for active block tracking based on configuration.
///
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::HashSet;
use std::fmt;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
......@@ -577,6 +578,26 @@ fn convert_event(
block_mm_infos,
..
} => {
// Reject self-referencing blocks: all block hashes (including parent) must be unique.
{
let mut seen = HashSet::with_capacity(block_hashes.len() + 1);
if let Some(parent) = parent_block_hash {
seen.insert(parent.into_u64());
}
let has_duplicate = block_hashes.iter().any(|h| !seen.insert(h.into_u64()));
if has_duplicate {
tracing::warn!(
event_id,
"Self-referencing block detected: duplicate hash in store event; dropping"
);
return KvCacheEvent {
event_id,
data: KvCacheEventData::Cleared,
dp_rank,
};
}
}
let num_block_tokens = vec![block_size as u64; block_hashes.len()];
let block_hashes_u64: Vec<u64> = block_hashes
.into_iter()
......
......@@ -14,12 +14,10 @@ use dynamo_runtime::{
};
use futures::StreamExt;
use rand::Rng;
use tokio::sync::{mpsc, oneshot};
use tokio_util::sync::CancellationToken;
use crate::kv_router::{
KV_EVENT_SUBJECT, KvRouterConfig, RADIX_STATE_BUCKET, RADIX_STATE_FILE,
indexer::{DumpRequest, GetWorkersRequest, KvIndexer},
Indexer, KV_EVENT_SUBJECT, KvRouterConfig, RADIX_STATE_BUCKET, RADIX_STATE_FILE,
protocols::{DpRank, RouterEvent, WorkerId},
router_discovery_query,
worker_query::WorkerQueryClient,
......@@ -84,7 +82,7 @@ async fn get_instance_discovery_stream(
async fn download_stable_snapshot(
nats_client: &dynamo_runtime::transports::nats::Client,
bucket_name: &str,
kv_events_tx: &mpsc::Sender<RouterEvent>,
indexer: &Indexer,
) -> Result<()> {
let url = url::Url::parse(&format!(
"nats://{}/{bucket_name}/{RADIX_STATE_FILE}",
......@@ -147,9 +145,7 @@ async fn download_stable_snapshot(
// Send all events to the indexer
for event in prev_events {
if let Err(e) = kv_events_tx.send(event).await {
tracing::warn!("Failed to send initial event to indexer: {e:?}");
}
indexer.apply_event(event).await;
}
tracing::info!("Successfully sent all initial events to indexer");
......@@ -162,57 +158,27 @@ struct SnapshotResources {
nats_client: dynamo_runtime::transports::nats::Client,
bucket_name: String,
instances_rx: tokio::sync::watch::Receiver<Vec<dynamo_runtime::component::Instance>>,
get_workers_tx: mpsc::Sender<GetWorkersRequest>,
snapshot_tx: mpsc::Sender<DumpRequest>,
indexer: Indexer,
}
impl SnapshotResources {
/// Perform snapshot upload and purge operations
async fn purge_then_snapshot(
&self,
nats_queue: &mut NatsQueue,
remove_worker_tx: &mpsc::Sender<WorkerId>,
) -> anyhow::Result<()> {
// Purge before snapshot ensures new/warm-restarted routers won't replay already-acknowledged messages.
// Since KV events are idempotent, this ordering reduces unnecessary reprocessing while maintaining
// at-least-once delivery guarantees. The snapshot will capture the clean state after purge.
async fn purge_then_snapshot(&self, nats_queue: &mut NatsQueue) -> anyhow::Result<()> {
tracing::info!("Purging acknowledged messages and performing snapshot of radix tree");
let start_time = std::time::Instant::now();
// Clean up stale workers before snapshot
// Get current worker IDs from instances_rx
let current_instances = self.instances_rx.borrow().clone();
let current_worker_ids: std::collections::HashSet<u64> = current_instances
.iter()
.map(|instance| instance.instance_id)
.collect();
// Get worker IDs from the indexer
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
let get_workers_req = GetWorkersRequest { resp: resp_tx };
if let Err(e) = self.get_workers_tx.send(get_workers_req).await {
tracing::warn!("Failed to send get_workers request during snapshot: {e:?}");
} else {
match resp_rx.await {
Ok(indexer_worker_ids) => {
// Find workers in indexer but not in current instances
for worker_id in indexer_worker_ids {
if !current_worker_ids.contains(&worker_id) {
tracing::info!(
"Removing stale worker {worker_id} from indexer during snapshot"
);
if let Err(e) = remove_worker_tx.send(worker_id).await {
tracing::warn!(
"Failed to send remove_worker for stale worker {worker_id}: {e:?}"
);
}
}
}
}
Err(e) => {
tracing::warn!("Failed to receive worker IDs from indexer: {e:?}");
}
let indexer_worker_ids = self.indexer.get_workers().await;
for worker_id in indexer_worker_ids {
if !current_worker_ids.contains(&worker_id) {
tracing::info!("Removing stale worker {worker_id} from indexer during snapshot");
self.indexer.remove_worker(worker_id).await;
}
}
......@@ -220,18 +186,11 @@ impl SnapshotResources {
nats_queue.purge_acknowledged().await?;
// Now request a snapshot from the indexer (which reflects the post-purge state)
let (resp_tx, resp_rx) = oneshot::channel();
let dump_req = DumpRequest { resp: resp_tx };
self.snapshot_tx
.send(dump_req)
let events = self
.indexer
.dump_events()
.await
.map_err(|e| anyhow::anyhow!("Failed to send dump request: {e:?}"))?;
// Wait for the dump response
let events = resp_rx
.await
.map_err(|e| anyhow::anyhow!("Failed to receive dump response: {e:?}"))?;
.map_err(|e| anyhow::anyhow!("Failed to dump events for snapshot: {e:?}"))?;
// Upload the snapshot to NATS object store in background (non-blocking)
let nats_client = self.nats_client.clone();
......@@ -262,14 +221,10 @@ impl SnapshotResources {
}
/// Start a unified background task for event consumption and optional snapshot management
#[allow(clippy::too_many_arguments)]
pub async fn start_kv_router_background(
component: Component,
consumer_id: String,
kv_events_tx: mpsc::Sender<RouterEvent>,
remove_worker_tx: mpsc::Sender<WorkerId>,
maybe_get_workers_tx: Option<mpsc::Sender<GetWorkersRequest>>,
maybe_snapshot_tx: Option<mpsc::Sender<DumpRequest>>,
indexer: Indexer,
cancellation_token: CancellationToken,
router_snapshot_threshold: Option<u32>,
router_reset_states: bool,
......@@ -307,7 +262,7 @@ pub async fn start_kv_router_background(
// Handle initial state based on router_reset_states flag
if !router_reset_states {
// Try to download initial state from object store with stability check
download_stable_snapshot(&nats_client, &bucket_name, &kv_events_tx).await?;
download_stable_snapshot(&nats_client, &bucket_name, &indexer).await?;
} else {
// Delete the bucket to reset state
tracing::info!("Resetting router state, deleting bucket: {bucket_name}");
......@@ -335,22 +290,13 @@ pub async fn start_kv_router_background(
let client = generate_endpoint.client().await?;
let instances_rx = client.instance_source.as_ref().clone();
// Only set up snapshot-related resources if snapshot_tx, get_workers_tx, and threshold are provided
let snapshot_resources = if let (Some(get_workers_tx), Some(snapshot_tx), Some(_)) = (
maybe_get_workers_tx,
maybe_snapshot_tx,
router_snapshot_threshold,
) {
Some(SnapshotResources {
nats_client,
bucket_name,
instances_rx,
get_workers_tx,
snapshot_tx,
})
} else {
None
};
// Only set up snapshot-related resources if snapshot threshold is configured
let snapshot_resources = router_snapshot_threshold.map(|_| SnapshotResources {
nats_client,
bucket_name,
instances_rx,
indexer: indexer.clone(),
});
tokio::spawn(async move {
// Create interval with jitter
......@@ -392,9 +338,7 @@ pub async fn start_kv_router_background(
"DISCOVERY: Generate endpoint instance removed, removing worker {worker_id}"
);
if let Err(e) = remove_worker_tx.send(worker_id).await {
tracing::warn!("Failed to send worker removal for worker {worker_id}: {e}");
}
indexer.remove_worker(worker_id).await;
}
// Handle event consumption
......@@ -410,12 +354,7 @@ pub async fn start_kv_router_background(
};
// Forward the RouterEvent to the indexer
if let Err(e) = kv_events_tx.send(event).await {
tracing::warn!(
"failed to send kv event to indexer; shutting down: {e:?}"
);
break;
}
indexer.apply_event(event).await;
},
Ok(None) => {
tracing::trace!("Dequeue timeout, continuing");
......@@ -449,7 +388,6 @@ pub async fn start_kv_router_background(
match resources.purge_then_snapshot(
&mut nats_queue,
&remove_worker_tx,
).await {
Ok(_) => tracing::info!("Successfully performed purge and snapshot"),
Err(e) => tracing::debug!("Could not perform purge and snapshot: {e:?}"),
......@@ -510,15 +448,13 @@ pub async fn start_kv_router_background(
/// This is appropriate when workers have local indexers enabled.
pub async fn start_kv_router_background_event_plane(
component: Component,
kv_events_tx: mpsc::Sender<RouterEvent>,
remove_worker_tx: mpsc::Sender<WorkerId>,
indexer: Indexer,
cancellation_token: CancellationToken,
transport_kind: EventTransportKind,
) -> Result<()> {
// WorkerQueryClient handles its own discovery loop for lifecycle + initial recovery.
// No blocking wait — recovery happens asynchronously as endpoints are discovered.
let worker_query_client =
WorkerQueryClient::spawn(component.clone(), remove_worker_tx, kv_events_tx.clone()).await?;
let worker_query_client = WorkerQueryClient::spawn(component.clone(), indexer.clone()).await?;
// Subscribe to KV events using the selected event plane transport
let mut subscriber =
......@@ -611,12 +547,7 @@ pub async fn start_kv_router_background_event_plane(
.or_insert(event_id);
// Forward the RouterEvent to the indexer
if let Err(e) = kv_events_tx.send(event).await {
tracing::warn!(
"failed to send kv event to indexer; shutting down: {e:?}"
);
break;
}
indexer.apply_event(event).await;
}
}
}
......@@ -670,7 +601,7 @@ pub async fn start_subscriber(
component: Component,
kv_router_config: &KvRouterConfig,
router_id: u64,
kv_indexer: &KvIndexer,
indexer: Indexer,
cancellation_token: CancellationToken,
) -> Result<()> {
let transport_kind = EventTransportKind::from_env_or_default();
......@@ -690,14 +621,7 @@ pub async fn start_subscriber(
start_kv_router_background(
component,
consumer_id,
kv_indexer.event_sender(),
kv_indexer.remove_worker_sender(),
kv_router_config
.router_snapshot_threshold
.map(|_| kv_indexer.get_workers_sender()),
kv_router_config
.router_snapshot_threshold
.map(|_| kv_indexer.snapshot_event_sender()),
indexer,
cancellation_token,
kv_router_config.router_snapshot_threshold,
kv_router_config.router_reset_states,
......@@ -719,8 +643,7 @@ pub async fn start_subscriber(
start_kv_router_background_event_plane(
component.clone(),
kv_indexer.event_sender(),
kv_indexer.remove_worker_sender(),
indexer,
cancellation_token,
transport_kind,
)
......
......@@ -17,10 +17,10 @@ use dynamo_runtime::protocols::maybe_error::MaybeError;
use dynamo_runtime::stream;
use dynamo_runtime::traits::DistributedRuntimeProvider;
use futures::StreamExt;
use tokio::sync::mpsc;
use crate::kv_router::Indexer;
use crate::kv_router::indexer::{LocalKvIndexer, WorkerKvQueryRequest, WorkerKvQueryResponse};
use crate::kv_router::protocols::{DpRank, RouterEvent, WorkerId};
use crate::kv_router::protocols::{DpRank, WorkerId};
use crate::kv_router::worker_kv_indexer_query_endpoint;
// Recovery retry configuration
......@@ -43,10 +43,8 @@ pub struct WorkerQueryClient {
component: Component,
/// Routers keyed by dp_rank — each dp_rank has its own endpoint. Created lazily.
routers: Arc<DashMap<DpRank, Arc<PushRouter<WorkerKvQueryRequest, WorkerKvQueryResponse>>>>,
/// Channel to send recovered events to the router indexer
kv_events_tx: mpsc::Sender<RouterEvent>,
/// Channel to send worker removal events to the router indexer
remove_worker_tx: mpsc::Sender<WorkerId>,
/// Indexer for applying recovered events and worker removals.
indexer: Indexer,
}
impl WorkerQueryClient {
......@@ -55,16 +53,11 @@ impl WorkerQueryClient {
/// The background loop watches `ComponentEndpoints` discovery for query endpoints,
/// recovers each `(worker_id, dp_rank)` as it appears, and sends worker removal
/// events when all dp_ranks for a worker disappear.
pub async fn spawn(
component: Component,
remove_worker_tx: mpsc::Sender<WorkerId>,
kv_events_tx: mpsc::Sender<RouterEvent>,
) -> Result<Arc<Self>> {
pub async fn spawn(component: Component, indexer: Indexer) -> Result<Arc<Self>> {
let client = Arc::new(Self {
component: component.clone(),
routers: Arc::new(DashMap::new()),
kv_events_tx,
remove_worker_tx,
indexer,
});
let client_bg = client.clone();
......@@ -151,11 +144,7 @@ impl WorkerQueryClient {
tracing::warn!(
"WorkerQueryClient: all dp_ranks gone for worker {worker_id}, removing"
);
if let Err(e) = self.remove_worker_tx.send(worker_id).await {
tracing::warn!(
"Failed to send worker removal for worker {worker_id}: {e}"
);
}
self.indexer.remove_worker(worker_id).await;
}
}
}
......@@ -354,12 +343,7 @@ impl WorkerQueryClient {
tracing::info!("Recovered {count} events from worker {worker_id} dp_rank {dp_rank}");
for event in events {
if let Err(e) = self.kv_events_tx.send(event).await {
tracing::error!(
"Failed to send recovered event to indexer for worker {worker_id} dp_rank {dp_rank}: {e}"
);
anyhow::bail!("Failed to send recovered event: {e}");
}
self.indexer.apply_event(event).await;
}
Ok(count)
......
......@@ -1305,6 +1305,7 @@ def _test_router_indexers_sync(
test_nats_interruption: bool = False,
nats_server: Optional["NatsServer"] = None,
durable_kv_events: bool = False,
router_event_threads: int = 1,
):
"""Test that two KV routers have synchronized indexer states after processing requests.
......@@ -1349,6 +1350,7 @@ def _test_router_indexers_sync(
kv_router_config = KvRouterConfig(
router_snapshot_threshold=20,
durable_kv_events=durable_kv_events,
router_event_threads=router_event_threads,
)
async def send_requests_to_router(router, num_requests, router_name, endpoint):
......@@ -1881,6 +1883,7 @@ def _test_router_decisions(
block_size: int = BLOCK_SIZE,
use_kv_events: bool = True,
durable_kv_events: bool = False,
router_event_threads: int = 1,
):
"""Validate KV cache prefix reuse and worker routing by sending requests diverging prefixes.
......@@ -1911,6 +1914,7 @@ def _test_router_decisions(
router_snapshot_threshold=20,
use_kv_events=use_kv_events,
durable_kv_events=durable_kv_events,
router_event_threads=router_event_threads,
)
kv_push_router = KvPushRouter(
endpoint=endpoint,
......
......@@ -522,16 +522,18 @@ def test_kv_push_router_bindings(
@pytest.mark.parametrize(
"store_backend,durable_kv_events,request_plane",
"store_backend,durable_kv_events,request_plane,router_event_threads",
[
("etcd", True, "nats"), # JetStream mode - uses JetStream
("etcd", False, "tcp"), # NATS core mode (with gap detection) - no JetStream
("file", True, "nats"), # File backend - uses JetStream
("etcd", True, "nats", 1), # JetStream mode - uses JetStream
("etcd", False, "tcp", 1), # NATS core mode (with gap detection) - no JetStream
("file", True, "nats", 1), # File backend - uses JetStream
("etcd", False, "tcp", 2), # NATS core mode - multi-threaded indexer
],
ids=[
"jetstream",
"nats_core",
"file",
"nats_core_multi_thread",
],
indirect=["request_plane", "durable_kv_events"],
)
......@@ -544,6 +546,7 @@ def test_indexers_sync(
store_backend,
durable_kv_events,
request_plane,
router_event_threads,
):
"""
Test that two KV routers have synchronized indexer states after processing requests.
......@@ -596,6 +599,7 @@ def test_indexers_sync(
test_nats_interruption=not durable_kv_events,
nats_server=nats_process if not durable_kv_events else None,
durable_kv_events=durable_kv_events,
router_event_threads=router_event_threads,
)
logger.info("Indexers sync test completed successfully")
......@@ -639,15 +643,16 @@ def test_query_instance_id_returns_worker_and_tokens(
@pytest.mark.timeout(90) # bumped for xdist contention (was 29s; ~9.55s serial avg)
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
@pytest.mark.parametrize("request_plane", ["tcp"], indirect=True)
@pytest.mark.parametrize(
"durable_kv_events,use_kv_events",
"durable_kv_events,use_kv_events,router_event_threads",
[
(True, True), # JetStream mode with KV events
(False, True), # NATS Core mode with local indexer (default)
(False, False), # Approximate mode (--no-kv-events) - no KV events
(True, True, 1), # JetStream mode with KV events
(False, True, 1), # NATS Core mode with local indexer (default)
(False, False, 1), # Approximate mode (--no-kv-events) - no KV events
(False, True, 2), # NATS Core mode - multi-threaded indexer
],
ids=["jetstream", "nats_core", "no_kv_events"],
ids=["jetstream", "nats_core", "no_kv_events", "nats_core_multi_thread"],
indirect=["durable_kv_events"],
)
def test_router_decisions(
......@@ -657,6 +662,7 @@ def test_router_decisions(
durable_kv_events,
use_kv_events,
request_plane,
router_event_threads,
):
"""Validate KV cache prefix reuse and dp_rank routing by sending progressive requests with overlapping prefixes.
......@@ -704,6 +710,7 @@ def test_router_decisions(
test_dp_rank=True,
use_kv_events=use_kv_events,
durable_kv_events=durable_kv_events,
router_event_threads=router_event_threads,
)
......
......@@ -371,15 +371,19 @@ def test_sglang_kv_router_basic(
@pytest.mark.pre_merge
@pytest.mark.gpu_1
@pytest.mark.skip(reason="Broken by sglang changes")
# TODO: Re-enable this test once https://github.com/sgl-project/sglang/pull/14934 is merged
@pytest.mark.parametrize("request_plane", ["tcp"], indirect=True)
@pytest.mark.parametrize(
"router_event_threads",
[1, 2],
ids=["single_thread", "multi_thread"],
)
def test_router_decisions_sglang_multiple_workers(
request,
runtime_services_dynamic_ports,
predownload_models,
set_ucx_tls_no_mm,
request_plane,
router_event_threads,
):
# runtime_services starts etcd and nats
logger.info("Starting SGLang router prefix reuse test with two workers")
......@@ -396,15 +400,18 @@ def test_router_decisions_sglang_multiple_workers(
logger.info("Starting 2 SGLang worker processes on single GPU (mem_frac=0.4)")
logger.info(f"All SGLang workers using namespace: {sglang_workers.namespace}")
# Initialize SGLang workers
# Get runtime and create endpoint
runtime = get_runtime(request_plane=request_plane)
namespace = runtime.namespace(sglang_workers.namespace)
component = namespace.component("backend")
endpoint = component.endpoint("generate")
_test_router_decisions(
sglang_workers, endpoint, MODEL_NAME, request, test_dp_rank=False
sglang_workers,
endpoint,
MODEL_NAME,
request,
test_dp_rank=False,
router_event_threads=router_event_threads,
)
......
......@@ -419,6 +419,11 @@ def test_router_decisions_trtllm_attention_dp(
@pytest.mark.pre_merge
@pytest.mark.gpu_1
@pytest.mark.parametrize("request_plane", ["tcp"], indirect=True)
@pytest.mark.parametrize(
"router_event_threads",
[1, 2],
ids=["single_thread", "multi_thread"],
)
@pytest.mark.timeout(150) # ~3x average (~45s/test), rounded up
def test_router_decisions_trtllm_multiple_workers(
request,
......@@ -426,6 +431,7 @@ def test_router_decisions_trtllm_multiple_workers(
predownload_models,
set_ucx_tls_no_mm,
request_plane,
router_event_threads,
):
# runtime_services starts etcd and nats
logger.info("Starting TRT-LLM router prefix reuse test with two workers")
......@@ -444,8 +450,6 @@ def test_router_decisions_trtllm_multiple_workers(
)
logger.info(f"All TRT-LLM workers using namespace: {trtllm_workers.namespace}")
# Initialize TRT-LLM workers
# Get runtime and create endpoint
runtime = get_runtime(request_plane=request_plane)
namespace = runtime.namespace(trtllm_workers.namespace)
component = namespace.component("tensorrt_llm")
......@@ -458,6 +462,7 @@ def test_router_decisions_trtllm_multiple_workers(
request,
test_dp_rank=False,
block_size=TRTLLM_BLOCK_SIZE,
router_event_threads=router_event_threads,
)
......
......@@ -385,12 +385,18 @@ def test_vllm_kv_router_basic(
@pytest.mark.gpu_1
@pytest.mark.timeout(150) # ~3x average (~43s/test), rounded up
@pytest.mark.parametrize("request_plane", ["tcp"], indirect=True)
@pytest.mark.parametrize(
"router_event_threads",
[1, 2],
ids=["single_thread", "multi_thread"],
)
def test_router_decisions_vllm_multiple_workers(
request,
runtime_services_dynamic_ports,
predownload_models,
set_ucx_tls_no_mm,
request_plane,
router_event_threads,
):
# runtime_services starts etcd and nats
logger.info("Starting vLLM router prefix reuse test with two workers")
......@@ -414,7 +420,12 @@ def test_router_decisions_vllm_multiple_workers(
endpoint = component.endpoint("generate")
_test_router_decisions(
vllm_workers, endpoint, MODEL_NAME, request, test_dp_rank=False
vllm_workers,
endpoint,
MODEL_NAME,
request,
test_dp_rank=False,
router_event_threads=router_event_threads,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment