"vscode:/vscode.git/clone" did not exist on "6deeecb1d6a9f4eb1770b4272bfa85a4b6226e0a"
Unverified Commit e3d00b89 authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: Tier-based KV Routing (#8380)


Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent 7e48f3bd
......@@ -10,9 +10,12 @@ pub mod publisher;
pub mod subscriber;
pub mod tracker;
pub use config::KvEventConsolidatorConfig;
pub use config::{KvEventConsolidationMode, KvEventConsolidatorConfig};
pub use publisher::KvEventConsolidatorPublisher;
pub use tracker::{CacheStatusTracker, EventSource, StorageTier};
pub use tracker::{
CacheStatusTracker, DedupCacheStatusTracker, EventSource, PassthroughCacheStatusTracker,
StorageTier,
};
use anyhow::Result;
use std::sync::Arc;
......@@ -21,11 +24,14 @@ use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use subscriber::start_simple_zmq_listener;
use tracker::{RemoveEventInput, StoreEventInput};
pub type SharedCacheStatusTracker = Arc<RwLock<Box<dyn CacheStatusTracker>>>;
/// Handle for KVBM to send G2/G3 events directly to the KV Event Consolidator
#[derive(Clone, Debug)]
pub struct KvEventConsolidatorHandle {
pub(crate) tracker: Arc<RwLock<CacheStatusTracker>>,
pub(crate) tracker: SharedCacheStatusTracker,
}
impl KvEventConsolidatorHandle {
......@@ -45,7 +51,7 @@ impl KvEventConsolidatorHandle {
data_parallel_rank: Option<i32>,
) {
let mut tracker = self.tracker.write().await;
tracker.handle_store(
tracker.handle_store(StoreEventInput {
block_hash,
source,
token_ids,
......@@ -54,15 +60,24 @@ impl KvEventConsolidatorHandle {
lora_name,
tier,
data_parallel_rank,
);
});
}
/// Send a block remove event to the KV Event Consolidator
///
/// This is called by KVBM when a block is removed from G2 or G3.
pub async fn handle_remove(&self, block_hash: &str, source: EventSource) {
pub async fn handle_remove(
&self,
block_hash: &str,
source: EventSource,
tier: Option<StorageTier>,
) {
let mut tracker = self.tracker.write().await;
tracker.handle_remove(block_hash, source);
tracker.handle_remove(RemoveEventInput {
block_hash: block_hash.to_string(),
source,
tier,
});
}
/// Clear all blocks from the KV Event Consolidator
......@@ -77,7 +92,7 @@ impl KvEventConsolidatorHandle {
/// The main KV Event Consolidator that manages the event flow
pub struct KvEventConsolidator {
config: KvEventConsolidatorConfig,
tracker: Arc<RwLock<CacheStatusTracker>>,
tracker: SharedCacheStatusTracker,
subscriber_handle: Option<JoinHandle<()>>,
cancellation_token: CancellationToken,
publisher: Option<KvEventConsolidatorPublisher>,
......@@ -86,7 +101,11 @@ pub struct KvEventConsolidator {
impl KvEventConsolidator {
/// Create a new KV Event Consolidator
pub fn new(config: KvEventConsolidatorConfig) -> Result<Self> {
let tracker = Arc::new(RwLock::new(CacheStatusTracker::new()));
let tracker: Box<dyn CacheStatusTracker> = match config.mode {
KvEventConsolidationMode::Dedup => Box::new(DedupCacheStatusTracker::new()),
KvEventConsolidationMode::Passthrough => Box::new(PassthroughCacheStatusTracker::new()),
};
let tracker = Arc::new(RwLock::new(tracker));
let cancellation_token = CancellationToken::new();
Ok(Self {
......@@ -101,7 +120,8 @@ impl KvEventConsolidator {
/// Start the KV Event Consolidator
pub async fn start(&mut self) -> Result<()> {
tracing::info!(
"Starting KV Event Consolidator: subscribe from {}, publish to ZMQ at {}",
"Starting KV Event Consolidator in {} mode: subscribe from {}, publish to ZMQ at {}",
self.config.mode.as_str(),
self.config.engine_event_endpoint,
self.config.consolidated_event_endpoint
);
......@@ -152,7 +172,7 @@ impl KvEventConsolidator {
}
/// Get a reference to the cache status tracker (for debugging/metrics)
pub fn tracker(&self) -> Arc<RwLock<CacheStatusTracker>> {
pub fn tracker(&self) -> SharedCacheStatusTracker {
self.tracker.clone()
}
......
......@@ -12,10 +12,10 @@ use rmp_serde::Serializer;
use serde::Serialize;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::RwLock;
use tokio::task::JoinHandle;
use super::tracker::{CacheStatusTracker, ConsolidatedEvent};
use super::SharedCacheStatusTracker;
use super::tracker::ConsolidatedEvent;
use crate::utils::zmq::{bind_pub_socket, send_multipart};
/// Event batch structure matching vLLM's format (array_like=True)
......@@ -70,6 +70,7 @@ impl Event {
block_size,
lora_name,
source: _,
tier,
} => {
let parsed_hash = block_hash
.parse::<u64>()
......@@ -106,12 +107,13 @@ impl Event {
token_ids: token_ids_i32,
block_size: block_size_i32,
lora_name,
medium: None,
medium: tier.map(|t| t.to_vllm_medium().to_string()),
})
}
ConsolidatedEvent::Remove {
block_hash,
source: _,
tier,
} => {
// Parse block hash - fail if invalid to prevent corruption
let parsed_hash = block_hash.parse::<u64>().with_context(|| {
......@@ -120,7 +122,7 @@ impl Event {
Ok(Event::BlockRemoved {
block_hashes: vec![parsed_hash],
medium: None, // Not provided by ConsolidatedEvent
medium: tier.map(|t| t.to_vllm_medium().to_string()),
})
}
ConsolidatedEvent::ClearAll => Ok(Event::AllBlocksCleared {}),
......@@ -131,14 +133,14 @@ impl Event {
/// ZMQ Publisher for consolidated events
pub struct KvEventConsolidatorPublisher {
endpoint: String,
tracker: Arc<RwLock<CacheStatusTracker>>,
tracker: SharedCacheStatusTracker,
sequence: Arc<AtomicU64>,
task_handle: Option<JoinHandle<()>>,
}
impl KvEventConsolidatorPublisher {
/// Create a new publisher
pub fn new(endpoint: &str, tracker: Arc<RwLock<CacheStatusTracker>>) -> Result<Self> {
pub fn new(endpoint: &str, tracker: SharedCacheStatusTracker) -> Result<Self> {
let endpoint = endpoint.to_string();
let sequence = Arc::new(AtomicU64::new(0));
......@@ -177,7 +179,7 @@ impl KvEventConsolidatorPublisher {
/// Main publisher loop
async fn run_publisher_loop(
endpoint: String,
tracker: Arc<RwLock<CacheStatusTracker>>,
tracker: SharedCacheStatusTracker,
sequence: Arc<AtomicU64>,
) -> Result<()> {
tracing::info!("Starting consolidated event publisher on {}", endpoint);
......@@ -239,10 +241,13 @@ impl KvEventConsolidatorPublisher {
Some(0), // data_parallel_rank (default)
);
// Serialize to msgpack
// Serialize to msgpack.
// Keep the outer batch as a tuple [ts, events, dp_rank], but force
// inner struct-like events to use named fields so RawKvEvent's map
// deserializer can safely decode optional fields like `medium`.
let mut payload = Vec::new();
batch
.serialize(&mut Serializer::new(&mut payload))
.serialize(&mut Serializer::new(&mut payload).with_struct_map())
.context("Failed to serialize event batch")?;
// Get sequence number
......@@ -270,3 +275,55 @@ impl KvEventConsolidatorPublisher {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use dynamo_kv_router::zmq_wire::{KvEventBatch, RawKvEvent};
use rmp_serde as rmps;
#[test]
fn test_block_stored_with_medium_and_no_lora_name_decodes() {
let batch = EventBatch(
0.0,
vec![Event::BlockStored {
block_hashes: vec![42],
parent_block_hash: Some(7),
token_ids: vec![1, 2, 3, 4],
block_size: 4,
lora_name: None,
medium: Some("CPU_TIER1".to_string()),
}],
Some(0),
);
let mut payload = Vec::new();
batch
.serialize(&mut Serializer::new(&mut payload).with_struct_map())
.unwrap();
let decoded: KvEventBatch = rmps::from_slice(&payload).unwrap();
assert_eq!(decoded.events.len(), 1);
assert_eq!(decoded.data_parallel_rank, Some(0));
let RawKvEvent::BlockStored {
block_hashes,
parent_block_hash,
token_ids,
block_size,
medium,
lora_name,
..
} = &decoded.events[0]
else {
panic!("expected BlockStored");
};
assert_eq!(block_hashes.len(), 1);
assert!(parent_block_hash.is_some());
assert_eq!(token_ids, &vec![1, 2, 3, 4]);
assert_eq!(*block_size, 4);
assert_eq!(medium.as_deref(), Some("CPU_TIER1"));
assert_eq!(lora_name.as_deref(), None);
}
}
......@@ -9,14 +9,15 @@ use anyhow::{Context, Result};
use futures::StreamExt;
use rmp_serde::Deserializer;
use serde::Deserialize;
use std::sync::Arc;
use tokio::sync::RwLock;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use dynamo_kv_router::zmq_wire::RawKvEvent;
use super::tracker::{CacheStatusTracker, EventSource, StorageTier};
use super::SharedCacheStatusTracker;
use super::tracker::{
CacheStatusTracker, EventSource, RemoveEventInput, StorageTier, StoreEventInput,
};
use crate::utils::zmq::{connect_sub_socket, multipart_message};
/// Event batch received from vLLM/TensorRT-LLM (array format)
......@@ -48,7 +49,7 @@ impl VllmEventBatch {
/// Start ZMQ listener and process events into tracker
pub async fn start_simple_zmq_listener(
endpoint: String,
tracker: Arc<RwLock<CacheStatusTracker>>,
tracker: SharedCacheStatusTracker,
cancellation_token: CancellationToken,
engine_source: EventSource,
) -> Result<JoinHandle<()>> {
......@@ -65,7 +66,7 @@ pub async fn start_simple_zmq_listener(
async fn run_listener_loop(
endpoint: String,
tracker: Arc<RwLock<CacheStatusTracker>>,
tracker: SharedCacheStatusTracker,
cancellation_token: CancellationToken,
engine_source: EventSource,
) -> Result<()> {
......@@ -136,7 +137,7 @@ async fn run_listener_loop(
// Process events
let mut tracker_guard = tracker.write().await;
for event in batch.events() {
process_event(&mut tracker_guard, event.clone(), dp_rank, engine_source);
process_event(&mut **tracker_guard, event.clone(), dp_rank, engine_source);
}
}
}
......@@ -146,7 +147,7 @@ async fn run_listener_loop(
}
fn process_event(
tracker: &mut CacheStatusTracker,
tracker: &mut dyn CacheStatusTracker,
event: RawKvEvent,
data_parallel_rank: Option<i32>,
engine_source: EventSource,
......@@ -204,16 +205,16 @@ fn process_event(
let block_tokens = token_chunks[i].clone();
let block_hash_u64 = block_hash.into_u64();
tracker.handle_store(
block_hash_u64.to_string(),
engine_source,
block_tokens,
current_parent.clone(),
tracker.handle_store(StoreEventInput {
block_hash: block_hash_u64.to_string(),
source: engine_source,
token_ids: block_tokens,
parent_hash: current_parent.clone(),
block_size,
lora_name.clone(),
Some(storage_tier),
lora_name: lora_name.clone(),
tier: Some(storage_tier),
data_parallel_rank,
);
});
// Next block's parent is this block (only if hash was valid)
current_parent = Some(block_hash_u64.to_string());
......@@ -233,7 +234,11 @@ fn process_event(
);
for block_hash in block_hashes {
tracker.handle_remove(&block_hash.into_u64().to_string(), engine_source);
tracker.handle_remove(RemoveEventInput {
block_hash: block_hash.into_u64().to_string(),
source: engine_source,
tier: Some(storage_tier),
});
}
}
......
......@@ -140,10 +140,6 @@ impl From<RouterStorageTier> for StorageTier {
}
}
/// Legacy type alias for backward compatibility
#[deprecated(note = "Use StorageTier instead")]
pub type StorageMedium = StorageTier;
/// Minimal metadata for tracking which event sources have a block
/// All other metadata (tokens, parent, etc.) is stored in the ConsolidatedEvent when queued
#[derive(Debug, Clone)]
......@@ -195,65 +191,58 @@ pub enum ConsolidatedEvent {
block_size: usize,
lora_name: Option<String>,
source: String,
tier: Option<StorageTier>,
},
/// Block removed (removed from all sources)
Remove {
block_hash: String,
source: String, // The source where it was last removed
tier: Option<StorageTier>,
},
/// All blocks cleared
ClearAll,
}
/// Cache Status Tracker
///
/// Deduplication logic:
/// - Uses SequenceHash (computed from tokens + parent) as the key for deduplication
/// - SequenceHash is position-aware: same tokens at different positions = different keys
/// - Always uses KVBM's xxHash3 hashing function, regardless of source
/// - This allows vLLM and KVBM blocks at the same position to be deduplicated
/// - Emit Store: Only when a block is first stored from ANY source
/// - Emit Remove: Only when a block is removed from ALL sources
#[derive(Debug)]
pub struct CacheStatusTracker {
/// Map of SequenceHash -> BlockMetadata (tracking which sources have this block)
/// The key is position-aware: includes parent context
blocks: HashMap<SequenceHash, BlockMetadata>,
#[derive(Debug, Clone)]
pub struct StoreEventInput {
pub block_hash: String,
pub source: EventSource,
pub token_ids: Vec<u32>,
pub parent_hash: Option<String>,
pub block_size: usize,
pub lora_name: Option<String>,
pub tier: Option<StorageTier>,
pub data_parallel_rank: Option<i32>,
}
/// Reverse mapping: external_block_hash -> SequenceHash (that we computed)
/// Needed because remove events only provide external hash, not token IDs
/// Maps each source's external hash to our computed sequence hash
hash_mapping: HashMap<String, SequenceHash>,
#[derive(Debug, Clone)]
pub struct RemoveEventInput {
pub block_hash: String,
pub source: EventSource,
pub tier: Option<StorageTier>,
}
/// Queue of events to be published
event_queue: Vec<ConsolidatedEvent>,
pub trait CacheStatusTracker: std::fmt::Debug + Send + Sync {
fn handle_store(&mut self, event: StoreEventInput) -> bool;
fn handle_remove(&mut self, event: RemoveEventInput) -> bool;
fn handle_clear_all(&mut self);
fn drain_events(&mut self) -> Vec<ConsolidatedEvent>;
fn num_blocks(&self) -> usize;
}
impl Default for CacheStatusTracker {
fn default() -> Self {
Self::new()
}
/// Deduplicating cache-status tracker.
#[derive(Debug, Default)]
pub struct DedupCacheStatusTracker {
blocks: HashMap<SequenceHash, BlockMetadata>,
hash_mapping: HashMap<String, SequenceHash>,
event_queue: Vec<ConsolidatedEvent>,
}
impl CacheStatusTracker {
impl DedupCacheStatusTracker {
pub fn new() -> Self {
Self {
blocks: HashMap::new(),
hash_mapping: HashMap::new(),
event_queue: Vec::new(),
}
Self::default()
}
/// Handle a STORE event
///
/// Returns true if a consolidated STORE event should be published.
/// Only publishes when a block is stored for the FIRST TIME from ANY source.
///
/// # Arguments
/// * `block_hash` - The external block hash (from vLLM or KVBM)
/// * `source` - The event source (vLLM or KVBM) that stored this block
/// * `token_ids` - The token IDs in this block (for content-based deduplication)
/// * `tier` - Optional storage tier information (for metadata/debugging)
#[allow(clippy::too_many_arguments)]
pub fn handle_store(
&mut self,
......@@ -266,16 +255,65 @@ impl CacheStatusTracker {
tier: Option<StorageTier>,
data_parallel_rank: Option<i32>,
) -> bool {
// Compute LocalBlockHash from token IDs (content only)
let local_block_hash = compute_local_block_hash(&token_ids);
CacheStatusTracker::handle_store(
self,
StoreEventInput {
block_hash,
source,
token_ids,
parent_hash,
block_size,
lora_name,
tier,
data_parallel_rank,
},
)
}
pub fn handle_remove(
&mut self,
block_hash: &str,
source: EventSource,
tier: Option<StorageTier>,
) -> bool {
CacheStatusTracker::handle_remove(
self,
RemoveEventInput {
block_hash: block_hash.to_string(),
source,
tier,
},
)
}
pub fn get_block_sources(&self, external_block_hash: &str) -> Option<&HashSet<EventSource>> {
let local_hash = self.hash_mapping.get(external_block_hash)?;
self.blocks.get(local_hash).map(|m| &m.sources)
}
#[deprecated(note = "Use get_block_sources instead")]
pub fn get_block_tiers(&self, block_hash: &str) -> Option<&HashSet<EventSource>> {
self.get_block_sources(block_hash)
}
}
impl CacheStatusTracker for DedupCacheStatusTracker {
fn handle_store(&mut self, event: StoreEventInput) -> bool {
let StoreEventInput {
block_hash,
source,
token_ids,
parent_hash,
block_size,
lora_name,
tier,
data_parallel_rank,
} = event;
// Resolve parent sequence hash from parent's external hash (if provided)
let local_block_hash = compute_local_block_hash(&token_ids);
let parent_sequence_hash = parent_hash
.as_ref()
.and_then(|ph| self.hash_mapping.get(ph).copied());
// Compute SequenceHash using KVBM's hashing method (position-aware)
// This ensures consistent deduplication regardless of source
let sequence_hash = compute_sequence_hash(parent_sequence_hash, local_block_hash);
tracing::debug!(
......@@ -286,11 +324,7 @@ impl CacheStatusTracker {
);
if let Some(metadata) = self.blocks.get_mut(&sequence_hash) {
// Block already exists from another source (deduplication!), just add the new source
let is_new_source = metadata.add_source(source);
// Add this external hash to the mapping so remove events from this source can find the block
// Multiple external hashes (from different sources) can map to the same SequenceHash
self.hash_mapping.insert(block_hash.clone(), sequence_hash);
if is_new_source {
......@@ -314,10 +348,8 @@ impl CacheStatusTracker {
&token_ids
);
}
// Don't publish a new STORE event (block already exists)
false
} else {
// First time seeing this block from any source - create metadata and queue STORE event
let metadata = BlockMetadata::new(source, block_hash.clone());
tracing::debug!(
......@@ -338,34 +370,16 @@ impl CacheStatusTracker {
);
self.blocks.insert(sequence_hash, metadata);
// Add to hash mapping so remove events can find the block by external hash
self.hash_mapping.insert(block_hash.clone(), sequence_hash);
// Resolve parent_hash to first_block_hash if parent was deduplicated
//
// Problem: When the same block is stored from multiple sources (deduplication),
// each source may use a different external hash for the same logical block.
// Example:
// - Source A (TRTLLM) stores parent with hash "hash_A"
// - Source B (KVBM) stores same parent with hash "hash_B" (different format/algorithm)
// - Router only received STORE event with "hash_A" (first source)
// - When Source B stores child with parent_hash="hash_B", router won't recognize it
//
// Resolve the parent's external hash to its first_block_hash (the hash
// that was sent to the router in the first STORE event) so the router can find it.
let resolved_parent_hash = parent_hash.and_then(|ph| {
// Look up parent's sequence hash from its external hash
self.hash_mapping.get(&ph).and_then(|&parent_seq_hash| {
// Get parent's metadata to find first_block_hash
self.blocks
.get(&parent_seq_hash)
.map(|parent_metadata| parent_metadata.first_block_hash.clone())
})
});
// Queue a STORE event with full metadata
// Use resolved_parent_hash (first_block_hash) so router can find the parent
self.event_queue.push(ConsolidatedEvent::Store {
block_hash: block_hash.clone(),
parent_hash: resolved_parent_hash,
......@@ -373,6 +387,7 @@ impl CacheStatusTracker {
block_size,
lora_name,
source: source.to_str().to_string(),
tier,
});
tracing::debug!(
......@@ -387,17 +402,14 @@ impl CacheStatusTracker {
}
}
/// Handle a REMOVE event
///
/// Returns true if a consolidated REMOVE event should be published.
/// Only publishes when a block is removed from ALL sources.
///
/// # Arguments
/// * `block_hash` - The external block hash to remove
/// * `source` - The event source (vLLM or KVBM) that removed this block
pub fn handle_remove(&mut self, block_hash: &str, source: EventSource) -> bool {
// Look up the SequenceHash from the external block hash
let sequence_hash = match self.hash_mapping.get(block_hash) {
fn handle_remove(&mut self, event: RemoveEventInput) -> bool {
let RemoveEventInput {
block_hash,
source,
tier,
} = event;
let sequence_hash = match self.hash_mapping.get(&block_hash) {
Some(&hash) => hash,
None => {
tracing::warn!(
......@@ -410,7 +422,6 @@ impl CacheStatusTracker {
};
if let Some(metadata) = self.blocks.get_mut(&sequence_hash) {
// Remove the source
let was_removed = metadata.remove_source(source);
if !was_removed {
tracing::warn!(
......@@ -421,11 +432,7 @@ impl CacheStatusTracker {
return false;
}
// Remove this external hash immediately when the source removes it
// This keeps hash_mapping clean
// Each external hash belongs to exactly one source, so when that source
// removes the block, we can safely remove the hash_mapping entry
self.hash_mapping.remove(block_hash);
self.hash_mapping.remove(&block_hash);
tracing::debug!(
"Removed hash_mapping entry for {} (hash_mapping size: {})",
......@@ -433,17 +440,13 @@ impl CacheStatusTracker {
self.hash_mapping.len()
);
// Check if this was the last source
if !metadata.exists_in_any_source() {
// Block is gone from all sources - remove from tracker and publish REMOVE
let first_block_hash = metadata.first_block_hash.clone();
self.blocks.remove(&sequence_hash);
// Double-check: clean up any stray hash mappings (should be empty by now)
// This is a safety check
let stray_count_before = self.hash_mapping.len();
self.hash_mapping
.retain(|_ext_hash, &mut seq_hash| seq_hash != sequence_hash);
.retain(|_ext_hash, seq_hash| *seq_hash != sequence_hash);
let stray_count = stray_count_before - self.hash_mapping.len();
if stray_count > 0 {
......@@ -458,6 +461,7 @@ impl CacheStatusTracker {
self.event_queue.push(ConsolidatedEvent::Remove {
block_hash: first_block_hash.clone(),
source: source.to_str().to_string(),
tier,
});
tracing::debug!(
......@@ -470,7 +474,6 @@ impl CacheStatusTracker {
);
true
} else {
// Block still exists in other sources
tracing::debug!(
"Block {} (seq_hash={}) removed from source {:?}, still in {} source(s): {:?} (hash_mapping: {})",
&metadata.first_block_hash[..16.min(metadata.first_block_hash.len())],
......@@ -492,8 +495,7 @@ impl CacheStatusTracker {
}
}
/// Handle a CLEAR_ALL event
pub fn handle_clear_all(&mut self) {
fn handle_clear_all(&mut self) {
let num_blocks = self.blocks.len();
tracing::debug!("Clearing all {} blocks from tracker", num_blocks);
self.blocks.clear();
......@@ -501,8 +503,7 @@ impl CacheStatusTracker {
self.event_queue.push(ConsolidatedEvent::ClearAll);
}
/// Drain all pending events to be published
pub fn drain_events(&mut self) -> Vec<ConsolidatedEvent> {
fn drain_events(&mut self) -> Vec<ConsolidatedEvent> {
let events = std::mem::take(&mut self.event_queue);
if !events.is_empty() {
tracing::debug!(
......@@ -513,22 +514,106 @@ impl CacheStatusTracker {
events
}
/// Get the number of tracked blocks
pub fn num_blocks(&self) -> usize {
fn num_blocks(&self) -> usize {
self.blocks.len()
}
}
/// Get sources for a specific block by external block hash
pub fn get_block_sources(&self, external_block_hash: &str) -> Option<&HashSet<EventSource>> {
// Look up the local hash from external hash, then get sources
let local_hash = self.hash_mapping.get(external_block_hash)?;
self.blocks.get(local_hash).map(|m| &m.sources)
/// Pass-through cache-status tracker.
#[derive(Debug, Default)]
pub struct PassthroughCacheStatusTracker {
event_queue: Vec<ConsolidatedEvent>,
}
impl PassthroughCacheStatusTracker {
pub fn new() -> Self {
Self::default()
}
/// Legacy method for backwards compatibility
#[deprecated(note = "Use get_block_sources instead")]
pub fn get_block_tiers(&self, block_hash: &str) -> Option<&HashSet<EventSource>> {
self.get_block_sources(block_hash)
#[allow(clippy::too_many_arguments)]
pub fn handle_store(
&mut self,
block_hash: String,
source: EventSource,
token_ids: Vec<u32>,
parent_hash: Option<String>,
block_size: usize,
lora_name: Option<String>,
tier: Option<StorageTier>,
data_parallel_rank: Option<i32>,
) -> bool {
CacheStatusTracker::handle_store(
self,
StoreEventInput {
block_hash,
source,
token_ids,
parent_hash,
block_size,
lora_name,
tier,
data_parallel_rank,
},
)
}
pub fn handle_remove(
&mut self,
block_hash: &str,
source: EventSource,
tier: Option<StorageTier>,
) -> bool {
CacheStatusTracker::handle_remove(
self,
RemoveEventInput {
block_hash: block_hash.to_string(),
source,
tier,
},
)
}
}
impl CacheStatusTracker for PassthroughCacheStatusTracker {
fn handle_store(&mut self, event: StoreEventInput) -> bool {
self.event_queue.push(ConsolidatedEvent::Store {
block_hash: event.block_hash,
parent_hash: event.parent_hash,
token_ids: event.token_ids,
block_size: event.block_size,
lora_name: event.lora_name,
source: event.source.to_str().to_string(),
tier: event.tier,
});
true
}
fn handle_remove(&mut self, event: RemoveEventInput) -> bool {
self.event_queue.push(ConsolidatedEvent::Remove {
block_hash: event.block_hash,
source: event.source.to_str().to_string(),
tier: event.tier,
});
true
}
fn handle_clear_all(&mut self) {
self.event_queue.push(ConsolidatedEvent::ClearAll);
}
fn drain_events(&mut self) -> Vec<ConsolidatedEvent> {
let events = std::mem::take(&mut self.event_queue);
if !events.is_empty() {
tracing::debug!(
"Draining {} pending kv event(s) for publishing",
events.len()
);
}
events
}
fn num_blocks(&self) -> usize {
0
}
}
......@@ -536,9 +621,11 @@ impl CacheStatusTracker {
mod tests {
use super::*;
type TestTracker = DedupCacheStatusTracker;
#[test]
fn test_first_store_publishes() {
let mut tracker = CacheStatusTracker::new();
let mut tracker = TestTracker::new();
let should_publish = tracker.handle_store(
"hash1".to_string(),
......@@ -558,7 +645,7 @@ mod tests {
#[test]
fn test_duplicate_store_no_publish() {
let mut tracker = CacheStatusTracker::new();
let mut tracker = TestTracker::new();
tracker.handle_store(
"hash1".to_string(),
......@@ -589,7 +676,7 @@ mod tests {
#[test]
fn test_multi_source_store() {
let mut tracker = CacheStatusTracker::new();
let mut tracker = TestTracker::new();
// First store from vLLM
tracker.handle_store(
......@@ -624,7 +711,7 @@ mod tests {
#[test]
fn test_remove_from_single_source_publishes() {
let mut tracker = CacheStatusTracker::new();
let mut tracker = TestTracker::new();
tracker.handle_store(
"hash1".to_string(),
......@@ -638,18 +725,24 @@ mod tests {
);
tracker.drain_events();
let should_publish = tracker.handle_remove("hash1", EventSource::Vllm);
let should_publish =
tracker.handle_remove("hash1", EventSource::Vllm, Some(StorageTier::Device));
assert!(should_publish);
assert_eq!(tracker.num_blocks(), 0);
let events = tracker.drain_events();
assert_eq!(events.len(), 1);
matches!(events[0], ConsolidatedEvent::Remove { .. });
match &events[0] {
ConsolidatedEvent::Remove { tier, .. } => {
assert_eq!(*tier, Some(StorageTier::Device));
}
other => panic!("expected Remove event, got: {:?}", other),
}
}
#[test]
fn test_remove_from_multi_source_no_publish() {
let mut tracker = CacheStatusTracker::new();
let mut tracker = TestTracker::new();
// Store from vLLM - first STORE event published
tracker.handle_store(
......@@ -676,14 +769,19 @@ mod tests {
tracker.drain_events();
// Remove from vLLM - should not publish (still in KVBM)
let should_publish = tracker.handle_remove("vllm_hash1", EventSource::Vllm);
let should_publish =
tracker.handle_remove("vllm_hash1", EventSource::Vllm, Some(StorageTier::Device));
assert!(!should_publish);
assert_eq!(tracker.num_blocks(), 1);
assert_eq!(tracker.drain_events().len(), 0);
// Remove from KVBM (last source) - should publish REMOVE event
let should_publish = tracker.handle_remove("kvbm_hash1", EventSource::Kvbm);
let should_publish = tracker.handle_remove(
"kvbm_hash1",
EventSource::Kvbm,
Some(StorageTier::HostPinned),
);
assert!(should_publish);
assert_eq!(tracker.num_blocks(), 0);
......@@ -691,7 +789,7 @@ mod tests {
#[test]
fn test_sequence_hash_first_block() {
let mut tracker = CacheStatusTracker::new();
let mut tracker = TestTracker::new();
// First block (no parent)
let should_publish = tracker.handle_store(
......@@ -714,7 +812,7 @@ mod tests {
#[test]
fn test_sequence_hash_with_parent() {
let mut tracker = CacheStatusTracker::new();
let mut tracker = TestTracker::new();
// First block
tracker.handle_store(
......@@ -747,7 +845,7 @@ mod tests {
#[test]
fn test_same_tokens_different_position_different_blocks() {
let mut tracker = CacheStatusTracker::new();
let mut tracker = TestTracker::new();
// First occurrence: tokens [1,2,3,4] at position 0 (no parent)
tracker.handle_store(
......@@ -782,7 +880,7 @@ mod tests {
#[test]
fn test_clear_all() {
let mut tracker = CacheStatusTracker::new();
let mut tracker = TestTracker::new();
// Add multiple blocks
tracker.handle_store(
......@@ -815,13 +913,14 @@ mod tests {
assert_eq!(tracker.num_blocks(), 0);
// Verify hash_mapping is also cleared
let should_publish = tracker.handle_remove("block1", EventSource::Vllm);
let should_publish =
tracker.handle_remove("block1", EventSource::Vllm, Some(StorageTier::Device));
assert!(!should_publish); // Should fail because block is gone
}
#[test]
fn test_deduplication_across_sources_with_parent() {
let mut tracker = CacheStatusTracker::new();
let mut tracker = TestTracker::new();
// vLLM stores block 1 (parent)
tracker.handle_store(
......@@ -870,9 +969,10 @@ mod tests {
#[test]
fn test_remove_non_existent_block() {
let mut tracker = CacheStatusTracker::new();
let mut tracker = TestTracker::new();
let should_publish = tracker.handle_remove("non_existent", EventSource::Vllm);
let should_publish =
tracker.handle_remove("non_existent", EventSource::Vllm, Some(StorageTier::Device));
assert!(!should_publish);
assert_eq!(tracker.num_blocks(), 0);
......@@ -896,7 +996,7 @@ mod tests {
#[test]
fn test_lora_name_round_trip_through_tracker() {
let mut tracker = CacheStatusTracker::new();
let mut tracker = TestTracker::new();
let should_publish = tracker.handle_store(
"hash_lora".to_string(),
......@@ -916,10 +1016,12 @@ mod tests {
ConsolidatedEvent::Store {
lora_name,
token_ids,
tier,
..
} => {
assert_eq!(lora_name.as_deref(), Some("my-adapter"));
assert_eq!(token_ids, &[1, 2, 3, 4]);
assert_eq!(*tier, Some(StorageTier::Device));
}
other => panic!("expected Store event, got: {:?}", other),
}
......@@ -927,7 +1029,7 @@ mod tests {
#[test]
fn test_lora_name_none_for_base_model() {
let mut tracker = CacheStatusTracker::new();
let mut tracker = TestTracker::new();
tracker.handle_store(
"hash_base".to_string(),
......@@ -943,8 +1045,11 @@ mod tests {
let events = tracker.drain_events();
assert_eq!(events.len(), 1);
match &events[0] {
ConsolidatedEvent::Store { lora_name, .. } => {
ConsolidatedEvent::Store {
lora_name, tier, ..
} => {
assert!(lora_name.is_none());
assert_eq!(*tier, Some(StorageTier::Device));
}
other => panic!("expected Store event, got: {:?}", other),
}
......@@ -971,4 +1076,53 @@ mod tests {
let seq_hash2_different = compute_sequence_hash(Some(different_parent), block_hash2);
assert_ne!(seq_hash2_v1, seq_hash2_different);
}
#[test]
fn test_passthrough_tracker_forwards_duplicate_store() {
let mut tracker = PassthroughCacheStatusTracker::new();
assert!(tracker.handle_store(
"hash1".to_string(),
EventSource::Vllm,
vec![1, 2, 3],
None,
3,
None,
Some(StorageTier::Device),
None,
));
assert!(tracker.handle_store(
"hash1".to_string(),
EventSource::Vllm,
vec![1, 2, 3],
None,
3,
None,
Some(StorageTier::Device),
None,
));
let events = tracker.drain_events();
assert_eq!(events.len(), 2);
}
#[test]
fn test_passthrough_tracker_remove_and_clear() {
let mut tracker = PassthroughCacheStatusTracker::new();
assert!(tracker.handle_remove("hash1", EventSource::Kvbm, Some(StorageTier::HostPinned),));
tracker.handle_clear_all();
let events = tracker.drain_events();
assert_eq!(events.len(), 2);
assert!(matches!(
&events[0],
ConsolidatedEvent::Remove {
tier: Some(StorageTier::HostPinned),
..
}
));
assert!(matches!(&events[1], ConsolidatedEvent::ClearAll));
assert_eq!(tracker.num_blocks(), 0);
}
}
......@@ -51,6 +51,7 @@ pub mod inactive;
pub mod priority_key;
pub mod state;
use crate::block_manager::kv_consolidator::StorageTier;
use active::ActiveBlockPool;
use inactive::InactiveBlockPool;
......@@ -72,6 +73,9 @@ pub struct ManagedBlockPoolArgs<S: Storage, L: LocalityProvider, M: BlockMetadat
#[builder(default = "Handle::current()")]
async_runtime: Handle,
#[builder(default = "StorageTier::Device")]
storage_tier: StorageTier,
#[builder(default = "BlockRegistrationDuplicationSetting::Disabled")]
default_duplication_setting: BlockRegistrationDuplicationSetting,
}
......@@ -85,6 +89,7 @@ impl<S: Storage, L: LocalityProvider, M: BlockMetadata> ManagedBlockPoolArgsBuil
blocks,
global_registry,
async_runtime,
storage_tier,
default_duplication_setting,
) = args.dissolve();
......@@ -95,6 +100,7 @@ impl<S: Storage, L: LocalityProvider, M: BlockMetadata> ManagedBlockPoolArgsBuil
blocks,
global_registry,
async_runtime,
storage_tier,
default_duplication_setting,
);
......@@ -176,6 +182,7 @@ impl<S: Storage, L: LocalityProvider, M: BlockMetadata> ManagedBlockPool<S, L, M
blocks: Vec<Block<S, L, M>>,
global_registry: GlobalRegistry,
async_runtime: Handle,
storage_tier: StorageTier,
default_duplication_setting: BlockRegistrationDuplicationSetting,
) -> Self {
let (pool, progress_engine) = Self::with_progress_engine(
......@@ -184,6 +191,7 @@ impl<S: Storage, L: LocalityProvider, M: BlockMetadata> ManagedBlockPool<S, L, M
blocks,
global_registry,
async_runtime,
storage_tier,
default_duplication_setting,
);
......@@ -228,6 +236,7 @@ impl<S: Storage, L: LocalityProvider, M: BlockMetadata> ManagedBlockPool<S, L, M
blocks: Vec<Block<S, L, M>>,
global_registry: GlobalRegistry,
async_runtime: Handle,
storage_tier: StorageTier,
default_duplication_setting: BlockRegistrationDuplicationSetting,
) -> (Self, ProgressEngine<S, L, M>) {
let (priority_tx, priority_rx) = tokio::sync::mpsc::unbounded_channel();
......@@ -241,6 +250,7 @@ impl<S: Storage, L: LocalityProvider, M: BlockMetadata> ManagedBlockPool<S, L, M
blocks,
global_registry,
async_runtime,
storage_tier,
);
let available_blocks_counter = progress_engine.available_blocks_counter.clone();
......@@ -515,10 +525,16 @@ impl<S: Storage, L: LocalityProvider + 'static, M: BlockMetadata> ProgressEngine
blocks: Vec<Block<S, L, M>>,
global_registry: GlobalRegistry,
async_runtime: Handle,
storage_tier: StorageTier,
) -> Self {
let (return_tx, return_rx) = tokio::sync::mpsc::unbounded_channel();
let mut state =
State::<S, L, M>::new(event_manager, return_tx, global_registry, async_runtime);
let mut state = State::<S, L, M>::new(
event_manager,
return_tx,
global_registry,
async_runtime,
storage_tier,
);
let count = blocks.len();
......@@ -589,6 +605,7 @@ mod tests {
blocks,
global_registry,
async_runtime,
storage_tier,
default_duplication_setting,
) = args.dissolve();
......@@ -598,6 +615,7 @@ mod tests {
blocks,
global_registry,
async_runtime,
storage_tier,
default_duplication_setting,
);
......
......@@ -532,6 +532,7 @@ pub(crate) mod tests {
state::CompleteState,
},
events::NullEventManager,
kv_consolidator::StorageTier,
layout::{BlockLayout, FullyContiguous, LayoutConfigBuilder},
storage::tests::{NullDeviceAllocator, NullDeviceStorage},
},
......@@ -701,8 +702,12 @@ pub(crate) mod tests {
let mut blocks = create_block_collection(num_blocks).into_blocks().unwrap();
let event_manager = NullEventManager::new();
let mut registry =
BlockRegistry::new(event_manager, GlobalRegistry::default(), async_runtime);
let mut registry = BlockRegistry::new(
event_manager,
GlobalRegistry::default(),
async_runtime,
StorageTier::Device,
);
// Iterate through the generated TokenBlocks and the template Blocks,
// setting the state and registering each one.
......@@ -745,8 +750,12 @@ pub(crate) mod tests {
let matched_block_count = matched_blocks.len();
let event_manager = NullEventManager::new();
let mut registry =
BlockRegistry::new(event_manager, GlobalRegistry::default(), async_runtime);
let mut registry = BlockRegistry::new(
event_manager,
GlobalRegistry::default(),
async_runtime,
StorageTier::Device,
);
// all matched blocks should be in the complete or registered state
for block in &mut matched_blocks {
......
......@@ -17,11 +17,17 @@ impl<S: Storage, L: LocalityProvider + 'static, M: BlockMetadata> State<S, L, M>
return_tx: tokio::sync::mpsc::UnboundedSender<Block<S, L, M>>,
global_registry: GlobalRegistry,
async_runtime: Handle,
storage_tier: StorageTier,
) -> Self {
Self {
active: ActiveBlockPool::new(),
inactive: InactiveBlockPool::new(),
registry: BlockRegistry::new(event_manager.clone(), global_registry, async_runtime),
registry: BlockRegistry::new(
event_manager.clone(),
global_registry,
async_runtime,
storage_tier,
),
return_tx,
event_manager,
}
......
......@@ -26,6 +26,8 @@ use std::sync::Arc;
use tokio::runtime::Handle;
use tokio::sync::oneshot;
use crate::block_manager::kv_consolidator::StorageTier;
pub(crate) struct Resources {
pub worker_id: WorkerID,
pub cancellation_token: CancellationToken,
......@@ -253,7 +255,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<locality::Local, Metadata> {
let (device_pool, device_blocks, device_offload_filter) = match device_factory {
Some(factory) => {
let (pool, blocks, offload_filter) =
create_block_pool::<_, _, Metadata>(factory, &resources, "disk")?;
create_block_pool::<_, _, Metadata>(factory, &resources, "device")?;
(Some(pool), Some(blocks), offload_filter)
}
None => {
......@@ -523,17 +525,25 @@ impl<Locality: LocalityProvider, Metadata: BlockMetadata> std::fmt::Debug
pub(crate) fn create_block_pool<S: Storage, L: LocalityProvider, M: BlockMetadata>(
factory: impl IntoBlocks<S, L>,
resources: &Resources,
_pool_name: &str,
pool_name: &str,
) -> Result<(
Arc<dyn BlockPool<S, L, M>>,
Vec<Block<S, L, M>>,
Option<Arc<dyn OffloadFilter>>,
)> {
let storage_tier = match pool_name {
"device" => StorageTier::Device,
"host" => StorageTier::HostPinned,
"disk" => StorageTier::Disk,
_ => anyhow::bail!("unsupported block pool tier: {}", pool_name),
};
let pool = ManagedBlockPool::<S, L, M>::builder()
.cancel_token(resources.cancellation_token.clone())
.global_registry(resources.global_registry.clone())
.async_runtime(resources.async_rt_handle.clone())
.event_manager(resources.event_manager.clone())
.storage_tier(storage_tier)
.build()?;
let offload_filter = factory.offload_filter();
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use std::time::Instant;
use std::{
collections::{HashMap, HashSet},
sync::Arc,
time::Instant,
};
use anyhow::Result;
use dynamo_kv_router::{
......@@ -15,6 +18,7 @@ use dynamo_kv_router::{
RouterRequest, RouterResponse, TokensWithHashes, WorkerId, WorkerWithDpRank,
compute_block_hash_for_seq,
},
scheduling::TierOverlapBlocks,
};
use dynamo_runtime::{
component::{Client, Endpoint},
......@@ -63,8 +67,6 @@ use crate::{
local_model::runtime_config::ModelRuntimeConfig,
};
use std::collections::HashSet;
// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
......@@ -85,6 +87,144 @@ pub const RADIX_STATE_FILE: &str = "radix-state";
// for worker-local kvindexer query
pub const WORKER_KV_INDEXER_BUFFER_SIZE: usize = 1024; // store 1024 most recent events in worker buffer
#[derive(Debug, Clone, Copy, Default)]
pub(crate) struct WorkerCacheHitEstimate {
pub effective_overlap_blocks: f64,
pub cached_tokens: usize,
}
impl WorkerCacheHitEstimate {
pub fn rounded_overlap_blocks(self) -> u32 {
self.effective_overlap_blocks.round() as u32
}
}
#[derive(Debug, Clone, Default)]
struct CacheHitEstimates {
effective_overlap_blocks: HashMap<WorkerWithDpRank, f64>,
cached_tokens: HashMap<WorkerWithDpRank, usize>,
}
#[derive(Debug, Clone, Copy)]
pub(crate) struct BestMatchDetails {
pub worker: WorkerWithDpRank,
pub cache_hit: WorkerCacheHitEstimate,
}
fn cache_hit_weight_for_tier(
kv_router_config: &KvRouterConfig,
storage_tier: dynamo_kv_router::protocols::StorageTier,
) -> f64 {
match storage_tier {
dynamo_kv_router::protocols::StorageTier::Device => 1.0,
dynamo_kv_router::protocols::StorageTier::HostPinned => {
kv_router_config.host_cache_hit_weight
}
dynamo_kv_router::protocols::StorageTier::Disk
| dynamo_kv_router::protocols::StorageTier::External => {
kv_router_config.disk_cache_hit_weight
}
}
}
fn cached_tokens_from_effective_overlap(block_size: u32, effective_overlap_blocks: f64) -> usize {
(effective_overlap_blocks * block_size as f64)
.round()
.max(0.0) as usize
}
fn cache_hit_estimates_from_tiered_matches(
kv_router_config: &KvRouterConfig,
block_size: u32,
tiered_matches: &indexer::TieredMatchDetails,
) -> CacheHitEstimates {
let mut effective_overlap_blocks = HashMap::new();
for (worker, overlap) in &tiered_matches.device.overlap_scores.scores {
effective_overlap_blocks.insert(*worker, *overlap as f64);
}
for (storage_tier, tier_matches) in &tiered_matches.lower_tier {
let weight = cache_hit_weight_for_tier(kv_router_config, *storage_tier);
if weight == 0.0 {
continue;
}
for (worker, hits) in &tier_matches.hits {
if *hits == 0 {
continue;
}
*effective_overlap_blocks.entry(*worker).or_insert(0.0) += *hits as f64 * weight;
}
}
let cached_tokens = effective_overlap_blocks
.iter()
.map(|(worker, overlap)| {
(
*worker,
cached_tokens_from_effective_overlap(block_size, *overlap),
)
})
.collect();
CacheHitEstimates {
effective_overlap_blocks,
cached_tokens,
}
}
fn cache_hit_for_worker(
cache_hit_estimates: &CacheHitEstimates,
worker: WorkerWithDpRank,
) -> WorkerCacheHitEstimate {
WorkerCacheHitEstimate {
effective_overlap_blocks: cache_hit_estimates
.effective_overlap_blocks
.get(&worker)
.copied()
.unwrap_or(0.0),
cached_tokens: cache_hit_estimates
.cached_tokens
.get(&worker)
.copied()
.unwrap_or(0),
}
}
fn tier_overlap_blocks_from_tiered_matches(
tiered_matches: &indexer::TieredMatchDetails,
) -> TierOverlapBlocks {
let mut tier_overlap_blocks = TierOverlapBlocks::default();
if let Some(host_matches) = tiered_matches
.lower_tier
.get(&dynamo_kv_router::protocols::StorageTier::HostPinned)
{
tier_overlap_blocks.host_pinned.extend(
host_matches
.hits
.iter()
.map(|(worker, hits)| (*worker, *hits)),
);
}
// Disk and External share the same weighting (see `storage_tier_weight`),
// so accumulate both into the disk bucket.
for tier in [
dynamo_kv_router::protocols::StorageTier::Disk,
dynamo_kv_router::protocols::StorageTier::External,
] {
if let Some(matches) = tiered_matches.lower_tier.get(&tier) {
for (worker, hits) in &matches.hits {
*tier_overlap_blocks.disk.entry(*worker).or_default() += *hits;
}
}
}
tier_overlap_blocks
}
/// Generates a dp_rank-specific endpoint name for the worker KV indexer query service.
/// Each dp_rank has its own LocalKvIndexer and query endpoint to ensure per-dp_rank monotonicity.
pub fn worker_kv_indexer_query_endpoint(dp_rank: DpRank) -> String {
......@@ -275,6 +415,25 @@ where
self.is_eagle
}
fn cache_hit_estimates_from_tiered_matches(
&self,
tiered_matches: &indexer::TieredMatchDetails,
) -> CacheHitEstimates {
cache_hit_estimates_from_tiered_matches(
&self.kv_router_config,
self.block_size,
tiered_matches,
)
}
fn cache_hit_for_worker(
&self,
cache_hit_estimates: &CacheHitEstimates,
worker: WorkerWithDpRank,
) -> WorkerCacheHitEstimate {
cache_hit_for_worker(cache_hit_estimates, worker)
}
pub async fn record_routing_decision(
&self,
mut tokens_with_hashes: TokensWithHashes,
......@@ -285,16 +444,15 @@ where
.await
}
/// Give these tokens, find the worker with the best match in it's KV cache.
/// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
/// Now also takes optional context_id for request tracking.
/// Give these tokens, find the worker with the best weighted cache hit.
/// Returns the full match details for the selected worker.
///
/// When `pinned_worker` is Some, scheduling and queueing are constrained to
/// that exact worker/rank.
///
/// When `allowed_worker_ids` is Some, only workers in that set are considered for selection.
#[allow(clippy::too_many_arguments)]
pub async fn find_best_match(
pub(crate) async fn find_best_match_details(
&self,
context_id: Option<&str>,
tokens: &[u32],
......@@ -306,7 +464,7 @@ where
expected_output_tokens: Option<u32>,
pinned_worker: Option<WorkerWithDpRank>,
allowed_worker_ids: Option<HashSet<WorkerId>>,
) -> anyhow::Result<(WorkerWithDpRank, u32)> {
) -> anyhow::Result<BestMatchDetails> {
let start = Instant::now();
if update_states && context_id.is_none() {
......@@ -336,13 +494,13 @@ where
});
let seq_hash_elapsed = start.elapsed();
// Query indexer and shared cache in parallel when shared cache is configured.
// Query indexer (tiered) and shared cache in parallel when shared cache is configured.
// Time each independently so metrics can separate indexer vs shared cache latency.
let (overlap_scores, shared_cache_hits, indexer_duration, shared_cache_duration) =
let (tiered_matches, shared_cache_hits, indexer_duration, shared_cache_duration) =
if let Some(ref shared_cache) = self.shared_cache {
let indexer_fut = self
.indexer
.find_matches(block_hashes.clone())
.find_matches_by_tier(block_hashes)
.instrument(tracing::info_span!("kv_router.find_matches"));
let shared_fut = shared_cache
.check_blocks(tokens, self.block_size)
......@@ -361,7 +519,7 @@ where
let ((indexer_result, idx_dur), (shared_result, sc_dur)) =
tokio::join!(indexer_timed, shared_timed);
let overlaps = indexer_result?;
let tiered = indexer_result?;
// Shared cache failure is non-fatal: log warning and fall back to empty hits.
let hits = match shared_result {
Ok(hits) => Some(hits),
......@@ -373,16 +531,26 @@ where
None
}
};
(overlaps, hits, idx_dur, Some(sc_dur))
(tiered, hits, idx_dur, Some(sc_dur))
} else {
let t = Instant::now();
let overlaps = self
let tiered = self
.indexer
.find_matches(block_hashes)
.find_matches_by_tier(block_hashes)
.instrument(tracing::info_span!("kv_router.find_matches"))
.await?;
(overlaps, None, t.elapsed(), None)
(tiered, None, t.elapsed(), None)
};
let tier_overlap_blocks = tier_overlap_blocks_from_tiered_matches(&tiered_matches);
let cache_hit_estimates = self.cache_hit_estimates_from_tiered_matches(&tiered_matches);
let tree_sizes: HashMap<_, _> = tiered_matches
.device
.overlap_scores
.tree_sizes
.iter()
.map(|(k, v)| (*k, *v))
.collect();
let find_matches_elapsed = start.elapsed();
// Capture shared cache info for metrics before moving into schedule().
......@@ -397,7 +565,10 @@ where
context_id.map(|s| s.to_string()),
isl_tokens,
maybe_seq_hashes,
overlap_scores,
tier_overlap_blocks,
cache_hit_estimates.effective_overlap_blocks,
cache_hit_estimates.cached_tokens,
tree_sizes,
router_config_override,
update_states,
lora_name,
......@@ -430,7 +601,7 @@ where
m.shared_cache_hit_rate
.observe(hits.total_hits as f64 / num_blocks as f64);
}
let beyond = hits.hits_beyond(response.overlap_blocks);
let beyond = hits.hits_beyond(response.effective_overlap_blocks.round() as u32);
m.shared_cache_beyond_blocks.observe(beyond as f64);
}
......@@ -445,7 +616,45 @@ where
"find_best_match completed"
);
Ok((response.best_worker, response.overlap_blocks))
Ok(BestMatchDetails {
worker: response.best_worker,
cache_hit: WorkerCacheHitEstimate {
effective_overlap_blocks: response.effective_overlap_blocks,
cached_tokens: response.cached_tokens,
},
})
}
/// Give these tokens, find the worker with the best match in its KV cache.
/// Returns the best worker (with dp_rank) and approximate effective overlap in blocks.
#[allow(clippy::too_many_arguments)]
pub async fn find_best_match(
&self,
context_id: Option<&str>,
tokens: &[u32],
block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
router_config_override: Option<&RouterConfigOverride>,
update_states: bool,
lora_name: Option<String>,
priority_jump: f64,
expected_output_tokens: Option<u32>,
allowed_worker_ids: Option<HashSet<WorkerId>>,
) -> anyhow::Result<(WorkerWithDpRank, u32)> {
let result = self
.find_best_match_details(
context_id,
tokens,
block_mm_infos,
router_config_override,
update_states,
lora_name,
priority_jump,
expected_output_tokens,
None,
allowed_worker_ids,
)
.await?;
Ok((result.worker, result.cache_hit.rounded_overlap_blocks()))
}
/// Register externally-provided workers in the slot tracker.
......@@ -459,7 +668,7 @@ where
request_id: String,
tokens: &[u32],
block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
overlap_blocks: u32,
cached_tokens: usize,
expected_output_tokens: Option<u32>,
worker: WorkerWithDpRank,
lora_name: Option<String>,
......@@ -483,7 +692,7 @@ where
.kv_router_config
.track_prefill_tokens(router_config_override);
let prefill_load_hint =
self.prefill_load_hint_for(isl_tokens, overlap_blocks, track_prefill_tokens);
self.prefill_load_hint_for(isl_tokens, cached_tokens, track_prefill_tokens);
if let Err(e) = self
.scheduler
......@@ -518,14 +727,14 @@ where
fn prefill_load_hint_for(
&self,
isl_tokens: usize,
overlap_blocks: u32,
cached_tokens: usize,
track_prefill_tokens: bool,
) -> Option<PrefillLoadHint> {
if !track_prefill_tokens {
return None;
}
let prefix = (overlap_blocks as usize) * (self.block_size as usize);
let prefix = cached_tokens.min(isl_tokens);
let effective_isl = isl_tokens.saturating_sub(prefix);
if effective_isl == 0 {
return None;
......@@ -578,7 +787,7 @@ where
}
/// Compute the overlap blocks for a given token sequence and worker.
/// This queries the indexer to find how many blocks are already cached.
/// This queries the indexer to find the effective weighted cache hit.
pub async fn get_overlap_blocks(
&self,
tokens: &[u32],
......@@ -586,6 +795,19 @@ where
worker: WorkerWithDpRank,
lora_name: Option<&str>,
) -> Result<u32, KvRouterError> {
Ok(self
.get_cache_hit_estimate(tokens, block_mm_infos, worker, lora_name)
.await?
.rounded_overlap_blocks())
}
pub(crate) async fn get_cache_hit_estimate(
&self,
tokens: &[u32],
block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
worker: WorkerWithDpRank,
lora_name: Option<&str>,
) -> Result<WorkerCacheHitEstimate, KvRouterError> {
let block_hashes = compute_block_hash_for_seq(
tokens,
self.block_size,
......@@ -595,9 +817,9 @@ where
is_eagle: Some(self.is_eagle),
},
);
log_routing_input_hashes(None, self.block_size, tokens, &block_hashes);
let overlap_scores = self.indexer.find_matches(block_hashes).await?;
Ok(overlap_scores.scores.get(&worker).copied().unwrap_or(0))
let tiered_matches = self.indexer.find_matches_by_tier(block_hashes).await?;
let cache_hit_estimates = self.cache_hit_estimates_from_tiered_matches(&tiered_matches);
Ok(self.cache_hit_for_worker(&cache_hit_estimates, worker))
}
/// Get potential prefill and decode loads for all workers
......@@ -626,12 +848,13 @@ where
let track_prefill_tokens = self
.kv_router_config
.track_prefill_tokens(router_config_override);
let overlap_scores = self.indexer.find_matches(block_hashes).await?;
let tiered_matches = self.indexer.find_matches_by_tier(block_hashes).await?;
let cache_hit_estimates = self.cache_hit_estimates_from_tiered_matches(&tiered_matches);
Ok(self.scheduler.get_potential_loads(
maybe_seq_hashes,
isl_tokens,
overlap_scores,
cache_hit_estimates.cached_tokens,
track_prefill_tokens,
))
}
......@@ -673,7 +896,6 @@ where
0.0,
None,
None,
None,
)
.await?;
......@@ -719,12 +941,57 @@ mod tests {
use std::collections::HashMap;
use async_trait::async_trait;
use dynamo_kv_router::{
indexer::{LowerTierMatchDetails, MatchDetails},
protocols::{OverlapScores, StorageTier},
};
use dynamo_runtime::{DistributedRuntime, Runtime, distributed::DistributedConfig};
use tokio::sync::watch;
use crate::kv_router::scheduler::KvSchedulerError;
use crate::local_model::runtime_config::ModelRuntimeConfig;
#[test]
fn weighted_cache_hit_estimates_include_lower_tiers() {
let worker_1 = WorkerWithDpRank::new(1, 0);
let worker_2 = WorkerWithDpRank::new(2, 0);
let mut device_overlap_scores = OverlapScores::new();
device_overlap_scores.scores.insert(worker_1, 2);
let mut host_match_details = LowerTierMatchDetails::default();
host_match_details.hits.insert(worker_1, 1);
host_match_details.hits.insert(worker_2, 1);
let mut disk_match_details = LowerTierMatchDetails::default();
disk_match_details.hits.insert(worker_1, 2);
let tiered_matches = indexer::TieredMatchDetails {
device: MatchDetails {
overlap_scores: device_overlap_scores,
..Default::default()
},
lower_tier: HashMap::from([
(StorageTier::HostPinned, host_match_details),
(StorageTier::Disk, disk_match_details),
]),
};
let estimates = cache_hit_estimates_from_tiered_matches(
&KvRouterConfig::default(),
16,
&tiered_matches,
);
assert_eq!(
estimates.effective_overlap_blocks.get(&worker_1),
Some(&3.25)
);
assert_eq!(estimates.cached_tokens.get(&worker_1), Some(&52));
assert_eq!(
estimates.effective_overlap_blocks.get(&worker_2),
Some(&0.75)
);
assert_eq!(estimates.cached_tokens.get(&worker_2), Some(&12));
}
struct FakeSharedCache {
hits: Option<dynamo_kv_router::protocols::SharedCacheHits>,
should_error: bool,
......@@ -766,7 +1033,8 @@ mod tests {
Ok(dynamo_kv_router::protocols::WorkerSelectionResult {
worker: self.selected_worker,
required_blocks: request.isl_tokens.div_ceil(block_size as usize) as u64,
overlap_blocks: 0,
effective_overlap_blocks: 0.0,
cached_tokens: 0,
})
}
}
......@@ -855,7 +1123,6 @@ mod tests {
0.0,
None,
None,
None,
)
.await
.unwrap();
......@@ -889,7 +1156,6 @@ mod tests {
0.0,
None,
None,
None,
)
.await
.unwrap();
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use std::time::Duration;
use std::{
collections::HashMap,
sync::{Arc, RwLock},
time::Duration,
};
use anyhow::Result;
use dynamo_kv_router::{
ConcurrentRadixTreeCompressed, ThreadPoolIndexer,
ConcurrentRadixTreeCompressed, LowerTierIndexer, ThreadPoolIndexer,
approx::PruneConfig,
config::KvRouterConfig,
indexer::{KvIndexer, KvIndexerInterface, KvIndexerMetrics, KvRouterError},
indexer::{
KvIndexer, KvIndexerInterface, KvIndexerMetrics, KvRouterError, LowerTierContinuation,
LowerTierMatchDetails, MatchDetails,
},
protocols::{
DpRank, LocalBlockHash, OverlapScores, RouterEvent, TokensWithHashes, WorkerId,
WorkerWithDpRank,
DpRank, LocalBlockHash, OverlapScores, RouterEvent, StorageTier, TokensWithHashes,
WorkerId, WorkerWithDpRank,
},
};
use dynamo_runtime::{component::Component, traits::DistributedRuntimeProvider};
......@@ -24,6 +30,123 @@ pub mod remote;
mod subscriber;
mod worker_query;
#[derive(Clone)]
pub struct LowerTierIndexers {
num_threads: usize,
block_size: u32,
indexers: Arc<RwLock<HashMap<StorageTier, Arc<ThreadPoolIndexer<LowerTierIndexer>>>>>,
}
impl LowerTierIndexers {
pub(crate) fn new(num_threads: usize, block_size: u32) -> Self {
assert!(
num_threads > 0,
"lower-tier indexer threads must be non-zero"
);
Self {
num_threads,
block_size,
indexers: Arc::new(RwLock::new(HashMap::new())),
}
}
fn get_or_create(&self, storage_tier: StorageTier) -> Arc<ThreadPoolIndexer<LowerTierIndexer>> {
debug_assert!(!storage_tier.is_gpu());
if let Some(indexer) = self.indexers.read().unwrap().get(&storage_tier).cloned() {
return indexer;
}
self.indexers
.write()
.unwrap()
.entry(storage_tier)
.or_insert_with(|| {
Arc::new(ThreadPoolIndexer::new(
LowerTierIndexer::new(),
self.num_threads,
self.block_size,
))
})
.clone()
}
fn all(&self) -> Vec<Arc<ThreadPoolIndexer<LowerTierIndexer>>> {
self.indexers.read().unwrap().values().cloned().collect()
}
fn get(&self, storage_tier: StorageTier) -> Option<Arc<ThreadPoolIndexer<LowerTierIndexer>>> {
self.indexers.read().unwrap().get(&storage_tier).cloned()
}
}
fn lower_tier_query_order() -> [StorageTier; 3] {
[
StorageTier::HostPinned,
StorageTier::Disk,
StorageTier::External,
]
}
fn query_lower_tiers(
indexers: &LowerTierIndexers,
sequence: &[LocalBlockHash],
device_matches: &MatchDetails,
) -> HashMap<StorageTier, LowerTierMatchDetails> {
let mut continuations = LowerTierMatchDetails::default().next_continuations;
for (worker, matched_blocks) in &device_matches.overlap_scores.scores {
let Some(last_hash) = device_matches.last_matched_hashes.get(worker).copied() else {
debug_assert!(
false,
"device match result missing last matched hash for worker {worker:?}"
);
continue;
};
continuations.insert(
*worker,
LowerTierContinuation::new(*matched_blocks as usize, last_hash),
);
}
let mut lower_tier_matches = HashMap::new();
for storage_tier in lower_tier_query_order() {
let Some(indexer) = indexers.get(storage_tier) else {
continue;
};
if let Some(&first_hash) = sequence.first() {
let root_workers: Vec<_> = indexer.backend().root_workers(first_hash);
for worker in root_workers.iter() {
continuations
.entry(*worker)
.or_insert_with(|| LowerTierContinuation::from_root(0));
}
}
let tier_matches = indexer
.backend()
.query_match_details(sequence, &continuations);
let matched_workers = tier_matches.hits.values().filter(|&&hits| hits > 0).count();
tracing::debug!(
?storage_tier,
queried_workers = continuations.len(),
matched_workers,
"Queried lower-tier indexer"
);
continuations = tier_matches.next_continuations.clone();
lower_tier_matches.insert(storage_tier, tier_matches);
}
lower_tier_matches
}
#[derive(Debug, Clone, Default)]
pub(crate) struct TieredMatchDetails {
pub device: MatchDetails,
#[cfg_attr(not(test), allow(dead_code))]
pub lower_tier: HashMap<StorageTier, LowerTierMatchDetails>,
}
use self::remote::RemoteIndexer;
pub use self::remote::{ServedIndexerHandle, ServedIndexerMode, ensure_served_indexer_service};
pub(crate) use subscriber::start_subscriber;
......@@ -31,8 +154,14 @@ pub(crate) use worker_query::start_worker_kv_query_endpoint;
#[derive(Clone)]
pub enum Indexer {
KvIndexer(KvIndexer),
Concurrent(Arc<ThreadPoolIndexer<ConcurrentRadixTreeCompressed>>),
KvIndexer {
primary: KvIndexer,
lower_tier: LowerTierIndexers,
},
Concurrent {
primary: Arc<ThreadPoolIndexer<ConcurrentRadixTreeCompressed>>,
lower_tier: LowerTierIndexers,
},
Remote(Arc<RemoteIndexer>),
None,
}
......@@ -73,57 +202,98 @@ impl Indexer {
max_tree_size: kv_router_config.router_max_tree_size,
prune_target_ratio: kv_router_config.router_prune_target_ratio,
});
return Ok(Self::KvIndexer(KvIndexer::new_with_frequency(
cancellation_token,
None,
block_size,
kv_indexer_metrics,
prune_config,
)));
return Ok(Self::KvIndexer {
primary: KvIndexer::new_with_frequency(
cancellation_token,
None,
block_size,
kv_indexer_metrics,
prune_config,
),
lower_tier: LowerTierIndexers::new(1, block_size),
});
}
if kv_router_config.router_event_threads > 1 {
let kv_indexer_metrics = KvIndexerMetrics::from_component(component);
return Ok(Self::Concurrent(Arc::new(
ThreadPoolIndexer::new_with_metrics(
return Ok(Self::Concurrent {
primary: Arc::new(ThreadPoolIndexer::new_with_metrics(
ConcurrentRadixTreeCompressed::new(),
kv_router_config.router_event_threads as usize,
block_size,
Some(kv_indexer_metrics),
)),
lower_tier: LowerTierIndexers::new(
kv_router_config.router_event_threads as usize,
block_size,
),
)));
});
}
let kv_indexer_metrics = KvIndexerMetrics::from_component(component);
let cancellation_token = component.drt().primary_token();
Ok(Self::KvIndexer(KvIndexer::new_with_frequency(
cancellation_token,
None,
block_size,
kv_indexer_metrics,
None,
)))
Ok(Self::KvIndexer {
primary: KvIndexer::new_with_frequency(
cancellation_token,
None,
block_size,
kv_indexer_metrics,
None,
),
lower_tier: LowerTierIndexers::new(1, block_size),
})
}
#[allow(dead_code)]
pub(crate) async fn find_matches(
&self,
sequence: Vec<LocalBlockHash>,
) -> Result<OverlapScores, KvRouterError> {
self.find_match_details(sequence)
.await
.map(|details| details.overlap_scores)
}
pub(crate) async fn find_match_details(
&self,
sequence: Vec<LocalBlockHash>,
) -> Result<MatchDetails, KvRouterError> {
match self {
Self::KvIndexer(indexer) => indexer.find_matches(sequence).await,
Self::Concurrent(tpi) => tpi.find_matches(sequence).await,
Self::Remote(remote) => match remote.find_matches(sequence).await {
Ok(scores) => Ok(scores),
Err(error) => {
tracing::warn!(error = %error, "Remote indexer query failed");
Ok(OverlapScores::new())
}
},
Self::None => Ok(OverlapScores::new()),
Self::KvIndexer { primary, .. } => primary.find_match_details(sequence).await,
Self::Concurrent { primary, .. } => {
Ok(primary.backend().find_match_details_impl(&sequence, false))
}
Self::Remote(remote) => remote
.find_matches(sequence)
.await
.map(|overlap_scores| MatchDetails {
overlap_scores,
..Default::default()
})
.map_err(|e| {
tracing::warn!(error = %e, "Remote indexer query failed");
KvRouterError::IndexerOffline
}),
Self::None => Ok(MatchDetails::new()),
}
}
pub(crate) async fn find_matches_by_tier(
&self,
sequence: Vec<LocalBlockHash>,
) -> Result<TieredMatchDetails, KvRouterError> {
let device = self.find_match_details(sequence.clone()).await?;
let lower_tier = match self {
Self::KvIndexer { lower_tier, .. } | Self::Concurrent { lower_tier, .. } => {
query_lower_tiers(lower_tier, &sequence, &device)
}
Self::Remote(_) | Self::None => HashMap::new(),
};
Ok(TieredMatchDetails { device, lower_tier })
}
pub(crate) async fn record_hashed_routing_decision(
&self,
worker: WorkerWithDpRank,
......@@ -131,12 +301,12 @@ impl Indexer {
sequence_hashes: Vec<SequenceHash>,
) -> Result<(), KvRouterError> {
match self {
Self::KvIndexer(indexer) => {
indexer
Self::KvIndexer { primary, .. } => {
primary
.process_routing_decision_with_hashes(worker, local_hashes, sequence_hashes)
.await
}
Self::Concurrent(_) => {
Self::Concurrent { .. } => {
tracing::warn!(
"Hashed routing-decision recording is unsupported for concurrent indexers"
);
......@@ -155,8 +325,8 @@ impl Indexer {
pub(crate) async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
match self {
Self::KvIndexer(indexer) => indexer.dump_events().await,
Self::Concurrent(tpi) => tpi.dump_events().await,
Self::KvIndexer { primary, .. } => primary.dump_events().await,
Self::Concurrent { primary, .. } => primary.dump_events().await,
Self::Remote(_) => Ok(Vec::new()),
Self::None => {
panic!(
......@@ -172,14 +342,15 @@ impl Indexer {
worker: WorkerWithDpRank,
) -> Result<(), KvRouterError> {
match self {
Self::KvIndexer(_) | Self::Remote(_) => {
Self::KvIndexer { .. } | Self::Remote(_) => {
let local_hashes = tokens_with_hashes.get_or_compute_block_hashes().to_vec();
let sequence_hashes = tokens_with_hashes.get_or_compute_seq_hashes().to_vec();
self.record_hashed_routing_decision(worker, local_hashes, sequence_hashes)
.await
}
Self::Concurrent(tpi) => {
tpi.process_routing_decision_for_request(tokens_with_hashes, worker)
Self::Concurrent { primary, .. } => {
primary
.process_routing_decision_for_request(tokens_with_hashes, worker)
.await
}
Self::None => Ok(()),
......@@ -188,25 +359,77 @@ impl Indexer {
pub(crate) async fn apply_event(&self, event: RouterEvent) {
match self {
Self::KvIndexer(indexer) => {
if let Err(e) = indexer.event_sender().send(event).await {
tracing::warn!("Failed to send event to indexer: {e}");
Self::KvIndexer {
primary,
lower_tier,
} => match &event.event.data {
dynamo_kv_router::protocols::KvCacheEventData::Cleared => {
if let Err(e) = primary.event_sender().send(event.clone()).await {
tracing::warn!("Failed to send event to indexer: {e}");
}
for indexer in lower_tier.all() {
indexer.apply_event(event.clone()).await;
}
}
}
Self::Concurrent(tpi) => tpi.apply_event(event).await,
_ if event.storage_tier.is_gpu() => {
if let Err(e) = primary.event_sender().send(event).await {
tracing::warn!("Failed to send event to indexer: {e}");
}
}
_ => {
lower_tier
.get_or_create(event.storage_tier)
.apply_event(event)
.await;
}
},
Self::Concurrent {
primary,
lower_tier,
} => match &event.event.data {
dynamo_kv_router::protocols::KvCacheEventData::Cleared => {
primary.apply_event(event.clone()).await;
for indexer in lower_tier.all() {
indexer.apply_event(event.clone()).await;
}
}
_ if event.storage_tier.is_gpu() => {
primary.apply_event(event).await;
}
_ => {
lower_tier
.get_or_create(event.storage_tier)
.apply_event(event)
.await;
}
},
Self::Remote(_) | Self::None => {}
}
}
pub(crate) async fn remove_worker(&self, worker_id: WorkerId) {
match self {
Self::KvIndexer(indexer) => {
if let Err(e) = indexer.remove_worker_sender().send(worker_id).await {
Self::KvIndexer {
primary,
lower_tier,
} => {
for indexer in lower_tier.all() {
indexer.remove_worker(worker_id).await;
}
if let Err(e) = primary.remove_worker_sender().send(worker_id).await {
tracing::warn!("Failed to send worker removal for {worker_id}: {e}");
}
}
Self::Concurrent(tpi) => {
KvIndexerInterface::remove_worker(tpi.as_ref(), worker_id).await;
Self::Concurrent {
primary,
lower_tier,
} => {
for indexer in lower_tier.all() {
indexer.remove_worker(worker_id).await;
}
KvIndexerInterface::remove_worker(primary.as_ref(), worker_id).await;
}
Self::Remote(_) | Self::None => {}
}
......@@ -214,11 +437,24 @@ impl Indexer {
pub(crate) async fn remove_worker_dp_rank(&self, worker_id: WorkerId, dp_rank: DpRank) {
match self {
Self::KvIndexer(indexer) => {
KvIndexerInterface::remove_worker_dp_rank(indexer, worker_id, dp_rank).await;
Self::KvIndexer {
primary,
lower_tier,
} => {
for indexer in lower_tier.all() {
KvIndexerInterface::remove_worker_dp_rank(&*indexer, worker_id, dp_rank).await;
}
KvIndexerInterface::remove_worker_dp_rank(primary, worker_id, dp_rank).await;
}
Self::Concurrent(tpi) => {
KvIndexerInterface::remove_worker_dp_rank(tpi.as_ref(), worker_id, dp_rank).await;
Self::Concurrent {
primary,
lower_tier,
} => {
for indexer in lower_tier.all() {
KvIndexerInterface::remove_worker_dp_rank(&*indexer, worker_id, dp_rank).await;
}
KvIndexerInterface::remove_worker_dp_rank(primary.as_ref(), worker_id, dp_rank)
.await;
}
Self::Remote(_) | Self::None => {}
}
......@@ -226,17 +462,472 @@ impl Indexer {
pub(crate) async fn get_workers(&self) -> Vec<WorkerId> {
match self {
Self::KvIndexer(indexer) => {
Self::KvIndexer { primary, .. } => {
let (resp_tx, resp_rx) = oneshot::channel();
let req = dynamo_kv_router::indexer::GetWorkersRequest { resp: resp_tx };
if let Err(e) = indexer.get_workers_sender().send(req).await {
if let Err(e) = primary.get_workers_sender().send(req).await {
tracing::warn!("Failed to send get_workers request: {e}");
return Vec::new();
}
resp_rx.await.unwrap_or_default()
}
Self::Concurrent(tpi) => tpi.backend().get_workers(),
Self::Concurrent { primary, .. } => primary.backend().get_workers(),
Self::Remote(_) | Self::None => Vec::new(),
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use tokio_util::sync::CancellationToken;
use super::{Indexer, LowerTierIndexers};
use dynamo_kv_router::{
ConcurrentRadixTreeCompressed, ThreadPoolIndexer,
indexer::{KvIndexer, KvIndexerInterface, KvIndexerMetrics},
protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash, RouterEvent, StorageTier, WorkerWithDpRank,
compute_seq_hash_for_block,
},
};
fn make_test_indexer() -> Indexer {
Indexer::KvIndexer {
primary: KvIndexer::new(
CancellationToken::new(),
4,
Arc::new(KvIndexerMetrics::new_unregistered()),
),
lower_tier: LowerTierIndexers::new(1, 4),
}
}
fn make_test_concurrent_indexer() -> Indexer {
Indexer::Concurrent {
primary: Arc::new(ThreadPoolIndexer::new(
ConcurrentRadixTreeCompressed::new(),
2,
4,
)),
lower_tier: LowerTierIndexers::new(2, 4),
}
}
async fn flush_indexer(indexer: &Indexer) {
match indexer {
Indexer::KvIndexer {
primary,
lower_tier,
} => {
let _ = primary.flush().await;
for indexer in lower_tier.all() {
let _ = indexer.dump_events().await.unwrap();
}
}
Indexer::Concurrent {
primary,
lower_tier,
} => {
primary.flush().await;
for indexer in lower_tier.all() {
let _ = indexer.dump_events().await.unwrap();
}
}
Indexer::Remote(_) | Indexer::None => {}
}
}
fn store_event(
worker_id: u64,
dp_rank: u32,
event_id: u64,
prefix_hashes: &[u64],
local_hashes: &[u64],
storage_tier: StorageTier,
) -> RouterEvent {
let prefix_block_hashes: Vec<LocalBlockHash> =
prefix_hashes.iter().copied().map(LocalBlockHash).collect();
let parent_hash = compute_seq_hash_for_block(&prefix_block_hashes)
.last()
.copied()
.map(ExternalSequenceBlockHash);
let full_hashes: Vec<LocalBlockHash> = prefix_hashes
.iter()
.chain(local_hashes.iter())
.copied()
.map(LocalBlockHash)
.collect();
let full_sequence_hashes = compute_seq_hash_for_block(&full_hashes);
let new_sequence_hashes = &full_sequence_hashes[prefix_hashes.len()..];
let blocks = local_hashes
.iter()
.zip(new_sequence_hashes.iter())
.map(|(&local_hash, &sequence_hash)| KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(sequence_hash),
tokens_hash: LocalBlockHash(local_hash),
mm_extra_info: None,
})
.collect();
RouterEvent::with_storage_tier(
worker_id,
KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
start_position: None,
blocks,
}),
dp_rank,
},
storage_tier,
)
}
#[tokio::test]
async fn tiered_query_chains_device_host_and_disk() {
let indexer = make_test_indexer();
let worker = WorkerWithDpRank::new(7, 0);
indexer
.apply_event(store_event(7, 0, 1, &[], &[11, 12], StorageTier::Device))
.await;
indexer
.apply_event(store_event(
7,
0,
2,
&[11, 12],
&[13],
StorageTier::HostPinned,
))
.await;
indexer
.apply_event(store_event(
7,
0,
3,
&[11, 12, 13],
&[14],
StorageTier::Disk,
))
.await;
flush_indexer(&indexer).await;
let matches = indexer
.find_matches_by_tier(vec![
LocalBlockHash(11),
LocalBlockHash(12),
LocalBlockHash(13),
LocalBlockHash(14),
])
.await
.unwrap();
assert_eq!(matches.device.overlap_scores.scores.get(&worker), Some(&2));
assert_eq!(
matches
.lower_tier
.get(&StorageTier::HostPinned)
.and_then(|tier| tier.hits.get(&worker)),
Some(&1)
);
assert_eq!(
matches
.lower_tier
.get(&StorageTier::Disk)
.and_then(|tier| tier.hits.get(&worker)),
Some(&1)
);
}
#[tokio::test]
async fn tiered_query_seeds_lower_tier_only_workers_without_affecting_device_scores() {
let indexer = make_test_indexer();
let device_worker = WorkerWithDpRank::new(10, 0);
let host_only_worker = WorkerWithDpRank::new(20, 0);
let disk_only_worker = WorkerWithDpRank::new(30, 0);
indexer
.apply_event(store_event(10, 0, 1, &[], &[21], StorageTier::Device))
.await;
indexer
.apply_event(store_event(20, 0, 2, &[], &[21], StorageTier::HostPinned))
.await;
indexer
.apply_event(store_event(30, 0, 3, &[], &[21], StorageTier::Disk))
.await;
flush_indexer(&indexer).await;
let matches = indexer
.find_matches_by_tier(vec![LocalBlockHash(21)])
.await
.unwrap();
assert_eq!(
matches.device.overlap_scores.scores.get(&device_worker),
Some(&1)
);
assert!(
!matches
.device
.overlap_scores
.scores
.contains_key(&host_only_worker)
);
assert!(
!matches
.device
.overlap_scores
.scores
.contains_key(&disk_only_worker)
);
assert_eq!(
matches
.lower_tier
.get(&StorageTier::HostPinned)
.and_then(|tier| tier.hits.get(&host_only_worker)),
Some(&1)
);
assert_eq!(
matches
.lower_tier
.get(&StorageTier::Disk)
.and_then(|tier| tier.hits.get(&disk_only_worker)),
Some(&1)
);
}
#[tokio::test]
async fn tiered_query_only_seeds_matching_root_workers() {
let indexer = make_test_indexer();
let matching_host_worker = WorkerWithDpRank::new(20, 0);
let nonmatching_host_worker = WorkerWithDpRank::new(21, 0);
indexer
.apply_event(store_event(20, 0, 1, &[], &[31], StorageTier::HostPinned))
.await;
indexer
.apply_event(store_event(21, 0, 2, &[], &[32], StorageTier::HostPinned))
.await;
flush_indexer(&indexer).await;
let matches = indexer
.find_matches_by_tier(vec![LocalBlockHash(31)])
.await
.unwrap();
assert_eq!(
matches
.lower_tier
.get(&StorageTier::HostPinned)
.and_then(|tier| tier.hits.get(&matching_host_worker)),
Some(&1)
);
assert!(
!matches
.lower_tier
.get(&StorageTier::HostPinned)
.is_some_and(|tier| tier.hits.contains_key(&nonmatching_host_worker))
);
}
#[tokio::test]
async fn concurrent_tiered_query_chains_device_and_lower_tier_matches() {
let indexer = make_test_concurrent_indexer();
let worker = WorkerWithDpRank::new(7, 0);
indexer
.apply_event(store_event(7, 0, 1, &[], &[11, 12], StorageTier::Device))
.await;
indexer
.apply_event(store_event(
7,
0,
2,
&[11, 12],
&[13],
StorageTier::HostPinned,
))
.await;
flush_indexer(&indexer).await;
let matches = indexer
.find_matches_by_tier(vec![
LocalBlockHash(11),
LocalBlockHash(12),
LocalBlockHash(13),
])
.await
.unwrap();
assert_eq!(matches.device.overlap_scores.scores.get(&worker), Some(&2));
assert_eq!(
matches
.lower_tier
.get(&StorageTier::HostPinned)
.and_then(|tier| tier.hits.get(&worker)),
Some(&1)
);
}
#[tokio::test]
async fn concurrent_tiered_query_seeds_lower_tier_only_workers_without_affecting_device_scores()
{
let indexer = make_test_concurrent_indexer();
let device_worker = WorkerWithDpRank::new(10, 0);
let host_only_worker = WorkerWithDpRank::new(20, 0);
let disk_only_worker = WorkerWithDpRank::new(30, 0);
indexer
.apply_event(store_event(10, 0, 1, &[], &[21], StorageTier::Device))
.await;
indexer
.apply_event(store_event(20, 0, 2, &[], &[21], StorageTier::HostPinned))
.await;
indexer
.apply_event(store_event(30, 0, 3, &[], &[21], StorageTier::Disk))
.await;
flush_indexer(&indexer).await;
let matches = indexer
.find_matches_by_tier(vec![LocalBlockHash(21)])
.await
.unwrap();
assert_eq!(
matches.device.overlap_scores.scores.get(&device_worker),
Some(&1)
);
assert!(
!matches
.device
.overlap_scores
.scores
.contains_key(&host_only_worker)
);
assert!(
!matches
.device
.overlap_scores
.scores
.contains_key(&disk_only_worker)
);
assert_eq!(
matches
.lower_tier
.get(&StorageTier::HostPinned)
.and_then(|tier| tier.hits.get(&host_only_worker)),
Some(&1)
);
assert_eq!(
matches
.lower_tier
.get(&StorageTier::Disk)
.and_then(|tier| tier.hits.get(&disk_only_worker)),
Some(&1)
);
}
/// Regression test: when a worker has blocks in both device and lower-tier
/// storage (e.g. same prefix stored on GPU and offloaded to host), the
/// Concurrent indexer doesn't return last_matched_hashes. Without the fix,
/// query_lower_tiers would re-query that worker from root in the lower tier,
/// double-counting overlap blocks and producing cached_tokens > ISL.
#[tokio::test]
async fn concurrent_tiered_query_does_not_double_count_device_and_lower_tier_overlap() {
let indexer = make_test_concurrent_indexer();
let worker = WorkerWithDpRank::new(7, 0);
// Worker has the same blocks in both device and host-pinned storage.
indexer
.apply_event(store_event(
7,
0,
1,
&[],
&[11, 12, 13],
StorageTier::Device,
))
.await;
indexer
.apply_event(store_event(
7,
0,
2,
&[],
&[11, 12, 13],
StorageTier::HostPinned,
))
.await;
flush_indexer(&indexer).await;
let matches = indexer
.find_matches_by_tier(vec![
LocalBlockHash(11),
LocalBlockHash(12),
LocalBlockHash(13),
])
.await
.unwrap();
// Device overlap should be 3 blocks.
assert_eq!(matches.device.overlap_scores.scores.get(&worker), Some(&3));
// Lower-tier must NOT report additional hits for the same worker
// whose blocks are already fully accounted for in the device tier.
let host_hits = matches
.lower_tier
.get(&StorageTier::HostPinned)
.and_then(|tier| tier.hits.get(&worker).copied())
.unwrap_or(0);
assert_eq!(
host_hits, 0,
"lower-tier should not double-count blocks already matched in device tier \
(got {host_hits} host-pinned hits for a worker with full device overlap)"
);
}
#[tokio::test]
async fn concurrent_remove_worker_removes_lower_tier_state() {
let indexer = make_test_concurrent_indexer();
let worker = WorkerWithDpRank::new(20, 0);
indexer
.apply_event(store_event(20, 0, 1, &[], &[31], StorageTier::HostPinned))
.await;
flush_indexer(&indexer).await;
let before = indexer
.find_matches_by_tier(vec![LocalBlockHash(31)])
.await
.unwrap();
assert_eq!(
before
.lower_tier
.get(&StorageTier::HostPinned)
.and_then(|tier| tier.hits.get(&worker)),
Some(&1)
);
indexer.remove_worker(20).await;
flush_indexer(&indexer).await;
let after = indexer
.find_matches_by_tier(vec![LocalBlockHash(31)])
.await
.unwrap();
assert!(
!after
.lower_tier
.get(&StorageTier::HostPinned)
.is_some_and(|tier| tier.hits.contains_key(&worker))
);
}
}
......@@ -823,7 +823,7 @@ impl Drop for SlowQueryGuard {
#[cfg(test)]
mod tests {
use super::*;
use crate::kv_router::Indexer;
use crate::kv_router::{Indexer, indexer::LowerTierIndexers};
use dynamo_kv_router::indexer::{KvIndexer, KvIndexerInterface, KvIndexerMetrics};
use dynamo_kv_router::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheStoreData,
......@@ -923,7 +923,13 @@ mod tests {
let token = CancellationToken::new();
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
let kv_indexer = KvIndexer::new(token, 4, metrics);
(kv_indexer.clone(), Indexer::KvIndexer(kv_indexer))
(
kv_indexer.clone(),
Indexer::KvIndexer {
primary: kv_indexer,
lower_tier: LowerTierIndexers::new(1, 4),
},
)
}
async fn make_test_client(
......
......@@ -287,7 +287,6 @@ impl PrefillRouter {
lora_name,
priority_jump,
None,
None,
allowed_worker_ids,
)
.await?;
......
......@@ -31,35 +31,48 @@ use super::{DEFAULT_MAX_BATCH_BLOCKS, kv_publisher_metrics};
/// - **Remove**: only passes through when refcount decrements to 0.
/// - **Cleared**: resets refcounts for all ranks.
pub(super) struct EventDedupFilter {
/// Per-dp-rank refcounts.
per_rank: HashMap<u32, HashMap<ExternalSequenceBlockHash, usize>>,
/// Per-(dp_rank, storage_tier) refcounts.
per_rank_tier: HashMap<(u32, StorageTier), HashMap<ExternalSequenceBlockHash, usize>>,
}
impl EventDedupFilter {
pub(super) fn new() -> Self {
Self {
per_rank: HashMap::new(),
per_rank_tier: HashMap::new(),
}
}
/// Track a store event. Increments refcount for each block hash on the
/// given DP rank. Stores always pass through — this only updates bookkeeping.
pub(super) fn track_store(&mut self, dp_rank: u32, data: &KvCacheStoreData) {
let refcounts = self.per_rank.entry(dp_rank).or_default();
/// given (DP rank, storage tier). Stores always pass through — this only
/// updates bookkeeping.
pub(super) fn track_store(
&mut self,
dp_rank: u32,
storage_tier: StorageTier,
data: &KvCacheStoreData,
) {
let refcounts = self
.per_rank_tier
.entry((dp_rank, storage_tier))
.or_default();
for block in &data.blocks {
*refcounts.entry(block.block_hash).or_insert(0) += 1;
}
}
/// Filter a remove event. Retains only block hashes whose refcount on the
/// given DP rank decrements to 0 (removing them from the map). Returns
/// `None` if no hashes survive filtering.
/// given (DP rank, storage tier) decrements to 0 (removing them from the
/// map). Returns `None` if no hashes survive filtering.
pub(super) fn filter_remove(
&mut self,
dp_rank: u32,
storage_tier: StorageTier,
mut data: KvCacheRemoveData,
) -> Option<KvCacheRemoveData> {
let refcounts = self.per_rank.entry(dp_rank).or_default();
let refcounts = self
.per_rank_tier
.entry((dp_rank, storage_tier))
.or_default();
data.block_hashes.retain(|hash| {
match refcounts.entry(*hash) {
Entry::Occupied(mut entry) => {
......@@ -83,11 +96,11 @@ impl EventDedupFilter {
}
}
/// Clear refcounts for all DP ranks. A `Cleared` event from any rank
/// causes the indexer to wipe all blocks for the entire worker, so we
/// must reset all ranks' refcounts to stay consistent.
/// Clear refcounts for all DP ranks and tiers. A `Cleared` event from any
/// rank causes the indexer to wipe all blocks for the entire worker, so we
/// must reset all refcounts to stay consistent.
pub(super) fn clear(&mut self) {
self.per_rank.clear();
self.per_rank_tier.clear();
}
}
......@@ -99,6 +112,7 @@ pub(super) struct BatchingState {
pub(super) pending_stored: Option<KvCacheStoreData>,
pub(super) next_publish_id: u64,
pub(super) last_dp_rank: u32,
pub(super) last_storage_tier: StorageTier,
pub(super) last_flush_time: Instant,
}
......@@ -109,6 +123,7 @@ impl BatchingState {
pending_stored: None,
next_publish_id: 1,
last_dp_rank: 0,
last_storage_tier: StorageTier::Device,
last_flush_time: Instant::now(),
}
}
......@@ -160,12 +175,13 @@ impl BatchingState {
let dp_rank = self.last_dp_rank;
let mut emitted = false;
if let Some(data) = self.pending_removed.take()
&& let Some(filtered) = dedup.filter_remove(dp_rank, data)
&& let Some(filtered) = dedup.filter_remove(dp_rank, self.last_storage_tier, data)
{
emit(
publisher,
local_indexer,
worker_id,
self.last_storage_tier,
KvCacheEvent {
event_id: self.next_publish_id,
data: KvCacheEventData::Removed(filtered),
......@@ -176,11 +192,12 @@ impl BatchingState {
emitted = true;
}
if let Some(data) = self.pending_stored.take() {
dedup.track_store(dp_rank, &data);
dedup.track_store(dp_rank, self.last_storage_tier, &data);
emit(
publisher,
local_indexer,
worker_id,
self.last_storage_tier,
KvCacheEvent {
event_id: self.next_publish_id,
data: KvCacheEventData::Stored(data),
......@@ -217,9 +234,10 @@ async fn emit<P: RouterEventSink>(
publisher: &P,
local_indexer: &Option<Arc<LocalKvIndexer>>,
worker_id: u64,
storage_tier: StorageTier,
event: KvCacheEvent,
) {
let router_event = RouterEvent::new(worker_id, event);
let router_event = RouterEvent::with_storage_tier(worker_id, event, storage_tier);
if let Some(indexer) = local_indexer
&& let Err(e) = indexer.apply_event_with_buffer(router_event.clone()).await
{
......@@ -281,16 +299,7 @@ pub(super) async fn run_event_processor_loop<P: RouterEventSink + Send + Sync +
}
last_raw_input_id = Some(raw_event_id);
if !placement_event.placement.is_local_gpu() {
tracing::trace!(
worker_id,
?placement_event.placement,
event_id = placement_event.event.event_id,
"Skipping non-local-GPU placement event"
);
continue;
}
let storage_tier = placement_event.placement.tier;
let event = placement_event.event;
tracing::trace!(
"Event processor for worker_id {} processing event: {:?}",
......@@ -300,10 +309,15 @@ pub(super) async fn run_event_processor_loop<P: RouterEventSink + Send + Sync +
let dp_rank_changed =
batching_state.has_pending() && event.dp_rank != batching_state.last_dp_rank;
let storage_tier_changed =
batching_state.has_pending() && storage_tier != batching_state.last_storage_tier;
match event.data {
KvCacheEventData::Removed(data) => {
if batching_state.pending_stored.is_some() || dp_rank_changed {
if batching_state.pending_stored.is_some()
|| dp_rank_changed
|| storage_tier_changed
{
batching_state.flush(&publisher, &local_indexer, worker_id, &mut dedup).await;
}
match &mut batching_state.pending_removed {
......@@ -315,6 +329,7 @@ pub(super) async fn run_event_processor_loop<P: RouterEventSink + Send + Sync +
}
KvCacheEventData::Stored(data) => {
let should_flush = dp_rank_changed
|| storage_tier_changed
|| batching_state.pending_removed.is_some()
|| batching_state.pending_stored.as_ref().is_some_and(|p| {
data.parent_hash != p.blocks.last().map(|b| b.block_hash)
......@@ -336,6 +351,7 @@ pub(super) async fn run_event_processor_loop<P: RouterEventSink + Send + Sync +
&publisher,
&local_indexer,
worker_id,
storage_tier,
KvCacheEvent {
event_id: batching_state.next_publish_id,
data: KvCacheEventData::Cleared,
......@@ -348,6 +364,7 @@ pub(super) async fn run_event_processor_loop<P: RouterEventSink + Send + Sync +
}
batching_state.last_dp_rank = event.dp_rank;
batching_state.last_storage_tier = storage_tier;
if batching_state.has_pending()
&& (timeout_ms.is_none_or(|ms| batching_state.is_timeout_elapsed(ms))
......
......@@ -1314,15 +1314,15 @@ mod test_event_dedup_filter {
let data = store_data(&[1, 2, 3]);
// Store same hashes twice — refcount should be 2
filter.track_store(0, &data);
filter.track_store(0, &data);
filter.track_store(0, StorageTier::Device, &data);
filter.track_store(0, StorageTier::Device, &data);
// First remove — refcounts 2→1, all filtered out
let result = filter.filter_remove(0, remove_data(&[1, 2, 3]));
let result = filter.filter_remove(0, StorageTier::Device, remove_data(&[1, 2, 3]));
assert!(result.is_none());
// Second remove — refcounts 1→0, all pass through
let result = filter.filter_remove(0, remove_data(&[1, 2, 3]));
let result = filter.filter_remove(0, StorageTier::Device, remove_data(&[1, 2, 3]));
assert!(result.is_some());
assert_eq!(result.unwrap().block_hashes.len(), 3);
}
......@@ -1332,15 +1332,15 @@ mod test_event_dedup_filter {
let mut filter = EventDedupFilter::new();
// Store same hash twice
filter.track_store(0, &store_data(&[1]));
filter.track_store(0, &store_data(&[1]));
filter.track_store(0, StorageTier::Device, &store_data(&[1]));
filter.track_store(0, StorageTier::Device, &store_data(&[1]));
// First remove — refcount 2→1, filtered out
let result = filter.filter_remove(0, remove_data(&[1]));
let result = filter.filter_remove(0, StorageTier::Device, remove_data(&[1]));
assert!(result.is_none());
// Second remove — refcount 1→0, passes through
let result = filter.filter_remove(0, remove_data(&[1]));
let result = filter.filter_remove(0, StorageTier::Device, remove_data(&[1]));
assert!(result.is_some());
assert_eq!(result.unwrap().block_hashes.len(), 1);
}
......@@ -1350,17 +1350,17 @@ mod test_event_dedup_filter {
let mut filter = EventDedupFilter::new();
// Store hash 1
filter.track_store(0, &store_data(&[1]));
filter.track_store(0, StorageTier::Device, &store_data(&[1]));
// Remove hash 1 — refcount 1→0, passes through
let result = filter.filter_remove(0, remove_data(&[1]));
let result = filter.filter_remove(0, StorageTier::Device, remove_data(&[1]));
assert!(result.is_some());
// Store hash 1 again — refcount starts fresh at 1
filter.track_store(0, &store_data(&[1]));
filter.track_store(0, StorageTier::Device, &store_data(&[1]));
// Remove again — refcount 1→0, passes through
let result = filter.filter_remove(0, remove_data(&[1]));
let result = filter.filter_remove(0, StorageTier::Device, remove_data(&[1]));
assert!(result.is_some());
}
......@@ -1369,20 +1369,20 @@ mod test_event_dedup_filter {
let mut filter = EventDedupFilter::new();
// Store on rank 0 and rank 1
filter.track_store(0, &store_data(&[1, 2]));
filter.track_store(0, &store_data(&[1, 2]));
filter.track_store(1, &store_data(&[1, 2]));
filter.track_store(1, &store_data(&[1, 2]));
filter.track_store(0, StorageTier::Device, &store_data(&[1, 2]));
filter.track_store(0, StorageTier::Device, &store_data(&[1, 2]));
filter.track_store(1, StorageTier::Device, &store_data(&[1, 2]));
filter.track_store(1, StorageTier::Device, &store_data(&[1, 2]));
// Clear wipes all ranks (matches indexer semantics where Cleared
// from any rank removes all blocks for the entire worker).
filter.clear();
// Both ranks pass through defensively after clear
let result = filter.filter_remove(0, remove_data(&[1]));
let result = filter.filter_remove(0, StorageTier::Device, remove_data(&[1]));
assert!(result.is_some());
let result = filter.filter_remove(1, remove_data(&[1]));
let result = filter.filter_remove(1, StorageTier::Device, remove_data(&[1]));
assert!(result.is_some());
}
......@@ -1391,18 +1391,18 @@ mod test_event_dedup_filter {
let mut filter = EventDedupFilter::new();
// Hash 1: stored twice (refcount 2)
filter.track_store(0, &store_data(&[1]));
filter.track_store(0, &store_data(&[1]));
filter.track_store(0, StorageTier::Device, &store_data(&[1]));
filter.track_store(0, StorageTier::Device, &store_data(&[1]));
// Hash 2: stored once (refcount 1)
filter.track_store(0, &store_data(&[2]));
filter.track_store(0, StorageTier::Device, &store_data(&[2]));
// Hash 3: stored twice (refcount 2)
filter.track_store(0, &store_data(&[3]));
filter.track_store(0, &store_data(&[3]));
filter.track_store(0, StorageTier::Device, &store_data(&[3]));
filter.track_store(0, StorageTier::Device, &store_data(&[3]));
// Remove all three — only hash 2 (refcount 1→0) passes through
let result = filter.filter_remove(0, remove_data(&[1, 2, 3]));
let result = filter.filter_remove(0, StorageTier::Device, remove_data(&[1, 2, 3]));
assert!(result.is_some());
let result = result.unwrap();
assert_eq!(result.block_hashes.len(), 1);
......@@ -1414,20 +1414,20 @@ mod test_event_dedup_filter {
let mut filter = EventDedupFilter::new();
// Store hash 1 on rank 0 (twice) and rank 1 (once)
filter.track_store(0, &store_data(&[1]));
filter.track_store(0, &store_data(&[1]));
filter.track_store(1, &store_data(&[1]));
filter.track_store(0, StorageTier::Device, &store_data(&[1]));
filter.track_store(0, StorageTier::Device, &store_data(&[1]));
filter.track_store(1, StorageTier::Device, &store_data(&[1]));
// Remove hash 1 on rank 1 — refcount 1→0, passes through
let result = filter.filter_remove(1, remove_data(&[1]));
let result = filter.filter_remove(1, StorageTier::Device, remove_data(&[1]));
assert!(result.is_some());
// Remove hash 1 on rank 0 — refcount 2→1, filtered out
let result = filter.filter_remove(0, remove_data(&[1]));
let result = filter.filter_remove(0, StorageTier::Device, remove_data(&[1]));
assert!(result.is_none());
// Remove hash 1 on rank 0 again — refcount 1→0, passes through
let result = filter.filter_remove(0, remove_data(&[1]));
let result = filter.filter_remove(0, StorageTier::Device, remove_data(&[1]));
assert!(result.is_some());
}
}
......@@ -1724,6 +1724,13 @@ mod event_processor_tests {
PlacementEvent::local_gpu(1, event)
}
fn local_host_event(event: KvCacheEvent) -> PlacementEvent {
PlacementEvent::new(
Placement::local_worker(1, event.dp_rank, StorageTier::HostPinned),
event,
)
}
/// Test that pushing N removed events results in batched output
/// Uses a 10ms timeout to ensure events are batched (events sent rapidly)
#[tokio::test]
......@@ -2287,6 +2294,106 @@ mod event_processor_tests {
);
}
#[tokio::test]
async fn test_host_tier_events_are_published_and_preserved() {
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
let publisher = MockPublisher::new();
let publisher_clone = publisher.clone();
let cancellation_token = CancellationToken::new();
let handle = tokio::spawn(async move {
run_event_processor_loop(
publisher_clone,
1,
cancellation_token,
rx,
None,
Some(100),
DEFAULT_MAX_BATCH_BLOCKS,
)
.await
});
tx.send(local_host_event(KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(42)],
}),
dp_rank: 0,
}))
.unwrap();
drop(tx);
handle.await.unwrap();
let events = publisher.get_events();
assert_eq!(
events.len(),
1,
"Expected a single published host-tier event"
);
assert_eq!(events[0].storage_tier, StorageTier::HostPinned);
let KvCacheEventData::Removed(data) = &events[0].event.data else {
panic!("Expected Removed event");
};
assert_eq!(data.block_hashes, vec![ExternalSequenceBlockHash(42)]);
}
#[tokio::test]
async fn test_storage_tier_change_causes_flush() {
let timeout_ms = Some(100);
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
let publisher = MockPublisher::new();
let publisher_clone = publisher.clone();
let cancellation_token = CancellationToken::new();
let handle = tokio::spawn(async move {
run_event_processor_loop(
publisher_clone,
1,
cancellation_token,
rx,
None,
timeout_ms,
DEFAULT_MAX_BATCH_BLOCKS,
)
.await
});
tx.send(local_host_event(KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(1)],
}),
dp_rank: 0,
}))
.unwrap();
tokio::task::yield_now().await;
tx.send(local_gpu_event(KvCacheEvent {
event_id: 1,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(2)],
}),
dp_rank: 0,
}))
.unwrap();
drop(tx);
handle.await.unwrap();
let events = publisher.get_events();
assert_eq!(
events.len(),
2,
"Changing storage tier should flush the current batch"
);
assert_eq!(events[0].storage_tier, StorageTier::HostPinned);
assert_eq!(events[1].storage_tier, StorageTier::Device);
}
/// Test that dp_rank change causes immediate flush
#[tokio::test]
async fn test_dp_rank_change_causes_flush() {
......
......@@ -45,26 +45,13 @@ pub struct KvPushRouter {
/// Result of worker selection containing instance ID, dp_rank, and overlap amount.
struct WorkerSelection {
instance_id: u64,
backend_dp_rank: Option<u32>,
bookkeeping_dp_rank: Option<u32>,
overlap_amount: Option<u32>,
}
fn pinned_worker_hint(
phase: RequestPhase,
routing: Option<&RoutingHints>,
) -> Option<(u64, Option<u32>)> {
let routing = routing?;
let worker_id = match phase {
RequestPhase::Prefill => routing.prefill_worker_id.or(routing.backend_instance_id),
RequestPhase::Decode => routing.decode_worker_id.or(routing.backend_instance_id),
RequestPhase::Aggregated => routing.backend_instance_id,
}?;
let dp_rank = match phase {
RequestPhase::Prefill => routing.prefill_dp_rank.or(routing.dp_rank),
RequestPhase::Decode | RequestPhase::Aggregated => routing.dp_rank,
};
Some((worker_id, dp_rank))
dp_rank: u32,
overlap_amount: u32,
effective_overlap_blocks: f64,
cached_tokens: usize,
/// Whether the scheduler is tracking this request (add_request or
/// find_best_match_details with update_states=true was called).
scheduler_tracked: bool,
}
/// Drop guard that manages the full lifecycle of a routed request:
......@@ -318,9 +305,9 @@ impl KvPushRouter {
let (routing_token_ids, block_mm_infos) = request.block_mm_routing_info();
let Some((pinned_worker_id, requested_dp_rank)) = pinned_worker_hint(phase, routing) else {
let _nvtx_kv = dynamo_nvtx_range!("route.kv_match");
let (best_worker, overlap_amount) = self
let selection = self
.chooser
.find_best_match(
.find_best_match_details(
Some(context_id),
routing_token_ids,
block_mm_infos,
......@@ -333,6 +320,10 @@ impl KvPushRouter {
allowed_worker_ids,
)
.await?;
let best_worker = selection.worker;
let effective_overlap_blocks = selection.cache_hit.effective_overlap_blocks;
let cached_tokens = selection.cache_hit.cached_tokens;
let overlap_amount = selection.cache_hit.rounded_overlap_blocks();
if !is_query_only {
let total_blocks = routing_token_ids
......@@ -357,20 +348,22 @@ impl KvPushRouter {
return Ok(WorkerSelection {
instance_id: best_worker.worker_id,
backend_dp_rank: Some(best_worker.dp_rank),
bookkeeping_dp_rank: Some(best_worker.dp_rank),
overlap_amount: Some(overlap_amount),
dp_rank: best_worker.dp_rank,
overlap_amount,
effective_overlap_blocks,
cached_tokens,
scheduler_tracked: !is_query_only,
});
};
let resolved_pinned_worker = requested_dp_rank
let resolved_pinned_worker: Option<WorkerWithDpRank> = requested_dp_rank
.or_else(|| self.chooser.unique_dp_rank_for_worker(pinned_worker_id))
.map(|dp_rank| WorkerWithDpRank::new(pinned_worker_id, dp_rank));
if !is_query_only && let Some(pinned_worker) = resolved_pinned_worker {
let (best_worker, overlap_amount) = self
let selection = self
.chooser
.find_best_match(
.find_best_match_details(
Some(context_id),
routing_token_ids,
block_mm_infos,
......@@ -383,43 +376,60 @@ impl KvPushRouter {
allowed_worker_ids,
)
.await?;
let best_worker = selection.worker;
let effective_overlap_blocks = selection.cache_hit.effective_overlap_blocks;
let cached_tokens = selection.cache_hit.cached_tokens;
let overlap_amount = selection.cache_hit.rounded_overlap_blocks();
return Ok(WorkerSelection {
instance_id: best_worker.worker_id,
backend_dp_rank: Some(best_worker.dp_rank),
bookkeeping_dp_rank: Some(best_worker.dp_rank),
overlap_amount: Some(overlap_amount),
dp_rank: best_worker.dp_rank,
overlap_amount,
effective_overlap_blocks,
cached_tokens,
scheduler_tracked: true,
});
}
let backend_dp_rank = resolved_pinned_worker.map(|worker| worker.dp_rank);
// Fallback: pinned worker hint was present but dp_rank could not be
// resolved (or this is a query-only request that skipped the scheduler
// path above). Estimate cache hit directly and, when possible, register
// the request with the scheduler for bookkeeping.
let resolved_dp_rank: Option<u32> = resolved_pinned_worker.map(|w| w.dp_rank);
tracing::debug!(
worker_id = pinned_worker_id,
dp_rank = ?backend_dp_rank,
dp_rank = ?resolved_dp_rank,
?phase,
"Routing to specified worker"
);
let (bookkeeping_dp_rank, overlap_amount) = if let Some(dp_rank) = backend_dp_rank {
let worker = WorkerWithDpRank::new(pinned_worker_id, dp_rank);
let overlap_blocks = self
.chooser
.get_overlap_blocks(
routing_token_ids,
block_mm_infos,
worker,
lora_name.as_deref(),
)
.await?;
// Build a WorkerWithDpRank; use 0 as a fallback dp_rank when it
// couldn't be resolved -- this is only used for the cache-hit
// estimate query and won't affect scheduler state.
let effective_dp_rank = resolved_dp_rank.unwrap_or(0);
let worker = WorkerWithDpRank::new(pinned_worker_id, effective_dp_rank);
let cache_hit = self
.chooser
.get_cache_hit_estimate(
routing_token_ids,
block_mm_infos,
worker,
lora_name.as_deref(),
)
.await?;
let effective_overlap_blocks = cache_hit.effective_overlap_blocks;
let cached_tokens = cache_hit.cached_tokens;
let overlap_blocks = cache_hit.rounded_overlap_blocks();
if !is_query_only {
if !is_query_only {
if let Some(_dp_rank) = resolved_dp_rank {
self.chooser
.add_request(
context_id.to_string(),
routing_token_ids,
block_mm_infos,
overlap_blocks,
cached_tokens,
expected_output_tokens,
worker,
lora_name,
......@@ -430,27 +440,26 @@ impl KvPushRouter {
tracing::debug!(
request_id = %context_id,
worker_id = pinned_worker_id,
dp_rank = dp_rank,
"Skipping add_request - query-only request"
?phase,
"Routing to specified worker without resolved dp_rank; skipping scheduler bookkeeping"
);
}
(Some(dp_rank), Some(overlap_blocks))
} else {
tracing::debug!(
request_id = %context_id,
worker_id = pinned_worker_id,
?phase,
"Routing to specified worker without resolved dp_rank; skipping scheduler bookkeeping"
dp_rank = ?resolved_dp_rank,
"Skipping add_request - query-only request"
);
(None, None)
};
}
Ok(WorkerSelection {
instance_id: pinned_worker_id,
backend_dp_rank,
bookkeeping_dp_rank,
overlap_amount,
dp_rank: effective_dp_rank,
overlap_amount: overlap_blocks,
effective_overlap_blocks,
cached_tokens,
scheduler_tracked: !is_query_only && resolved_dp_rank.is_some(),
})
}
}
......@@ -522,47 +531,40 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
.await?;
let WorkerSelection {
instance_id,
backend_dp_rank,
bookkeeping_dp_rank,
dp_rank,
overlap_amount,
effective_overlap_blocks,
cached_tokens,
scheduler_tracked,
} = selection;
let scheduler_tracked = !is_query_only && bookkeeping_dp_rank.is_some();
// In approximate mode (use_kv_events=false), record the routing decision
// so the indexer can track cache state based on routing decisions.
// This covers both pre-selected workers and find_best_match selections.
if !is_query_only && !self.chooser.kv_router_config().use_kv_events {
if let Some(dp_rank) = bookkeeping_dp_rank {
let lora_name = request.routing.as_ref().and_then(|r| r.lora_name.clone());
let (routing_token_ids, block_mm_infos) = request.block_mm_routing_info();
let worker = WorkerWithDpRank::new(instance_id, dp_rank);
let mut tokens_with_hashes =
TokensWithHashes::new(routing_token_ids.to_vec(), self.chooser.block_size())
.with_is_eagle(self.chooser.is_eagle());
if let Some(infos) = block_mm_infos {
tokens_with_hashes = tokens_with_hashes.with_mm_infos(infos.to_vec());
}
if let Some(lora_name) = lora_name {
tokens_with_hashes = tokens_with_hashes.with_lora_name(lora_name);
}
if let Err(e) = self
.chooser
.record_routing_decision(tokens_with_hashes, worker)
.await
{
tracing::warn!(
request_id = %context_id,
worker_id = instance_id,
dp_rank = dp_rank,
error = %e,
"Failed to record routing decision in approximate mode"
);
}
} else {
tracing::debug!(
let lora_name = request.routing.as_ref().and_then(|r| r.lora_name.clone());
let (routing_token_ids, block_mm_infos) = request.block_mm_routing_info();
let worker = WorkerWithDpRank::new(instance_id, dp_rank);
let mut tokens_with_hashes =
TokensWithHashes::new(routing_token_ids.to_vec(), self.chooser.block_size())
.with_is_eagle(self.chooser.is_eagle());
if let Some(infos) = block_mm_infos {
tokens_with_hashes = tokens_with_hashes.with_mm_infos(infos.to_vec());
}
if let Some(lora_name) = lora_name {
tokens_with_hashes = tokens_with_hashes.with_lora_name(lora_name);
}
if let Err(e) = self
.chooser
.record_routing_decision(tokens_with_hashes, worker)
.await
{
tracing::warn!(
request_id = %context_id,
worker_id = instance_id,
"Skipping approximate-mode routing decision for unresolved dp_rank"
dp_rank = dp_rank,
error = %e,
"Failed to record routing decision in approximate mode"
);
}
}
......@@ -573,14 +575,9 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
if let Some(ref tracker) = request.tracker {
let (routing_token_ids, _) = request.block_mm_routing_info();
let isl_blocks = routing_token_ids.len().div_ceil(block_size);
if let Some(overlap_amount) = overlap_amount {
tracker.record_kv_hit(overlap_amount, isl_blocks);
}
tracker.record_isl(
routing_token_ids.len(),
overlap_amount.map(|overlap| overlap as usize * block_size),
);
tracker.record_worker(instance_id, backend_dp_rank, self.chooser.worker_type());
tracker.record_kv_hit(effective_overlap_blocks, isl_blocks);
tracker.record_isl(routing_token_ids.len(), Some(cached_tokens));
tracker.record_worker(instance_id, Some(dp_rank), self.chooser.worker_type());
tracker.record_router_queue_depth(self.chooser.pending_count());
if let Some(hit_rate) = tracker.kv_hit_rate() {
request_metrics.kv_hit_rate.observe(hit_rate);
......@@ -641,7 +638,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
.await?;
let (mut backend_input, context) = request.into_parts();
backend_input.routing_mut().dp_rank = backend_dp_rank;
backend_input.routing_mut().dp_rank = Some(dp_rank);
let updated_request = context.map(|_| backend_input);
// Record prefill start right before pushing to backend (OnceLock: first call wins).
......@@ -691,8 +688,8 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
"kv_router.route_request",
request_id = %context_id,
worker_id = instance_id,
dp_rank = ?backend_dp_rank,
overlap_blocks = ?overlap_amount,
dp_rank = dp_rank,
overlap_blocks = overlap_amount,
phase = ?phase,
))
.await?;
......@@ -734,6 +731,35 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
}
}
/// Extract a phase-specific (worker_id, dp_rank) pin from routing hints.
///
/// Returns `Some((worker_id, optional_dp_rank))` when the request should be
/// pinned to a particular worker, or `None` when the normal KV-overlap
/// selection path should be used.
fn pinned_worker_hint(
phase: RequestPhase,
routing: Option<&RoutingHints>,
) -> Option<(u64, Option<u32>)> {
let routing = routing?;
match phase {
RequestPhase::Prefill => {
let worker_id = routing.prefill_worker_id.or(routing.backend_instance_id)?;
let dp_rank = routing.prefill_dp_rank.or(routing.dp_rank);
Some((worker_id, dp_rank))
}
RequestPhase::Decode => {
let worker_id = routing.decode_worker_id.or(routing.backend_instance_id)?;
let dp_rank = routing.dp_rank;
Some((worker_id, dp_rank))
}
RequestPhase::Aggregated => {
let worker_id = routing.backend_instance_id?;
let dp_rank = routing.dp_rank;
Some((worker_id, dp_rank))
}
}
}
/// A direct routing wrapper for `RouterMode::Direct`.
///
/// This wraps a `PushRouter` and reads worker IDs from each request's routing hints,
......
......@@ -5,6 +5,7 @@ use dynamo_kv_router::protocols::SharedCacheHits;
pub use dynamo_kv_router::scheduling::policy::RouterSchedulingPolicy;
pub use dynamo_kv_router::scheduling::{
KvSchedulerError, LocalScheduler, PotentialLoad, SchedulingRequest, SchedulingResponse,
TierOverlapBlocks,
};
pub use dynamo_kv_router::selector::DefaultWorkerSelector;
use dynamo_kv_router::selector::WorkerSelector as WorkerSelectorTrait;
......@@ -19,7 +20,7 @@ use anyhow::Result;
use dynamo_kv_router::{
PrefillLoadEstimator,
config::{KvRouterConfig, RouterConfigOverride},
protocols::{OverlapScores, WorkerId, WorkerWithDpRank},
protocols::{WorkerId, WorkerWithDpRank},
};
use dynamo_runtime::component::Component;
use dynamo_runtime::traits::DistributedRuntimeProvider;
......@@ -70,8 +71,7 @@ where
tracing::info!("skipping discovery-based worker monitoring");
}
let policy =
RouterSchedulingPolicy::new(kv_router_config.router_queue_policy, block_size as usize);
let policy = RouterSchedulingPolicy::new(kv_router_config.router_queue_policy);
tracing::info!(
"Router queue policy: {}",
kv_router_config.router_queue_policy
......@@ -131,7 +131,10 @@ where
maybe_request_id: Option<String>,
isl_tokens: usize,
token_seq: Option<Vec<SequenceHash>>,
overlaps: OverlapScores,
tier_overlap_blocks: TierOverlapBlocks,
effective_overlap_blocks: HashMap<dynamo_kv_router::protocols::WorkerWithDpRank, f64>,
effective_cached_tokens: HashMap<dynamo_kv_router::protocols::WorkerWithDpRank, usize>,
tree_sizes: HashMap<dynamo_kv_router::protocols::WorkerWithDpRank, usize>,
router_config_override: Option<&RouterConfigOverride>,
update_states: bool,
lora_name: Option<String>,
......@@ -147,7 +150,10 @@ where
maybe_request_id,
isl_tokens,
token_seq,
overlaps,
tier_overlap_blocks,
effective_overlap_blocks,
effective_cached_tokens,
tree_sizes,
router_config_override,
update_states,
lora_name,
......@@ -209,11 +215,15 @@ where
&self,
token_seq: Option<Vec<SequenceHash>>,
isl_tokens: usize,
overlaps: OverlapScores,
effective_cached_tokens: HashMap<dynamo_kv_router::protocols::WorkerWithDpRank, usize>,
track_prefill_tokens: bool,
) -> Vec<PotentialLoad> {
self.inner
.get_potential_loads(token_seq, isl_tokens, overlaps, track_prefill_tokens)
self.inner.get_potential_loads(
token_seq,
isl_tokens,
effective_cached_tokens,
track_prefill_tokens,
)
}
pub fn get_active_lora_counts(&self) -> HashMap<String, usize> {
......
......@@ -196,6 +196,7 @@ mod tests {
.await?;
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
let decay_now = Instant::now();
seq_manager_1.add_request(
SequenceRequest {
......@@ -207,7 +208,7 @@ mod tests {
worker: WorkerWithDpRank::new(0, 0),
lora_name: None,
},
Instant::now(),
decay_now,
)?;
seq_manager_1.add_request(
......@@ -220,7 +221,7 @@ mod tests {
worker: WorkerWithDpRank::new(0, 1),
lora_name: None,
},
Instant::now(),
decay_now,
)?;
seq_manager_2.add_request(
......@@ -233,7 +234,7 @@ mod tests {
worker: WorkerWithDpRank::new(1, 0),
lora_name: None,
},
Instant::now(),
decay_now,
)?;
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
......@@ -349,6 +350,7 @@ mod tests {
.await?;
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
let decay_now = Instant::now();
seq_manager_1.add_request(
SequenceRequest {
......@@ -360,7 +362,7 @@ mod tests {
worker: WorkerWithDpRank::from_worker_id(0),
lora_name: None,
},
Instant::now(),
decay_now,
)?;
seq_manager_1.add_request(
......@@ -373,7 +375,7 @@ mod tests {
worker: WorkerWithDpRank::from_worker_id(1),
lora_name: None,
},
Instant::now(),
decay_now,
)?;
seq_manager_2.add_request(
......@@ -386,7 +388,7 @@ mod tests {
worker: WorkerWithDpRank::from_worker_id(2),
lora_name: None,
},
Instant::now(),
decay_now,
)?;
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
......
......@@ -105,8 +105,8 @@ pub struct RequestTracker {
/// record the final finish time.
request_finish_time: Mutex<Option<Instant>>,
/// KV cache overlap blocks (prefix cache hits) - set once via OnceLock
kv_overlap_blocks: OnceLock<u32>,
/// Effective KV cache overlap blocks (weighted prefix cache hits) - set once via OnceLock
kv_overlap_blocks: OnceLock<f64>,
/// Input sequence length in blocks (for hit rate calculation) - set once via OnceLock
isl_blocks: OnceLock<usize>,
......@@ -114,7 +114,7 @@ pub struct RequestTracker {
/// Input sequence length in tokens - set once via OnceLock
isl_tokens: OnceLock<usize>,
/// Number of cached tokens (overlap_blocks * block_size) - set once via OnceLock
/// Number of cached tokens derived from the effective cache hit - set once via OnceLock
cached_tokens: OnceLock<usize>,
/// Output sequence length in tokens - updated atomically as tokens stream back
......@@ -226,7 +226,7 @@ impl RequestTracker {
}
/// Record KV cache hit information. Returns true if this was the first call.
pub fn record_kv_hit(&self, overlap_blocks: u32, isl_blocks: usize) -> bool {
pub fn record_kv_hit(&self, overlap_blocks: f64, isl_blocks: usize) -> bool {
let overlap_set = self.kv_overlap_blocks.set(overlap_blocks).is_ok();
let isl_set = self.isl_blocks.set(isl_blocks).is_ok();
overlap_set && isl_set
......@@ -311,7 +311,7 @@ impl RequestTracker {
if isl == 0 {
return None;
}
Some(overlap as f64 / isl as f64)
Some(overlap / isl as f64)
}
/// Set the request phase and return a permit that blocks subsequent phase changes.
......@@ -707,7 +707,7 @@ mod tests {
#[test]
fn test_kv_hit_rate() {
let tracker = RequestTracker::new();
tracker.record_kv_hit(3, 10);
tracker.record_kv_hit(3.0, 10);
let rate = tracker.kv_hit_rate().unwrap();
assert!(
......@@ -719,7 +719,7 @@ mod tests {
#[test]
fn test_kv_hit_rate_zero_isl() {
let tracker = RequestTracker::new();
tracker.record_kv_hit(0, 0);
tracker.record_kv_hit(0.0, 0);
assert!(
tracker.kv_hit_rate().is_none(),
"KV hit rate should be None when isl_blocks is 0"
......
......@@ -389,6 +389,8 @@ pub struct TokenBlock {
block_hash: BlockHash,
sequence_hash: SequenceHash,
parent_sequence_hash: Option<SequenceHash>,
external_sequence_hash: Option<SequenceHash>,
external_parent_sequence_hash: Option<SequenceHash>,
}
impl TokenBlock {
......@@ -425,6 +427,8 @@ impl TokenBlock {
block_hash: chunk.block_hash,
sequence_hash,
parent_sequence_hash,
external_sequence_hash: None,
external_parent_sequence_hash: None,
}
}
......@@ -453,6 +457,61 @@ impl TokenBlock {
self.parent_sequence_hash
}
/// Returns the TRT-LLM/framework sequence hash for this block, if assigned.
pub fn external_sequence_hash(&self) -> Option<SequenceHash> {
self.external_sequence_hash
}
/// Returns the TRT-LLM/framework parent sequence hash for this block, if assigned.
pub fn external_parent_sequence_hash(&self) -> Option<SequenceHash> {
self.external_parent_sequence_hash
}
/// Assigns the TRT-LLM/framework hash chain for this block.
///
/// Idempotent: calling with the same values on an already-assigned block
/// is a no-op, but re-assigning a different chain panics to match the
/// invariant `sync_external_sequence_hashes` enforces.
pub fn assign_external_hashes(
&mut self,
external_sequence_hash: SequenceHash,
external_parent_sequence_hash: Option<SequenceHash>,
) {
if let Some(existing) = self.external_sequence_hash {
assert_eq!(
existing, external_sequence_hash,
"external_sequence_hash re-assignment mismatch",
);
assert_eq!(
self.external_parent_sequence_hash, external_parent_sequence_hash,
"external_parent_sequence_hash re-assignment mismatch",
);
return;
}
self.external_sequence_hash = Some(external_sequence_hash);
self.external_parent_sequence_hash = external_parent_sequence_hash;
}
/// Ensures that this complete block has an assigned TRT-LLM/framework hash chain.
pub fn assert_external_hashes_assigned(&self) {
assert!(
self.external_sequence_hash.is_some(),
"complete block is missing external_sequence_hash"
);
if self.parent_sequence_hash.is_some() {
assert!(
self.external_parent_sequence_hash.is_some(),
"non-root complete block is missing external_parent_sequence_hash"
);
} else {
assert!(
self.external_parent_sequence_hash.is_none(),
"root complete block must not have external_parent_sequence_hash"
);
}
}
/// Returns the number of tokens in the block.
pub fn block_size(&self) -> usize {
self.tokens.0.len()
......@@ -836,6 +895,45 @@ impl TokenBlockSequence {
Tokens::from(result)
}
/// Synchronize the TRT-LLM/framework sequence hash chain onto all completed blocks.
///
/// `external_sequence_hashes` must contain exactly one hash per completed block in
/// sequence order. Existing assignments are validated and preserved.
pub fn sync_external_sequence_hashes(&mut self, external_sequence_hashes: &[SequenceHash]) {
assert_eq!(
external_sequence_hashes.len(),
self.blocks.len(),
"external_sequence_hashes length ({}) must match completed block count ({})",
external_sequence_hashes.len(),
self.blocks.len()
);
for (idx, block) in self.blocks.iter_mut().enumerate() {
let external_sequence_hash = external_sequence_hashes[idx];
let external_parent_sequence_hash = idx
.checked_sub(1)
.map(|parent_idx| external_sequence_hashes[parent_idx]);
match block.external_sequence_hash() {
Some(existing) => {
assert_eq!(
existing, external_sequence_hash,
"external_sequence_hash mismatch at block index {}",
idx
);
assert_eq!(
block.external_parent_sequence_hash(),
external_parent_sequence_hash,
"external_parent_sequence_hash mismatch at block index {}",
idx
);
}
None => block
.assign_external_hashes(external_sequence_hash, external_parent_sequence_hash),
}
}
}
/// Splits a [`Tokens`] object into a vector of completed blocks and a final partial block.
///
/// This is primarily used internally by [`TokenBlockSequence::new`] but can be used externally.
......@@ -1575,4 +1673,28 @@ mod tests {
assert_eq!(partial.tokens.len(), 4);
assert_eq!(remaining.len(), 6);
}
#[test]
fn test_sync_external_sequence_hashes_assigns_chain() {
let mut seq = create_test_sequence(&[1, 2, 3, 4, 5, 6, 7, 8], 4, Some(TEST_SALT_HASH));
let external_hashes = vec![100_u64, 200_u64];
seq.sync_external_sequence_hashes(&external_hashes);
assert_eq!(seq.blocks[0].external_sequence_hash(), Some(100));
assert_eq!(seq.blocks[0].external_parent_sequence_hash(), None);
assert_eq!(seq.blocks[1].external_sequence_hash(), Some(200));
assert_eq!(seq.blocks[1].external_parent_sequence_hash(), Some(100));
seq.blocks[0].assert_external_hashes_assigned();
seq.blocks[1].assert_external_hashes_assigned();
}
#[test]
#[should_panic(expected = "external_sequence_hash mismatch")]
fn test_sync_external_sequence_hashes_rejects_mismatched_existing_chain() {
let mut seq = create_test_sequence(&[1, 2, 3, 4, 5, 6, 7, 8], 4, Some(TEST_SALT_HASH));
seq.sync_external_sequence_hashes(&[100_u64, 200_u64]);
seq.sync_external_sequence_hashes(&[100_u64, 201_u64]);
}
}
......@@ -17,6 +17,7 @@ use dynamo_kv_router::queue::DEFAULT_MAX_BATCHED_TOKENS;
use dynamo_kv_router::{
ActiveSequencesMultiWorker, DefaultWorkerSelector, RadixTree, RouterSchedulingPolicy,
SchedulingPolicy, SchedulingRequest, SequenceRequest, WorkerSelector,
scheduling::TierOverlapBlocks,
};
use dynamo_tokens::SequenceHash;
use rustc_hash::FxHashMap;
......@@ -135,14 +136,35 @@ impl PendingRequest {
fn scheduling_request(
&self,
block_size: usize,
decode_blocks: FxHashMap<WorkerWithDpRank, usize>,
prefill_tokens: FxHashMap<WorkerWithDpRank, usize>,
) -> SchedulingRequest {
let effective_overlap_blocks = self
.overlaps
.scores
.iter()
.map(|(worker, overlap)| (*worker, *overlap as f64))
.collect();
let effective_cached_tokens = self
.overlaps
.scores
.iter()
.map(|(worker, overlap)| (*worker, *overlap as usize * block_size))
.collect();
SchedulingRequest {
maybe_request_id: Some(self.request_id()),
token_seq: self.token_seq.clone(),
isl_tokens: self.isl_tokens,
overlaps: self.overlaps.clone(),
tier_overlap_blocks: TierOverlapBlocks::default(),
effective_overlap_blocks,
effective_cached_tokens,
tree_sizes: self
.overlaps
.tree_sizes
.iter()
.map(|(k, v)| (*k, *v))
.collect(),
decode_blocks,
prefill_tokens,
track_prefill_tokens: self.track_prefill_tokens,
......@@ -216,7 +238,7 @@ impl OfflineReplayRouter {
let workers_with_configs = replay_workers_with_configs(args, num_workers);
let slots = replay_slots(args, &workers_with_configs);
let selector = replay_selector(&config);
let policy = replay_policy(&config, args);
let policy = replay_policy(&config);
let queue_threshold = config.router_queue_threshold;
Ok(Self {
......@@ -423,7 +445,11 @@ impl OfflineReplayRouter {
let arrival_offset = Duration::from_secs_f64((now_ms.max(0.0)) / 1000.0);
self.policy.enqueue_key(
arrival_offset,
&request.scheduling_request(FxHashMap::default(), FxHashMap::default()),
&request.scheduling_request(
self.block_size as usize,
FxHashMap::default(),
FxHashMap::default(),
),
)
}
......@@ -495,11 +521,19 @@ impl OfflineReplayRouter {
.potential_blocks_and_tokens_with_prefill_tracking(
request.token_seq.as_deref(),
request.isl_tokens,
request.overlaps.clone(),
request
.overlaps
.scores
.iter()
.map(|(worker, overlap)| {
(*worker, *overlap as usize * self.block_size as usize)
})
.collect(),
request.track_prefill_tokens,
decay_now,
);
let scheduling_request = request.scheduling_request(decode_blocks, prefill_tokens);
let scheduling_request =
request.scheduling_request(self.block_size as usize, decode_blocks, prefill_tokens);
let selection = self.selector.select_worker(
&self.workers_with_configs,
&scheduling_request,
......@@ -510,13 +544,13 @@ impl OfflineReplayRouter {
let request_id = request.request_id();
let prefill_load_hint = self.prefill_load_hint_for(
request.isl_tokens,
selection.overlap_blocks,
selection.cached_tokens,
request.track_prefill_tokens,
);
let isl_blocks = u32::try_from(request.isl_tokens.div_ceil(self.block_size as usize))
.unwrap_or(u32::MAX);
let overlap_blocks = selection.overlap_blocks;
let overlap_blocks = selection.effective_overlap_blocks.floor() as u32;
self.slots
.add_request(
......@@ -584,14 +618,14 @@ impl OfflineReplayRouter {
fn prefill_load_hint_for(
&self,
isl_tokens: usize,
overlap_blocks: u32,
cached_tokens: usize,
track_prefill_tokens: bool,
) -> Option<PrefillLoadHint> {
if !track_prefill_tokens {
return None;
}
let prefix = (overlap_blocks as usize) * (self.block_size as usize);
let prefix = cached_tokens.min(isl_tokens);
let effective_isl = isl_tokens.saturating_sub(prefix);
if effective_isl == 0 {
return None;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment