Unverified Commit 937398cf authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: Flash Indexer (#5785)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Signed-off-by: default avatarjthomson04 <jothomson@nvidia.com>
Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
Signed-off-by: default avatarJanelle Cai <jcai18@mit.edu>
Co-authored-by: default avatarjthomson04 <jwillthomson19@gmail.com>
Co-authored-by: default avatarJanelle Cai <jcai18@mit.edu>
parent de27efe6
...@@ -153,7 +153,11 @@ impl RadixTree { ...@@ -153,7 +153,11 @@ impl RadixTree {
/// An `OverlapScores` representing the match scores. /// An `OverlapScores` representing the match scores.
pub fn find_matches(&self, sequence: Vec<LocalBlockHash>, early_exit: bool) -> OverlapScores { pub fn find_matches(&self, sequence: Vec<LocalBlockHash>, early_exit: bool) -> OverlapScores {
let mut scores = OverlapScores::new(); let mut scores = OverlapScores::new();
let mut current = self.root.clone();
if sequence.is_empty() {
return scores;
}
let now = Instant::now(); let now = Instant::now();
tracing::trace!( tracing::trace!(
...@@ -161,46 +165,142 @@ impl RadixTree { ...@@ -161,46 +165,142 @@ impl RadixTree {
sequence.iter().map(|h| h.0).collect::<Vec<_>>() sequence.iter().map(|h| h.0).collect::<Vec<_>>()
); );
for (idx, block_hash) in sequence.iter().enumerate() { // Get first child from root.
let first_child = {
let current_borrow = self.root.borrow();
current_borrow.children.get(&sequence[0]).cloned()
};
let Some(first_child) = first_child else {
return scores;
};
// Initialize active worker set from first child.
let (mut active, mut active_count) = {
let borrow = first_child.borrow();
(borrow.workers.clone(), borrow.workers.len())
};
// Frequency tracking for first child.
if let Some(expiration_duration) = self.expiration_duration {
let mut block_mut = first_child.borrow_mut();
while let Some(access_time) = block_mut.recent_uses.front() {
if now.duration_since(*access_time) > expiration_duration {
block_mut.recent_uses.pop_front();
} else {
break;
}
}
scores.add_frequency(block_mut.recent_uses.len());
block_mut.recent_uses.push_back(now);
}
if active.is_empty() {
return scores;
}
if early_exit && active_count == 1 {
for worker in &active {
scores.scores.insert(*worker, 1);
}
for worker in scores.scores.keys() {
let tree_size = self
.lookup
.get(worker)
.expect("worker in scores must exist in lookup table")
.len();
scores.tree_sizes.insert(*worker, tree_size);
}
return scores;
}
let mut current = first_child;
let mut matched_depth = 1u32;
// Traverse remaining levels. In a clean tree, workers at a child node
// are always a subset of the parent (along the same path), so:
// - workers can only drop out, never join, as we descend
// - if child.workers.len() == active_count, the sets are identical
//
// However, because apply_event(Removed) does NOT cascade to descendants,
// a child may transiently have MORE workers than its parent (stale
// entries from an ancestor remove whose descendant remove events
// haven't arrived yet). We detect this via child_count > active_count
// and fall back to a full membership check.
for (idx, item) in sequence.iter().enumerate().skip(1) {
let next_block = { let next_block = {
let current_borrow = current.borrow(); let current_borrow = current.borrow();
current_borrow.children.get(block_hash).cloned() current_borrow.children.get(item).cloned()
}; };
if let Some(block) = next_block {
scores.update_scores(block.borrow().workers.iter());
if let Some(expiration_duration) = self.expiration_duration { let Some(block) = next_block else {
let mut block_mut = block.borrow_mut(); break;
};
while let Some(access_time) = block_mut.recent_uses.front() { {
if now.duration_since(*access_time) > expiration_duration { let borrow = block.borrow();
block_mut.recent_uses.pop_front(); let child_count = borrow.workers.len();
} else {
break; if child_count < active_count {
// Workers dropped out. Record scores for those that left.
// Score = matched_depth (number of nodes they were present at).
for worker in &active {
if !borrow.workers.contains(worker) {
scores.scores.insert(*worker, matched_depth);
} }
} }
scores.add_frequency(block_mut.recent_uses.len()); active.clone_from(&borrow.workers);
block_mut.recent_uses.push_back(now); active_count = child_count;
} else if child_count > active_count {
// Stale entries: child retains workers already removed from
// an ancestor. Fall back to full membership check.
active.retain(|w| {
if borrow.workers.contains(w) {
true
} else {
scores.scores.insert(*w, matched_depth);
false
}
});
active_count = active.len();
} }
}
if early_exit && block.borrow().workers.len() == 1 { // Frequency tracking (always runs when enabled, independent of dropout).
break; if let Some(expiration_duration) = self.expiration_duration {
let mut block_mut = block.borrow_mut();
while let Some(access_time) = block_mut.recent_uses.front() {
if now.duration_since(*access_time) > expiration_duration {
block_mut.recent_uses.pop_front();
} else {
break;
}
} }
scores.add_frequency(block_mut.recent_uses.len());
block_mut.recent_uses.push_back(now);
}
current = block; if active_count == 0 {
} else {
tracing::trace!(
"RadixTree::find_matches: block not found at index {} for hash {}",
idx,
block_hash.0
);
break; break;
} }
if early_exit && active_count == 1 {
matched_depth = (idx + 1) as u32;
break;
}
current = block;
matched_depth = (idx + 1) as u32;
}
// Record scores for workers that survived through the deepest matched level.
for worker in &active {
scores.scores.insert(*worker, matched_depth);
} }
tracing::trace!("RadixTree::find_matches: final scores={:?}", scores.scores); tracing::trace!("RadixTree::find_matches: final scores={:?}", scores.scores);
// Populate tree sizes for all workers that have scores // Populate tree sizes for all workers that have scores.
for worker in scores.scores.keys() { for worker in scores.scores.keys() {
let tree_size = self let tree_size = self
.lookup .lookup
...@@ -250,8 +350,19 @@ impl RadixTree { ...@@ -250,8 +350,19 @@ impl RadixTree {
None => self.root.clone(), None => self.root.clone(),
}; };
let mut needs_worker_insert = false;
// In each iteration we lock the parent and insert the worker
// deferred from the previous iteration, avoiding a second
// borrow on the same block.
for block_data in op.blocks { for block_data in op.blocks {
let mut parent_mut = current.borrow_mut(); let mut parent_mut = current.borrow_mut();
if needs_worker_insert {
parent_mut.workers.insert(worker);
}
needs_worker_insert = true;
let child = match parent_mut.children.get(&block_data.tokens_hash) { let child = match parent_mut.children.get(&block_data.tokens_hash) {
Some(block) => { Some(block) => {
// Verify our simplifying assumption: block_hash is uniform across workers // Verify our simplifying assumption: block_hash is uniform across workers
...@@ -265,7 +376,6 @@ impl RadixTree { ...@@ -265,7 +376,6 @@ impl RadixTree {
block.clone() block.clone()
} }
None => { None => {
// create new block or reuse existing from worker's lookup
let new_block = worker_lookup let new_block = worker_lookup
.get(&block_data.block_hash) .get(&block_data.block_hash)
.cloned() .cloned()
...@@ -275,7 +385,6 @@ impl RadixTree { ...@@ -275,7 +385,6 @@ impl RadixTree {
))) )))
}); });
// insert into radix tree
parent_mut parent_mut
.children .children
.insert(block_data.tokens_hash, new_block.clone()); .insert(block_data.tokens_hash, new_block.clone());
...@@ -284,36 +393,30 @@ impl RadixTree { ...@@ -284,36 +393,30 @@ impl RadixTree {
} }
}; };
// Update child and check for self referential blocks // Self-reference check: try_borrow_mut will fail if child
{ // is the same Rc as current (parent_mut holds a mutable borrow).
// Try to borrow the child mutably - if it fails, it's already borrowed if child.try_borrow_mut().is_err() {
// which means a self referencing block. tracing::warn!(
let mut child_mut = match child.try_borrow_mut() { worker_id = worker.worker_id.to_string(),
Ok(b) => b, dp_rank = worker.dp_rank,
Err(_) => { id,
tracing::warn!( block_hash = ?block_data.block_hash,
worker_id = worker.worker_id.to_string(), "Detected self referencing block in store event; rejecting sequence"
dp_rank = worker.dp_rank, );
id, return Err(KvCacheEventError::InvalidBlockSequence);
block_hash = ?block_data.block_hash,
"Detected self referencing block in store event; rejecting sequence"
);
return Err(KvCacheEventError::InvalidBlockSequence);
}
};
// add our worker to the block
child_mut.workers.insert(worker);
} }
// add the block to the worker's lookup table
worker_lookup.insert(block_data.block_hash, child.clone()); worker_lookup.insert(block_data.block_hash, child.clone());
// drop child so we can shift current to this block
drop(parent_mut); drop(parent_mut);
current = child; current = child;
} }
// Insert worker into the last child.
if needs_worker_insert {
current.borrow_mut().workers.insert(worker);
}
Ok(()) Ok(())
} }
KvCacheEventData::Removed(remove) => { KvCacheEventData::Removed(remove) => {
...@@ -474,64 +577,8 @@ impl RadixTree { ...@@ -474,64 +577,8 @@ impl RadixTree {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::protocols::{ use crate::protocols::{ExternalSequenceBlockHash, LocalBlockHash};
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, use crate::test_utils::{create_remove_event, create_store_event};
KvCacheStoreData, KvCacheStoredBlockData, LocalBlockHash, WorkerId,
};
/// Creates blocks with artificial hash mapping (hash * 100) for testing RadixTree internals.
fn make_blocks(hashes: Vec<u64>) -> Vec<KvCacheStoredBlockData> {
hashes
.iter()
.map(|i| KvCacheStoredBlockData {
tokens_hash: LocalBlockHash(*i),
block_hash: ExternalSequenceBlockHash(*i * 100),
mm_extra_info: None,
})
.collect()
}
fn add_blocks(
hashes: Vec<u64>,
parent_hash: Option<ExternalSequenceBlockHash>,
) -> KvCacheEventData {
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks: make_blocks(hashes),
})
}
fn create_store_event(
worker_id: WorkerId,
event_id: u64,
hashes: Vec<u64>,
parent: Option<ExternalSequenceBlockHash>,
) -> RouterEvent {
RouterEvent {
worker_id,
event: KvCacheEvent {
event_id,
data: add_blocks(hashes, parent),
dp_rank: 0,
},
}
}
fn create_remove_event(worker_id: WorkerId, event_id: u64, hashes: Vec<u64>) -> RouterEvent {
RouterEvent {
worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: hashes
.iter()
.map(|i| ExternalSequenceBlockHash(*i * 100))
.collect(),
}),
dp_rank: 0,
},
}
}
#[test] #[test]
fn test_radix_tree() { fn test_radix_tree() {
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Shared test utilities for radix tree tests.
use crate::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash, RouterEvent, WorkerId,
};
/// Creates blocks with artificial hash mapping (hash * 100) for testing.
pub fn make_blocks(hashes: Vec<u64>) -> Vec<KvCacheStoredBlockData> {
hashes
.iter()
.map(|i| KvCacheStoredBlockData {
tokens_hash: LocalBlockHash(*i),
block_hash: ExternalSequenceBlockHash(*i * 100),
mm_extra_info: None,
})
.collect()
}
pub fn add_blocks(
hashes: Vec<u64>,
parent_hash: Option<ExternalSequenceBlockHash>,
) -> KvCacheEventData {
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks: make_blocks(hashes),
})
}
pub fn create_store_event(
worker_id: WorkerId,
event_id: u64,
hashes: Vec<u64>,
parent: Option<ExternalSequenceBlockHash>,
) -> RouterEvent {
RouterEvent {
worker_id,
event: KvCacheEvent {
event_id,
data: add_blocks(hashes, parent),
dp_rank: 0,
},
}
}
pub fn create_remove_event(worker_id: WorkerId, event_id: u64, hashes: Vec<u64>) -> RouterEvent {
RouterEvent {
worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: hashes
.iter()
.map(|i| ExternalSequenceBlockHash(*i * 100))
.collect(),
}),
dp_rank: 0,
},
}
}
...@@ -2,9 +2,11 @@ ...@@ -2,9 +2,11 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use anyhow::Result; use anyhow::Result;
use dynamo_kv_router::{ConcurrentRadixTree, ThreadPoolIndexer};
use dynamo_runtime::{ use dynamo_runtime::{
component::{Client, Endpoint}, component::{Client, Endpoint},
discovery::DiscoveryQuery, discovery::DiscoveryQuery,
...@@ -17,6 +19,7 @@ use dynamo_runtime::{ ...@@ -17,6 +19,7 @@ use dynamo_runtime::{
traits::DistributedRuntimeProvider, traits::DistributedRuntimeProvider,
}; };
use futures::stream; use futures::stream;
use tokio::sync::oneshot;
use validator::Validate; use validator::Validate;
// Re-export from dynamo-kv-router crate // Re-export from dynamo-kv-router crate
...@@ -43,10 +46,11 @@ use crate::{ ...@@ -43,10 +46,11 @@ use crate::{
discovery::RuntimeConfigWatch, discovery::RuntimeConfigWatch,
kv_router::{ kv_router::{
approx::PruneConfig, approx::PruneConfig,
indexer::{KvIndexer, KvIndexerInterface, KvRouterError}, indexer::{GetWorkersRequest, KvIndexer, KvIndexerInterface, KvRouterError},
protocols::{ protocols::{
DpRank, LocalBlockHash, OverlapScores, RouterEvent, RouterRequest, RouterResponse, DpRank, LocalBlockHash, OverlapScores, RouterEvent, RouterRequest, RouterResponse,
TokensWithHashes, WorkerSelectionResult, WorkerWithDpRank, compute_block_hash_for_seq, TokensWithHashes, WorkerId, WorkerSelectionResult, WorkerWithDpRank,
compute_block_hash_for_seq,
}, },
scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest}, scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
sequence::SequenceError, sequence::SequenceError,
...@@ -113,12 +117,18 @@ pub trait WorkerSelector { ...@@ -113,12 +117,18 @@ pub trait WorkerSelector {
) -> Result<WorkerSelectionResult, KvSchedulerError>; ) -> Result<WorkerSelectionResult, KvSchedulerError>;
} }
#[derive(Clone)]
pub enum Indexer { pub enum Indexer {
/// Updates itself based on KV events emitted by backend workers or routing decisions. /// Single-threaded radix tree with channel-based event processing.
/// Supports TTL-based expiration and size-based pruning. /// Supports TTL-based expiration and size-based pruning.
/// Has the ability to persist and snapshot states. /// Has the ability to persist and snapshot states.
KvIndexer(KvIndexer), KvIndexer(KvIndexer),
/// Concurrent radix tree with a thread pool for event processing.
/// Uses sticky worker routing for per-worker event serialization.
/// Does not support TTL/pruning.
Concurrent(Arc<ThreadPoolIndexer<ConcurrentRadixTree>>),
/// Used when we do not wish to use the indexer at all (e.g., when overlap_score_weight is 0). /// Used when we do not wish to use the indexer at all (e.g., when overlap_score_weight is 0).
/// Note: This will cause KV events to accumulate in JetStream as we do not regularly purge them. /// Note: This will cause KV events to accumulate in JetStream as we do not regularly purge them.
None, None,
...@@ -132,30 +142,37 @@ impl Indexer { ...@@ -132,30 +142,37 @@ impl Indexer {
cancellation_token: tokio_util::sync::CancellationToken, cancellation_token: tokio_util::sync::CancellationToken,
) -> Self { ) -> Self {
if kv_router_config.overlap_score_weight == 0.0 { if kv_router_config.overlap_score_weight == 0.0 {
// When overlap_score_weight is zero, we don't need to track prefixes return Indexer::None;
Indexer::None }
} else {
let kv_indexer_metrics = indexer::KvIndexerMetrics::from_component(component); if kv_router_config.router_event_threads > 1 {
return Indexer::Concurrent(Arc::new(ThreadPoolIndexer::new(
// If use_kv_events is false, enable TTL and pruning for approximate behavior ConcurrentRadixTree::new(),
let prune_config = if !kv_router_config.use_kv_events { kv_router_config.router_event_threads as usize,
Some(PruneConfig {
ttl: Duration::from_secs_f64(kv_router_config.router_ttl_secs),
max_tree_size: kv_router_config.router_max_tree_size,
prune_target_ratio: kv_router_config.router_prune_target_ratio,
})
} else {
None
};
Indexer::KvIndexer(KvIndexer::new_with_frequency(
cancellation_token,
None, // expiration_duration for frequency tracking
block_size, block_size,
kv_indexer_metrics, )));
prune_config,
))
} }
let kv_indexer_metrics = indexer::KvIndexerMetrics::from_component(component);
// If use_kv_events is false, enable TTL and pruning for approximate behavior
let prune_config = if !kv_router_config.use_kv_events {
Some(PruneConfig {
ttl: Duration::from_secs_f64(kv_router_config.router_ttl_secs),
max_tree_size: kv_router_config.router_max_tree_size,
prune_target_ratio: kv_router_config.router_prune_target_ratio,
})
} else {
None
};
Indexer::KvIndexer(KvIndexer::new_with_frequency(
cancellation_token,
None, // expiration_duration for frequency tracking
block_size,
kv_indexer_metrics,
prune_config,
))
} }
pub(crate) async fn find_matches( pub(crate) async fn find_matches(
...@@ -164,6 +181,7 @@ impl Indexer { ...@@ -164,6 +181,7 @@ impl Indexer {
) -> Result<OverlapScores, KvRouterError> { ) -> Result<OverlapScores, KvRouterError> {
match self { match self {
Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await, Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await,
Indexer::Concurrent(tpi) => tpi.find_matches(sequence).await,
Indexer::None => Ok(OverlapScores { Indexer::None => Ok(OverlapScores {
scores: HashMap::new(), scores: HashMap::new(),
frequencies: Vec::new(), frequencies: Vec::new(),
...@@ -175,6 +193,7 @@ impl Indexer { ...@@ -175,6 +193,7 @@ impl Indexer {
pub(crate) async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> { pub(crate) async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
match self { match self {
Indexer::KvIndexer(indexer) => indexer.dump_events().await, Indexer::KvIndexer(indexer) => indexer.dump_events().await,
Indexer::Concurrent(tpi) => tpi.dump_events().await,
Indexer::None => { Indexer::None => {
panic!( panic!(
"Cannot dump events: indexer does not exist (is overlap_score_weight set to 0?)" "Cannot dump events: indexer does not exist (is overlap_score_weight set to 0?)"
...@@ -194,9 +213,55 @@ impl Indexer { ...@@ -194,9 +213,55 @@ impl Indexer {
.process_routing_decision_for_request(tokens_with_hashes, worker) .process_routing_decision_for_request(tokens_with_hashes, worker)
.await .await
} }
Indexer::Concurrent(tpi) => {
tpi.process_routing_decision_for_request(tokens_with_hashes, worker)
.await
}
Indexer::None => Ok(()), Indexer::None => Ok(()),
} }
} }
pub(crate) async fn apply_event(&self, event: RouterEvent) {
match self {
Indexer::KvIndexer(indexer) => {
if let Err(e) = indexer.event_sender().send(event).await {
tracing::warn!("Failed to send event to indexer: {e}");
}
}
Indexer::Concurrent(tpi) => tpi.apply_event(event).await,
Indexer::None => {}
}
}
pub(crate) async fn remove_worker(&self, worker_id: WorkerId) {
match self {
Indexer::KvIndexer(indexer) => {
if let Err(e) = indexer.remove_worker_sender().send(worker_id).await {
tracing::warn!("Failed to send worker removal for {worker_id}: {e}");
}
}
Indexer::Concurrent(tpi) => {
KvIndexerInterface::remove_worker(tpi.as_ref(), worker_id).await;
}
Indexer::None => {}
}
}
pub(crate) async fn get_workers(&self) -> Vec<WorkerId> {
match self {
Indexer::KvIndexer(indexer) => {
let (resp_tx, resp_rx) = oneshot::channel();
let req = GetWorkersRequest { resp: resp_tx };
if let Err(e) = indexer.get_workers_sender().send(req).await {
tracing::warn!("Failed to send get_workers request: {e}");
return Vec::new();
}
resp_rx.await.unwrap_or_default()
}
Indexer::Concurrent(tpi) => tpi.backend().get_workers(),
Indexer::None => Vec::new(),
}
}
} }
/// A KvRouter only decides which worker you should use. It doesn't send you there. /// A KvRouter only decides which worker you should use. It doesn't send you there.
...@@ -255,18 +320,11 @@ impl KvRouter { ...@@ -255,18 +320,11 @@ impl KvRouter {
// Start KV event subscription if needed (use_kv_events=true and overlap_score_weight>0) // Start KV event subscription if needed (use_kv_events=true and overlap_score_weight>0)
if kv_router_config.should_subscribe_to_kv_events() { if kv_router_config.should_subscribe_to_kv_events() {
// Guaranteed to be KvIndexer since overlap_score_weight > 0.0
let Indexer::KvIndexer(kv_indexer) = &indexer else {
unreachable!(
"should_subscribe_to_kv_events implies overlap_score_weight > 0 implies KvIndexer"
)
};
subscriber::start_subscriber( subscriber::start_subscriber(
component.clone(), component.clone(),
&kv_router_config, &kv_router_config,
router_id, router_id,
kv_indexer, indexer.clone(),
cancellation_token.clone(), cancellation_token.clone(),
) )
.await?; .await?;
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
use derive_builder::Builder; use derive_builder::Builder;
use rand::Rng; use rand::Rng;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use validator::Validate; use validator::{Validate, ValidationError};
use crate::kv_router::protocols::{compute_block_hash_for_seq, compute_seq_hash_for_block}; use crate::kv_router::protocols::{compute_block_hash_for_seq, compute_seq_hash_for_block};
...@@ -21,6 +21,7 @@ pub struct RouterConfigOverride { ...@@ -21,6 +21,7 @@ pub struct RouterConfigOverride {
/// KV Router configuration parameters /// KV Router configuration parameters
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Validate)] #[derive(Debug, Clone, Copy, Serialize, Deserialize, Validate)]
#[validate(schema(function = "validate_kv_router_config"))]
pub struct KvRouterConfig { pub struct KvRouterConfig {
#[validate(range(min = 0.0))] #[validate(range(min = 0.0))]
pub overlap_score_weight: f64, pub overlap_score_weight: f64,
...@@ -69,6 +70,12 @@ pub struct KvRouterConfig { ...@@ -69,6 +70,12 @@ pub struct KvRouterConfig {
/// Target size ratio after pruning (only used when use_kv_events is false, default: 0.8) /// Target size ratio after pruning (only used when use_kv_events is false, default: 0.8)
#[validate(range(min = 0.0, max = 1.0))] #[validate(range(min = 0.0, max = 1.0))]
pub router_prune_target_ratio: f64, pub router_prune_target_ratio: f64,
/// Number of event processing threads for the KV indexer.
/// When > 1, uses ConcurrentRadixTree with a thread pool instead of the
/// single-threaded RadixTree. Default: 1.
#[validate(range(min = 1))]
pub router_event_threads: u32,
} }
impl Default for KvRouterConfig { impl Default for KvRouterConfig {
...@@ -87,10 +94,30 @@ impl Default for KvRouterConfig { ...@@ -87,10 +94,30 @@ impl Default for KvRouterConfig {
router_ttl_secs: 120.0, router_ttl_secs: 120.0,
router_max_tree_size: 2usize.pow(20), // 2^20 = 1048576, matches PruneConfig::default() router_max_tree_size: 2usize.pow(20), // 2^20 = 1048576, matches PruneConfig::default()
router_prune_target_ratio: 0.8, router_prune_target_ratio: 0.8,
router_event_threads: 1,
} }
} }
} }
fn validate_kv_router_config(config: &KvRouterConfig) -> Result<(), ValidationError> {
if config.durable_kv_events && !config.use_kv_events {
return Err(ValidationError::new(
"durable_kv_events requires use_kv_events=true",
));
}
if !config.use_kv_events && config.router_event_threads > 1 {
return Err(ValidationError::new(
"router_event_threads > 1 requires use_kv_events=true",
));
}
if config.router_track_output_blocks && !config.router_track_active_blocks {
return Err(ValidationError::new(
"router_track_output_blocks requires router_track_active_blocks=true",
));
}
Ok(())
}
impl KvRouterConfig { impl KvRouterConfig {
/// Compute sequence hashes for active block tracking based on configuration. /// Compute sequence hashes for active block tracking based on configuration.
/// ///
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::collections::HashSet;
use std::fmt; use std::fmt;
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering}; use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
...@@ -577,6 +578,26 @@ fn convert_event( ...@@ -577,6 +578,26 @@ fn convert_event(
block_mm_infos, block_mm_infos,
.. ..
} => { } => {
// Reject self-referencing blocks: all block hashes (including parent) must be unique.
{
let mut seen = HashSet::with_capacity(block_hashes.len() + 1);
if let Some(parent) = parent_block_hash {
seen.insert(parent.into_u64());
}
let has_duplicate = block_hashes.iter().any(|h| !seen.insert(h.into_u64()));
if has_duplicate {
tracing::warn!(
event_id,
"Self-referencing block detected: duplicate hash in store event; dropping"
);
return KvCacheEvent {
event_id,
data: KvCacheEventData::Cleared,
dp_rank,
};
}
}
let num_block_tokens = vec![block_size as u64; block_hashes.len()]; let num_block_tokens = vec![block_size as u64; block_hashes.len()];
let block_hashes_u64: Vec<u64> = block_hashes let block_hashes_u64: Vec<u64> = block_hashes
.into_iter() .into_iter()
......
...@@ -14,12 +14,10 @@ use dynamo_runtime::{ ...@@ -14,12 +14,10 @@ use dynamo_runtime::{
}; };
use futures::StreamExt; use futures::StreamExt;
use rand::Rng; use rand::Rng;
use tokio::sync::{mpsc, oneshot};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use crate::kv_router::{ use crate::kv_router::{
KV_EVENT_SUBJECT, KvRouterConfig, RADIX_STATE_BUCKET, RADIX_STATE_FILE, Indexer, KV_EVENT_SUBJECT, KvRouterConfig, RADIX_STATE_BUCKET, RADIX_STATE_FILE,
indexer::{DumpRequest, GetWorkersRequest, KvIndexer},
protocols::{DpRank, RouterEvent, WorkerId}, protocols::{DpRank, RouterEvent, WorkerId},
router_discovery_query, router_discovery_query,
worker_query::WorkerQueryClient, worker_query::WorkerQueryClient,
...@@ -84,7 +82,7 @@ async fn get_instance_discovery_stream( ...@@ -84,7 +82,7 @@ async fn get_instance_discovery_stream(
async fn download_stable_snapshot( async fn download_stable_snapshot(
nats_client: &dynamo_runtime::transports::nats::Client, nats_client: &dynamo_runtime::transports::nats::Client,
bucket_name: &str, bucket_name: &str,
kv_events_tx: &mpsc::Sender<RouterEvent>, indexer: &Indexer,
) -> Result<()> { ) -> Result<()> {
let url = url::Url::parse(&format!( let url = url::Url::parse(&format!(
"nats://{}/{bucket_name}/{RADIX_STATE_FILE}", "nats://{}/{bucket_name}/{RADIX_STATE_FILE}",
...@@ -147,9 +145,7 @@ async fn download_stable_snapshot( ...@@ -147,9 +145,7 @@ async fn download_stable_snapshot(
// Send all events to the indexer // Send all events to the indexer
for event in prev_events { for event in prev_events {
if let Err(e) = kv_events_tx.send(event).await { indexer.apply_event(event).await;
tracing::warn!("Failed to send initial event to indexer: {e:?}");
}
} }
tracing::info!("Successfully sent all initial events to indexer"); tracing::info!("Successfully sent all initial events to indexer");
...@@ -162,57 +158,27 @@ struct SnapshotResources { ...@@ -162,57 +158,27 @@ struct SnapshotResources {
nats_client: dynamo_runtime::transports::nats::Client, nats_client: dynamo_runtime::transports::nats::Client,
bucket_name: String, bucket_name: String,
instances_rx: tokio::sync::watch::Receiver<Vec<dynamo_runtime::component::Instance>>, instances_rx: tokio::sync::watch::Receiver<Vec<dynamo_runtime::component::Instance>>,
get_workers_tx: mpsc::Sender<GetWorkersRequest>, indexer: Indexer,
snapshot_tx: mpsc::Sender<DumpRequest>,
} }
impl SnapshotResources { impl SnapshotResources {
/// Perform snapshot upload and purge operations /// Perform snapshot upload and purge operations
async fn purge_then_snapshot( async fn purge_then_snapshot(&self, nats_queue: &mut NatsQueue) -> anyhow::Result<()> {
&self,
nats_queue: &mut NatsQueue,
remove_worker_tx: &mpsc::Sender<WorkerId>,
) -> anyhow::Result<()> {
// Purge before snapshot ensures new/warm-restarted routers won't replay already-acknowledged messages.
// Since KV events are idempotent, this ordering reduces unnecessary reprocessing while maintaining
// at-least-once delivery guarantees. The snapshot will capture the clean state after purge.
tracing::info!("Purging acknowledged messages and performing snapshot of radix tree"); tracing::info!("Purging acknowledged messages and performing snapshot of radix tree");
let start_time = std::time::Instant::now(); let start_time = std::time::Instant::now();
// Clean up stale workers before snapshot // Clean up stale workers before snapshot
// Get current worker IDs from instances_rx
let current_instances = self.instances_rx.borrow().clone(); let current_instances = self.instances_rx.borrow().clone();
let current_worker_ids: std::collections::HashSet<u64> = current_instances let current_worker_ids: std::collections::HashSet<u64> = current_instances
.iter() .iter()
.map(|instance| instance.instance_id) .map(|instance| instance.instance_id)
.collect(); .collect();
// Get worker IDs from the indexer let indexer_worker_ids = self.indexer.get_workers().await;
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); for worker_id in indexer_worker_ids {
let get_workers_req = GetWorkersRequest { resp: resp_tx }; if !current_worker_ids.contains(&worker_id) {
tracing::info!("Removing stale worker {worker_id} from indexer during snapshot");
if let Err(e) = self.get_workers_tx.send(get_workers_req).await { self.indexer.remove_worker(worker_id).await;
tracing::warn!("Failed to send get_workers request during snapshot: {e:?}");
} else {
match resp_rx.await {
Ok(indexer_worker_ids) => {
// Find workers in indexer but not in current instances
for worker_id in indexer_worker_ids {
if !current_worker_ids.contains(&worker_id) {
tracing::info!(
"Removing stale worker {worker_id} from indexer during snapshot"
);
if let Err(e) = remove_worker_tx.send(worker_id).await {
tracing::warn!(
"Failed to send remove_worker for stale worker {worker_id}: {e:?}"
);
}
}
}
}
Err(e) => {
tracing::warn!("Failed to receive worker IDs from indexer: {e:?}");
}
} }
} }
...@@ -220,18 +186,11 @@ impl SnapshotResources { ...@@ -220,18 +186,11 @@ impl SnapshotResources {
nats_queue.purge_acknowledged().await?; nats_queue.purge_acknowledged().await?;
// Now request a snapshot from the indexer (which reflects the post-purge state) // Now request a snapshot from the indexer (which reflects the post-purge state)
let (resp_tx, resp_rx) = oneshot::channel(); let events = self
let dump_req = DumpRequest { resp: resp_tx }; .indexer
.dump_events()
self.snapshot_tx
.send(dump_req)
.await .await
.map_err(|e| anyhow::anyhow!("Failed to send dump request: {e:?}"))?; .map_err(|e| anyhow::anyhow!("Failed to dump events for snapshot: {e:?}"))?;
// Wait for the dump response
let events = resp_rx
.await
.map_err(|e| anyhow::anyhow!("Failed to receive dump response: {e:?}"))?;
// Upload the snapshot to NATS object store in background (non-blocking) // Upload the snapshot to NATS object store in background (non-blocking)
let nats_client = self.nats_client.clone(); let nats_client = self.nats_client.clone();
...@@ -262,14 +221,10 @@ impl SnapshotResources { ...@@ -262,14 +221,10 @@ impl SnapshotResources {
} }
/// Start a unified background task for event consumption and optional snapshot management /// Start a unified background task for event consumption and optional snapshot management
#[allow(clippy::too_many_arguments)]
pub async fn start_kv_router_background( pub async fn start_kv_router_background(
component: Component, component: Component,
consumer_id: String, consumer_id: String,
kv_events_tx: mpsc::Sender<RouterEvent>, indexer: Indexer,
remove_worker_tx: mpsc::Sender<WorkerId>,
maybe_get_workers_tx: Option<mpsc::Sender<GetWorkersRequest>>,
maybe_snapshot_tx: Option<mpsc::Sender<DumpRequest>>,
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
router_snapshot_threshold: Option<u32>, router_snapshot_threshold: Option<u32>,
router_reset_states: bool, router_reset_states: bool,
...@@ -307,7 +262,7 @@ pub async fn start_kv_router_background( ...@@ -307,7 +262,7 @@ pub async fn start_kv_router_background(
// Handle initial state based on router_reset_states flag // Handle initial state based on router_reset_states flag
if !router_reset_states { if !router_reset_states {
// Try to download initial state from object store with stability check // Try to download initial state from object store with stability check
download_stable_snapshot(&nats_client, &bucket_name, &kv_events_tx).await?; download_stable_snapshot(&nats_client, &bucket_name, &indexer).await?;
} else { } else {
// Delete the bucket to reset state // Delete the bucket to reset state
tracing::info!("Resetting router state, deleting bucket: {bucket_name}"); tracing::info!("Resetting router state, deleting bucket: {bucket_name}");
...@@ -335,22 +290,13 @@ pub async fn start_kv_router_background( ...@@ -335,22 +290,13 @@ pub async fn start_kv_router_background(
let client = generate_endpoint.client().await?; let client = generate_endpoint.client().await?;
let instances_rx = client.instance_source.as_ref().clone(); let instances_rx = client.instance_source.as_ref().clone();
// Only set up snapshot-related resources if snapshot_tx, get_workers_tx, and threshold are provided // Only set up snapshot-related resources if snapshot threshold is configured
let snapshot_resources = if let (Some(get_workers_tx), Some(snapshot_tx), Some(_)) = ( let snapshot_resources = router_snapshot_threshold.map(|_| SnapshotResources {
maybe_get_workers_tx, nats_client,
maybe_snapshot_tx, bucket_name,
router_snapshot_threshold, instances_rx,
) { indexer: indexer.clone(),
Some(SnapshotResources { });
nats_client,
bucket_name,
instances_rx,
get_workers_tx,
snapshot_tx,
})
} else {
None
};
tokio::spawn(async move { tokio::spawn(async move {
// Create interval with jitter // Create interval with jitter
...@@ -392,9 +338,7 @@ pub async fn start_kv_router_background( ...@@ -392,9 +338,7 @@ pub async fn start_kv_router_background(
"DISCOVERY: Generate endpoint instance removed, removing worker {worker_id}" "DISCOVERY: Generate endpoint instance removed, removing worker {worker_id}"
); );
if let Err(e) = remove_worker_tx.send(worker_id).await { indexer.remove_worker(worker_id).await;
tracing::warn!("Failed to send worker removal for worker {worker_id}: {e}");
}
} }
// Handle event consumption // Handle event consumption
...@@ -410,12 +354,7 @@ pub async fn start_kv_router_background( ...@@ -410,12 +354,7 @@ pub async fn start_kv_router_background(
}; };
// Forward the RouterEvent to the indexer // Forward the RouterEvent to the indexer
if let Err(e) = kv_events_tx.send(event).await { indexer.apply_event(event).await;
tracing::warn!(
"failed to send kv event to indexer; shutting down: {e:?}"
);
break;
}
}, },
Ok(None) => { Ok(None) => {
tracing::trace!("Dequeue timeout, continuing"); tracing::trace!("Dequeue timeout, continuing");
...@@ -449,7 +388,6 @@ pub async fn start_kv_router_background( ...@@ -449,7 +388,6 @@ pub async fn start_kv_router_background(
match resources.purge_then_snapshot( match resources.purge_then_snapshot(
&mut nats_queue, &mut nats_queue,
&remove_worker_tx,
).await { ).await {
Ok(_) => tracing::info!("Successfully performed purge and snapshot"), Ok(_) => tracing::info!("Successfully performed purge and snapshot"),
Err(e) => tracing::debug!("Could not perform purge and snapshot: {e:?}"), Err(e) => tracing::debug!("Could not perform purge and snapshot: {e:?}"),
...@@ -510,15 +448,13 @@ pub async fn start_kv_router_background( ...@@ -510,15 +448,13 @@ pub async fn start_kv_router_background(
/// This is appropriate when workers have local indexers enabled. /// This is appropriate when workers have local indexers enabled.
pub async fn start_kv_router_background_event_plane( pub async fn start_kv_router_background_event_plane(
component: Component, component: Component,
kv_events_tx: mpsc::Sender<RouterEvent>, indexer: Indexer,
remove_worker_tx: mpsc::Sender<WorkerId>,
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
transport_kind: EventTransportKind, transport_kind: EventTransportKind,
) -> Result<()> { ) -> Result<()> {
// WorkerQueryClient handles its own discovery loop for lifecycle + initial recovery. // WorkerQueryClient handles its own discovery loop for lifecycle + initial recovery.
// No blocking wait — recovery happens asynchronously as endpoints are discovered. // No blocking wait — recovery happens asynchronously as endpoints are discovered.
let worker_query_client = let worker_query_client = WorkerQueryClient::spawn(component.clone(), indexer.clone()).await?;
WorkerQueryClient::spawn(component.clone(), remove_worker_tx, kv_events_tx.clone()).await?;
// Subscribe to KV events using the selected event plane transport // Subscribe to KV events using the selected event plane transport
let mut subscriber = let mut subscriber =
...@@ -611,12 +547,7 @@ pub async fn start_kv_router_background_event_plane( ...@@ -611,12 +547,7 @@ pub async fn start_kv_router_background_event_plane(
.or_insert(event_id); .or_insert(event_id);
// Forward the RouterEvent to the indexer // Forward the RouterEvent to the indexer
if let Err(e) = kv_events_tx.send(event).await { indexer.apply_event(event).await;
tracing::warn!(
"failed to send kv event to indexer; shutting down: {e:?}"
);
break;
}
} }
} }
} }
...@@ -670,7 +601,7 @@ pub async fn start_subscriber( ...@@ -670,7 +601,7 @@ pub async fn start_subscriber(
component: Component, component: Component,
kv_router_config: &KvRouterConfig, kv_router_config: &KvRouterConfig,
router_id: u64, router_id: u64,
kv_indexer: &KvIndexer, indexer: Indexer,
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
) -> Result<()> { ) -> Result<()> {
let transport_kind = EventTransportKind::from_env_or_default(); let transport_kind = EventTransportKind::from_env_or_default();
...@@ -690,14 +621,7 @@ pub async fn start_subscriber( ...@@ -690,14 +621,7 @@ pub async fn start_subscriber(
start_kv_router_background( start_kv_router_background(
component, component,
consumer_id, consumer_id,
kv_indexer.event_sender(), indexer,
kv_indexer.remove_worker_sender(),
kv_router_config
.router_snapshot_threshold
.map(|_| kv_indexer.get_workers_sender()),
kv_router_config
.router_snapshot_threshold
.map(|_| kv_indexer.snapshot_event_sender()),
cancellation_token, cancellation_token,
kv_router_config.router_snapshot_threshold, kv_router_config.router_snapshot_threshold,
kv_router_config.router_reset_states, kv_router_config.router_reset_states,
...@@ -719,8 +643,7 @@ pub async fn start_subscriber( ...@@ -719,8 +643,7 @@ pub async fn start_subscriber(
start_kv_router_background_event_plane( start_kv_router_background_event_plane(
component.clone(), component.clone(),
kv_indexer.event_sender(), indexer,
kv_indexer.remove_worker_sender(),
cancellation_token, cancellation_token,
transport_kind, transport_kind,
) )
......
...@@ -17,10 +17,10 @@ use dynamo_runtime::protocols::maybe_error::MaybeError; ...@@ -17,10 +17,10 @@ use dynamo_runtime::protocols::maybe_error::MaybeError;
use dynamo_runtime::stream; use dynamo_runtime::stream;
use dynamo_runtime::traits::DistributedRuntimeProvider; use dynamo_runtime::traits::DistributedRuntimeProvider;
use futures::StreamExt; use futures::StreamExt;
use tokio::sync::mpsc;
use crate::kv_router::Indexer;
use crate::kv_router::indexer::{LocalKvIndexer, WorkerKvQueryRequest, WorkerKvQueryResponse}; use crate::kv_router::indexer::{LocalKvIndexer, WorkerKvQueryRequest, WorkerKvQueryResponse};
use crate::kv_router::protocols::{DpRank, RouterEvent, WorkerId}; use crate::kv_router::protocols::{DpRank, WorkerId};
use crate::kv_router::worker_kv_indexer_query_endpoint; use crate::kv_router::worker_kv_indexer_query_endpoint;
// Recovery retry configuration // Recovery retry configuration
...@@ -43,10 +43,8 @@ pub struct WorkerQueryClient { ...@@ -43,10 +43,8 @@ pub struct WorkerQueryClient {
component: Component, component: Component,
/// Routers keyed by dp_rank — each dp_rank has its own endpoint. Created lazily. /// Routers keyed by dp_rank — each dp_rank has its own endpoint. Created lazily.
routers: Arc<DashMap<DpRank, Arc<PushRouter<WorkerKvQueryRequest, WorkerKvQueryResponse>>>>, routers: Arc<DashMap<DpRank, Arc<PushRouter<WorkerKvQueryRequest, WorkerKvQueryResponse>>>>,
/// Channel to send recovered events to the router indexer /// Indexer for applying recovered events and worker removals.
kv_events_tx: mpsc::Sender<RouterEvent>, indexer: Indexer,
/// Channel to send worker removal events to the router indexer
remove_worker_tx: mpsc::Sender<WorkerId>,
} }
impl WorkerQueryClient { impl WorkerQueryClient {
...@@ -55,16 +53,11 @@ impl WorkerQueryClient { ...@@ -55,16 +53,11 @@ impl WorkerQueryClient {
/// The background loop watches `ComponentEndpoints` discovery for query endpoints, /// The background loop watches `ComponentEndpoints` discovery for query endpoints,
/// recovers each `(worker_id, dp_rank)` as it appears, and sends worker removal /// recovers each `(worker_id, dp_rank)` as it appears, and sends worker removal
/// events when all dp_ranks for a worker disappear. /// events when all dp_ranks for a worker disappear.
pub async fn spawn( pub async fn spawn(component: Component, indexer: Indexer) -> Result<Arc<Self>> {
component: Component,
remove_worker_tx: mpsc::Sender<WorkerId>,
kv_events_tx: mpsc::Sender<RouterEvent>,
) -> Result<Arc<Self>> {
let client = Arc::new(Self { let client = Arc::new(Self {
component: component.clone(), component: component.clone(),
routers: Arc::new(DashMap::new()), routers: Arc::new(DashMap::new()),
kv_events_tx, indexer,
remove_worker_tx,
}); });
let client_bg = client.clone(); let client_bg = client.clone();
...@@ -151,11 +144,7 @@ impl WorkerQueryClient { ...@@ -151,11 +144,7 @@ impl WorkerQueryClient {
tracing::warn!( tracing::warn!(
"WorkerQueryClient: all dp_ranks gone for worker {worker_id}, removing" "WorkerQueryClient: all dp_ranks gone for worker {worker_id}, removing"
); );
if let Err(e) = self.remove_worker_tx.send(worker_id).await { self.indexer.remove_worker(worker_id).await;
tracing::warn!(
"Failed to send worker removal for worker {worker_id}: {e}"
);
}
} }
} }
} }
...@@ -354,12 +343,7 @@ impl WorkerQueryClient { ...@@ -354,12 +343,7 @@ impl WorkerQueryClient {
tracing::info!("Recovered {count} events from worker {worker_id} dp_rank {dp_rank}"); tracing::info!("Recovered {count} events from worker {worker_id} dp_rank {dp_rank}");
for event in events { for event in events {
if let Err(e) = self.kv_events_tx.send(event).await { self.indexer.apply_event(event).await;
tracing::error!(
"Failed to send recovered event to indexer for worker {worker_id} dp_rank {dp_rank}: {e}"
);
anyhow::bail!("Failed to send recovered event: {e}");
}
} }
Ok(count) Ok(count)
......
...@@ -1305,6 +1305,7 @@ def _test_router_indexers_sync( ...@@ -1305,6 +1305,7 @@ def _test_router_indexers_sync(
test_nats_interruption: bool = False, test_nats_interruption: bool = False,
nats_server: Optional["NatsServer"] = None, nats_server: Optional["NatsServer"] = None,
durable_kv_events: bool = False, durable_kv_events: bool = False,
router_event_threads: int = 1,
): ):
"""Test that two KV routers have synchronized indexer states after processing requests. """Test that two KV routers have synchronized indexer states after processing requests.
...@@ -1349,6 +1350,7 @@ def _test_router_indexers_sync( ...@@ -1349,6 +1350,7 @@ def _test_router_indexers_sync(
kv_router_config = KvRouterConfig( kv_router_config = KvRouterConfig(
router_snapshot_threshold=20, router_snapshot_threshold=20,
durable_kv_events=durable_kv_events, durable_kv_events=durable_kv_events,
router_event_threads=router_event_threads,
) )
async def send_requests_to_router(router, num_requests, router_name, endpoint): async def send_requests_to_router(router, num_requests, router_name, endpoint):
...@@ -1881,6 +1883,7 @@ def _test_router_decisions( ...@@ -1881,6 +1883,7 @@ def _test_router_decisions(
block_size: int = BLOCK_SIZE, block_size: int = BLOCK_SIZE,
use_kv_events: bool = True, use_kv_events: bool = True,
durable_kv_events: bool = False, durable_kv_events: bool = False,
router_event_threads: int = 1,
): ):
"""Validate KV cache prefix reuse and worker routing by sending requests diverging prefixes. """Validate KV cache prefix reuse and worker routing by sending requests diverging prefixes.
...@@ -1911,6 +1914,7 @@ def _test_router_decisions( ...@@ -1911,6 +1914,7 @@ def _test_router_decisions(
router_snapshot_threshold=20, router_snapshot_threshold=20,
use_kv_events=use_kv_events, use_kv_events=use_kv_events,
durable_kv_events=durable_kv_events, durable_kv_events=durable_kv_events,
router_event_threads=router_event_threads,
) )
kv_push_router = KvPushRouter( kv_push_router = KvPushRouter(
endpoint=endpoint, endpoint=endpoint,
......
...@@ -522,16 +522,18 @@ def test_kv_push_router_bindings( ...@@ -522,16 +522,18 @@ def test_kv_push_router_bindings(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"store_backend,durable_kv_events,request_plane", "store_backend,durable_kv_events,request_plane,router_event_threads",
[ [
("etcd", True, "nats"), # JetStream mode - uses JetStream ("etcd", True, "nats", 1), # JetStream mode - uses JetStream
("etcd", False, "tcp"), # NATS core mode (with gap detection) - no JetStream ("etcd", False, "tcp", 1), # NATS core mode (with gap detection) - no JetStream
("file", True, "nats"), # File backend - uses JetStream ("file", True, "nats", 1), # File backend - uses JetStream
("etcd", False, "tcp", 2), # NATS core mode - multi-threaded indexer
], ],
ids=[ ids=[
"jetstream", "jetstream",
"nats_core", "nats_core",
"file", "file",
"nats_core_multi_thread",
], ],
indirect=["request_plane", "durable_kv_events"], indirect=["request_plane", "durable_kv_events"],
) )
...@@ -544,6 +546,7 @@ def test_indexers_sync( ...@@ -544,6 +546,7 @@ def test_indexers_sync(
store_backend, store_backend,
durable_kv_events, durable_kv_events,
request_plane, request_plane,
router_event_threads,
): ):
""" """
Test that two KV routers have synchronized indexer states after processing requests. Test that two KV routers have synchronized indexer states after processing requests.
...@@ -596,6 +599,7 @@ def test_indexers_sync( ...@@ -596,6 +599,7 @@ def test_indexers_sync(
test_nats_interruption=not durable_kv_events, test_nats_interruption=not durable_kv_events,
nats_server=nats_process if not durable_kv_events else None, nats_server=nats_process if not durable_kv_events else None,
durable_kv_events=durable_kv_events, durable_kv_events=durable_kv_events,
router_event_threads=router_event_threads,
) )
logger.info("Indexers sync test completed successfully") logger.info("Indexers sync test completed successfully")
...@@ -639,15 +643,16 @@ def test_query_instance_id_returns_worker_and_tokens( ...@@ -639,15 +643,16 @@ def test_query_instance_id_returns_worker_and_tokens(
@pytest.mark.timeout(90) # bumped for xdist contention (was 29s; ~9.55s serial avg) @pytest.mark.timeout(90) # bumped for xdist contention (was 29s; ~9.55s serial avg)
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True) @pytest.mark.parametrize("request_plane", ["tcp"], indirect=True)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"durable_kv_events,use_kv_events", "durable_kv_events,use_kv_events,router_event_threads",
[ [
(True, True), # JetStream mode with KV events (True, True, 1), # JetStream mode with KV events
(False, True), # NATS Core mode with local indexer (default) (False, True, 1), # NATS Core mode with local indexer (default)
(False, False), # Approximate mode (--no-kv-events) - no KV events (False, False, 1), # Approximate mode (--no-kv-events) - no KV events
(False, True, 2), # NATS Core mode - multi-threaded indexer
], ],
ids=["jetstream", "nats_core", "no_kv_events"], ids=["jetstream", "nats_core", "no_kv_events", "nats_core_multi_thread"],
indirect=["durable_kv_events"], indirect=["durable_kv_events"],
) )
def test_router_decisions( def test_router_decisions(
...@@ -657,6 +662,7 @@ def test_router_decisions( ...@@ -657,6 +662,7 @@ def test_router_decisions(
durable_kv_events, durable_kv_events,
use_kv_events, use_kv_events,
request_plane, request_plane,
router_event_threads,
): ):
"""Validate KV cache prefix reuse and dp_rank routing by sending progressive requests with overlapping prefixes. """Validate KV cache prefix reuse and dp_rank routing by sending progressive requests with overlapping prefixes.
...@@ -704,6 +710,7 @@ def test_router_decisions( ...@@ -704,6 +710,7 @@ def test_router_decisions(
test_dp_rank=True, test_dp_rank=True,
use_kv_events=use_kv_events, use_kv_events=use_kv_events,
durable_kv_events=durable_kv_events, durable_kv_events=durable_kv_events,
router_event_threads=router_event_threads,
) )
......
...@@ -371,15 +371,19 @@ def test_sglang_kv_router_basic( ...@@ -371,15 +371,19 @@ def test_sglang_kv_router_basic(
@pytest.mark.pre_merge @pytest.mark.pre_merge
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.skip(reason="Broken by sglang changes")
# TODO: Re-enable this test once https://github.com/sgl-project/sglang/pull/14934 is merged
@pytest.mark.parametrize("request_plane", ["tcp"], indirect=True) @pytest.mark.parametrize("request_plane", ["tcp"], indirect=True)
@pytest.mark.parametrize(
"router_event_threads",
[1, 2],
ids=["single_thread", "multi_thread"],
)
def test_router_decisions_sglang_multiple_workers( def test_router_decisions_sglang_multiple_workers(
request, request,
runtime_services_dynamic_ports, runtime_services_dynamic_ports,
predownload_models, predownload_models,
set_ucx_tls_no_mm, set_ucx_tls_no_mm,
request_plane, request_plane,
router_event_threads,
): ):
# runtime_services starts etcd and nats # runtime_services starts etcd and nats
logger.info("Starting SGLang router prefix reuse test with two workers") logger.info("Starting SGLang router prefix reuse test with two workers")
...@@ -396,15 +400,18 @@ def test_router_decisions_sglang_multiple_workers( ...@@ -396,15 +400,18 @@ def test_router_decisions_sglang_multiple_workers(
logger.info("Starting 2 SGLang worker processes on single GPU (mem_frac=0.4)") logger.info("Starting 2 SGLang worker processes on single GPU (mem_frac=0.4)")
logger.info(f"All SGLang workers using namespace: {sglang_workers.namespace}") logger.info(f"All SGLang workers using namespace: {sglang_workers.namespace}")
# Initialize SGLang workers
# Get runtime and create endpoint
runtime = get_runtime(request_plane=request_plane) runtime = get_runtime(request_plane=request_plane)
namespace = runtime.namespace(sglang_workers.namespace) namespace = runtime.namespace(sglang_workers.namespace)
component = namespace.component("backend") component = namespace.component("backend")
endpoint = component.endpoint("generate") endpoint = component.endpoint("generate")
_test_router_decisions( _test_router_decisions(
sglang_workers, endpoint, MODEL_NAME, request, test_dp_rank=False sglang_workers,
endpoint,
MODEL_NAME,
request,
test_dp_rank=False,
router_event_threads=router_event_threads,
) )
......
...@@ -419,6 +419,11 @@ def test_router_decisions_trtllm_attention_dp( ...@@ -419,6 +419,11 @@ def test_router_decisions_trtllm_attention_dp(
@pytest.mark.pre_merge @pytest.mark.pre_merge
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.parametrize("request_plane", ["tcp"], indirect=True) @pytest.mark.parametrize("request_plane", ["tcp"], indirect=True)
@pytest.mark.parametrize(
"router_event_threads",
[1, 2],
ids=["single_thread", "multi_thread"],
)
@pytest.mark.timeout(150) # ~3x average (~45s/test), rounded up @pytest.mark.timeout(150) # ~3x average (~45s/test), rounded up
def test_router_decisions_trtllm_multiple_workers( def test_router_decisions_trtllm_multiple_workers(
request, request,
...@@ -426,6 +431,7 @@ def test_router_decisions_trtllm_multiple_workers( ...@@ -426,6 +431,7 @@ def test_router_decisions_trtllm_multiple_workers(
predownload_models, predownload_models,
set_ucx_tls_no_mm, set_ucx_tls_no_mm,
request_plane, request_plane,
router_event_threads,
): ):
# runtime_services starts etcd and nats # runtime_services starts etcd and nats
logger.info("Starting TRT-LLM router prefix reuse test with two workers") logger.info("Starting TRT-LLM router prefix reuse test with two workers")
...@@ -444,8 +450,6 @@ def test_router_decisions_trtllm_multiple_workers( ...@@ -444,8 +450,6 @@ def test_router_decisions_trtllm_multiple_workers(
) )
logger.info(f"All TRT-LLM workers using namespace: {trtllm_workers.namespace}") logger.info(f"All TRT-LLM workers using namespace: {trtllm_workers.namespace}")
# Initialize TRT-LLM workers
# Get runtime and create endpoint
runtime = get_runtime(request_plane=request_plane) runtime = get_runtime(request_plane=request_plane)
namespace = runtime.namespace(trtllm_workers.namespace) namespace = runtime.namespace(trtllm_workers.namespace)
component = namespace.component("tensorrt_llm") component = namespace.component("tensorrt_llm")
...@@ -458,6 +462,7 @@ def test_router_decisions_trtllm_multiple_workers( ...@@ -458,6 +462,7 @@ def test_router_decisions_trtllm_multiple_workers(
request, request,
test_dp_rank=False, test_dp_rank=False,
block_size=TRTLLM_BLOCK_SIZE, block_size=TRTLLM_BLOCK_SIZE,
router_event_threads=router_event_threads,
) )
......
...@@ -385,12 +385,18 @@ def test_vllm_kv_router_basic( ...@@ -385,12 +385,18 @@ def test_vllm_kv_router_basic(
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.timeout(150) # ~3x average (~43s/test), rounded up @pytest.mark.timeout(150) # ~3x average (~43s/test), rounded up
@pytest.mark.parametrize("request_plane", ["tcp"], indirect=True) @pytest.mark.parametrize("request_plane", ["tcp"], indirect=True)
@pytest.mark.parametrize(
"router_event_threads",
[1, 2],
ids=["single_thread", "multi_thread"],
)
def test_router_decisions_vllm_multiple_workers( def test_router_decisions_vllm_multiple_workers(
request, request,
runtime_services_dynamic_ports, runtime_services_dynamic_ports,
predownload_models, predownload_models,
set_ucx_tls_no_mm, set_ucx_tls_no_mm,
request_plane, request_plane,
router_event_threads,
): ):
# runtime_services starts etcd and nats # runtime_services starts etcd and nats
logger.info("Starting vLLM router prefix reuse test with two workers") logger.info("Starting vLLM router prefix reuse test with two workers")
...@@ -414,7 +420,12 @@ def test_router_decisions_vllm_multiple_workers( ...@@ -414,7 +420,12 @@ def test_router_decisions_vllm_multiple_workers(
endpoint = component.endpoint("generate") endpoint = component.endpoint("generate")
_test_router_decisions( _test_router_decisions(
vllm_workers, endpoint, MODEL_NAME, request, test_dp_rank=False vllm_workers,
endpoint,
MODEL_NAME,
request,
test_dp_rank=False,
router_event_threads=router_event_threads,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment