Unverified Commit 9b3e9249 authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat(kv-router): add dedup filter for duplicate vLLM KV block events (#8012)


Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent f45a6985
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::HashMap;
use std::collections::hash_map::Entry;
use std::future::Future;
use std::sync::Arc;
use std::time::{Duration, Instant};
......@@ -19,6 +21,76 @@ use crate::kv_router::KV_EVENT_SUBJECT;
use super::{DEFAULT_MAX_BATCH_BLOCKS, kv_publisher_metrics};
/// Reference-counting filter that deduplicates KV cache events.
///
/// vLLM can emit multiple store/remove events for the same block hash.
/// Refcounts are tracked **per DP rank** because identical block hashes
/// on different ranks represent independent blocks.
///
/// - **Store**: always passes through; increments refcount for the rank.
/// - **Remove**: only passes through when refcount decrements to 0.
/// - **Cleared**: resets refcounts for all ranks.
pub(super) struct EventDedupFilter {
/// Per-dp-rank refcounts.
per_rank: HashMap<u32, HashMap<ExternalSequenceBlockHash, usize>>,
}
impl EventDedupFilter {
pub(super) fn new() -> Self {
Self {
per_rank: HashMap::new(),
}
}
/// Track a store event. Increments refcount for each block hash on the
/// given DP rank. Stores always pass through — this only updates bookkeeping.
pub(super) fn track_store(&mut self, dp_rank: u32, data: &KvCacheStoreData) {
let refcounts = self.per_rank.entry(dp_rank).or_default();
for block in &data.blocks {
*refcounts.entry(block.block_hash).or_insert(0) += 1;
}
}
/// Filter a remove event. Retains only block hashes whose refcount on the
/// given DP rank decrements to 0 (removing them from the map). Returns
/// `None` if no hashes survive filtering.
pub(super) fn filter_remove(
&mut self,
dp_rank: u32,
mut data: KvCacheRemoveData,
) -> Option<KvCacheRemoveData> {
let refcounts = self.per_rank.entry(dp_rank).or_default();
data.block_hashes.retain(|hash| {
match refcounts.entry(*hash) {
Entry::Occupied(mut entry) => {
*entry.get_mut() -= 1;
if *entry.get() == 0 {
entry.remove();
true // refcount hit 0 → pass through
} else {
false // still has references → filter out
}
}
Entry::Vacant(_) => {
true // not tracked → pass through defensively
}
}
});
if data.block_hashes.is_empty() {
None
} else {
Some(data)
}
}
/// Clear refcounts for all DP ranks. A `Cleared` event from any rank
/// causes the indexer to wipe all blocks for the entire worker, so we
/// must reset all ranks' refcounts to stay consistent.
pub(super) fn clear(&mut self) {
self.per_rank.clear();
}
}
/// Accumulator for in-flight KV cache events that will be merged into a single
/// [`RouterEvent`] before being forwarded to the event sink.
#[derive(Debug)]
......@@ -80,39 +152,47 @@ impl BatchingState {
publisher: &P,
local_indexer: &Option<Arc<LocalKvIndexer>>,
worker_id: u64,
dedup: &mut EventDedupFilter,
) {
if !self.has_pending() {
return;
}
let id = self.next_publish_id;
let dp_rank = self.last_dp_rank;
if let Some(data) = self.pending_removed.take() {
let mut emitted = false;
if let Some(data) = self.pending_removed.take()
&& let Some(filtered) = dedup.filter_remove(dp_rank, data)
{
emit(
publisher,
local_indexer,
worker_id,
KvCacheEvent {
event_id: id,
data: KvCacheEventData::Removed(data),
event_id: self.next_publish_id,
data: KvCacheEventData::Removed(filtered),
dp_rank,
},
)
.await;
emitted = true;
}
if let Some(data) = self.pending_stored.take() {
dedup.track_store(dp_rank, &data);
emit(
publisher,
local_indexer,
worker_id,
KvCacheEvent {
event_id: id,
event_id: self.next_publish_id,
data: KvCacheEventData::Stored(data),
dp_rank,
},
)
.await;
emitted = true;
}
if emitted {
self.next_publish_id += 1;
}
self.next_publish_id += 1;
self.record_flush_time();
}
}
......@@ -160,19 +240,20 @@ pub(super) async fn run_event_processor_loop<P: RouterEventSink + Send + Sync +
max_batch_blocks: usize,
) {
let mut batching_state = BatchingState::new();
let mut dedup = EventDedupFilter::new();
let mut last_raw_input_id: Option<u64> = None;
loop {
tokio::select! {
_ = cancellation_token.cancelled() => {
tracing::info!("KV Event source received cancellation signal");
batching_state.flush(&publisher, &local_indexer, worker_id).await;
batching_state.flush(&publisher, &local_indexer, worker_id, &mut dedup).await;
break;
}
event = rx.recv() => {
let Some(placement_event) = event else {
tracing::debug!("Event processor channel closed.");
batching_state.flush(&publisher, &local_indexer, worker_id).await;
batching_state.flush(&publisher, &local_indexer, worker_id, &mut dedup).await;
break;
};
......@@ -223,7 +304,7 @@ pub(super) async fn run_event_processor_loop<P: RouterEventSink + Send + Sync +
match event.data {
KvCacheEventData::Removed(data) => {
if batching_state.pending_stored.is_some() || dp_rank_changed {
batching_state.flush(&publisher, &local_indexer, worker_id).await;
batching_state.flush(&publisher, &local_indexer, worker_id, &mut dedup).await;
}
match &mut batching_state.pending_removed {
Some(pending) => pending.block_hashes.extend(data.block_hashes),
......@@ -239,7 +320,7 @@ pub(super) async fn run_event_processor_loop<P: RouterEventSink + Send + Sync +
data.parent_hash != p.blocks.last().map(|b| b.block_hash)
});
if should_flush {
batching_state.flush(&publisher, &local_indexer, worker_id).await;
batching_state.flush(&publisher, &local_indexer, worker_id, &mut dedup).await;
}
match &mut batching_state.pending_stored {
Some(pending) => pending.blocks.extend(data.blocks),
......@@ -249,7 +330,8 @@ pub(super) async fn run_event_processor_loop<P: RouterEventSink + Send + Sync +
}
}
KvCacheEventData::Cleared => {
batching_state.flush(&publisher, &local_indexer, worker_id).await;
batching_state.flush(&publisher, &local_indexer, worker_id, &mut dedup).await;
dedup.clear();
emit(
&publisher,
&local_indexer,
......@@ -271,7 +353,7 @@ pub(super) async fn run_event_processor_loop<P: RouterEventSink + Send + Sync +
&& (timeout_ms.is_none_or(|ms| batching_state.is_timeout_elapsed(ms))
|| batching_state.pending_block_count() > max_batch_blocks)
{
batching_state.flush(&publisher, &local_indexer, worker_id).await;
batching_state.flush(&publisher, &local_indexer, worker_id, &mut dedup).await;
}
}
_ = tokio::time::sleep(
......@@ -279,7 +361,7 @@ pub(super) async fn run_event_processor_loop<P: RouterEventSink + Send + Sync +
.map(|ms| batching_state.remaining_timeout(ms))
.unwrap_or(Duration::from_secs(3600))
), if timeout_ms.is_some() && batching_state.has_pending() => {
batching_state.flush(&publisher, &local_indexer, worker_id).await;
batching_state.flush(&publisher, &local_indexer, worker_id, &mut dedup).await;
}
}
}
......
......@@ -34,7 +34,7 @@ mod worker_metrics;
mod zmq_listener;
#[cfg(test)]
use event_processor::{BatchingState, run_event_processor_loop};
use event_processor::{BatchingState, EventDedupFilter, run_event_processor_loop};
use event_processor::{
EventPlanePublisher, start_event_processor, start_event_processor_jetstream,
};
......
......@@ -1263,6 +1263,157 @@ mod tests_startup_helpers {
}
}
#[cfg(test)]
mod test_event_dedup_filter {
use super::*;
fn store_data(hashes: &[u64]) -> KvCacheStoreData {
KvCacheStoreData {
parent_hash: None,
blocks: hashes
.iter()
.map(|&h| KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(h),
tokens_hash: LocalBlockHash(h * 10),
mm_extra_info: None,
})
.collect(),
}
}
fn remove_data(hashes: &[u64]) -> KvCacheRemoveData {
KvCacheRemoveData {
block_hashes: hashes
.iter()
.map(|&h| ExternalSequenceBlockHash(h))
.collect(),
}
}
#[test]
fn stores_track_refcounts_for_removes() {
let mut filter = EventDedupFilter::new();
let data = store_data(&[1, 2, 3]);
// Store same hashes twice — refcount should be 2
filter.track_store(0, &data);
filter.track_store(0, &data);
// First remove — refcounts 2→1, all filtered out
let result = filter.filter_remove(0, remove_data(&[1, 2, 3]));
assert!(result.is_none());
// Second remove — refcounts 1→0, all pass through
let result = filter.filter_remove(0, remove_data(&[1, 2, 3]));
assert!(result.is_some());
assert_eq!(result.unwrap().block_hashes.len(), 3);
}
#[test]
fn duplicate_removes_are_filtered() {
let mut filter = EventDedupFilter::new();
// Store same hash twice
filter.track_store(0, &store_data(&[1]));
filter.track_store(0, &store_data(&[1]));
// First remove — refcount 2→1, filtered out
let result = filter.filter_remove(0, remove_data(&[1]));
assert!(result.is_none());
// Second remove — refcount 1→0, passes through
let result = filter.filter_remove(0, remove_data(&[1]));
assert!(result.is_some());
assert_eq!(result.unwrap().block_hashes.len(), 1);
}
#[test]
fn store_remove_store_cycle() {
let mut filter = EventDedupFilter::new();
// Store hash 1
filter.track_store(0, &store_data(&[1]));
// Remove hash 1 — refcount 1→0, passes through
let result = filter.filter_remove(0, remove_data(&[1]));
assert!(result.is_some());
// Store hash 1 again — refcount starts fresh at 1
filter.track_store(0, &store_data(&[1]));
// Remove again — refcount 1→0, passes through
let result = filter.filter_remove(0, remove_data(&[1]));
assert!(result.is_some());
}
#[test]
fn clear_resets_all_ranks() {
let mut filter = EventDedupFilter::new();
// Store on rank 0 and rank 1
filter.track_store(0, &store_data(&[1, 2]));
filter.track_store(0, &store_data(&[1, 2]));
filter.track_store(1, &store_data(&[1, 2]));
filter.track_store(1, &store_data(&[1, 2]));
// Clear wipes all ranks (matches indexer semantics where Cleared
// from any rank removes all blocks for the entire worker).
filter.clear();
// Both ranks pass through defensively after clear
let result = filter.filter_remove(0, remove_data(&[1]));
assert!(result.is_some());
let result = filter.filter_remove(1, remove_data(&[1]));
assert!(result.is_some());
}
#[test]
fn mixed_blocks_in_single_remove() {
let mut filter = EventDedupFilter::new();
// Hash 1: stored twice (refcount 2)
filter.track_store(0, &store_data(&[1]));
filter.track_store(0, &store_data(&[1]));
// Hash 2: stored once (refcount 1)
filter.track_store(0, &store_data(&[2]));
// Hash 3: stored twice (refcount 2)
filter.track_store(0, &store_data(&[3]));
filter.track_store(0, &store_data(&[3]));
// Remove all three — only hash 2 (refcount 1→0) passes through
let result = filter.filter_remove(0, remove_data(&[1, 2, 3]));
assert!(result.is_some());
let result = result.unwrap();
assert_eq!(result.block_hashes.len(), 1);
assert_eq!(result.block_hashes[0], ExternalSequenceBlockHash(2));
}
#[test]
fn same_hash_on_different_ranks_are_independent() {
let mut filter = EventDedupFilter::new();
// Store hash 1 on rank 0 (twice) and rank 1 (once)
filter.track_store(0, &store_data(&[1]));
filter.track_store(0, &store_data(&[1]));
filter.track_store(1, &store_data(&[1]));
// Remove hash 1 on rank 1 — refcount 1→0, passes through
let result = filter.filter_remove(1, remove_data(&[1]));
assert!(result.is_some());
// Remove hash 1 on rank 0 — refcount 2→1, filtered out
let result = filter.filter_remove(0, remove_data(&[1]));
assert!(result.is_none());
// Remove hash 1 on rank 0 again — refcount 1→0, passes through
let result = filter.filter_remove(0, remove_data(&[1]));
assert!(result.is_some());
}
}
#[cfg(all(test, feature = "integration"))]
mod test_integration_publisher {
use super::*;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment