Unverified Commit e3d00b89 authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: Tier-based KV Routing (#8380)


Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent 7e48f3bd
...@@ -10,9 +10,12 @@ pub mod publisher; ...@@ -10,9 +10,12 @@ pub mod publisher;
pub mod subscriber; pub mod subscriber;
pub mod tracker; pub mod tracker;
pub use config::KvEventConsolidatorConfig; pub use config::{KvEventConsolidationMode, KvEventConsolidatorConfig};
pub use publisher::KvEventConsolidatorPublisher; pub use publisher::KvEventConsolidatorPublisher;
pub use tracker::{CacheStatusTracker, EventSource, StorageTier}; pub use tracker::{
CacheStatusTracker, DedupCacheStatusTracker, EventSource, PassthroughCacheStatusTracker,
StorageTier,
};
use anyhow::Result; use anyhow::Result;
use std::sync::Arc; use std::sync::Arc;
...@@ -21,11 +24,14 @@ use tokio::task::JoinHandle; ...@@ -21,11 +24,14 @@ use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use subscriber::start_simple_zmq_listener; use subscriber::start_simple_zmq_listener;
use tracker::{RemoveEventInput, StoreEventInput};
pub type SharedCacheStatusTracker = Arc<RwLock<Box<dyn CacheStatusTracker>>>;
/// Handle for KVBM to send G2/G3 events directly to the KV Event Consolidator /// Handle for KVBM to send G2/G3 events directly to the KV Event Consolidator
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct KvEventConsolidatorHandle { pub struct KvEventConsolidatorHandle {
pub(crate) tracker: Arc<RwLock<CacheStatusTracker>>, pub(crate) tracker: SharedCacheStatusTracker,
} }
impl KvEventConsolidatorHandle { impl KvEventConsolidatorHandle {
...@@ -45,7 +51,7 @@ impl KvEventConsolidatorHandle { ...@@ -45,7 +51,7 @@ impl KvEventConsolidatorHandle {
data_parallel_rank: Option<i32>, data_parallel_rank: Option<i32>,
) { ) {
let mut tracker = self.tracker.write().await; let mut tracker = self.tracker.write().await;
tracker.handle_store( tracker.handle_store(StoreEventInput {
block_hash, block_hash,
source, source,
token_ids, token_ids,
...@@ -54,15 +60,24 @@ impl KvEventConsolidatorHandle { ...@@ -54,15 +60,24 @@ impl KvEventConsolidatorHandle {
lora_name, lora_name,
tier, tier,
data_parallel_rank, data_parallel_rank,
); });
} }
/// Send a block remove event to the KV Event Consolidator /// Send a block remove event to the KV Event Consolidator
/// ///
/// This is called by KVBM when a block is removed from G2 or G3. /// This is called by KVBM when a block is removed from G2 or G3.
pub async fn handle_remove(&self, block_hash: &str, source: EventSource) { pub async fn handle_remove(
&self,
block_hash: &str,
source: EventSource,
tier: Option<StorageTier>,
) {
let mut tracker = self.tracker.write().await; let mut tracker = self.tracker.write().await;
tracker.handle_remove(block_hash, source); tracker.handle_remove(RemoveEventInput {
block_hash: block_hash.to_string(),
source,
tier,
});
} }
/// Clear all blocks from the KV Event Consolidator /// Clear all blocks from the KV Event Consolidator
...@@ -77,7 +92,7 @@ impl KvEventConsolidatorHandle { ...@@ -77,7 +92,7 @@ impl KvEventConsolidatorHandle {
/// The main KV Event Consolidator that manages the event flow /// The main KV Event Consolidator that manages the event flow
pub struct KvEventConsolidator { pub struct KvEventConsolidator {
config: KvEventConsolidatorConfig, config: KvEventConsolidatorConfig,
tracker: Arc<RwLock<CacheStatusTracker>>, tracker: SharedCacheStatusTracker,
subscriber_handle: Option<JoinHandle<()>>, subscriber_handle: Option<JoinHandle<()>>,
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
publisher: Option<KvEventConsolidatorPublisher>, publisher: Option<KvEventConsolidatorPublisher>,
...@@ -86,7 +101,11 @@ pub struct KvEventConsolidator { ...@@ -86,7 +101,11 @@ pub struct KvEventConsolidator {
impl KvEventConsolidator { impl KvEventConsolidator {
/// Create a new KV Event Consolidator /// Create a new KV Event Consolidator
pub fn new(config: KvEventConsolidatorConfig) -> Result<Self> { pub fn new(config: KvEventConsolidatorConfig) -> Result<Self> {
let tracker = Arc::new(RwLock::new(CacheStatusTracker::new())); let tracker: Box<dyn CacheStatusTracker> = match config.mode {
KvEventConsolidationMode::Dedup => Box::new(DedupCacheStatusTracker::new()),
KvEventConsolidationMode::Passthrough => Box::new(PassthroughCacheStatusTracker::new()),
};
let tracker = Arc::new(RwLock::new(tracker));
let cancellation_token = CancellationToken::new(); let cancellation_token = CancellationToken::new();
Ok(Self { Ok(Self {
...@@ -101,7 +120,8 @@ impl KvEventConsolidator { ...@@ -101,7 +120,8 @@ impl KvEventConsolidator {
/// Start the KV Event Consolidator /// Start the KV Event Consolidator
pub async fn start(&mut self) -> Result<()> { pub async fn start(&mut self) -> Result<()> {
tracing::info!( tracing::info!(
"Starting KV Event Consolidator: subscribe from {}, publish to ZMQ at {}", "Starting KV Event Consolidator in {} mode: subscribe from {}, publish to ZMQ at {}",
self.config.mode.as_str(),
self.config.engine_event_endpoint, self.config.engine_event_endpoint,
self.config.consolidated_event_endpoint self.config.consolidated_event_endpoint
); );
...@@ -152,7 +172,7 @@ impl KvEventConsolidator { ...@@ -152,7 +172,7 @@ impl KvEventConsolidator {
} }
/// Get a reference to the cache status tracker (for debugging/metrics) /// Get a reference to the cache status tracker (for debugging/metrics)
pub fn tracker(&self) -> Arc<RwLock<CacheStatusTracker>> { pub fn tracker(&self) -> SharedCacheStatusTracker {
self.tracker.clone() self.tracker.clone()
} }
......
...@@ -12,10 +12,10 @@ use rmp_serde::Serializer; ...@@ -12,10 +12,10 @@ use rmp_serde::Serializer;
use serde::Serialize; use serde::Serialize;
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::RwLock;
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
use super::tracker::{CacheStatusTracker, ConsolidatedEvent}; use super::SharedCacheStatusTracker;
use super::tracker::ConsolidatedEvent;
use crate::utils::zmq::{bind_pub_socket, send_multipart}; use crate::utils::zmq::{bind_pub_socket, send_multipart};
/// Event batch structure matching vLLM's format (array_like=True) /// Event batch structure matching vLLM's format (array_like=True)
...@@ -70,6 +70,7 @@ impl Event { ...@@ -70,6 +70,7 @@ impl Event {
block_size, block_size,
lora_name, lora_name,
source: _, source: _,
tier,
} => { } => {
let parsed_hash = block_hash let parsed_hash = block_hash
.parse::<u64>() .parse::<u64>()
...@@ -106,12 +107,13 @@ impl Event { ...@@ -106,12 +107,13 @@ impl Event {
token_ids: token_ids_i32, token_ids: token_ids_i32,
block_size: block_size_i32, block_size: block_size_i32,
lora_name, lora_name,
medium: None, medium: tier.map(|t| t.to_vllm_medium().to_string()),
}) })
} }
ConsolidatedEvent::Remove { ConsolidatedEvent::Remove {
block_hash, block_hash,
source: _, source: _,
tier,
} => { } => {
// Parse block hash - fail if invalid to prevent corruption // Parse block hash - fail if invalid to prevent corruption
let parsed_hash = block_hash.parse::<u64>().with_context(|| { let parsed_hash = block_hash.parse::<u64>().with_context(|| {
...@@ -120,7 +122,7 @@ impl Event { ...@@ -120,7 +122,7 @@ impl Event {
Ok(Event::BlockRemoved { Ok(Event::BlockRemoved {
block_hashes: vec![parsed_hash], block_hashes: vec![parsed_hash],
medium: None, // Not provided by ConsolidatedEvent medium: tier.map(|t| t.to_vllm_medium().to_string()),
}) })
} }
ConsolidatedEvent::ClearAll => Ok(Event::AllBlocksCleared {}), ConsolidatedEvent::ClearAll => Ok(Event::AllBlocksCleared {}),
...@@ -131,14 +133,14 @@ impl Event { ...@@ -131,14 +133,14 @@ impl Event {
/// ZMQ Publisher for consolidated events /// ZMQ Publisher for consolidated events
pub struct KvEventConsolidatorPublisher { pub struct KvEventConsolidatorPublisher {
endpoint: String, endpoint: String,
tracker: Arc<RwLock<CacheStatusTracker>>, tracker: SharedCacheStatusTracker,
sequence: Arc<AtomicU64>, sequence: Arc<AtomicU64>,
task_handle: Option<JoinHandle<()>>, task_handle: Option<JoinHandle<()>>,
} }
impl KvEventConsolidatorPublisher { impl KvEventConsolidatorPublisher {
/// Create a new publisher /// Create a new publisher
pub fn new(endpoint: &str, tracker: Arc<RwLock<CacheStatusTracker>>) -> Result<Self> { pub fn new(endpoint: &str, tracker: SharedCacheStatusTracker) -> Result<Self> {
let endpoint = endpoint.to_string(); let endpoint = endpoint.to_string();
let sequence = Arc::new(AtomicU64::new(0)); let sequence = Arc::new(AtomicU64::new(0));
...@@ -177,7 +179,7 @@ impl KvEventConsolidatorPublisher { ...@@ -177,7 +179,7 @@ impl KvEventConsolidatorPublisher {
/// Main publisher loop /// Main publisher loop
async fn run_publisher_loop( async fn run_publisher_loop(
endpoint: String, endpoint: String,
tracker: Arc<RwLock<CacheStatusTracker>>, tracker: SharedCacheStatusTracker,
sequence: Arc<AtomicU64>, sequence: Arc<AtomicU64>,
) -> Result<()> { ) -> Result<()> {
tracing::info!("Starting consolidated event publisher on {}", endpoint); tracing::info!("Starting consolidated event publisher on {}", endpoint);
...@@ -239,10 +241,13 @@ impl KvEventConsolidatorPublisher { ...@@ -239,10 +241,13 @@ impl KvEventConsolidatorPublisher {
Some(0), // data_parallel_rank (default) Some(0), // data_parallel_rank (default)
); );
// Serialize to msgpack // Serialize to msgpack.
// Keep the outer batch as a tuple [ts, events, dp_rank], but force
// inner struct-like events to use named fields so RawKvEvent's map
// deserializer can safely decode optional fields like `medium`.
let mut payload = Vec::new(); let mut payload = Vec::new();
batch batch
.serialize(&mut Serializer::new(&mut payload)) .serialize(&mut Serializer::new(&mut payload).with_struct_map())
.context("Failed to serialize event batch")?; .context("Failed to serialize event batch")?;
// Get sequence number // Get sequence number
...@@ -270,3 +275,55 @@ impl KvEventConsolidatorPublisher { ...@@ -270,3 +275,55 @@ impl KvEventConsolidatorPublisher {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
use dynamo_kv_router::zmq_wire::{KvEventBatch, RawKvEvent};
use rmp_serde as rmps;
#[test]
fn test_block_stored_with_medium_and_no_lora_name_decodes() {
let batch = EventBatch(
0.0,
vec![Event::BlockStored {
block_hashes: vec![42],
parent_block_hash: Some(7),
token_ids: vec![1, 2, 3, 4],
block_size: 4,
lora_name: None,
medium: Some("CPU_TIER1".to_string()),
}],
Some(0),
);
let mut payload = Vec::new();
batch
.serialize(&mut Serializer::new(&mut payload).with_struct_map())
.unwrap();
let decoded: KvEventBatch = rmps::from_slice(&payload).unwrap();
assert_eq!(decoded.events.len(), 1);
assert_eq!(decoded.data_parallel_rank, Some(0));
let RawKvEvent::BlockStored {
block_hashes,
parent_block_hash,
token_ids,
block_size,
medium,
lora_name,
..
} = &decoded.events[0]
else {
panic!("expected BlockStored");
};
assert_eq!(block_hashes.len(), 1);
assert!(parent_block_hash.is_some());
assert_eq!(token_ids, &vec![1, 2, 3, 4]);
assert_eq!(*block_size, 4);
assert_eq!(medium.as_deref(), Some("CPU_TIER1"));
assert_eq!(lora_name.as_deref(), None);
}
}
...@@ -9,14 +9,15 @@ use anyhow::{Context, Result}; ...@@ -9,14 +9,15 @@ use anyhow::{Context, Result};
use futures::StreamExt; use futures::StreamExt;
use rmp_serde::Deserializer; use rmp_serde::Deserializer;
use serde::Deserialize; use serde::Deserialize;
use std::sync::Arc;
use tokio::sync::RwLock;
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use dynamo_kv_router::zmq_wire::RawKvEvent; use dynamo_kv_router::zmq_wire::RawKvEvent;
use super::tracker::{CacheStatusTracker, EventSource, StorageTier}; use super::SharedCacheStatusTracker;
use super::tracker::{
CacheStatusTracker, EventSource, RemoveEventInput, StorageTier, StoreEventInput,
};
use crate::utils::zmq::{connect_sub_socket, multipart_message}; use crate::utils::zmq::{connect_sub_socket, multipart_message};
/// Event batch received from vLLM/TensorRT-LLM (array format) /// Event batch received from vLLM/TensorRT-LLM (array format)
...@@ -48,7 +49,7 @@ impl VllmEventBatch { ...@@ -48,7 +49,7 @@ impl VllmEventBatch {
/// Start ZMQ listener and process events into tracker /// Start ZMQ listener and process events into tracker
pub async fn start_simple_zmq_listener( pub async fn start_simple_zmq_listener(
endpoint: String, endpoint: String,
tracker: Arc<RwLock<CacheStatusTracker>>, tracker: SharedCacheStatusTracker,
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
engine_source: EventSource, engine_source: EventSource,
) -> Result<JoinHandle<()>> { ) -> Result<JoinHandle<()>> {
...@@ -65,7 +66,7 @@ pub async fn start_simple_zmq_listener( ...@@ -65,7 +66,7 @@ pub async fn start_simple_zmq_listener(
async fn run_listener_loop( async fn run_listener_loop(
endpoint: String, endpoint: String,
tracker: Arc<RwLock<CacheStatusTracker>>, tracker: SharedCacheStatusTracker,
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
engine_source: EventSource, engine_source: EventSource,
) -> Result<()> { ) -> Result<()> {
...@@ -136,7 +137,7 @@ async fn run_listener_loop( ...@@ -136,7 +137,7 @@ async fn run_listener_loop(
// Process events // Process events
let mut tracker_guard = tracker.write().await; let mut tracker_guard = tracker.write().await;
for event in batch.events() { for event in batch.events() {
process_event(&mut tracker_guard, event.clone(), dp_rank, engine_source); process_event(&mut **tracker_guard, event.clone(), dp_rank, engine_source);
} }
} }
} }
...@@ -146,7 +147,7 @@ async fn run_listener_loop( ...@@ -146,7 +147,7 @@ async fn run_listener_loop(
} }
fn process_event( fn process_event(
tracker: &mut CacheStatusTracker, tracker: &mut dyn CacheStatusTracker,
event: RawKvEvent, event: RawKvEvent,
data_parallel_rank: Option<i32>, data_parallel_rank: Option<i32>,
engine_source: EventSource, engine_source: EventSource,
...@@ -204,16 +205,16 @@ fn process_event( ...@@ -204,16 +205,16 @@ fn process_event(
let block_tokens = token_chunks[i].clone(); let block_tokens = token_chunks[i].clone();
let block_hash_u64 = block_hash.into_u64(); let block_hash_u64 = block_hash.into_u64();
tracker.handle_store( tracker.handle_store(StoreEventInput {
block_hash_u64.to_string(), block_hash: block_hash_u64.to_string(),
engine_source, source: engine_source,
block_tokens, token_ids: block_tokens,
current_parent.clone(), parent_hash: current_parent.clone(),
block_size, block_size,
lora_name.clone(), lora_name: lora_name.clone(),
Some(storage_tier), tier: Some(storage_tier),
data_parallel_rank, data_parallel_rank,
); });
// Next block's parent is this block (only if hash was valid) // Next block's parent is this block (only if hash was valid)
current_parent = Some(block_hash_u64.to_string()); current_parent = Some(block_hash_u64.to_string());
...@@ -233,7 +234,11 @@ fn process_event( ...@@ -233,7 +234,11 @@ fn process_event(
); );
for block_hash in block_hashes { for block_hash in block_hashes {
tracker.handle_remove(&block_hash.into_u64().to_string(), engine_source); tracker.handle_remove(RemoveEventInput {
block_hash: block_hash.into_u64().to_string(),
source: engine_source,
tier: Some(storage_tier),
});
} }
} }
......
...@@ -140,10 +140,6 @@ impl From<RouterStorageTier> for StorageTier { ...@@ -140,10 +140,6 @@ impl From<RouterStorageTier> for StorageTier {
} }
} }
/// Legacy type alias for backward compatibility
#[deprecated(note = "Use StorageTier instead")]
pub type StorageMedium = StorageTier;
/// Minimal metadata for tracking which event sources have a block /// Minimal metadata for tracking which event sources have a block
/// All other metadata (tokens, parent, etc.) is stored in the ConsolidatedEvent when queued /// All other metadata (tokens, parent, etc.) is stored in the ConsolidatedEvent when queued
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
...@@ -195,65 +191,58 @@ pub enum ConsolidatedEvent { ...@@ -195,65 +191,58 @@ pub enum ConsolidatedEvent {
block_size: usize, block_size: usize,
lora_name: Option<String>, lora_name: Option<String>,
source: String, source: String,
tier: Option<StorageTier>,
}, },
/// Block removed (removed from all sources) /// Block removed (removed from all sources)
Remove { Remove {
block_hash: String, block_hash: String,
source: String, // The source where it was last removed source: String, // The source where it was last removed
tier: Option<StorageTier>,
}, },
/// All blocks cleared /// All blocks cleared
ClearAll, ClearAll,
} }
/// Cache Status Tracker #[derive(Debug, Clone)]
/// pub struct StoreEventInput {
/// Deduplication logic: pub block_hash: String,
/// - Uses SequenceHash (computed from tokens + parent) as the key for deduplication pub source: EventSource,
/// - SequenceHash is position-aware: same tokens at different positions = different keys pub token_ids: Vec<u32>,
/// - Always uses KVBM's xxHash3 hashing function, regardless of source pub parent_hash: Option<String>,
/// - This allows vLLM and KVBM blocks at the same position to be deduplicated pub block_size: usize,
/// - Emit Store: Only when a block is first stored from ANY source pub lora_name: Option<String>,
/// - Emit Remove: Only when a block is removed from ALL sources pub tier: Option<StorageTier>,
#[derive(Debug)] pub data_parallel_rank: Option<i32>,
pub struct CacheStatusTracker { }
/// Map of SequenceHash -> BlockMetadata (tracking which sources have this block)
/// The key is position-aware: includes parent context
blocks: HashMap<SequenceHash, BlockMetadata>,
/// Reverse mapping: external_block_hash -> SequenceHash (that we computed) #[derive(Debug, Clone)]
/// Needed because remove events only provide external hash, not token IDs pub struct RemoveEventInput {
/// Maps each source's external hash to our computed sequence hash pub block_hash: String,
hash_mapping: HashMap<String, SequenceHash>, pub source: EventSource,
pub tier: Option<StorageTier>,
}
/// Queue of events to be published pub trait CacheStatusTracker: std::fmt::Debug + Send + Sync {
event_queue: Vec<ConsolidatedEvent>, fn handle_store(&mut self, event: StoreEventInput) -> bool;
fn handle_remove(&mut self, event: RemoveEventInput) -> bool;
fn handle_clear_all(&mut self);
fn drain_events(&mut self) -> Vec<ConsolidatedEvent>;
fn num_blocks(&self) -> usize;
} }
impl Default for CacheStatusTracker { /// Deduplicating cache-status tracker.
fn default() -> Self { #[derive(Debug, Default)]
Self::new() pub struct DedupCacheStatusTracker {
} blocks: HashMap<SequenceHash, BlockMetadata>,
hash_mapping: HashMap<String, SequenceHash>,
event_queue: Vec<ConsolidatedEvent>,
} }
impl CacheStatusTracker { impl DedupCacheStatusTracker {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self::default()
blocks: HashMap::new(),
hash_mapping: HashMap::new(),
event_queue: Vec::new(),
}
} }
/// Handle a STORE event
///
/// Returns true if a consolidated STORE event should be published.
/// Only publishes when a block is stored for the FIRST TIME from ANY source.
///
/// # Arguments
/// * `block_hash` - The external block hash (from vLLM or KVBM)
/// * `source` - The event source (vLLM or KVBM) that stored this block
/// * `token_ids` - The token IDs in this block (for content-based deduplication)
/// * `tier` - Optional storage tier information (for metadata/debugging)
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub fn handle_store( pub fn handle_store(
&mut self, &mut self,
...@@ -266,16 +255,65 @@ impl CacheStatusTracker { ...@@ -266,16 +255,65 @@ impl CacheStatusTracker {
tier: Option<StorageTier>, tier: Option<StorageTier>,
data_parallel_rank: Option<i32>, data_parallel_rank: Option<i32>,
) -> bool { ) -> bool {
// Compute LocalBlockHash from token IDs (content only) CacheStatusTracker::handle_store(
let local_block_hash = compute_local_block_hash(&token_ids); self,
StoreEventInput {
block_hash,
source,
token_ids,
parent_hash,
block_size,
lora_name,
tier,
data_parallel_rank,
},
)
}
pub fn handle_remove(
&mut self,
block_hash: &str,
source: EventSource,
tier: Option<StorageTier>,
) -> bool {
CacheStatusTracker::handle_remove(
self,
RemoveEventInput {
block_hash: block_hash.to_string(),
source,
tier,
},
)
}
pub fn get_block_sources(&self, external_block_hash: &str) -> Option<&HashSet<EventSource>> {
let local_hash = self.hash_mapping.get(external_block_hash)?;
self.blocks.get(local_hash).map(|m| &m.sources)
}
#[deprecated(note = "Use get_block_sources instead")]
pub fn get_block_tiers(&self, block_hash: &str) -> Option<&HashSet<EventSource>> {
self.get_block_sources(block_hash)
}
}
impl CacheStatusTracker for DedupCacheStatusTracker {
fn handle_store(&mut self, event: StoreEventInput) -> bool {
let StoreEventInput {
block_hash,
source,
token_ids,
parent_hash,
block_size,
lora_name,
tier,
data_parallel_rank,
} = event;
// Resolve parent sequence hash from parent's external hash (if provided) let local_block_hash = compute_local_block_hash(&token_ids);
let parent_sequence_hash = parent_hash let parent_sequence_hash = parent_hash
.as_ref() .as_ref()
.and_then(|ph| self.hash_mapping.get(ph).copied()); .and_then(|ph| self.hash_mapping.get(ph).copied());
// Compute SequenceHash using KVBM's hashing method (position-aware)
// This ensures consistent deduplication regardless of source
let sequence_hash = compute_sequence_hash(parent_sequence_hash, local_block_hash); let sequence_hash = compute_sequence_hash(parent_sequence_hash, local_block_hash);
tracing::debug!( tracing::debug!(
...@@ -286,11 +324,7 @@ impl CacheStatusTracker { ...@@ -286,11 +324,7 @@ impl CacheStatusTracker {
); );
if let Some(metadata) = self.blocks.get_mut(&sequence_hash) { if let Some(metadata) = self.blocks.get_mut(&sequence_hash) {
// Block already exists from another source (deduplication!), just add the new source
let is_new_source = metadata.add_source(source); let is_new_source = metadata.add_source(source);
// Add this external hash to the mapping so remove events from this source can find the block
// Multiple external hashes (from different sources) can map to the same SequenceHash
self.hash_mapping.insert(block_hash.clone(), sequence_hash); self.hash_mapping.insert(block_hash.clone(), sequence_hash);
if is_new_source { if is_new_source {
...@@ -314,10 +348,8 @@ impl CacheStatusTracker { ...@@ -314,10 +348,8 @@ impl CacheStatusTracker {
&token_ids &token_ids
); );
} }
// Don't publish a new STORE event (block already exists)
false false
} else { } else {
// First time seeing this block from any source - create metadata and queue STORE event
let metadata = BlockMetadata::new(source, block_hash.clone()); let metadata = BlockMetadata::new(source, block_hash.clone());
tracing::debug!( tracing::debug!(
...@@ -338,34 +370,16 @@ impl CacheStatusTracker { ...@@ -338,34 +370,16 @@ impl CacheStatusTracker {
); );
self.blocks.insert(sequence_hash, metadata); self.blocks.insert(sequence_hash, metadata);
// Add to hash mapping so remove events can find the block by external hash
self.hash_mapping.insert(block_hash.clone(), sequence_hash); self.hash_mapping.insert(block_hash.clone(), sequence_hash);
// Resolve parent_hash to first_block_hash if parent was deduplicated
//
// Problem: When the same block is stored from multiple sources (deduplication),
// each source may use a different external hash for the same logical block.
// Example:
// - Source A (TRTLLM) stores parent with hash "hash_A"
// - Source B (KVBM) stores same parent with hash "hash_B" (different format/algorithm)
// - Router only received STORE event with "hash_A" (first source)
// - When Source B stores child with parent_hash="hash_B", router won't recognize it
//
// Resolve the parent's external hash to its first_block_hash (the hash
// that was sent to the router in the first STORE event) so the router can find it.
let resolved_parent_hash = parent_hash.and_then(|ph| { let resolved_parent_hash = parent_hash.and_then(|ph| {
// Look up parent's sequence hash from its external hash
self.hash_mapping.get(&ph).and_then(|&parent_seq_hash| { self.hash_mapping.get(&ph).and_then(|&parent_seq_hash| {
// Get parent's metadata to find first_block_hash
self.blocks self.blocks
.get(&parent_seq_hash) .get(&parent_seq_hash)
.map(|parent_metadata| parent_metadata.first_block_hash.clone()) .map(|parent_metadata| parent_metadata.first_block_hash.clone())
}) })
}); });
// Queue a STORE event with full metadata
// Use resolved_parent_hash (first_block_hash) so router can find the parent
self.event_queue.push(ConsolidatedEvent::Store { self.event_queue.push(ConsolidatedEvent::Store {
block_hash: block_hash.clone(), block_hash: block_hash.clone(),
parent_hash: resolved_parent_hash, parent_hash: resolved_parent_hash,
...@@ -373,6 +387,7 @@ impl CacheStatusTracker { ...@@ -373,6 +387,7 @@ impl CacheStatusTracker {
block_size, block_size,
lora_name, lora_name,
source: source.to_str().to_string(), source: source.to_str().to_string(),
tier,
}); });
tracing::debug!( tracing::debug!(
...@@ -387,17 +402,14 @@ impl CacheStatusTracker { ...@@ -387,17 +402,14 @@ impl CacheStatusTracker {
} }
} }
/// Handle a REMOVE event fn handle_remove(&mut self, event: RemoveEventInput) -> bool {
/// let RemoveEventInput {
/// Returns true if a consolidated REMOVE event should be published. block_hash,
/// Only publishes when a block is removed from ALL sources. source,
/// tier,
/// # Arguments } = event;
/// * `block_hash` - The external block hash to remove
/// * `source` - The event source (vLLM or KVBM) that removed this block let sequence_hash = match self.hash_mapping.get(&block_hash) {
pub fn handle_remove(&mut self, block_hash: &str, source: EventSource) -> bool {
// Look up the SequenceHash from the external block hash
let sequence_hash = match self.hash_mapping.get(block_hash) {
Some(&hash) => hash, Some(&hash) => hash,
None => { None => {
tracing::warn!( tracing::warn!(
...@@ -410,7 +422,6 @@ impl CacheStatusTracker { ...@@ -410,7 +422,6 @@ impl CacheStatusTracker {
}; };
if let Some(metadata) = self.blocks.get_mut(&sequence_hash) { if let Some(metadata) = self.blocks.get_mut(&sequence_hash) {
// Remove the source
let was_removed = metadata.remove_source(source); let was_removed = metadata.remove_source(source);
if !was_removed { if !was_removed {
tracing::warn!( tracing::warn!(
...@@ -421,11 +432,7 @@ impl CacheStatusTracker { ...@@ -421,11 +432,7 @@ impl CacheStatusTracker {
return false; return false;
} }
// Remove this external hash immediately when the source removes it self.hash_mapping.remove(&block_hash);
// This keeps hash_mapping clean
// Each external hash belongs to exactly one source, so when that source
// removes the block, we can safely remove the hash_mapping entry
self.hash_mapping.remove(block_hash);
tracing::debug!( tracing::debug!(
"Removed hash_mapping entry for {} (hash_mapping size: {})", "Removed hash_mapping entry for {} (hash_mapping size: {})",
...@@ -433,17 +440,13 @@ impl CacheStatusTracker { ...@@ -433,17 +440,13 @@ impl CacheStatusTracker {
self.hash_mapping.len() self.hash_mapping.len()
); );
// Check if this was the last source
if !metadata.exists_in_any_source() { if !metadata.exists_in_any_source() {
// Block is gone from all sources - remove from tracker and publish REMOVE
let first_block_hash = metadata.first_block_hash.clone(); let first_block_hash = metadata.first_block_hash.clone();
self.blocks.remove(&sequence_hash); self.blocks.remove(&sequence_hash);
// Double-check: clean up any stray hash mappings (should be empty by now)
// This is a safety check
let stray_count_before = self.hash_mapping.len(); let stray_count_before = self.hash_mapping.len();
self.hash_mapping self.hash_mapping
.retain(|_ext_hash, &mut seq_hash| seq_hash != sequence_hash); .retain(|_ext_hash, seq_hash| *seq_hash != sequence_hash);
let stray_count = stray_count_before - self.hash_mapping.len(); let stray_count = stray_count_before - self.hash_mapping.len();
if stray_count > 0 { if stray_count > 0 {
...@@ -458,6 +461,7 @@ impl CacheStatusTracker { ...@@ -458,6 +461,7 @@ impl CacheStatusTracker {
self.event_queue.push(ConsolidatedEvent::Remove { self.event_queue.push(ConsolidatedEvent::Remove {
block_hash: first_block_hash.clone(), block_hash: first_block_hash.clone(),
source: source.to_str().to_string(), source: source.to_str().to_string(),
tier,
}); });
tracing::debug!( tracing::debug!(
...@@ -470,7 +474,6 @@ impl CacheStatusTracker { ...@@ -470,7 +474,6 @@ impl CacheStatusTracker {
); );
true true
} else { } else {
// Block still exists in other sources
tracing::debug!( tracing::debug!(
"Block {} (seq_hash={}) removed from source {:?}, still in {} source(s): {:?} (hash_mapping: {})", "Block {} (seq_hash={}) removed from source {:?}, still in {} source(s): {:?} (hash_mapping: {})",
&metadata.first_block_hash[..16.min(metadata.first_block_hash.len())], &metadata.first_block_hash[..16.min(metadata.first_block_hash.len())],
...@@ -492,8 +495,7 @@ impl CacheStatusTracker { ...@@ -492,8 +495,7 @@ impl CacheStatusTracker {
} }
} }
/// Handle a CLEAR_ALL event fn handle_clear_all(&mut self) {
pub fn handle_clear_all(&mut self) {
let num_blocks = self.blocks.len(); let num_blocks = self.blocks.len();
tracing::debug!("Clearing all {} blocks from tracker", num_blocks); tracing::debug!("Clearing all {} blocks from tracker", num_blocks);
self.blocks.clear(); self.blocks.clear();
...@@ -501,8 +503,7 @@ impl CacheStatusTracker { ...@@ -501,8 +503,7 @@ impl CacheStatusTracker {
self.event_queue.push(ConsolidatedEvent::ClearAll); self.event_queue.push(ConsolidatedEvent::ClearAll);
} }
/// Drain all pending events to be published fn drain_events(&mut self) -> Vec<ConsolidatedEvent> {
pub fn drain_events(&mut self) -> Vec<ConsolidatedEvent> {
let events = std::mem::take(&mut self.event_queue); let events = std::mem::take(&mut self.event_queue);
if !events.is_empty() { if !events.is_empty() {
tracing::debug!( tracing::debug!(
...@@ -513,22 +514,106 @@ impl CacheStatusTracker { ...@@ -513,22 +514,106 @@ impl CacheStatusTracker {
events events
} }
/// Get the number of tracked blocks fn num_blocks(&self) -> usize {
pub fn num_blocks(&self) -> usize {
self.blocks.len() self.blocks.len()
} }
}
/// Get sources for a specific block by external block hash /// Pass-through cache-status tracker.
pub fn get_block_sources(&self, external_block_hash: &str) -> Option<&HashSet<EventSource>> { #[derive(Debug, Default)]
// Look up the local hash from external hash, then get sources pub struct PassthroughCacheStatusTracker {
let local_hash = self.hash_mapping.get(external_block_hash)?; event_queue: Vec<ConsolidatedEvent>,
self.blocks.get(local_hash).map(|m| &m.sources) }
impl PassthroughCacheStatusTracker {
pub fn new() -> Self {
Self::default()
} }
/// Legacy method for backwards compatibility #[allow(clippy::too_many_arguments)]
#[deprecated(note = "Use get_block_sources instead")] pub fn handle_store(
pub fn get_block_tiers(&self, block_hash: &str) -> Option<&HashSet<EventSource>> { &mut self,
self.get_block_sources(block_hash) block_hash: String,
source: EventSource,
token_ids: Vec<u32>,
parent_hash: Option<String>,
block_size: usize,
lora_name: Option<String>,
tier: Option<StorageTier>,
data_parallel_rank: Option<i32>,
) -> bool {
CacheStatusTracker::handle_store(
self,
StoreEventInput {
block_hash,
source,
token_ids,
parent_hash,
block_size,
lora_name,
tier,
data_parallel_rank,
},
)
}
pub fn handle_remove(
&mut self,
block_hash: &str,
source: EventSource,
tier: Option<StorageTier>,
) -> bool {
CacheStatusTracker::handle_remove(
self,
RemoveEventInput {
block_hash: block_hash.to_string(),
source,
tier,
},
)
}
}
impl CacheStatusTracker for PassthroughCacheStatusTracker {
fn handle_store(&mut self, event: StoreEventInput) -> bool {
self.event_queue.push(ConsolidatedEvent::Store {
block_hash: event.block_hash,
parent_hash: event.parent_hash,
token_ids: event.token_ids,
block_size: event.block_size,
lora_name: event.lora_name,
source: event.source.to_str().to_string(),
tier: event.tier,
});
true
}
fn handle_remove(&mut self, event: RemoveEventInput) -> bool {
self.event_queue.push(ConsolidatedEvent::Remove {
block_hash: event.block_hash,
source: event.source.to_str().to_string(),
tier: event.tier,
});
true
}
fn handle_clear_all(&mut self) {
self.event_queue.push(ConsolidatedEvent::ClearAll);
}
fn drain_events(&mut self) -> Vec<ConsolidatedEvent> {
let events = std::mem::take(&mut self.event_queue);
if !events.is_empty() {
tracing::debug!(
"Draining {} pending kv event(s) for publishing",
events.len()
);
}
events
}
fn num_blocks(&self) -> usize {
0
} }
} }
...@@ -536,9 +621,11 @@ impl CacheStatusTracker { ...@@ -536,9 +621,11 @@ impl CacheStatusTracker {
mod tests { mod tests {
use super::*; use super::*;
type TestTracker = DedupCacheStatusTracker;
#[test] #[test]
fn test_first_store_publishes() { fn test_first_store_publishes() {
let mut tracker = CacheStatusTracker::new(); let mut tracker = TestTracker::new();
let should_publish = tracker.handle_store( let should_publish = tracker.handle_store(
"hash1".to_string(), "hash1".to_string(),
...@@ -558,7 +645,7 @@ mod tests { ...@@ -558,7 +645,7 @@ mod tests {
#[test] #[test]
fn test_duplicate_store_no_publish() { fn test_duplicate_store_no_publish() {
let mut tracker = CacheStatusTracker::new(); let mut tracker = TestTracker::new();
tracker.handle_store( tracker.handle_store(
"hash1".to_string(), "hash1".to_string(),
...@@ -589,7 +676,7 @@ mod tests { ...@@ -589,7 +676,7 @@ mod tests {
#[test] #[test]
fn test_multi_source_store() { fn test_multi_source_store() {
let mut tracker = CacheStatusTracker::new(); let mut tracker = TestTracker::new();
// First store from vLLM // First store from vLLM
tracker.handle_store( tracker.handle_store(
...@@ -624,7 +711,7 @@ mod tests { ...@@ -624,7 +711,7 @@ mod tests {
#[test] #[test]
fn test_remove_from_single_source_publishes() { fn test_remove_from_single_source_publishes() {
let mut tracker = CacheStatusTracker::new(); let mut tracker = TestTracker::new();
tracker.handle_store( tracker.handle_store(
"hash1".to_string(), "hash1".to_string(),
...@@ -638,18 +725,24 @@ mod tests { ...@@ -638,18 +725,24 @@ mod tests {
); );
tracker.drain_events(); tracker.drain_events();
let should_publish = tracker.handle_remove("hash1", EventSource::Vllm); let should_publish =
tracker.handle_remove("hash1", EventSource::Vllm, Some(StorageTier::Device));
assert!(should_publish); assert!(should_publish);
assert_eq!(tracker.num_blocks(), 0); assert_eq!(tracker.num_blocks(), 0);
let events = tracker.drain_events(); let events = tracker.drain_events();
assert_eq!(events.len(), 1); assert_eq!(events.len(), 1);
matches!(events[0], ConsolidatedEvent::Remove { .. }); match &events[0] {
ConsolidatedEvent::Remove { tier, .. } => {
assert_eq!(*tier, Some(StorageTier::Device));
}
other => panic!("expected Remove event, got: {:?}", other),
}
} }
#[test] #[test]
fn test_remove_from_multi_source_no_publish() { fn test_remove_from_multi_source_no_publish() {
let mut tracker = CacheStatusTracker::new(); let mut tracker = TestTracker::new();
// Store from vLLM - first STORE event published // Store from vLLM - first STORE event published
tracker.handle_store( tracker.handle_store(
...@@ -676,14 +769,19 @@ mod tests { ...@@ -676,14 +769,19 @@ mod tests {
tracker.drain_events(); tracker.drain_events();
// Remove from vLLM - should not publish (still in KVBM) // Remove from vLLM - should not publish (still in KVBM)
let should_publish = tracker.handle_remove("vllm_hash1", EventSource::Vllm); let should_publish =
tracker.handle_remove("vllm_hash1", EventSource::Vllm, Some(StorageTier::Device));
assert!(!should_publish); assert!(!should_publish);
assert_eq!(tracker.num_blocks(), 1); assert_eq!(tracker.num_blocks(), 1);
assert_eq!(tracker.drain_events().len(), 0); assert_eq!(tracker.drain_events().len(), 0);
// Remove from KVBM (last source) - should publish REMOVE event // Remove from KVBM (last source) - should publish REMOVE event
let should_publish = tracker.handle_remove("kvbm_hash1", EventSource::Kvbm); let should_publish = tracker.handle_remove(
"kvbm_hash1",
EventSource::Kvbm,
Some(StorageTier::HostPinned),
);
assert!(should_publish); assert!(should_publish);
assert_eq!(tracker.num_blocks(), 0); assert_eq!(tracker.num_blocks(), 0);
...@@ -691,7 +789,7 @@ mod tests { ...@@ -691,7 +789,7 @@ mod tests {
#[test] #[test]
fn test_sequence_hash_first_block() { fn test_sequence_hash_first_block() {
let mut tracker = CacheStatusTracker::new(); let mut tracker = TestTracker::new();
// First block (no parent) // First block (no parent)
let should_publish = tracker.handle_store( let should_publish = tracker.handle_store(
...@@ -714,7 +812,7 @@ mod tests { ...@@ -714,7 +812,7 @@ mod tests {
#[test] #[test]
fn test_sequence_hash_with_parent() { fn test_sequence_hash_with_parent() {
let mut tracker = CacheStatusTracker::new(); let mut tracker = TestTracker::new();
// First block // First block
tracker.handle_store( tracker.handle_store(
...@@ -747,7 +845,7 @@ mod tests { ...@@ -747,7 +845,7 @@ mod tests {
#[test] #[test]
fn test_same_tokens_different_position_different_blocks() { fn test_same_tokens_different_position_different_blocks() {
let mut tracker = CacheStatusTracker::new(); let mut tracker = TestTracker::new();
// First occurrence: tokens [1,2,3,4] at position 0 (no parent) // First occurrence: tokens [1,2,3,4] at position 0 (no parent)
tracker.handle_store( tracker.handle_store(
...@@ -782,7 +880,7 @@ mod tests { ...@@ -782,7 +880,7 @@ mod tests {
#[test] #[test]
fn test_clear_all() { fn test_clear_all() {
let mut tracker = CacheStatusTracker::new(); let mut tracker = TestTracker::new();
// Add multiple blocks // Add multiple blocks
tracker.handle_store( tracker.handle_store(
...@@ -815,13 +913,14 @@ mod tests { ...@@ -815,13 +913,14 @@ mod tests {
assert_eq!(tracker.num_blocks(), 0); assert_eq!(tracker.num_blocks(), 0);
// Verify hash_mapping is also cleared // Verify hash_mapping is also cleared
let should_publish = tracker.handle_remove("block1", EventSource::Vllm); let should_publish =
tracker.handle_remove("block1", EventSource::Vllm, Some(StorageTier::Device));
assert!(!should_publish); // Should fail because block is gone assert!(!should_publish); // Should fail because block is gone
} }
#[test] #[test]
fn test_deduplication_across_sources_with_parent() { fn test_deduplication_across_sources_with_parent() {
let mut tracker = CacheStatusTracker::new(); let mut tracker = TestTracker::new();
// vLLM stores block 1 (parent) // vLLM stores block 1 (parent)
tracker.handle_store( tracker.handle_store(
...@@ -870,9 +969,10 @@ mod tests { ...@@ -870,9 +969,10 @@ mod tests {
#[test] #[test]
fn test_remove_non_existent_block() { fn test_remove_non_existent_block() {
let mut tracker = CacheStatusTracker::new(); let mut tracker = TestTracker::new();
let should_publish = tracker.handle_remove("non_existent", EventSource::Vllm); let should_publish =
tracker.handle_remove("non_existent", EventSource::Vllm, Some(StorageTier::Device));
assert!(!should_publish); assert!(!should_publish);
assert_eq!(tracker.num_blocks(), 0); assert_eq!(tracker.num_blocks(), 0);
...@@ -896,7 +996,7 @@ mod tests { ...@@ -896,7 +996,7 @@ mod tests {
#[test] #[test]
fn test_lora_name_round_trip_through_tracker() { fn test_lora_name_round_trip_through_tracker() {
let mut tracker = CacheStatusTracker::new(); let mut tracker = TestTracker::new();
let should_publish = tracker.handle_store( let should_publish = tracker.handle_store(
"hash_lora".to_string(), "hash_lora".to_string(),
...@@ -916,10 +1016,12 @@ mod tests { ...@@ -916,10 +1016,12 @@ mod tests {
ConsolidatedEvent::Store { ConsolidatedEvent::Store {
lora_name, lora_name,
token_ids, token_ids,
tier,
.. ..
} => { } => {
assert_eq!(lora_name.as_deref(), Some("my-adapter")); assert_eq!(lora_name.as_deref(), Some("my-adapter"));
assert_eq!(token_ids, &[1, 2, 3, 4]); assert_eq!(token_ids, &[1, 2, 3, 4]);
assert_eq!(*tier, Some(StorageTier::Device));
} }
other => panic!("expected Store event, got: {:?}", other), other => panic!("expected Store event, got: {:?}", other),
} }
...@@ -927,7 +1029,7 @@ mod tests { ...@@ -927,7 +1029,7 @@ mod tests {
#[test] #[test]
fn test_lora_name_none_for_base_model() { fn test_lora_name_none_for_base_model() {
let mut tracker = CacheStatusTracker::new(); let mut tracker = TestTracker::new();
tracker.handle_store( tracker.handle_store(
"hash_base".to_string(), "hash_base".to_string(),
...@@ -943,8 +1045,11 @@ mod tests { ...@@ -943,8 +1045,11 @@ mod tests {
let events = tracker.drain_events(); let events = tracker.drain_events();
assert_eq!(events.len(), 1); assert_eq!(events.len(), 1);
match &events[0] { match &events[0] {
ConsolidatedEvent::Store { lora_name, .. } => { ConsolidatedEvent::Store {
lora_name, tier, ..
} => {
assert!(lora_name.is_none()); assert!(lora_name.is_none());
assert_eq!(*tier, Some(StorageTier::Device));
} }
other => panic!("expected Store event, got: {:?}", other), other => panic!("expected Store event, got: {:?}", other),
} }
...@@ -971,4 +1076,53 @@ mod tests { ...@@ -971,4 +1076,53 @@ mod tests {
let seq_hash2_different = compute_sequence_hash(Some(different_parent), block_hash2); let seq_hash2_different = compute_sequence_hash(Some(different_parent), block_hash2);
assert_ne!(seq_hash2_v1, seq_hash2_different); assert_ne!(seq_hash2_v1, seq_hash2_different);
} }
#[test]
fn test_passthrough_tracker_forwards_duplicate_store() {
let mut tracker = PassthroughCacheStatusTracker::new();
assert!(tracker.handle_store(
"hash1".to_string(),
EventSource::Vllm,
vec![1, 2, 3],
None,
3,
None,
Some(StorageTier::Device),
None,
));
assert!(tracker.handle_store(
"hash1".to_string(),
EventSource::Vllm,
vec![1, 2, 3],
None,
3,
None,
Some(StorageTier::Device),
None,
));
let events = tracker.drain_events();
assert_eq!(events.len(), 2);
}
#[test]
fn test_passthrough_tracker_remove_and_clear() {
let mut tracker = PassthroughCacheStatusTracker::new();
assert!(tracker.handle_remove("hash1", EventSource::Kvbm, Some(StorageTier::HostPinned),));
tracker.handle_clear_all();
let events = tracker.drain_events();
assert_eq!(events.len(), 2);
assert!(matches!(
&events[0],
ConsolidatedEvent::Remove {
tier: Some(StorageTier::HostPinned),
..
}
));
assert!(matches!(&events[1], ConsolidatedEvent::ClearAll));
assert_eq!(tracker.num_blocks(), 0);
}
} }
...@@ -51,6 +51,7 @@ pub mod inactive; ...@@ -51,6 +51,7 @@ pub mod inactive;
pub mod priority_key; pub mod priority_key;
pub mod state; pub mod state;
use crate::block_manager::kv_consolidator::StorageTier;
use active::ActiveBlockPool; use active::ActiveBlockPool;
use inactive::InactiveBlockPool; use inactive::InactiveBlockPool;
...@@ -72,6 +73,9 @@ pub struct ManagedBlockPoolArgs<S: Storage, L: LocalityProvider, M: BlockMetadat ...@@ -72,6 +73,9 @@ pub struct ManagedBlockPoolArgs<S: Storage, L: LocalityProvider, M: BlockMetadat
#[builder(default = "Handle::current()")] #[builder(default = "Handle::current()")]
async_runtime: Handle, async_runtime: Handle,
#[builder(default = "StorageTier::Device")]
storage_tier: StorageTier,
#[builder(default = "BlockRegistrationDuplicationSetting::Disabled")] #[builder(default = "BlockRegistrationDuplicationSetting::Disabled")]
default_duplication_setting: BlockRegistrationDuplicationSetting, default_duplication_setting: BlockRegistrationDuplicationSetting,
} }
...@@ -85,6 +89,7 @@ impl<S: Storage, L: LocalityProvider, M: BlockMetadata> ManagedBlockPoolArgsBuil ...@@ -85,6 +89,7 @@ impl<S: Storage, L: LocalityProvider, M: BlockMetadata> ManagedBlockPoolArgsBuil
blocks, blocks,
global_registry, global_registry,
async_runtime, async_runtime,
storage_tier,
default_duplication_setting, default_duplication_setting,
) = args.dissolve(); ) = args.dissolve();
...@@ -95,6 +100,7 @@ impl<S: Storage, L: LocalityProvider, M: BlockMetadata> ManagedBlockPoolArgsBuil ...@@ -95,6 +100,7 @@ impl<S: Storage, L: LocalityProvider, M: BlockMetadata> ManagedBlockPoolArgsBuil
blocks, blocks,
global_registry, global_registry,
async_runtime, async_runtime,
storage_tier,
default_duplication_setting, default_duplication_setting,
); );
...@@ -176,6 +182,7 @@ impl<S: Storage, L: LocalityProvider, M: BlockMetadata> ManagedBlockPool<S, L, M ...@@ -176,6 +182,7 @@ impl<S: Storage, L: LocalityProvider, M: BlockMetadata> ManagedBlockPool<S, L, M
blocks: Vec<Block<S, L, M>>, blocks: Vec<Block<S, L, M>>,
global_registry: GlobalRegistry, global_registry: GlobalRegistry,
async_runtime: Handle, async_runtime: Handle,
storage_tier: StorageTier,
default_duplication_setting: BlockRegistrationDuplicationSetting, default_duplication_setting: BlockRegistrationDuplicationSetting,
) -> Self { ) -> Self {
let (pool, progress_engine) = Self::with_progress_engine( let (pool, progress_engine) = Self::with_progress_engine(
...@@ -184,6 +191,7 @@ impl<S: Storage, L: LocalityProvider, M: BlockMetadata> ManagedBlockPool<S, L, M ...@@ -184,6 +191,7 @@ impl<S: Storage, L: LocalityProvider, M: BlockMetadata> ManagedBlockPool<S, L, M
blocks, blocks,
global_registry, global_registry,
async_runtime, async_runtime,
storage_tier,
default_duplication_setting, default_duplication_setting,
); );
...@@ -228,6 +236,7 @@ impl<S: Storage, L: LocalityProvider, M: BlockMetadata> ManagedBlockPool<S, L, M ...@@ -228,6 +236,7 @@ impl<S: Storage, L: LocalityProvider, M: BlockMetadata> ManagedBlockPool<S, L, M
blocks: Vec<Block<S, L, M>>, blocks: Vec<Block<S, L, M>>,
global_registry: GlobalRegistry, global_registry: GlobalRegistry,
async_runtime: Handle, async_runtime: Handle,
storage_tier: StorageTier,
default_duplication_setting: BlockRegistrationDuplicationSetting, default_duplication_setting: BlockRegistrationDuplicationSetting,
) -> (Self, ProgressEngine<S, L, M>) { ) -> (Self, ProgressEngine<S, L, M>) {
let (priority_tx, priority_rx) = tokio::sync::mpsc::unbounded_channel(); let (priority_tx, priority_rx) = tokio::sync::mpsc::unbounded_channel();
...@@ -241,6 +250,7 @@ impl<S: Storage, L: LocalityProvider, M: BlockMetadata> ManagedBlockPool<S, L, M ...@@ -241,6 +250,7 @@ impl<S: Storage, L: LocalityProvider, M: BlockMetadata> ManagedBlockPool<S, L, M
blocks, blocks,
global_registry, global_registry,
async_runtime, async_runtime,
storage_tier,
); );
let available_blocks_counter = progress_engine.available_blocks_counter.clone(); let available_blocks_counter = progress_engine.available_blocks_counter.clone();
...@@ -515,10 +525,16 @@ impl<S: Storage, L: LocalityProvider + 'static, M: BlockMetadata> ProgressEngine ...@@ -515,10 +525,16 @@ impl<S: Storage, L: LocalityProvider + 'static, M: BlockMetadata> ProgressEngine
blocks: Vec<Block<S, L, M>>, blocks: Vec<Block<S, L, M>>,
global_registry: GlobalRegistry, global_registry: GlobalRegistry,
async_runtime: Handle, async_runtime: Handle,
storage_tier: StorageTier,
) -> Self { ) -> Self {
let (return_tx, return_rx) = tokio::sync::mpsc::unbounded_channel(); let (return_tx, return_rx) = tokio::sync::mpsc::unbounded_channel();
let mut state = let mut state = State::<S, L, M>::new(
State::<S, L, M>::new(event_manager, return_tx, global_registry, async_runtime); event_manager,
return_tx,
global_registry,
async_runtime,
storage_tier,
);
let count = blocks.len(); let count = blocks.len();
...@@ -589,6 +605,7 @@ mod tests { ...@@ -589,6 +605,7 @@ mod tests {
blocks, blocks,
global_registry, global_registry,
async_runtime, async_runtime,
storage_tier,
default_duplication_setting, default_duplication_setting,
) = args.dissolve(); ) = args.dissolve();
...@@ -598,6 +615,7 @@ mod tests { ...@@ -598,6 +615,7 @@ mod tests {
blocks, blocks,
global_registry, global_registry,
async_runtime, async_runtime,
storage_tier,
default_duplication_setting, default_duplication_setting,
); );
......
...@@ -532,6 +532,7 @@ pub(crate) mod tests { ...@@ -532,6 +532,7 @@ pub(crate) mod tests {
state::CompleteState, state::CompleteState,
}, },
events::NullEventManager, events::NullEventManager,
kv_consolidator::StorageTier,
layout::{BlockLayout, FullyContiguous, LayoutConfigBuilder}, layout::{BlockLayout, FullyContiguous, LayoutConfigBuilder},
storage::tests::{NullDeviceAllocator, NullDeviceStorage}, storage::tests::{NullDeviceAllocator, NullDeviceStorage},
}, },
...@@ -701,8 +702,12 @@ pub(crate) mod tests { ...@@ -701,8 +702,12 @@ pub(crate) mod tests {
let mut blocks = create_block_collection(num_blocks).into_blocks().unwrap(); let mut blocks = create_block_collection(num_blocks).into_blocks().unwrap();
let event_manager = NullEventManager::new(); let event_manager = NullEventManager::new();
let mut registry = let mut registry = BlockRegistry::new(
BlockRegistry::new(event_manager, GlobalRegistry::default(), async_runtime); event_manager,
GlobalRegistry::default(),
async_runtime,
StorageTier::Device,
);
// Iterate through the generated TokenBlocks and the template Blocks, // Iterate through the generated TokenBlocks and the template Blocks,
// setting the state and registering each one. // setting the state and registering each one.
...@@ -745,8 +750,12 @@ pub(crate) mod tests { ...@@ -745,8 +750,12 @@ pub(crate) mod tests {
let matched_block_count = matched_blocks.len(); let matched_block_count = matched_blocks.len();
let event_manager = NullEventManager::new(); let event_manager = NullEventManager::new();
let mut registry = let mut registry = BlockRegistry::new(
BlockRegistry::new(event_manager, GlobalRegistry::default(), async_runtime); event_manager,
GlobalRegistry::default(),
async_runtime,
StorageTier::Device,
);
// all matched blocks should be in the complete or registered state // all matched blocks should be in the complete or registered state
for block in &mut matched_blocks { for block in &mut matched_blocks {
......
...@@ -17,11 +17,17 @@ impl<S: Storage, L: LocalityProvider + 'static, M: BlockMetadata> State<S, L, M> ...@@ -17,11 +17,17 @@ impl<S: Storage, L: LocalityProvider + 'static, M: BlockMetadata> State<S, L, M>
return_tx: tokio::sync::mpsc::UnboundedSender<Block<S, L, M>>, return_tx: tokio::sync::mpsc::UnboundedSender<Block<S, L, M>>,
global_registry: GlobalRegistry, global_registry: GlobalRegistry,
async_runtime: Handle, async_runtime: Handle,
storage_tier: StorageTier,
) -> Self { ) -> Self {
Self { Self {
active: ActiveBlockPool::new(), active: ActiveBlockPool::new(),
inactive: InactiveBlockPool::new(), inactive: InactiveBlockPool::new(),
registry: BlockRegistry::new(event_manager.clone(), global_registry, async_runtime), registry: BlockRegistry::new(
event_manager.clone(),
global_registry,
async_runtime,
storage_tier,
),
return_tx, return_tx,
event_manager, event_manager,
} }
......
...@@ -26,6 +26,8 @@ use std::sync::Arc; ...@@ -26,6 +26,8 @@ use std::sync::Arc;
use tokio::runtime::Handle; use tokio::runtime::Handle;
use tokio::sync::oneshot; use tokio::sync::oneshot;
use crate::block_manager::kv_consolidator::StorageTier;
pub(crate) struct Resources { pub(crate) struct Resources {
pub worker_id: WorkerID, pub worker_id: WorkerID,
pub cancellation_token: CancellationToken, pub cancellation_token: CancellationToken,
...@@ -253,7 +255,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<locality::Local, Metadata> { ...@@ -253,7 +255,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<locality::Local, Metadata> {
let (device_pool, device_blocks, device_offload_filter) = match device_factory { let (device_pool, device_blocks, device_offload_filter) = match device_factory {
Some(factory) => { Some(factory) => {
let (pool, blocks, offload_filter) = let (pool, blocks, offload_filter) =
create_block_pool::<_, _, Metadata>(factory, &resources, "disk")?; create_block_pool::<_, _, Metadata>(factory, &resources, "device")?;
(Some(pool), Some(blocks), offload_filter) (Some(pool), Some(blocks), offload_filter)
} }
None => { None => {
...@@ -523,17 +525,25 @@ impl<Locality: LocalityProvider, Metadata: BlockMetadata> std::fmt::Debug ...@@ -523,17 +525,25 @@ impl<Locality: LocalityProvider, Metadata: BlockMetadata> std::fmt::Debug
pub(crate) fn create_block_pool<S: Storage, L: LocalityProvider, M: BlockMetadata>( pub(crate) fn create_block_pool<S: Storage, L: LocalityProvider, M: BlockMetadata>(
factory: impl IntoBlocks<S, L>, factory: impl IntoBlocks<S, L>,
resources: &Resources, resources: &Resources,
_pool_name: &str, pool_name: &str,
) -> Result<( ) -> Result<(
Arc<dyn BlockPool<S, L, M>>, Arc<dyn BlockPool<S, L, M>>,
Vec<Block<S, L, M>>, Vec<Block<S, L, M>>,
Option<Arc<dyn OffloadFilter>>, Option<Arc<dyn OffloadFilter>>,
)> { )> {
let storage_tier = match pool_name {
"device" => StorageTier::Device,
"host" => StorageTier::HostPinned,
"disk" => StorageTier::Disk,
_ => anyhow::bail!("unsupported block pool tier: {}", pool_name),
};
let pool = ManagedBlockPool::<S, L, M>::builder() let pool = ManagedBlockPool::<S, L, M>::builder()
.cancel_token(resources.cancellation_token.clone()) .cancel_token(resources.cancellation_token.clone())
.global_registry(resources.global_registry.clone()) .global_registry(resources.global_registry.clone())
.async_runtime(resources.async_rt_handle.clone()) .async_runtime(resources.async_rt_handle.clone())
.event_manager(resources.event_manager.clone()) .event_manager(resources.event_manager.clone())
.storage_tier(storage_tier)
.build()?; .build()?;
let offload_filter = factory.offload_filter(); let offload_filter = factory.offload_filter();
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::sync::Arc; use std::{
use std::time::Instant; collections::{HashMap, HashSet},
sync::Arc,
time::Instant,
};
use anyhow::Result; use anyhow::Result;
use dynamo_kv_router::{ use dynamo_kv_router::{
...@@ -15,6 +18,7 @@ use dynamo_kv_router::{ ...@@ -15,6 +18,7 @@ use dynamo_kv_router::{
RouterRequest, RouterResponse, TokensWithHashes, WorkerId, WorkerWithDpRank, RouterRequest, RouterResponse, TokensWithHashes, WorkerId, WorkerWithDpRank,
compute_block_hash_for_seq, compute_block_hash_for_seq,
}, },
scheduling::TierOverlapBlocks,
}; };
use dynamo_runtime::{ use dynamo_runtime::{
component::{Client, Endpoint}, component::{Client, Endpoint},
...@@ -63,8 +67,6 @@ use crate::{ ...@@ -63,8 +67,6 @@ use crate::{
local_model::runtime_config::ModelRuntimeConfig, local_model::runtime_config::ModelRuntimeConfig,
}; };
use std::collections::HashSet;
// [gluo TODO] shouldn't need to be public // [gluo TODO] shouldn't need to be public
// this should be discovered from the component // this should be discovered from the component
...@@ -85,6 +87,144 @@ pub const RADIX_STATE_FILE: &str = "radix-state"; ...@@ -85,6 +87,144 @@ pub const RADIX_STATE_FILE: &str = "radix-state";
// for worker-local kvindexer query // for worker-local kvindexer query
pub const WORKER_KV_INDEXER_BUFFER_SIZE: usize = 1024; // store 1024 most recent events in worker buffer pub const WORKER_KV_INDEXER_BUFFER_SIZE: usize = 1024; // store 1024 most recent events in worker buffer
#[derive(Debug, Clone, Copy, Default)]
pub(crate) struct WorkerCacheHitEstimate {
pub effective_overlap_blocks: f64,
pub cached_tokens: usize,
}
impl WorkerCacheHitEstimate {
pub fn rounded_overlap_blocks(self) -> u32 {
self.effective_overlap_blocks.round() as u32
}
}
#[derive(Debug, Clone, Default)]
struct CacheHitEstimates {
effective_overlap_blocks: HashMap<WorkerWithDpRank, f64>,
cached_tokens: HashMap<WorkerWithDpRank, usize>,
}
#[derive(Debug, Clone, Copy)]
pub(crate) struct BestMatchDetails {
pub worker: WorkerWithDpRank,
pub cache_hit: WorkerCacheHitEstimate,
}
fn cache_hit_weight_for_tier(
kv_router_config: &KvRouterConfig,
storage_tier: dynamo_kv_router::protocols::StorageTier,
) -> f64 {
match storage_tier {
dynamo_kv_router::protocols::StorageTier::Device => 1.0,
dynamo_kv_router::protocols::StorageTier::HostPinned => {
kv_router_config.host_cache_hit_weight
}
dynamo_kv_router::protocols::StorageTier::Disk
| dynamo_kv_router::protocols::StorageTier::External => {
kv_router_config.disk_cache_hit_weight
}
}
}
fn cached_tokens_from_effective_overlap(block_size: u32, effective_overlap_blocks: f64) -> usize {
(effective_overlap_blocks * block_size as f64)
.round()
.max(0.0) as usize
}
fn cache_hit_estimates_from_tiered_matches(
kv_router_config: &KvRouterConfig,
block_size: u32,
tiered_matches: &indexer::TieredMatchDetails,
) -> CacheHitEstimates {
let mut effective_overlap_blocks = HashMap::new();
for (worker, overlap) in &tiered_matches.device.overlap_scores.scores {
effective_overlap_blocks.insert(*worker, *overlap as f64);
}
for (storage_tier, tier_matches) in &tiered_matches.lower_tier {
let weight = cache_hit_weight_for_tier(kv_router_config, *storage_tier);
if weight == 0.0 {
continue;
}
for (worker, hits) in &tier_matches.hits {
if *hits == 0 {
continue;
}
*effective_overlap_blocks.entry(*worker).or_insert(0.0) += *hits as f64 * weight;
}
}
let cached_tokens = effective_overlap_blocks
.iter()
.map(|(worker, overlap)| {
(
*worker,
cached_tokens_from_effective_overlap(block_size, *overlap),
)
})
.collect();
CacheHitEstimates {
effective_overlap_blocks,
cached_tokens,
}
}
fn cache_hit_for_worker(
cache_hit_estimates: &CacheHitEstimates,
worker: WorkerWithDpRank,
) -> WorkerCacheHitEstimate {
WorkerCacheHitEstimate {
effective_overlap_blocks: cache_hit_estimates
.effective_overlap_blocks
.get(&worker)
.copied()
.unwrap_or(0.0),
cached_tokens: cache_hit_estimates
.cached_tokens
.get(&worker)
.copied()
.unwrap_or(0),
}
}
fn tier_overlap_blocks_from_tiered_matches(
tiered_matches: &indexer::TieredMatchDetails,
) -> TierOverlapBlocks {
let mut tier_overlap_blocks = TierOverlapBlocks::default();
if let Some(host_matches) = tiered_matches
.lower_tier
.get(&dynamo_kv_router::protocols::StorageTier::HostPinned)
{
tier_overlap_blocks.host_pinned.extend(
host_matches
.hits
.iter()
.map(|(worker, hits)| (*worker, *hits)),
);
}
// Disk and External share the same weighting (see `storage_tier_weight`),
// so accumulate both into the disk bucket.
for tier in [
dynamo_kv_router::protocols::StorageTier::Disk,
dynamo_kv_router::protocols::StorageTier::External,
] {
if let Some(matches) = tiered_matches.lower_tier.get(&tier) {
for (worker, hits) in &matches.hits {
*tier_overlap_blocks.disk.entry(*worker).or_default() += *hits;
}
}
}
tier_overlap_blocks
}
/// Generates a dp_rank-specific endpoint name for the worker KV indexer query service. /// Generates a dp_rank-specific endpoint name for the worker KV indexer query service.
/// Each dp_rank has its own LocalKvIndexer and query endpoint to ensure per-dp_rank monotonicity. /// Each dp_rank has its own LocalKvIndexer and query endpoint to ensure per-dp_rank monotonicity.
pub fn worker_kv_indexer_query_endpoint(dp_rank: DpRank) -> String { pub fn worker_kv_indexer_query_endpoint(dp_rank: DpRank) -> String {
...@@ -275,6 +415,25 @@ where ...@@ -275,6 +415,25 @@ where
self.is_eagle self.is_eagle
} }
fn cache_hit_estimates_from_tiered_matches(
&self,
tiered_matches: &indexer::TieredMatchDetails,
) -> CacheHitEstimates {
cache_hit_estimates_from_tiered_matches(
&self.kv_router_config,
self.block_size,
tiered_matches,
)
}
fn cache_hit_for_worker(
&self,
cache_hit_estimates: &CacheHitEstimates,
worker: WorkerWithDpRank,
) -> WorkerCacheHitEstimate {
cache_hit_for_worker(cache_hit_estimates, worker)
}
pub async fn record_routing_decision( pub async fn record_routing_decision(
&self, &self,
mut tokens_with_hashes: TokensWithHashes, mut tokens_with_hashes: TokensWithHashes,
...@@ -285,16 +444,15 @@ where ...@@ -285,16 +444,15 @@ where
.await .await
} }
/// Give these tokens, find the worker with the best match in it's KV cache. /// Give these tokens, find the worker with the best weighted cache hit.
/// Returns the best worker (with dp_rank) and overlap amount in number of blocks. /// Returns the full match details for the selected worker.
/// Now also takes optional context_id for request tracking.
/// ///
/// When `pinned_worker` is Some, scheduling and queueing are constrained to /// When `pinned_worker` is Some, scheduling and queueing are constrained to
/// that exact worker/rank. /// that exact worker/rank.
/// ///
/// When `allowed_worker_ids` is Some, only workers in that set are considered for selection. /// When `allowed_worker_ids` is Some, only workers in that set are considered for selection.
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub async fn find_best_match( pub(crate) async fn find_best_match_details(
&self, &self,
context_id: Option<&str>, context_id: Option<&str>,
tokens: &[u32], tokens: &[u32],
...@@ -306,7 +464,7 @@ where ...@@ -306,7 +464,7 @@ where
expected_output_tokens: Option<u32>, expected_output_tokens: Option<u32>,
pinned_worker: Option<WorkerWithDpRank>, pinned_worker: Option<WorkerWithDpRank>,
allowed_worker_ids: Option<HashSet<WorkerId>>, allowed_worker_ids: Option<HashSet<WorkerId>>,
) -> anyhow::Result<(WorkerWithDpRank, u32)> { ) -> anyhow::Result<BestMatchDetails> {
let start = Instant::now(); let start = Instant::now();
if update_states && context_id.is_none() { if update_states && context_id.is_none() {
...@@ -336,13 +494,13 @@ where ...@@ -336,13 +494,13 @@ where
}); });
let seq_hash_elapsed = start.elapsed(); let seq_hash_elapsed = start.elapsed();
// Query indexer and shared cache in parallel when shared cache is configured. // Query indexer (tiered) and shared cache in parallel when shared cache is configured.
// Time each independently so metrics can separate indexer vs shared cache latency. // Time each independently so metrics can separate indexer vs shared cache latency.
let (overlap_scores, shared_cache_hits, indexer_duration, shared_cache_duration) = let (tiered_matches, shared_cache_hits, indexer_duration, shared_cache_duration) =
if let Some(ref shared_cache) = self.shared_cache { if let Some(ref shared_cache) = self.shared_cache {
let indexer_fut = self let indexer_fut = self
.indexer .indexer
.find_matches(block_hashes.clone()) .find_matches_by_tier(block_hashes)
.instrument(tracing::info_span!("kv_router.find_matches")); .instrument(tracing::info_span!("kv_router.find_matches"));
let shared_fut = shared_cache let shared_fut = shared_cache
.check_blocks(tokens, self.block_size) .check_blocks(tokens, self.block_size)
...@@ -361,7 +519,7 @@ where ...@@ -361,7 +519,7 @@ where
let ((indexer_result, idx_dur), (shared_result, sc_dur)) = let ((indexer_result, idx_dur), (shared_result, sc_dur)) =
tokio::join!(indexer_timed, shared_timed); tokio::join!(indexer_timed, shared_timed);
let overlaps = indexer_result?; let tiered = indexer_result?;
// Shared cache failure is non-fatal: log warning and fall back to empty hits. // Shared cache failure is non-fatal: log warning and fall back to empty hits.
let hits = match shared_result { let hits = match shared_result {
Ok(hits) => Some(hits), Ok(hits) => Some(hits),
...@@ -373,16 +531,26 @@ where ...@@ -373,16 +531,26 @@ where
None None
} }
}; };
(overlaps, hits, idx_dur, Some(sc_dur)) (tiered, hits, idx_dur, Some(sc_dur))
} else { } else {
let t = Instant::now(); let t = Instant::now();
let overlaps = self let tiered = self
.indexer .indexer
.find_matches(block_hashes) .find_matches_by_tier(block_hashes)
.instrument(tracing::info_span!("kv_router.find_matches")) .instrument(tracing::info_span!("kv_router.find_matches"))
.await?; .await?;
(overlaps, None, t.elapsed(), None) (tiered, None, t.elapsed(), None)
}; };
let tier_overlap_blocks = tier_overlap_blocks_from_tiered_matches(&tiered_matches);
let cache_hit_estimates = self.cache_hit_estimates_from_tiered_matches(&tiered_matches);
let tree_sizes: HashMap<_, _> = tiered_matches
.device
.overlap_scores
.tree_sizes
.iter()
.map(|(k, v)| (*k, *v))
.collect();
let find_matches_elapsed = start.elapsed(); let find_matches_elapsed = start.elapsed();
// Capture shared cache info for metrics before moving into schedule(). // Capture shared cache info for metrics before moving into schedule().
...@@ -397,7 +565,10 @@ where ...@@ -397,7 +565,10 @@ where
context_id.map(|s| s.to_string()), context_id.map(|s| s.to_string()),
isl_tokens, isl_tokens,
maybe_seq_hashes, maybe_seq_hashes,
overlap_scores, tier_overlap_blocks,
cache_hit_estimates.effective_overlap_blocks,
cache_hit_estimates.cached_tokens,
tree_sizes,
router_config_override, router_config_override,
update_states, update_states,
lora_name, lora_name,
...@@ -430,7 +601,7 @@ where ...@@ -430,7 +601,7 @@ where
m.shared_cache_hit_rate m.shared_cache_hit_rate
.observe(hits.total_hits as f64 / num_blocks as f64); .observe(hits.total_hits as f64 / num_blocks as f64);
} }
let beyond = hits.hits_beyond(response.overlap_blocks); let beyond = hits.hits_beyond(response.effective_overlap_blocks.round() as u32);
m.shared_cache_beyond_blocks.observe(beyond as f64); m.shared_cache_beyond_blocks.observe(beyond as f64);
} }
...@@ -445,7 +616,45 @@ where ...@@ -445,7 +616,45 @@ where
"find_best_match completed" "find_best_match completed"
); );
Ok((response.best_worker, response.overlap_blocks)) Ok(BestMatchDetails {
worker: response.best_worker,
cache_hit: WorkerCacheHitEstimate {
effective_overlap_blocks: response.effective_overlap_blocks,
cached_tokens: response.cached_tokens,
},
})
}
/// Give these tokens, find the worker with the best match in its KV cache.
/// Returns the best worker (with dp_rank) and approximate effective overlap in blocks.
#[allow(clippy::too_many_arguments)]
pub async fn find_best_match(
&self,
context_id: Option<&str>,
tokens: &[u32],
block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
router_config_override: Option<&RouterConfigOverride>,
update_states: bool,
lora_name: Option<String>,
priority_jump: f64,
expected_output_tokens: Option<u32>,
allowed_worker_ids: Option<HashSet<WorkerId>>,
) -> anyhow::Result<(WorkerWithDpRank, u32)> {
let result = self
.find_best_match_details(
context_id,
tokens,
block_mm_infos,
router_config_override,
update_states,
lora_name,
priority_jump,
expected_output_tokens,
None,
allowed_worker_ids,
)
.await?;
Ok((result.worker, result.cache_hit.rounded_overlap_blocks()))
} }
/// Register externally-provided workers in the slot tracker. /// Register externally-provided workers in the slot tracker.
...@@ -459,7 +668,7 @@ where ...@@ -459,7 +668,7 @@ where
request_id: String, request_id: String,
tokens: &[u32], tokens: &[u32],
block_mm_infos: Option<&[Option<BlockExtraInfo>]>, block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
overlap_blocks: u32, cached_tokens: usize,
expected_output_tokens: Option<u32>, expected_output_tokens: Option<u32>,
worker: WorkerWithDpRank, worker: WorkerWithDpRank,
lora_name: Option<String>, lora_name: Option<String>,
...@@ -483,7 +692,7 @@ where ...@@ -483,7 +692,7 @@ where
.kv_router_config .kv_router_config
.track_prefill_tokens(router_config_override); .track_prefill_tokens(router_config_override);
let prefill_load_hint = let prefill_load_hint =
self.prefill_load_hint_for(isl_tokens, overlap_blocks, track_prefill_tokens); self.prefill_load_hint_for(isl_tokens, cached_tokens, track_prefill_tokens);
if let Err(e) = self if let Err(e) = self
.scheduler .scheduler
...@@ -518,14 +727,14 @@ where ...@@ -518,14 +727,14 @@ where
fn prefill_load_hint_for( fn prefill_load_hint_for(
&self, &self,
isl_tokens: usize, isl_tokens: usize,
overlap_blocks: u32, cached_tokens: usize,
track_prefill_tokens: bool, track_prefill_tokens: bool,
) -> Option<PrefillLoadHint> { ) -> Option<PrefillLoadHint> {
if !track_prefill_tokens { if !track_prefill_tokens {
return None; return None;
} }
let prefix = (overlap_blocks as usize) * (self.block_size as usize); let prefix = cached_tokens.min(isl_tokens);
let effective_isl = isl_tokens.saturating_sub(prefix); let effective_isl = isl_tokens.saturating_sub(prefix);
if effective_isl == 0 { if effective_isl == 0 {
return None; return None;
...@@ -578,7 +787,7 @@ where ...@@ -578,7 +787,7 @@ where
} }
/// Compute the overlap blocks for a given token sequence and worker. /// Compute the overlap blocks for a given token sequence and worker.
/// This queries the indexer to find how many blocks are already cached. /// This queries the indexer to find the effective weighted cache hit.
pub async fn get_overlap_blocks( pub async fn get_overlap_blocks(
&self, &self,
tokens: &[u32], tokens: &[u32],
...@@ -586,6 +795,19 @@ where ...@@ -586,6 +795,19 @@ where
worker: WorkerWithDpRank, worker: WorkerWithDpRank,
lora_name: Option<&str>, lora_name: Option<&str>,
) -> Result<u32, KvRouterError> { ) -> Result<u32, KvRouterError> {
Ok(self
.get_cache_hit_estimate(tokens, block_mm_infos, worker, lora_name)
.await?
.rounded_overlap_blocks())
}
pub(crate) async fn get_cache_hit_estimate(
&self,
tokens: &[u32],
block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
worker: WorkerWithDpRank,
lora_name: Option<&str>,
) -> Result<WorkerCacheHitEstimate, KvRouterError> {
let block_hashes = compute_block_hash_for_seq( let block_hashes = compute_block_hash_for_seq(
tokens, tokens,
self.block_size, self.block_size,
...@@ -595,9 +817,9 @@ where ...@@ -595,9 +817,9 @@ where
is_eagle: Some(self.is_eagle), is_eagle: Some(self.is_eagle),
}, },
); );
log_routing_input_hashes(None, self.block_size, tokens, &block_hashes); let tiered_matches = self.indexer.find_matches_by_tier(block_hashes).await?;
let overlap_scores = self.indexer.find_matches(block_hashes).await?; let cache_hit_estimates = self.cache_hit_estimates_from_tiered_matches(&tiered_matches);
Ok(overlap_scores.scores.get(&worker).copied().unwrap_or(0)) Ok(self.cache_hit_for_worker(&cache_hit_estimates, worker))
} }
/// Get potential prefill and decode loads for all workers /// Get potential prefill and decode loads for all workers
...@@ -626,12 +848,13 @@ where ...@@ -626,12 +848,13 @@ where
let track_prefill_tokens = self let track_prefill_tokens = self
.kv_router_config .kv_router_config
.track_prefill_tokens(router_config_override); .track_prefill_tokens(router_config_override);
let overlap_scores = self.indexer.find_matches(block_hashes).await?; let tiered_matches = self.indexer.find_matches_by_tier(block_hashes).await?;
let cache_hit_estimates = self.cache_hit_estimates_from_tiered_matches(&tiered_matches);
Ok(self.scheduler.get_potential_loads( Ok(self.scheduler.get_potential_loads(
maybe_seq_hashes, maybe_seq_hashes,
isl_tokens, isl_tokens,
overlap_scores, cache_hit_estimates.cached_tokens,
track_prefill_tokens, track_prefill_tokens,
)) ))
} }
...@@ -673,7 +896,6 @@ where ...@@ -673,7 +896,6 @@ where
0.0, 0.0,
None, None,
None, None,
None,
) )
.await?; .await?;
...@@ -719,12 +941,57 @@ mod tests { ...@@ -719,12 +941,57 @@ mod tests {
use std::collections::HashMap; use std::collections::HashMap;
use async_trait::async_trait; use async_trait::async_trait;
use dynamo_kv_router::{
indexer::{LowerTierMatchDetails, MatchDetails},
protocols::{OverlapScores, StorageTier},
};
use dynamo_runtime::{DistributedRuntime, Runtime, distributed::DistributedConfig}; use dynamo_runtime::{DistributedRuntime, Runtime, distributed::DistributedConfig};
use tokio::sync::watch; use tokio::sync::watch;
use crate::kv_router::scheduler::KvSchedulerError; use crate::kv_router::scheduler::KvSchedulerError;
use crate::local_model::runtime_config::ModelRuntimeConfig; use crate::local_model::runtime_config::ModelRuntimeConfig;
#[test]
fn weighted_cache_hit_estimates_include_lower_tiers() {
let worker_1 = WorkerWithDpRank::new(1, 0);
let worker_2 = WorkerWithDpRank::new(2, 0);
let mut device_overlap_scores = OverlapScores::new();
device_overlap_scores.scores.insert(worker_1, 2);
let mut host_match_details = LowerTierMatchDetails::default();
host_match_details.hits.insert(worker_1, 1);
host_match_details.hits.insert(worker_2, 1);
let mut disk_match_details = LowerTierMatchDetails::default();
disk_match_details.hits.insert(worker_1, 2);
let tiered_matches = indexer::TieredMatchDetails {
device: MatchDetails {
overlap_scores: device_overlap_scores,
..Default::default()
},
lower_tier: HashMap::from([
(StorageTier::HostPinned, host_match_details),
(StorageTier::Disk, disk_match_details),
]),
};
let estimates = cache_hit_estimates_from_tiered_matches(
&KvRouterConfig::default(),
16,
&tiered_matches,
);
assert_eq!(
estimates.effective_overlap_blocks.get(&worker_1),
Some(&3.25)
);
assert_eq!(estimates.cached_tokens.get(&worker_1), Some(&52));
assert_eq!(
estimates.effective_overlap_blocks.get(&worker_2),
Some(&0.75)
);
assert_eq!(estimates.cached_tokens.get(&worker_2), Some(&12));
}
struct FakeSharedCache { struct FakeSharedCache {
hits: Option<dynamo_kv_router::protocols::SharedCacheHits>, hits: Option<dynamo_kv_router::protocols::SharedCacheHits>,
should_error: bool, should_error: bool,
...@@ -766,7 +1033,8 @@ mod tests { ...@@ -766,7 +1033,8 @@ mod tests {
Ok(dynamo_kv_router::protocols::WorkerSelectionResult { Ok(dynamo_kv_router::protocols::WorkerSelectionResult {
worker: self.selected_worker, worker: self.selected_worker,
required_blocks: request.isl_tokens.div_ceil(block_size as usize) as u64, required_blocks: request.isl_tokens.div_ceil(block_size as usize) as u64,
overlap_blocks: 0, effective_overlap_blocks: 0.0,
cached_tokens: 0,
}) })
} }
} }
...@@ -855,7 +1123,6 @@ mod tests { ...@@ -855,7 +1123,6 @@ mod tests {
0.0, 0.0,
None, None,
None, None,
None,
) )
.await .await
.unwrap(); .unwrap();
...@@ -889,7 +1156,6 @@ mod tests { ...@@ -889,7 +1156,6 @@ mod tests {
0.0, 0.0,
None, None,
None, None,
None,
) )
.await .await
.unwrap(); .unwrap();
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::sync::Arc; use std::{
use std::time::Duration; collections::HashMap,
sync::{Arc, RwLock},
time::Duration,
};
use anyhow::Result; use anyhow::Result;
use dynamo_kv_router::{ use dynamo_kv_router::{
ConcurrentRadixTreeCompressed, ThreadPoolIndexer, ConcurrentRadixTreeCompressed, LowerTierIndexer, ThreadPoolIndexer,
approx::PruneConfig, approx::PruneConfig,
config::KvRouterConfig, config::KvRouterConfig,
indexer::{KvIndexer, KvIndexerInterface, KvIndexerMetrics, KvRouterError}, indexer::{
KvIndexer, KvIndexerInterface, KvIndexerMetrics, KvRouterError, LowerTierContinuation,
LowerTierMatchDetails, MatchDetails,
},
protocols::{ protocols::{
DpRank, LocalBlockHash, OverlapScores, RouterEvent, TokensWithHashes, WorkerId, DpRank, LocalBlockHash, OverlapScores, RouterEvent, StorageTier, TokensWithHashes,
WorkerWithDpRank, WorkerId, WorkerWithDpRank,
}, },
}; };
use dynamo_runtime::{component::Component, traits::DistributedRuntimeProvider}; use dynamo_runtime::{component::Component, traits::DistributedRuntimeProvider};
...@@ -24,6 +30,123 @@ pub mod remote; ...@@ -24,6 +30,123 @@ pub mod remote;
mod subscriber; mod subscriber;
mod worker_query; mod worker_query;
#[derive(Clone)]
pub struct LowerTierIndexers {
num_threads: usize,
block_size: u32,
indexers: Arc<RwLock<HashMap<StorageTier, Arc<ThreadPoolIndexer<LowerTierIndexer>>>>>,
}
impl LowerTierIndexers {
pub(crate) fn new(num_threads: usize, block_size: u32) -> Self {
assert!(
num_threads > 0,
"lower-tier indexer threads must be non-zero"
);
Self {
num_threads,
block_size,
indexers: Arc::new(RwLock::new(HashMap::new())),
}
}
fn get_or_create(&self, storage_tier: StorageTier) -> Arc<ThreadPoolIndexer<LowerTierIndexer>> {
debug_assert!(!storage_tier.is_gpu());
if let Some(indexer) = self.indexers.read().unwrap().get(&storage_tier).cloned() {
return indexer;
}
self.indexers
.write()
.unwrap()
.entry(storage_tier)
.or_insert_with(|| {
Arc::new(ThreadPoolIndexer::new(
LowerTierIndexer::new(),
self.num_threads,
self.block_size,
))
})
.clone()
}
fn all(&self) -> Vec<Arc<ThreadPoolIndexer<LowerTierIndexer>>> {
self.indexers.read().unwrap().values().cloned().collect()
}
fn get(&self, storage_tier: StorageTier) -> Option<Arc<ThreadPoolIndexer<LowerTierIndexer>>> {
self.indexers.read().unwrap().get(&storage_tier).cloned()
}
}
fn lower_tier_query_order() -> [StorageTier; 3] {
[
StorageTier::HostPinned,
StorageTier::Disk,
StorageTier::External,
]
}
fn query_lower_tiers(
indexers: &LowerTierIndexers,
sequence: &[LocalBlockHash],
device_matches: &MatchDetails,
) -> HashMap<StorageTier, LowerTierMatchDetails> {
let mut continuations = LowerTierMatchDetails::default().next_continuations;
for (worker, matched_blocks) in &device_matches.overlap_scores.scores {
let Some(last_hash) = device_matches.last_matched_hashes.get(worker).copied() else {
debug_assert!(
false,
"device match result missing last matched hash for worker {worker:?}"
);
continue;
};
continuations.insert(
*worker,
LowerTierContinuation::new(*matched_blocks as usize, last_hash),
);
}
let mut lower_tier_matches = HashMap::new();
for storage_tier in lower_tier_query_order() {
let Some(indexer) = indexers.get(storage_tier) else {
continue;
};
if let Some(&first_hash) = sequence.first() {
let root_workers: Vec<_> = indexer.backend().root_workers(first_hash);
for worker in root_workers.iter() {
continuations
.entry(*worker)
.or_insert_with(|| LowerTierContinuation::from_root(0));
}
}
let tier_matches = indexer
.backend()
.query_match_details(sequence, &continuations);
let matched_workers = tier_matches.hits.values().filter(|&&hits| hits > 0).count();
tracing::debug!(
?storage_tier,
queried_workers = continuations.len(),
matched_workers,
"Queried lower-tier indexer"
);
continuations = tier_matches.next_continuations.clone();
lower_tier_matches.insert(storage_tier, tier_matches);
}
lower_tier_matches
}
#[derive(Debug, Clone, Default)]
pub(crate) struct TieredMatchDetails {
pub device: MatchDetails,
#[cfg_attr(not(test), allow(dead_code))]
pub lower_tier: HashMap<StorageTier, LowerTierMatchDetails>,
}
use self::remote::RemoteIndexer; use self::remote::RemoteIndexer;
pub use self::remote::{ServedIndexerHandle, ServedIndexerMode, ensure_served_indexer_service}; pub use self::remote::{ServedIndexerHandle, ServedIndexerMode, ensure_served_indexer_service};
pub(crate) use subscriber::start_subscriber; pub(crate) use subscriber::start_subscriber;
...@@ -31,8 +154,14 @@ pub(crate) use worker_query::start_worker_kv_query_endpoint; ...@@ -31,8 +154,14 @@ pub(crate) use worker_query::start_worker_kv_query_endpoint;
#[derive(Clone)] #[derive(Clone)]
pub enum Indexer { pub enum Indexer {
KvIndexer(KvIndexer), KvIndexer {
Concurrent(Arc<ThreadPoolIndexer<ConcurrentRadixTreeCompressed>>), primary: KvIndexer,
lower_tier: LowerTierIndexers,
},
Concurrent {
primary: Arc<ThreadPoolIndexer<ConcurrentRadixTreeCompressed>>,
lower_tier: LowerTierIndexers,
},
Remote(Arc<RemoteIndexer>), Remote(Arc<RemoteIndexer>),
None, None,
} }
...@@ -73,57 +202,98 @@ impl Indexer { ...@@ -73,57 +202,98 @@ impl Indexer {
max_tree_size: kv_router_config.router_max_tree_size, max_tree_size: kv_router_config.router_max_tree_size,
prune_target_ratio: kv_router_config.router_prune_target_ratio, prune_target_ratio: kv_router_config.router_prune_target_ratio,
}); });
return Ok(Self::KvIndexer(KvIndexer::new_with_frequency( return Ok(Self::KvIndexer {
cancellation_token, primary: KvIndexer::new_with_frequency(
None, cancellation_token,
block_size, None,
kv_indexer_metrics, block_size,
prune_config, kv_indexer_metrics,
))); prune_config,
),
lower_tier: LowerTierIndexers::new(1, block_size),
});
} }
if kv_router_config.router_event_threads > 1 { if kv_router_config.router_event_threads > 1 {
let kv_indexer_metrics = KvIndexerMetrics::from_component(component); let kv_indexer_metrics = KvIndexerMetrics::from_component(component);
return Ok(Self::Concurrent(Arc::new( return Ok(Self::Concurrent {
ThreadPoolIndexer::new_with_metrics( primary: Arc::new(ThreadPoolIndexer::new_with_metrics(
ConcurrentRadixTreeCompressed::new(), ConcurrentRadixTreeCompressed::new(),
kv_router_config.router_event_threads as usize, kv_router_config.router_event_threads as usize,
block_size, block_size,
Some(kv_indexer_metrics), Some(kv_indexer_metrics),
)),
lower_tier: LowerTierIndexers::new(
kv_router_config.router_event_threads as usize,
block_size,
), ),
))); });
} }
let kv_indexer_metrics = KvIndexerMetrics::from_component(component); let kv_indexer_metrics = KvIndexerMetrics::from_component(component);
let cancellation_token = component.drt().primary_token(); let cancellation_token = component.drt().primary_token();
Ok(Self::KvIndexer(KvIndexer::new_with_frequency( Ok(Self::KvIndexer {
cancellation_token, primary: KvIndexer::new_with_frequency(
None, cancellation_token,
block_size, None,
kv_indexer_metrics, block_size,
None, kv_indexer_metrics,
))) None,
),
lower_tier: LowerTierIndexers::new(1, block_size),
})
} }
#[allow(dead_code)]
pub(crate) async fn find_matches( pub(crate) async fn find_matches(
&self, &self,
sequence: Vec<LocalBlockHash>, sequence: Vec<LocalBlockHash>,
) -> Result<OverlapScores, KvRouterError> { ) -> Result<OverlapScores, KvRouterError> {
self.find_match_details(sequence)
.await
.map(|details| details.overlap_scores)
}
pub(crate) async fn find_match_details(
&self,
sequence: Vec<LocalBlockHash>,
) -> Result<MatchDetails, KvRouterError> {
match self { match self {
Self::KvIndexer(indexer) => indexer.find_matches(sequence).await, Self::KvIndexer { primary, .. } => primary.find_match_details(sequence).await,
Self::Concurrent(tpi) => tpi.find_matches(sequence).await, Self::Concurrent { primary, .. } => {
Self::Remote(remote) => match remote.find_matches(sequence).await { Ok(primary.backend().find_match_details_impl(&sequence, false))
Ok(scores) => Ok(scores), }
Err(error) => { Self::Remote(remote) => remote
tracing::warn!(error = %error, "Remote indexer query failed"); .find_matches(sequence)
Ok(OverlapScores::new()) .await
} .map(|overlap_scores| MatchDetails {
}, overlap_scores,
Self::None => Ok(OverlapScores::new()), ..Default::default()
})
.map_err(|e| {
tracing::warn!(error = %e, "Remote indexer query failed");
KvRouterError::IndexerOffline
}),
Self::None => Ok(MatchDetails::new()),
} }
} }
pub(crate) async fn find_matches_by_tier(
&self,
sequence: Vec<LocalBlockHash>,
) -> Result<TieredMatchDetails, KvRouterError> {
let device = self.find_match_details(sequence.clone()).await?;
let lower_tier = match self {
Self::KvIndexer { lower_tier, .. } | Self::Concurrent { lower_tier, .. } => {
query_lower_tiers(lower_tier, &sequence, &device)
}
Self::Remote(_) | Self::None => HashMap::new(),
};
Ok(TieredMatchDetails { device, lower_tier })
}
pub(crate) async fn record_hashed_routing_decision( pub(crate) async fn record_hashed_routing_decision(
&self, &self,
worker: WorkerWithDpRank, worker: WorkerWithDpRank,
...@@ -131,12 +301,12 @@ impl Indexer { ...@@ -131,12 +301,12 @@ impl Indexer {
sequence_hashes: Vec<SequenceHash>, sequence_hashes: Vec<SequenceHash>,
) -> Result<(), KvRouterError> { ) -> Result<(), KvRouterError> {
match self { match self {
Self::KvIndexer(indexer) => { Self::KvIndexer { primary, .. } => {
indexer primary
.process_routing_decision_with_hashes(worker, local_hashes, sequence_hashes) .process_routing_decision_with_hashes(worker, local_hashes, sequence_hashes)
.await .await
} }
Self::Concurrent(_) => { Self::Concurrent { .. } => {
tracing::warn!( tracing::warn!(
"Hashed routing-decision recording is unsupported for concurrent indexers" "Hashed routing-decision recording is unsupported for concurrent indexers"
); );
...@@ -155,8 +325,8 @@ impl Indexer { ...@@ -155,8 +325,8 @@ impl Indexer {
pub(crate) async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> { pub(crate) async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
match self { match self {
Self::KvIndexer(indexer) => indexer.dump_events().await, Self::KvIndexer { primary, .. } => primary.dump_events().await,
Self::Concurrent(tpi) => tpi.dump_events().await, Self::Concurrent { primary, .. } => primary.dump_events().await,
Self::Remote(_) => Ok(Vec::new()), Self::Remote(_) => Ok(Vec::new()),
Self::None => { Self::None => {
panic!( panic!(
...@@ -172,14 +342,15 @@ impl Indexer { ...@@ -172,14 +342,15 @@ impl Indexer {
worker: WorkerWithDpRank, worker: WorkerWithDpRank,
) -> Result<(), KvRouterError> { ) -> Result<(), KvRouterError> {
match self { match self {
Self::KvIndexer(_) | Self::Remote(_) => { Self::KvIndexer { .. } | Self::Remote(_) => {
let local_hashes = tokens_with_hashes.get_or_compute_block_hashes().to_vec(); let local_hashes = tokens_with_hashes.get_or_compute_block_hashes().to_vec();
let sequence_hashes = tokens_with_hashes.get_or_compute_seq_hashes().to_vec(); let sequence_hashes = tokens_with_hashes.get_or_compute_seq_hashes().to_vec();
self.record_hashed_routing_decision(worker, local_hashes, sequence_hashes) self.record_hashed_routing_decision(worker, local_hashes, sequence_hashes)
.await .await
} }
Self::Concurrent(tpi) => { Self::Concurrent { primary, .. } => {
tpi.process_routing_decision_for_request(tokens_with_hashes, worker) primary
.process_routing_decision_for_request(tokens_with_hashes, worker)
.await .await
} }
Self::None => Ok(()), Self::None => Ok(()),
...@@ -188,25 +359,77 @@ impl Indexer { ...@@ -188,25 +359,77 @@ impl Indexer {
pub(crate) async fn apply_event(&self, event: RouterEvent) { pub(crate) async fn apply_event(&self, event: RouterEvent) {
match self { match self {
Self::KvIndexer(indexer) => { Self::KvIndexer {
if let Err(e) = indexer.event_sender().send(event).await { primary,
tracing::warn!("Failed to send event to indexer: {e}"); lower_tier,
} => match &event.event.data {
dynamo_kv_router::protocols::KvCacheEventData::Cleared => {
if let Err(e) = primary.event_sender().send(event.clone()).await {
tracing::warn!("Failed to send event to indexer: {e}");
}
for indexer in lower_tier.all() {
indexer.apply_event(event.clone()).await;
}
} }
} _ if event.storage_tier.is_gpu() => {
Self::Concurrent(tpi) => tpi.apply_event(event).await, if let Err(e) = primary.event_sender().send(event).await {
tracing::warn!("Failed to send event to indexer: {e}");
}
}
_ => {
lower_tier
.get_or_create(event.storage_tier)
.apply_event(event)
.await;
}
},
Self::Concurrent {
primary,
lower_tier,
} => match &event.event.data {
dynamo_kv_router::protocols::KvCacheEventData::Cleared => {
primary.apply_event(event.clone()).await;
for indexer in lower_tier.all() {
indexer.apply_event(event.clone()).await;
}
}
_ if event.storage_tier.is_gpu() => {
primary.apply_event(event).await;
}
_ => {
lower_tier
.get_or_create(event.storage_tier)
.apply_event(event)
.await;
}
},
Self::Remote(_) | Self::None => {} Self::Remote(_) | Self::None => {}
} }
} }
pub(crate) async fn remove_worker(&self, worker_id: WorkerId) { pub(crate) async fn remove_worker(&self, worker_id: WorkerId) {
match self { match self {
Self::KvIndexer(indexer) => { Self::KvIndexer {
if let Err(e) = indexer.remove_worker_sender().send(worker_id).await { primary,
lower_tier,
} => {
for indexer in lower_tier.all() {
indexer.remove_worker(worker_id).await;
}
if let Err(e) = primary.remove_worker_sender().send(worker_id).await {
tracing::warn!("Failed to send worker removal for {worker_id}: {e}"); tracing::warn!("Failed to send worker removal for {worker_id}: {e}");
} }
} }
Self::Concurrent(tpi) => { Self::Concurrent {
KvIndexerInterface::remove_worker(tpi.as_ref(), worker_id).await; primary,
lower_tier,
} => {
for indexer in lower_tier.all() {
indexer.remove_worker(worker_id).await;
}
KvIndexerInterface::remove_worker(primary.as_ref(), worker_id).await;
} }
Self::Remote(_) | Self::None => {} Self::Remote(_) | Self::None => {}
} }
...@@ -214,11 +437,24 @@ impl Indexer { ...@@ -214,11 +437,24 @@ impl Indexer {
pub(crate) async fn remove_worker_dp_rank(&self, worker_id: WorkerId, dp_rank: DpRank) { pub(crate) async fn remove_worker_dp_rank(&self, worker_id: WorkerId, dp_rank: DpRank) {
match self { match self {
Self::KvIndexer(indexer) => { Self::KvIndexer {
KvIndexerInterface::remove_worker_dp_rank(indexer, worker_id, dp_rank).await; primary,
lower_tier,
} => {
for indexer in lower_tier.all() {
KvIndexerInterface::remove_worker_dp_rank(&*indexer, worker_id, dp_rank).await;
}
KvIndexerInterface::remove_worker_dp_rank(primary, worker_id, dp_rank).await;
} }
Self::Concurrent(tpi) => { Self::Concurrent {
KvIndexerInterface::remove_worker_dp_rank(tpi.as_ref(), worker_id, dp_rank).await; primary,
lower_tier,
} => {
for indexer in lower_tier.all() {
KvIndexerInterface::remove_worker_dp_rank(&*indexer, worker_id, dp_rank).await;
}
KvIndexerInterface::remove_worker_dp_rank(primary.as_ref(), worker_id, dp_rank)
.await;
} }
Self::Remote(_) | Self::None => {} Self::Remote(_) | Self::None => {}
} }
...@@ -226,17 +462,472 @@ impl Indexer { ...@@ -226,17 +462,472 @@ impl Indexer {
pub(crate) async fn get_workers(&self) -> Vec<WorkerId> { pub(crate) async fn get_workers(&self) -> Vec<WorkerId> {
match self { match self {
Self::KvIndexer(indexer) => { Self::KvIndexer { primary, .. } => {
let (resp_tx, resp_rx) = oneshot::channel(); let (resp_tx, resp_rx) = oneshot::channel();
let req = dynamo_kv_router::indexer::GetWorkersRequest { resp: resp_tx }; let req = dynamo_kv_router::indexer::GetWorkersRequest { resp: resp_tx };
if let Err(e) = indexer.get_workers_sender().send(req).await { if let Err(e) = primary.get_workers_sender().send(req).await {
tracing::warn!("Failed to send get_workers request: {e}"); tracing::warn!("Failed to send get_workers request: {e}");
return Vec::new(); return Vec::new();
} }
resp_rx.await.unwrap_or_default() resp_rx.await.unwrap_or_default()
} }
Self::Concurrent(tpi) => tpi.backend().get_workers(), Self::Concurrent { primary, .. } => primary.backend().get_workers(),
Self::Remote(_) | Self::None => Vec::new(), Self::Remote(_) | Self::None => Vec::new(),
} }
} }
} }
#[cfg(test)]
mod tests {
use std::sync::Arc;
use tokio_util::sync::CancellationToken;
use super::{Indexer, LowerTierIndexers};
use dynamo_kv_router::{
ConcurrentRadixTreeCompressed, ThreadPoolIndexer,
indexer::{KvIndexer, KvIndexerInterface, KvIndexerMetrics},
protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash, RouterEvent, StorageTier, WorkerWithDpRank,
compute_seq_hash_for_block,
},
};
fn make_test_indexer() -> Indexer {
Indexer::KvIndexer {
primary: KvIndexer::new(
CancellationToken::new(),
4,
Arc::new(KvIndexerMetrics::new_unregistered()),
),
lower_tier: LowerTierIndexers::new(1, 4),
}
}
fn make_test_concurrent_indexer() -> Indexer {
Indexer::Concurrent {
primary: Arc::new(ThreadPoolIndexer::new(
ConcurrentRadixTreeCompressed::new(),
2,
4,
)),
lower_tier: LowerTierIndexers::new(2, 4),
}
}
async fn flush_indexer(indexer: &Indexer) {
match indexer {
Indexer::KvIndexer {
primary,
lower_tier,
} => {
let _ = primary.flush().await;
for indexer in lower_tier.all() {
let _ = indexer.dump_events().await.unwrap();
}
}
Indexer::Concurrent {
primary,
lower_tier,
} => {
primary.flush().await;
for indexer in lower_tier.all() {
let _ = indexer.dump_events().await.unwrap();
}
}
Indexer::Remote(_) | Indexer::None => {}
}
}
fn store_event(
worker_id: u64,
dp_rank: u32,
event_id: u64,
prefix_hashes: &[u64],
local_hashes: &[u64],
storage_tier: StorageTier,
) -> RouterEvent {
let prefix_block_hashes: Vec<LocalBlockHash> =
prefix_hashes.iter().copied().map(LocalBlockHash).collect();
let parent_hash = compute_seq_hash_for_block(&prefix_block_hashes)
.last()
.copied()
.map(ExternalSequenceBlockHash);
let full_hashes: Vec<LocalBlockHash> = prefix_hashes
.iter()
.chain(local_hashes.iter())
.copied()
.map(LocalBlockHash)
.collect();
let full_sequence_hashes = compute_seq_hash_for_block(&full_hashes);
let new_sequence_hashes = &full_sequence_hashes[prefix_hashes.len()..];
let blocks = local_hashes
.iter()
.zip(new_sequence_hashes.iter())
.map(|(&local_hash, &sequence_hash)| KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(sequence_hash),
tokens_hash: LocalBlockHash(local_hash),
mm_extra_info: None,
})
.collect();
RouterEvent::with_storage_tier(
worker_id,
KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
start_position: None,
blocks,
}),
dp_rank,
},
storage_tier,
)
}
#[tokio::test]
async fn tiered_query_chains_device_host_and_disk() {
let indexer = make_test_indexer();
let worker = WorkerWithDpRank::new(7, 0);
indexer
.apply_event(store_event(7, 0, 1, &[], &[11, 12], StorageTier::Device))
.await;
indexer
.apply_event(store_event(
7,
0,
2,
&[11, 12],
&[13],
StorageTier::HostPinned,
))
.await;
indexer
.apply_event(store_event(
7,
0,
3,
&[11, 12, 13],
&[14],
StorageTier::Disk,
))
.await;
flush_indexer(&indexer).await;
let matches = indexer
.find_matches_by_tier(vec![
LocalBlockHash(11),
LocalBlockHash(12),
LocalBlockHash(13),
LocalBlockHash(14),
])
.await
.unwrap();
assert_eq!(matches.device.overlap_scores.scores.get(&worker), Some(&2));
assert_eq!(
matches
.lower_tier
.get(&StorageTier::HostPinned)
.and_then(|tier| tier.hits.get(&worker)),
Some(&1)
);
assert_eq!(
matches
.lower_tier
.get(&StorageTier::Disk)
.and_then(|tier| tier.hits.get(&worker)),
Some(&1)
);
}
#[tokio::test]
async fn tiered_query_seeds_lower_tier_only_workers_without_affecting_device_scores() {
let indexer = make_test_indexer();
let device_worker = WorkerWithDpRank::new(10, 0);
let host_only_worker = WorkerWithDpRank::new(20, 0);
let disk_only_worker = WorkerWithDpRank::new(30, 0);
indexer
.apply_event(store_event(10, 0, 1, &[], &[21], StorageTier::Device))
.await;
indexer
.apply_event(store_event(20, 0, 2, &[], &[21], StorageTier::HostPinned))
.await;
indexer
.apply_event(store_event(30, 0, 3, &[], &[21], StorageTier::Disk))
.await;
flush_indexer(&indexer).await;
let matches = indexer
.find_matches_by_tier(vec![LocalBlockHash(21)])
.await
.unwrap();
assert_eq!(
matches.device.overlap_scores.scores.get(&device_worker),
Some(&1)
);
assert!(
!matches
.device
.overlap_scores
.scores
.contains_key(&host_only_worker)
);
assert!(
!matches
.device
.overlap_scores
.scores
.contains_key(&disk_only_worker)
);
assert_eq!(
matches
.lower_tier
.get(&StorageTier::HostPinned)
.and_then(|tier| tier.hits.get(&host_only_worker)),
Some(&1)
);
assert_eq!(
matches
.lower_tier
.get(&StorageTier::Disk)
.and_then(|tier| tier.hits.get(&disk_only_worker)),
Some(&1)
);
}
#[tokio::test]
async fn tiered_query_only_seeds_matching_root_workers() {
let indexer = make_test_indexer();
let matching_host_worker = WorkerWithDpRank::new(20, 0);
let nonmatching_host_worker = WorkerWithDpRank::new(21, 0);
indexer
.apply_event(store_event(20, 0, 1, &[], &[31], StorageTier::HostPinned))
.await;
indexer
.apply_event(store_event(21, 0, 2, &[], &[32], StorageTier::HostPinned))
.await;
flush_indexer(&indexer).await;
let matches = indexer
.find_matches_by_tier(vec![LocalBlockHash(31)])
.await
.unwrap();
assert_eq!(
matches
.lower_tier
.get(&StorageTier::HostPinned)
.and_then(|tier| tier.hits.get(&matching_host_worker)),
Some(&1)
);
assert!(
!matches
.lower_tier
.get(&StorageTier::HostPinned)
.is_some_and(|tier| tier.hits.contains_key(&nonmatching_host_worker))
);
}
#[tokio::test]
async fn concurrent_tiered_query_chains_device_and_lower_tier_matches() {
let indexer = make_test_concurrent_indexer();
let worker = WorkerWithDpRank::new(7, 0);
indexer
.apply_event(store_event(7, 0, 1, &[], &[11, 12], StorageTier::Device))
.await;
indexer
.apply_event(store_event(
7,
0,
2,
&[11, 12],
&[13],
StorageTier::HostPinned,
))
.await;
flush_indexer(&indexer).await;
let matches = indexer
.find_matches_by_tier(vec![
LocalBlockHash(11),
LocalBlockHash(12),
LocalBlockHash(13),
])
.await
.unwrap();
assert_eq!(matches.device.overlap_scores.scores.get(&worker), Some(&2));
assert_eq!(
matches
.lower_tier
.get(&StorageTier::HostPinned)
.and_then(|tier| tier.hits.get(&worker)),
Some(&1)
);
}
#[tokio::test]
async fn concurrent_tiered_query_seeds_lower_tier_only_workers_without_affecting_device_scores()
{
let indexer = make_test_concurrent_indexer();
let device_worker = WorkerWithDpRank::new(10, 0);
let host_only_worker = WorkerWithDpRank::new(20, 0);
let disk_only_worker = WorkerWithDpRank::new(30, 0);
indexer
.apply_event(store_event(10, 0, 1, &[], &[21], StorageTier::Device))
.await;
indexer
.apply_event(store_event(20, 0, 2, &[], &[21], StorageTier::HostPinned))
.await;
indexer
.apply_event(store_event(30, 0, 3, &[], &[21], StorageTier::Disk))
.await;
flush_indexer(&indexer).await;
let matches = indexer
.find_matches_by_tier(vec![LocalBlockHash(21)])
.await
.unwrap();
assert_eq!(
matches.device.overlap_scores.scores.get(&device_worker),
Some(&1)
);
assert!(
!matches
.device
.overlap_scores
.scores
.contains_key(&host_only_worker)
);
assert!(
!matches
.device
.overlap_scores
.scores
.contains_key(&disk_only_worker)
);
assert_eq!(
matches
.lower_tier
.get(&StorageTier::HostPinned)
.and_then(|tier| tier.hits.get(&host_only_worker)),
Some(&1)
);
assert_eq!(
matches
.lower_tier
.get(&StorageTier::Disk)
.and_then(|tier| tier.hits.get(&disk_only_worker)),
Some(&1)
);
}
/// Regression test: when a worker has blocks in both device and lower-tier
/// storage (e.g. same prefix stored on GPU and offloaded to host), the
/// Concurrent indexer doesn't return last_matched_hashes. Without the fix,
/// query_lower_tiers would re-query that worker from root in the lower tier,
/// double-counting overlap blocks and producing cached_tokens > ISL.
#[tokio::test]
async fn concurrent_tiered_query_does_not_double_count_device_and_lower_tier_overlap() {
let indexer = make_test_concurrent_indexer();
let worker = WorkerWithDpRank::new(7, 0);
// Worker has the same blocks in both device and host-pinned storage.
indexer
.apply_event(store_event(
7,
0,
1,
&[],
&[11, 12, 13],
StorageTier::Device,
))
.await;
indexer
.apply_event(store_event(
7,
0,
2,
&[],
&[11, 12, 13],
StorageTier::HostPinned,
))
.await;
flush_indexer(&indexer).await;
let matches = indexer
.find_matches_by_tier(vec![
LocalBlockHash(11),
LocalBlockHash(12),
LocalBlockHash(13),
])
.await
.unwrap();
// Device overlap should be 3 blocks.
assert_eq!(matches.device.overlap_scores.scores.get(&worker), Some(&3));
// Lower-tier must NOT report additional hits for the same worker
// whose blocks are already fully accounted for in the device tier.
let host_hits = matches
.lower_tier
.get(&StorageTier::HostPinned)
.and_then(|tier| tier.hits.get(&worker).copied())
.unwrap_or(0);
assert_eq!(
host_hits, 0,
"lower-tier should not double-count blocks already matched in device tier \
(got {host_hits} host-pinned hits for a worker with full device overlap)"
);
}
#[tokio::test]
async fn concurrent_remove_worker_removes_lower_tier_state() {
let indexer = make_test_concurrent_indexer();
let worker = WorkerWithDpRank::new(20, 0);
indexer
.apply_event(store_event(20, 0, 1, &[], &[31], StorageTier::HostPinned))
.await;
flush_indexer(&indexer).await;
let before = indexer
.find_matches_by_tier(vec![LocalBlockHash(31)])
.await
.unwrap();
assert_eq!(
before
.lower_tier
.get(&StorageTier::HostPinned)
.and_then(|tier| tier.hits.get(&worker)),
Some(&1)
);
indexer.remove_worker(20).await;
flush_indexer(&indexer).await;
let after = indexer
.find_matches_by_tier(vec![LocalBlockHash(31)])
.await
.unwrap();
assert!(
!after
.lower_tier
.get(&StorageTier::HostPinned)
.is_some_and(|tier| tier.hits.contains_key(&worker))
);
}
}
...@@ -823,7 +823,7 @@ impl Drop for SlowQueryGuard { ...@@ -823,7 +823,7 @@ impl Drop for SlowQueryGuard {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::kv_router::Indexer; use crate::kv_router::{Indexer, indexer::LowerTierIndexers};
use dynamo_kv_router::indexer::{KvIndexer, KvIndexerInterface, KvIndexerMetrics}; use dynamo_kv_router::indexer::{KvIndexer, KvIndexerInterface, KvIndexerMetrics};
use dynamo_kv_router::protocols::{ use dynamo_kv_router::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheStoreData, ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheStoreData,
...@@ -923,7 +923,13 @@ mod tests { ...@@ -923,7 +923,13 @@ mod tests {
let token = CancellationToken::new(); let token = CancellationToken::new();
let metrics = Arc::new(KvIndexerMetrics::new_unregistered()); let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
let kv_indexer = KvIndexer::new(token, 4, metrics); let kv_indexer = KvIndexer::new(token, 4, metrics);
(kv_indexer.clone(), Indexer::KvIndexer(kv_indexer)) (
kv_indexer.clone(),
Indexer::KvIndexer {
primary: kv_indexer,
lower_tier: LowerTierIndexers::new(1, 4),
},
)
} }
async fn make_test_client( async fn make_test_client(
......
...@@ -287,7 +287,6 @@ impl PrefillRouter { ...@@ -287,7 +287,6 @@ impl PrefillRouter {
lora_name, lora_name,
priority_jump, priority_jump,
None, None,
None,
allowed_worker_ids, allowed_worker_ids,
) )
.await?; .await?;
......
...@@ -31,35 +31,48 @@ use super::{DEFAULT_MAX_BATCH_BLOCKS, kv_publisher_metrics}; ...@@ -31,35 +31,48 @@ use super::{DEFAULT_MAX_BATCH_BLOCKS, kv_publisher_metrics};
/// - **Remove**: only passes through when refcount decrements to 0. /// - **Remove**: only passes through when refcount decrements to 0.
/// - **Cleared**: resets refcounts for all ranks. /// - **Cleared**: resets refcounts for all ranks.
pub(super) struct EventDedupFilter { pub(super) struct EventDedupFilter {
/// Per-dp-rank refcounts. /// Per-(dp_rank, storage_tier) refcounts.
per_rank: HashMap<u32, HashMap<ExternalSequenceBlockHash, usize>>, per_rank_tier: HashMap<(u32, StorageTier), HashMap<ExternalSequenceBlockHash, usize>>,
} }
impl EventDedupFilter { impl EventDedupFilter {
pub(super) fn new() -> Self { pub(super) fn new() -> Self {
Self { Self {
per_rank: HashMap::new(), per_rank_tier: HashMap::new(),
} }
} }
/// Track a store event. Increments refcount for each block hash on the /// Track a store event. Increments refcount for each block hash on the
/// given DP rank. Stores always pass through — this only updates bookkeeping. /// given (DP rank, storage tier). Stores always pass through — this only
pub(super) fn track_store(&mut self, dp_rank: u32, data: &KvCacheStoreData) { /// updates bookkeeping.
let refcounts = self.per_rank.entry(dp_rank).or_default(); pub(super) fn track_store(
&mut self,
dp_rank: u32,
storage_tier: StorageTier,
data: &KvCacheStoreData,
) {
let refcounts = self
.per_rank_tier
.entry((dp_rank, storage_tier))
.or_default();
for block in &data.blocks { for block in &data.blocks {
*refcounts.entry(block.block_hash).or_insert(0) += 1; *refcounts.entry(block.block_hash).or_insert(0) += 1;
} }
} }
/// Filter a remove event. Retains only block hashes whose refcount on the /// Filter a remove event. Retains only block hashes whose refcount on the
/// given DP rank decrements to 0 (removing them from the map). Returns /// given (DP rank, storage tier) decrements to 0 (removing them from the
/// `None` if no hashes survive filtering. /// map). Returns `None` if no hashes survive filtering.
pub(super) fn filter_remove( pub(super) fn filter_remove(
&mut self, &mut self,
dp_rank: u32, dp_rank: u32,
storage_tier: StorageTier,
mut data: KvCacheRemoveData, mut data: KvCacheRemoveData,
) -> Option<KvCacheRemoveData> { ) -> Option<KvCacheRemoveData> {
let refcounts = self.per_rank.entry(dp_rank).or_default(); let refcounts = self
.per_rank_tier
.entry((dp_rank, storage_tier))
.or_default();
data.block_hashes.retain(|hash| { data.block_hashes.retain(|hash| {
match refcounts.entry(*hash) { match refcounts.entry(*hash) {
Entry::Occupied(mut entry) => { Entry::Occupied(mut entry) => {
...@@ -83,11 +96,11 @@ impl EventDedupFilter { ...@@ -83,11 +96,11 @@ impl EventDedupFilter {
} }
} }
/// Clear refcounts for all DP ranks. A `Cleared` event from any rank /// Clear refcounts for all DP ranks and tiers. A `Cleared` event from any
/// causes the indexer to wipe all blocks for the entire worker, so we /// rank causes the indexer to wipe all blocks for the entire worker, so we
/// must reset all ranks' refcounts to stay consistent. /// must reset all refcounts to stay consistent.
pub(super) fn clear(&mut self) { pub(super) fn clear(&mut self) {
self.per_rank.clear(); self.per_rank_tier.clear();
} }
} }
...@@ -99,6 +112,7 @@ pub(super) struct BatchingState { ...@@ -99,6 +112,7 @@ pub(super) struct BatchingState {
pub(super) pending_stored: Option<KvCacheStoreData>, pub(super) pending_stored: Option<KvCacheStoreData>,
pub(super) next_publish_id: u64, pub(super) next_publish_id: u64,
pub(super) last_dp_rank: u32, pub(super) last_dp_rank: u32,
pub(super) last_storage_tier: StorageTier,
pub(super) last_flush_time: Instant, pub(super) last_flush_time: Instant,
} }
...@@ -109,6 +123,7 @@ impl BatchingState { ...@@ -109,6 +123,7 @@ impl BatchingState {
pending_stored: None, pending_stored: None,
next_publish_id: 1, next_publish_id: 1,
last_dp_rank: 0, last_dp_rank: 0,
last_storage_tier: StorageTier::Device,
last_flush_time: Instant::now(), last_flush_time: Instant::now(),
} }
} }
...@@ -160,12 +175,13 @@ impl BatchingState { ...@@ -160,12 +175,13 @@ impl BatchingState {
let dp_rank = self.last_dp_rank; let dp_rank = self.last_dp_rank;
let mut emitted = false; let mut emitted = false;
if let Some(data) = self.pending_removed.take() if let Some(data) = self.pending_removed.take()
&& let Some(filtered) = dedup.filter_remove(dp_rank, data) && let Some(filtered) = dedup.filter_remove(dp_rank, self.last_storage_tier, data)
{ {
emit( emit(
publisher, publisher,
local_indexer, local_indexer,
worker_id, worker_id,
self.last_storage_tier,
KvCacheEvent { KvCacheEvent {
event_id: self.next_publish_id, event_id: self.next_publish_id,
data: KvCacheEventData::Removed(filtered), data: KvCacheEventData::Removed(filtered),
...@@ -176,11 +192,12 @@ impl BatchingState { ...@@ -176,11 +192,12 @@ impl BatchingState {
emitted = true; emitted = true;
} }
if let Some(data) = self.pending_stored.take() { if let Some(data) = self.pending_stored.take() {
dedup.track_store(dp_rank, &data); dedup.track_store(dp_rank, self.last_storage_tier, &data);
emit( emit(
publisher, publisher,
local_indexer, local_indexer,
worker_id, worker_id,
self.last_storage_tier,
KvCacheEvent { KvCacheEvent {
event_id: self.next_publish_id, event_id: self.next_publish_id,
data: KvCacheEventData::Stored(data), data: KvCacheEventData::Stored(data),
...@@ -217,9 +234,10 @@ async fn emit<P: RouterEventSink>( ...@@ -217,9 +234,10 @@ async fn emit<P: RouterEventSink>(
publisher: &P, publisher: &P,
local_indexer: &Option<Arc<LocalKvIndexer>>, local_indexer: &Option<Arc<LocalKvIndexer>>,
worker_id: u64, worker_id: u64,
storage_tier: StorageTier,
event: KvCacheEvent, event: KvCacheEvent,
) { ) {
let router_event = RouterEvent::new(worker_id, event); let router_event = RouterEvent::with_storage_tier(worker_id, event, storage_tier);
if let Some(indexer) = local_indexer if let Some(indexer) = local_indexer
&& let Err(e) = indexer.apply_event_with_buffer(router_event.clone()).await && let Err(e) = indexer.apply_event_with_buffer(router_event.clone()).await
{ {
...@@ -281,16 +299,7 @@ pub(super) async fn run_event_processor_loop<P: RouterEventSink + Send + Sync + ...@@ -281,16 +299,7 @@ pub(super) async fn run_event_processor_loop<P: RouterEventSink + Send + Sync +
} }
last_raw_input_id = Some(raw_event_id); last_raw_input_id = Some(raw_event_id);
if !placement_event.placement.is_local_gpu() { let storage_tier = placement_event.placement.tier;
tracing::trace!(
worker_id,
?placement_event.placement,
event_id = placement_event.event.event_id,
"Skipping non-local-GPU placement event"
);
continue;
}
let event = placement_event.event; let event = placement_event.event;
tracing::trace!( tracing::trace!(
"Event processor for worker_id {} processing event: {:?}", "Event processor for worker_id {} processing event: {:?}",
...@@ -300,10 +309,15 @@ pub(super) async fn run_event_processor_loop<P: RouterEventSink + Send + Sync + ...@@ -300,10 +309,15 @@ pub(super) async fn run_event_processor_loop<P: RouterEventSink + Send + Sync +
let dp_rank_changed = let dp_rank_changed =
batching_state.has_pending() && event.dp_rank != batching_state.last_dp_rank; batching_state.has_pending() && event.dp_rank != batching_state.last_dp_rank;
let storage_tier_changed =
batching_state.has_pending() && storage_tier != batching_state.last_storage_tier;
match event.data { match event.data {
KvCacheEventData::Removed(data) => { KvCacheEventData::Removed(data) => {
if batching_state.pending_stored.is_some() || dp_rank_changed { if batching_state.pending_stored.is_some()
|| dp_rank_changed
|| storage_tier_changed
{
batching_state.flush(&publisher, &local_indexer, worker_id, &mut dedup).await; batching_state.flush(&publisher, &local_indexer, worker_id, &mut dedup).await;
} }
match &mut batching_state.pending_removed { match &mut batching_state.pending_removed {
...@@ -315,6 +329,7 @@ pub(super) async fn run_event_processor_loop<P: RouterEventSink + Send + Sync + ...@@ -315,6 +329,7 @@ pub(super) async fn run_event_processor_loop<P: RouterEventSink + Send + Sync +
} }
KvCacheEventData::Stored(data) => { KvCacheEventData::Stored(data) => {
let should_flush = dp_rank_changed let should_flush = dp_rank_changed
|| storage_tier_changed
|| batching_state.pending_removed.is_some() || batching_state.pending_removed.is_some()
|| batching_state.pending_stored.as_ref().is_some_and(|p| { || batching_state.pending_stored.as_ref().is_some_and(|p| {
data.parent_hash != p.blocks.last().map(|b| b.block_hash) data.parent_hash != p.blocks.last().map(|b| b.block_hash)
...@@ -336,6 +351,7 @@ pub(super) async fn run_event_processor_loop<P: RouterEventSink + Send + Sync + ...@@ -336,6 +351,7 @@ pub(super) async fn run_event_processor_loop<P: RouterEventSink + Send + Sync +
&publisher, &publisher,
&local_indexer, &local_indexer,
worker_id, worker_id,
storage_tier,
KvCacheEvent { KvCacheEvent {
event_id: batching_state.next_publish_id, event_id: batching_state.next_publish_id,
data: KvCacheEventData::Cleared, data: KvCacheEventData::Cleared,
...@@ -348,6 +364,7 @@ pub(super) async fn run_event_processor_loop<P: RouterEventSink + Send + Sync + ...@@ -348,6 +364,7 @@ pub(super) async fn run_event_processor_loop<P: RouterEventSink + Send + Sync +
} }
batching_state.last_dp_rank = event.dp_rank; batching_state.last_dp_rank = event.dp_rank;
batching_state.last_storage_tier = storage_tier;
if batching_state.has_pending() if batching_state.has_pending()
&& (timeout_ms.is_none_or(|ms| batching_state.is_timeout_elapsed(ms)) && (timeout_ms.is_none_or(|ms| batching_state.is_timeout_elapsed(ms))
......
...@@ -1314,15 +1314,15 @@ mod test_event_dedup_filter { ...@@ -1314,15 +1314,15 @@ mod test_event_dedup_filter {
let data = store_data(&[1, 2, 3]); let data = store_data(&[1, 2, 3]);
// Store same hashes twice — refcount should be 2 // Store same hashes twice — refcount should be 2
filter.track_store(0, &data); filter.track_store(0, StorageTier::Device, &data);
filter.track_store(0, &data); filter.track_store(0, StorageTier::Device, &data);
// First remove — refcounts 2→1, all filtered out // First remove — refcounts 2→1, all filtered out
let result = filter.filter_remove(0, remove_data(&[1, 2, 3])); let result = filter.filter_remove(0, StorageTier::Device, remove_data(&[1, 2, 3]));
assert!(result.is_none()); assert!(result.is_none());
// Second remove — refcounts 1→0, all pass through // Second remove — refcounts 1→0, all pass through
let result = filter.filter_remove(0, remove_data(&[1, 2, 3])); let result = filter.filter_remove(0, StorageTier::Device, remove_data(&[1, 2, 3]));
assert!(result.is_some()); assert!(result.is_some());
assert_eq!(result.unwrap().block_hashes.len(), 3); assert_eq!(result.unwrap().block_hashes.len(), 3);
} }
...@@ -1332,15 +1332,15 @@ mod test_event_dedup_filter { ...@@ -1332,15 +1332,15 @@ mod test_event_dedup_filter {
let mut filter = EventDedupFilter::new(); let mut filter = EventDedupFilter::new();
// Store same hash twice // Store same hash twice
filter.track_store(0, &store_data(&[1])); filter.track_store(0, StorageTier::Device, &store_data(&[1]));
filter.track_store(0, &store_data(&[1])); filter.track_store(0, StorageTier::Device, &store_data(&[1]));
// First remove — refcount 2→1, filtered out // First remove — refcount 2→1, filtered out
let result = filter.filter_remove(0, remove_data(&[1])); let result = filter.filter_remove(0, StorageTier::Device, remove_data(&[1]));
assert!(result.is_none()); assert!(result.is_none());
// Second remove — refcount 1→0, passes through // Second remove — refcount 1→0, passes through
let result = filter.filter_remove(0, remove_data(&[1])); let result = filter.filter_remove(0, StorageTier::Device, remove_data(&[1]));
assert!(result.is_some()); assert!(result.is_some());
assert_eq!(result.unwrap().block_hashes.len(), 1); assert_eq!(result.unwrap().block_hashes.len(), 1);
} }
...@@ -1350,17 +1350,17 @@ mod test_event_dedup_filter { ...@@ -1350,17 +1350,17 @@ mod test_event_dedup_filter {
let mut filter = EventDedupFilter::new(); let mut filter = EventDedupFilter::new();
// Store hash 1 // Store hash 1
filter.track_store(0, &store_data(&[1])); filter.track_store(0, StorageTier::Device, &store_data(&[1]));
// Remove hash 1 — refcount 1→0, passes through // Remove hash 1 — refcount 1→0, passes through
let result = filter.filter_remove(0, remove_data(&[1])); let result = filter.filter_remove(0, StorageTier::Device, remove_data(&[1]));
assert!(result.is_some()); assert!(result.is_some());
// Store hash 1 again — refcount starts fresh at 1 // Store hash 1 again — refcount starts fresh at 1
filter.track_store(0, &store_data(&[1])); filter.track_store(0, StorageTier::Device, &store_data(&[1]));
// Remove again — refcount 1→0, passes through // Remove again — refcount 1→0, passes through
let result = filter.filter_remove(0, remove_data(&[1])); let result = filter.filter_remove(0, StorageTier::Device, remove_data(&[1]));
assert!(result.is_some()); assert!(result.is_some());
} }
...@@ -1369,20 +1369,20 @@ mod test_event_dedup_filter { ...@@ -1369,20 +1369,20 @@ mod test_event_dedup_filter {
let mut filter = EventDedupFilter::new(); let mut filter = EventDedupFilter::new();
// Store on rank 0 and rank 1 // Store on rank 0 and rank 1
filter.track_store(0, &store_data(&[1, 2])); filter.track_store(0, StorageTier::Device, &store_data(&[1, 2]));
filter.track_store(0, &store_data(&[1, 2])); filter.track_store(0, StorageTier::Device, &store_data(&[1, 2]));
filter.track_store(1, &store_data(&[1, 2])); filter.track_store(1, StorageTier::Device, &store_data(&[1, 2]));
filter.track_store(1, &store_data(&[1, 2])); filter.track_store(1, StorageTier::Device, &store_data(&[1, 2]));
// Clear wipes all ranks (matches indexer semantics where Cleared // Clear wipes all ranks (matches indexer semantics where Cleared
// from any rank removes all blocks for the entire worker). // from any rank removes all blocks for the entire worker).
filter.clear(); filter.clear();
// Both ranks pass through defensively after clear // Both ranks pass through defensively after clear
let result = filter.filter_remove(0, remove_data(&[1])); let result = filter.filter_remove(0, StorageTier::Device, remove_data(&[1]));
assert!(result.is_some()); assert!(result.is_some());
let result = filter.filter_remove(1, remove_data(&[1])); let result = filter.filter_remove(1, StorageTier::Device, remove_data(&[1]));
assert!(result.is_some()); assert!(result.is_some());
} }
...@@ -1391,18 +1391,18 @@ mod test_event_dedup_filter { ...@@ -1391,18 +1391,18 @@ mod test_event_dedup_filter {
let mut filter = EventDedupFilter::new(); let mut filter = EventDedupFilter::new();
// Hash 1: stored twice (refcount 2) // Hash 1: stored twice (refcount 2)
filter.track_store(0, &store_data(&[1])); filter.track_store(0, StorageTier::Device, &store_data(&[1]));
filter.track_store(0, &store_data(&[1])); filter.track_store(0, StorageTier::Device, &store_data(&[1]));
// Hash 2: stored once (refcount 1) // Hash 2: stored once (refcount 1)
filter.track_store(0, &store_data(&[2])); filter.track_store(0, StorageTier::Device, &store_data(&[2]));
// Hash 3: stored twice (refcount 2) // Hash 3: stored twice (refcount 2)
filter.track_store(0, &store_data(&[3])); filter.track_store(0, StorageTier::Device, &store_data(&[3]));
filter.track_store(0, &store_data(&[3])); filter.track_store(0, StorageTier::Device, &store_data(&[3]));
// Remove all three — only hash 2 (refcount 1→0) passes through // Remove all three — only hash 2 (refcount 1→0) passes through
let result = filter.filter_remove(0, remove_data(&[1, 2, 3])); let result = filter.filter_remove(0, StorageTier::Device, remove_data(&[1, 2, 3]));
assert!(result.is_some()); assert!(result.is_some());
let result = result.unwrap(); let result = result.unwrap();
assert_eq!(result.block_hashes.len(), 1); assert_eq!(result.block_hashes.len(), 1);
...@@ -1414,20 +1414,20 @@ mod test_event_dedup_filter { ...@@ -1414,20 +1414,20 @@ mod test_event_dedup_filter {
let mut filter = EventDedupFilter::new(); let mut filter = EventDedupFilter::new();
// Store hash 1 on rank 0 (twice) and rank 1 (once) // Store hash 1 on rank 0 (twice) and rank 1 (once)
filter.track_store(0, &store_data(&[1])); filter.track_store(0, StorageTier::Device, &store_data(&[1]));
filter.track_store(0, &store_data(&[1])); filter.track_store(0, StorageTier::Device, &store_data(&[1]));
filter.track_store(1, &store_data(&[1])); filter.track_store(1, StorageTier::Device, &store_data(&[1]));
// Remove hash 1 on rank 1 — refcount 1→0, passes through // Remove hash 1 on rank 1 — refcount 1→0, passes through
let result = filter.filter_remove(1, remove_data(&[1])); let result = filter.filter_remove(1, StorageTier::Device, remove_data(&[1]));
assert!(result.is_some()); assert!(result.is_some());
// Remove hash 1 on rank 0 — refcount 2→1, filtered out // Remove hash 1 on rank 0 — refcount 2→1, filtered out
let result = filter.filter_remove(0, remove_data(&[1])); let result = filter.filter_remove(0, StorageTier::Device, remove_data(&[1]));
assert!(result.is_none()); assert!(result.is_none());
// Remove hash 1 on rank 0 again — refcount 1→0, passes through // Remove hash 1 on rank 0 again — refcount 1→0, passes through
let result = filter.filter_remove(0, remove_data(&[1])); let result = filter.filter_remove(0, StorageTier::Device, remove_data(&[1]));
assert!(result.is_some()); assert!(result.is_some());
} }
} }
...@@ -1724,6 +1724,13 @@ mod event_processor_tests { ...@@ -1724,6 +1724,13 @@ mod event_processor_tests {
PlacementEvent::local_gpu(1, event) PlacementEvent::local_gpu(1, event)
} }
fn local_host_event(event: KvCacheEvent) -> PlacementEvent {
PlacementEvent::new(
Placement::local_worker(1, event.dp_rank, StorageTier::HostPinned),
event,
)
}
/// Test that pushing N removed events results in batched output /// Test that pushing N removed events results in batched output
/// Uses a 10ms timeout to ensure events are batched (events sent rapidly) /// Uses a 10ms timeout to ensure events are batched (events sent rapidly)
#[tokio::test] #[tokio::test]
...@@ -2287,6 +2294,106 @@ mod event_processor_tests { ...@@ -2287,6 +2294,106 @@ mod event_processor_tests {
); );
} }
#[tokio::test]
async fn test_host_tier_events_are_published_and_preserved() {
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
let publisher = MockPublisher::new();
let publisher_clone = publisher.clone();
let cancellation_token = CancellationToken::new();
let handle = tokio::spawn(async move {
run_event_processor_loop(
publisher_clone,
1,
cancellation_token,
rx,
None,
Some(100),
DEFAULT_MAX_BATCH_BLOCKS,
)
.await
});
tx.send(local_host_event(KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(42)],
}),
dp_rank: 0,
}))
.unwrap();
drop(tx);
handle.await.unwrap();
let events = publisher.get_events();
assert_eq!(
events.len(),
1,
"Expected a single published host-tier event"
);
assert_eq!(events[0].storage_tier, StorageTier::HostPinned);
let KvCacheEventData::Removed(data) = &events[0].event.data else {
panic!("Expected Removed event");
};
assert_eq!(data.block_hashes, vec![ExternalSequenceBlockHash(42)]);
}
#[tokio::test]
async fn test_storage_tier_change_causes_flush() {
let timeout_ms = Some(100);
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
let publisher = MockPublisher::new();
let publisher_clone = publisher.clone();
let cancellation_token = CancellationToken::new();
let handle = tokio::spawn(async move {
run_event_processor_loop(
publisher_clone,
1,
cancellation_token,
rx,
None,
timeout_ms,
DEFAULT_MAX_BATCH_BLOCKS,
)
.await
});
tx.send(local_host_event(KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(1)],
}),
dp_rank: 0,
}))
.unwrap();
tokio::task::yield_now().await;
tx.send(local_gpu_event(KvCacheEvent {
event_id: 1,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(2)],
}),
dp_rank: 0,
}))
.unwrap();
drop(tx);
handle.await.unwrap();
let events = publisher.get_events();
assert_eq!(
events.len(),
2,
"Changing storage tier should flush the current batch"
);
assert_eq!(events[0].storage_tier, StorageTier::HostPinned);
assert_eq!(events[1].storage_tier, StorageTier::Device);
}
/// Test that dp_rank change causes immediate flush /// Test that dp_rank change causes immediate flush
#[tokio::test] #[tokio::test]
async fn test_dp_rank_change_causes_flush() { async fn test_dp_rank_change_causes_flush() {
......
...@@ -45,26 +45,13 @@ pub struct KvPushRouter { ...@@ -45,26 +45,13 @@ pub struct KvPushRouter {
/// Result of worker selection containing instance ID, dp_rank, and overlap amount. /// Result of worker selection containing instance ID, dp_rank, and overlap amount.
struct WorkerSelection { struct WorkerSelection {
instance_id: u64, instance_id: u64,
backend_dp_rank: Option<u32>, dp_rank: u32,
bookkeeping_dp_rank: Option<u32>, overlap_amount: u32,
overlap_amount: Option<u32>, effective_overlap_blocks: f64,
} cached_tokens: usize,
/// Whether the scheduler is tracking this request (add_request or
fn pinned_worker_hint( /// find_best_match_details with update_states=true was called).
phase: RequestPhase, scheduler_tracked: bool,
routing: Option<&RoutingHints>,
) -> Option<(u64, Option<u32>)> {
let routing = routing?;
let worker_id = match phase {
RequestPhase::Prefill => routing.prefill_worker_id.or(routing.backend_instance_id),
RequestPhase::Decode => routing.decode_worker_id.or(routing.backend_instance_id),
RequestPhase::Aggregated => routing.backend_instance_id,
}?;
let dp_rank = match phase {
RequestPhase::Prefill => routing.prefill_dp_rank.or(routing.dp_rank),
RequestPhase::Decode | RequestPhase::Aggregated => routing.dp_rank,
};
Some((worker_id, dp_rank))
} }
/// Drop guard that manages the full lifecycle of a routed request: /// Drop guard that manages the full lifecycle of a routed request:
...@@ -318,9 +305,9 @@ impl KvPushRouter { ...@@ -318,9 +305,9 @@ impl KvPushRouter {
let (routing_token_ids, block_mm_infos) = request.block_mm_routing_info(); let (routing_token_ids, block_mm_infos) = request.block_mm_routing_info();
let Some((pinned_worker_id, requested_dp_rank)) = pinned_worker_hint(phase, routing) else { let Some((pinned_worker_id, requested_dp_rank)) = pinned_worker_hint(phase, routing) else {
let _nvtx_kv = dynamo_nvtx_range!("route.kv_match"); let _nvtx_kv = dynamo_nvtx_range!("route.kv_match");
let (best_worker, overlap_amount) = self let selection = self
.chooser .chooser
.find_best_match( .find_best_match_details(
Some(context_id), Some(context_id),
routing_token_ids, routing_token_ids,
block_mm_infos, block_mm_infos,
...@@ -333,6 +320,10 @@ impl KvPushRouter { ...@@ -333,6 +320,10 @@ impl KvPushRouter {
allowed_worker_ids, allowed_worker_ids,
) )
.await?; .await?;
let best_worker = selection.worker;
let effective_overlap_blocks = selection.cache_hit.effective_overlap_blocks;
let cached_tokens = selection.cache_hit.cached_tokens;
let overlap_amount = selection.cache_hit.rounded_overlap_blocks();
if !is_query_only { if !is_query_only {
let total_blocks = routing_token_ids let total_blocks = routing_token_ids
...@@ -357,20 +348,22 @@ impl KvPushRouter { ...@@ -357,20 +348,22 @@ impl KvPushRouter {
return Ok(WorkerSelection { return Ok(WorkerSelection {
instance_id: best_worker.worker_id, instance_id: best_worker.worker_id,
backend_dp_rank: Some(best_worker.dp_rank), dp_rank: best_worker.dp_rank,
bookkeeping_dp_rank: Some(best_worker.dp_rank), overlap_amount,
overlap_amount: Some(overlap_amount), effective_overlap_blocks,
cached_tokens,
scheduler_tracked: !is_query_only,
}); });
}; };
let resolved_pinned_worker = requested_dp_rank let resolved_pinned_worker: Option<WorkerWithDpRank> = requested_dp_rank
.or_else(|| self.chooser.unique_dp_rank_for_worker(pinned_worker_id)) .or_else(|| self.chooser.unique_dp_rank_for_worker(pinned_worker_id))
.map(|dp_rank| WorkerWithDpRank::new(pinned_worker_id, dp_rank)); .map(|dp_rank| WorkerWithDpRank::new(pinned_worker_id, dp_rank));
if !is_query_only && let Some(pinned_worker) = resolved_pinned_worker { if !is_query_only && let Some(pinned_worker) = resolved_pinned_worker {
let (best_worker, overlap_amount) = self let selection = self
.chooser .chooser
.find_best_match( .find_best_match_details(
Some(context_id), Some(context_id),
routing_token_ids, routing_token_ids,
block_mm_infos, block_mm_infos,
...@@ -383,43 +376,60 @@ impl KvPushRouter { ...@@ -383,43 +376,60 @@ impl KvPushRouter {
allowed_worker_ids, allowed_worker_ids,
) )
.await?; .await?;
let best_worker = selection.worker;
let effective_overlap_blocks = selection.cache_hit.effective_overlap_blocks;
let cached_tokens = selection.cache_hit.cached_tokens;
let overlap_amount = selection.cache_hit.rounded_overlap_blocks();
return Ok(WorkerSelection { return Ok(WorkerSelection {
instance_id: best_worker.worker_id, instance_id: best_worker.worker_id,
backend_dp_rank: Some(best_worker.dp_rank), dp_rank: best_worker.dp_rank,
bookkeeping_dp_rank: Some(best_worker.dp_rank), overlap_amount,
overlap_amount: Some(overlap_amount), effective_overlap_blocks,
cached_tokens,
scheduler_tracked: true,
}); });
} }
let backend_dp_rank = resolved_pinned_worker.map(|worker| worker.dp_rank); // Fallback: pinned worker hint was present but dp_rank could not be
// resolved (or this is a query-only request that skipped the scheduler
// path above). Estimate cache hit directly and, when possible, register
// the request with the scheduler for bookkeeping.
let resolved_dp_rank: Option<u32> = resolved_pinned_worker.map(|w| w.dp_rank);
tracing::debug!( tracing::debug!(
worker_id = pinned_worker_id, worker_id = pinned_worker_id,
dp_rank = ?backend_dp_rank, dp_rank = ?resolved_dp_rank,
?phase, ?phase,
"Routing to specified worker" "Routing to specified worker"
); );
let (bookkeeping_dp_rank, overlap_amount) = if let Some(dp_rank) = backend_dp_rank { // Build a WorkerWithDpRank; use 0 as a fallback dp_rank when it
let worker = WorkerWithDpRank::new(pinned_worker_id, dp_rank); // couldn't be resolved -- this is only used for the cache-hit
let overlap_blocks = self // estimate query and won't affect scheduler state.
.chooser let effective_dp_rank = resolved_dp_rank.unwrap_or(0);
.get_overlap_blocks( let worker = WorkerWithDpRank::new(pinned_worker_id, effective_dp_rank);
routing_token_ids, let cache_hit = self
block_mm_infos, .chooser
worker, .get_cache_hit_estimate(
lora_name.as_deref(), routing_token_ids,
) block_mm_infos,
.await?; worker,
lora_name.as_deref(),
)
.await?;
let effective_overlap_blocks = cache_hit.effective_overlap_blocks;
let cached_tokens = cache_hit.cached_tokens;
let overlap_blocks = cache_hit.rounded_overlap_blocks();
if !is_query_only { if !is_query_only {
if let Some(_dp_rank) = resolved_dp_rank {
self.chooser self.chooser
.add_request( .add_request(
context_id.to_string(), context_id.to_string(),
routing_token_ids, routing_token_ids,
block_mm_infos, block_mm_infos,
overlap_blocks, cached_tokens,
expected_output_tokens, expected_output_tokens,
worker, worker,
lora_name, lora_name,
...@@ -430,27 +440,26 @@ impl KvPushRouter { ...@@ -430,27 +440,26 @@ impl KvPushRouter {
tracing::debug!( tracing::debug!(
request_id = %context_id, request_id = %context_id,
worker_id = pinned_worker_id, worker_id = pinned_worker_id,
dp_rank = dp_rank, ?phase,
"Skipping add_request - query-only request" "Routing to specified worker without resolved dp_rank; skipping scheduler bookkeeping"
); );
} }
(Some(dp_rank), Some(overlap_blocks))
} else { } else {
tracing::debug!( tracing::debug!(
request_id = %context_id, request_id = %context_id,
worker_id = pinned_worker_id, worker_id = pinned_worker_id,
?phase, dp_rank = ?resolved_dp_rank,
"Routing to specified worker without resolved dp_rank; skipping scheduler bookkeeping" "Skipping add_request - query-only request"
); );
(None, None) }
};
Ok(WorkerSelection { Ok(WorkerSelection {
instance_id: pinned_worker_id, instance_id: pinned_worker_id,
backend_dp_rank, dp_rank: effective_dp_rank,
bookkeeping_dp_rank, overlap_amount: overlap_blocks,
overlap_amount, effective_overlap_blocks,
cached_tokens,
scheduler_tracked: !is_query_only && resolved_dp_rank.is_some(),
}) })
} }
} }
...@@ -522,47 +531,40 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -522,47 +531,40 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
.await?; .await?;
let WorkerSelection { let WorkerSelection {
instance_id, instance_id,
backend_dp_rank, dp_rank,
bookkeeping_dp_rank,
overlap_amount, overlap_amount,
effective_overlap_blocks,
cached_tokens,
scheduler_tracked,
} = selection; } = selection;
let scheduler_tracked = !is_query_only && bookkeeping_dp_rank.is_some();
// In approximate mode (use_kv_events=false), record the routing decision // In approximate mode (use_kv_events=false), record the routing decision
// so the indexer can track cache state based on routing decisions. // so the indexer can track cache state based on routing decisions.
// This covers both pre-selected workers and find_best_match selections. // This covers both pre-selected workers and find_best_match selections.
if !is_query_only && !self.chooser.kv_router_config().use_kv_events { if !is_query_only && !self.chooser.kv_router_config().use_kv_events {
if let Some(dp_rank) = bookkeeping_dp_rank { let lora_name = request.routing.as_ref().and_then(|r| r.lora_name.clone());
let lora_name = request.routing.as_ref().and_then(|r| r.lora_name.clone()); let (routing_token_ids, block_mm_infos) = request.block_mm_routing_info();
let (routing_token_ids, block_mm_infos) = request.block_mm_routing_info(); let worker = WorkerWithDpRank::new(instance_id, dp_rank);
let worker = WorkerWithDpRank::new(instance_id, dp_rank); let mut tokens_with_hashes =
let mut tokens_with_hashes = TokensWithHashes::new(routing_token_ids.to_vec(), self.chooser.block_size())
TokensWithHashes::new(routing_token_ids.to_vec(), self.chooser.block_size()) .with_is_eagle(self.chooser.is_eagle());
.with_is_eagle(self.chooser.is_eagle()); if let Some(infos) = block_mm_infos {
if let Some(infos) = block_mm_infos { tokens_with_hashes = tokens_with_hashes.with_mm_infos(infos.to_vec());
tokens_with_hashes = tokens_with_hashes.with_mm_infos(infos.to_vec()); }
} if let Some(lora_name) = lora_name {
if let Some(lora_name) = lora_name { tokens_with_hashes = tokens_with_hashes.with_lora_name(lora_name);
tokens_with_hashes = tokens_with_hashes.with_lora_name(lora_name); }
} if let Err(e) = self
if let Err(e) = self .chooser
.chooser .record_routing_decision(tokens_with_hashes, worker)
.record_routing_decision(tokens_with_hashes, worker) .await
.await {
{ tracing::warn!(
tracing::warn!(
request_id = %context_id,
worker_id = instance_id,
dp_rank = dp_rank,
error = %e,
"Failed to record routing decision in approximate mode"
);
}
} else {
tracing::debug!(
request_id = %context_id, request_id = %context_id,
worker_id = instance_id, worker_id = instance_id,
"Skipping approximate-mode routing decision for unresolved dp_rank" dp_rank = dp_rank,
error = %e,
"Failed to record routing decision in approximate mode"
); );
} }
} }
...@@ -573,14 +575,9 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -573,14 +575,9 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
if let Some(ref tracker) = request.tracker { if let Some(ref tracker) = request.tracker {
let (routing_token_ids, _) = request.block_mm_routing_info(); let (routing_token_ids, _) = request.block_mm_routing_info();
let isl_blocks = routing_token_ids.len().div_ceil(block_size); let isl_blocks = routing_token_ids.len().div_ceil(block_size);
if let Some(overlap_amount) = overlap_amount { tracker.record_kv_hit(effective_overlap_blocks, isl_blocks);
tracker.record_kv_hit(overlap_amount, isl_blocks); tracker.record_isl(routing_token_ids.len(), Some(cached_tokens));
} tracker.record_worker(instance_id, Some(dp_rank), self.chooser.worker_type());
tracker.record_isl(
routing_token_ids.len(),
overlap_amount.map(|overlap| overlap as usize * block_size),
);
tracker.record_worker(instance_id, backend_dp_rank, self.chooser.worker_type());
tracker.record_router_queue_depth(self.chooser.pending_count()); tracker.record_router_queue_depth(self.chooser.pending_count());
if let Some(hit_rate) = tracker.kv_hit_rate() { if let Some(hit_rate) = tracker.kv_hit_rate() {
request_metrics.kv_hit_rate.observe(hit_rate); request_metrics.kv_hit_rate.observe(hit_rate);
...@@ -641,7 +638,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -641,7 +638,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
.await?; .await?;
let (mut backend_input, context) = request.into_parts(); let (mut backend_input, context) = request.into_parts();
backend_input.routing_mut().dp_rank = backend_dp_rank; backend_input.routing_mut().dp_rank = Some(dp_rank);
let updated_request = context.map(|_| backend_input); let updated_request = context.map(|_| backend_input);
// Record prefill start right before pushing to backend (OnceLock: first call wins). // Record prefill start right before pushing to backend (OnceLock: first call wins).
...@@ -691,8 +688,8 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -691,8 +688,8 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
"kv_router.route_request", "kv_router.route_request",
request_id = %context_id, request_id = %context_id,
worker_id = instance_id, worker_id = instance_id,
dp_rank = ?backend_dp_rank, dp_rank = dp_rank,
overlap_blocks = ?overlap_amount, overlap_blocks = overlap_amount,
phase = ?phase, phase = ?phase,
)) ))
.await?; .await?;
...@@ -734,6 +731,35 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -734,6 +731,35 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
} }
} }
/// Extract a phase-specific (worker_id, dp_rank) pin from routing hints.
///
/// Returns `Some((worker_id, optional_dp_rank))` when the request should be
/// pinned to a particular worker, or `None` when the normal KV-overlap
/// selection path should be used.
fn pinned_worker_hint(
phase: RequestPhase,
routing: Option<&RoutingHints>,
) -> Option<(u64, Option<u32>)> {
let routing = routing?;
match phase {
RequestPhase::Prefill => {
let worker_id = routing.prefill_worker_id.or(routing.backend_instance_id)?;
let dp_rank = routing.prefill_dp_rank.or(routing.dp_rank);
Some((worker_id, dp_rank))
}
RequestPhase::Decode => {
let worker_id = routing.decode_worker_id.or(routing.backend_instance_id)?;
let dp_rank = routing.dp_rank;
Some((worker_id, dp_rank))
}
RequestPhase::Aggregated => {
let worker_id = routing.backend_instance_id?;
let dp_rank = routing.dp_rank;
Some((worker_id, dp_rank))
}
}
}
/// A direct routing wrapper for `RouterMode::Direct`. /// A direct routing wrapper for `RouterMode::Direct`.
/// ///
/// This wraps a `PushRouter` and reads worker IDs from each request's routing hints, /// This wraps a `PushRouter` and reads worker IDs from each request's routing hints,
......
...@@ -5,6 +5,7 @@ use dynamo_kv_router::protocols::SharedCacheHits; ...@@ -5,6 +5,7 @@ use dynamo_kv_router::protocols::SharedCacheHits;
pub use dynamo_kv_router::scheduling::policy::RouterSchedulingPolicy; pub use dynamo_kv_router::scheduling::policy::RouterSchedulingPolicy;
pub use dynamo_kv_router::scheduling::{ pub use dynamo_kv_router::scheduling::{
KvSchedulerError, LocalScheduler, PotentialLoad, SchedulingRequest, SchedulingResponse, KvSchedulerError, LocalScheduler, PotentialLoad, SchedulingRequest, SchedulingResponse,
TierOverlapBlocks,
}; };
pub use dynamo_kv_router::selector::DefaultWorkerSelector; pub use dynamo_kv_router::selector::DefaultWorkerSelector;
use dynamo_kv_router::selector::WorkerSelector as WorkerSelectorTrait; use dynamo_kv_router::selector::WorkerSelector as WorkerSelectorTrait;
...@@ -19,7 +20,7 @@ use anyhow::Result; ...@@ -19,7 +20,7 @@ use anyhow::Result;
use dynamo_kv_router::{ use dynamo_kv_router::{
PrefillLoadEstimator, PrefillLoadEstimator,
config::{KvRouterConfig, RouterConfigOverride}, config::{KvRouterConfig, RouterConfigOverride},
protocols::{OverlapScores, WorkerId, WorkerWithDpRank}, protocols::{WorkerId, WorkerWithDpRank},
}; };
use dynamo_runtime::component::Component; use dynamo_runtime::component::Component;
use dynamo_runtime::traits::DistributedRuntimeProvider; use dynamo_runtime::traits::DistributedRuntimeProvider;
...@@ -70,8 +71,7 @@ where ...@@ -70,8 +71,7 @@ where
tracing::info!("skipping discovery-based worker monitoring"); tracing::info!("skipping discovery-based worker monitoring");
} }
let policy = let policy = RouterSchedulingPolicy::new(kv_router_config.router_queue_policy);
RouterSchedulingPolicy::new(kv_router_config.router_queue_policy, block_size as usize);
tracing::info!( tracing::info!(
"Router queue policy: {}", "Router queue policy: {}",
kv_router_config.router_queue_policy kv_router_config.router_queue_policy
...@@ -131,7 +131,10 @@ where ...@@ -131,7 +131,10 @@ where
maybe_request_id: Option<String>, maybe_request_id: Option<String>,
isl_tokens: usize, isl_tokens: usize,
token_seq: Option<Vec<SequenceHash>>, token_seq: Option<Vec<SequenceHash>>,
overlaps: OverlapScores, tier_overlap_blocks: TierOverlapBlocks,
effective_overlap_blocks: HashMap<dynamo_kv_router::protocols::WorkerWithDpRank, f64>,
effective_cached_tokens: HashMap<dynamo_kv_router::protocols::WorkerWithDpRank, usize>,
tree_sizes: HashMap<dynamo_kv_router::protocols::WorkerWithDpRank, usize>,
router_config_override: Option<&RouterConfigOverride>, router_config_override: Option<&RouterConfigOverride>,
update_states: bool, update_states: bool,
lora_name: Option<String>, lora_name: Option<String>,
...@@ -147,7 +150,10 @@ where ...@@ -147,7 +150,10 @@ where
maybe_request_id, maybe_request_id,
isl_tokens, isl_tokens,
token_seq, token_seq,
overlaps, tier_overlap_blocks,
effective_overlap_blocks,
effective_cached_tokens,
tree_sizes,
router_config_override, router_config_override,
update_states, update_states,
lora_name, lora_name,
...@@ -209,11 +215,15 @@ where ...@@ -209,11 +215,15 @@ where
&self, &self,
token_seq: Option<Vec<SequenceHash>>, token_seq: Option<Vec<SequenceHash>>,
isl_tokens: usize, isl_tokens: usize,
overlaps: OverlapScores, effective_cached_tokens: HashMap<dynamo_kv_router::protocols::WorkerWithDpRank, usize>,
track_prefill_tokens: bool, track_prefill_tokens: bool,
) -> Vec<PotentialLoad> { ) -> Vec<PotentialLoad> {
self.inner self.inner.get_potential_loads(
.get_potential_loads(token_seq, isl_tokens, overlaps, track_prefill_tokens) token_seq,
isl_tokens,
effective_cached_tokens,
track_prefill_tokens,
)
} }
pub fn get_active_lora_counts(&self) -> HashMap<String, usize> { pub fn get_active_lora_counts(&self) -> HashMap<String, usize> {
......
...@@ -196,6 +196,7 @@ mod tests { ...@@ -196,6 +196,7 @@ mod tests {
.await?; .await?;
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
let decay_now = Instant::now();
seq_manager_1.add_request( seq_manager_1.add_request(
SequenceRequest { SequenceRequest {
...@@ -207,7 +208,7 @@ mod tests { ...@@ -207,7 +208,7 @@ mod tests {
worker: WorkerWithDpRank::new(0, 0), worker: WorkerWithDpRank::new(0, 0),
lora_name: None, lora_name: None,
}, },
Instant::now(), decay_now,
)?; )?;
seq_manager_1.add_request( seq_manager_1.add_request(
...@@ -220,7 +221,7 @@ mod tests { ...@@ -220,7 +221,7 @@ mod tests {
worker: WorkerWithDpRank::new(0, 1), worker: WorkerWithDpRank::new(0, 1),
lora_name: None, lora_name: None,
}, },
Instant::now(), decay_now,
)?; )?;
seq_manager_2.add_request( seq_manager_2.add_request(
...@@ -233,7 +234,7 @@ mod tests { ...@@ -233,7 +234,7 @@ mod tests {
worker: WorkerWithDpRank::new(1, 0), worker: WorkerWithDpRank::new(1, 0),
lora_name: None, lora_name: None,
}, },
Instant::now(), decay_now,
)?; )?;
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await; tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
...@@ -349,6 +350,7 @@ mod tests { ...@@ -349,6 +350,7 @@ mod tests {
.await?; .await?;
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
let decay_now = Instant::now();
seq_manager_1.add_request( seq_manager_1.add_request(
SequenceRequest { SequenceRequest {
...@@ -360,7 +362,7 @@ mod tests { ...@@ -360,7 +362,7 @@ mod tests {
worker: WorkerWithDpRank::from_worker_id(0), worker: WorkerWithDpRank::from_worker_id(0),
lora_name: None, lora_name: None,
}, },
Instant::now(), decay_now,
)?; )?;
seq_manager_1.add_request( seq_manager_1.add_request(
...@@ -373,7 +375,7 @@ mod tests { ...@@ -373,7 +375,7 @@ mod tests {
worker: WorkerWithDpRank::from_worker_id(1), worker: WorkerWithDpRank::from_worker_id(1),
lora_name: None, lora_name: None,
}, },
Instant::now(), decay_now,
)?; )?;
seq_manager_2.add_request( seq_manager_2.add_request(
...@@ -386,7 +388,7 @@ mod tests { ...@@ -386,7 +388,7 @@ mod tests {
worker: WorkerWithDpRank::from_worker_id(2), worker: WorkerWithDpRank::from_worker_id(2),
lora_name: None, lora_name: None,
}, },
Instant::now(), decay_now,
)?; )?;
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await; tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
......
...@@ -105,8 +105,8 @@ pub struct RequestTracker { ...@@ -105,8 +105,8 @@ pub struct RequestTracker {
/// record the final finish time. /// record the final finish time.
request_finish_time: Mutex<Option<Instant>>, request_finish_time: Mutex<Option<Instant>>,
/// KV cache overlap blocks (prefix cache hits) - set once via OnceLock /// Effective KV cache overlap blocks (weighted prefix cache hits) - set once via OnceLock
kv_overlap_blocks: OnceLock<u32>, kv_overlap_blocks: OnceLock<f64>,
/// Input sequence length in blocks (for hit rate calculation) - set once via OnceLock /// Input sequence length in blocks (for hit rate calculation) - set once via OnceLock
isl_blocks: OnceLock<usize>, isl_blocks: OnceLock<usize>,
...@@ -114,7 +114,7 @@ pub struct RequestTracker { ...@@ -114,7 +114,7 @@ pub struct RequestTracker {
/// Input sequence length in tokens - set once via OnceLock /// Input sequence length in tokens - set once via OnceLock
isl_tokens: OnceLock<usize>, isl_tokens: OnceLock<usize>,
/// Number of cached tokens (overlap_blocks * block_size) - set once via OnceLock /// Number of cached tokens derived from the effective cache hit - set once via OnceLock
cached_tokens: OnceLock<usize>, cached_tokens: OnceLock<usize>,
/// Output sequence length in tokens - updated atomically as tokens stream back /// Output sequence length in tokens - updated atomically as tokens stream back
...@@ -226,7 +226,7 @@ impl RequestTracker { ...@@ -226,7 +226,7 @@ impl RequestTracker {
} }
/// Record KV cache hit information. Returns true if this was the first call. /// Record KV cache hit information. Returns true if this was the first call.
pub fn record_kv_hit(&self, overlap_blocks: u32, isl_blocks: usize) -> bool { pub fn record_kv_hit(&self, overlap_blocks: f64, isl_blocks: usize) -> bool {
let overlap_set = self.kv_overlap_blocks.set(overlap_blocks).is_ok(); let overlap_set = self.kv_overlap_blocks.set(overlap_blocks).is_ok();
let isl_set = self.isl_blocks.set(isl_blocks).is_ok(); let isl_set = self.isl_blocks.set(isl_blocks).is_ok();
overlap_set && isl_set overlap_set && isl_set
...@@ -311,7 +311,7 @@ impl RequestTracker { ...@@ -311,7 +311,7 @@ impl RequestTracker {
if isl == 0 { if isl == 0 {
return None; return None;
} }
Some(overlap as f64 / isl as f64) Some(overlap / isl as f64)
} }
/// Set the request phase and return a permit that blocks subsequent phase changes. /// Set the request phase and return a permit that blocks subsequent phase changes.
...@@ -707,7 +707,7 @@ mod tests { ...@@ -707,7 +707,7 @@ mod tests {
#[test] #[test]
fn test_kv_hit_rate() { fn test_kv_hit_rate() {
let tracker = RequestTracker::new(); let tracker = RequestTracker::new();
tracker.record_kv_hit(3, 10); tracker.record_kv_hit(3.0, 10);
let rate = tracker.kv_hit_rate().unwrap(); let rate = tracker.kv_hit_rate().unwrap();
assert!( assert!(
...@@ -719,7 +719,7 @@ mod tests { ...@@ -719,7 +719,7 @@ mod tests {
#[test] #[test]
fn test_kv_hit_rate_zero_isl() { fn test_kv_hit_rate_zero_isl() {
let tracker = RequestTracker::new(); let tracker = RequestTracker::new();
tracker.record_kv_hit(0, 0); tracker.record_kv_hit(0.0, 0);
assert!( assert!(
tracker.kv_hit_rate().is_none(), tracker.kv_hit_rate().is_none(),
"KV hit rate should be None when isl_blocks is 0" "KV hit rate should be None when isl_blocks is 0"
......
...@@ -389,6 +389,8 @@ pub struct TokenBlock { ...@@ -389,6 +389,8 @@ pub struct TokenBlock {
block_hash: BlockHash, block_hash: BlockHash,
sequence_hash: SequenceHash, sequence_hash: SequenceHash,
parent_sequence_hash: Option<SequenceHash>, parent_sequence_hash: Option<SequenceHash>,
external_sequence_hash: Option<SequenceHash>,
external_parent_sequence_hash: Option<SequenceHash>,
} }
impl TokenBlock { impl TokenBlock {
...@@ -425,6 +427,8 @@ impl TokenBlock { ...@@ -425,6 +427,8 @@ impl TokenBlock {
block_hash: chunk.block_hash, block_hash: chunk.block_hash,
sequence_hash, sequence_hash,
parent_sequence_hash, parent_sequence_hash,
external_sequence_hash: None,
external_parent_sequence_hash: None,
} }
} }
...@@ -453,6 +457,61 @@ impl TokenBlock { ...@@ -453,6 +457,61 @@ impl TokenBlock {
self.parent_sequence_hash self.parent_sequence_hash
} }
/// Returns the TRT-LLM/framework sequence hash for this block, if assigned.
pub fn external_sequence_hash(&self) -> Option<SequenceHash> {
self.external_sequence_hash
}
/// Returns the TRT-LLM/framework parent sequence hash for this block, if assigned.
pub fn external_parent_sequence_hash(&self) -> Option<SequenceHash> {
self.external_parent_sequence_hash
}
/// Assigns the TRT-LLM/framework hash chain for this block.
///
/// Idempotent: calling with the same values on an already-assigned block
/// is a no-op, but re-assigning a different chain panics to match the
/// invariant `sync_external_sequence_hashes` enforces.
pub fn assign_external_hashes(
&mut self,
external_sequence_hash: SequenceHash,
external_parent_sequence_hash: Option<SequenceHash>,
) {
if let Some(existing) = self.external_sequence_hash {
assert_eq!(
existing, external_sequence_hash,
"external_sequence_hash re-assignment mismatch",
);
assert_eq!(
self.external_parent_sequence_hash, external_parent_sequence_hash,
"external_parent_sequence_hash re-assignment mismatch",
);
return;
}
self.external_sequence_hash = Some(external_sequence_hash);
self.external_parent_sequence_hash = external_parent_sequence_hash;
}
/// Ensures that this complete block has an assigned TRT-LLM/framework hash chain.
pub fn assert_external_hashes_assigned(&self) {
assert!(
self.external_sequence_hash.is_some(),
"complete block is missing external_sequence_hash"
);
if self.parent_sequence_hash.is_some() {
assert!(
self.external_parent_sequence_hash.is_some(),
"non-root complete block is missing external_parent_sequence_hash"
);
} else {
assert!(
self.external_parent_sequence_hash.is_none(),
"root complete block must not have external_parent_sequence_hash"
);
}
}
/// Returns the number of tokens in the block. /// Returns the number of tokens in the block.
pub fn block_size(&self) -> usize { pub fn block_size(&self) -> usize {
self.tokens.0.len() self.tokens.0.len()
...@@ -836,6 +895,45 @@ impl TokenBlockSequence { ...@@ -836,6 +895,45 @@ impl TokenBlockSequence {
Tokens::from(result) Tokens::from(result)
} }
/// Synchronize the TRT-LLM/framework sequence hash chain onto all completed blocks.
///
/// `external_sequence_hashes` must contain exactly one hash per completed block in
/// sequence order. Existing assignments are validated and preserved.
pub fn sync_external_sequence_hashes(&mut self, external_sequence_hashes: &[SequenceHash]) {
assert_eq!(
external_sequence_hashes.len(),
self.blocks.len(),
"external_sequence_hashes length ({}) must match completed block count ({})",
external_sequence_hashes.len(),
self.blocks.len()
);
for (idx, block) in self.blocks.iter_mut().enumerate() {
let external_sequence_hash = external_sequence_hashes[idx];
let external_parent_sequence_hash = idx
.checked_sub(1)
.map(|parent_idx| external_sequence_hashes[parent_idx]);
match block.external_sequence_hash() {
Some(existing) => {
assert_eq!(
existing, external_sequence_hash,
"external_sequence_hash mismatch at block index {}",
idx
);
assert_eq!(
block.external_parent_sequence_hash(),
external_parent_sequence_hash,
"external_parent_sequence_hash mismatch at block index {}",
idx
);
}
None => block
.assign_external_hashes(external_sequence_hash, external_parent_sequence_hash),
}
}
}
/// Splits a [`Tokens`] object into a vector of completed blocks and a final partial block. /// Splits a [`Tokens`] object into a vector of completed blocks and a final partial block.
/// ///
/// This is primarily used internally by [`TokenBlockSequence::new`] but can be used externally. /// This is primarily used internally by [`TokenBlockSequence::new`] but can be used externally.
...@@ -1575,4 +1673,28 @@ mod tests { ...@@ -1575,4 +1673,28 @@ mod tests {
assert_eq!(partial.tokens.len(), 4); assert_eq!(partial.tokens.len(), 4);
assert_eq!(remaining.len(), 6); assert_eq!(remaining.len(), 6);
} }
#[test]
fn test_sync_external_sequence_hashes_assigns_chain() {
let mut seq = create_test_sequence(&[1, 2, 3, 4, 5, 6, 7, 8], 4, Some(TEST_SALT_HASH));
let external_hashes = vec![100_u64, 200_u64];
seq.sync_external_sequence_hashes(&external_hashes);
assert_eq!(seq.blocks[0].external_sequence_hash(), Some(100));
assert_eq!(seq.blocks[0].external_parent_sequence_hash(), None);
assert_eq!(seq.blocks[1].external_sequence_hash(), Some(200));
assert_eq!(seq.blocks[1].external_parent_sequence_hash(), Some(100));
seq.blocks[0].assert_external_hashes_assigned();
seq.blocks[1].assert_external_hashes_assigned();
}
#[test]
#[should_panic(expected = "external_sequence_hash mismatch")]
fn test_sync_external_sequence_hashes_rejects_mismatched_existing_chain() {
let mut seq = create_test_sequence(&[1, 2, 3, 4, 5, 6, 7, 8], 4, Some(TEST_SALT_HASH));
seq.sync_external_sequence_hashes(&[100_u64, 200_u64]);
seq.sync_external_sequence_hashes(&[100_u64, 201_u64]);
}
} }
...@@ -17,6 +17,7 @@ use dynamo_kv_router::queue::DEFAULT_MAX_BATCHED_TOKENS; ...@@ -17,6 +17,7 @@ use dynamo_kv_router::queue::DEFAULT_MAX_BATCHED_TOKENS;
use dynamo_kv_router::{ use dynamo_kv_router::{
ActiveSequencesMultiWorker, DefaultWorkerSelector, RadixTree, RouterSchedulingPolicy, ActiveSequencesMultiWorker, DefaultWorkerSelector, RadixTree, RouterSchedulingPolicy,
SchedulingPolicy, SchedulingRequest, SequenceRequest, WorkerSelector, SchedulingPolicy, SchedulingRequest, SequenceRequest, WorkerSelector,
scheduling::TierOverlapBlocks,
}; };
use dynamo_tokens::SequenceHash; use dynamo_tokens::SequenceHash;
use rustc_hash::FxHashMap; use rustc_hash::FxHashMap;
...@@ -135,14 +136,35 @@ impl PendingRequest { ...@@ -135,14 +136,35 @@ impl PendingRequest {
fn scheduling_request( fn scheduling_request(
&self, &self,
block_size: usize,
decode_blocks: FxHashMap<WorkerWithDpRank, usize>, decode_blocks: FxHashMap<WorkerWithDpRank, usize>,
prefill_tokens: FxHashMap<WorkerWithDpRank, usize>, prefill_tokens: FxHashMap<WorkerWithDpRank, usize>,
) -> SchedulingRequest { ) -> SchedulingRequest {
let effective_overlap_blocks = self
.overlaps
.scores
.iter()
.map(|(worker, overlap)| (*worker, *overlap as f64))
.collect();
let effective_cached_tokens = self
.overlaps
.scores
.iter()
.map(|(worker, overlap)| (*worker, *overlap as usize * block_size))
.collect();
SchedulingRequest { SchedulingRequest {
maybe_request_id: Some(self.request_id()), maybe_request_id: Some(self.request_id()),
token_seq: self.token_seq.clone(), token_seq: self.token_seq.clone(),
isl_tokens: self.isl_tokens, isl_tokens: self.isl_tokens,
overlaps: self.overlaps.clone(), tier_overlap_blocks: TierOverlapBlocks::default(),
effective_overlap_blocks,
effective_cached_tokens,
tree_sizes: self
.overlaps
.tree_sizes
.iter()
.map(|(k, v)| (*k, *v))
.collect(),
decode_blocks, decode_blocks,
prefill_tokens, prefill_tokens,
track_prefill_tokens: self.track_prefill_tokens, track_prefill_tokens: self.track_prefill_tokens,
...@@ -216,7 +238,7 @@ impl OfflineReplayRouter { ...@@ -216,7 +238,7 @@ impl OfflineReplayRouter {
let workers_with_configs = replay_workers_with_configs(args, num_workers); let workers_with_configs = replay_workers_with_configs(args, num_workers);
let slots = replay_slots(args, &workers_with_configs); let slots = replay_slots(args, &workers_with_configs);
let selector = replay_selector(&config); let selector = replay_selector(&config);
let policy = replay_policy(&config, args); let policy = replay_policy(&config);
let queue_threshold = config.router_queue_threshold; let queue_threshold = config.router_queue_threshold;
Ok(Self { Ok(Self {
...@@ -423,7 +445,11 @@ impl OfflineReplayRouter { ...@@ -423,7 +445,11 @@ impl OfflineReplayRouter {
let arrival_offset = Duration::from_secs_f64((now_ms.max(0.0)) / 1000.0); let arrival_offset = Duration::from_secs_f64((now_ms.max(0.0)) / 1000.0);
self.policy.enqueue_key( self.policy.enqueue_key(
arrival_offset, arrival_offset,
&request.scheduling_request(FxHashMap::default(), FxHashMap::default()), &request.scheduling_request(
self.block_size as usize,
FxHashMap::default(),
FxHashMap::default(),
),
) )
} }
...@@ -495,11 +521,19 @@ impl OfflineReplayRouter { ...@@ -495,11 +521,19 @@ impl OfflineReplayRouter {
.potential_blocks_and_tokens_with_prefill_tracking( .potential_blocks_and_tokens_with_prefill_tracking(
request.token_seq.as_deref(), request.token_seq.as_deref(),
request.isl_tokens, request.isl_tokens,
request.overlaps.clone(), request
.overlaps
.scores
.iter()
.map(|(worker, overlap)| {
(*worker, *overlap as usize * self.block_size as usize)
})
.collect(),
request.track_prefill_tokens, request.track_prefill_tokens,
decay_now, decay_now,
); );
let scheduling_request = request.scheduling_request(decode_blocks, prefill_tokens); let scheduling_request =
request.scheduling_request(self.block_size as usize, decode_blocks, prefill_tokens);
let selection = self.selector.select_worker( let selection = self.selector.select_worker(
&self.workers_with_configs, &self.workers_with_configs,
&scheduling_request, &scheduling_request,
...@@ -510,13 +544,13 @@ impl OfflineReplayRouter { ...@@ -510,13 +544,13 @@ impl OfflineReplayRouter {
let request_id = request.request_id(); let request_id = request.request_id();
let prefill_load_hint = self.prefill_load_hint_for( let prefill_load_hint = self.prefill_load_hint_for(
request.isl_tokens, request.isl_tokens,
selection.overlap_blocks, selection.cached_tokens,
request.track_prefill_tokens, request.track_prefill_tokens,
); );
let isl_blocks = u32::try_from(request.isl_tokens.div_ceil(self.block_size as usize)) let isl_blocks = u32::try_from(request.isl_tokens.div_ceil(self.block_size as usize))
.unwrap_or(u32::MAX); .unwrap_or(u32::MAX);
let overlap_blocks = selection.overlap_blocks; let overlap_blocks = selection.effective_overlap_blocks.floor() as u32;
self.slots self.slots
.add_request( .add_request(
...@@ -584,14 +618,14 @@ impl OfflineReplayRouter { ...@@ -584,14 +618,14 @@ impl OfflineReplayRouter {
fn prefill_load_hint_for( fn prefill_load_hint_for(
&self, &self,
isl_tokens: usize, isl_tokens: usize,
overlap_blocks: u32, cached_tokens: usize,
track_prefill_tokens: bool, track_prefill_tokens: bool,
) -> Option<PrefillLoadHint> { ) -> Option<PrefillLoadHint> {
if !track_prefill_tokens { if !track_prefill_tokens {
return None; return None;
} }
let prefix = (overlap_blocks as usize) * (self.block_size as usize); let prefix = cached_tokens.min(isl_tokens);
let effective_isl = isl_tokens.saturating_sub(prefix); let effective_isl = isl_tokens.saturating_sub(prefix);
if effective_isl == 0 { if effective_isl == 0 {
return None; return None;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment