Unverified Commit e3d00b89 authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: Tier-based KV Routing (#8380)


Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent 7e48f3bd
......@@ -660,6 +660,7 @@ async def init_llm_worker(
kv_block_size=config.kv_block_size,
zmq_endpoint=consolidator_output_connect_endpoint,
zmq_topic="",
enable_local_indexer=config.enable_local_indexer,
)
logging.info(
f"Created worker-side publisher for consolidated events: "
......
......@@ -8,7 +8,7 @@ use common::*;
use clap::Parser;
use common::NoopSequencePublisher;
use dynamo_kv_router::protocols::{PrefillLoadHint, WorkerWithDpRank};
use dynamo_kv_router::{ActiveSequencesMultiWorker, OverlapScores, SequenceRequest};
use dynamo_kv_router::{ActiveSequencesMultiWorker, SequenceRequest};
use dynamo_mocker::loadgen::Trace;
use dynamo_tokens::SequenceHash;
use std::collections::HashMap;
......@@ -379,12 +379,7 @@ async fn apply_entry(
isl,
output_length,
} => {
let _ = multi.potential_blocks_and_tokens(
Some(&block_hashes),
isl,
OverlapScores::default(),
decay_now,
);
let _ = multi.potential_blocks_and_tokens(Some(&block_hashes), isl, HashMap::new());
let _ = multi.add_request(
SequenceRequest {
request_id,
......
......@@ -521,7 +521,6 @@ impl RouterHandles {
None,
0.0,
None,
None,
allowed_worker_ids,
)
.await
......@@ -884,12 +883,13 @@ pub unsafe extern "C" fn add_request(
}
};
let cached_tokens = overlap_blocks as usize * decode_router.block_size() as usize;
decode_router
.add_request(
request_id_str.clone(),
&tokens,
None,
overlap_blocks,
cached_tokens,
None,
worker,
None, // lora_name
......
......@@ -7,7 +7,7 @@ from typing import List, Optional
import tensorrt_llm
from kvbm import KvbmLeader
from kvbm.trtllm_integration.consolidator_config import is_truthy
from kvbm.trtllm_integration.consolidator_config import get_consolidator_mode, is_truthy
from kvbm.trtllm_integration.rust import KvbmRequest
from kvbm.trtllm_integration.rust import KvConnectorLeader as RustKvConnectorLeader
from kvbm.trtllm_integration.rust import SchedulerOutput as RustSchedulerOutput
......@@ -55,7 +55,9 @@ class DynamoKVBMConnectorLeader(KvCacheConnectorScheduler):
trtllm_ep = None
consolidator_output_ep = None
consolidator_mode = None
if consolidator_enabled:
consolidator_mode = get_consolidator_mode()
# Get consolidator endpoint from environment variable
# DYN_KVBM_TRTLLM_ZMQ_PORT contains just the port number (e.g., "20081")
zmq_port = os.getenv("DYN_KVBM_TRTLLM_ZMQ_PORT")
......@@ -105,6 +107,7 @@ class DynamoKVBMConnectorLeader(KvCacheConnectorScheduler):
leader,
consolidator_trtllm_endpoint=trtllm_ep,
consolidator_output_endpoint=consolidator_output_ep,
consolidator_mode=consolidator_mode,
)
@nvtx_annotate(category="scheduler")
......@@ -132,6 +135,7 @@ class DynamoKVBMConnectorLeader(KvCacheConnectorScheduler):
req.new_block_ids,
req.computed_position,
req.priorities, # Pass retention priorities for offload filtering
list(req.block_hashes),
)
resumed_from_preemption = False
......@@ -143,6 +147,7 @@ class DynamoKVBMConnectorLeader(KvCacheConnectorScheduler):
req.new_block_ids,
req.computed_position,
req.priorities, # Pass retention priorities for offload filtering
list(req.block_hashes),
)
output.add_num_scheduled_tokens(
......
......@@ -8,21 +8,27 @@ Helper functions for KV Event Consolidator configuration for TensorRT-LLM.
import logging
import os
from kvbm.utils import get_consolidator_mode, is_truthy
__all__ = [
"get_consolidator_endpoints",
"get_consolidator_mode",
"is_truthy",
"should_enable_consolidator",
]
logger = logging.getLogger(__name__)
def is_truthy(val: str) -> bool:
"""
Check if a string represents a truthy value.
Truthy values: "1", "true", "on", "yes" (case-insensitive)
def _get_connector_module(kv_connector_config) -> str | None:
"""Extract connector_module from either a dict or a TRT-LLM config object."""
if kv_connector_config is None:
return None
Args:
val: The string value to check
if isinstance(kv_connector_config, dict):
return kv_connector_config.get("connector_module")
Returns:
True if the value is truthy, False otherwise
"""
return val.lower() in ("1", "true", "on", "yes")
return getattr(kv_connector_config, "connector_module", None)
def should_enable_consolidator(arg_map) -> bool:
......@@ -48,23 +54,19 @@ def should_enable_consolidator(arg_map) -> bool:
)
return False
# Check if KVBM connector is enabled by extracting connector_module
# from kv_connector_config (works whether arg_map holds raw dicts or typed objects)
kv_connector_config = (
arg_map.get("kv_connector_config") if isinstance(arg_map, dict) else None
)
if kv_connector_config is None:
# Check if KVBM connector is enabled
if not isinstance(arg_map, dict):
logger.warning("KV Event Consolidator is not enabled: arg_map is not a dict")
return False
kv_connector_config = arg_map.get("kv_connector_config")
connector_module = _get_connector_module(kv_connector_config) or ""
if not connector_module:
logger.warning(
"KV Event Consolidator is not enabled: no kv_connector_config found"
"KV Event Consolidator is not enabled: kv_connector_config has no connector_module"
)
return False
if isinstance(kv_connector_config, dict):
connector_module = kv_connector_config.get("connector_module", "")
else:
# Access directly so AttributeError surfaces if the contract changes
connector_module = kv_connector_config.connector_module or ""
has_kvbm_connector = "kvbm.trtllm_integration.connector" in connector_module
if not has_kvbm_connector:
......
......@@ -2,8 +2,33 @@
# SPDX-License-Identifier: Apache-2.0
import logging
import os
logger = logging.getLogger(__name__)
def is_truthy(val: str) -> bool:
"""Truthy values: "1", "true", "on", "yes" (case-insensitive)."""
return val.strip().lower() in ("1", "true", "on", "yes")
def get_consolidator_mode() -> str:
"""Return the KV event consolidator mode from DYN_KVBM_KV_EVENTS_CONSOLIDATOR_MODE.
Returns "dedup" or "passthrough"; invalid/unset values fall back to "dedup".
"""
mode = os.getenv("DYN_KVBM_KV_EVENTS_CONSOLIDATOR_MODE", "dedup").strip().lower()
if mode in ("dedup", "passthrough"):
return mode
logger.warning(
"Invalid DYN_KVBM_KV_EVENTS_CONSOLIDATOR_MODE=%r. Falling back to 'dedup'.",
mode,
)
return "dedup"
try:
from nvtx import annotate # type: ignore
except ImportError:
......
......@@ -30,6 +30,7 @@ if TYPE_CHECKING:
from kvbm import KvbmLeader
from kvbm.utils import is_dyn_runtime_enabled
from kvbm.vllm_integration.consolidator_config import get_consolidator_mode
from kvbm.vllm_integration.rust import KvbmRequest
from kvbm.vllm_integration.rust import KvConnectorLeader as RustKvConnectorLeader
from kvbm.vllm_integration.rust import SchedulerOutput as RustSchedulerOutput
......@@ -72,10 +73,12 @@ class KvConnectorLeader:
# Get kv event consolidator endpoints from vllm_config (pre-computed in main.py)
consolidator_vllm_endpoint = None
consolidator_output_endpoint = None
consolidator_mode = None
self._consolidator_output_port = None
_consolidator_eps = vllm_config.additional_config.get("consolidator_endpoints")
if _consolidator_eps:
consolidator_mode = get_consolidator_mode()
# Unpack all three endpoints
# [0]: vllm_endpoint (for consolidator to subscribe to vLLM)
# [1]: output_bind_endpoint (for consolidator to bind/publish)
......@@ -97,6 +100,7 @@ class KvConnectorLeader:
leader,
consolidator_vllm_endpoint=consolidator_vllm_endpoint,
consolidator_output_endpoint=consolidator_output_endpoint,
consolidator_mode=consolidator_mode,
)
else:
# No kv event consolidator - pass None to Rust
......@@ -107,6 +111,7 @@ class KvConnectorLeader:
leader,
consolidator_vllm_endpoint=None,
consolidator_output_endpoint=None,
consolidator_mode=None,
)
# KV Connector
......
......@@ -9,23 +9,17 @@ import logging
import os
from typing import Optional, Tuple
from kvbm.utils import get_consolidator_mode, is_truthy
from vllm.distributed.kv_events import ZmqEventPublisher
logger = logging.getLogger(__name__)
__all__ = [
"get_consolidator_endpoints",
"get_consolidator_mode",
"is_truthy",
"should_enable_consolidator",
]
def is_truthy(val: str) -> bool:
"""
Check if a string represents a truthy value.
Truthy values: "1", "true", "on", "yes" (case-insensitive)
Args:
val: The string value to check
Returns:
True if the value is truthy, False otherwise
"""
return val.lower() in ("1", "true", "on", "yes")
logger = logging.getLogger(__name__)
def should_enable_consolidator(vllm_config) -> bool:
......
......@@ -6,7 +6,7 @@ use anyhow::Result;
use dynamo_llm::block_manager::block::{
data::logical::distributed_leader_worker::DistributedLeaderWorkerResources, locality::Logical,
};
use dynamo_llm::block_manager::kv_consolidator::EventSource;
use dynamo_llm::block_manager::kv_consolidator::{EventSource, KvEventConsolidationMode};
use dynamo_llm::block_manager::offload::filter::FrequencyFilter;
use dynamo_llm::block_manager::{BasicMetadata, BlockParallelismStrategy};
use dynamo_runtime::DistributedRuntime;
......@@ -252,7 +252,12 @@ pub struct BlockManagerBuilder {
page_size: usize,
disable_device_pool: bool,
kvbm_metrics: Option<dynamo_llm::block_manager::metrics_kvbm::KvbmMetrics>,
consolidator_config: Option<(String, Option<String>, EventSource)>, // (engine_endpoint, output_endpoint (optional), engine_source)
consolidator_config: Option<(
String,
Option<String>,
EventSource,
KvEventConsolidationMode,
)>, // (engine_endpoint, output_endpoint (optional), engine_source, mode)
}
impl BlockManagerBuilder {
......@@ -293,8 +298,9 @@ impl BlockManagerBuilder {
engine_endpoint: String,
output_endpoint: Option<String>,
engine_source: EventSource,
mode: KvEventConsolidationMode,
) -> Self {
self.consolidator_config = Some((engine_endpoint, output_endpoint, engine_source));
self.consolidator_config = Some((engine_endpoint, output_endpoint, engine_source, mode));
self
}
......@@ -368,9 +374,9 @@ impl BlockManagerBuilder {
config_builder = config_builder.kvbm_metrics(Some(kvbm_metrics));
}
if let Some((engine_ep, output_ep, engine_source)) = self.consolidator_config {
if let Some((engine_ep, output_ep, engine_source, mode)) = self.consolidator_config {
config_builder =
config_builder.consolidator_config(engine_ep, output_ep, engine_source);
config_builder.consolidator_config(engine_ep, output_ep, engine_source, mode);
}
let config = config_builder.build()?;
......
......@@ -4,6 +4,7 @@
use dynamo_llm::block_manager::{
block::BlockId, connector::protocol::WorkerTransferRequest, pool::BlockPoolError,
};
use dynamo_llm::tokens::SequenceHash;
pub mod leader;
pub mod trtllm_leader;
......@@ -42,7 +43,7 @@ impl SchedulerOutput {
// I am surprised that vLLM's NewRequestData does not include the salt hash.
// It has almost everything else to compute the block hashes worker side.
#[pyo3(signature = (request_id, prompt_token_ids, block_ids, num_computed_tokens, priorities=None))]
#[pyo3(signature = (request_id, prompt_token_ids, block_ids, num_computed_tokens, priorities=None, external_sequence_hashes=None))]
pub fn add_new_request(
&mut self,
request_id: String,
......@@ -50,6 +51,7 @@ impl SchedulerOutput {
block_ids: Vec<BlockId>,
num_computed_tokens: usize,
priorities: Option<Vec<u32>>,
external_sequence_hashes: Option<Vec<SequenceHash>>,
) {
self.new_requests.push(NewRequestData {
request_id,
......@@ -57,11 +59,13 @@ impl SchedulerOutput {
block_ids,
num_computed_tokens,
priorities,
external_sequence_hashes,
});
}
/// This is called by the leader to update the cached requests
#[pyo3(signature = (request_id, resumed_from_preemption, new_token_ids, new_block_ids, num_computed_tokens, priorities=None))]
#[pyo3(signature = (request_id, resumed_from_preemption, new_token_ids, new_block_ids, num_computed_tokens, priorities=None, external_sequence_hashes=None))]
#[allow(clippy::too_many_arguments)]
pub fn add_cached_request(
&mut self,
request_id: String,
......@@ -70,6 +74,7 @@ impl SchedulerOutput {
new_block_ids: Vec<BlockId>,
num_computed_tokens: usize,
priorities: Option<Vec<u32>>,
external_sequence_hashes: Option<Vec<SequenceHash>>,
) {
self.cached_requests.push(CachedRequestData {
request_id,
......@@ -78,6 +83,7 @@ impl SchedulerOutput {
new_block_ids,
num_computed_tokens,
priorities,
external_sequence_hashes,
});
}
......@@ -108,6 +114,8 @@ pub struct NewRequestData {
/// Retention priorities for each block (same length as block_ids).
/// Used for priority-based offload filtering.
pub priorities: Option<Vec<u32>>,
/// TRT-LLM cumulative external sequence-hash chain for all completed blocks.
pub external_sequence_hashes: Option<Vec<SequenceHash>>,
}
impl std::fmt::Debug for NewRequestData {
......@@ -131,6 +139,8 @@ pub struct CachedRequestData {
/// Retention priorities for each new block (same length as new_block_ids).
/// Used for priority-based offload filtering.
pub priorities: Option<Vec<u32>>,
/// TRT-LLM cumulative external sequence-hash chain for all completed blocks.
pub external_sequence_hashes: Option<Vec<SequenceHash>>,
}
impl std::fmt::Debug for CachedRequestData {
......@@ -188,3 +198,40 @@ impl ConnectorMetadata {
self.operations.extend(xfer_reqs);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_scheduler_output_preserves_external_sequence_hashes() {
let mut output = SchedulerOutput::new();
output.add_new_request(
"req-1".to_string(),
vec![1, 2, 3, 4],
vec![10],
4,
None,
Some(vec![101]),
);
output.add_cached_request(
"req-1".to_string(),
false,
vec![5, 6, 7, 8],
vec![11],
8,
None,
Some(vec![101, 202]),
);
assert_eq!(
output.new_requests[0].external_sequence_hashes,
Some(vec![101])
);
assert_eq!(
output.cached_requests[0].external_sequence_hashes,
Some(vec![101, 202])
);
}
}
......@@ -22,7 +22,7 @@ use dynamo_llm::block_manager::{
locality::Logical,
},
connector::{protocol::RequestType, *},
kv_consolidator::EventSource,
kv_consolidator::{EventSource, KvEventConsolidationMode},
};
use dynamo_llm::tokens::{SaltHash, TokenBlockSequence, Tokens};
use dynamo_runtime::config::environment_names::kvbm as env_kvbm;
......@@ -35,6 +35,24 @@ use tokio::sync::oneshot;
type VllmLocality = Logical<DistributedLeaderWorkerResources>;
fn parse_consolidator_mode(mode: Option<String>) -> KvEventConsolidationMode {
let Some(mode) = mode else {
return KvEventConsolidationMode::Dedup;
};
match mode.parse() {
Ok(mode) => mode,
Err(error) => {
tracing::warn!(
"Invalid KV event consolidator mode {:?}: {}. Falling back to dedup.",
mode,
error
);
KvEventConsolidationMode::Dedup
}
}
}
impl From<SlotError> for PyErr {
fn from(err: SlotError) -> Self {
to_pyerr(err)
......@@ -94,6 +112,7 @@ impl KvConnectorLeader {
leader_py: PyKvbmLeader,
consolidator_vllm_endpoint: Option<String>,
consolidator_output_endpoint: Option<String>,
consolidator_mode: Option<String>,
) -> Self {
tracing::info!(
"KvConnectorLeader initialized with worker_id: {}",
......@@ -118,6 +137,7 @@ impl KvConnectorLeader {
// Capture consolidator endpoints for the async block
let consolidator_vllm_ep = consolidator_vllm_endpoint.clone();
let consolidator_output_ep = consolidator_output_endpoint.clone();
let consolidator_mode = parse_consolidator_mode(consolidator_mode.clone());
handle.spawn(async move {
let ready = leader.wait_worker_sync_ready().await;
......@@ -148,6 +168,7 @@ impl KvConnectorLeader {
vllm_ep,
Some(output_ep),
EventSource::Vllm,
consolidator_mode,
);
}
......@@ -435,6 +456,7 @@ impl Leader for KvConnectorLeader {
new_req.num_computed_tokens,
scheduled_tokens,
None,
None,
)?;
let pending_ops_opt = slot.take_pending_operations();
......@@ -506,6 +528,7 @@ impl Leader for KvConnectorLeader {
cached_req.num_computed_tokens,
scheduled_tokens,
None,
None,
)?;
if let Some(pending_ops) = slot.take_pending_operations() {
......@@ -621,7 +644,7 @@ pub struct PyKvConnectorLeader {
#[pymethods]
impl PyKvConnectorLeader {
#[new]
#[pyo3(signature = (worker_id, drt, page_size, leader, consolidator_vllm_endpoint=None, consolidator_output_endpoint=None))]
#[pyo3(signature = (worker_id, drt, page_size, leader, consolidator_vllm_endpoint=None, consolidator_output_endpoint=None, consolidator_mode=None))]
pub fn new(
worker_id: String,
drt: Option<PyObject>,
......@@ -629,6 +652,7 @@ impl PyKvConnectorLeader {
leader: PyKvbmLeader,
consolidator_vllm_endpoint: Option<String>,
consolidator_output_endpoint: Option<String>,
consolidator_mode: Option<String>,
) -> PyResult<Self> {
let _ = &drt; // drt is currently un-used in leader
......@@ -646,6 +670,7 @@ impl PyKvConnectorLeader {
leader,
consolidator_vllm_endpoint,
consolidator_output_endpoint,
consolidator_mode,
))
} else {
Box::new(KvConnectorLeader::new(
......@@ -654,6 +679,7 @@ impl PyKvConnectorLeader {
leader,
consolidator_vllm_endpoint,
consolidator_output_endpoint,
consolidator_mode,
))
};
Ok(Self { connector_leader })
......
......@@ -92,6 +92,7 @@ impl KvConnectorLeaderRecorder {
leader_py: PyKvbmLeader,
consolidator_vllm_endpoint: Option<String>,
consolidator_output_endpoint: Option<String>,
consolidator_mode: Option<String>,
) -> Self {
tracing::info!(
"KvConnectorLeaderRecorder initialized with worker_id: {}",
......@@ -131,6 +132,7 @@ impl KvConnectorLeaderRecorder {
// Capture consolidator endpoints for the async block
let consolidator_vllm_ep = consolidator_vllm_endpoint.clone();
let consolidator_output_ep = consolidator_output_endpoint.clone();
let consolidator_mode = super::parse_consolidator_mode(consolidator_mode.clone());
handle.spawn(async move {
let ready = leader.wait_worker_sync_ready().await;
......@@ -156,6 +158,7 @@ impl KvConnectorLeaderRecorder {
vllm_ep,
Some(output_ep),
EventSource::Vllm,
consolidator_mode,
);
}
......
......@@ -16,7 +16,7 @@ use dynamo_llm::{
connector::protocol::{LeaderTransferRequest, RequestType, TransferType},
distributed::{BlockTransferPool, BlockTransferRequest, KvbmLeader},
},
tokens::TokenBlock,
tokens::{SequenceHash, TokenBlock},
};
use dynamo_runtime::utils::task::CriticalTaskExecutionHandle;
use tokio_util::sync::CancellationToken;
......@@ -114,6 +114,7 @@ pub trait Slot: std::fmt::Debug {
num_computed_tokens: usize,
num_scheduled_tokens: usize,
priorities: Option<&[u32]>,
external_sequence_hashes: Option<&[SequenceHash]>,
) -> Result<(), SlotError>;
fn record_start_iteration(&mut self, iteration: u64) -> Result<(), SlotError>;
......@@ -481,6 +482,29 @@ impl VllmConnectorSlot {
&self.device_blocks
}
fn sync_external_sequence_hashes(
&mut self,
external_sequence_hashes: &[SequenceHash],
) -> Result<(), SlotError> {
assert_eq!(
external_sequence_hashes.len(),
self.sequence.blocks().len(),
"external_sequence_hashes length ({}) must match completed block count ({}) for request {}",
external_sequence_hashes.len(),
self.sequence.blocks().len(),
self.request_id
);
self.sequence
.sync_external_sequence_hashes(external_sequence_hashes);
for block in self.sequence.blocks() {
block.assert_external_hashes_assigned();
}
Ok(())
}
fn mark_as_skipped_prefill(&mut self) -> Result<(), SlotError> {
if self.state != SlotState::Prefilling {
return Err(SlotError::InvalidState(format!(
......@@ -597,6 +621,7 @@ impl Slot for VllmConnectorSlot {
num_computed_tokens: usize,
num_scheduled_tokens: usize,
priorities: Option<&[u32]>,
external_sequence_hashes: Option<&[SequenceHash]>,
) -> Result<(), SlotError> {
tracing::debug!(
"ENTRY: apply_scheduler_output: req={}, tokens.len={}, block_ids.len={}, computed={}, scheduled={}, \
......@@ -634,6 +659,10 @@ impl Slot for VllmConnectorSlot {
self.state = SlotState::Prefilling;
}
if let Some(external_sequence_hashes) = external_sequence_hashes {
self.sync_external_sequence_hashes(external_sequence_hashes)?;
}
// Use max to advance both current_position and evaluated_blocks at least by num_computed_tokens.
// This logic is to prevent redundant block offloading.
self.current_position = max(self.current_position, num_computed_tokens);
......@@ -849,6 +878,12 @@ impl Slot for VllmConnectorSlot {
.copied()
.collect();
if external_sequence_hashes.is_some() {
for block in &offload_token_blocks {
block.assert_external_hashes_assigned();
}
}
self.offload_blocks(
&offload_block_ids,
&offload_token_blocks,
......@@ -1878,7 +1913,7 @@ impl<S: Storage, L: LocalityProvider, M: BlockMetadata> AnyBlocks for AnyImmutab
mod connector_tests {
use super::*;
use crate::block_manager::cache_stats::CacheStatsTracker;
use dynamo_llm::tokens::{SaltHash, Tokens};
use dynamo_llm::tokens::{SaltHash, SequenceHash, Tokens};
use std::sync::Arc;
use tokio::sync::mpsc;
......@@ -1914,6 +1949,12 @@ mod connector_tests {
(start..start + count).collect()
}
fn external_hashes(start: SequenceHash, count: usize) -> Vec<SequenceHash> {
(0..count)
.map(|offset| start + offset as SequenceHash)
.collect()
}
/// Drains all pending offload requests from the channel and returns their block IDs.
fn drain_offload_block_ids(
rx: &mut mpsc::UnboundedReceiver<LocalTransferRequest>,
......@@ -1941,7 +1982,7 @@ mod connector_tests {
assert_eq!(slot.num_device_blocks_allocated(), 3);
// Step 2: apply_scheduler_output with empty blocks (vLLM pattern)
slot.apply_scheduler_output(&[], &[], 0, num_tokens, None)
slot.apply_scheduler_output(&[], &[], 0, num_tokens, None, None)
.unwrap();
// device_blocks should still be exactly 3 — no double-add
......@@ -1964,13 +2005,70 @@ mod connector_tests {
// Step 2: apply_scheduler_output with THE SAME blocks (TRT-LLM pattern)
// Without the dedup guard, this doubles device_blocks to len=6.
slot.apply_scheduler_output(&[], &blocks, 0, num_tokens, None)
slot.apply_scheduler_output(&[], &blocks, 0, num_tokens, None, None)
.unwrap();
// device_blocks must still be exactly 3 — dedup guard prevented the double-add
assert_eq!(slot.num_device_blocks_allocated(), 3);
}
#[test]
fn test_trtllm_external_hashes_are_assigned_to_completed_blocks() {
let num_tokens = 96; // 3 blocks of 32
let (mut slot, _rx) = create_test_slot(num_tokens, 0);
let blocks = block_ids(100, 3);
let external_hashes = external_hashes(10_000, 3);
slot.append_mutable_device_blocks(&blocks).unwrap();
slot.apply_scheduler_output(&[], &blocks, 0, num_tokens, None, Some(&external_hashes))
.unwrap();
let completed_blocks = slot.sequence.blocks();
assert_eq!(completed_blocks.len(), 3);
assert_eq!(
completed_blocks[0].external_sequence_hash(),
Some(external_hashes[0])
);
assert_eq!(completed_blocks[0].external_parent_sequence_hash(), None);
assert_eq!(
completed_blocks[1].external_sequence_hash(),
Some(external_hashes[1])
);
assert_eq!(
completed_blocks[1].external_parent_sequence_hash(),
Some(external_hashes[0])
);
assert_eq!(
completed_blocks[2].external_sequence_hash(),
Some(external_hashes[2])
);
assert_eq!(
completed_blocks[2].external_parent_sequence_hash(),
Some(external_hashes[1])
);
}
#[test]
#[should_panic(expected = "external_sequence_hash mismatch")]
fn test_trtllm_external_hash_chain_mismatch_panics() {
let num_tokens = 96; // 3 blocks of 32
let (mut slot, _rx) = create_test_slot(num_tokens, 0);
let blocks = block_ids(100, 3);
let external_hashes = external_hashes(20_000, 3);
slot.append_mutable_device_blocks(&blocks).unwrap();
slot.apply_scheduler_output(&[], &blocks, 0, num_tokens, None, Some(&external_hashes))
.unwrap();
let mismatched_hashes = vec![
external_hashes[0],
external_hashes[1] + 1,
external_hashes[2],
];
slot.apply_scheduler_output(&[], &blocks, 0, num_tokens, None, Some(&mismatched_hashes))
.unwrap();
}
// ---------------------------------------------------------------
// Test 3: Decode adds a new block correctly
// ---------------------------------------------------------------
......@@ -1982,14 +2080,14 @@ mod connector_tests {
// Prefill: append + apply with empty blocks (vLLM pattern)
slot.append_mutable_device_blocks(&prefill_blocks).unwrap();
slot.apply_scheduler_output(&[], &[], 0, num_tokens, None)
slot.apply_scheduler_output(&[], &[], 0, num_tokens, None, None)
.unwrap();
assert_eq!(slot.num_device_blocks_allocated(), 3);
// Decode: new block at boundary (token 96 = block 3)
let decode_block = block_ids(200, 1);
let decode_token: Vec<u32> = vec![9999];
slot.apply_scheduler_output(&decode_token, &decode_block, 95, 1, None)
slot.apply_scheduler_output(&decode_token, &decode_block, 95, 1, None, None)
.unwrap();
assert_eq!(slot.num_device_blocks_allocated(), 4);
}
......@@ -2024,7 +2122,7 @@ mod connector_tests {
// Empty tokens → Prefilling state, and next_position(96) == total_tokens(96)
// so the early-return does not fire and offload proceeds.
slot.append_mutable_device_blocks(&blocks).unwrap();
slot.apply_scheduler_output(&[], &[], 0, num_tokens, None)
slot.apply_scheduler_output(&[], &[], 0, num_tokens, None, None)
.unwrap();
let offloads = drain_offload_block_ids(&mut rx);
......@@ -2045,7 +2143,7 @@ mod connector_tests {
// Use the TRT-LLM pattern: append_mutable first, then apply with same blocks + priorities.
// The dedup guard prevents the double-add, but priorities are still processed.
slot.append_mutable_device_blocks(&blocks).unwrap();
slot.apply_scheduler_output(&[], &blocks, 0, num_tokens, Some(&priorities))
slot.apply_scheduler_output(&[], &blocks, 0, num_tokens, Some(&priorities), None)
.unwrap();
// device_blocks should be 3 (dedup prevented doubling)
......@@ -2069,7 +2167,7 @@ mod connector_tests {
let priorities: Vec<u32> = vec![80, 80, 10, 10];
slot.append_mutable_device_blocks(&blocks).unwrap();
slot.apply_scheduler_output(&[], &blocks, 0, num_tokens, Some(&priorities))
slot.apply_scheduler_output(&[], &blocks, 0, num_tokens, Some(&priorities), None)
.unwrap();
let offloads = drain_offload_block_ids(&mut rx);
......@@ -2079,7 +2177,7 @@ mod connector_tests {
// Because offload was terminated, no further offloading should happen.
let decode_block = block_ids(200, 1);
let decode_token: Vec<u32> = vec![9999];
slot.apply_scheduler_output(&decode_token, &decode_block, 127, 1, None)
slot.apply_scheduler_output(&decode_token, &decode_block, 127, 1, None, None)
.unwrap();
let further_offloads = drain_offload_block_ids(&mut rx);
......@@ -2102,14 +2200,16 @@ mod connector_tests {
slot.append_mutable_device_blocks(&blocks).unwrap();
// Chunk 1: schedule first 64 tokens → evaluates blocks 0,1
slot.apply_scheduler_output(&[], &[], 0, 64, None).unwrap();
slot.apply_scheduler_output(&[], &[], 0, 64, None, None)
.unwrap();
let offloads_1 = drain_offload_block_ids(&mut rx);
assert_eq!(offloads_1.len(), 1);
assert_eq!(offloads_1[0], vec![100, 101]); // blocks 0,1
// Chunk 2: schedule next 64 tokens → evaluates blocks 2,3
// (uses cached_request pattern: empty tokens, empty blocks)
slot.apply_scheduler_output(&[], &[], 64, 64, None).unwrap();
slot.apply_scheduler_output(&[], &[], 64, 64, None, None)
.unwrap();
let offloads_2 = drain_offload_block_ids(&mut rx);
assert_eq!(offloads_2.len(), 1);
assert_eq!(offloads_2[0], vec![102, 103]); // blocks 2,3
......@@ -2133,7 +2233,7 @@ mod connector_tests {
// Step 2: apply_scheduler_output with overlapping blocks [12, 13].
// Suffix [12] of device_blocks matches prefix [12] of block_ids.
// Only block 13 is new and gets appended.
slot.apply_scheduler_output(&[], &[12, 13], 0, 128, None)
slot.apply_scheduler_output(&[], &[12, 13], 0, 128, None, None)
.unwrap();
assert_eq!(slot.num_device_blocks_allocated(), 4);
......@@ -2152,7 +2252,7 @@ mod connector_tests {
// Chunk 1: 3 blocks, all high priority, schedule 96 tokens
slot.append_mutable_device_blocks(&[10, 11, 12]).unwrap();
slot.apply_scheduler_output(&[], &[10, 11, 12], 0, 96, Some(&[80, 80, 80]))
slot.apply_scheduler_output(&[], &[10, 11, 12], 0, 96, Some(&[80, 80, 80]), None)
.unwrap();
let offloads_1 = drain_offload_block_ids(&mut rx);
......@@ -2171,7 +2271,7 @@ mod connector_tests {
slot.append_mutable_device_blocks(&[13]).unwrap();
assert_eq!(slot.num_device_blocks_allocated(), 4);
slot.apply_scheduler_output(&[], &[12, 13], 96, 32, Some(&[80, 10]))
slot.apply_scheduler_output(&[], &[12, 13], 96, 32, Some(&[80, 10]), None)
.unwrap();
// Candidate is block 13 (index 3, evaluated_blocks=3).
......@@ -2198,7 +2298,7 @@ mod connector_tests {
// block_ids[0]=11 is found at device_blocks[1], but device_blocks[1..] = [11,12]
// does NOT match block_ids[..2] = [11,14]. Contract violation.
slot.apply_scheduler_output(&[], &[11, 14], 0, 128, None)
slot.apply_scheduler_output(&[], &[11, 14], 0, 128, None, None)
.unwrap();
}
......@@ -2218,7 +2318,7 @@ mod connector_tests {
// Overlap: suffix [13,14] matches prefix [13,14], overlap=2.
// new_ids = [10]. But 10 ∈ device_blocks → contract violation.
slot.apply_scheduler_output(&[], &[13, 14, 10], 0, 192, None)
slot.apply_scheduler_output(&[], &[13, 14, 10], 0, 192, None, None)
.unwrap();
}
......@@ -2238,7 +2338,7 @@ mod connector_tests {
// block_ids[0]=10 found at device_blocks[0], suffix_len=3.
// device_blocks[0..3]=[10,11,12] == block_ids[0..3]=[10,11,12] → overlap=3.
// new_ids=[13,14], both genuinely new → extend.
slot.apply_scheduler_output(&[], &[10, 11, 12, 13, 14], 0, 160, None)
slot.apply_scheduler_output(&[], &[10, 11, 12, 13, 14], 0, 160, None, None)
.unwrap();
assert_eq!(slot.num_device_blocks_allocated(), 5);
......@@ -2261,7 +2361,7 @@ mod connector_tests {
assert_eq!(slot.num_device_blocks_allocated(), 5);
// Full overlap of all 5 existing blocks, plus 2 new.
slot.apply_scheduler_output(&[], &[10, 11, 12, 13, 14, 15, 16], 0, 224, None)
slot.apply_scheduler_output(&[], &[10, 11, 12, 13, 14, 15, 16], 0, 224, None, None)
.unwrap();
assert_eq!(slot.num_device_blocks_allocated(), 7);
......
......@@ -14,12 +14,30 @@ use crate::block_manager::{distributed::KvbmLeader as PyKvbmLeader, vllm::KvbmRe
use crate::get_current_tokio_handle;
use anyhow;
use dynamo_llm::block_manager::connector::protocol::RequestType;
use dynamo_llm::block_manager::kv_consolidator::EventSource;
use dynamo_llm::block_manager::kv_consolidator::{EventSource, KvEventConsolidationMode};
use dynamo_llm::block_manager::metrics_kvbm::{KvbmMetrics, KvbmMetricsRegistry};
use std::collections::HashSet;
use std::sync::{Arc, OnceLock};
use tokio::runtime::Handle;
fn parse_consolidator_mode(mode: Option<String>) -> KvEventConsolidationMode {
let Some(mode) = mode else {
return KvEventConsolidationMode::Dedup;
};
match mode.parse() {
Ok(mode) => mode,
Err(error) => {
tracing::warn!(
"Invalid KV event consolidator mode {:?}: {}. Falling back to dedup.",
mode,
error
);
KvEventConsolidationMode::Dedup
}
}
}
pub trait Leader: Send + Sync + std::fmt::Debug {
fn get_num_new_matched_tokens(
&mut self,
......@@ -71,6 +89,7 @@ impl KvConnectorLeader {
leader_py: PyKvbmLeader,
consolidator_trtllm_endpoint: Option<String>,
consolidator_output_endpoint: Option<String>,
consolidator_mode: Option<String>,
) -> Self {
tracing::info!(
"KvConnectorLeader initialized with worker_id: {}",
......@@ -95,6 +114,7 @@ impl KvConnectorLeader {
// Capture consolidator endpoints for the async block
let consolidator_trtllm_ep = consolidator_trtllm_endpoint.clone();
let consolidator_output_ep = consolidator_output_endpoint.clone();
let consolidator_mode = parse_consolidator_mode(consolidator_mode.clone());
handle.spawn(async move {
let ready = leader.wait_worker_sync_ready().await;
......@@ -125,6 +145,7 @@ impl KvConnectorLeader {
trtllm_ep,
consolidator_output_ep,
EventSource::Trtllm,
consolidator_mode,
);
}
......@@ -389,6 +410,7 @@ impl Leader for KvConnectorLeader {
new_req.num_computed_tokens,
scheduled_tokens,
new_req.priorities.as_deref(),
new_req.external_sequence_hashes.as_deref(),
)?;
let pending_ops_opt = slot.take_pending_operations();
......@@ -440,6 +462,7 @@ impl Leader for KvConnectorLeader {
cached_req.num_computed_tokens,
scheduled_tokens,
cached_req.priorities.as_deref(),
cached_req.external_sequence_hashes.as_deref(),
)?;
if let Some(pending_ops) = slot.take_pending_operations() {
......@@ -518,7 +541,7 @@ pub struct PyTrtllmKvConnectorLeader {
#[pymethods]
impl PyTrtllmKvConnectorLeader {
#[new]
#[pyo3(signature = (worker_id, drt, page_size, leader, consolidator_trtllm_endpoint=None, consolidator_output_endpoint=None))]
#[pyo3(signature = (worker_id, drt, page_size, leader, consolidator_trtllm_endpoint=None, consolidator_output_endpoint=None, consolidator_mode=None))]
pub fn new(
worker_id: u64,
drt: Option<PyObject>,
......@@ -526,6 +549,7 @@ impl PyTrtllmKvConnectorLeader {
leader: PyKvbmLeader,
consolidator_trtllm_endpoint: Option<String>,
consolidator_output_endpoint: Option<String>,
consolidator_mode: Option<String>,
) -> PyResult<Self> {
let _ = &drt; // drt is currently un-used in leader
......@@ -535,6 +559,7 @@ impl PyTrtllmKvConnectorLeader {
leader,
consolidator_trtllm_endpoint,
consolidator_output_endpoint,
consolidator_mode,
));
Ok(Self { connector_leader })
}
......
......@@ -126,10 +126,12 @@ impl AicPerfConfig {
#[pymethods]
impl KvRouterConfig {
#[new]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, durable_kv_events=false, router_replica_sync=false, router_track_active_blocks=true, router_track_output_blocks=false, router_assume_kv_reuse=true, router_track_prefill_tokens=true, router_prefill_load_model="none", router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8, router_queue_threshold=Some(4.0), router_event_threads=4, router_queue_policy="fcfs", use_remote_indexer=false, serve_indexer=false, shared_cache_multiplier=0.0, shared_cache_type="none"))]
#[pyo3(signature = (overlap_score_weight=1.0, host_cache_hit_weight=0.75, disk_cache_hit_weight=0.25, router_temperature=0.0, use_kv_events=true, durable_kv_events=false, router_replica_sync=false, router_track_active_blocks=true, router_track_output_blocks=false, router_assume_kv_reuse=true, router_track_prefill_tokens=true, router_prefill_load_model="none", router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8, router_queue_threshold=Some(4.0), router_event_threads=4, router_queue_policy="fcfs", use_remote_indexer=false, serve_indexer=false, shared_cache_multiplier=0.0, shared_cache_type="none"))]
#[allow(clippy::too_many_arguments)]
fn new(
overlap_score_weight: f64,
host_cache_hit_weight: f64,
disk_cache_hit_weight: f64,
router_temperature: f64,
use_kv_events: bool,
durable_kv_events: bool,
......@@ -155,6 +157,8 @@ impl KvRouterConfig {
KvRouterConfig {
inner: RsKvRouterConfig {
overlap_score_weight,
host_cache_hit_weight,
disk_cache_hit_weight,
router_temperature,
use_kv_events,
durable_kv_events,
......
......@@ -1078,7 +1078,6 @@ impl KvRouter {
lora_name.clone(),
0.0,
None,
None,
None, // allowed_worker_ids: pass via RoutingHints in PreprocessedRequest path
)
.await
......
......@@ -68,7 +68,8 @@ use std::collections::VecDeque;
use std::sync::atomic::{AtomicUsize, Ordering};
use super::{
EventKind, EventWarningKind, KvIndexerMetrics, PreBoundEventCounters, SyncIndexer, WorkerTask,
EventKind, EventWarningKind, KvIndexerMetrics, MatchDetails, PreBoundEventCounters,
SyncIndexer, WorkerTask,
};
use crate::cleanup::{self, CleanableNode, CleanupGuard, CleanupState};
use crate::protocols::*;
......@@ -479,26 +480,36 @@ impl ConcurrentRadixTreeCompressed {
// ------------------------------------------------------------------
/// Traverse the radix tree to find the best match for a given sequence of
/// [`LocalBlockHash`]es.
/// [`LocalBlockHash`]es, returning both overlap scores and the last matched
/// `ExternalSequenceBlockHash` per worker (used for lower-tier continuation).
///
/// Workers in `full_edge_workers` are tracked in the `active` set and continue
/// into children. Workers in `worker_cutoffs` are scored at the node where their
/// cutoff falls short and are never propagated into children.
pub fn find_matches_impl(
pub fn find_match_details_impl(
&self,
sequence: &[LocalBlockHash],
early_exit: bool,
) -> OverlapScores {
let mut scores = OverlapScores::new();
) -> MatchDetails {
let mut details = MatchDetails::new();
if sequence.is_empty() {
return scores;
return details;
}
let MatchDetails {
overlap_scores: ref mut scores,
ref mut last_matched_hashes,
} = details;
let mut active: FxHashSet<WorkerWithDpRank> = FxHashSet::default();
let mut active_count: usize = 0;
let mut matched_depth: u32 = 0;
let mut seq_pos: usize = 0;
let mut first_node = true;
// Last ExternalSequenceBlockHash from the previous fully-matched edge.
// Workers that drop at a node boundary (not present in the new node)
// were last matched at the end of the previous edge.
let mut prev_edge_last_hash: Option<ExternalSequenceBlockHash> = None;
let mut next_child = {
let root_guard = read_lock!(self, self.root);
......@@ -531,38 +542,49 @@ impl ConcurrentRadixTreeCompressed {
}
edge_match_len = match_len;
// Helper: ExternalSequenceBlockHash at a given depth within this edge.
let edge_hash_at = |depth: usize| -> ExternalSequenceBlockHash {
debug_assert!(depth > 0 && depth <= guard.edge.len());
guard.edge[depth - 1].1
};
let prev_depth = matched_depth;
if first_node {
// Seed active set from full-edge workers (they can continue to children).
// Score partial workers immediately; they never continue into children.
active = guard.full_edge_workers.clone();
active_count = active.len();
for (&w, &k) in &guard.worker_cutoffs {
let contribution = k.min(edge_match_len) as u32;
let contribution = k.min(edge_match_len);
if contribution > 0 {
scores.scores.insert(w, contribution);
scores.scores.insert(w, contribution as u32);
last_matched_hashes.insert(w, edge_hash_at(contribution));
}
}
first_node = false;
} else {
let has_partial = !guard.worker_cutoffs.is_empty();
if has_partial {
// Slow path: check each active worker against both maps.
active.retain(|w| {
if guard.full_edge_workers.contains(w) {
true
} else if let Some(&k) = guard.worker_cutoffs.get(w) {
let effective = k.min(edge_match_len) as u32;
scores.scores.insert(*w, prev_depth + effective);
let effective = k.min(edge_match_len);
scores.scores.insert(*w, prev_depth + effective as u32);
if effective > 0 {
last_matched_hashes.insert(*w, edge_hash_at(effective));
} else if let Some(h) = prev_edge_last_hash {
last_matched_hashes.insert(*w, h);
}
false
} else {
scores.scores.insert(*w, prev_depth);
if let Some(h) = prev_edge_last_hash {
last_matched_hashes.insert(*w, h);
}
false
}
});
} else {
// Fast path: no partial workers — all coverage is full or absent.
let full_count = guard.full_edge_workers.len();
if full_count != active_count {
active.retain(|w| {
......@@ -570,11 +592,13 @@ impl ConcurrentRadixTreeCompressed {
true
} else {
scores.scores.insert(*w, prev_depth);
if let Some(h) = prev_edge_last_hash {
last_matched_hashes.insert(*w, h);
}
false
}
});
}
// full_count == active_count: sets are identical (fast path).
}
active_count = active.len();
}
......@@ -590,6 +614,9 @@ impl ConcurrentRadixTreeCompressed {
} else {
None
};
// Track the deepest matched hash in this edge (both full and partial).
prev_edge_last_hash = Some(guard.edge[edge_match_len - 1].1);
}
if active_count == 0 {
......@@ -605,15 +632,32 @@ impl ConcurrentRadixTreeCompressed {
}
}
for worker in &active {
scores.scores.insert(*worker, matched_depth);
// Record scores and hashes for workers that survived to the deepest level.
if let Some(h) = prev_edge_last_hash {
for worker in &active {
scores.scores.insert(*worker, matched_depth);
last_matched_hashes.insert(*worker, h);
}
} else {
for worker in &active {
scores.scores.insert(*worker, matched_depth);
}
}
for worker in scores.scores.keys() {
if let Some(s) = self.tree_sizes.get(worker) {
scores.tree_sizes.insert(*worker, s.load(Ordering::Relaxed));
}
}
scores
details
}
pub fn find_matches_impl(
&self,
sequence: &[LocalBlockHash],
early_exit: bool,
) -> OverlapScores {
self.find_match_details_impl(sequence, early_exit)
.overlap_scores
}
// ------------------------------------------------------------------
......
......@@ -12,7 +12,8 @@ use tokio_util::sync::CancellationToken;
use super::{
DumpRequest, EventKind, GetWorkersRequest, KvIndexerInterface, KvIndexerMetrics, KvRouterError,
MatchRequest, PreBoundEventCounters, RadixTree, RoutingDecisionRequest,
MatchDetails, MatchDetailsRequest, MatchRequest, PreBoundEventCounters, RadixTree,
RoutingDecisionRequest,
};
use crate::indexer::pruning::{BlockEntry, PruneConfig, PruneManager};
use crate::protocols::*;
......@@ -95,6 +96,8 @@ pub struct KvIndexer {
event_tx: mpsc::Sender<RouterEvent>,
/// A sender for `MatchRequest`s.
match_tx: mpsc::Sender<MatchRequest>,
/// A sender for `MatchDetailsRequest`s.
match_details_tx: mpsc::Sender<MatchDetailsRequest>,
/// A sender for remove worker requests.
remove_worker_tx: mpsc::Sender<WorkerId>,
/// A sender for remove worker dp_rank requests.
......@@ -136,6 +139,7 @@ impl KvIndexer {
let (event_tx, event_rx) = mpsc::channel::<RouterEvent>(16384);
let (match_tx, match_rx) = mpsc::channel::<MatchRequest>(128);
let (match_details_tx, match_details_rx) = mpsc::channel::<MatchDetailsRequest>(128);
let (remove_worker_tx, remove_worker_rx) = mpsc::channel::<WorkerId>(16);
let (remove_worker_dp_rank_tx, remove_worker_dp_rank_rx) =
mpsc::channel::<(WorkerId, DpRank)>(16);
......@@ -156,6 +160,7 @@ impl KvIndexer {
runtime.block_on(async move {
let cancel = cancel_clone;
let mut match_rx = match_rx;
let mut match_details_rx = match_details_rx;
let mut event_rx = event_rx;
let mut remove_worker_rx = remove_worker_rx;
let mut remove_worker_dp_rank_rx = remove_worker_dp_rank_rx;
......@@ -320,6 +325,11 @@ impl KvIndexer {
let _ = req.resp.send(matches);
}
Some(req) = match_details_rx.recv() => {
let matches = trie.find_match_details(req.sequence, req.early_exit);
let _ = req.resp.send(matches);
}
_ = expiry_fut => {
// TTL-based expiry triggered
let Some(ref mut pm) = prune_manager else { continue };
......@@ -351,6 +361,7 @@ impl KvIndexer {
cancel: token,
event_tx,
match_tx,
match_details_tx,
remove_worker_tx,
remove_worker_dp_rank_tx,
get_workers_tx,
......@@ -382,6 +393,21 @@ impl KvIndexer {
self.event_tx.clone()
}
pub async fn find_match_details(
&self,
sequence: Vec<LocalBlockHash>,
) -> Result<MatchDetails, KvRouterError> {
let (resp_tx, resp_rx) = oneshot::channel();
self.match_details_tx
.send(MatchDetailsRequest::new(sequence, false, resp_tx))
.await
.map_err(|_| KvRouterError::IndexerOffline)?;
resp_rx
.await
.map_err(|_| KvRouterError::IndexerDroppedRequest)
}
#[cfg(test)]
pub fn snapshot_event_sender(&self) -> mpsc::Sender<DumpRequest> {
self.dump_tx.clone()
......
......@@ -2,7 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
use std::{
collections::VecDeque,
collections::{HashMap, VecDeque},
sync::{Arc, Mutex},
};
......@@ -13,7 +13,7 @@ use tokio_util::sync::CancellationToken;
use super::{
GetWorkersRequest, KvIndexer, KvIndexerInterface, KvIndexerMetrics, KvRouterError,
WorkerKvQueryResponse,
LowerTierIndexer, ThreadPoolIndexer, WorkerKvQueryResponse,
};
use crate::protocols::*;
......@@ -205,6 +205,8 @@ impl RecoverySnapshotCache {
pub struct LocalKvIndexer {
/// The underlying indexer
indexer: KvIndexer,
/// Lazily-created exact lower-tier indexes partitioned by storage tier.
lower_tier_indexers: Arc<Mutex<HashMap<StorageTier, Arc<ThreadPoolIndexer<LowerTierIndexer>>>>>,
/// Circular buffer of recent events
pub(super) event_buffer: Mutex<VecDeque<RouterEvent>>,
/// Coordinates single-flight tree dumps and the cached recovery snapshot.
......@@ -229,6 +231,7 @@ impl LocalKvIndexer {
) -> Self {
Self {
indexer: KvIndexer::new(token, kv_block_size, metrics),
lower_tier_indexers: Arc::new(Mutex::new(HashMap::new())),
event_buffer: Mutex::new(VecDeque::with_capacity(max_buffer_size)),
recovery_cache: Arc::new(RecoverySnapshotCache::new()),
max_buffer_size,
......@@ -335,13 +338,7 @@ impl LocalKvIndexer {
///
/// This forwards the event to the underlying indexer and records it on success.
pub async fn apply_event_with_buffer(&self, event: RouterEvent) -> Result<(), KvRouterError> {
// Forward to underlying indexer
let result = self
.indexer
.event_sender()
.send(event.clone())
.await
.map_err(|_| KvRouterError::IndexerOffline);
let result = self.apply_event_by_tier(&event).await;
if result.is_ok() {
let should_invalidate = matches!(event.event.data, KvCacheEventData::Cleared);
let detected_gap = self.record_event(event);
......@@ -617,6 +614,63 @@ impl LocalKvIndexer {
pub fn get_workers_sender(&self) -> mpsc::Sender<GetWorkersRequest> {
self.indexer.get_workers_sender()
}
/// Get the KV block size.
pub fn block_size(&self) -> u32 {
self.indexer.block_size()
}
async fn apply_event_to_primary(&self, event: RouterEvent) -> Result<(), KvRouterError> {
self.indexer
.event_sender()
.send(event)
.await
.map_err(|_| KvRouterError::IndexerOffline)
}
async fn apply_event_to_lower_tier(&self, event: RouterEvent) -> Result<(), KvRouterError> {
self.get_or_create_lower_tier_indexer(event.storage_tier)
.apply_event(event)
.await;
Ok(())
}
async fn apply_event_by_tier(&self, event: &RouterEvent) -> Result<(), KvRouterError> {
match &event.event.data {
KvCacheEventData::Cleared => {
self.apply_event_to_primary(event.clone()).await?;
for indexer in self.all_lower_tier_indexers() {
indexer.apply_event(event.clone()).await;
}
Ok(())
}
_ if event.storage_tier.is_gpu() => self.apply_event_to_primary(event.clone()).await,
_ => self.apply_event_to_lower_tier(event.clone()).await,
}
}
fn get_or_create_lower_tier_indexer(
&self,
storage_tier: StorageTier,
) -> Arc<ThreadPoolIndexer<LowerTierIndexer>> {
debug_assert!(!storage_tier.is_gpu());
let mut indexers = self.lower_tier_indexers.lock().unwrap();
indexers
.entry(storage_tier)
.or_insert_with(|| {
Arc::new(ThreadPoolIndexer::new(
LowerTierIndexer::new(),
1,
self.block_size(),
))
})
.clone()
}
fn all_lower_tier_indexers(&self) -> Vec<Arc<ThreadPoolIndexer<LowerTierIndexer>>> {
let indexers = self.lower_tier_indexers.lock().unwrap();
indexers.values().cloned().collect()
}
}
// Implement KvIndexerInterface by delegating to the underlying indexer
......@@ -646,10 +700,16 @@ impl KvIndexerInterface for LocalKvIndexer {
}
async fn remove_worker(&self, worker: WorkerId) {
for indexer in self.all_lower_tier_indexers() {
indexer.remove_worker(worker).await;
}
let _ = self.indexer.remove_worker_sender().send(worker).await;
}
async fn remove_worker_dp_rank(&self, worker: WorkerId, dp_rank: DpRank) {
for indexer in self.all_lower_tier_indexers() {
KvIndexerInterface::remove_worker_dp_rank(&*indexer, worker, dp_rank).await;
}
KvIndexerInterface::remove_worker_dp_rank(&self.indexer, worker, dp_rank).await;
}
......@@ -658,7 +718,27 @@ impl KvIndexerInterface for LocalKvIndexer {
}
async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
self.indexer.dump_events().await
let mut events = self.indexer.dump_events().await?;
// Also dump lower-tier indexer state so the router receives
// host-pinned / disk block information during recovery.
let lower_tiers: Vec<(StorageTier, Arc<ThreadPoolIndexer<LowerTierIndexer>>)> = {
let indexers = self.lower_tier_indexers.lock().unwrap();
indexers
.iter()
.map(|(&tier, idx)| (tier, idx.clone()))
.collect()
};
for (tier, indexer) in lower_tiers {
if let Ok(tier_events) = indexer.dump_events().await {
for mut event in tier_events {
event.storage_tier = tier;
events.push(event);
}
}
}
Ok(events)
}
async fn process_routing_decision_for_request(
......@@ -674,6 +754,231 @@ impl KvIndexerInterface for LocalKvIndexer {
}
async fn flush(&self) -> usize {
self.indexer.flush().await
let queued = self.indexer.flush().await;
for indexer in self.all_lower_tier_indexers() {
let _ = indexer.dump_events().await;
}
queued
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use rustc_hash::FxHashMap;
use tokio_util::sync::CancellationToken;
use super::LocalKvIndexer;
use crate::indexer::{KvIndexerInterface, KvIndexerMetrics, LowerTierContinuation};
use crate::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash, RouterEvent, StorageTier, WorkerWithDpRank,
};
fn lower_tier_store_event(
worker_id: u64,
dp_rank: u32,
event_id: u64,
parent_hash: u64,
tokens_hash: u64,
block_hash: u64,
storage_tier: StorageTier,
) -> RouterEvent {
RouterEvent::with_storage_tier(
worker_id,
KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: Some(ExternalSequenceBlockHash(parent_hash)),
start_position: None,
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(block_hash),
tokens_hash: LocalBlockHash(tokens_hash),
mm_extra_info: None,
}],
}),
dp_rank,
},
storage_tier,
)
}
fn lower_tier_hits(
indexer: &LocalKvIndexer,
storage_tier: StorageTier,
worker_id: u64,
dp_rank: u32,
parent_hash: u64,
tokens_hash: u64,
) -> usize {
let lower_tier_indexer = {
let indexers = indexer.lower_tier_indexers.lock().unwrap();
indexers.get(&storage_tier).cloned()
};
let Some(lower_tier_indexer) = lower_tier_indexer else {
return 0;
};
let mut continuations = FxHashMap::default();
continuations.insert(
WorkerWithDpRank::new(worker_id, dp_rank),
LowerTierContinuation::new(0, ExternalSequenceBlockHash(parent_hash)),
);
lower_tier_indexer
.backend()
.query_contiguous_hits(&[LocalBlockHash(tokens_hash)], &continuations)
.get(&WorkerWithDpRank::new(worker_id, dp_rank))
.copied()
.unwrap_or(0)
}
#[tokio::test]
async fn lower_tier_events_are_buffered_without_touching_primary_index() {
let indexer = LocalKvIndexer::new(
CancellationToken::new(),
4,
Arc::new(KvIndexerMetrics::new_unregistered()),
16,
);
let event = lower_tier_store_event(7, 0, 1, 900, 11, 101, StorageTier::HostPinned);
indexer
.apply_event_with_buffer(event.clone())
.await
.unwrap();
let _ = indexer.flush().await;
assert_eq!(indexer.get_all_events_in_buffer(), vec![event]);
assert_eq!(indexer.lower_tier_indexers.lock().unwrap().len(), 1);
assert_eq!(
lower_tier_hits(&indexer, StorageTier::HostPinned, 7, 0, 900, 11),
1
);
let overlap = indexer
.find_matches(vec![LocalBlockHash(11)])
.await
.unwrap();
assert!(overlap.scores.is_empty());
}
#[tokio::test]
async fn lower_tier_events_are_partitioned_by_storage_tier() {
let indexer = LocalKvIndexer::new(
CancellationToken::new(),
4,
Arc::new(KvIndexerMetrics::new_unregistered()),
16,
);
assert_eq!(indexer.lower_tier_indexers.lock().unwrap().len(), 0);
indexer
.apply_event_with_buffer(lower_tier_store_event(
19,
0,
1,
1000,
31,
301,
StorageTier::HostPinned,
))
.await
.unwrap();
let _ = indexer.flush().await;
assert_eq!(indexer.lower_tier_indexers.lock().unwrap().len(), 1);
indexer
.apply_event_with_buffer(lower_tier_store_event(
19,
0,
2,
2000,
31,
302,
StorageTier::Disk,
))
.await
.unwrap();
let _ = indexer.flush().await;
assert_eq!(indexer.lower_tier_indexers.lock().unwrap().len(), 2);
assert_eq!(
lower_tier_hits(&indexer, StorageTier::HostPinned, 19, 0, 1000, 31),
1
);
assert_eq!(
lower_tier_hits(&indexer, StorageTier::Disk, 19, 0, 2000, 31),
1
);
assert_eq!(
lower_tier_hits(&indexer, StorageTier::HostPinned, 19, 0, 2000, 31),
0
);
assert_eq!(
lower_tier_hits(&indexer, StorageTier::Disk, 19, 0, 1000, 31),
0
);
}
#[tokio::test]
async fn cleared_event_clears_all_lower_tier_dp_ranks_for_worker() {
let indexer = LocalKvIndexer::new(
CancellationToken::new(),
4,
Arc::new(KvIndexerMetrics::new_unregistered()),
16,
);
indexer
.apply_event_with_buffer(lower_tier_store_event(
11,
0,
1,
1000,
21,
201,
StorageTier::HostPinned,
))
.await
.unwrap();
indexer
.apply_event_with_buffer(lower_tier_store_event(
11,
1,
2,
2000,
22,
202,
StorageTier::HostPinned,
))
.await
.unwrap();
indexer
.apply_event_with_buffer(RouterEvent::with_storage_tier(
11,
KvCacheEvent {
event_id: 3,
data: KvCacheEventData::Cleared,
dp_rank: 0,
},
StorageTier::HostPinned,
))
.await
.unwrap();
let _ = indexer.flush().await;
assert_eq!(
lower_tier_hits(&indexer, StorageTier::HostPinned, 11, 0, 1000, 21),
0
);
assert_eq!(
lower_tier_hits(&indexer, StorageTier::HostPinned, 11, 1, 2000, 22),
0
);
}
}
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment