Unverified Commit e3d00b89 authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: Tier-based KV Routing (#8380)


Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent 7e48f3bd
......@@ -660,6 +660,7 @@ async def init_llm_worker(
kv_block_size=config.kv_block_size,
zmq_endpoint=consolidator_output_connect_endpoint,
zmq_topic="",
enable_local_indexer=config.enable_local_indexer,
)
logging.info(
f"Created worker-side publisher for consolidated events: "
......
......@@ -8,7 +8,7 @@ use common::*;
use clap::Parser;
use common::NoopSequencePublisher;
use dynamo_kv_router::protocols::{PrefillLoadHint, WorkerWithDpRank};
use dynamo_kv_router::{ActiveSequencesMultiWorker, OverlapScores, SequenceRequest};
use dynamo_kv_router::{ActiveSequencesMultiWorker, SequenceRequest};
use dynamo_mocker::loadgen::Trace;
use dynamo_tokens::SequenceHash;
use std::collections::HashMap;
......@@ -379,12 +379,7 @@ async fn apply_entry(
isl,
output_length,
} => {
let _ = multi.potential_blocks_and_tokens(
Some(&block_hashes),
isl,
OverlapScores::default(),
decay_now,
);
let _ = multi.potential_blocks_and_tokens(Some(&block_hashes), isl, HashMap::new());
let _ = multi.add_request(
SequenceRequest {
request_id,
......
......@@ -521,7 +521,6 @@ impl RouterHandles {
None,
0.0,
None,
None,
allowed_worker_ids,
)
.await
......@@ -884,12 +883,13 @@ pub unsafe extern "C" fn add_request(
}
};
let cached_tokens = overlap_blocks as usize * decode_router.block_size() as usize;
decode_router
.add_request(
request_id_str.clone(),
&tokens,
None,
overlap_blocks,
cached_tokens,
None,
worker,
None, // lora_name
......
......@@ -7,7 +7,7 @@ from typing import List, Optional
import tensorrt_llm
from kvbm import KvbmLeader
from kvbm.trtllm_integration.consolidator_config import is_truthy
from kvbm.trtllm_integration.consolidator_config import get_consolidator_mode, is_truthy
from kvbm.trtllm_integration.rust import KvbmRequest
from kvbm.trtllm_integration.rust import KvConnectorLeader as RustKvConnectorLeader
from kvbm.trtllm_integration.rust import SchedulerOutput as RustSchedulerOutput
......@@ -55,7 +55,9 @@ class DynamoKVBMConnectorLeader(KvCacheConnectorScheduler):
trtllm_ep = None
consolidator_output_ep = None
consolidator_mode = None
if consolidator_enabled:
consolidator_mode = get_consolidator_mode()
# Get consolidator endpoint from environment variable
# DYN_KVBM_TRTLLM_ZMQ_PORT contains just the port number (e.g., "20081")
zmq_port = os.getenv("DYN_KVBM_TRTLLM_ZMQ_PORT")
......@@ -105,6 +107,7 @@ class DynamoKVBMConnectorLeader(KvCacheConnectorScheduler):
leader,
consolidator_trtllm_endpoint=trtllm_ep,
consolidator_output_endpoint=consolidator_output_ep,
consolidator_mode=consolidator_mode,
)
@nvtx_annotate(category="scheduler")
......@@ -132,6 +135,7 @@ class DynamoKVBMConnectorLeader(KvCacheConnectorScheduler):
req.new_block_ids,
req.computed_position,
req.priorities, # Pass retention priorities for offload filtering
list(req.block_hashes),
)
resumed_from_preemption = False
......@@ -143,6 +147,7 @@ class DynamoKVBMConnectorLeader(KvCacheConnectorScheduler):
req.new_block_ids,
req.computed_position,
req.priorities, # Pass retention priorities for offload filtering
list(req.block_hashes),
)
output.add_num_scheduled_tokens(
......
......@@ -8,21 +8,27 @@ Helper functions for KV Event Consolidator configuration for TensorRT-LLM.
import logging
import os
from kvbm.utils import get_consolidator_mode, is_truthy
__all__ = [
"get_consolidator_endpoints",
"get_consolidator_mode",
"is_truthy",
"should_enable_consolidator",
]
logger = logging.getLogger(__name__)
def is_truthy(val: str) -> bool:
"""
Check if a string represents a truthy value.
Truthy values: "1", "true", "on", "yes" (case-insensitive)
def _get_connector_module(kv_connector_config) -> str | None:
"""Extract connector_module from either a dict or a TRT-LLM config object."""
if kv_connector_config is None:
return None
Args:
val: The string value to check
if isinstance(kv_connector_config, dict):
return kv_connector_config.get("connector_module")
Returns:
True if the value is truthy, False otherwise
"""
return val.lower() in ("1", "true", "on", "yes")
return getattr(kv_connector_config, "connector_module", None)
def should_enable_consolidator(arg_map) -> bool:
......@@ -48,23 +54,19 @@ def should_enable_consolidator(arg_map) -> bool:
)
return False
# Check if KVBM connector is enabled by extracting connector_module
# from kv_connector_config (works whether arg_map holds raw dicts or typed objects)
kv_connector_config = (
arg_map.get("kv_connector_config") if isinstance(arg_map, dict) else None
)
if kv_connector_config is None:
# Check if KVBM connector is enabled
if not isinstance(arg_map, dict):
logger.warning("KV Event Consolidator is not enabled: arg_map is not a dict")
return False
kv_connector_config = arg_map.get("kv_connector_config")
connector_module = _get_connector_module(kv_connector_config) or ""
if not connector_module:
logger.warning(
"KV Event Consolidator is not enabled: no kv_connector_config found"
"KV Event Consolidator is not enabled: kv_connector_config has no connector_module"
)
return False
if isinstance(kv_connector_config, dict):
connector_module = kv_connector_config.get("connector_module", "")
else:
# Access directly so AttributeError surfaces if the contract changes
connector_module = kv_connector_config.connector_module or ""
has_kvbm_connector = "kvbm.trtllm_integration.connector" in connector_module
if not has_kvbm_connector:
......
......@@ -2,8 +2,33 @@
# SPDX-License-Identifier: Apache-2.0
import logging
import os
logger = logging.getLogger(__name__)
def is_truthy(val: str) -> bool:
"""Truthy values: "1", "true", "on", "yes" (case-insensitive)."""
return val.strip().lower() in ("1", "true", "on", "yes")
def get_consolidator_mode() -> str:
"""Return the KV event consolidator mode from DYN_KVBM_KV_EVENTS_CONSOLIDATOR_MODE.
Returns "dedup" or "passthrough"; invalid/unset values fall back to "dedup".
"""
mode = os.getenv("DYN_KVBM_KV_EVENTS_CONSOLIDATOR_MODE", "dedup").strip().lower()
if mode in ("dedup", "passthrough"):
return mode
logger.warning(
"Invalid DYN_KVBM_KV_EVENTS_CONSOLIDATOR_MODE=%r. Falling back to 'dedup'.",
mode,
)
return "dedup"
try:
from nvtx import annotate # type: ignore
except ImportError:
......
......@@ -30,6 +30,7 @@ if TYPE_CHECKING:
from kvbm import KvbmLeader
from kvbm.utils import is_dyn_runtime_enabled
from kvbm.vllm_integration.consolidator_config import get_consolidator_mode
from kvbm.vllm_integration.rust import KvbmRequest
from kvbm.vllm_integration.rust import KvConnectorLeader as RustKvConnectorLeader
from kvbm.vllm_integration.rust import SchedulerOutput as RustSchedulerOutput
......@@ -72,10 +73,12 @@ class KvConnectorLeader:
# Get kv event consolidator endpoints from vllm_config (pre-computed in main.py)
consolidator_vllm_endpoint = None
consolidator_output_endpoint = None
consolidator_mode = None
self._consolidator_output_port = None
_consolidator_eps = vllm_config.additional_config.get("consolidator_endpoints")
if _consolidator_eps:
consolidator_mode = get_consolidator_mode()
# Unpack all three endpoints
# [0]: vllm_endpoint (for consolidator to subscribe to vLLM)
# [1]: output_bind_endpoint (for consolidator to bind/publish)
......@@ -97,6 +100,7 @@ class KvConnectorLeader:
leader,
consolidator_vllm_endpoint=consolidator_vllm_endpoint,
consolidator_output_endpoint=consolidator_output_endpoint,
consolidator_mode=consolidator_mode,
)
else:
# No kv event consolidator - pass None to Rust
......@@ -107,6 +111,7 @@ class KvConnectorLeader:
leader,
consolidator_vllm_endpoint=None,
consolidator_output_endpoint=None,
consolidator_mode=None,
)
# KV Connector
......
......@@ -9,23 +9,17 @@ import logging
import os
from typing import Optional, Tuple
from kvbm.utils import get_consolidator_mode, is_truthy
from vllm.distributed.kv_events import ZmqEventPublisher
logger = logging.getLogger(__name__)
__all__ = [
"get_consolidator_endpoints",
"get_consolidator_mode",
"is_truthy",
"should_enable_consolidator",
]
def is_truthy(val: str) -> bool:
"""
Check if a string represents a truthy value.
Truthy values: "1", "true", "on", "yes" (case-insensitive)
Args:
val: The string value to check
Returns:
True if the value is truthy, False otherwise
"""
return val.lower() in ("1", "true", "on", "yes")
logger = logging.getLogger(__name__)
def should_enable_consolidator(vllm_config) -> bool:
......
......@@ -6,7 +6,7 @@ use anyhow::Result;
use dynamo_llm::block_manager::block::{
data::logical::distributed_leader_worker::DistributedLeaderWorkerResources, locality::Logical,
};
use dynamo_llm::block_manager::kv_consolidator::EventSource;
use dynamo_llm::block_manager::kv_consolidator::{EventSource, KvEventConsolidationMode};
use dynamo_llm::block_manager::offload::filter::FrequencyFilter;
use dynamo_llm::block_manager::{BasicMetadata, BlockParallelismStrategy};
use dynamo_runtime::DistributedRuntime;
......@@ -252,7 +252,12 @@ pub struct BlockManagerBuilder {
page_size: usize,
disable_device_pool: bool,
kvbm_metrics: Option<dynamo_llm::block_manager::metrics_kvbm::KvbmMetrics>,
consolidator_config: Option<(String, Option<String>, EventSource)>, // (engine_endpoint, output_endpoint (optional), engine_source)
consolidator_config: Option<(
String,
Option<String>,
EventSource,
KvEventConsolidationMode,
)>, // (engine_endpoint, output_endpoint (optional), engine_source, mode)
}
impl BlockManagerBuilder {
......@@ -293,8 +298,9 @@ impl BlockManagerBuilder {
engine_endpoint: String,
output_endpoint: Option<String>,
engine_source: EventSource,
mode: KvEventConsolidationMode,
) -> Self {
self.consolidator_config = Some((engine_endpoint, output_endpoint, engine_source));
self.consolidator_config = Some((engine_endpoint, output_endpoint, engine_source, mode));
self
}
......@@ -368,9 +374,9 @@ impl BlockManagerBuilder {
config_builder = config_builder.kvbm_metrics(Some(kvbm_metrics));
}
if let Some((engine_ep, output_ep, engine_source)) = self.consolidator_config {
if let Some((engine_ep, output_ep, engine_source, mode)) = self.consolidator_config {
config_builder =
config_builder.consolidator_config(engine_ep, output_ep, engine_source);
config_builder.consolidator_config(engine_ep, output_ep, engine_source, mode);
}
let config = config_builder.build()?;
......
......@@ -4,6 +4,7 @@
use dynamo_llm::block_manager::{
block::BlockId, connector::protocol::WorkerTransferRequest, pool::BlockPoolError,
};
use dynamo_llm::tokens::SequenceHash;
pub mod leader;
pub mod trtllm_leader;
......@@ -42,7 +43,7 @@ impl SchedulerOutput {
// I am surprised that vLLM's NewRequestData does not include the salt hash.
// It has almost everything else to compute the block hashes worker side.
#[pyo3(signature = (request_id, prompt_token_ids, block_ids, num_computed_tokens, priorities=None))]
#[pyo3(signature = (request_id, prompt_token_ids, block_ids, num_computed_tokens, priorities=None, external_sequence_hashes=None))]
pub fn add_new_request(
&mut self,
request_id: String,
......@@ -50,6 +51,7 @@ impl SchedulerOutput {
block_ids: Vec<BlockId>,
num_computed_tokens: usize,
priorities: Option<Vec<u32>>,
external_sequence_hashes: Option<Vec<SequenceHash>>,
) {
self.new_requests.push(NewRequestData {
request_id,
......@@ -57,11 +59,13 @@ impl SchedulerOutput {
block_ids,
num_computed_tokens,
priorities,
external_sequence_hashes,
});
}
/// This is called by the leader to update the cached requests
#[pyo3(signature = (request_id, resumed_from_preemption, new_token_ids, new_block_ids, num_computed_tokens, priorities=None))]
#[pyo3(signature = (request_id, resumed_from_preemption, new_token_ids, new_block_ids, num_computed_tokens, priorities=None, external_sequence_hashes=None))]
#[allow(clippy::too_many_arguments)]
pub fn add_cached_request(
&mut self,
request_id: String,
......@@ -70,6 +74,7 @@ impl SchedulerOutput {
new_block_ids: Vec<BlockId>,
num_computed_tokens: usize,
priorities: Option<Vec<u32>>,
external_sequence_hashes: Option<Vec<SequenceHash>>,
) {
self.cached_requests.push(CachedRequestData {
request_id,
......@@ -78,6 +83,7 @@ impl SchedulerOutput {
new_block_ids,
num_computed_tokens,
priorities,
external_sequence_hashes,
});
}
......@@ -108,6 +114,8 @@ pub struct NewRequestData {
/// Retention priorities for each block (same length as block_ids).
/// Used for priority-based offload filtering.
pub priorities: Option<Vec<u32>>,
/// TRT-LLM cumulative external sequence-hash chain for all completed blocks.
pub external_sequence_hashes: Option<Vec<SequenceHash>>,
}
impl std::fmt::Debug for NewRequestData {
......@@ -131,6 +139,8 @@ pub struct CachedRequestData {
/// Retention priorities for each new block (same length as new_block_ids).
/// Used for priority-based offload filtering.
pub priorities: Option<Vec<u32>>,
/// TRT-LLM cumulative external sequence-hash chain for all completed blocks.
pub external_sequence_hashes: Option<Vec<SequenceHash>>,
}
impl std::fmt::Debug for CachedRequestData {
......@@ -188,3 +198,40 @@ impl ConnectorMetadata {
self.operations.extend(xfer_reqs);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_scheduler_output_preserves_external_sequence_hashes() {
let mut output = SchedulerOutput::new();
output.add_new_request(
"req-1".to_string(),
vec![1, 2, 3, 4],
vec![10],
4,
None,
Some(vec![101]),
);
output.add_cached_request(
"req-1".to_string(),
false,
vec![5, 6, 7, 8],
vec![11],
8,
None,
Some(vec![101, 202]),
);
assert_eq!(
output.new_requests[0].external_sequence_hashes,
Some(vec![101])
);
assert_eq!(
output.cached_requests[0].external_sequence_hashes,
Some(vec![101, 202])
);
}
}
......@@ -22,7 +22,7 @@ use dynamo_llm::block_manager::{
locality::Logical,
},
connector::{protocol::RequestType, *},
kv_consolidator::EventSource,
kv_consolidator::{EventSource, KvEventConsolidationMode},
};
use dynamo_llm::tokens::{SaltHash, TokenBlockSequence, Tokens};
use dynamo_runtime::config::environment_names::kvbm as env_kvbm;
......@@ -35,6 +35,24 @@ use tokio::sync::oneshot;
type VllmLocality = Logical<DistributedLeaderWorkerResources>;
fn parse_consolidator_mode(mode: Option<String>) -> KvEventConsolidationMode {
let Some(mode) = mode else {
return KvEventConsolidationMode::Dedup;
};
match mode.parse() {
Ok(mode) => mode,
Err(error) => {
tracing::warn!(
"Invalid KV event consolidator mode {:?}: {}. Falling back to dedup.",
mode,
error
);
KvEventConsolidationMode::Dedup
}
}
}
impl From<SlotError> for PyErr {
fn from(err: SlotError) -> Self {
to_pyerr(err)
......@@ -94,6 +112,7 @@ impl KvConnectorLeader {
leader_py: PyKvbmLeader,
consolidator_vllm_endpoint: Option<String>,
consolidator_output_endpoint: Option<String>,
consolidator_mode: Option<String>,
) -> Self {
tracing::info!(
"KvConnectorLeader initialized with worker_id: {}",
......@@ -118,6 +137,7 @@ impl KvConnectorLeader {
// Capture consolidator endpoints for the async block
let consolidator_vllm_ep = consolidator_vllm_endpoint.clone();
let consolidator_output_ep = consolidator_output_endpoint.clone();
let consolidator_mode = parse_consolidator_mode(consolidator_mode.clone());
handle.spawn(async move {
let ready = leader.wait_worker_sync_ready().await;
......@@ -148,6 +168,7 @@ impl KvConnectorLeader {
vllm_ep,
Some(output_ep),
EventSource::Vllm,
consolidator_mode,
);
}
......@@ -435,6 +456,7 @@ impl Leader for KvConnectorLeader {
new_req.num_computed_tokens,
scheduled_tokens,
None,
None,
)?;
let pending_ops_opt = slot.take_pending_operations();
......@@ -506,6 +528,7 @@ impl Leader for KvConnectorLeader {
cached_req.num_computed_tokens,
scheduled_tokens,
None,
None,
)?;
if let Some(pending_ops) = slot.take_pending_operations() {
......@@ -621,7 +644,7 @@ pub struct PyKvConnectorLeader {
#[pymethods]
impl PyKvConnectorLeader {
#[new]
#[pyo3(signature = (worker_id, drt, page_size, leader, consolidator_vllm_endpoint=None, consolidator_output_endpoint=None))]
#[pyo3(signature = (worker_id, drt, page_size, leader, consolidator_vllm_endpoint=None, consolidator_output_endpoint=None, consolidator_mode=None))]
pub fn new(
worker_id: String,
drt: Option<PyObject>,
......@@ -629,6 +652,7 @@ impl PyKvConnectorLeader {
leader: PyKvbmLeader,
consolidator_vllm_endpoint: Option<String>,
consolidator_output_endpoint: Option<String>,
consolidator_mode: Option<String>,
) -> PyResult<Self> {
let _ = &drt; // drt is currently un-used in leader
......@@ -646,6 +670,7 @@ impl PyKvConnectorLeader {
leader,
consolidator_vllm_endpoint,
consolidator_output_endpoint,
consolidator_mode,
))
} else {
Box::new(KvConnectorLeader::new(
......@@ -654,6 +679,7 @@ impl PyKvConnectorLeader {
leader,
consolidator_vllm_endpoint,
consolidator_output_endpoint,
consolidator_mode,
))
};
Ok(Self { connector_leader })
......
......@@ -92,6 +92,7 @@ impl KvConnectorLeaderRecorder {
leader_py: PyKvbmLeader,
consolidator_vllm_endpoint: Option<String>,
consolidator_output_endpoint: Option<String>,
consolidator_mode: Option<String>,
) -> Self {
tracing::info!(
"KvConnectorLeaderRecorder initialized with worker_id: {}",
......@@ -131,6 +132,7 @@ impl KvConnectorLeaderRecorder {
// Capture consolidator endpoints for the async block
let consolidator_vllm_ep = consolidator_vllm_endpoint.clone();
let consolidator_output_ep = consolidator_output_endpoint.clone();
let consolidator_mode = super::parse_consolidator_mode(consolidator_mode.clone());
handle.spawn(async move {
let ready = leader.wait_worker_sync_ready().await;
......@@ -156,6 +158,7 @@ impl KvConnectorLeaderRecorder {
vllm_ep,
Some(output_ep),
EventSource::Vllm,
consolidator_mode,
);
}
......
......@@ -16,7 +16,7 @@ use dynamo_llm::{
connector::protocol::{LeaderTransferRequest, RequestType, TransferType},
distributed::{BlockTransferPool, BlockTransferRequest, KvbmLeader},
},
tokens::TokenBlock,
tokens::{SequenceHash, TokenBlock},
};
use dynamo_runtime::utils::task::CriticalTaskExecutionHandle;
use tokio_util::sync::CancellationToken;
......@@ -114,6 +114,7 @@ pub trait Slot: std::fmt::Debug {
num_computed_tokens: usize,
num_scheduled_tokens: usize,
priorities: Option<&[u32]>,
external_sequence_hashes: Option<&[SequenceHash]>,
) -> Result<(), SlotError>;
fn record_start_iteration(&mut self, iteration: u64) -> Result<(), SlotError>;
......@@ -481,6 +482,29 @@ impl VllmConnectorSlot {
&self.device_blocks
}
fn sync_external_sequence_hashes(
&mut self,
external_sequence_hashes: &[SequenceHash],
) -> Result<(), SlotError> {
assert_eq!(
external_sequence_hashes.len(),
self.sequence.blocks().len(),
"external_sequence_hashes length ({}) must match completed block count ({}) for request {}",
external_sequence_hashes.len(),
self.sequence.blocks().len(),
self.request_id
);
self.sequence
.sync_external_sequence_hashes(external_sequence_hashes);
for block in self.sequence.blocks() {
block.assert_external_hashes_assigned();
}
Ok(())
}
fn mark_as_skipped_prefill(&mut self) -> Result<(), SlotError> {
if self.state != SlotState::Prefilling {
return Err(SlotError::InvalidState(format!(
......@@ -597,6 +621,7 @@ impl Slot for VllmConnectorSlot {
num_computed_tokens: usize,
num_scheduled_tokens: usize,
priorities: Option<&[u32]>,
external_sequence_hashes: Option<&[SequenceHash]>,
) -> Result<(), SlotError> {
tracing::debug!(
"ENTRY: apply_scheduler_output: req={}, tokens.len={}, block_ids.len={}, computed={}, scheduled={}, \
......@@ -634,6 +659,10 @@ impl Slot for VllmConnectorSlot {
self.state = SlotState::Prefilling;
}
if let Some(external_sequence_hashes) = external_sequence_hashes {
self.sync_external_sequence_hashes(external_sequence_hashes)?;
}
// Use max to advance both current_position and evaluated_blocks at least by num_computed_tokens.
// This logic is to prevent redundant block offloading.
self.current_position = max(self.current_position, num_computed_tokens);
......@@ -849,6 +878,12 @@ impl Slot for VllmConnectorSlot {
.copied()
.collect();
if external_sequence_hashes.is_some() {
for block in &offload_token_blocks {
block.assert_external_hashes_assigned();
}
}
self.offload_blocks(
&offload_block_ids,
&offload_token_blocks,
......@@ -1878,7 +1913,7 @@ impl<S: Storage, L: LocalityProvider, M: BlockMetadata> AnyBlocks for AnyImmutab
mod connector_tests {
use super::*;
use crate::block_manager::cache_stats::CacheStatsTracker;
use dynamo_llm::tokens::{SaltHash, Tokens};
use dynamo_llm::tokens::{SaltHash, SequenceHash, Tokens};
use std::sync::Arc;
use tokio::sync::mpsc;
......@@ -1914,6 +1949,12 @@ mod connector_tests {
(start..start + count).collect()
}
fn external_hashes(start: SequenceHash, count: usize) -> Vec<SequenceHash> {
(0..count)
.map(|offset| start + offset as SequenceHash)
.collect()
}
/// Drains all pending offload requests from the channel and returns their block IDs.
fn drain_offload_block_ids(
rx: &mut mpsc::UnboundedReceiver<LocalTransferRequest>,
......@@ -1941,7 +1982,7 @@ mod connector_tests {
assert_eq!(slot.num_device_blocks_allocated(), 3);
// Step 2: apply_scheduler_output with empty blocks (vLLM pattern)
slot.apply_scheduler_output(&[], &[], 0, num_tokens, None)
slot.apply_scheduler_output(&[], &[], 0, num_tokens, None, None)
.unwrap();
// device_blocks should still be exactly 3 — no double-add
......@@ -1964,13 +2005,70 @@ mod connector_tests {
// Step 2: apply_scheduler_output with THE SAME blocks (TRT-LLM pattern)
// Without the dedup guard, this doubles device_blocks to len=6.
slot.apply_scheduler_output(&[], &blocks, 0, num_tokens, None)
slot.apply_scheduler_output(&[], &blocks, 0, num_tokens, None, None)
.unwrap();
// device_blocks must still be exactly 3 — dedup guard prevented the double-add
assert_eq!(slot.num_device_blocks_allocated(), 3);
}
#[test]
fn test_trtllm_external_hashes_are_assigned_to_completed_blocks() {
let num_tokens = 96; // 3 blocks of 32
let (mut slot, _rx) = create_test_slot(num_tokens, 0);
let blocks = block_ids(100, 3);
let external_hashes = external_hashes(10_000, 3);
slot.append_mutable_device_blocks(&blocks).unwrap();
slot.apply_scheduler_output(&[], &blocks, 0, num_tokens, None, Some(&external_hashes))
.unwrap();
let completed_blocks = slot.sequence.blocks();
assert_eq!(completed_blocks.len(), 3);
assert_eq!(
completed_blocks[0].external_sequence_hash(),
Some(external_hashes[0])
);
assert_eq!(completed_blocks[0].external_parent_sequence_hash(), None);
assert_eq!(
completed_blocks[1].external_sequence_hash(),
Some(external_hashes[1])
);
assert_eq!(
completed_blocks[1].external_parent_sequence_hash(),
Some(external_hashes[0])
);
assert_eq!(
completed_blocks[2].external_sequence_hash(),
Some(external_hashes[2])
);
assert_eq!(
completed_blocks[2].external_parent_sequence_hash(),
Some(external_hashes[1])
);
}
#[test]
#[should_panic(expected = "external_sequence_hash mismatch")]
fn test_trtllm_external_hash_chain_mismatch_panics() {
let num_tokens = 96; // 3 blocks of 32
let (mut slot, _rx) = create_test_slot(num_tokens, 0);
let blocks = block_ids(100, 3);
let external_hashes = external_hashes(20_000, 3);
slot.append_mutable_device_blocks(&blocks).unwrap();
slot.apply_scheduler_output(&[], &blocks, 0, num_tokens, None, Some(&external_hashes))
.unwrap();
let mismatched_hashes = vec![
external_hashes[0],
external_hashes[1] + 1,
external_hashes[2],
];
slot.apply_scheduler_output(&[], &blocks, 0, num_tokens, None, Some(&mismatched_hashes))
.unwrap();
}
// ---------------------------------------------------------------
// Test 3: Decode adds a new block correctly
// ---------------------------------------------------------------
......@@ -1982,14 +2080,14 @@ mod connector_tests {
// Prefill: append + apply with empty blocks (vLLM pattern)
slot.append_mutable_device_blocks(&prefill_blocks).unwrap();
slot.apply_scheduler_output(&[], &[], 0, num_tokens, None)
slot.apply_scheduler_output(&[], &[], 0, num_tokens, None, None)
.unwrap();
assert_eq!(slot.num_device_blocks_allocated(), 3);
// Decode: new block at boundary (token 96 = block 3)
let decode_block = block_ids(200, 1);
let decode_token: Vec<u32> = vec![9999];
slot.apply_scheduler_output(&decode_token, &decode_block, 95, 1, None)
slot.apply_scheduler_output(&decode_token, &decode_block, 95, 1, None, None)
.unwrap();
assert_eq!(slot.num_device_blocks_allocated(), 4);
}
......@@ -2024,7 +2122,7 @@ mod connector_tests {
// Empty tokens → Prefilling state, and next_position(96) == total_tokens(96)
// so the early-return does not fire and offload proceeds.
slot.append_mutable_device_blocks(&blocks).unwrap();
slot.apply_scheduler_output(&[], &[], 0, num_tokens, None)
slot.apply_scheduler_output(&[], &[], 0, num_tokens, None, None)
.unwrap();
let offloads = drain_offload_block_ids(&mut rx);
......@@ -2045,7 +2143,7 @@ mod connector_tests {
// Use the TRT-LLM pattern: append_mutable first, then apply with same blocks + priorities.
// The dedup guard prevents the double-add, but priorities are still processed.
slot.append_mutable_device_blocks(&blocks).unwrap();
slot.apply_scheduler_output(&[], &blocks, 0, num_tokens, Some(&priorities))
slot.apply_scheduler_output(&[], &blocks, 0, num_tokens, Some(&priorities), None)
.unwrap();
// device_blocks should be 3 (dedup prevented doubling)
......@@ -2069,7 +2167,7 @@ mod connector_tests {
let priorities: Vec<u32> = vec![80, 80, 10, 10];
slot.append_mutable_device_blocks(&blocks).unwrap();
slot.apply_scheduler_output(&[], &blocks, 0, num_tokens, Some(&priorities))
slot.apply_scheduler_output(&[], &blocks, 0, num_tokens, Some(&priorities), None)
.unwrap();
let offloads = drain_offload_block_ids(&mut rx);
......@@ -2079,7 +2177,7 @@ mod connector_tests {
// Because offload was terminated, no further offloading should happen.
let decode_block = block_ids(200, 1);
let decode_token: Vec<u32> = vec![9999];
slot.apply_scheduler_output(&decode_token, &decode_block, 127, 1, None)
slot.apply_scheduler_output(&decode_token, &decode_block, 127, 1, None, None)
.unwrap();
let further_offloads = drain_offload_block_ids(&mut rx);
......@@ -2102,14 +2200,16 @@ mod connector_tests {
slot.append_mutable_device_blocks(&blocks).unwrap();
// Chunk 1: schedule first 64 tokens → evaluates blocks 0,1
slot.apply_scheduler_output(&[], &[], 0, 64, None).unwrap();
slot.apply_scheduler_output(&[], &[], 0, 64, None, None)
.unwrap();
let offloads_1 = drain_offload_block_ids(&mut rx);
assert_eq!(offloads_1.len(), 1);
assert_eq!(offloads_1[0], vec![100, 101]); // blocks 0,1
// Chunk 2: schedule next 64 tokens → evaluates blocks 2,3
// (uses cached_request pattern: empty tokens, empty blocks)
slot.apply_scheduler_output(&[], &[], 64, 64, None).unwrap();
slot.apply_scheduler_output(&[], &[], 64, 64, None, None)
.unwrap();
let offloads_2 = drain_offload_block_ids(&mut rx);
assert_eq!(offloads_2.len(), 1);
assert_eq!(offloads_2[0], vec![102, 103]); // blocks 2,3
......@@ -2133,7 +2233,7 @@ mod connector_tests {
// Step 2: apply_scheduler_output with overlapping blocks [12, 13].
// Suffix [12] of device_blocks matches prefix [12] of block_ids.
// Only block 13 is new and gets appended.
slot.apply_scheduler_output(&[], &[12, 13], 0, 128, None)
slot.apply_scheduler_output(&[], &[12, 13], 0, 128, None, None)
.unwrap();
assert_eq!(slot.num_device_blocks_allocated(), 4);
......@@ -2152,7 +2252,7 @@ mod connector_tests {
// Chunk 1: 3 blocks, all high priority, schedule 96 tokens
slot.append_mutable_device_blocks(&[10, 11, 12]).unwrap();
slot.apply_scheduler_output(&[], &[10, 11, 12], 0, 96, Some(&[80, 80, 80]))
slot.apply_scheduler_output(&[], &[10, 11, 12], 0, 96, Some(&[80, 80, 80]), None)
.unwrap();
let offloads_1 = drain_offload_block_ids(&mut rx);
......@@ -2171,7 +2271,7 @@ mod connector_tests {
slot.append_mutable_device_blocks(&[13]).unwrap();
assert_eq!(slot.num_device_blocks_allocated(), 4);
slot.apply_scheduler_output(&[], &[12, 13], 96, 32, Some(&[80, 10]))
slot.apply_scheduler_output(&[], &[12, 13], 96, 32, Some(&[80, 10]), None)
.unwrap();
// Candidate is block 13 (index 3, evaluated_blocks=3).
......@@ -2198,7 +2298,7 @@ mod connector_tests {
// block_ids[0]=11 is found at device_blocks[1], but device_blocks[1..] = [11,12]
// does NOT match block_ids[..2] = [11,14]. Contract violation.
slot.apply_scheduler_output(&[], &[11, 14], 0, 128, None)
slot.apply_scheduler_output(&[], &[11, 14], 0, 128, None, None)
.unwrap();
}
......@@ -2218,7 +2318,7 @@ mod connector_tests {
// Overlap: suffix [13,14] matches prefix [13,14], overlap=2.
// new_ids = [10]. But 10 ∈ device_blocks → contract violation.
slot.apply_scheduler_output(&[], &[13, 14, 10], 0, 192, None)
slot.apply_scheduler_output(&[], &[13, 14, 10], 0, 192, None, None)
.unwrap();
}
......@@ -2238,7 +2338,7 @@ mod connector_tests {
// block_ids[0]=10 found at device_blocks[0], suffix_len=3.
// device_blocks[0..3]=[10,11,12] == block_ids[0..3]=[10,11,12] → overlap=3.
// new_ids=[13,14], both genuinely new → extend.
slot.apply_scheduler_output(&[], &[10, 11, 12, 13, 14], 0, 160, None)
slot.apply_scheduler_output(&[], &[10, 11, 12, 13, 14], 0, 160, None, None)
.unwrap();
assert_eq!(slot.num_device_blocks_allocated(), 5);
......@@ -2261,7 +2361,7 @@ mod connector_tests {
assert_eq!(slot.num_device_blocks_allocated(), 5);
// Full overlap of all 5 existing blocks, plus 2 new.
slot.apply_scheduler_output(&[], &[10, 11, 12, 13, 14, 15, 16], 0, 224, None)
slot.apply_scheduler_output(&[], &[10, 11, 12, 13, 14, 15, 16], 0, 224, None, None)
.unwrap();
assert_eq!(slot.num_device_blocks_allocated(), 7);
......
......@@ -14,12 +14,30 @@ use crate::block_manager::{distributed::KvbmLeader as PyKvbmLeader, vllm::KvbmRe
use crate::get_current_tokio_handle;
use anyhow;
use dynamo_llm::block_manager::connector::protocol::RequestType;
use dynamo_llm::block_manager::kv_consolidator::EventSource;
use dynamo_llm::block_manager::kv_consolidator::{EventSource, KvEventConsolidationMode};
use dynamo_llm::block_manager::metrics_kvbm::{KvbmMetrics, KvbmMetricsRegistry};
use std::collections::HashSet;
use std::sync::{Arc, OnceLock};
use tokio::runtime::Handle;
fn parse_consolidator_mode(mode: Option<String>) -> KvEventConsolidationMode {
let Some(mode) = mode else {
return KvEventConsolidationMode::Dedup;
};
match mode.parse() {
Ok(mode) => mode,
Err(error) => {
tracing::warn!(
"Invalid KV event consolidator mode {:?}: {}. Falling back to dedup.",
mode,
error
);
KvEventConsolidationMode::Dedup
}
}
}
pub trait Leader: Send + Sync + std::fmt::Debug {
fn get_num_new_matched_tokens(
&mut self,
......@@ -71,6 +89,7 @@ impl KvConnectorLeader {
leader_py: PyKvbmLeader,
consolidator_trtllm_endpoint: Option<String>,
consolidator_output_endpoint: Option<String>,
consolidator_mode: Option<String>,
) -> Self {
tracing::info!(
"KvConnectorLeader initialized with worker_id: {}",
......@@ -95,6 +114,7 @@ impl KvConnectorLeader {
// Capture consolidator endpoints for the async block
let consolidator_trtllm_ep = consolidator_trtllm_endpoint.clone();
let consolidator_output_ep = consolidator_output_endpoint.clone();
let consolidator_mode = parse_consolidator_mode(consolidator_mode.clone());
handle.spawn(async move {
let ready = leader.wait_worker_sync_ready().await;
......@@ -125,6 +145,7 @@ impl KvConnectorLeader {
trtllm_ep,
consolidator_output_ep,
EventSource::Trtllm,
consolidator_mode,
);
}
......@@ -389,6 +410,7 @@ impl Leader for KvConnectorLeader {
new_req.num_computed_tokens,
scheduled_tokens,
new_req.priorities.as_deref(),
new_req.external_sequence_hashes.as_deref(),
)?;
let pending_ops_opt = slot.take_pending_operations();
......@@ -440,6 +462,7 @@ impl Leader for KvConnectorLeader {
cached_req.num_computed_tokens,
scheduled_tokens,
cached_req.priorities.as_deref(),
cached_req.external_sequence_hashes.as_deref(),
)?;
if let Some(pending_ops) = slot.take_pending_operations() {
......@@ -518,7 +541,7 @@ pub struct PyTrtllmKvConnectorLeader {
#[pymethods]
impl PyTrtllmKvConnectorLeader {
#[new]
#[pyo3(signature = (worker_id, drt, page_size, leader, consolidator_trtllm_endpoint=None, consolidator_output_endpoint=None))]
#[pyo3(signature = (worker_id, drt, page_size, leader, consolidator_trtllm_endpoint=None, consolidator_output_endpoint=None, consolidator_mode=None))]
pub fn new(
worker_id: u64,
drt: Option<PyObject>,
......@@ -526,6 +549,7 @@ impl PyTrtllmKvConnectorLeader {
leader: PyKvbmLeader,
consolidator_trtllm_endpoint: Option<String>,
consolidator_output_endpoint: Option<String>,
consolidator_mode: Option<String>,
) -> PyResult<Self> {
let _ = &drt; // drt is currently un-used in leader
......@@ -535,6 +559,7 @@ impl PyTrtllmKvConnectorLeader {
leader,
consolidator_trtllm_endpoint,
consolidator_output_endpoint,
consolidator_mode,
));
Ok(Self { connector_leader })
}
......
......@@ -126,10 +126,12 @@ impl AicPerfConfig {
#[pymethods]
impl KvRouterConfig {
#[new]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, durable_kv_events=false, router_replica_sync=false, router_track_active_blocks=true, router_track_output_blocks=false, router_assume_kv_reuse=true, router_track_prefill_tokens=true, router_prefill_load_model="none", router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8, router_queue_threshold=Some(4.0), router_event_threads=4, router_queue_policy="fcfs", use_remote_indexer=false, serve_indexer=false, shared_cache_multiplier=0.0, shared_cache_type="none"))]
#[pyo3(signature = (overlap_score_weight=1.0, host_cache_hit_weight=0.75, disk_cache_hit_weight=0.25, router_temperature=0.0, use_kv_events=true, durable_kv_events=false, router_replica_sync=false, router_track_active_blocks=true, router_track_output_blocks=false, router_assume_kv_reuse=true, router_track_prefill_tokens=true, router_prefill_load_model="none", router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8, router_queue_threshold=Some(4.0), router_event_threads=4, router_queue_policy="fcfs", use_remote_indexer=false, serve_indexer=false, shared_cache_multiplier=0.0, shared_cache_type="none"))]
#[allow(clippy::too_many_arguments)]
fn new(
overlap_score_weight: f64,
host_cache_hit_weight: f64,
disk_cache_hit_weight: f64,
router_temperature: f64,
use_kv_events: bool,
durable_kv_events: bool,
......@@ -155,6 +157,8 @@ impl KvRouterConfig {
KvRouterConfig {
inner: RsKvRouterConfig {
overlap_score_weight,
host_cache_hit_weight,
disk_cache_hit_weight,
router_temperature,
use_kv_events,
durable_kv_events,
......
......@@ -1078,7 +1078,6 @@ impl KvRouter {
lora_name.clone(),
0.0,
None,
None,
None, // allowed_worker_ids: pass via RoutingHints in PreprocessedRequest path
)
.await
......
......@@ -68,7 +68,8 @@ use std::collections::VecDeque;
use std::sync::atomic::{AtomicUsize, Ordering};
use super::{
EventKind, EventWarningKind, KvIndexerMetrics, PreBoundEventCounters, SyncIndexer, WorkerTask,
EventKind, EventWarningKind, KvIndexerMetrics, MatchDetails, PreBoundEventCounters,
SyncIndexer, WorkerTask,
};
use crate::cleanup::{self, CleanableNode, CleanupGuard, CleanupState};
use crate::protocols::*;
......@@ -479,26 +480,36 @@ impl ConcurrentRadixTreeCompressed {
// ------------------------------------------------------------------
/// Traverse the radix tree to find the best match for a given sequence of
/// [`LocalBlockHash`]es.
/// [`LocalBlockHash`]es, returning both overlap scores and the last matched
/// `ExternalSequenceBlockHash` per worker (used for lower-tier continuation).
///
/// Workers in `full_edge_workers` are tracked in the `active` set and continue
/// into children. Workers in `worker_cutoffs` are scored at the node where their
/// cutoff falls short and are never propagated into children.
pub fn find_matches_impl(
pub fn find_match_details_impl(
&self,
sequence: &[LocalBlockHash],
early_exit: bool,
) -> OverlapScores {
let mut scores = OverlapScores::new();
) -> MatchDetails {
let mut details = MatchDetails::new();
if sequence.is_empty() {
return scores;
return details;
}
let MatchDetails {
overlap_scores: ref mut scores,
ref mut last_matched_hashes,
} = details;
let mut active: FxHashSet<WorkerWithDpRank> = FxHashSet::default();
let mut active_count: usize = 0;
let mut matched_depth: u32 = 0;
let mut seq_pos: usize = 0;
let mut first_node = true;
// Last ExternalSequenceBlockHash from the previous fully-matched edge.
// Workers that drop at a node boundary (not present in the new node)
// were last matched at the end of the previous edge.
let mut prev_edge_last_hash: Option<ExternalSequenceBlockHash> = None;
let mut next_child = {
let root_guard = read_lock!(self, self.root);
......@@ -531,38 +542,49 @@ impl ConcurrentRadixTreeCompressed {
}
edge_match_len = match_len;
// Helper: ExternalSequenceBlockHash at a given depth within this edge.
let edge_hash_at = |depth: usize| -> ExternalSequenceBlockHash {
debug_assert!(depth > 0 && depth <= guard.edge.len());
guard.edge[depth - 1].1
};
let prev_depth = matched_depth;
if first_node {
// Seed active set from full-edge workers (they can continue to children).
// Score partial workers immediately; they never continue into children.
active = guard.full_edge_workers.clone();
active_count = active.len();
for (&w, &k) in &guard.worker_cutoffs {
let contribution = k.min(edge_match_len) as u32;
let contribution = k.min(edge_match_len);
if contribution > 0 {
scores.scores.insert(w, contribution);
scores.scores.insert(w, contribution as u32);
last_matched_hashes.insert(w, edge_hash_at(contribution));
}
}
first_node = false;
} else {
let has_partial = !guard.worker_cutoffs.is_empty();
if has_partial {
// Slow path: check each active worker against both maps.
active.retain(|w| {
if guard.full_edge_workers.contains(w) {
true
} else if let Some(&k) = guard.worker_cutoffs.get(w) {
let effective = k.min(edge_match_len) as u32;
scores.scores.insert(*w, prev_depth + effective);
let effective = k.min(edge_match_len);
scores.scores.insert(*w, prev_depth + effective as u32);
if effective > 0 {
last_matched_hashes.insert(*w, edge_hash_at(effective));
} else if let Some(h) = prev_edge_last_hash {
last_matched_hashes.insert(*w, h);
}
false
} else {
scores.scores.insert(*w, prev_depth);
if let Some(h) = prev_edge_last_hash {
last_matched_hashes.insert(*w, h);
}
false
}
});
} else {
// Fast path: no partial workers — all coverage is full or absent.
let full_count = guard.full_edge_workers.len();
if full_count != active_count {
active.retain(|w| {
......@@ -570,11 +592,13 @@ impl ConcurrentRadixTreeCompressed {
true
} else {
scores.scores.insert(*w, prev_depth);
if let Some(h) = prev_edge_last_hash {
last_matched_hashes.insert(*w, h);
}
false
}
});
}
// full_count == active_count: sets are identical (fast path).
}
active_count = active.len();
}
......@@ -590,6 +614,9 @@ impl ConcurrentRadixTreeCompressed {
} else {
None
};
// Track the deepest matched hash in this edge (both full and partial).
prev_edge_last_hash = Some(guard.edge[edge_match_len - 1].1);
}
if active_count == 0 {
......@@ -605,15 +632,32 @@ impl ConcurrentRadixTreeCompressed {
}
}
// Record scores and hashes for workers that survived to the deepest level.
if let Some(h) = prev_edge_last_hash {
for worker in &active {
scores.scores.insert(*worker, matched_depth);
last_matched_hashes.insert(*worker, h);
}
} else {
for worker in &active {
scores.scores.insert(*worker, matched_depth);
}
}
for worker in scores.scores.keys() {
if let Some(s) = self.tree_sizes.get(worker) {
scores.tree_sizes.insert(*worker, s.load(Ordering::Relaxed));
}
}
scores
details
}
pub fn find_matches_impl(
&self,
sequence: &[LocalBlockHash],
early_exit: bool,
) -> OverlapScores {
self.find_match_details_impl(sequence, early_exit)
.overlap_scores
}
// ------------------------------------------------------------------
......
......@@ -12,7 +12,8 @@ use tokio_util::sync::CancellationToken;
use super::{
DumpRequest, EventKind, GetWorkersRequest, KvIndexerInterface, KvIndexerMetrics, KvRouterError,
MatchRequest, PreBoundEventCounters, RadixTree, RoutingDecisionRequest,
MatchDetails, MatchDetailsRequest, MatchRequest, PreBoundEventCounters, RadixTree,
RoutingDecisionRequest,
};
use crate::indexer::pruning::{BlockEntry, PruneConfig, PruneManager};
use crate::protocols::*;
......@@ -95,6 +96,8 @@ pub struct KvIndexer {
event_tx: mpsc::Sender<RouterEvent>,
/// A sender for `MatchRequest`s.
match_tx: mpsc::Sender<MatchRequest>,
/// A sender for `MatchDetailsRequest`s.
match_details_tx: mpsc::Sender<MatchDetailsRequest>,
/// A sender for remove worker requests.
remove_worker_tx: mpsc::Sender<WorkerId>,
/// A sender for remove worker dp_rank requests.
......@@ -136,6 +139,7 @@ impl KvIndexer {
let (event_tx, event_rx) = mpsc::channel::<RouterEvent>(16384);
let (match_tx, match_rx) = mpsc::channel::<MatchRequest>(128);
let (match_details_tx, match_details_rx) = mpsc::channel::<MatchDetailsRequest>(128);
let (remove_worker_tx, remove_worker_rx) = mpsc::channel::<WorkerId>(16);
let (remove_worker_dp_rank_tx, remove_worker_dp_rank_rx) =
mpsc::channel::<(WorkerId, DpRank)>(16);
......@@ -156,6 +160,7 @@ impl KvIndexer {
runtime.block_on(async move {
let cancel = cancel_clone;
let mut match_rx = match_rx;
let mut match_details_rx = match_details_rx;
let mut event_rx = event_rx;
let mut remove_worker_rx = remove_worker_rx;
let mut remove_worker_dp_rank_rx = remove_worker_dp_rank_rx;
......@@ -320,6 +325,11 @@ impl KvIndexer {
let _ = req.resp.send(matches);
}
Some(req) = match_details_rx.recv() => {
let matches = trie.find_match_details(req.sequence, req.early_exit);
let _ = req.resp.send(matches);
}
_ = expiry_fut => {
// TTL-based expiry triggered
let Some(ref mut pm) = prune_manager else { continue };
......@@ -351,6 +361,7 @@ impl KvIndexer {
cancel: token,
event_tx,
match_tx,
match_details_tx,
remove_worker_tx,
remove_worker_dp_rank_tx,
get_workers_tx,
......@@ -382,6 +393,21 @@ impl KvIndexer {
self.event_tx.clone()
}
pub async fn find_match_details(
&self,
sequence: Vec<LocalBlockHash>,
) -> Result<MatchDetails, KvRouterError> {
let (resp_tx, resp_rx) = oneshot::channel();
self.match_details_tx
.send(MatchDetailsRequest::new(sequence, false, resp_tx))
.await
.map_err(|_| KvRouterError::IndexerOffline)?;
resp_rx
.await
.map_err(|_| KvRouterError::IndexerDroppedRequest)
}
#[cfg(test)]
pub fn snapshot_event_sender(&self) -> mpsc::Sender<DumpRequest> {
self.dump_tx.clone()
......
......@@ -2,7 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
use std::{
collections::VecDeque,
collections::{HashMap, VecDeque},
sync::{Arc, Mutex},
};
......@@ -13,7 +13,7 @@ use tokio_util::sync::CancellationToken;
use super::{
GetWorkersRequest, KvIndexer, KvIndexerInterface, KvIndexerMetrics, KvRouterError,
WorkerKvQueryResponse,
LowerTierIndexer, ThreadPoolIndexer, WorkerKvQueryResponse,
};
use crate::protocols::*;
......@@ -205,6 +205,8 @@ impl RecoverySnapshotCache {
pub struct LocalKvIndexer {
/// The underlying indexer
indexer: KvIndexer,
/// Lazily-created exact lower-tier indexes partitioned by storage tier.
lower_tier_indexers: Arc<Mutex<HashMap<StorageTier, Arc<ThreadPoolIndexer<LowerTierIndexer>>>>>,
/// Circular buffer of recent events
pub(super) event_buffer: Mutex<VecDeque<RouterEvent>>,
/// Coordinates single-flight tree dumps and the cached recovery snapshot.
......@@ -229,6 +231,7 @@ impl LocalKvIndexer {
) -> Self {
Self {
indexer: KvIndexer::new(token, kv_block_size, metrics),
lower_tier_indexers: Arc::new(Mutex::new(HashMap::new())),
event_buffer: Mutex::new(VecDeque::with_capacity(max_buffer_size)),
recovery_cache: Arc::new(RecoverySnapshotCache::new()),
max_buffer_size,
......@@ -335,13 +338,7 @@ impl LocalKvIndexer {
///
/// This forwards the event to the underlying indexer and records it on success.
pub async fn apply_event_with_buffer(&self, event: RouterEvent) -> Result<(), KvRouterError> {
// Forward to underlying indexer
let result = self
.indexer
.event_sender()
.send(event.clone())
.await
.map_err(|_| KvRouterError::IndexerOffline);
let result = self.apply_event_by_tier(&event).await;
if result.is_ok() {
let should_invalidate = matches!(event.event.data, KvCacheEventData::Cleared);
let detected_gap = self.record_event(event);
......@@ -617,6 +614,63 @@ impl LocalKvIndexer {
pub fn get_workers_sender(&self) -> mpsc::Sender<GetWorkersRequest> {
self.indexer.get_workers_sender()
}
/// Get the KV block size.
pub fn block_size(&self) -> u32 {
self.indexer.block_size()
}
async fn apply_event_to_primary(&self, event: RouterEvent) -> Result<(), KvRouterError> {
self.indexer
.event_sender()
.send(event)
.await
.map_err(|_| KvRouterError::IndexerOffline)
}
async fn apply_event_to_lower_tier(&self, event: RouterEvent) -> Result<(), KvRouterError> {
self.get_or_create_lower_tier_indexer(event.storage_tier)
.apply_event(event)
.await;
Ok(())
}
async fn apply_event_by_tier(&self, event: &RouterEvent) -> Result<(), KvRouterError> {
match &event.event.data {
KvCacheEventData::Cleared => {
self.apply_event_to_primary(event.clone()).await?;
for indexer in self.all_lower_tier_indexers() {
indexer.apply_event(event.clone()).await;
}
Ok(())
}
_ if event.storage_tier.is_gpu() => self.apply_event_to_primary(event.clone()).await,
_ => self.apply_event_to_lower_tier(event.clone()).await,
}
}
fn get_or_create_lower_tier_indexer(
&self,
storage_tier: StorageTier,
) -> Arc<ThreadPoolIndexer<LowerTierIndexer>> {
debug_assert!(!storage_tier.is_gpu());
let mut indexers = self.lower_tier_indexers.lock().unwrap();
indexers
.entry(storage_tier)
.or_insert_with(|| {
Arc::new(ThreadPoolIndexer::new(
LowerTierIndexer::new(),
1,
self.block_size(),
))
})
.clone()
}
fn all_lower_tier_indexers(&self) -> Vec<Arc<ThreadPoolIndexer<LowerTierIndexer>>> {
let indexers = self.lower_tier_indexers.lock().unwrap();
indexers.values().cloned().collect()
}
}
// Implement KvIndexerInterface by delegating to the underlying indexer
......@@ -646,10 +700,16 @@ impl KvIndexerInterface for LocalKvIndexer {
}
async fn remove_worker(&self, worker: WorkerId) {
for indexer in self.all_lower_tier_indexers() {
indexer.remove_worker(worker).await;
}
let _ = self.indexer.remove_worker_sender().send(worker).await;
}
async fn remove_worker_dp_rank(&self, worker: WorkerId, dp_rank: DpRank) {
for indexer in self.all_lower_tier_indexers() {
KvIndexerInterface::remove_worker_dp_rank(&*indexer, worker, dp_rank).await;
}
KvIndexerInterface::remove_worker_dp_rank(&self.indexer, worker, dp_rank).await;
}
......@@ -658,7 +718,27 @@ impl KvIndexerInterface for LocalKvIndexer {
}
async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
self.indexer.dump_events().await
let mut events = self.indexer.dump_events().await?;
// Also dump lower-tier indexer state so the router receives
// host-pinned / disk block information during recovery.
let lower_tiers: Vec<(StorageTier, Arc<ThreadPoolIndexer<LowerTierIndexer>>)> = {
let indexers = self.lower_tier_indexers.lock().unwrap();
indexers
.iter()
.map(|(&tier, idx)| (tier, idx.clone()))
.collect()
};
for (tier, indexer) in lower_tiers {
if let Ok(tier_events) = indexer.dump_events().await {
for mut event in tier_events {
event.storage_tier = tier;
events.push(event);
}
}
}
Ok(events)
}
async fn process_routing_decision_for_request(
......@@ -674,6 +754,231 @@ impl KvIndexerInterface for LocalKvIndexer {
}
async fn flush(&self) -> usize {
self.indexer.flush().await
let queued = self.indexer.flush().await;
for indexer in self.all_lower_tier_indexers() {
let _ = indexer.dump_events().await;
}
queued
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use rustc_hash::FxHashMap;
use tokio_util::sync::CancellationToken;
use super::LocalKvIndexer;
use crate::indexer::{KvIndexerInterface, KvIndexerMetrics, LowerTierContinuation};
use crate::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash, RouterEvent, StorageTier, WorkerWithDpRank,
};
fn lower_tier_store_event(
worker_id: u64,
dp_rank: u32,
event_id: u64,
parent_hash: u64,
tokens_hash: u64,
block_hash: u64,
storage_tier: StorageTier,
) -> RouterEvent {
RouterEvent::with_storage_tier(
worker_id,
KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: Some(ExternalSequenceBlockHash(parent_hash)),
start_position: None,
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(block_hash),
tokens_hash: LocalBlockHash(tokens_hash),
mm_extra_info: None,
}],
}),
dp_rank,
},
storage_tier,
)
}
fn lower_tier_hits(
indexer: &LocalKvIndexer,
storage_tier: StorageTier,
worker_id: u64,
dp_rank: u32,
parent_hash: u64,
tokens_hash: u64,
) -> usize {
let lower_tier_indexer = {
let indexers = indexer.lower_tier_indexers.lock().unwrap();
indexers.get(&storage_tier).cloned()
};
let Some(lower_tier_indexer) = lower_tier_indexer else {
return 0;
};
let mut continuations = FxHashMap::default();
continuations.insert(
WorkerWithDpRank::new(worker_id, dp_rank),
LowerTierContinuation::new(0, ExternalSequenceBlockHash(parent_hash)),
);
lower_tier_indexer
.backend()
.query_contiguous_hits(&[LocalBlockHash(tokens_hash)], &continuations)
.get(&WorkerWithDpRank::new(worker_id, dp_rank))
.copied()
.unwrap_or(0)
}
#[tokio::test]
async fn lower_tier_events_are_buffered_without_touching_primary_index() {
let indexer = LocalKvIndexer::new(
CancellationToken::new(),
4,
Arc::new(KvIndexerMetrics::new_unregistered()),
16,
);
let event = lower_tier_store_event(7, 0, 1, 900, 11, 101, StorageTier::HostPinned);
indexer
.apply_event_with_buffer(event.clone())
.await
.unwrap();
let _ = indexer.flush().await;
assert_eq!(indexer.get_all_events_in_buffer(), vec![event]);
assert_eq!(indexer.lower_tier_indexers.lock().unwrap().len(), 1);
assert_eq!(
lower_tier_hits(&indexer, StorageTier::HostPinned, 7, 0, 900, 11),
1
);
let overlap = indexer
.find_matches(vec![LocalBlockHash(11)])
.await
.unwrap();
assert!(overlap.scores.is_empty());
}
#[tokio::test]
async fn lower_tier_events_are_partitioned_by_storage_tier() {
let indexer = LocalKvIndexer::new(
CancellationToken::new(),
4,
Arc::new(KvIndexerMetrics::new_unregistered()),
16,
);
assert_eq!(indexer.lower_tier_indexers.lock().unwrap().len(), 0);
indexer
.apply_event_with_buffer(lower_tier_store_event(
19,
0,
1,
1000,
31,
301,
StorageTier::HostPinned,
))
.await
.unwrap();
let _ = indexer.flush().await;
assert_eq!(indexer.lower_tier_indexers.lock().unwrap().len(), 1);
indexer
.apply_event_with_buffer(lower_tier_store_event(
19,
0,
2,
2000,
31,
302,
StorageTier::Disk,
))
.await
.unwrap();
let _ = indexer.flush().await;
assert_eq!(indexer.lower_tier_indexers.lock().unwrap().len(), 2);
assert_eq!(
lower_tier_hits(&indexer, StorageTier::HostPinned, 19, 0, 1000, 31),
1
);
assert_eq!(
lower_tier_hits(&indexer, StorageTier::Disk, 19, 0, 2000, 31),
1
);
assert_eq!(
lower_tier_hits(&indexer, StorageTier::HostPinned, 19, 0, 2000, 31),
0
);
assert_eq!(
lower_tier_hits(&indexer, StorageTier::Disk, 19, 0, 1000, 31),
0
);
}
#[tokio::test]
async fn cleared_event_clears_all_lower_tier_dp_ranks_for_worker() {
let indexer = LocalKvIndexer::new(
CancellationToken::new(),
4,
Arc::new(KvIndexerMetrics::new_unregistered()),
16,
);
indexer
.apply_event_with_buffer(lower_tier_store_event(
11,
0,
1,
1000,
21,
201,
StorageTier::HostPinned,
))
.await
.unwrap();
indexer
.apply_event_with_buffer(lower_tier_store_event(
11,
1,
2,
2000,
22,
202,
StorageTier::HostPinned,
))
.await
.unwrap();
indexer
.apply_event_with_buffer(RouterEvent::with_storage_tier(
11,
KvCacheEvent {
event_id: 3,
data: KvCacheEventData::Cleared,
dp_rank: 0,
},
StorageTier::HostPinned,
))
.await
.unwrap();
let _ = indexer.flush().await;
assert_eq!(
lower_tier_hits(&indexer, StorageTier::HostPinned, 11, 0, 1000, 21),
0
);
assert_eq!(
lower_tier_hits(&indexer, StorageTier::HostPinned, 11, 1, 2000, 22),
0
);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Exact lower-tier KV continuation index.
//!
//! This structure stores worker ownership over shared continuation edges in the
//! event hash space: `(parent_sequence_hash, local_hash) -> child_sequence_hash`.
//!
//! Unlike the primary KV indexers, this index does not attempt prefix-overlap
//! scoring. Queries continue from a caller-provided per-worker continuation
//! point and count how many consecutive lower-tier blocks are present.
//!
//! The index treats lower-tier state as a set of unique continuation edges. If a
//! duplicate or conflicting store arrives, the existing mapping wins and the new
//! event is ignored.
use std::hash::BuildHasher;
use std::sync::Arc;
use dashmap::DashMap;
use rustc_hash::{FxBuildHasher, FxHashMap, FxHashSet};
use super::{KvIndexerMetrics, SyncIndexer, WorkerTask};
use crate::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheEventError, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash, OverlapScores, RouterEvent, WorkerWithDpRank,
};
type WorkerSet = FxHashSet<WorkerWithDpRank>;
type FrontierBuckets = FxHashMap<Option<ExternalSequenceBlockHash>, WorkerSet>;
type FinalStates = FxHashMap<WorkerWithDpRank, (usize, Option<ExternalSequenceBlockHash>)>;
type WorkerBlockIndex =
FxHashMap<WorkerWithDpRank, FxHashMap<ExternalSequenceBlockHash, TransitionKey>>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct TransitionKey {
parent_hash: Option<ExternalSequenceBlockHash>,
local_hash: LocalBlockHash,
}
#[derive(Debug, Clone)]
enum EdgeOwnersEntry {
Single {
child_hash: ExternalSequenceBlockHash,
owner: WorkerWithDpRank,
},
Multi {
child_hash: ExternalSequenceBlockHash,
owners: WorkerSet,
},
}
impl EdgeOwnersEntry {
fn new(child_hash: ExternalSequenceBlockHash, owner: WorkerWithDpRank) -> Self {
Self::Single { child_hash, owner }
}
fn child_hash(&self) -> ExternalSequenceBlockHash {
match self {
Self::Single { child_hash, .. } | Self::Multi { child_hash, .. } => *child_hash,
}
}
fn insert(&mut self, child_hash: ExternalSequenceBlockHash, owner: WorkerWithDpRank) -> bool {
match self {
Self::Single {
child_hash: existing_hash,
owner: existing_owner,
} => {
if *existing_hash != child_hash {
return false;
}
if *existing_owner == owner {
return true;
}
let mut owners = WorkerSet::default();
owners.insert(*existing_owner);
owners.insert(owner);
*self = Self::Multi { child_hash, owners };
true
}
Self::Multi {
child_hash: existing_hash,
owners,
} => {
if *existing_hash != child_hash {
return false;
}
owners.insert(owner);
true
}
}
}
fn remove(&mut self, owner: WorkerWithDpRank) -> bool {
match self {
Self::Single {
owner: existing_owner,
..
} => *existing_owner == owner,
Self::Multi { child_hash, owners } => {
if !owners.remove(&owner) {
return false;
}
if owners.is_empty() {
return true;
}
if owners.len() == 1 {
let remaining_owner = owners.iter().next().copied().unwrap();
*self = Self::Single {
child_hash: *child_hash,
owner: remaining_owner,
};
}
false
}
}
}
fn contains(&self, owner: &WorkerWithDpRank) -> bool {
match self {
Self::Single {
owner: existing_owner,
..
} => existing_owner == owner,
Self::Multi { owners, .. } => owners.contains(owner),
}
}
fn collect_workers(&self) -> Vec<WorkerWithDpRank> {
match self {
Self::Single { owner, .. } => vec![*owner],
Self::Multi { owners, .. } => owners.iter().copied().collect(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct LowerTierContinuation {
pub start_pos: usize,
pub last_matched_hash: Option<ExternalSequenceBlockHash>,
}
impl LowerTierContinuation {
pub fn new(start_pos: usize, last_matched_hash: ExternalSequenceBlockHash) -> Self {
Self {
start_pos,
last_matched_hash: Some(last_matched_hash),
}
}
pub fn from_root(start_pos: usize) -> Self {
Self {
start_pos,
last_matched_hash: None,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct LowerTierMatchDetails {
pub hits: FxHashMap<WorkerWithDpRank, usize>,
pub next_continuations: FxHashMap<WorkerWithDpRank, LowerTierContinuation>,
}
/// Standalone lower-tier continuation index.
pub struct LowerTierIndexer {
edges: DashMap<TransitionKey, EdgeOwnersEntry, FxBuildHasher>,
}
impl LowerTierIndexer {
pub fn new() -> Self {
Self {
edges: DashMap::with_hasher(FxBuildHasher),
}
}
fn apply_event(
&self,
worker_blocks: &mut WorkerBlockIndex,
event: RouterEvent,
) -> Result<(), KvCacheEventError> {
let worker = WorkerWithDpRank::new(event.worker_id, event.event.dp_rank);
match event.event.data {
KvCacheEventData::Stored(store_data) => {
self.store_blocks_impl(worker_blocks, worker, store_data);
Ok(())
}
KvCacheEventData::Removed(remove_data) => {
self.remove_blocks_impl(worker_blocks, worker, &remove_data.block_hashes)
}
KvCacheEventData::Cleared => {
self.clear_worker_impl(worker_blocks, event.worker_id);
Ok(())
}
}
}
fn store_blocks_impl(
&self,
worker_blocks: &mut WorkerBlockIndex,
worker: WorkerWithDpRank,
store_data: KvCacheStoreData,
) {
let mut parent_hash = store_data.parent_hash;
let worker_map = worker_blocks.entry(worker).or_default();
for block in store_data.blocks {
let key = TransitionKey {
parent_hash,
local_hash: block.tokens_hash,
};
// If this worker already has a different parent/local for the same
// block_hash, or if the shared edge is owned by a conflicting
// child_hash, stop the walk: any further blocks in this chain would
// hang off an edge this index never accepted for the worker.
if worker_map
.get(&block.block_hash)
.is_some_and(|existing_key| *existing_key != key)
{
break;
}
let inserted = match self.edges.entry(key) {
dashmap::mapref::entry::Entry::Occupied(mut edge) => {
edge.get_mut().insert(block.block_hash, worker)
}
dashmap::mapref::entry::Entry::Vacant(edge) => {
edge.insert(EdgeOwnersEntry::new(block.block_hash, worker));
true
}
};
if !inserted {
break;
}
worker_map.insert(block.block_hash, key);
parent_hash = Some(block.block_hash);
}
}
fn remove_blocks_impl(
&self,
worker_blocks: &mut WorkerBlockIndex,
worker: WorkerWithDpRank,
block_hashes: &[ExternalSequenceBlockHash],
) -> Result<(), KvCacheEventError> {
let remove_worker_entry = {
let Some(worker_map) = worker_blocks.get_mut(&worker) else {
return Err(KvCacheEventError::BlockNotFound);
};
for block_hash in block_hashes {
let Some(key) = worker_map.remove(block_hash) else {
return Err(KvCacheEventError::BlockNotFound);
};
let remove_edge = match self.edges.get_mut(&key) {
Some(mut edge) => edge.remove(worker),
None => false,
};
if remove_edge {
self.edges.remove(&key);
}
}
worker_map.is_empty()
};
if remove_worker_entry {
worker_blocks.remove(&worker);
}
Ok(())
}
fn clear_worker_impl(&self, worker_blocks: &mut WorkerBlockIndex, worker_id: u64) {
let workers: Vec<_> = worker_blocks
.keys()
.copied()
.filter(|worker| worker.worker_id == worker_id)
.collect();
for worker in workers {
self.remove_worker_dp_rank_impl(worker_blocks, worker);
}
}
fn remove_worker_dp_rank_impl(
&self,
worker_blocks: &mut WorkerBlockIndex,
worker: WorkerWithDpRank,
) {
let Some(worker_map) = worker_blocks.remove(&worker) else {
return;
};
for (_, key) in worker_map {
let remove_edge = match self.edges.get_mut(&key) {
Some(mut edge) => edge.remove(worker),
None => false,
};
if remove_edge {
self.edges.remove(&key);
}
}
}
fn remove_worker(&self, worker_blocks: &mut WorkerBlockIndex, worker_id: u64) {
self.clear_worker_impl(worker_blocks, worker_id);
}
fn remove_worker_dp_rank(
&self,
worker_blocks: &mut WorkerBlockIndex,
worker_id: u64,
dp_rank: u32,
) {
self.remove_worker_dp_rank_impl(worker_blocks, WorkerWithDpRank::new(worker_id, dp_rank));
}
pub fn root_workers(&self, local_hash: LocalBlockHash) -> Vec<WorkerWithDpRank> {
self.edges
.get(&TransitionKey {
parent_hash: None,
local_hash,
})
.map(|edge| edge.collect_workers())
.unwrap_or_default()
}
/// Reconstruct store events from the per-worker block index. Each block
/// becomes a single-block `Stored` event with the correct parent hash,
/// suitable for replaying into a fresh indexer to recreate the same state.
fn dump_events(worker_blocks: &WorkerBlockIndex) -> Vec<RouterEvent> {
let mut events = Vec::new();
let mut event_id = 0u64;
for (worker, block_map) in worker_blocks {
for (block_hash, key) in block_map {
events.push(RouterEvent::new(
worker.worker_id,
KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: key.parent_hash,
start_position: None,
blocks: vec![KvCacheStoredBlockData {
block_hash: *block_hash,
tokens_hash: key.local_hash,
mm_extra_info: None,
}],
}),
dp_rank: worker.dp_rank,
},
));
event_id += 1;
}
}
events
}
pub fn query_contiguous_hits<S>(
&self,
local_hashes: &[LocalBlockHash],
continuations: &std::collections::HashMap<WorkerWithDpRank, LowerTierContinuation, S>,
) -> FxHashMap<WorkerWithDpRank, usize>
where
S: BuildHasher,
{
self.query_match_details(local_hashes, continuations).hits
}
/// For each worker, counts how many contiguous lower-tier blocks match
/// starting from the worker's continuation point, and returns the updated
/// continuation state.
///
/// Workers may start at different positions in `local_hashes` (each has its
/// own `LowerTierContinuation`). The algorithm groups workers that share a
/// start position into "breakpoints", sorts them, and advances each group
/// forward through the hash sequence one position at a time. When a group
/// reaches the next breakpoint it pauses so the two groups can be merged
/// (workers that converge onto the same edge path are walked together).
pub fn query_match_details<S>(
&self,
local_hashes: &[LocalBlockHash],
continuations: &std::collections::HashMap<WorkerWithDpRank, LowerTierContinuation, S>,
) -> LowerTierMatchDetails
where
S: BuildHasher,
{
// Build the sorted breakpoint list. Each entry is a position in the
// hash sequence and a set of (parent_hash -> workers) groups that start
// walking from that position. The set of positions is fixed — the walk
// never creates new breakpoints, it only merges overflow workers into
// the next existing one.
let mut breakpoints: Vec<(usize, FrontierBuckets)> = Vec::new();
{
let mut pos_index: FxHashMap<usize, usize> = FxHashMap::default();
for (worker, continuation) in continuations {
let idx = match pos_index.get(&continuation.start_pos) {
Some(&idx) => idx,
None => {
let idx = breakpoints.len();
pos_index.insert(continuation.start_pos, idx);
breakpoints.push((continuation.start_pos, FrontierBuckets::default()));
idx
}
};
breakpoints[idx]
.1
.entry(continuation.last_matched_hash)
.or_default()
.insert(*worker);
}
breakpoints.sort_unstable_by_key(|(pos, _)| *pos);
}
let mut final_states = FinalStates::default();
// Process breakpoints front-to-back. Each group walks forward until it
// hits the next breakpoint or runs out of matching edges. Workers that
// survive to the next breakpoint are collected as "overflow" and merged
// into that breakpoint's buckets before it gets processed.
for idx in 0..breakpoints.len() {
let pos = breakpoints[idx].0;
let states = std::mem::take(&mut breakpoints[idx].1);
let next_breakpoint = breakpoints
.get(idx + 1)
.map(|(p, _)| *p)
.unwrap_or(local_hashes.len())
.min(local_hashes.len());
let mut overflow = FrontierBuckets::default();
for (parent_hash, workers) in states {
advance_state_to_breakpoint(
self,
local_hashes,
pos,
parent_hash,
workers,
next_breakpoint,
&mut overflow,
&mut final_states,
);
}
if !overflow.is_empty()
&& let Some((_, next_buckets)) = breakpoints.get_mut(idx + 1)
{
for (hash, workers) in overflow {
next_buckets.entry(hash).or_default().extend(workers);
}
}
}
// Convert final_states into the result. Workers that never appeared in
// final_states (e.g. empty sequence) keep their original continuation.
let mut results = LowerTierMatchDetails::default();
for (worker, continuation) in continuations {
let (final_pos, final_hash) = final_states
.get(worker)
.copied()
.unwrap_or((continuation.start_pos, continuation.last_matched_hash));
let hits = final_pos.saturating_sub(continuation.start_pos);
results.hits.insert(*worker, hits);
let next_continuation = if hits == 0 {
*continuation
} else {
LowerTierContinuation {
start_pos: final_pos,
last_matched_hash: final_hash.or(continuation.last_matched_hash),
}
};
results
.next_continuations
.insert(*worker, next_continuation);
}
results
}
}
impl Default for LowerTierIndexer {
fn default() -> Self {
Self::new()
}
}
impl SyncIndexer for LowerTierIndexer {
fn worker(
&self,
event_receiver: flume::Receiver<WorkerTask>,
_metrics: Option<Arc<KvIndexerMetrics>>,
) -> anyhow::Result<()> {
let mut worker_blocks = WorkerBlockIndex::default();
while let Ok(task) = event_receiver.recv() {
match task {
WorkerTask::Event(event) => {
if let Err(error) = self.apply_event(&mut worker_blocks, event) {
tracing::warn!(%error, "Failed to apply lower-tier event");
}
}
WorkerTask::RemoveWorker(worker_id) => {
self.remove_worker(&mut worker_blocks, worker_id);
}
WorkerTask::RemoveWorkerDpRank(worker_id, dp_rank) => {
self.remove_worker_dp_rank(&mut worker_blocks, worker_id, dp_rank);
}
WorkerTask::DumpEvents(sender) => {
let _ = sender.send(Ok(Self::dump_events(&worker_blocks)));
}
WorkerTask::CleanupStaleChildren => {}
WorkerTask::Terminate => {
break;
}
}
}
tracing::debug!("LowerTierIndexer worker thread shutting down");
Ok(())
}
fn find_matches(&self, sequence: &[LocalBlockHash], _early_exit: bool) -> OverlapScores {
let Some(&first_hash) = sequence.first() else {
return OverlapScores::default();
};
let mut continuations = FxHashMap::default();
for worker in self.root_workers(first_hash) {
continuations.insert(worker, LowerTierContinuation::from_root(0));
}
let hits = self.query_contiguous_hits(sequence, &continuations);
let mut scores = OverlapScores::default();
for (worker, hits) in hits {
if hits > 0 {
scores
.scores
.insert(worker, hits.min(u32::MAX as usize) as u32);
}
}
scores
}
}
/// Walks a group of workers sharing the same `(start_pos, parent_hash)` forward
/// through `local_hashes`, one position at a time, until `next_breakpoint`.
///
/// At each position the function looks up the edge `(cur_hash, local_hash) ->
/// child_hash` and partitions workers into those that own the edge (they
/// continue) and those that don't (they are finalized at this position).
///
/// Workers that survive all the way to `next_breakpoint` are placed into
/// `overflow` so the caller can merge them into the next breakpoint's groups.
/// Workers that reach the end of `local_hashes` are finalized instead.
#[allow(clippy::too_many_arguments)]
fn advance_state_to_breakpoint(
index: &LowerTierIndexer,
local_hashes: &[LocalBlockHash],
start_pos: usize,
start_hash: Option<ExternalSequenceBlockHash>,
workers: WorkerSet,
next_breakpoint: usize,
overflow: &mut FrontierBuckets,
final_states: &mut FinalStates,
) {
let mut cur_pos = start_pos;
let mut cur_hash = start_hash;
let mut active = workers;
// When only one worker is active we can skip all set bookkeeping and just
// do a straight edge-lookup loop.
if active.len() == 1 {
let worker = active.into_iter().next().unwrap();
advance_single_worker(
index,
local_hashes,
worker,
&mut cur_pos,
&mut cur_hash,
next_breakpoint,
overflow,
final_states,
);
return;
}
// Reusable scratch buffer for partitioning workers each iteration, avoids
// allocating new HashSets on every step.
let mut scratch = WorkerSet::default();
while cur_pos < next_breakpoint && !active.is_empty() {
// Look up the edge for the current (parent_hash, local_hash) pair.
// If no edge exists, no worker can continue — finalize everyone.
let Some(edge) = index.edges.get(&TransitionKey {
parent_hash: cur_hash,
local_hash: local_hashes[cur_pos],
}) else {
finalize_workers(final_states, active.drain(), cur_pos, cur_hash);
break;
};
// Partition active workers into matched (own the edge) and unmatched.
// For single-owner edges we can check membership in O(1) instead of
// iterating all active workers. For multi-owner edges we iterate
// whichever side is smaller.
match edge.value() {
EdgeOwnersEntry::Single { owner, .. } => {
if active.remove(owner) {
finalize_workers(final_states, active.drain(), cur_pos, cur_hash);
active.insert(*owner);
} else {
finalize_workers(final_states, active.drain(), cur_pos, cur_hash);
break;
}
}
EdgeOwnersEntry::Multi { owners, .. } => {
if owners.len() <= active.len() {
scratch.clear();
for owner in owners {
if active.remove(owner) {
scratch.insert(*owner);
}
}
finalize_workers(final_states, active.drain(), cur_pos, cur_hash);
std::mem::swap(&mut active, &mut scratch);
} else {
scratch.clear();
for worker in active.drain() {
if owners.contains(&worker) {
scratch.insert(worker);
} else {
final_states.insert(worker, (cur_pos, cur_hash));
}
}
std::mem::swap(&mut active, &mut scratch);
}
if active.is_empty() {
break;
}
}
}
cur_hash = Some(edge.child_hash());
cur_pos += 1;
// If we're down to one worker, switch to the scalar loop for the
// remaining positions to avoid set overhead.
if active.len() == 1 {
let worker = active.into_iter().next().unwrap();
advance_single_worker(
index,
local_hashes,
worker,
&mut cur_pos,
&mut cur_hash,
next_breakpoint,
overflow,
final_states,
);
return;
}
}
if active.is_empty() {
return;
}
// Workers that reached the breakpoint without dropping off. If we're past
// the end of the sequence they're finalized; otherwise they overflow into
// the next breakpoint for continued walking.
if cur_pos >= local_hashes.len() {
finalize_workers(final_states, active, cur_pos, cur_hash);
} else {
overflow.entry(cur_hash).or_default().extend(active);
}
}
/// Simplified walk for exactly one worker. Just does sequential edge lookups
/// without any set operations — either the worker owns each edge and continues,
/// or it stops.
#[allow(clippy::too_many_arguments)]
fn advance_single_worker(
index: &LowerTierIndexer,
local_hashes: &[LocalBlockHash],
worker: WorkerWithDpRank,
cur_pos: &mut usize,
cur_hash: &mut Option<ExternalSequenceBlockHash>,
next_breakpoint: usize,
overflow: &mut FrontierBuckets,
final_states: &mut FinalStates,
) {
while *cur_pos < next_breakpoint {
let Some(edge) = index.edges.get(&TransitionKey {
parent_hash: *cur_hash,
local_hash: local_hashes[*cur_pos],
}) else {
final_states.insert(worker, (*cur_pos, *cur_hash));
return;
};
if !edge.contains(&worker) {
final_states.insert(worker, (*cur_pos, *cur_hash));
return;
}
*cur_hash = Some(edge.child_hash());
*cur_pos += 1;
}
if *cur_pos >= local_hashes.len() {
final_states.insert(worker, (*cur_pos, *cur_hash));
} else {
overflow.entry(*cur_hash).or_default().insert(worker);
}
}
fn finalize_workers(
final_states: &mut FinalStates,
workers: impl IntoIterator<Item = WorkerWithDpRank>,
pos: usize,
parent_hash: Option<ExternalSequenceBlockHash>,
) {
for worker in workers {
final_states.insert(worker, (pos, parent_hash));
}
}
#[cfg(test)]
mod tests {
use super::{LowerTierContinuation, LowerTierIndexer, WorkerBlockIndex};
use rustc_hash::FxHashMap;
use crate::indexer::{KvIndexerInterface, ThreadPoolIndexer};
use crate::protocols::{
ExternalSequenceBlockHash, KvCacheEventData, KvCacheStoreData, LocalBlockHash,
WorkerWithDpRank,
};
use crate::test_utils::{remove_event, router_event, stored_blocks_with_sequence_hashes};
fn local_hashes(values: &[u64]) -> Vec<LocalBlockHash> {
values.iter().copied().map(LocalBlockHash).collect()
}
fn store_event(
worker_id: u64,
dp_rank: u32,
event_id: u64,
parent_hash: Option<u64>,
local_values: &[u64],
external_hashes: &[u64],
) -> crate::protocols::RouterEvent {
router_event(
worker_id,
event_id,
dp_rank,
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: parent_hash.map(ExternalSequenceBlockHash),
start_position: None,
blocks: stored_blocks_with_sequence_hashes(
&local_hashes(local_values),
external_hashes,
),
}),
)
}
struct TestLowerTierIndex {
index: LowerTierIndexer,
worker_blocks: WorkerBlockIndex,
}
impl TestLowerTierIndex {
fn new() -> Self {
Self {
index: LowerTierIndexer::new(),
worker_blocks: WorkerBlockIndex::default(),
}
}
fn apply_event(
&mut self,
event: crate::protocols::RouterEvent,
) -> Result<(), crate::protocols::KvCacheEventError> {
self.index.apply_event(&mut self.worker_blocks, event)
}
fn remove_worker(&mut self, worker_id: u64) {
self.index.remove_worker(&mut self.worker_blocks, worker_id);
}
fn remove_worker_dp_rank(&mut self, worker_id: u64, dp_rank: u32) {
self.index
.remove_worker_dp_rank(&mut self.worker_blocks, worker_id, dp_rank);
}
fn root_workers(&self, local_hash: LocalBlockHash) -> Vec<WorkerWithDpRank> {
self.index.root_workers(local_hash)
}
fn query_contiguous_hits<S>(
&self,
local_hashes: &[LocalBlockHash],
continuations: &std::collections::HashMap<WorkerWithDpRank, LowerTierContinuation, S>,
) -> FxHashMap<WorkerWithDpRank, usize>
where
S: std::hash::BuildHasher,
{
self.index
.query_contiguous_hits(local_hashes, continuations)
}
fn query_match_details<S>(
&self,
local_hashes: &[LocalBlockHash],
continuations: &std::collections::HashMap<WorkerWithDpRank, LowerTierContinuation, S>,
) -> super::LowerTierMatchDetails
where
S: std::hash::BuildHasher,
{
self.index.query_match_details(local_hashes, continuations)
}
fn dump_events(&self) -> Vec<crate::protocols::RouterEvent> {
LowerTierIndexer::dump_events(&self.worker_blocks)
}
}
#[test]
fn root_query_uses_none_parent_transition() {
let mut index = TestLowerTierIndex::new();
index
.apply_event(store_event(7, 0, 0, None, &[11, 12, 13], &[101, 102, 103]))
.unwrap();
let mut continuations = FxHashMap::default();
continuations.insert(
WorkerWithDpRank::new(7, 0),
LowerTierContinuation::from_root(0),
);
let hits = index.query_contiguous_hits(&local_hashes(&[11, 12, 13]), &continuations);
assert_eq!(hits.get(&WorkerWithDpRank::new(7, 0)), Some(&3));
}
#[test]
fn root_workers_only_include_matching_root_edges() {
let mut index = TestLowerTierIndex::new();
index
.apply_event(store_event(7, 0, 0, None, &[11, 12], &[101, 102]))
.unwrap();
index
.apply_event(store_event(8, 0, 1, Some(500), &[11], &[201]))
.unwrap();
let workers = index.root_workers(LocalBlockHash(11));
assert_eq!(workers.len(), 1);
assert!(workers.contains(&WorkerWithDpRank::new(7, 0)));
}
#[tokio::test]
async fn thread_pool_backend_applies_lower_tier_events() {
let index = ThreadPoolIndexer::new(LowerTierIndexer::new(), 2, 1);
let worker = WorkerWithDpRank::new(7, 0);
index
.apply_event(store_event(7, 0, 0, None, &[11, 12], &[101, 102]))
.await;
let _ = index.dump_events().await.unwrap();
let mut continuations = FxHashMap::default();
continuations.insert(worker, LowerTierContinuation::from_root(0));
let hits = index
.backend()
.query_contiguous_hits(&local_hashes(&[11, 12]), &continuations);
assert_eq!(hits.get(&worker), Some(&2));
}
#[tokio::test]
async fn thread_pool_backend_remove_worker_dp_rank_keeps_other_rank() {
let index = ThreadPoolIndexer::new(LowerTierIndexer::new(), 2, 1);
let worker_dp0 = WorkerWithDpRank::new(43, 0);
let worker_dp1 = WorkerWithDpRank::new(43, 1);
index
.apply_event(store_event(43, 0, 0, None, &[11], &[101]))
.await;
index
.apply_event(store_event(43, 1, 1, None, &[11], &[101]))
.await;
let _ = index.dump_events().await.unwrap();
index.remove_worker_dp_rank(43, 0).await;
let _ = index.dump_events().await.unwrap();
let mut continuations = FxHashMap::default();
continuations.insert(worker_dp0, LowerTierContinuation::from_root(0));
continuations.insert(worker_dp1, LowerTierContinuation::from_root(0));
let hits = index
.backend()
.query_contiguous_hits(&local_hashes(&[11]), &continuations);
assert_eq!(hits.get(&worker_dp0), Some(&0));
assert_eq!(hits.get(&worker_dp1), Some(&1));
}
#[tokio::test]
async fn thread_pool_backend_cleared_event_preserves_other_workers() {
let index = ThreadPoolIndexer::new(LowerTierIndexer::new(), 2, 1);
let worker_a = WorkerWithDpRank::new(29, 0);
let worker_b = WorkerWithDpRank::new(30, 0);
index
.apply_event(store_event(29, 0, 0, None, &[101, 102], &[1001, 1002]))
.await;
index
.apply_event(store_event(30, 0, 1, None, &[101, 102], &[1001, 1002]))
.await;
index
.apply_event(router_event(29, 2, 0, KvCacheEventData::Cleared))
.await;
let _ = index.dump_events().await.unwrap();
let mut continuations = FxHashMap::default();
continuations.insert(worker_a, LowerTierContinuation::from_root(0));
continuations.insert(worker_b, LowerTierContinuation::from_root(0));
let hits = index
.backend()
.query_contiguous_hits(&local_hashes(&[101, 102]), &continuations);
assert_eq!(hits.get(&worker_a), Some(&0));
assert_eq!(hits.get(&worker_b), Some(&2));
}
#[test]
fn missing_parent_tail_queries_exactly_from_last_matched_hash() {
let mut index = TestLowerTierIndex::new();
index
.apply_event(store_event(
3,
0,
0,
Some(999),
&[21, 22, 23],
&[201, 202, 203],
))
.unwrap();
let query = local_hashes(&[1, 2, 21, 22, 23]);
let mut continuations = FxHashMap::default();
continuations.insert(
WorkerWithDpRank::new(3, 0),
LowerTierContinuation::new(2, ExternalSequenceBlockHash(999)),
);
let hits = index.query_contiguous_hits(&query, &continuations);
assert_eq!(hits.get(&WorkerWithDpRank::new(3, 0)), Some(&3));
}
#[test]
fn mid_segment_continuation_works_without_materialization() {
let mut index = TestLowerTierIndex::new();
index
.apply_event(store_event(
5,
0,
0,
Some(700),
&[31, 32, 33],
&[301, 302, 303],
))
.unwrap();
let query = local_hashes(&[10, 31, 32, 33]);
let mut continuations = FxHashMap::default();
continuations.insert(
WorkerWithDpRank::new(5, 0),
LowerTierContinuation::new(2, ExternalSequenceBlockHash(301)),
);
let hits = index.query_contiguous_hits(&query, &continuations);
assert_eq!(hits.get(&WorkerWithDpRank::new(5, 0)), Some(&2));
}
#[test]
fn branch_matching_is_exact_by_parent_hash() {
let mut index = TestLowerTierIndex::new();
index
.apply_event(store_event(9, 0, 0, Some(500), &[91, 92], &[901, 902]))
.unwrap();
index
.apply_event(store_event(9, 0, 1, Some(700), &[91, 93], &[903, 904]))
.unwrap();
let query = local_hashes(&[91, 92]);
let mut continuations = FxHashMap::default();
continuations.insert(
WorkerWithDpRank::new(9, 0),
LowerTierContinuation::new(0, ExternalSequenceBlockHash(500)),
);
let hits = index.query_contiguous_hits(&query, &continuations);
assert_eq!(hits.get(&WorkerWithDpRank::new(9, 0)), Some(&2));
continuations.insert(
WorkerWithDpRank::new(9, 0),
LowerTierContinuation::new(0, ExternalSequenceBlockHash(700)),
);
let branch_b_hits = index.query_contiguous_hits(&query, &continuations);
assert_eq!(branch_b_hits.get(&WorkerWithDpRank::new(9, 0)), Some(&1));
}
#[test]
fn shared_worker_traversal_fuses_at_descendant_breakpoint() {
let mut index = TestLowerTierIndex::new();
let worker_a = WorkerWithDpRank::new(1, 0);
let worker_b = WorkerWithDpRank::new(2, 0);
index
.apply_event(store_event(
1,
0,
0,
None,
&[11, 12, 13, 14],
&[101, 102, 103, 104],
))
.unwrap();
index
.apply_event(store_event(2, 0, 1, Some(102), &[13, 14], &[103, 104]))
.unwrap();
let mut continuations = FxHashMap::default();
continuations.insert(worker_a, LowerTierContinuation::from_root(0));
continuations.insert(
worker_b,
LowerTierContinuation::new(2, ExternalSequenceBlockHash(102)),
);
let details = index.query_match_details(&local_hashes(&[11, 12, 13, 14]), &continuations);
assert_eq!(details.hits.get(&worker_a), Some(&4));
assert_eq!(details.hits.get(&worker_b), Some(&2));
assert_eq!(
details.next_continuations.get(&worker_a),
Some(&LowerTierContinuation::new(
4,
ExternalSequenceBlockHash(104)
))
);
assert_eq!(
details.next_continuations.get(&worker_b),
Some(&LowerTierContinuation::new(
4,
ExternalSequenceBlockHash(104)
))
);
}
#[test]
fn shared_worker_traversal_fuses_across_multiple_breakpoints() {
let mut index = TestLowerTierIndex::new();
let worker_a = WorkerWithDpRank::new(1, 0);
let worker_b = WorkerWithDpRank::new(2, 0);
let worker_c = WorkerWithDpRank::new(3, 0);
index
.apply_event(store_event(
1,
0,
0,
None,
&[11, 12, 13, 14],
&[101, 102, 103, 104],
))
.unwrap();
index
.apply_event(store_event(
2,
0,
1,
Some(101),
&[12, 13, 14],
&[102, 103, 104],
))
.unwrap();
index
.apply_event(store_event(3, 0, 2, Some(103), &[14], &[104]))
.unwrap();
let mut continuations = FxHashMap::default();
continuations.insert(worker_a, LowerTierContinuation::from_root(0));
continuations.insert(
worker_b,
LowerTierContinuation::new(1, ExternalSequenceBlockHash(101)),
);
continuations.insert(
worker_c,
LowerTierContinuation::new(3, ExternalSequenceBlockHash(103)),
);
let details = index.query_match_details(&local_hashes(&[11, 12, 13, 14]), &continuations);
assert_eq!(details.hits.get(&worker_a), Some(&4));
assert_eq!(details.hits.get(&worker_b), Some(&3));
assert_eq!(details.hits.get(&worker_c), Some(&1));
assert_eq!(
details.next_continuations.get(&worker_a),
Some(&LowerTierContinuation::new(
4,
ExternalSequenceBlockHash(104)
))
);
assert_eq!(
details.next_continuations.get(&worker_b),
Some(&LowerTierContinuation::new(
4,
ExternalSequenceBlockHash(104)
))
);
assert_eq!(
details.next_continuations.get(&worker_c),
Some(&LowerTierContinuation::new(
4,
ExternalSequenceBlockHash(104)
))
);
}
#[test]
fn duplicate_store_is_idempotent_for_remove() {
let mut index = TestLowerTierIndex::new();
let event = store_event(13, 0, 0, Some(800), &[61], &[601]);
index.apply_event(event.clone()).unwrap();
index.apply_event(event).unwrap();
let mut continuations = FxHashMap::default();
continuations.insert(
WorkerWithDpRank::new(13, 0),
LowerTierContinuation::new(0, ExternalSequenceBlockHash(800)),
);
let query = local_hashes(&[61]);
let initial = index.query_contiguous_hits(&query, &continuations);
assert_eq!(initial.get(&WorkerWithDpRank::new(13, 0)), Some(&1));
index
.apply_event(remove_event(13, 1, 0, vec![ExternalSequenceBlockHash(601)]))
.unwrap();
let after_one_remove = index.query_contiguous_hits(&query, &continuations);
assert_eq!(
after_one_remove.get(&WorkerWithDpRank::new(13, 0)),
Some(&0)
);
}
#[test]
fn removing_one_owner_preserves_shared_edge_for_other_workers() {
let mut index = TestLowerTierIndex::new();
let worker_a = WorkerWithDpRank::new(1, 0);
let worker_b = WorkerWithDpRank::new(2, 0);
index
.apply_event(store_event(1, 0, 0, None, &[11, 12], &[101, 102]))
.unwrap();
index
.apply_event(store_event(2, 0, 1, None, &[11, 12], &[101, 102]))
.unwrap();
index
.apply_event(remove_event(
1,
2,
0,
vec![
ExternalSequenceBlockHash(101),
ExternalSequenceBlockHash(102),
],
))
.unwrap();
let mut continuations = FxHashMap::default();
continuations.insert(worker_a, LowerTierContinuation::from_root(0));
continuations.insert(worker_b, LowerTierContinuation::from_root(0));
let hits = index.query_contiguous_hits(&local_hashes(&[11, 12]), &continuations);
assert_eq!(hits.get(&worker_a), Some(&0));
assert_eq!(hits.get(&worker_b), Some(&2));
}
#[test]
fn remove_stops_contiguous_walk_at_missing_edge() {
let mut index = TestLowerTierIndex::new();
index
.apply_event(store_event(
17,
0,
0,
Some(900),
&[71, 72, 73],
&[701, 702, 703],
))
.unwrap();
index
.apply_event(remove_event(17, 1, 0, vec![ExternalSequenceBlockHash(702)]))
.unwrap();
let query = local_hashes(&[71, 72, 73]);
let mut continuations = FxHashMap::default();
continuations.insert(
WorkerWithDpRank::new(17, 0),
LowerTierContinuation::new(0, ExternalSequenceBlockHash(900)),
);
let hits = index.query_contiguous_hits(&query, &continuations);
assert_eq!(hits.get(&WorkerWithDpRank::new(17, 0)), Some(&1));
}
#[test]
fn unknown_last_matched_hash_returns_zero() {
let mut index = TestLowerTierIndex::new();
index
.apply_event(store_event(19, 0, 0, Some(1000), &[81, 82], &[801, 802]))
.unwrap();
let query = local_hashes(&[81, 82]);
let mut continuations = FxHashMap::default();
continuations.insert(
WorkerWithDpRank::new(19, 0),
LowerTierContinuation::new(0, ExternalSequenceBlockHash(9999)),
);
let hits = index.query_contiguous_hits(&query, &continuations);
assert_eq!(hits.get(&WorkerWithDpRank::new(19, 0)), Some(&0));
}
#[test]
fn start_pos_past_end_returns_zero() {
let mut index = TestLowerTierIndex::new();
index
.apply_event(store_event(23, 0, 0, Some(1100), &[91], &[901]))
.unwrap();
let query = local_hashes(&[91]);
let mut continuations = FxHashMap::default();
continuations.insert(
WorkerWithDpRank::new(23, 0),
LowerTierContinuation::new(1, ExternalSequenceBlockHash(1100)),
);
let hits = index.query_contiguous_hits(&query, &continuations);
assert_eq!(hits.get(&WorkerWithDpRank::new(23, 0)), Some(&0));
}
#[test]
fn cleared_event_removes_all_lower_tier_state() {
let mut index = TestLowerTierIndex::new();
index
.apply_event(store_event(
29,
0,
0,
Some(1200),
&[101, 102],
&[1001, 1002],
))
.unwrap();
index
.apply_event(router_event(29, 1, 0, KvCacheEventData::Cleared))
.unwrap();
let query = local_hashes(&[101, 102]);
let mut continuations = FxHashMap::default();
continuations.insert(
WorkerWithDpRank::new(29, 0),
LowerTierContinuation::new(0, ExternalSequenceBlockHash(1200)),
);
let hits = index.query_contiguous_hits(&query, &continuations);
assert_eq!(hits.get(&WorkerWithDpRank::new(29, 0)), Some(&0));
}
#[test]
fn cleared_event_is_worker_wide_across_dp_ranks() {
let mut index = TestLowerTierIndex::new();
index
.apply_event(store_event(29, 0, 0, Some(1200), &[101], &[1001]))
.unwrap();
index
.apply_event(store_event(29, 1, 1, Some(2200), &[201], &[2001]))
.unwrap();
index
.apply_event(router_event(29, 2, 0, KvCacheEventData::Cleared))
.unwrap();
let mut continuations = FxHashMap::default();
continuations.insert(
WorkerWithDpRank::new(29, 0),
LowerTierContinuation::new(0, ExternalSequenceBlockHash(1200)),
);
continuations.insert(
WorkerWithDpRank::new(29, 1),
LowerTierContinuation::new(0, ExternalSequenceBlockHash(2200)),
);
let hits = index.query_contiguous_hits(&local_hashes(&[101]), &continuations);
assert_eq!(hits.get(&WorkerWithDpRank::new(29, 0)), Some(&0));
assert_eq!(hits.get(&WorkerWithDpRank::new(29, 1)), Some(&0));
}
#[test]
fn cleared_event_preserves_shared_edges_for_other_workers() {
let mut index = TestLowerTierIndex::new();
let worker_a = WorkerWithDpRank::new(29, 0);
let worker_b = WorkerWithDpRank::new(30, 0);
index
.apply_event(store_event(29, 0, 0, None, &[101, 102], &[1001, 1002]))
.unwrap();
index
.apply_event(store_event(30, 0, 1, None, &[101, 102], &[1001, 1002]))
.unwrap();
index
.apply_event(router_event(29, 2, 0, KvCacheEventData::Cleared))
.unwrap();
let mut continuations = FxHashMap::default();
continuations.insert(worker_a, LowerTierContinuation::from_root(0));
continuations.insert(worker_b, LowerTierContinuation::from_root(0));
let hits = index.query_contiguous_hits(&local_hashes(&[101, 102]), &continuations);
assert_eq!(hits.get(&worker_a), Some(&0));
assert_eq!(hits.get(&worker_b), Some(&2));
}
#[test]
fn remove_worker_drops_all_ranks() {
let mut index = TestLowerTierIndex::new();
index
.apply_event(store_event(41, 0, 0, Some(3000), &[1], &[301]))
.unwrap();
index
.apply_event(store_event(41, 1, 1, Some(4000), &[2], &[401]))
.unwrap();
index.remove_worker(41);
let mut continuations = FxHashMap::default();
continuations.insert(
WorkerWithDpRank::new(41, 0),
LowerTierContinuation::new(0, ExternalSequenceBlockHash(3000)),
);
continuations.insert(
WorkerWithDpRank::new(41, 1),
LowerTierContinuation::new(0, ExternalSequenceBlockHash(4000)),
);
let hits = index.query_contiguous_hits(&local_hashes(&[1]), &continuations);
assert_eq!(hits.get(&WorkerWithDpRank::new(41, 0)), Some(&0));
assert_eq!(hits.get(&WorkerWithDpRank::new(41, 1)), Some(&0));
}
#[test]
fn remove_worker_dp_rank_keeps_other_ranks() {
let mut index = TestLowerTierIndex::new();
index
.apply_event(store_event(43, 0, 0, Some(5000), &[1], &[501]))
.unwrap();
index
.apply_event(store_event(43, 1, 1, Some(6000), &[2], &[601]))
.unwrap();
index.remove_worker_dp_rank(43, 0);
let mut continuations = FxHashMap::default();
continuations.insert(
WorkerWithDpRank::new(43, 0),
LowerTierContinuation::new(0, ExternalSequenceBlockHash(5000)),
);
continuations.insert(
WorkerWithDpRank::new(43, 1),
LowerTierContinuation::new(0, ExternalSequenceBlockHash(6000)),
);
let hits = index.query_contiguous_hits(&local_hashes(&[2]), &continuations);
assert_eq!(hits.get(&WorkerWithDpRank::new(43, 0)), Some(&0));
assert_eq!(hits.get(&WorkerWithDpRank::new(43, 1)), Some(&1));
}
#[test]
fn removing_parent_block_keeps_child_continuation_edge() {
let mut index = TestLowerTierIndex::new();
index
.apply_event(store_event(
31,
0,
0,
Some(1300),
&[111, 112],
&[1101, 1102],
))
.unwrap();
index
.apply_event(remove_event(
31,
1,
0,
vec![ExternalSequenceBlockHash(1101)],
))
.unwrap();
let root_query = local_hashes(&[111, 112]);
let mut root_continuations = FxHashMap::default();
root_continuations.insert(
WorkerWithDpRank::new(31, 0),
LowerTierContinuation::new(0, ExternalSequenceBlockHash(1300)),
);
let root_hits = index.query_contiguous_hits(&root_query, &root_continuations);
assert_eq!(root_hits.get(&WorkerWithDpRank::new(31, 0)), Some(&0));
let child_query = local_hashes(&[111, 112]);
let mut child_continuations = FxHashMap::default();
child_continuations.insert(
WorkerWithDpRank::new(31, 0),
LowerTierContinuation::new(1, ExternalSequenceBlockHash(1101)),
);
let child_hits = index.query_contiguous_hits(&child_query, &child_continuations);
assert_eq!(child_hits.get(&WorkerWithDpRank::new(31, 0)), Some(&1));
}
#[test]
fn conflicting_transition_insert_is_ignored() {
let mut index = TestLowerTierIndex::new();
index
.apply_event(store_event(37, 0, 0, Some(1400), &[121], &[1201]))
.unwrap();
index
.apply_event(store_event(37, 0, 1, Some(1400), &[121], &[1202]))
.unwrap();
let query = local_hashes(&[121]);
let mut continuations = FxHashMap::default();
continuations.insert(
WorkerWithDpRank::new(37, 0),
LowerTierContinuation::new(0, ExternalSequenceBlockHash(1400)),
);
let hits = index.query_contiguous_hits(&query, &continuations);
assert_eq!(hits.get(&WorkerWithDpRank::new(37, 0)), Some(&1));
}
#[test]
fn conflicting_child_hash_mapping_is_ignored() {
let mut index = TestLowerTierIndex::new();
index
.apply_event(store_event(47, 0, 0, Some(1500), &[131], &[1301]))
.unwrap();
index
.apply_event(store_event(47, 0, 1, Some(2500), &[231], &[1301]))
.unwrap();
let mut original_continuations = FxHashMap::default();
original_continuations.insert(
WorkerWithDpRank::new(47, 0),
LowerTierContinuation::new(0, ExternalSequenceBlockHash(1500)),
);
let original_hits =
index.query_contiguous_hits(&local_hashes(&[131]), &original_continuations);
assert_eq!(original_hits.get(&WorkerWithDpRank::new(47, 0)), Some(&1));
let mut conflicting_continuations = FxHashMap::default();
conflicting_continuations.insert(
WorkerWithDpRank::new(47, 0),
LowerTierContinuation::new(0, ExternalSequenceBlockHash(2500)),
);
let conflicting_hits =
index.query_contiguous_hits(&local_hashes(&[231]), &conflicting_continuations);
assert_eq!(
conflicting_hits.get(&WorkerWithDpRank::new(47, 0)),
Some(&0)
);
}
// --- Tests targeting optimization edge cases ---
/// Single-worker fast path: exercises the scalar loop that skips set
/// operations when only one worker is in the continuation map.
#[test]
fn single_worker_fast_path_full_match() {
let mut index = TestLowerTierIndex::new();
index
.apply_event(store_event(
50,
0,
0,
None,
&[1, 2, 3, 4, 5],
&[101, 102, 103, 104, 105],
))
.unwrap();
let mut continuations = FxHashMap::default();
continuations.insert(
WorkerWithDpRank::new(50, 0),
LowerTierContinuation::from_root(0),
);
let details = index.query_match_details(&local_hashes(&[1, 2, 3, 4, 5]), &continuations);
assert_eq!(details.hits.get(&WorkerWithDpRank::new(50, 0)), Some(&5));
assert_eq!(
details
.next_continuations
.get(&WorkerWithDpRank::new(50, 0)),
Some(&LowerTierContinuation::new(
5,
ExternalSequenceBlockHash(105),
)),
);
}
/// Single-worker fast path where the worker doesn't own the edge.
#[test]
fn single_worker_fast_path_no_match() {
let mut index = TestLowerTierIndex::new();
// Worker 50 owns the chain, but we query with worker 51.
index
.apply_event(store_event(50, 0, 0, None, &[1, 2], &[101, 102]))
.unwrap();
let mut continuations = FxHashMap::default();
continuations.insert(
WorkerWithDpRank::new(51, 0),
LowerTierContinuation::from_root(0),
);
let hits = index.query_contiguous_hits(&local_hashes(&[1, 2]), &continuations);
assert_eq!(hits.get(&WorkerWithDpRank::new(51, 0)), Some(&0));
}
/// Single-worker partial match: worker owns the first two edges but the
/// third edge doesn't exist, testing early termination in the scalar loop.
#[test]
fn single_worker_fast_path_partial_match() {
let mut index = TestLowerTierIndex::new();
index
.apply_event(store_event(52, 0, 0, None, &[1, 2], &[101, 102]))
.unwrap();
let mut continuations = FxHashMap::default();
continuations.insert(
WorkerWithDpRank::new(52, 0),
LowerTierContinuation::from_root(0),
);
let details = index.query_match_details(&local_hashes(&[1, 2, 3]), &continuations);
assert_eq!(details.hits.get(&WorkerWithDpRank::new(52, 0)), Some(&2));
assert_eq!(
details
.next_continuations
.get(&WorkerWithDpRank::new(52, 0)),
Some(&LowerTierContinuation::new(
2,
ExternalSequenceBlockHash(102),
)),
);
}
/// Exercises the Single-edge flip: two workers query, but the edge is
/// owned by only one of them (Single variant). The non-owner should be
/// finalized immediately.
#[test]
fn single_edge_owner_splits_active_set() {
let mut index = TestLowerTierIndex::new();
// Only worker 60 owns this chain.
index
.apply_event(store_event(60, 0, 0, None, &[1, 2, 3], &[101, 102, 103]))
.unwrap();
let mut continuations = FxHashMap::default();
continuations.insert(
WorkerWithDpRank::new(60, 0),
LowerTierContinuation::from_root(0),
);
continuations.insert(
WorkerWithDpRank::new(61, 0),
LowerTierContinuation::from_root(0),
);
let details = index.query_match_details(&local_hashes(&[1, 2, 3]), &continuations);
assert_eq!(details.hits.get(&WorkerWithDpRank::new(60, 0)), Some(&3));
assert_eq!(details.hits.get(&WorkerWithDpRank::new(61, 0)), Some(&0));
}
/// Multiple workers share an edge (Multi variant), but only a subset are
/// active. Tests the min-side iteration path.
#[test]
fn multi_edge_subset_of_owners_active() {
let mut index = TestLowerTierIndex::new();
// Workers 70, 71, 72 all own the same chain.
index
.apply_event(store_event(70, 0, 0, None, &[1, 2], &[101, 102]))
.unwrap();
index
.apply_event(store_event(71, 0, 1, None, &[1, 2], &[101, 102]))
.unwrap();
index
.apply_event(store_event(72, 0, 2, None, &[1, 2], &[101, 102]))
.unwrap();
// Query with only workers 70 and 71 (active < owners wouldn't apply
// here since counts are close, but the Multi branch is exercised).
let mut continuations = FxHashMap::default();
continuations.insert(
WorkerWithDpRank::new(70, 0),
LowerTierContinuation::from_root(0),
);
continuations.insert(
WorkerWithDpRank::new(71, 0),
LowerTierContinuation::from_root(0),
);
let details = index.query_match_details(&local_hashes(&[1, 2]), &continuations);
assert_eq!(details.hits.get(&WorkerWithDpRank::new(70, 0)), Some(&2));
assert_eq!(details.hits.get(&WorkerWithDpRank::new(71, 0)), Some(&2));
}
/// Multi-worker walk where one worker drops off mid-sequence, causing the
/// set to shrink to 1 and triggering the mid-loop scalar fast path.
#[test]
fn multi_to_single_worker_transition_mid_walk() {
let mut index = TestLowerTierIndex::new();
// Worker 80 owns [1,2,3,4], worker 81 owns only [1,2].
index
.apply_event(store_event(
80,
0,
0,
None,
&[1, 2, 3, 4],
&[101, 102, 103, 104],
))
.unwrap();
index
.apply_event(store_event(81, 0, 1, None, &[1, 2], &[101, 102]))
.unwrap();
let mut continuations = FxHashMap::default();
continuations.insert(
WorkerWithDpRank::new(80, 0),
LowerTierContinuation::from_root(0),
);
continuations.insert(
WorkerWithDpRank::new(81, 0),
LowerTierContinuation::from_root(0),
);
let details = index.query_match_details(&local_hashes(&[1, 2, 3, 4]), &continuations);
assert_eq!(details.hits.get(&WorkerWithDpRank::new(80, 0)), Some(&4));
assert_eq!(details.hits.get(&WorkerWithDpRank::new(81, 0)), Some(&2));
assert_eq!(
details
.next_continuations
.get(&WorkerWithDpRank::new(80, 0)),
Some(&LowerTierContinuation::new(
4,
ExternalSequenceBlockHash(104),
)),
);
assert_eq!(
details
.next_continuations
.get(&WorkerWithDpRank::new(81, 0)),
Some(&LowerTierContinuation::new(
2,
ExternalSequenceBlockHash(102),
)),
);
}
/// All active workers drop off at the same position because none of them
/// own the edge (Single variant, owner not in active set).
#[test]
fn single_edge_no_active_worker_owns_it() {
let mut index = TestLowerTierIndex::new();
index
.apply_event(store_event(90, 0, 0, None, &[1, 2], &[101, 102]))
.unwrap();
let mut continuations = FxHashMap::default();
continuations.insert(
WorkerWithDpRank::new(91, 0),
LowerTierContinuation::from_root(0),
);
continuations.insert(
WorkerWithDpRank::new(92, 0),
LowerTierContinuation::from_root(0),
);
let details = index.query_match_details(&local_hashes(&[1, 2]), &continuations);
assert_eq!(details.hits.get(&WorkerWithDpRank::new(91, 0)), Some(&0));
assert_eq!(details.hits.get(&WorkerWithDpRank::new(92, 0)), Some(&0));
}
/// Single-worker fast path hitting the breakpoint boundary — worker starts
/// at pos 0 but a second worker's start_pos creates a breakpoint at pos 2.
/// The first worker should stop at the breakpoint, then be re-merged in the
/// frontier and continue.
#[test]
fn single_worker_stops_at_breakpoint_then_continues() {
let mut index = TestLowerTierIndex::new();
index
.apply_event(store_event(
95,
0,
0,
None,
&[1, 2, 3, 4],
&[101, 102, 103, 104],
))
.unwrap();
index
.apply_event(store_event(96, 0, 1, Some(102), &[3, 4], &[103, 104]))
.unwrap();
let mut continuations = FxHashMap::default();
continuations.insert(
WorkerWithDpRank::new(95, 0),
LowerTierContinuation::from_root(0),
);
continuations.insert(
WorkerWithDpRank::new(96, 0),
LowerTierContinuation::new(2, ExternalSequenceBlockHash(102)),
);
let details = index.query_match_details(&local_hashes(&[1, 2, 3, 4]), &continuations);
assert_eq!(details.hits.get(&WorkerWithDpRank::new(95, 0)), Some(&4));
assert_eq!(details.hits.get(&WorkerWithDpRank::new(96, 0)), Some(&2));
}
/// Exercises the Multi-edge path where the active set is larger than the
/// owner set (iterate owners side).
#[test]
fn multi_edge_fewer_owners_than_active_workers() {
let mut index = TestLowerTierIndex::new();
// Edge owned by workers 100 and 101 (Multi with 2 owners).
index
.apply_event(store_event(100, 0, 0, None, &[1], &[101]))
.unwrap();
index
.apply_event(store_event(101, 0, 1, None, &[1], &[101]))
.unwrap();
// Query with 4 workers — only 2 own the edge.
let mut continuations = FxHashMap::default();
for id in 100..104 {
continuations.insert(
WorkerWithDpRank::new(id, 0),
LowerTierContinuation::from_root(0),
);
}
let details = index.query_match_details(&local_hashes(&[1]), &continuations);
assert_eq!(details.hits.get(&WorkerWithDpRank::new(100, 0)), Some(&1),);
assert_eq!(details.hits.get(&WorkerWithDpRank::new(101, 0)), Some(&1),);
assert_eq!(details.hits.get(&WorkerWithDpRank::new(102, 0)), Some(&0),);
assert_eq!(details.hits.get(&WorkerWithDpRank::new(103, 0)), Some(&0),);
}
/// Empty continuations map — should return empty results without panicking.
#[test]
fn empty_continuations_returns_empty_results() {
let mut index = TestLowerTierIndex::new();
index
.apply_event(store_event(110, 0, 0, None, &[1, 2], &[101, 102]))
.unwrap();
let continuations: FxHashMap<WorkerWithDpRank, LowerTierContinuation> =
FxHashMap::default();
let details = index.query_match_details(&local_hashes(&[1, 2]), &continuations);
assert!(details.hits.is_empty());
assert!(details.next_continuations.is_empty());
}
/// Empty sequence — every worker should get 0 hits.
#[test]
fn empty_sequence_returns_zero_hits() {
let mut index = TestLowerTierIndex::new();
index
.apply_event(store_event(111, 0, 0, None, &[1], &[101]))
.unwrap();
let mut continuations = FxHashMap::default();
continuations.insert(
WorkerWithDpRank::new(111, 0),
LowerTierContinuation::from_root(0),
);
let details = index.query_match_details(&local_hashes(&[]), &continuations);
assert_eq!(details.hits.get(&WorkerWithDpRank::new(111, 0)), Some(&0));
}
// --- dump_events tests ---
/// Helper: replay dumped events into a fresh indexer and return it.
fn replay_dump(events: Vec<crate::protocols::RouterEvent>) -> TestLowerTierIndex {
let mut fresh = TestLowerTierIndex::new();
for event in events {
fresh.apply_event(event).unwrap();
}
fresh
}
#[test]
fn dump_empty_indexer_returns_no_events() {
let index = TestLowerTierIndex::new();
assert!(index.dump_events().is_empty());
}
#[test]
fn dump_round_trip_single_chain() {
let mut index = TestLowerTierIndex::new();
index
.apply_event(store_event(7, 0, 0, None, &[11, 12, 13], &[101, 102, 103]))
.unwrap();
let events = index.dump_events();
assert_eq!(events.len(), 3);
let restored = replay_dump(events);
let mut continuations = FxHashMap::default();
continuations.insert(
WorkerWithDpRank::new(7, 0),
LowerTierContinuation::from_root(0),
);
let original = index.query_contiguous_hits(&local_hashes(&[11, 12, 13]), &continuations);
let replayed = restored.query_contiguous_hits(&local_hashes(&[11, 12, 13]), &continuations);
assert_eq!(original, replayed);
assert_eq!(replayed.get(&WorkerWithDpRank::new(7, 0)), Some(&3));
}
#[test]
fn dump_round_trip_multiple_workers() {
let mut index = TestLowerTierIndex::new();
index
.apply_event(store_event(1, 0, 0, None, &[11, 12], &[101, 102]))
.unwrap();
index
.apply_event(store_event(2, 0, 1, Some(500), &[21, 22], &[201, 202]))
.unwrap();
let events = index.dump_events();
assert_eq!(events.len(), 4);
let restored = replay_dump(events);
// Worker 1: root chain
let mut c1 = FxHashMap::default();
c1.insert(
WorkerWithDpRank::new(1, 0),
LowerTierContinuation::from_root(0),
);
assert_eq!(
index.query_contiguous_hits(&local_hashes(&[11, 12]), &c1),
restored.query_contiguous_hits(&local_hashes(&[11, 12]), &c1),
);
// Worker 2: non-root chain
let mut c2 = FxHashMap::default();
c2.insert(
WorkerWithDpRank::new(2, 0),
LowerTierContinuation::new(0, ExternalSequenceBlockHash(500)),
);
assert_eq!(
index.query_contiguous_hits(&local_hashes(&[21, 22]), &c2),
restored.query_contiguous_hits(&local_hashes(&[21, 22]), &c2),
);
}
#[test]
fn dump_round_trip_shared_edges() {
let mut index = TestLowerTierIndex::new();
// Two workers own the same chain.
index
.apply_event(store_event(1, 0, 0, None, &[11, 12], &[101, 102]))
.unwrap();
index
.apply_event(store_event(2, 0, 1, None, &[11, 12], &[101, 102]))
.unwrap();
let events = index.dump_events();
// 2 blocks * 2 workers = 4 events (each worker's blocks are dumped
// independently even if they share the same underlying edges).
assert_eq!(events.len(), 4);
let restored = replay_dump(events);
let mut continuations = FxHashMap::default();
continuations.insert(
WorkerWithDpRank::new(1, 0),
LowerTierContinuation::from_root(0),
);
continuations.insert(
WorkerWithDpRank::new(2, 0),
LowerTierContinuation::from_root(0),
);
assert_eq!(
index.query_contiguous_hits(&local_hashes(&[11, 12]), &continuations),
restored.query_contiguous_hits(&local_hashes(&[11, 12]), &continuations),
);
}
#[test]
fn dump_after_removal_excludes_removed_blocks() {
let mut index = TestLowerTierIndex::new();
index
.apply_event(store_event(
5,
0,
0,
Some(800),
&[31, 32, 33],
&[301, 302, 303],
))
.unwrap();
// Remove the middle block.
index
.apply_event(remove_event(5, 1, 0, vec![ExternalSequenceBlockHash(302)]))
.unwrap();
let events = index.dump_events();
// Only 2 blocks remain (301 and 303).
assert_eq!(events.len(), 2);
let restored = replay_dump(events);
let mut continuations = FxHashMap::default();
continuations.insert(
WorkerWithDpRank::new(5, 0),
LowerTierContinuation::new(0, ExternalSequenceBlockHash(800)),
);
// Original and restored should give the same result: only 1 hit
// (block 301 matches, 302 is gone so the chain breaks).
assert_eq!(
index.query_contiguous_hits(&local_hashes(&[31, 32, 33]), &continuations),
restored.query_contiguous_hits(&local_hashes(&[31, 32, 33]), &continuations),
);
}
#[test]
fn dump_round_trip_multiple_dp_ranks() {
let mut index = TestLowerTierIndex::new();
index
.apply_event(store_event(10, 0, 0, None, &[1, 2], &[101, 102]))
.unwrap();
index
.apply_event(store_event(10, 1, 1, None, &[3, 4], &[301, 302]))
.unwrap();
let events = index.dump_events();
assert_eq!(events.len(), 4);
let restored = replay_dump(events);
// Verify dp_rank=0 chain
let mut c0 = FxHashMap::default();
c0.insert(
WorkerWithDpRank::new(10, 0),
LowerTierContinuation::from_root(0),
);
assert_eq!(
index.query_contiguous_hits(&local_hashes(&[1, 2]), &c0),
restored.query_contiguous_hits(&local_hashes(&[1, 2]), &c0),
);
// Verify dp_rank=1 chain
let mut c1 = FxHashMap::default();
c1.insert(
WorkerWithDpRank::new(10, 1),
LowerTierContinuation::from_root(0),
);
assert_eq!(
index.query_contiguous_hits(&local_hashes(&[3, 4]), &c1),
restored.query_contiguous_hits(&local_hashes(&[3, 4]), &c1),
);
}
#[tokio::test]
async fn thread_pool_dump_events_round_trip() {
let index = ThreadPoolIndexer::new(LowerTierIndexer::new(), 2, 1);
let worker = WorkerWithDpRank::new(7, 0);
index
.apply_event(store_event(7, 0, 0, None, &[11, 12, 13], &[101, 102, 103]))
.await;
let events = index.dump_events().await.unwrap();
assert_eq!(events.len(), 3);
// Replay into a fresh ThreadPoolIndexer.
let restored = ThreadPoolIndexer::new(LowerTierIndexer::new(), 2, 1);
for event in events {
restored.apply_event(event).await;
}
let _ = restored.dump_events().await.unwrap();
let mut continuations = FxHashMap::default();
continuations.insert(worker, LowerTierContinuation::from_root(0));
let original = index
.backend()
.query_contiguous_hits(&local_hashes(&[11, 12, 13]), &continuations);
let replayed = restored
.backend()
.query_contiguous_hits(&local_hashes(&[11, 12, 13]), &continuations);
assert_eq!(original, replayed);
assert_eq!(replayed.get(&worker), Some(&3));
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment