Unverified Commit 9b3e9249 authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat(kv-router): add dedup filter for duplicate vLLM KV block events (#8012)


Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent f45a6985
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::collections::HashMap;
use std::collections::hash_map::Entry;
use std::future::Future; use std::future::Future;
use std::sync::Arc; use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
...@@ -19,6 +21,76 @@ use crate::kv_router::KV_EVENT_SUBJECT; ...@@ -19,6 +21,76 @@ use crate::kv_router::KV_EVENT_SUBJECT;
use super::{DEFAULT_MAX_BATCH_BLOCKS, kv_publisher_metrics}; use super::{DEFAULT_MAX_BATCH_BLOCKS, kv_publisher_metrics};
/// Reference-counting filter that deduplicates KV cache events.
///
/// vLLM can emit multiple store/remove events for the same block hash.
/// Refcounts are tracked **per DP rank** because identical block hashes
/// on different ranks represent independent blocks.
///
/// - **Store**: always passes through; increments refcount for the rank.
/// - **Remove**: only passes through when refcount decrements to 0.
/// - **Cleared**: resets refcounts for all ranks.
pub(super) struct EventDedupFilter {
/// Per-dp-rank refcounts.
per_rank: HashMap<u32, HashMap<ExternalSequenceBlockHash, usize>>,
}
impl EventDedupFilter {
pub(super) fn new() -> Self {
Self {
per_rank: HashMap::new(),
}
}
/// Track a store event. Increments refcount for each block hash on the
/// given DP rank. Stores always pass through — this only updates bookkeeping.
pub(super) fn track_store(&mut self, dp_rank: u32, data: &KvCacheStoreData) {
let refcounts = self.per_rank.entry(dp_rank).or_default();
for block in &data.blocks {
*refcounts.entry(block.block_hash).or_insert(0) += 1;
}
}
/// Filter a remove event. Retains only block hashes whose refcount on the
/// given DP rank decrements to 0 (removing them from the map). Returns
/// `None` if no hashes survive filtering.
pub(super) fn filter_remove(
&mut self,
dp_rank: u32,
mut data: KvCacheRemoveData,
) -> Option<KvCacheRemoveData> {
let refcounts = self.per_rank.entry(dp_rank).or_default();
data.block_hashes.retain(|hash| {
match refcounts.entry(*hash) {
Entry::Occupied(mut entry) => {
*entry.get_mut() -= 1;
if *entry.get() == 0 {
entry.remove();
true // refcount hit 0 → pass through
} else {
false // still has references → filter out
}
}
Entry::Vacant(_) => {
true // not tracked → pass through defensively
}
}
});
if data.block_hashes.is_empty() {
None
} else {
Some(data)
}
}
/// Clear refcounts for all DP ranks. A `Cleared` event from any rank
/// causes the indexer to wipe all blocks for the entire worker, so we
/// must reset all ranks' refcounts to stay consistent.
pub(super) fn clear(&mut self) {
self.per_rank.clear();
}
}
/// Accumulator for in-flight KV cache events that will be merged into a single /// Accumulator for in-flight KV cache events that will be merged into a single
/// [`RouterEvent`] before being forwarded to the event sink. /// [`RouterEvent`] before being forwarded to the event sink.
#[derive(Debug)] #[derive(Debug)]
...@@ -80,39 +152,47 @@ impl BatchingState { ...@@ -80,39 +152,47 @@ impl BatchingState {
publisher: &P, publisher: &P,
local_indexer: &Option<Arc<LocalKvIndexer>>, local_indexer: &Option<Arc<LocalKvIndexer>>,
worker_id: u64, worker_id: u64,
dedup: &mut EventDedupFilter,
) { ) {
if !self.has_pending() { if !self.has_pending() {
return; return;
} }
let id = self.next_publish_id;
let dp_rank = self.last_dp_rank; let dp_rank = self.last_dp_rank;
if let Some(data) = self.pending_removed.take() { let mut emitted = false;
if let Some(data) = self.pending_removed.take()
&& let Some(filtered) = dedup.filter_remove(dp_rank, data)
{
emit( emit(
publisher, publisher,
local_indexer, local_indexer,
worker_id, worker_id,
KvCacheEvent { KvCacheEvent {
event_id: id, event_id: self.next_publish_id,
data: KvCacheEventData::Removed(data), data: KvCacheEventData::Removed(filtered),
dp_rank, dp_rank,
}, },
) )
.await; .await;
emitted = true;
} }
if let Some(data) = self.pending_stored.take() { if let Some(data) = self.pending_stored.take() {
dedup.track_store(dp_rank, &data);
emit( emit(
publisher, publisher,
local_indexer, local_indexer,
worker_id, worker_id,
KvCacheEvent { KvCacheEvent {
event_id: id, event_id: self.next_publish_id,
data: KvCacheEventData::Stored(data), data: KvCacheEventData::Stored(data),
dp_rank, dp_rank,
}, },
) )
.await; .await;
emitted = true;
}
if emitted {
self.next_publish_id += 1;
} }
self.next_publish_id += 1;
self.record_flush_time(); self.record_flush_time();
} }
} }
...@@ -160,19 +240,20 @@ pub(super) async fn run_event_processor_loop<P: RouterEventSink + Send + Sync + ...@@ -160,19 +240,20 @@ pub(super) async fn run_event_processor_loop<P: RouterEventSink + Send + Sync +
max_batch_blocks: usize, max_batch_blocks: usize,
) { ) {
let mut batching_state = BatchingState::new(); let mut batching_state = BatchingState::new();
let mut dedup = EventDedupFilter::new();
let mut last_raw_input_id: Option<u64> = None; let mut last_raw_input_id: Option<u64> = None;
loop { loop {
tokio::select! { tokio::select! {
_ = cancellation_token.cancelled() => { _ = cancellation_token.cancelled() => {
tracing::info!("KV Event source received cancellation signal"); tracing::info!("KV Event source received cancellation signal");
batching_state.flush(&publisher, &local_indexer, worker_id).await; batching_state.flush(&publisher, &local_indexer, worker_id, &mut dedup).await;
break; break;
} }
event = rx.recv() => { event = rx.recv() => {
let Some(placement_event) = event else { let Some(placement_event) = event else {
tracing::debug!("Event processor channel closed."); tracing::debug!("Event processor channel closed.");
batching_state.flush(&publisher, &local_indexer, worker_id).await; batching_state.flush(&publisher, &local_indexer, worker_id, &mut dedup).await;
break; break;
}; };
...@@ -223,7 +304,7 @@ pub(super) async fn run_event_processor_loop<P: RouterEventSink + Send + Sync + ...@@ -223,7 +304,7 @@ pub(super) async fn run_event_processor_loop<P: RouterEventSink + Send + Sync +
match event.data { match event.data {
KvCacheEventData::Removed(data) => { KvCacheEventData::Removed(data) => {
if batching_state.pending_stored.is_some() || dp_rank_changed { if batching_state.pending_stored.is_some() || dp_rank_changed {
batching_state.flush(&publisher, &local_indexer, worker_id).await; batching_state.flush(&publisher, &local_indexer, worker_id, &mut dedup).await;
} }
match &mut batching_state.pending_removed { match &mut batching_state.pending_removed {
Some(pending) => pending.block_hashes.extend(data.block_hashes), Some(pending) => pending.block_hashes.extend(data.block_hashes),
...@@ -239,7 +320,7 @@ pub(super) async fn run_event_processor_loop<P: RouterEventSink + Send + Sync + ...@@ -239,7 +320,7 @@ pub(super) async fn run_event_processor_loop<P: RouterEventSink + Send + Sync +
data.parent_hash != p.blocks.last().map(|b| b.block_hash) data.parent_hash != p.blocks.last().map(|b| b.block_hash)
}); });
if should_flush { if should_flush {
batching_state.flush(&publisher, &local_indexer, worker_id).await; batching_state.flush(&publisher, &local_indexer, worker_id, &mut dedup).await;
} }
match &mut batching_state.pending_stored { match &mut batching_state.pending_stored {
Some(pending) => pending.blocks.extend(data.blocks), Some(pending) => pending.blocks.extend(data.blocks),
...@@ -249,7 +330,8 @@ pub(super) async fn run_event_processor_loop<P: RouterEventSink + Send + Sync + ...@@ -249,7 +330,8 @@ pub(super) async fn run_event_processor_loop<P: RouterEventSink + Send + Sync +
} }
} }
KvCacheEventData::Cleared => { KvCacheEventData::Cleared => {
batching_state.flush(&publisher, &local_indexer, worker_id).await; batching_state.flush(&publisher, &local_indexer, worker_id, &mut dedup).await;
dedup.clear();
emit( emit(
&publisher, &publisher,
&local_indexer, &local_indexer,
...@@ -271,7 +353,7 @@ pub(super) async fn run_event_processor_loop<P: RouterEventSink + Send + Sync + ...@@ -271,7 +353,7 @@ pub(super) async fn run_event_processor_loop<P: RouterEventSink + Send + Sync +
&& (timeout_ms.is_none_or(|ms| batching_state.is_timeout_elapsed(ms)) && (timeout_ms.is_none_or(|ms| batching_state.is_timeout_elapsed(ms))
|| batching_state.pending_block_count() > max_batch_blocks) || batching_state.pending_block_count() > max_batch_blocks)
{ {
batching_state.flush(&publisher, &local_indexer, worker_id).await; batching_state.flush(&publisher, &local_indexer, worker_id, &mut dedup).await;
} }
} }
_ = tokio::time::sleep( _ = tokio::time::sleep(
...@@ -279,7 +361,7 @@ pub(super) async fn run_event_processor_loop<P: RouterEventSink + Send + Sync + ...@@ -279,7 +361,7 @@ pub(super) async fn run_event_processor_loop<P: RouterEventSink + Send + Sync +
.map(|ms| batching_state.remaining_timeout(ms)) .map(|ms| batching_state.remaining_timeout(ms))
.unwrap_or(Duration::from_secs(3600)) .unwrap_or(Duration::from_secs(3600))
), if timeout_ms.is_some() && batching_state.has_pending() => { ), if timeout_ms.is_some() && batching_state.has_pending() => {
batching_state.flush(&publisher, &local_indexer, worker_id).await; batching_state.flush(&publisher, &local_indexer, worker_id, &mut dedup).await;
} }
} }
} }
......
...@@ -34,7 +34,7 @@ mod worker_metrics; ...@@ -34,7 +34,7 @@ mod worker_metrics;
mod zmq_listener; mod zmq_listener;
#[cfg(test)] #[cfg(test)]
use event_processor::{BatchingState, run_event_processor_loop}; use event_processor::{BatchingState, EventDedupFilter, run_event_processor_loop};
use event_processor::{ use event_processor::{
EventPlanePublisher, start_event_processor, start_event_processor_jetstream, EventPlanePublisher, start_event_processor, start_event_processor_jetstream,
}; };
......
...@@ -1263,6 +1263,157 @@ mod tests_startup_helpers { ...@@ -1263,6 +1263,157 @@ mod tests_startup_helpers {
} }
} }
#[cfg(test)]
mod test_event_dedup_filter {
use super::*;
fn store_data(hashes: &[u64]) -> KvCacheStoreData {
KvCacheStoreData {
parent_hash: None,
blocks: hashes
.iter()
.map(|&h| KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(h),
tokens_hash: LocalBlockHash(h * 10),
mm_extra_info: None,
})
.collect(),
}
}
fn remove_data(hashes: &[u64]) -> KvCacheRemoveData {
KvCacheRemoveData {
block_hashes: hashes
.iter()
.map(|&h| ExternalSequenceBlockHash(h))
.collect(),
}
}
#[test]
fn stores_track_refcounts_for_removes() {
let mut filter = EventDedupFilter::new();
let data = store_data(&[1, 2, 3]);
// Store same hashes twice — refcount should be 2
filter.track_store(0, &data);
filter.track_store(0, &data);
// First remove — refcounts 2→1, all filtered out
let result = filter.filter_remove(0, remove_data(&[1, 2, 3]));
assert!(result.is_none());
// Second remove — refcounts 1→0, all pass through
let result = filter.filter_remove(0, remove_data(&[1, 2, 3]));
assert!(result.is_some());
assert_eq!(result.unwrap().block_hashes.len(), 3);
}
#[test]
fn duplicate_removes_are_filtered() {
let mut filter = EventDedupFilter::new();
// Store same hash twice
filter.track_store(0, &store_data(&[1]));
filter.track_store(0, &store_data(&[1]));
// First remove — refcount 2→1, filtered out
let result = filter.filter_remove(0, remove_data(&[1]));
assert!(result.is_none());
// Second remove — refcount 1→0, passes through
let result = filter.filter_remove(0, remove_data(&[1]));
assert!(result.is_some());
assert_eq!(result.unwrap().block_hashes.len(), 1);
}
#[test]
fn store_remove_store_cycle() {
let mut filter = EventDedupFilter::new();
// Store hash 1
filter.track_store(0, &store_data(&[1]));
// Remove hash 1 — refcount 1→0, passes through
let result = filter.filter_remove(0, remove_data(&[1]));
assert!(result.is_some());
// Store hash 1 again — refcount starts fresh at 1
filter.track_store(0, &store_data(&[1]));
// Remove again — refcount 1→0, passes through
let result = filter.filter_remove(0, remove_data(&[1]));
assert!(result.is_some());
}
#[test]
fn clear_resets_all_ranks() {
let mut filter = EventDedupFilter::new();
// Store on rank 0 and rank 1
filter.track_store(0, &store_data(&[1, 2]));
filter.track_store(0, &store_data(&[1, 2]));
filter.track_store(1, &store_data(&[1, 2]));
filter.track_store(1, &store_data(&[1, 2]));
// Clear wipes all ranks (matches indexer semantics where Cleared
// from any rank removes all blocks for the entire worker).
filter.clear();
// Both ranks pass through defensively after clear
let result = filter.filter_remove(0, remove_data(&[1]));
assert!(result.is_some());
let result = filter.filter_remove(1, remove_data(&[1]));
assert!(result.is_some());
}
#[test]
fn mixed_blocks_in_single_remove() {
let mut filter = EventDedupFilter::new();
// Hash 1: stored twice (refcount 2)
filter.track_store(0, &store_data(&[1]));
filter.track_store(0, &store_data(&[1]));
// Hash 2: stored once (refcount 1)
filter.track_store(0, &store_data(&[2]));
// Hash 3: stored twice (refcount 2)
filter.track_store(0, &store_data(&[3]));
filter.track_store(0, &store_data(&[3]));
// Remove all three — only hash 2 (refcount 1→0) passes through
let result = filter.filter_remove(0, remove_data(&[1, 2, 3]));
assert!(result.is_some());
let result = result.unwrap();
assert_eq!(result.block_hashes.len(), 1);
assert_eq!(result.block_hashes[0], ExternalSequenceBlockHash(2));
}
#[test]
fn same_hash_on_different_ranks_are_independent() {
let mut filter = EventDedupFilter::new();
// Store hash 1 on rank 0 (twice) and rank 1 (once)
filter.track_store(0, &store_data(&[1]));
filter.track_store(0, &store_data(&[1]));
filter.track_store(1, &store_data(&[1]));
// Remove hash 1 on rank 1 — refcount 1→0, passes through
let result = filter.filter_remove(1, remove_data(&[1]));
assert!(result.is_some());
// Remove hash 1 on rank 0 — refcount 2→1, filtered out
let result = filter.filter_remove(0, remove_data(&[1]));
assert!(result.is_none());
// Remove hash 1 on rank 0 again — refcount 1→0, passes through
let result = filter.filter_remove(0, remove_data(&[1]));
assert!(result.is_some());
}
}
#[cfg(all(test, feature = "integration"))] #[cfg(all(test, feature = "integration"))]
mod test_integration_publisher { mod test_integration_publisher {
use super::*; use super::*;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment