"components/vscode:/vscode.git/clone" did not exist on "a3d46840760eb1e11e6880ea645fe7b50fa7cd66"
Unverified Commit e3d00b89 authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: Tier-based KV Routing (#8380)


Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent 7e48f3bd
...@@ -660,6 +660,7 @@ async def init_llm_worker( ...@@ -660,6 +660,7 @@ async def init_llm_worker(
kv_block_size=config.kv_block_size, kv_block_size=config.kv_block_size,
zmq_endpoint=consolidator_output_connect_endpoint, zmq_endpoint=consolidator_output_connect_endpoint,
zmq_topic="", zmq_topic="",
enable_local_indexer=config.enable_local_indexer,
) )
logging.info( logging.info(
f"Created worker-side publisher for consolidated events: " f"Created worker-side publisher for consolidated events: "
......
...@@ -8,7 +8,7 @@ use common::*; ...@@ -8,7 +8,7 @@ use common::*;
use clap::Parser; use clap::Parser;
use common::NoopSequencePublisher; use common::NoopSequencePublisher;
use dynamo_kv_router::protocols::{PrefillLoadHint, WorkerWithDpRank}; use dynamo_kv_router::protocols::{PrefillLoadHint, WorkerWithDpRank};
use dynamo_kv_router::{ActiveSequencesMultiWorker, OverlapScores, SequenceRequest}; use dynamo_kv_router::{ActiveSequencesMultiWorker, SequenceRequest};
use dynamo_mocker::loadgen::Trace; use dynamo_mocker::loadgen::Trace;
use dynamo_tokens::SequenceHash; use dynamo_tokens::SequenceHash;
use std::collections::HashMap; use std::collections::HashMap;
...@@ -379,12 +379,7 @@ async fn apply_entry( ...@@ -379,12 +379,7 @@ async fn apply_entry(
isl, isl,
output_length, output_length,
} => { } => {
let _ = multi.potential_blocks_and_tokens( let _ = multi.potential_blocks_and_tokens(Some(&block_hashes), isl, HashMap::new());
Some(&block_hashes),
isl,
OverlapScores::default(),
decay_now,
);
let _ = multi.add_request( let _ = multi.add_request(
SequenceRequest { SequenceRequest {
request_id, request_id,
......
...@@ -521,7 +521,6 @@ impl RouterHandles { ...@@ -521,7 +521,6 @@ impl RouterHandles {
None, None,
0.0, 0.0,
None, None,
None,
allowed_worker_ids, allowed_worker_ids,
) )
.await .await
...@@ -884,12 +883,13 @@ pub unsafe extern "C" fn add_request( ...@@ -884,12 +883,13 @@ pub unsafe extern "C" fn add_request(
} }
}; };
let cached_tokens = overlap_blocks as usize * decode_router.block_size() as usize;
decode_router decode_router
.add_request( .add_request(
request_id_str.clone(), request_id_str.clone(),
&tokens, &tokens,
None, None,
overlap_blocks, cached_tokens,
None, None,
worker, worker,
None, // lora_name None, // lora_name
......
...@@ -7,7 +7,7 @@ from typing import List, Optional ...@@ -7,7 +7,7 @@ from typing import List, Optional
import tensorrt_llm import tensorrt_llm
from kvbm import KvbmLeader from kvbm import KvbmLeader
from kvbm.trtllm_integration.consolidator_config import is_truthy from kvbm.trtllm_integration.consolidator_config import get_consolidator_mode, is_truthy
from kvbm.trtllm_integration.rust import KvbmRequest from kvbm.trtllm_integration.rust import KvbmRequest
from kvbm.trtllm_integration.rust import KvConnectorLeader as RustKvConnectorLeader from kvbm.trtllm_integration.rust import KvConnectorLeader as RustKvConnectorLeader
from kvbm.trtllm_integration.rust import SchedulerOutput as RustSchedulerOutput from kvbm.trtllm_integration.rust import SchedulerOutput as RustSchedulerOutput
...@@ -55,7 +55,9 @@ class DynamoKVBMConnectorLeader(KvCacheConnectorScheduler): ...@@ -55,7 +55,9 @@ class DynamoKVBMConnectorLeader(KvCacheConnectorScheduler):
trtllm_ep = None trtllm_ep = None
consolidator_output_ep = None consolidator_output_ep = None
consolidator_mode = None
if consolidator_enabled: if consolidator_enabled:
consolidator_mode = get_consolidator_mode()
# Get consolidator endpoint from environment variable # Get consolidator endpoint from environment variable
# DYN_KVBM_TRTLLM_ZMQ_PORT contains just the port number (e.g., "20081") # DYN_KVBM_TRTLLM_ZMQ_PORT contains just the port number (e.g., "20081")
zmq_port = os.getenv("DYN_KVBM_TRTLLM_ZMQ_PORT") zmq_port = os.getenv("DYN_KVBM_TRTLLM_ZMQ_PORT")
...@@ -105,6 +107,7 @@ class DynamoKVBMConnectorLeader(KvCacheConnectorScheduler): ...@@ -105,6 +107,7 @@ class DynamoKVBMConnectorLeader(KvCacheConnectorScheduler):
leader, leader,
consolidator_trtllm_endpoint=trtllm_ep, consolidator_trtllm_endpoint=trtllm_ep,
consolidator_output_endpoint=consolidator_output_ep, consolidator_output_endpoint=consolidator_output_ep,
consolidator_mode=consolidator_mode,
) )
@nvtx_annotate(category="scheduler") @nvtx_annotate(category="scheduler")
...@@ -132,6 +135,7 @@ class DynamoKVBMConnectorLeader(KvCacheConnectorScheduler): ...@@ -132,6 +135,7 @@ class DynamoKVBMConnectorLeader(KvCacheConnectorScheduler):
req.new_block_ids, req.new_block_ids,
req.computed_position, req.computed_position,
req.priorities, # Pass retention priorities for offload filtering req.priorities, # Pass retention priorities for offload filtering
list(req.block_hashes),
) )
resumed_from_preemption = False resumed_from_preemption = False
...@@ -143,6 +147,7 @@ class DynamoKVBMConnectorLeader(KvCacheConnectorScheduler): ...@@ -143,6 +147,7 @@ class DynamoKVBMConnectorLeader(KvCacheConnectorScheduler):
req.new_block_ids, req.new_block_ids,
req.computed_position, req.computed_position,
req.priorities, # Pass retention priorities for offload filtering req.priorities, # Pass retention priorities for offload filtering
list(req.block_hashes),
) )
output.add_num_scheduled_tokens( output.add_num_scheduled_tokens(
......
...@@ -8,21 +8,27 @@ Helper functions for KV Event Consolidator configuration for TensorRT-LLM. ...@@ -8,21 +8,27 @@ Helper functions for KV Event Consolidator configuration for TensorRT-LLM.
import logging import logging
import os import os
from kvbm.utils import get_consolidator_mode, is_truthy
__all__ = [
"get_consolidator_endpoints",
"get_consolidator_mode",
"is_truthy",
"should_enable_consolidator",
]
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def is_truthy(val: str) -> bool: def _get_connector_module(kv_connector_config) -> str | None:
""" """Extract connector_module from either a dict or a TRT-LLM config object."""
Check if a string represents a truthy value. if kv_connector_config is None:
Truthy values: "1", "true", "on", "yes" (case-insensitive) return None
Args: if isinstance(kv_connector_config, dict):
val: The string value to check return kv_connector_config.get("connector_module")
Returns: return getattr(kv_connector_config, "connector_module", None)
True if the value is truthy, False otherwise
"""
return val.lower() in ("1", "true", "on", "yes")
def should_enable_consolidator(arg_map) -> bool: def should_enable_consolidator(arg_map) -> bool:
...@@ -48,23 +54,19 @@ def should_enable_consolidator(arg_map) -> bool: ...@@ -48,23 +54,19 @@ def should_enable_consolidator(arg_map) -> bool:
) )
return False return False
# Check if KVBM connector is enabled by extracting connector_module # Check if KVBM connector is enabled
# from kv_connector_config (works whether arg_map holds raw dicts or typed objects) if not isinstance(arg_map, dict):
kv_connector_config = ( logger.warning("KV Event Consolidator is not enabled: arg_map is not a dict")
arg_map.get("kv_connector_config") if isinstance(arg_map, dict) else None return False
)
if kv_connector_config is None: kv_connector_config = arg_map.get("kv_connector_config")
connector_module = _get_connector_module(kv_connector_config) or ""
if not connector_module:
logger.warning( logger.warning(
"KV Event Consolidator is not enabled: no kv_connector_config found" "KV Event Consolidator is not enabled: kv_connector_config has no connector_module"
) )
return False return False
if isinstance(kv_connector_config, dict):
connector_module = kv_connector_config.get("connector_module", "")
else:
# Access directly so AttributeError surfaces if the contract changes
connector_module = kv_connector_config.connector_module or ""
has_kvbm_connector = "kvbm.trtllm_integration.connector" in connector_module has_kvbm_connector = "kvbm.trtllm_integration.connector" in connector_module
if not has_kvbm_connector: if not has_kvbm_connector:
......
...@@ -2,8 +2,33 @@ ...@@ -2,8 +2,33 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import logging
import os import os
logger = logging.getLogger(__name__)
def is_truthy(val: str) -> bool:
"""Truthy values: "1", "true", "on", "yes" (case-insensitive)."""
return val.strip().lower() in ("1", "true", "on", "yes")
def get_consolidator_mode() -> str:
"""Return the KV event consolidator mode from DYN_KVBM_KV_EVENTS_CONSOLIDATOR_MODE.
Returns "dedup" or "passthrough"; invalid/unset values fall back to "dedup".
"""
mode = os.getenv("DYN_KVBM_KV_EVENTS_CONSOLIDATOR_MODE", "dedup").strip().lower()
if mode in ("dedup", "passthrough"):
return mode
logger.warning(
"Invalid DYN_KVBM_KV_EVENTS_CONSOLIDATOR_MODE=%r. Falling back to 'dedup'.",
mode,
)
return "dedup"
try: try:
from nvtx import annotate # type: ignore from nvtx import annotate # type: ignore
except ImportError: except ImportError:
......
...@@ -30,6 +30,7 @@ if TYPE_CHECKING: ...@@ -30,6 +30,7 @@ if TYPE_CHECKING:
from kvbm import KvbmLeader from kvbm import KvbmLeader
from kvbm.utils import is_dyn_runtime_enabled from kvbm.utils import is_dyn_runtime_enabled
from kvbm.vllm_integration.consolidator_config import get_consolidator_mode
from kvbm.vllm_integration.rust import KvbmRequest from kvbm.vllm_integration.rust import KvbmRequest
from kvbm.vllm_integration.rust import KvConnectorLeader as RustKvConnectorLeader from kvbm.vllm_integration.rust import KvConnectorLeader as RustKvConnectorLeader
from kvbm.vllm_integration.rust import SchedulerOutput as RustSchedulerOutput from kvbm.vllm_integration.rust import SchedulerOutput as RustSchedulerOutput
...@@ -72,10 +73,12 @@ class KvConnectorLeader: ...@@ -72,10 +73,12 @@ class KvConnectorLeader:
# Get kv event consolidator endpoints from vllm_config (pre-computed in main.py) # Get kv event consolidator endpoints from vllm_config (pre-computed in main.py)
consolidator_vllm_endpoint = None consolidator_vllm_endpoint = None
consolidator_output_endpoint = None consolidator_output_endpoint = None
consolidator_mode = None
self._consolidator_output_port = None self._consolidator_output_port = None
_consolidator_eps = vllm_config.additional_config.get("consolidator_endpoints") _consolidator_eps = vllm_config.additional_config.get("consolidator_endpoints")
if _consolidator_eps: if _consolidator_eps:
consolidator_mode = get_consolidator_mode()
# Unpack all three endpoints # Unpack all three endpoints
# [0]: vllm_endpoint (for consolidator to subscribe to vLLM) # [0]: vllm_endpoint (for consolidator to subscribe to vLLM)
# [1]: output_bind_endpoint (for consolidator to bind/publish) # [1]: output_bind_endpoint (for consolidator to bind/publish)
...@@ -97,6 +100,7 @@ class KvConnectorLeader: ...@@ -97,6 +100,7 @@ class KvConnectorLeader:
leader, leader,
consolidator_vllm_endpoint=consolidator_vllm_endpoint, consolidator_vllm_endpoint=consolidator_vllm_endpoint,
consolidator_output_endpoint=consolidator_output_endpoint, consolidator_output_endpoint=consolidator_output_endpoint,
consolidator_mode=consolidator_mode,
) )
else: else:
# No kv event consolidator - pass None to Rust # No kv event consolidator - pass None to Rust
...@@ -107,6 +111,7 @@ class KvConnectorLeader: ...@@ -107,6 +111,7 @@ class KvConnectorLeader:
leader, leader,
consolidator_vllm_endpoint=None, consolidator_vllm_endpoint=None,
consolidator_output_endpoint=None, consolidator_output_endpoint=None,
consolidator_mode=None,
) )
# KV Connector # KV Connector
......
...@@ -9,23 +9,17 @@ import logging ...@@ -9,23 +9,17 @@ import logging
import os import os
from typing import Optional, Tuple from typing import Optional, Tuple
from kvbm.utils import get_consolidator_mode, is_truthy
from vllm.distributed.kv_events import ZmqEventPublisher from vllm.distributed.kv_events import ZmqEventPublisher
logger = logging.getLogger(__name__) __all__ = [
"get_consolidator_endpoints",
"get_consolidator_mode",
"is_truthy",
"should_enable_consolidator",
]
def is_truthy(val: str) -> bool: logger = logging.getLogger(__name__)
"""
Check if a string represents a truthy value.
Truthy values: "1", "true", "on", "yes" (case-insensitive)
Args:
val: The string value to check
Returns:
True if the value is truthy, False otherwise
"""
return val.lower() in ("1", "true", "on", "yes")
def should_enable_consolidator(vllm_config) -> bool: def should_enable_consolidator(vllm_config) -> bool:
......
...@@ -6,7 +6,7 @@ use anyhow::Result; ...@@ -6,7 +6,7 @@ use anyhow::Result;
use dynamo_llm::block_manager::block::{ use dynamo_llm::block_manager::block::{
data::logical::distributed_leader_worker::DistributedLeaderWorkerResources, locality::Logical, data::logical::distributed_leader_worker::DistributedLeaderWorkerResources, locality::Logical,
}; };
use dynamo_llm::block_manager::kv_consolidator::EventSource; use dynamo_llm::block_manager::kv_consolidator::{EventSource, KvEventConsolidationMode};
use dynamo_llm::block_manager::offload::filter::FrequencyFilter; use dynamo_llm::block_manager::offload::filter::FrequencyFilter;
use dynamo_llm::block_manager::{BasicMetadata, BlockParallelismStrategy}; use dynamo_llm::block_manager::{BasicMetadata, BlockParallelismStrategy};
use dynamo_runtime::DistributedRuntime; use dynamo_runtime::DistributedRuntime;
...@@ -252,7 +252,12 @@ pub struct BlockManagerBuilder { ...@@ -252,7 +252,12 @@ pub struct BlockManagerBuilder {
page_size: usize, page_size: usize,
disable_device_pool: bool, disable_device_pool: bool,
kvbm_metrics: Option<dynamo_llm::block_manager::metrics_kvbm::KvbmMetrics>, kvbm_metrics: Option<dynamo_llm::block_manager::metrics_kvbm::KvbmMetrics>,
consolidator_config: Option<(String, Option<String>, EventSource)>, // (engine_endpoint, output_endpoint (optional), engine_source) consolidator_config: Option<(
String,
Option<String>,
EventSource,
KvEventConsolidationMode,
)>, // (engine_endpoint, output_endpoint (optional), engine_source, mode)
} }
impl BlockManagerBuilder { impl BlockManagerBuilder {
...@@ -293,8 +298,9 @@ impl BlockManagerBuilder { ...@@ -293,8 +298,9 @@ impl BlockManagerBuilder {
engine_endpoint: String, engine_endpoint: String,
output_endpoint: Option<String>, output_endpoint: Option<String>,
engine_source: EventSource, engine_source: EventSource,
mode: KvEventConsolidationMode,
) -> Self { ) -> Self {
self.consolidator_config = Some((engine_endpoint, output_endpoint, engine_source)); self.consolidator_config = Some((engine_endpoint, output_endpoint, engine_source, mode));
self self
} }
...@@ -368,9 +374,9 @@ impl BlockManagerBuilder { ...@@ -368,9 +374,9 @@ impl BlockManagerBuilder {
config_builder = config_builder.kvbm_metrics(Some(kvbm_metrics)); config_builder = config_builder.kvbm_metrics(Some(kvbm_metrics));
} }
if let Some((engine_ep, output_ep, engine_source)) = self.consolidator_config { if let Some((engine_ep, output_ep, engine_source, mode)) = self.consolidator_config {
config_builder = config_builder =
config_builder.consolidator_config(engine_ep, output_ep, engine_source); config_builder.consolidator_config(engine_ep, output_ep, engine_source, mode);
} }
let config = config_builder.build()?; let config = config_builder.build()?;
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
use dynamo_llm::block_manager::{ use dynamo_llm::block_manager::{
block::BlockId, connector::protocol::WorkerTransferRequest, pool::BlockPoolError, block::BlockId, connector::protocol::WorkerTransferRequest, pool::BlockPoolError,
}; };
use dynamo_llm::tokens::SequenceHash;
pub mod leader; pub mod leader;
pub mod trtllm_leader; pub mod trtllm_leader;
...@@ -42,7 +43,7 @@ impl SchedulerOutput { ...@@ -42,7 +43,7 @@ impl SchedulerOutput {
// I am surprised that vLLM's NewRequestData does not include the salt hash. // I am surprised that vLLM's NewRequestData does not include the salt hash.
// It has almost everything else to compute the block hashes worker side. // It has almost everything else to compute the block hashes worker side.
#[pyo3(signature = (request_id, prompt_token_ids, block_ids, num_computed_tokens, priorities=None))] #[pyo3(signature = (request_id, prompt_token_ids, block_ids, num_computed_tokens, priorities=None, external_sequence_hashes=None))]
pub fn add_new_request( pub fn add_new_request(
&mut self, &mut self,
request_id: String, request_id: String,
...@@ -50,6 +51,7 @@ impl SchedulerOutput { ...@@ -50,6 +51,7 @@ impl SchedulerOutput {
block_ids: Vec<BlockId>, block_ids: Vec<BlockId>,
num_computed_tokens: usize, num_computed_tokens: usize,
priorities: Option<Vec<u32>>, priorities: Option<Vec<u32>>,
external_sequence_hashes: Option<Vec<SequenceHash>>,
) { ) {
self.new_requests.push(NewRequestData { self.new_requests.push(NewRequestData {
request_id, request_id,
...@@ -57,11 +59,13 @@ impl SchedulerOutput { ...@@ -57,11 +59,13 @@ impl SchedulerOutput {
block_ids, block_ids,
num_computed_tokens, num_computed_tokens,
priorities, priorities,
external_sequence_hashes,
}); });
} }
/// This is called by the leader to update the cached requests /// This is called by the leader to update the cached requests
#[pyo3(signature = (request_id, resumed_from_preemption, new_token_ids, new_block_ids, num_computed_tokens, priorities=None))] #[pyo3(signature = (request_id, resumed_from_preemption, new_token_ids, new_block_ids, num_computed_tokens, priorities=None, external_sequence_hashes=None))]
#[allow(clippy::too_many_arguments)]
pub fn add_cached_request( pub fn add_cached_request(
&mut self, &mut self,
request_id: String, request_id: String,
...@@ -70,6 +74,7 @@ impl SchedulerOutput { ...@@ -70,6 +74,7 @@ impl SchedulerOutput {
new_block_ids: Vec<BlockId>, new_block_ids: Vec<BlockId>,
num_computed_tokens: usize, num_computed_tokens: usize,
priorities: Option<Vec<u32>>, priorities: Option<Vec<u32>>,
external_sequence_hashes: Option<Vec<SequenceHash>>,
) { ) {
self.cached_requests.push(CachedRequestData { self.cached_requests.push(CachedRequestData {
request_id, request_id,
...@@ -78,6 +83,7 @@ impl SchedulerOutput { ...@@ -78,6 +83,7 @@ impl SchedulerOutput {
new_block_ids, new_block_ids,
num_computed_tokens, num_computed_tokens,
priorities, priorities,
external_sequence_hashes,
}); });
} }
...@@ -108,6 +114,8 @@ pub struct NewRequestData { ...@@ -108,6 +114,8 @@ pub struct NewRequestData {
/// Retention priorities for each block (same length as block_ids). /// Retention priorities for each block (same length as block_ids).
/// Used for priority-based offload filtering. /// Used for priority-based offload filtering.
pub priorities: Option<Vec<u32>>, pub priorities: Option<Vec<u32>>,
/// TRT-LLM cumulative external sequence-hash chain for all completed blocks.
pub external_sequence_hashes: Option<Vec<SequenceHash>>,
} }
impl std::fmt::Debug for NewRequestData { impl std::fmt::Debug for NewRequestData {
...@@ -131,6 +139,8 @@ pub struct CachedRequestData { ...@@ -131,6 +139,8 @@ pub struct CachedRequestData {
/// Retention priorities for each new block (same length as new_block_ids). /// Retention priorities for each new block (same length as new_block_ids).
/// Used for priority-based offload filtering. /// Used for priority-based offload filtering.
pub priorities: Option<Vec<u32>>, pub priorities: Option<Vec<u32>>,
/// TRT-LLM cumulative external sequence-hash chain for all completed blocks.
pub external_sequence_hashes: Option<Vec<SequenceHash>>,
} }
impl std::fmt::Debug for CachedRequestData { impl std::fmt::Debug for CachedRequestData {
...@@ -188,3 +198,40 @@ impl ConnectorMetadata { ...@@ -188,3 +198,40 @@ impl ConnectorMetadata {
self.operations.extend(xfer_reqs); self.operations.extend(xfer_reqs);
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_scheduler_output_preserves_external_sequence_hashes() {
let mut output = SchedulerOutput::new();
output.add_new_request(
"req-1".to_string(),
vec![1, 2, 3, 4],
vec![10],
4,
None,
Some(vec![101]),
);
output.add_cached_request(
"req-1".to_string(),
false,
vec![5, 6, 7, 8],
vec![11],
8,
None,
Some(vec![101, 202]),
);
assert_eq!(
output.new_requests[0].external_sequence_hashes,
Some(vec![101])
);
assert_eq!(
output.cached_requests[0].external_sequence_hashes,
Some(vec![101, 202])
);
}
}
...@@ -22,7 +22,7 @@ use dynamo_llm::block_manager::{ ...@@ -22,7 +22,7 @@ use dynamo_llm::block_manager::{
locality::Logical, locality::Logical,
}, },
connector::{protocol::RequestType, *}, connector::{protocol::RequestType, *},
kv_consolidator::EventSource, kv_consolidator::{EventSource, KvEventConsolidationMode},
}; };
use dynamo_llm::tokens::{SaltHash, TokenBlockSequence, Tokens}; use dynamo_llm::tokens::{SaltHash, TokenBlockSequence, Tokens};
use dynamo_runtime::config::environment_names::kvbm as env_kvbm; use dynamo_runtime::config::environment_names::kvbm as env_kvbm;
...@@ -35,6 +35,24 @@ use tokio::sync::oneshot; ...@@ -35,6 +35,24 @@ use tokio::sync::oneshot;
type VllmLocality = Logical<DistributedLeaderWorkerResources>; type VllmLocality = Logical<DistributedLeaderWorkerResources>;
fn parse_consolidator_mode(mode: Option<String>) -> KvEventConsolidationMode {
let Some(mode) = mode else {
return KvEventConsolidationMode::Dedup;
};
match mode.parse() {
Ok(mode) => mode,
Err(error) => {
tracing::warn!(
"Invalid KV event consolidator mode {:?}: {}. Falling back to dedup.",
mode,
error
);
KvEventConsolidationMode::Dedup
}
}
}
impl From<SlotError> for PyErr { impl From<SlotError> for PyErr {
fn from(err: SlotError) -> Self { fn from(err: SlotError) -> Self {
to_pyerr(err) to_pyerr(err)
...@@ -94,6 +112,7 @@ impl KvConnectorLeader { ...@@ -94,6 +112,7 @@ impl KvConnectorLeader {
leader_py: PyKvbmLeader, leader_py: PyKvbmLeader,
consolidator_vllm_endpoint: Option<String>, consolidator_vllm_endpoint: Option<String>,
consolidator_output_endpoint: Option<String>, consolidator_output_endpoint: Option<String>,
consolidator_mode: Option<String>,
) -> Self { ) -> Self {
tracing::info!( tracing::info!(
"KvConnectorLeader initialized with worker_id: {}", "KvConnectorLeader initialized with worker_id: {}",
...@@ -118,6 +137,7 @@ impl KvConnectorLeader { ...@@ -118,6 +137,7 @@ impl KvConnectorLeader {
// Capture consolidator endpoints for the async block // Capture consolidator endpoints for the async block
let consolidator_vllm_ep = consolidator_vllm_endpoint.clone(); let consolidator_vllm_ep = consolidator_vllm_endpoint.clone();
let consolidator_output_ep = consolidator_output_endpoint.clone(); let consolidator_output_ep = consolidator_output_endpoint.clone();
let consolidator_mode = parse_consolidator_mode(consolidator_mode.clone());
handle.spawn(async move { handle.spawn(async move {
let ready = leader.wait_worker_sync_ready().await; let ready = leader.wait_worker_sync_ready().await;
...@@ -148,6 +168,7 @@ impl KvConnectorLeader { ...@@ -148,6 +168,7 @@ impl KvConnectorLeader {
vllm_ep, vllm_ep,
Some(output_ep), Some(output_ep),
EventSource::Vllm, EventSource::Vllm,
consolidator_mode,
); );
} }
...@@ -435,6 +456,7 @@ impl Leader for KvConnectorLeader { ...@@ -435,6 +456,7 @@ impl Leader for KvConnectorLeader {
new_req.num_computed_tokens, new_req.num_computed_tokens,
scheduled_tokens, scheduled_tokens,
None, None,
None,
)?; )?;
let pending_ops_opt = slot.take_pending_operations(); let pending_ops_opt = slot.take_pending_operations();
...@@ -506,6 +528,7 @@ impl Leader for KvConnectorLeader { ...@@ -506,6 +528,7 @@ impl Leader for KvConnectorLeader {
cached_req.num_computed_tokens, cached_req.num_computed_tokens,
scheduled_tokens, scheduled_tokens,
None, None,
None,
)?; )?;
if let Some(pending_ops) = slot.take_pending_operations() { if let Some(pending_ops) = slot.take_pending_operations() {
...@@ -621,7 +644,7 @@ pub struct PyKvConnectorLeader { ...@@ -621,7 +644,7 @@ pub struct PyKvConnectorLeader {
#[pymethods] #[pymethods]
impl PyKvConnectorLeader { impl PyKvConnectorLeader {
#[new] #[new]
#[pyo3(signature = (worker_id, drt, page_size, leader, consolidator_vllm_endpoint=None, consolidator_output_endpoint=None))] #[pyo3(signature = (worker_id, drt, page_size, leader, consolidator_vllm_endpoint=None, consolidator_output_endpoint=None, consolidator_mode=None))]
pub fn new( pub fn new(
worker_id: String, worker_id: String,
drt: Option<PyObject>, drt: Option<PyObject>,
...@@ -629,6 +652,7 @@ impl PyKvConnectorLeader { ...@@ -629,6 +652,7 @@ impl PyKvConnectorLeader {
leader: PyKvbmLeader, leader: PyKvbmLeader,
consolidator_vllm_endpoint: Option<String>, consolidator_vllm_endpoint: Option<String>,
consolidator_output_endpoint: Option<String>, consolidator_output_endpoint: Option<String>,
consolidator_mode: Option<String>,
) -> PyResult<Self> { ) -> PyResult<Self> {
let _ = &drt; // drt is currently un-used in leader let _ = &drt; // drt is currently un-used in leader
...@@ -646,6 +670,7 @@ impl PyKvConnectorLeader { ...@@ -646,6 +670,7 @@ impl PyKvConnectorLeader {
leader, leader,
consolidator_vllm_endpoint, consolidator_vllm_endpoint,
consolidator_output_endpoint, consolidator_output_endpoint,
consolidator_mode,
)) ))
} else { } else {
Box::new(KvConnectorLeader::new( Box::new(KvConnectorLeader::new(
...@@ -654,6 +679,7 @@ impl PyKvConnectorLeader { ...@@ -654,6 +679,7 @@ impl PyKvConnectorLeader {
leader, leader,
consolidator_vllm_endpoint, consolidator_vllm_endpoint,
consolidator_output_endpoint, consolidator_output_endpoint,
consolidator_mode,
)) ))
}; };
Ok(Self { connector_leader }) Ok(Self { connector_leader })
......
...@@ -92,6 +92,7 @@ impl KvConnectorLeaderRecorder { ...@@ -92,6 +92,7 @@ impl KvConnectorLeaderRecorder {
leader_py: PyKvbmLeader, leader_py: PyKvbmLeader,
consolidator_vllm_endpoint: Option<String>, consolidator_vllm_endpoint: Option<String>,
consolidator_output_endpoint: Option<String>, consolidator_output_endpoint: Option<String>,
consolidator_mode: Option<String>,
) -> Self { ) -> Self {
tracing::info!( tracing::info!(
"KvConnectorLeaderRecorder initialized with worker_id: {}", "KvConnectorLeaderRecorder initialized with worker_id: {}",
...@@ -131,6 +132,7 @@ impl KvConnectorLeaderRecorder { ...@@ -131,6 +132,7 @@ impl KvConnectorLeaderRecorder {
// Capture consolidator endpoints for the async block // Capture consolidator endpoints for the async block
let consolidator_vllm_ep = consolidator_vllm_endpoint.clone(); let consolidator_vllm_ep = consolidator_vllm_endpoint.clone();
let consolidator_output_ep = consolidator_output_endpoint.clone(); let consolidator_output_ep = consolidator_output_endpoint.clone();
let consolidator_mode = super::parse_consolidator_mode(consolidator_mode.clone());
handle.spawn(async move { handle.spawn(async move {
let ready = leader.wait_worker_sync_ready().await; let ready = leader.wait_worker_sync_ready().await;
...@@ -156,6 +158,7 @@ impl KvConnectorLeaderRecorder { ...@@ -156,6 +158,7 @@ impl KvConnectorLeaderRecorder {
vllm_ep, vllm_ep,
Some(output_ep), Some(output_ep),
EventSource::Vllm, EventSource::Vllm,
consolidator_mode,
); );
} }
......
...@@ -16,7 +16,7 @@ use dynamo_llm::{ ...@@ -16,7 +16,7 @@ use dynamo_llm::{
connector::protocol::{LeaderTransferRequest, RequestType, TransferType}, connector::protocol::{LeaderTransferRequest, RequestType, TransferType},
distributed::{BlockTransferPool, BlockTransferRequest, KvbmLeader}, distributed::{BlockTransferPool, BlockTransferRequest, KvbmLeader},
}, },
tokens::TokenBlock, tokens::{SequenceHash, TokenBlock},
}; };
use dynamo_runtime::utils::task::CriticalTaskExecutionHandle; use dynamo_runtime::utils::task::CriticalTaskExecutionHandle;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
...@@ -114,6 +114,7 @@ pub trait Slot: std::fmt::Debug { ...@@ -114,6 +114,7 @@ pub trait Slot: std::fmt::Debug {
num_computed_tokens: usize, num_computed_tokens: usize,
num_scheduled_tokens: usize, num_scheduled_tokens: usize,
priorities: Option<&[u32]>, priorities: Option<&[u32]>,
external_sequence_hashes: Option<&[SequenceHash]>,
) -> Result<(), SlotError>; ) -> Result<(), SlotError>;
fn record_start_iteration(&mut self, iteration: u64) -> Result<(), SlotError>; fn record_start_iteration(&mut self, iteration: u64) -> Result<(), SlotError>;
...@@ -481,6 +482,29 @@ impl VllmConnectorSlot { ...@@ -481,6 +482,29 @@ impl VllmConnectorSlot {
&self.device_blocks &self.device_blocks
} }
fn sync_external_sequence_hashes(
&mut self,
external_sequence_hashes: &[SequenceHash],
) -> Result<(), SlotError> {
assert_eq!(
external_sequence_hashes.len(),
self.sequence.blocks().len(),
"external_sequence_hashes length ({}) must match completed block count ({}) for request {}",
external_sequence_hashes.len(),
self.sequence.blocks().len(),
self.request_id
);
self.sequence
.sync_external_sequence_hashes(external_sequence_hashes);
for block in self.sequence.blocks() {
block.assert_external_hashes_assigned();
}
Ok(())
}
fn mark_as_skipped_prefill(&mut self) -> Result<(), SlotError> { fn mark_as_skipped_prefill(&mut self) -> Result<(), SlotError> {
if self.state != SlotState::Prefilling { if self.state != SlotState::Prefilling {
return Err(SlotError::InvalidState(format!( return Err(SlotError::InvalidState(format!(
...@@ -597,6 +621,7 @@ impl Slot for VllmConnectorSlot { ...@@ -597,6 +621,7 @@ impl Slot for VllmConnectorSlot {
num_computed_tokens: usize, num_computed_tokens: usize,
num_scheduled_tokens: usize, num_scheduled_tokens: usize,
priorities: Option<&[u32]>, priorities: Option<&[u32]>,
external_sequence_hashes: Option<&[SequenceHash]>,
) -> Result<(), SlotError> { ) -> Result<(), SlotError> {
tracing::debug!( tracing::debug!(
"ENTRY: apply_scheduler_output: req={}, tokens.len={}, block_ids.len={}, computed={}, scheduled={}, \ "ENTRY: apply_scheduler_output: req={}, tokens.len={}, block_ids.len={}, computed={}, scheduled={}, \
...@@ -634,6 +659,10 @@ impl Slot for VllmConnectorSlot { ...@@ -634,6 +659,10 @@ impl Slot for VllmConnectorSlot {
self.state = SlotState::Prefilling; self.state = SlotState::Prefilling;
} }
if let Some(external_sequence_hashes) = external_sequence_hashes {
self.sync_external_sequence_hashes(external_sequence_hashes)?;
}
// Use max to advance both current_position and evaluated_blocks at least by num_computed_tokens. // Use max to advance both current_position and evaluated_blocks at least by num_computed_tokens.
// This logic is to prevent redundant block offloading. // This logic is to prevent redundant block offloading.
self.current_position = max(self.current_position, num_computed_tokens); self.current_position = max(self.current_position, num_computed_tokens);
...@@ -849,6 +878,12 @@ impl Slot for VllmConnectorSlot { ...@@ -849,6 +878,12 @@ impl Slot for VllmConnectorSlot {
.copied() .copied()
.collect(); .collect();
if external_sequence_hashes.is_some() {
for block in &offload_token_blocks {
block.assert_external_hashes_assigned();
}
}
self.offload_blocks( self.offload_blocks(
&offload_block_ids, &offload_block_ids,
&offload_token_blocks, &offload_token_blocks,
...@@ -1878,7 +1913,7 @@ impl<S: Storage, L: LocalityProvider, M: BlockMetadata> AnyBlocks for AnyImmutab ...@@ -1878,7 +1913,7 @@ impl<S: Storage, L: LocalityProvider, M: BlockMetadata> AnyBlocks for AnyImmutab
mod connector_tests { mod connector_tests {
use super::*; use super::*;
use crate::block_manager::cache_stats::CacheStatsTracker; use crate::block_manager::cache_stats::CacheStatsTracker;
use dynamo_llm::tokens::{SaltHash, Tokens}; use dynamo_llm::tokens::{SaltHash, SequenceHash, Tokens};
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::mpsc; use tokio::sync::mpsc;
...@@ -1914,6 +1949,12 @@ mod connector_tests { ...@@ -1914,6 +1949,12 @@ mod connector_tests {
(start..start + count).collect() (start..start + count).collect()
} }
fn external_hashes(start: SequenceHash, count: usize) -> Vec<SequenceHash> {
(0..count)
.map(|offset| start + offset as SequenceHash)
.collect()
}
/// Drains all pending offload requests from the channel and returns their block IDs. /// Drains all pending offload requests from the channel and returns their block IDs.
fn drain_offload_block_ids( fn drain_offload_block_ids(
rx: &mut mpsc::UnboundedReceiver<LocalTransferRequest>, rx: &mut mpsc::UnboundedReceiver<LocalTransferRequest>,
...@@ -1941,7 +1982,7 @@ mod connector_tests { ...@@ -1941,7 +1982,7 @@ mod connector_tests {
assert_eq!(slot.num_device_blocks_allocated(), 3); assert_eq!(slot.num_device_blocks_allocated(), 3);
// Step 2: apply_scheduler_output with empty blocks (vLLM pattern) // Step 2: apply_scheduler_output with empty blocks (vLLM pattern)
slot.apply_scheduler_output(&[], &[], 0, num_tokens, None) slot.apply_scheduler_output(&[], &[], 0, num_tokens, None, None)
.unwrap(); .unwrap();
// device_blocks should still be exactly 3 — no double-add // device_blocks should still be exactly 3 — no double-add
...@@ -1964,13 +2005,70 @@ mod connector_tests { ...@@ -1964,13 +2005,70 @@ mod connector_tests {
// Step 2: apply_scheduler_output with THE SAME blocks (TRT-LLM pattern) // Step 2: apply_scheduler_output with THE SAME blocks (TRT-LLM pattern)
// Without the dedup guard, this doubles device_blocks to len=6. // Without the dedup guard, this doubles device_blocks to len=6.
slot.apply_scheduler_output(&[], &blocks, 0, num_tokens, None) slot.apply_scheduler_output(&[], &blocks, 0, num_tokens, None, None)
.unwrap(); .unwrap();
// device_blocks must still be exactly 3 — dedup guard prevented the double-add // device_blocks must still be exactly 3 — dedup guard prevented the double-add
assert_eq!(slot.num_device_blocks_allocated(), 3); assert_eq!(slot.num_device_blocks_allocated(), 3);
} }
#[test]
fn test_trtllm_external_hashes_are_assigned_to_completed_blocks() {
let num_tokens = 96; // 3 blocks of 32
let (mut slot, _rx) = create_test_slot(num_tokens, 0);
let blocks = block_ids(100, 3);
let external_hashes = external_hashes(10_000, 3);
slot.append_mutable_device_blocks(&blocks).unwrap();
slot.apply_scheduler_output(&[], &blocks, 0, num_tokens, None, Some(&external_hashes))
.unwrap();
let completed_blocks = slot.sequence.blocks();
assert_eq!(completed_blocks.len(), 3);
assert_eq!(
completed_blocks[0].external_sequence_hash(),
Some(external_hashes[0])
);
assert_eq!(completed_blocks[0].external_parent_sequence_hash(), None);
assert_eq!(
completed_blocks[1].external_sequence_hash(),
Some(external_hashes[1])
);
assert_eq!(
completed_blocks[1].external_parent_sequence_hash(),
Some(external_hashes[0])
);
assert_eq!(
completed_blocks[2].external_sequence_hash(),
Some(external_hashes[2])
);
assert_eq!(
completed_blocks[2].external_parent_sequence_hash(),
Some(external_hashes[1])
);
}
#[test]
#[should_panic(expected = "external_sequence_hash mismatch")]
fn test_trtllm_external_hash_chain_mismatch_panics() {
let num_tokens = 96; // 3 blocks of 32
let (mut slot, _rx) = create_test_slot(num_tokens, 0);
let blocks = block_ids(100, 3);
let external_hashes = external_hashes(20_000, 3);
slot.append_mutable_device_blocks(&blocks).unwrap();
slot.apply_scheduler_output(&[], &blocks, 0, num_tokens, None, Some(&external_hashes))
.unwrap();
let mismatched_hashes = vec![
external_hashes[0],
external_hashes[1] + 1,
external_hashes[2],
];
slot.apply_scheduler_output(&[], &blocks, 0, num_tokens, None, Some(&mismatched_hashes))
.unwrap();
}
// --------------------------------------------------------------- // ---------------------------------------------------------------
// Test 3: Decode adds a new block correctly // Test 3: Decode adds a new block correctly
// --------------------------------------------------------------- // ---------------------------------------------------------------
...@@ -1982,14 +2080,14 @@ mod connector_tests { ...@@ -1982,14 +2080,14 @@ mod connector_tests {
// Prefill: append + apply with empty blocks (vLLM pattern) // Prefill: append + apply with empty blocks (vLLM pattern)
slot.append_mutable_device_blocks(&prefill_blocks).unwrap(); slot.append_mutable_device_blocks(&prefill_blocks).unwrap();
slot.apply_scheduler_output(&[], &[], 0, num_tokens, None) slot.apply_scheduler_output(&[], &[], 0, num_tokens, None, None)
.unwrap(); .unwrap();
assert_eq!(slot.num_device_blocks_allocated(), 3); assert_eq!(slot.num_device_blocks_allocated(), 3);
// Decode: new block at boundary (token 96 = block 3) // Decode: new block at boundary (token 96 = block 3)
let decode_block = block_ids(200, 1); let decode_block = block_ids(200, 1);
let decode_token: Vec<u32> = vec![9999]; let decode_token: Vec<u32> = vec![9999];
slot.apply_scheduler_output(&decode_token, &decode_block, 95, 1, None) slot.apply_scheduler_output(&decode_token, &decode_block, 95, 1, None, None)
.unwrap(); .unwrap();
assert_eq!(slot.num_device_blocks_allocated(), 4); assert_eq!(slot.num_device_blocks_allocated(), 4);
} }
...@@ -2024,7 +2122,7 @@ mod connector_tests { ...@@ -2024,7 +2122,7 @@ mod connector_tests {
// Empty tokens → Prefilling state, and next_position(96) == total_tokens(96) // Empty tokens → Prefilling state, and next_position(96) == total_tokens(96)
// so the early-return does not fire and offload proceeds. // so the early-return does not fire and offload proceeds.
slot.append_mutable_device_blocks(&blocks).unwrap(); slot.append_mutable_device_blocks(&blocks).unwrap();
slot.apply_scheduler_output(&[], &[], 0, num_tokens, None) slot.apply_scheduler_output(&[], &[], 0, num_tokens, None, None)
.unwrap(); .unwrap();
let offloads = drain_offload_block_ids(&mut rx); let offloads = drain_offload_block_ids(&mut rx);
...@@ -2045,7 +2143,7 @@ mod connector_tests { ...@@ -2045,7 +2143,7 @@ mod connector_tests {
// Use the TRT-LLM pattern: append_mutable first, then apply with same blocks + priorities. // Use the TRT-LLM pattern: append_mutable first, then apply with same blocks + priorities.
// The dedup guard prevents the double-add, but priorities are still processed. // The dedup guard prevents the double-add, but priorities are still processed.
slot.append_mutable_device_blocks(&blocks).unwrap(); slot.append_mutable_device_blocks(&blocks).unwrap();
slot.apply_scheduler_output(&[], &blocks, 0, num_tokens, Some(&priorities)) slot.apply_scheduler_output(&[], &blocks, 0, num_tokens, Some(&priorities), None)
.unwrap(); .unwrap();
// device_blocks should be 3 (dedup prevented doubling) // device_blocks should be 3 (dedup prevented doubling)
...@@ -2069,7 +2167,7 @@ mod connector_tests { ...@@ -2069,7 +2167,7 @@ mod connector_tests {
let priorities: Vec<u32> = vec![80, 80, 10, 10]; let priorities: Vec<u32> = vec![80, 80, 10, 10];
slot.append_mutable_device_blocks(&blocks).unwrap(); slot.append_mutable_device_blocks(&blocks).unwrap();
slot.apply_scheduler_output(&[], &blocks, 0, num_tokens, Some(&priorities)) slot.apply_scheduler_output(&[], &blocks, 0, num_tokens, Some(&priorities), None)
.unwrap(); .unwrap();
let offloads = drain_offload_block_ids(&mut rx); let offloads = drain_offload_block_ids(&mut rx);
...@@ -2079,7 +2177,7 @@ mod connector_tests { ...@@ -2079,7 +2177,7 @@ mod connector_tests {
// Because offload was terminated, no further offloading should happen. // Because offload was terminated, no further offloading should happen.
let decode_block = block_ids(200, 1); let decode_block = block_ids(200, 1);
let decode_token: Vec<u32> = vec![9999]; let decode_token: Vec<u32> = vec![9999];
slot.apply_scheduler_output(&decode_token, &decode_block, 127, 1, None) slot.apply_scheduler_output(&decode_token, &decode_block, 127, 1, None, None)
.unwrap(); .unwrap();
let further_offloads = drain_offload_block_ids(&mut rx); let further_offloads = drain_offload_block_ids(&mut rx);
...@@ -2102,14 +2200,16 @@ mod connector_tests { ...@@ -2102,14 +2200,16 @@ mod connector_tests {
slot.append_mutable_device_blocks(&blocks).unwrap(); slot.append_mutable_device_blocks(&blocks).unwrap();
// Chunk 1: schedule first 64 tokens → evaluates blocks 0,1 // Chunk 1: schedule first 64 tokens → evaluates blocks 0,1
slot.apply_scheduler_output(&[], &[], 0, 64, None).unwrap(); slot.apply_scheduler_output(&[], &[], 0, 64, None, None)
.unwrap();
let offloads_1 = drain_offload_block_ids(&mut rx); let offloads_1 = drain_offload_block_ids(&mut rx);
assert_eq!(offloads_1.len(), 1); assert_eq!(offloads_1.len(), 1);
assert_eq!(offloads_1[0], vec![100, 101]); // blocks 0,1 assert_eq!(offloads_1[0], vec![100, 101]); // blocks 0,1
// Chunk 2: schedule next 64 tokens → evaluates blocks 2,3 // Chunk 2: schedule next 64 tokens → evaluates blocks 2,3
// (uses cached_request pattern: empty tokens, empty blocks) // (uses cached_request pattern: empty tokens, empty blocks)
slot.apply_scheduler_output(&[], &[], 64, 64, None).unwrap(); slot.apply_scheduler_output(&[], &[], 64, 64, None, None)
.unwrap();
let offloads_2 = drain_offload_block_ids(&mut rx); let offloads_2 = drain_offload_block_ids(&mut rx);
assert_eq!(offloads_2.len(), 1); assert_eq!(offloads_2.len(), 1);
assert_eq!(offloads_2[0], vec![102, 103]); // blocks 2,3 assert_eq!(offloads_2[0], vec![102, 103]); // blocks 2,3
...@@ -2133,7 +2233,7 @@ mod connector_tests { ...@@ -2133,7 +2233,7 @@ mod connector_tests {
// Step 2: apply_scheduler_output with overlapping blocks [12, 13]. // Step 2: apply_scheduler_output with overlapping blocks [12, 13].
// Suffix [12] of device_blocks matches prefix [12] of block_ids. // Suffix [12] of device_blocks matches prefix [12] of block_ids.
// Only block 13 is new and gets appended. // Only block 13 is new and gets appended.
slot.apply_scheduler_output(&[], &[12, 13], 0, 128, None) slot.apply_scheduler_output(&[], &[12, 13], 0, 128, None, None)
.unwrap(); .unwrap();
assert_eq!(slot.num_device_blocks_allocated(), 4); assert_eq!(slot.num_device_blocks_allocated(), 4);
...@@ -2152,7 +2252,7 @@ mod connector_tests { ...@@ -2152,7 +2252,7 @@ mod connector_tests {
// Chunk 1: 3 blocks, all high priority, schedule 96 tokens // Chunk 1: 3 blocks, all high priority, schedule 96 tokens
slot.append_mutable_device_blocks(&[10, 11, 12]).unwrap(); slot.append_mutable_device_blocks(&[10, 11, 12]).unwrap();
slot.apply_scheduler_output(&[], &[10, 11, 12], 0, 96, Some(&[80, 80, 80])) slot.apply_scheduler_output(&[], &[10, 11, 12], 0, 96, Some(&[80, 80, 80]), None)
.unwrap(); .unwrap();
let offloads_1 = drain_offload_block_ids(&mut rx); let offloads_1 = drain_offload_block_ids(&mut rx);
...@@ -2171,7 +2271,7 @@ mod connector_tests { ...@@ -2171,7 +2271,7 @@ mod connector_tests {
slot.append_mutable_device_blocks(&[13]).unwrap(); slot.append_mutable_device_blocks(&[13]).unwrap();
assert_eq!(slot.num_device_blocks_allocated(), 4); assert_eq!(slot.num_device_blocks_allocated(), 4);
slot.apply_scheduler_output(&[], &[12, 13], 96, 32, Some(&[80, 10])) slot.apply_scheduler_output(&[], &[12, 13], 96, 32, Some(&[80, 10]), None)
.unwrap(); .unwrap();
// Candidate is block 13 (index 3, evaluated_blocks=3). // Candidate is block 13 (index 3, evaluated_blocks=3).
...@@ -2198,7 +2298,7 @@ mod connector_tests { ...@@ -2198,7 +2298,7 @@ mod connector_tests {
// block_ids[0]=11 is found at device_blocks[1], but device_blocks[1..] = [11,12] // block_ids[0]=11 is found at device_blocks[1], but device_blocks[1..] = [11,12]
// does NOT match block_ids[..2] = [11,14]. Contract violation. // does NOT match block_ids[..2] = [11,14]. Contract violation.
slot.apply_scheduler_output(&[], &[11, 14], 0, 128, None) slot.apply_scheduler_output(&[], &[11, 14], 0, 128, None, None)
.unwrap(); .unwrap();
} }
...@@ -2218,7 +2318,7 @@ mod connector_tests { ...@@ -2218,7 +2318,7 @@ mod connector_tests {
// Overlap: suffix [13,14] matches prefix [13,14], overlap=2. // Overlap: suffix [13,14] matches prefix [13,14], overlap=2.
// new_ids = [10]. But 10 ∈ device_blocks → contract violation. // new_ids = [10]. But 10 ∈ device_blocks → contract violation.
slot.apply_scheduler_output(&[], &[13, 14, 10], 0, 192, None) slot.apply_scheduler_output(&[], &[13, 14, 10], 0, 192, None, None)
.unwrap(); .unwrap();
} }
...@@ -2238,7 +2338,7 @@ mod connector_tests { ...@@ -2238,7 +2338,7 @@ mod connector_tests {
// block_ids[0]=10 found at device_blocks[0], suffix_len=3. // block_ids[0]=10 found at device_blocks[0], suffix_len=3.
// device_blocks[0..3]=[10,11,12] == block_ids[0..3]=[10,11,12] → overlap=3. // device_blocks[0..3]=[10,11,12] == block_ids[0..3]=[10,11,12] → overlap=3.
// new_ids=[13,14], both genuinely new → extend. // new_ids=[13,14], both genuinely new → extend.
slot.apply_scheduler_output(&[], &[10, 11, 12, 13, 14], 0, 160, None) slot.apply_scheduler_output(&[], &[10, 11, 12, 13, 14], 0, 160, None, None)
.unwrap(); .unwrap();
assert_eq!(slot.num_device_blocks_allocated(), 5); assert_eq!(slot.num_device_blocks_allocated(), 5);
...@@ -2261,7 +2361,7 @@ mod connector_tests { ...@@ -2261,7 +2361,7 @@ mod connector_tests {
assert_eq!(slot.num_device_blocks_allocated(), 5); assert_eq!(slot.num_device_blocks_allocated(), 5);
// Full overlap of all 5 existing blocks, plus 2 new. // Full overlap of all 5 existing blocks, plus 2 new.
slot.apply_scheduler_output(&[], &[10, 11, 12, 13, 14, 15, 16], 0, 224, None) slot.apply_scheduler_output(&[], &[10, 11, 12, 13, 14, 15, 16], 0, 224, None, None)
.unwrap(); .unwrap();
assert_eq!(slot.num_device_blocks_allocated(), 7); assert_eq!(slot.num_device_blocks_allocated(), 7);
......
...@@ -14,12 +14,30 @@ use crate::block_manager::{distributed::KvbmLeader as PyKvbmLeader, vllm::KvbmRe ...@@ -14,12 +14,30 @@ use crate::block_manager::{distributed::KvbmLeader as PyKvbmLeader, vllm::KvbmRe
use crate::get_current_tokio_handle; use crate::get_current_tokio_handle;
use anyhow; use anyhow;
use dynamo_llm::block_manager::connector::protocol::RequestType; use dynamo_llm::block_manager::connector::protocol::RequestType;
use dynamo_llm::block_manager::kv_consolidator::EventSource; use dynamo_llm::block_manager::kv_consolidator::{EventSource, KvEventConsolidationMode};
use dynamo_llm::block_manager::metrics_kvbm::{KvbmMetrics, KvbmMetricsRegistry}; use dynamo_llm::block_manager::metrics_kvbm::{KvbmMetrics, KvbmMetricsRegistry};
use std::collections::HashSet; use std::collections::HashSet;
use std::sync::{Arc, OnceLock}; use std::sync::{Arc, OnceLock};
use tokio::runtime::Handle; use tokio::runtime::Handle;
fn parse_consolidator_mode(mode: Option<String>) -> KvEventConsolidationMode {
let Some(mode) = mode else {
return KvEventConsolidationMode::Dedup;
};
match mode.parse() {
Ok(mode) => mode,
Err(error) => {
tracing::warn!(
"Invalid KV event consolidator mode {:?}: {}. Falling back to dedup.",
mode,
error
);
KvEventConsolidationMode::Dedup
}
}
}
pub trait Leader: Send + Sync + std::fmt::Debug { pub trait Leader: Send + Sync + std::fmt::Debug {
fn get_num_new_matched_tokens( fn get_num_new_matched_tokens(
&mut self, &mut self,
...@@ -71,6 +89,7 @@ impl KvConnectorLeader { ...@@ -71,6 +89,7 @@ impl KvConnectorLeader {
leader_py: PyKvbmLeader, leader_py: PyKvbmLeader,
consolidator_trtllm_endpoint: Option<String>, consolidator_trtllm_endpoint: Option<String>,
consolidator_output_endpoint: Option<String>, consolidator_output_endpoint: Option<String>,
consolidator_mode: Option<String>,
) -> Self { ) -> Self {
tracing::info!( tracing::info!(
"KvConnectorLeader initialized with worker_id: {}", "KvConnectorLeader initialized with worker_id: {}",
...@@ -95,6 +114,7 @@ impl KvConnectorLeader { ...@@ -95,6 +114,7 @@ impl KvConnectorLeader {
// Capture consolidator endpoints for the async block // Capture consolidator endpoints for the async block
let consolidator_trtllm_ep = consolidator_trtllm_endpoint.clone(); let consolidator_trtllm_ep = consolidator_trtllm_endpoint.clone();
let consolidator_output_ep = consolidator_output_endpoint.clone(); let consolidator_output_ep = consolidator_output_endpoint.clone();
let consolidator_mode = parse_consolidator_mode(consolidator_mode.clone());
handle.spawn(async move { handle.spawn(async move {
let ready = leader.wait_worker_sync_ready().await; let ready = leader.wait_worker_sync_ready().await;
...@@ -125,6 +145,7 @@ impl KvConnectorLeader { ...@@ -125,6 +145,7 @@ impl KvConnectorLeader {
trtllm_ep, trtllm_ep,
consolidator_output_ep, consolidator_output_ep,
EventSource::Trtllm, EventSource::Trtllm,
consolidator_mode,
); );
} }
...@@ -389,6 +410,7 @@ impl Leader for KvConnectorLeader { ...@@ -389,6 +410,7 @@ impl Leader for KvConnectorLeader {
new_req.num_computed_tokens, new_req.num_computed_tokens,
scheduled_tokens, scheduled_tokens,
new_req.priorities.as_deref(), new_req.priorities.as_deref(),
new_req.external_sequence_hashes.as_deref(),
)?; )?;
let pending_ops_opt = slot.take_pending_operations(); let pending_ops_opt = slot.take_pending_operations();
...@@ -440,6 +462,7 @@ impl Leader for KvConnectorLeader { ...@@ -440,6 +462,7 @@ impl Leader for KvConnectorLeader {
cached_req.num_computed_tokens, cached_req.num_computed_tokens,
scheduled_tokens, scheduled_tokens,
cached_req.priorities.as_deref(), cached_req.priorities.as_deref(),
cached_req.external_sequence_hashes.as_deref(),
)?; )?;
if let Some(pending_ops) = slot.take_pending_operations() { if let Some(pending_ops) = slot.take_pending_operations() {
...@@ -518,7 +541,7 @@ pub struct PyTrtllmKvConnectorLeader { ...@@ -518,7 +541,7 @@ pub struct PyTrtllmKvConnectorLeader {
#[pymethods] #[pymethods]
impl PyTrtllmKvConnectorLeader { impl PyTrtllmKvConnectorLeader {
#[new] #[new]
#[pyo3(signature = (worker_id, drt, page_size, leader, consolidator_trtllm_endpoint=None, consolidator_output_endpoint=None))] #[pyo3(signature = (worker_id, drt, page_size, leader, consolidator_trtllm_endpoint=None, consolidator_output_endpoint=None, consolidator_mode=None))]
pub fn new( pub fn new(
worker_id: u64, worker_id: u64,
drt: Option<PyObject>, drt: Option<PyObject>,
...@@ -526,6 +549,7 @@ impl PyTrtllmKvConnectorLeader { ...@@ -526,6 +549,7 @@ impl PyTrtllmKvConnectorLeader {
leader: PyKvbmLeader, leader: PyKvbmLeader,
consolidator_trtllm_endpoint: Option<String>, consolidator_trtllm_endpoint: Option<String>,
consolidator_output_endpoint: Option<String>, consolidator_output_endpoint: Option<String>,
consolidator_mode: Option<String>,
) -> PyResult<Self> { ) -> PyResult<Self> {
let _ = &drt; // drt is currently un-used in leader let _ = &drt; // drt is currently un-used in leader
...@@ -535,6 +559,7 @@ impl PyTrtllmKvConnectorLeader { ...@@ -535,6 +559,7 @@ impl PyTrtllmKvConnectorLeader {
leader, leader,
consolidator_trtllm_endpoint, consolidator_trtllm_endpoint,
consolidator_output_endpoint, consolidator_output_endpoint,
consolidator_mode,
)); ));
Ok(Self { connector_leader }) Ok(Self { connector_leader })
} }
......
...@@ -126,10 +126,12 @@ impl AicPerfConfig { ...@@ -126,10 +126,12 @@ impl AicPerfConfig {
#[pymethods] #[pymethods]
impl KvRouterConfig { impl KvRouterConfig {
#[new] #[new]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, durable_kv_events=false, router_replica_sync=false, router_track_active_blocks=true, router_track_output_blocks=false, router_assume_kv_reuse=true, router_track_prefill_tokens=true, router_prefill_load_model="none", router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8, router_queue_threshold=Some(4.0), router_event_threads=4, router_queue_policy="fcfs", use_remote_indexer=false, serve_indexer=false, shared_cache_multiplier=0.0, shared_cache_type="none"))] #[pyo3(signature = (overlap_score_weight=1.0, host_cache_hit_weight=0.75, disk_cache_hit_weight=0.25, router_temperature=0.0, use_kv_events=true, durable_kv_events=false, router_replica_sync=false, router_track_active_blocks=true, router_track_output_blocks=false, router_assume_kv_reuse=true, router_track_prefill_tokens=true, router_prefill_load_model="none", router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8, router_queue_threshold=Some(4.0), router_event_threads=4, router_queue_policy="fcfs", use_remote_indexer=false, serve_indexer=false, shared_cache_multiplier=0.0, shared_cache_type="none"))]
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn new( fn new(
overlap_score_weight: f64, overlap_score_weight: f64,
host_cache_hit_weight: f64,
disk_cache_hit_weight: f64,
router_temperature: f64, router_temperature: f64,
use_kv_events: bool, use_kv_events: bool,
durable_kv_events: bool, durable_kv_events: bool,
...@@ -155,6 +157,8 @@ impl KvRouterConfig { ...@@ -155,6 +157,8 @@ impl KvRouterConfig {
KvRouterConfig { KvRouterConfig {
inner: RsKvRouterConfig { inner: RsKvRouterConfig {
overlap_score_weight, overlap_score_weight,
host_cache_hit_weight,
disk_cache_hit_weight,
router_temperature, router_temperature,
use_kv_events, use_kv_events,
durable_kv_events, durable_kv_events,
......
...@@ -1078,7 +1078,6 @@ impl KvRouter { ...@@ -1078,7 +1078,6 @@ impl KvRouter {
lora_name.clone(), lora_name.clone(),
0.0, 0.0,
None, None,
None,
None, // allowed_worker_ids: pass via RoutingHints in PreprocessedRequest path None, // allowed_worker_ids: pass via RoutingHints in PreprocessedRequest path
) )
.await .await
......
...@@ -68,7 +68,8 @@ use std::collections::VecDeque; ...@@ -68,7 +68,8 @@ use std::collections::VecDeque;
use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::atomic::{AtomicUsize, Ordering};
use super::{ use super::{
EventKind, EventWarningKind, KvIndexerMetrics, PreBoundEventCounters, SyncIndexer, WorkerTask, EventKind, EventWarningKind, KvIndexerMetrics, MatchDetails, PreBoundEventCounters,
SyncIndexer, WorkerTask,
}; };
use crate::cleanup::{self, CleanableNode, CleanupGuard, CleanupState}; use crate::cleanup::{self, CleanableNode, CleanupGuard, CleanupState};
use crate::protocols::*; use crate::protocols::*;
...@@ -479,26 +480,36 @@ impl ConcurrentRadixTreeCompressed { ...@@ -479,26 +480,36 @@ impl ConcurrentRadixTreeCompressed {
// ------------------------------------------------------------------ // ------------------------------------------------------------------
/// Traverse the radix tree to find the best match for a given sequence of /// Traverse the radix tree to find the best match for a given sequence of
/// [`LocalBlockHash`]es. /// [`LocalBlockHash`]es, returning both overlap scores and the last matched
/// `ExternalSequenceBlockHash` per worker (used for lower-tier continuation).
/// ///
/// Workers in `full_edge_workers` are tracked in the `active` set and continue /// Workers in `full_edge_workers` are tracked in the `active` set and continue
/// into children. Workers in `worker_cutoffs` are scored at the node where their /// into children. Workers in `worker_cutoffs` are scored at the node where their
/// cutoff falls short and are never propagated into children. /// cutoff falls short and are never propagated into children.
pub fn find_matches_impl( pub fn find_match_details_impl(
&self, &self,
sequence: &[LocalBlockHash], sequence: &[LocalBlockHash],
early_exit: bool, early_exit: bool,
) -> OverlapScores { ) -> MatchDetails {
let mut scores = OverlapScores::new(); let mut details = MatchDetails::new();
if sequence.is_empty() { if sequence.is_empty() {
return scores; return details;
} }
let MatchDetails {
overlap_scores: ref mut scores,
ref mut last_matched_hashes,
} = details;
let mut active: FxHashSet<WorkerWithDpRank> = FxHashSet::default(); let mut active: FxHashSet<WorkerWithDpRank> = FxHashSet::default();
let mut active_count: usize = 0; let mut active_count: usize = 0;
let mut matched_depth: u32 = 0; let mut matched_depth: u32 = 0;
let mut seq_pos: usize = 0; let mut seq_pos: usize = 0;
let mut first_node = true; let mut first_node = true;
// Last ExternalSequenceBlockHash from the previous fully-matched edge.
// Workers that drop at a node boundary (not present in the new node)
// were last matched at the end of the previous edge.
let mut prev_edge_last_hash: Option<ExternalSequenceBlockHash> = None;
let mut next_child = { let mut next_child = {
let root_guard = read_lock!(self, self.root); let root_guard = read_lock!(self, self.root);
...@@ -531,38 +542,49 @@ impl ConcurrentRadixTreeCompressed { ...@@ -531,38 +542,49 @@ impl ConcurrentRadixTreeCompressed {
} }
edge_match_len = match_len; edge_match_len = match_len;
// Helper: ExternalSequenceBlockHash at a given depth within this edge.
let edge_hash_at = |depth: usize| -> ExternalSequenceBlockHash {
debug_assert!(depth > 0 && depth <= guard.edge.len());
guard.edge[depth - 1].1
};
let prev_depth = matched_depth; let prev_depth = matched_depth;
if first_node { if first_node {
// Seed active set from full-edge workers (they can continue to children).
// Score partial workers immediately; they never continue into children.
active = guard.full_edge_workers.clone(); active = guard.full_edge_workers.clone();
active_count = active.len(); active_count = active.len();
for (&w, &k) in &guard.worker_cutoffs { for (&w, &k) in &guard.worker_cutoffs {
let contribution = k.min(edge_match_len) as u32; let contribution = k.min(edge_match_len);
if contribution > 0 { if contribution > 0 {
scores.scores.insert(w, contribution); scores.scores.insert(w, contribution as u32);
last_matched_hashes.insert(w, edge_hash_at(contribution));
} }
} }
first_node = false; first_node = false;
} else { } else {
let has_partial = !guard.worker_cutoffs.is_empty(); let has_partial = !guard.worker_cutoffs.is_empty();
if has_partial { if has_partial {
// Slow path: check each active worker against both maps.
active.retain(|w| { active.retain(|w| {
if guard.full_edge_workers.contains(w) { if guard.full_edge_workers.contains(w) {
true true
} else if let Some(&k) = guard.worker_cutoffs.get(w) { } else if let Some(&k) = guard.worker_cutoffs.get(w) {
let effective = k.min(edge_match_len) as u32; let effective = k.min(edge_match_len);
scores.scores.insert(*w, prev_depth + effective); scores.scores.insert(*w, prev_depth + effective as u32);
if effective > 0 {
last_matched_hashes.insert(*w, edge_hash_at(effective));
} else if let Some(h) = prev_edge_last_hash {
last_matched_hashes.insert(*w, h);
}
false false
} else { } else {
scores.scores.insert(*w, prev_depth); scores.scores.insert(*w, prev_depth);
if let Some(h) = prev_edge_last_hash {
last_matched_hashes.insert(*w, h);
}
false false
} }
}); });
} else { } else {
// Fast path: no partial workers — all coverage is full or absent.
let full_count = guard.full_edge_workers.len(); let full_count = guard.full_edge_workers.len();
if full_count != active_count { if full_count != active_count {
active.retain(|w| { active.retain(|w| {
...@@ -570,11 +592,13 @@ impl ConcurrentRadixTreeCompressed { ...@@ -570,11 +592,13 @@ impl ConcurrentRadixTreeCompressed {
true true
} else { } else {
scores.scores.insert(*w, prev_depth); scores.scores.insert(*w, prev_depth);
if let Some(h) = prev_edge_last_hash {
last_matched_hashes.insert(*w, h);
}
false false
} }
}); });
} }
// full_count == active_count: sets are identical (fast path).
} }
active_count = active.len(); active_count = active.len();
} }
...@@ -590,6 +614,9 @@ impl ConcurrentRadixTreeCompressed { ...@@ -590,6 +614,9 @@ impl ConcurrentRadixTreeCompressed {
} else { } else {
None None
}; };
// Track the deepest matched hash in this edge (both full and partial).
prev_edge_last_hash = Some(guard.edge[edge_match_len - 1].1);
} }
if active_count == 0 { if active_count == 0 {
...@@ -605,15 +632,32 @@ impl ConcurrentRadixTreeCompressed { ...@@ -605,15 +632,32 @@ impl ConcurrentRadixTreeCompressed {
} }
} }
// Record scores and hashes for workers that survived to the deepest level.
if let Some(h) = prev_edge_last_hash {
for worker in &active { for worker in &active {
scores.scores.insert(*worker, matched_depth); scores.scores.insert(*worker, matched_depth);
last_matched_hashes.insert(*worker, h);
}
} else {
for worker in &active {
scores.scores.insert(*worker, matched_depth);
}
} }
for worker in scores.scores.keys() { for worker in scores.scores.keys() {
if let Some(s) = self.tree_sizes.get(worker) { if let Some(s) = self.tree_sizes.get(worker) {
scores.tree_sizes.insert(*worker, s.load(Ordering::Relaxed)); scores.tree_sizes.insert(*worker, s.load(Ordering::Relaxed));
} }
} }
scores details
}
pub fn find_matches_impl(
&self,
sequence: &[LocalBlockHash],
early_exit: bool,
) -> OverlapScores {
self.find_match_details_impl(sequence, early_exit)
.overlap_scores
} }
// ------------------------------------------------------------------ // ------------------------------------------------------------------
......
...@@ -12,7 +12,8 @@ use tokio_util::sync::CancellationToken; ...@@ -12,7 +12,8 @@ use tokio_util::sync::CancellationToken;
use super::{ use super::{
DumpRequest, EventKind, GetWorkersRequest, KvIndexerInterface, KvIndexerMetrics, KvRouterError, DumpRequest, EventKind, GetWorkersRequest, KvIndexerInterface, KvIndexerMetrics, KvRouterError,
MatchRequest, PreBoundEventCounters, RadixTree, RoutingDecisionRequest, MatchDetails, MatchDetailsRequest, MatchRequest, PreBoundEventCounters, RadixTree,
RoutingDecisionRequest,
}; };
use crate::indexer::pruning::{BlockEntry, PruneConfig, PruneManager}; use crate::indexer::pruning::{BlockEntry, PruneConfig, PruneManager};
use crate::protocols::*; use crate::protocols::*;
...@@ -95,6 +96,8 @@ pub struct KvIndexer { ...@@ -95,6 +96,8 @@ pub struct KvIndexer {
event_tx: mpsc::Sender<RouterEvent>, event_tx: mpsc::Sender<RouterEvent>,
/// A sender for `MatchRequest`s. /// A sender for `MatchRequest`s.
match_tx: mpsc::Sender<MatchRequest>, match_tx: mpsc::Sender<MatchRequest>,
/// A sender for `MatchDetailsRequest`s.
match_details_tx: mpsc::Sender<MatchDetailsRequest>,
/// A sender for remove worker requests. /// A sender for remove worker requests.
remove_worker_tx: mpsc::Sender<WorkerId>, remove_worker_tx: mpsc::Sender<WorkerId>,
/// A sender for remove worker dp_rank requests. /// A sender for remove worker dp_rank requests.
...@@ -136,6 +139,7 @@ impl KvIndexer { ...@@ -136,6 +139,7 @@ impl KvIndexer {
let (event_tx, event_rx) = mpsc::channel::<RouterEvent>(16384); let (event_tx, event_rx) = mpsc::channel::<RouterEvent>(16384);
let (match_tx, match_rx) = mpsc::channel::<MatchRequest>(128); let (match_tx, match_rx) = mpsc::channel::<MatchRequest>(128);
let (match_details_tx, match_details_rx) = mpsc::channel::<MatchDetailsRequest>(128);
let (remove_worker_tx, remove_worker_rx) = mpsc::channel::<WorkerId>(16); let (remove_worker_tx, remove_worker_rx) = mpsc::channel::<WorkerId>(16);
let (remove_worker_dp_rank_tx, remove_worker_dp_rank_rx) = let (remove_worker_dp_rank_tx, remove_worker_dp_rank_rx) =
mpsc::channel::<(WorkerId, DpRank)>(16); mpsc::channel::<(WorkerId, DpRank)>(16);
...@@ -156,6 +160,7 @@ impl KvIndexer { ...@@ -156,6 +160,7 @@ impl KvIndexer {
runtime.block_on(async move { runtime.block_on(async move {
let cancel = cancel_clone; let cancel = cancel_clone;
let mut match_rx = match_rx; let mut match_rx = match_rx;
let mut match_details_rx = match_details_rx;
let mut event_rx = event_rx; let mut event_rx = event_rx;
let mut remove_worker_rx = remove_worker_rx; let mut remove_worker_rx = remove_worker_rx;
let mut remove_worker_dp_rank_rx = remove_worker_dp_rank_rx; let mut remove_worker_dp_rank_rx = remove_worker_dp_rank_rx;
...@@ -320,6 +325,11 @@ impl KvIndexer { ...@@ -320,6 +325,11 @@ impl KvIndexer {
let _ = req.resp.send(matches); let _ = req.resp.send(matches);
} }
Some(req) = match_details_rx.recv() => {
let matches = trie.find_match_details(req.sequence, req.early_exit);
let _ = req.resp.send(matches);
}
_ = expiry_fut => { _ = expiry_fut => {
// TTL-based expiry triggered // TTL-based expiry triggered
let Some(ref mut pm) = prune_manager else { continue }; let Some(ref mut pm) = prune_manager else { continue };
...@@ -351,6 +361,7 @@ impl KvIndexer { ...@@ -351,6 +361,7 @@ impl KvIndexer {
cancel: token, cancel: token,
event_tx, event_tx,
match_tx, match_tx,
match_details_tx,
remove_worker_tx, remove_worker_tx,
remove_worker_dp_rank_tx, remove_worker_dp_rank_tx,
get_workers_tx, get_workers_tx,
...@@ -382,6 +393,21 @@ impl KvIndexer { ...@@ -382,6 +393,21 @@ impl KvIndexer {
self.event_tx.clone() self.event_tx.clone()
} }
pub async fn find_match_details(
&self,
sequence: Vec<LocalBlockHash>,
) -> Result<MatchDetails, KvRouterError> {
let (resp_tx, resp_rx) = oneshot::channel();
self.match_details_tx
.send(MatchDetailsRequest::new(sequence, false, resp_tx))
.await
.map_err(|_| KvRouterError::IndexerOffline)?;
resp_rx
.await
.map_err(|_| KvRouterError::IndexerDroppedRequest)
}
#[cfg(test)] #[cfg(test)]
pub fn snapshot_event_sender(&self) -> mpsc::Sender<DumpRequest> { pub fn snapshot_event_sender(&self) -> mpsc::Sender<DumpRequest> {
self.dump_tx.clone() self.dump_tx.clone()
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::{ use std::{
collections::VecDeque, collections::{HashMap, VecDeque},
sync::{Arc, Mutex}, sync::{Arc, Mutex},
}; };
...@@ -13,7 +13,7 @@ use tokio_util::sync::CancellationToken; ...@@ -13,7 +13,7 @@ use tokio_util::sync::CancellationToken;
use super::{ use super::{
GetWorkersRequest, KvIndexer, KvIndexerInterface, KvIndexerMetrics, KvRouterError, GetWorkersRequest, KvIndexer, KvIndexerInterface, KvIndexerMetrics, KvRouterError,
WorkerKvQueryResponse, LowerTierIndexer, ThreadPoolIndexer, WorkerKvQueryResponse,
}; };
use crate::protocols::*; use crate::protocols::*;
...@@ -205,6 +205,8 @@ impl RecoverySnapshotCache { ...@@ -205,6 +205,8 @@ impl RecoverySnapshotCache {
pub struct LocalKvIndexer { pub struct LocalKvIndexer {
/// The underlying indexer /// The underlying indexer
indexer: KvIndexer, indexer: KvIndexer,
/// Lazily-created exact lower-tier indexes partitioned by storage tier.
lower_tier_indexers: Arc<Mutex<HashMap<StorageTier, Arc<ThreadPoolIndexer<LowerTierIndexer>>>>>,
/// Circular buffer of recent events /// Circular buffer of recent events
pub(super) event_buffer: Mutex<VecDeque<RouterEvent>>, pub(super) event_buffer: Mutex<VecDeque<RouterEvent>>,
/// Coordinates single-flight tree dumps and the cached recovery snapshot. /// Coordinates single-flight tree dumps and the cached recovery snapshot.
...@@ -229,6 +231,7 @@ impl LocalKvIndexer { ...@@ -229,6 +231,7 @@ impl LocalKvIndexer {
) -> Self { ) -> Self {
Self { Self {
indexer: KvIndexer::new(token, kv_block_size, metrics), indexer: KvIndexer::new(token, kv_block_size, metrics),
lower_tier_indexers: Arc::new(Mutex::new(HashMap::new())),
event_buffer: Mutex::new(VecDeque::with_capacity(max_buffer_size)), event_buffer: Mutex::new(VecDeque::with_capacity(max_buffer_size)),
recovery_cache: Arc::new(RecoverySnapshotCache::new()), recovery_cache: Arc::new(RecoverySnapshotCache::new()),
max_buffer_size, max_buffer_size,
...@@ -335,13 +338,7 @@ impl LocalKvIndexer { ...@@ -335,13 +338,7 @@ impl LocalKvIndexer {
/// ///
/// This forwards the event to the underlying indexer and records it on success. /// This forwards the event to the underlying indexer and records it on success.
pub async fn apply_event_with_buffer(&self, event: RouterEvent) -> Result<(), KvRouterError> { pub async fn apply_event_with_buffer(&self, event: RouterEvent) -> Result<(), KvRouterError> {
// Forward to underlying indexer let result = self.apply_event_by_tier(&event).await;
let result = self
.indexer
.event_sender()
.send(event.clone())
.await
.map_err(|_| KvRouterError::IndexerOffline);
if result.is_ok() { if result.is_ok() {
let should_invalidate = matches!(event.event.data, KvCacheEventData::Cleared); let should_invalidate = matches!(event.event.data, KvCacheEventData::Cleared);
let detected_gap = self.record_event(event); let detected_gap = self.record_event(event);
...@@ -617,6 +614,63 @@ impl LocalKvIndexer { ...@@ -617,6 +614,63 @@ impl LocalKvIndexer {
pub fn get_workers_sender(&self) -> mpsc::Sender<GetWorkersRequest> { pub fn get_workers_sender(&self) -> mpsc::Sender<GetWorkersRequest> {
self.indexer.get_workers_sender() self.indexer.get_workers_sender()
} }
/// Get the KV block size.
pub fn block_size(&self) -> u32 {
self.indexer.block_size()
}
async fn apply_event_to_primary(&self, event: RouterEvent) -> Result<(), KvRouterError> {
self.indexer
.event_sender()
.send(event)
.await
.map_err(|_| KvRouterError::IndexerOffline)
}
async fn apply_event_to_lower_tier(&self, event: RouterEvent) -> Result<(), KvRouterError> {
self.get_or_create_lower_tier_indexer(event.storage_tier)
.apply_event(event)
.await;
Ok(())
}
async fn apply_event_by_tier(&self, event: &RouterEvent) -> Result<(), KvRouterError> {
match &event.event.data {
KvCacheEventData::Cleared => {
self.apply_event_to_primary(event.clone()).await?;
for indexer in self.all_lower_tier_indexers() {
indexer.apply_event(event.clone()).await;
}
Ok(())
}
_ if event.storage_tier.is_gpu() => self.apply_event_to_primary(event.clone()).await,
_ => self.apply_event_to_lower_tier(event.clone()).await,
}
}
fn get_or_create_lower_tier_indexer(
&self,
storage_tier: StorageTier,
) -> Arc<ThreadPoolIndexer<LowerTierIndexer>> {
debug_assert!(!storage_tier.is_gpu());
let mut indexers = self.lower_tier_indexers.lock().unwrap();
indexers
.entry(storage_tier)
.or_insert_with(|| {
Arc::new(ThreadPoolIndexer::new(
LowerTierIndexer::new(),
1,
self.block_size(),
))
})
.clone()
}
fn all_lower_tier_indexers(&self) -> Vec<Arc<ThreadPoolIndexer<LowerTierIndexer>>> {
let indexers = self.lower_tier_indexers.lock().unwrap();
indexers.values().cloned().collect()
}
} }
// Implement KvIndexerInterface by delegating to the underlying indexer // Implement KvIndexerInterface by delegating to the underlying indexer
...@@ -646,10 +700,16 @@ impl KvIndexerInterface for LocalKvIndexer { ...@@ -646,10 +700,16 @@ impl KvIndexerInterface for LocalKvIndexer {
} }
async fn remove_worker(&self, worker: WorkerId) { async fn remove_worker(&self, worker: WorkerId) {
for indexer in self.all_lower_tier_indexers() {
indexer.remove_worker(worker).await;
}
let _ = self.indexer.remove_worker_sender().send(worker).await; let _ = self.indexer.remove_worker_sender().send(worker).await;
} }
async fn remove_worker_dp_rank(&self, worker: WorkerId, dp_rank: DpRank) { async fn remove_worker_dp_rank(&self, worker: WorkerId, dp_rank: DpRank) {
for indexer in self.all_lower_tier_indexers() {
KvIndexerInterface::remove_worker_dp_rank(&*indexer, worker, dp_rank).await;
}
KvIndexerInterface::remove_worker_dp_rank(&self.indexer, worker, dp_rank).await; KvIndexerInterface::remove_worker_dp_rank(&self.indexer, worker, dp_rank).await;
} }
...@@ -658,7 +718,27 @@ impl KvIndexerInterface for LocalKvIndexer { ...@@ -658,7 +718,27 @@ impl KvIndexerInterface for LocalKvIndexer {
} }
async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> { async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
self.indexer.dump_events().await let mut events = self.indexer.dump_events().await?;
// Also dump lower-tier indexer state so the router receives
// host-pinned / disk block information during recovery.
let lower_tiers: Vec<(StorageTier, Arc<ThreadPoolIndexer<LowerTierIndexer>>)> = {
let indexers = self.lower_tier_indexers.lock().unwrap();
indexers
.iter()
.map(|(&tier, idx)| (tier, idx.clone()))
.collect()
};
for (tier, indexer) in lower_tiers {
if let Ok(tier_events) = indexer.dump_events().await {
for mut event in tier_events {
event.storage_tier = tier;
events.push(event);
}
}
}
Ok(events)
} }
async fn process_routing_decision_for_request( async fn process_routing_decision_for_request(
...@@ -674,6 +754,231 @@ impl KvIndexerInterface for LocalKvIndexer { ...@@ -674,6 +754,231 @@ impl KvIndexerInterface for LocalKvIndexer {
} }
async fn flush(&self) -> usize { async fn flush(&self) -> usize {
self.indexer.flush().await let queued = self.indexer.flush().await;
for indexer in self.all_lower_tier_indexers() {
let _ = indexer.dump_events().await;
}
queued
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use rustc_hash::FxHashMap;
use tokio_util::sync::CancellationToken;
use super::LocalKvIndexer;
use crate::indexer::{KvIndexerInterface, KvIndexerMetrics, LowerTierContinuation};
use crate::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash, RouterEvent, StorageTier, WorkerWithDpRank,
};
fn lower_tier_store_event(
worker_id: u64,
dp_rank: u32,
event_id: u64,
parent_hash: u64,
tokens_hash: u64,
block_hash: u64,
storage_tier: StorageTier,
) -> RouterEvent {
RouterEvent::with_storage_tier(
worker_id,
KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: Some(ExternalSequenceBlockHash(parent_hash)),
start_position: None,
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(block_hash),
tokens_hash: LocalBlockHash(tokens_hash),
mm_extra_info: None,
}],
}),
dp_rank,
},
storage_tier,
)
}
fn lower_tier_hits(
indexer: &LocalKvIndexer,
storage_tier: StorageTier,
worker_id: u64,
dp_rank: u32,
parent_hash: u64,
tokens_hash: u64,
) -> usize {
let lower_tier_indexer = {
let indexers = indexer.lower_tier_indexers.lock().unwrap();
indexers.get(&storage_tier).cloned()
};
let Some(lower_tier_indexer) = lower_tier_indexer else {
return 0;
};
let mut continuations = FxHashMap::default();
continuations.insert(
WorkerWithDpRank::new(worker_id, dp_rank),
LowerTierContinuation::new(0, ExternalSequenceBlockHash(parent_hash)),
);
lower_tier_indexer
.backend()
.query_contiguous_hits(&[LocalBlockHash(tokens_hash)], &continuations)
.get(&WorkerWithDpRank::new(worker_id, dp_rank))
.copied()
.unwrap_or(0)
}
#[tokio::test]
async fn lower_tier_events_are_buffered_without_touching_primary_index() {
let indexer = LocalKvIndexer::new(
CancellationToken::new(),
4,
Arc::new(KvIndexerMetrics::new_unregistered()),
16,
);
let event = lower_tier_store_event(7, 0, 1, 900, 11, 101, StorageTier::HostPinned);
indexer
.apply_event_with_buffer(event.clone())
.await
.unwrap();
let _ = indexer.flush().await;
assert_eq!(indexer.get_all_events_in_buffer(), vec![event]);
assert_eq!(indexer.lower_tier_indexers.lock().unwrap().len(), 1);
assert_eq!(
lower_tier_hits(&indexer, StorageTier::HostPinned, 7, 0, 900, 11),
1
);
let overlap = indexer
.find_matches(vec![LocalBlockHash(11)])
.await
.unwrap();
assert!(overlap.scores.is_empty());
}
#[tokio::test]
async fn lower_tier_events_are_partitioned_by_storage_tier() {
let indexer = LocalKvIndexer::new(
CancellationToken::new(),
4,
Arc::new(KvIndexerMetrics::new_unregistered()),
16,
);
assert_eq!(indexer.lower_tier_indexers.lock().unwrap().len(), 0);
indexer
.apply_event_with_buffer(lower_tier_store_event(
19,
0,
1,
1000,
31,
301,
StorageTier::HostPinned,
))
.await
.unwrap();
let _ = indexer.flush().await;
assert_eq!(indexer.lower_tier_indexers.lock().unwrap().len(), 1);
indexer
.apply_event_with_buffer(lower_tier_store_event(
19,
0,
2,
2000,
31,
302,
StorageTier::Disk,
))
.await
.unwrap();
let _ = indexer.flush().await;
assert_eq!(indexer.lower_tier_indexers.lock().unwrap().len(), 2);
assert_eq!(
lower_tier_hits(&indexer, StorageTier::HostPinned, 19, 0, 1000, 31),
1
);
assert_eq!(
lower_tier_hits(&indexer, StorageTier::Disk, 19, 0, 2000, 31),
1
);
assert_eq!(
lower_tier_hits(&indexer, StorageTier::HostPinned, 19, 0, 2000, 31),
0
);
assert_eq!(
lower_tier_hits(&indexer, StorageTier::Disk, 19, 0, 1000, 31),
0
);
}
#[tokio::test]
async fn cleared_event_clears_all_lower_tier_dp_ranks_for_worker() {
let indexer = LocalKvIndexer::new(
CancellationToken::new(),
4,
Arc::new(KvIndexerMetrics::new_unregistered()),
16,
);
indexer
.apply_event_with_buffer(lower_tier_store_event(
11,
0,
1,
1000,
21,
201,
StorageTier::HostPinned,
))
.await
.unwrap();
indexer
.apply_event_with_buffer(lower_tier_store_event(
11,
1,
2,
2000,
22,
202,
StorageTier::HostPinned,
))
.await
.unwrap();
indexer
.apply_event_with_buffer(RouterEvent::with_storage_tier(
11,
KvCacheEvent {
event_id: 3,
data: KvCacheEventData::Cleared,
dp_rank: 0,
},
StorageTier::HostPinned,
))
.await
.unwrap();
let _ = indexer.flush().await;
assert_eq!(
lower_tier_hits(&indexer, StorageTier::HostPinned, 11, 0, 1000, 21),
0
);
assert_eq!(
lower_tier_hits(&indexer, StorageTier::HostPinned, 11, 1, 2000, 22),
0
);
} }
} }
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment