"tests/fault_tolerance/vscode:/vscode.git/clone" did not exist on "fdcf611ffcd4535ec3a6491d1fdd972120408064"
Unverified Commit e811db50 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat(kv-router): plumb medium into worker-side placements (#7462)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 9681225a
......@@ -623,9 +623,9 @@ impl SequenceData {
/// Convert to a store event.
pub fn to_store_event(&self, event_id: u64) -> RouterEvent {
RouterEvent {
worker_id: self.worker_id,
event: KvCacheEvent {
RouterEvent::new(
self.worker_id,
KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
......@@ -642,21 +642,21 @@ impl SequenceData {
}),
dp_rank: 0,
},
}
)
}
/// Convert to a remove event.
pub fn to_remove_event(&self, event_id: u64) -> RouterEvent {
RouterEvent {
worker_id: self.worker_id,
event: KvCacheEvent {
RouterEvent::new(
self.worker_id,
KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: self.external_hashes.clone(),
}),
dp_rank: 0,
},
}
)
}
}
......
......@@ -297,10 +297,7 @@ async fn run_benchmark(
}
WorkerTraceEntry::Event(event) => {
indexer
.apply_event(RouterEvent {
worker_id: worker_id as u64,
event,
})
.apply_event(RouterEvent::new(worker_id as u64, event))
.await;
Ok(None)
}
......
......@@ -590,6 +590,7 @@ impl ConcurrentRadixTree {
// Create a store event for this worker
let event = RouterEvent {
worker_id: worker.worker_id,
storage_tier: crate::protocols::StorageTier::Device,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
......
......@@ -432,6 +432,7 @@ impl PositionalIndexer {
events.push(RouterEvent {
worker_id: worker.worker_id,
storage_tier: crate::protocols::StorageTier::Device,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
......
......@@ -557,6 +557,7 @@ impl RadixTree {
// Create a store event for this worker
let event = RouterEvent {
worker_id: worker.worker_id,
storage_tier: crate::protocols::StorageTier::Device,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
......
......@@ -13,6 +13,7 @@ use super::concurrent_radix_tree::ConcurrentRadixTree;
use super::positional::PositionalIndexer;
use super::*;
use crate::protocols::*;
use crate::test_utils::{remove_event, router_event, stored_blocks_with_sequence_hashes};
// ============================================================================
// Helper functions
......@@ -63,25 +64,15 @@ fn make_store_event_with_parent(
local_hashes.iter().map(|&h| LocalBlockHash(h)).collect();
let new_seq_hashes = &full_seq_hashes[prefix_hashes.len()..];
RouterEvent {
router_event(
worker_id,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks: new_block_hashes
.iter()
.zip(new_seq_hashes.iter())
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
}
0,
0,
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks: stored_blocks_with_sequence_hashes(&new_block_hashes, new_seq_hashes),
}),
)
}
/// Create a store event with all options.
......@@ -95,25 +86,15 @@ fn make_store_event_full(
local_hashes.iter().map(|&h| LocalBlockHash(h)).collect();
let seq_hashes = compute_seq_hash_for_block(&local_block_hashes);
RouterEvent {
router_event(
worker_id,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks: local_block_hashes
.iter()
.zip(seq_hashes.iter())
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank,
},
}
0,
dp_rank,
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks: stored_blocks_with_sequence_hashes(&local_block_hashes, &seq_hashes),
}),
)
}
/// Create a remove event for blocks with given local hashes.
......@@ -131,19 +112,15 @@ fn make_remove_event_with_dp_rank(
local_hashes.iter().map(|&h| LocalBlockHash(h)).collect();
let seq_hashes = compute_seq_hash_for_block(&local_block_hashes);
RouterEvent {
remove_event(
worker_id,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: seq_hashes
.iter()
.map(|&h| ExternalSequenceBlockHash(h))
.collect(),
}),
dp_rank,
},
}
0,
dp_rank,
seq_hashes
.iter()
.map(|&h| ExternalSequenceBlockHash(h))
.collect(),
)
}
/// Create a remove event with parent hash for continuation sequences.
......@@ -165,19 +142,15 @@ fn make_remove_event_with_parent(
let suffix_seq_hashes = &full_seq_hashes[prefix_hashes.len()..];
RouterEvent {
remove_event(
worker_id,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: suffix_seq_hashes
.iter()
.map(|&h| ExternalSequenceBlockHash(h))
.collect(),
}),
dp_rank: 0,
},
}
0,
0,
suffix_seq_hashes
.iter()
.map(|&h| ExternalSequenceBlockHash(h))
.collect(),
)
}
/// Snapshot the tree state for deterministic comparison.
......@@ -222,14 +195,7 @@ fn make_clear_event(worker_id: u64) -> RouterEvent {
/// Create a clear event with a specific dp_rank.
fn make_clear_event_with_dp_rank(worker_id: u64, dp_rank: u32) -> RouterEvent {
RouterEvent {
worker_id,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Cleared,
dp_rank,
},
}
router_event(worker_id, 0, dp_rank, KvCacheEventData::Cleared)
}
// ============================================================================
......@@ -646,16 +612,7 @@ async fn test_partial_block_removal(variant: &str) {
let seq_hashes = compute_seq_hash_for_block(&full_hashes);
let block_3_seq_hash = ExternalSequenceBlockHash(seq_hashes[2]); // Last block's hash
let remove_event = RouterEvent {
worker_id: 0,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![block_3_seq_hash],
}),
dp_rank: 0,
},
};
let remove_event = remove_event(0, 0, 0, vec![block_3_seq_hash]);
index.apply_event(remove_event).await;
flush_and_settle(index.as_ref()).await;
......@@ -698,16 +655,7 @@ async fn test_remove_mid_chain_block(variant: &str) {
let seq_hashes = compute_seq_hash_for_block(&full_hashes);
let block_3_seq_hash = ExternalSequenceBlockHash(seq_hashes[2]);
let remove_event = RouterEvent {
worker_id: 0,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![block_3_seq_hash],
}),
dp_rank: 0,
},
};
let remove_event = remove_event(0, 0, 0, vec![block_3_seq_hash]);
index.apply_event(remove_event).await;
flush_and_settle(index.as_ref()).await;
......@@ -895,47 +843,27 @@ async fn test_lora_and_base_model_blocks_do_not_conflict(variant: &str) {
let lora_seq = compute_seq_hash_for_block(&lora_hashes);
// Store base-model blocks on worker 0
let base_event = RouterEvent {
worker_id: 0,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: base_hashes
.iter()
.zip(base_seq.iter())
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
};
let base_event = router_event(
0,
0,
0,
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: stored_blocks_with_sequence_hashes(&base_hashes, &base_seq),
}),
);
index.apply_event(base_event).await;
// Store LoRA blocks on worker 1
let lora_event = RouterEvent {
worker_id: 1,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: lora_hashes
.iter()
.zip(lora_seq.iter())
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
};
let lora_event = router_event(
1,
0,
0,
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: stored_blocks_with_sequence_hashes(&lora_hashes, &lora_seq),
}),
);
index.apply_event(lora_event).await;
flush_and_settle(index.as_ref()).await;
......@@ -1003,49 +931,29 @@ async fn test_lora_base_same_tokens_no_seq_hash_mismatch(variant: &str) {
// Worker 0: base model
index
.apply_event(RouterEvent {
worker_id: 0,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: base_local
.iter()
.zip(base_seq.iter())
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
})
.apply_event(router_event(
0,
0,
0,
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: stored_blocks_with_sequence_hashes(&base_local, &base_seq),
}),
))
.await;
// Worker 1: LoRA adapter — different LocalBlockHash, so this goes to
// a separate tree path instead of colliding with worker 0's node.
index
.apply_event(RouterEvent {
worker_id: 1,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: lora_local
.iter()
.zip(lora_seq.iter())
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
})
.apply_event(router_event(
1,
0,
0,
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: stored_blocks_with_sequence_hashes(&lora_local, &lora_seq),
}),
))
.await;
flush_and_settle(index.as_ref()).await;
......@@ -1094,48 +1002,28 @@ async fn test_different_lora_adapters_do_not_conflict(variant: &str) {
// Store adapter-a blocks on worker 0
index
.apply_event(RouterEvent {
worker_id: 0,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: hashes_a
.iter()
.zip(seq_a.iter())
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
})
.apply_event(router_event(
0,
0,
0,
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: stored_blocks_with_sequence_hashes(&hashes_a, &seq_a),
}),
))
.await;
// Store adapter-b blocks on worker 1
index
.apply_event(RouterEvent {
worker_id: 1,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: hashes_b
.iter()
.zip(seq_b.iter())
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
})
.apply_event(router_event(
1,
0,
0,
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: stored_blocks_with_sequence_hashes(&hashes_b, &seq_b),
}),
))
.await;
flush_and_settle(index.as_ref()).await;
......@@ -1317,16 +1205,7 @@ async fn test_long_sequence_partial_removal(variant: &str) {
.map(|&h| ExternalSequenceBlockHash(h))
.collect();
let remove_event = RouterEvent {
worker_id: 0,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: remove_hashes,
}),
dp_rank: 0,
},
};
let remove_event = remove_event(0, 0, 0, remove_hashes);
index.apply_event(remove_event).await;
flush_and_settle(index.as_ref()).await;
......@@ -1868,6 +1747,7 @@ async fn test_frequency(variant: &str) {
// KvIndexerMetrics tests
// ============================================================================
#[cfg(feature = "metrics")]
#[test]
fn test_increment_event_applied() {
let metrics = KvIndexerMetrics::new_unregistered();
......
......@@ -133,6 +133,94 @@ impl WorkerWithDpRank {
}
}
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "snake_case")]
pub enum StorageTier {
#[default]
Device,
HostPinned,
Disk,
External,
}
impl StorageTier {
pub fn from_kv_medium(medium: &str) -> Option<Self> {
match medium {
"GPU" | "DEVICE" => Some(Self::Device),
"CPU_PINNED" | "CPU_TIER1" => Some(Self::HostPinned),
"CPU_TIER2" | "DISK" | "NVME" => Some(Self::Disk),
"EXTERNAL" | "NETWORK" | "REMOTE" | "SHARED" => Some(Self::External),
_ => None,
}
}
pub fn from_kv_medium_or_default(medium: Option<&str>) -> Self {
medium
.and_then(Self::from_kv_medium)
.unwrap_or(Self::Device)
}
pub fn is_gpu(self) -> bool {
matches!(self, Self::Device)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum PlacementOwner {
LocalWorker(WorkerWithDpRank),
Shared,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct Placement {
pub owner: PlacementOwner,
pub tier: StorageTier,
}
impl Placement {
pub fn local_worker(worker_id: WorkerId, dp_rank: DpRank, tier: StorageTier) -> Self {
Self {
owner: PlacementOwner::LocalWorker(WorkerWithDpRank::new(worker_id, dp_rank)),
tier,
}
}
pub fn local_gpu(worker_id: WorkerId, dp_rank: DpRank) -> Self {
Self::local_worker(worker_id, dp_rank, StorageTier::Device)
}
pub fn is_local_gpu(&self) -> bool {
matches!(self.owner, PlacementOwner::LocalWorker(_)) && self.tier.is_gpu()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct PlacementEvent {
pub placement: Placement,
pub event: KvCacheEvent,
}
impl PlacementEvent {
pub fn new(placement: Placement, event: KvCacheEvent) -> Self {
Self { placement, event }
}
pub fn local_gpu(worker_id: WorkerId, event: KvCacheEvent) -> Self {
Self::new(Placement::local_gpu(worker_id, event.dp_rank), event)
}
pub fn into_router_event(self) -> Option<RouterEvent> {
let PlacementOwner::LocalWorker(worker) = self.placement.owner else {
return None;
};
Some(RouterEvent::with_storage_tier(
worker.worker_id,
self.event,
self.placement.tier,
))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "method", rename_all = "snake_case")]
pub enum RouterRequest {
......@@ -512,6 +600,9 @@ pub enum KvCacheEventError {
pub struct RouterEvent {
/// The ID of the worker emitting the event.
pub worker_id: WorkerId,
/// The storage tier associated with the event.
#[serde(default)]
pub storage_tier: StorageTier,
/// The cache event associated with the worker.
pub event: KvCacheEvent,
}
......@@ -528,7 +619,19 @@ impl RouterEvent {
///
/// A new `RouterEvent`.
pub fn new(worker_id: WorkerId, event: KvCacheEvent) -> Self {
Self { worker_id, event }
Self::with_storage_tier(worker_id, event, StorageTier::Device)
}
pub fn with_storage_tier(
worker_id: WorkerId,
event: KvCacheEvent,
storage_tier: StorageTier,
) -> Self {
Self {
worker_id,
storage_tier,
event,
}
}
}
......
......@@ -11,7 +11,7 @@ use tokio::sync::watch;
use tokio_util::sync::CancellationToken;
use zeromq::{DealerSocket, Socket, SocketRecv, SocketSend, SubSocket};
use crate::protocols::{RouterEvent, WorkerId};
use crate::protocols::{RouterEvent, WorkerId, WorkerWithDpRank};
use crate::zmq_wire::{KvEventBatch, convert_event};
use super::indexer::Indexer;
......@@ -111,9 +111,19 @@ async fn replay_gap(
.data_parallel_rank
.map_or(dp_rank, |rank| rank.cast_unsigned());
for raw_event in batch.events {
let kv_event =
convert_event(raw_event, seq, block_size, effective_dp_rank, warning_count);
let router_event = RouterEvent::new(worker_id, kv_event);
let placement_event = convert_event(
raw_event,
seq,
block_size,
WorkerWithDpRank::new(worker_id, effective_dp_rank),
warning_count,
);
if !placement_event.placement.is_local_gpu() {
continue;
}
let router_event = placement_event
.into_router_event()
.expect("local worker placement must convert to router event");
indexer.apply_event(router_event).await;
}
watermark.store(seq, Ordering::Release);
......@@ -381,9 +391,19 @@ async fn zmq_recv_loop(
.data_parallel_rank
.map_or(dp_rank, |rank| rank.cast_unsigned());
for raw_event in batch.events {
let kv_event =
convert_event(raw_event, seq, block_size, effective_dp_rank, &warning_count);
let router_event = RouterEvent::new(worker_id, kv_event);
let placement_event = convert_event(
raw_event,
seq,
block_size,
WorkerWithDpRank::new(worker_id, effective_dp_rank),
&warning_count,
);
if !placement_event.placement.is_local_gpu() {
continue;
}
let router_event = placement_event
.into_router_event()
.expect("local worker placement must convert to router event");
indexer.apply_event(router_event).await;
messages_processed += 1;
}
......
......@@ -12,6 +12,51 @@ use crate::protocols::{
};
use crate::sequences::SequencePublisher;
pub fn router_event(
worker_id: WorkerId,
event_id: u64,
dp_rank: u32,
data: KvCacheEventData,
) -> RouterEvent {
RouterEvent::new(
worker_id,
KvCacheEvent {
event_id,
data,
dp_rank,
},
)
}
pub fn stored_blocks_with_sequence_hashes(
local_hashes: &[LocalBlockHash],
seq_hashes: &[u64],
) -> Vec<KvCacheStoredBlockData> {
local_hashes
.iter()
.zip(seq_hashes.iter())
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect()
}
pub fn remove_event(
worker_id: WorkerId,
event_id: u64,
dp_rank: u32,
block_hashes: Vec<ExternalSequenceBlockHash>,
) -> RouterEvent {
router_event(
worker_id,
event_id,
dp_rank,
KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }),
)
}
/// Creates blocks with artificial hash mapping (hash * 100) for testing.
pub fn make_blocks(hashes: Vec<u64>) -> Vec<KvCacheStoredBlockData> {
hashes
......@@ -40,30 +85,19 @@ pub fn create_store_event(
hashes: Vec<u64>,
parent: Option<ExternalSequenceBlockHash>,
) -> RouterEvent {
RouterEvent {
worker_id,
event: KvCacheEvent {
event_id,
data: add_blocks(hashes, parent),
dp_rank: 0,
},
}
router_event(worker_id, event_id, 0, add_blocks(hashes, parent))
}
pub fn create_remove_event(worker_id: WorkerId, event_id: u64, hashes: Vec<u64>) -> RouterEvent {
RouterEvent {
remove_event(
worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: hashes
.iter()
.map(|i| ExternalSequenceBlockHash(*i * 100))
.collect(),
}),
dp_rank: 0,
},
}
event_id,
0,
hashes
.iter()
.map(|i| ExternalSequenceBlockHash(*i * 100))
.collect(),
)
}
/// No-op [`SequencePublisher`] for tests and benchmarks that don't need event transport.
......
......@@ -18,7 +18,8 @@ use serde::de::{self, Deserializer, IgnoredAny, MapAccess, SeqAccess, Visitor};
use crate::protocols::{
BlockExtraInfo, BlockMmObjectInfo, ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData,
KvCacheRemoveData, KvCacheStoreData, KvCacheStoredBlockData, compute_block_hash_for_seq,
KvCacheRemoveData, KvCacheStoreData, KvCacheStoredBlockData, Placement, PlacementEvent,
StorageTier, WorkerWithDpRank, compute_block_hash_for_seq,
};
// -------------------------------------------------------------------------
......@@ -335,16 +336,22 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
// Event conversion --------------------------------------------------------
// -------------------------------------------------------------------------
/// Convert a raw event coming from the ZMQ channel into the internal
/// [`KvCacheEvent`] representation used by the router.
/// Convert a raw event coming from the ZMQ channel into a placement-aware worker event.
pub fn convert_event(
raw: RawKvEvent,
event_id: u64,
kv_block_size: u32,
dp_rank: u32,
worker: WorkerWithDpRank,
warning_count: &Arc<AtomicU32>,
) -> KvCacheEvent {
match raw {
) -> PlacementEvent {
let storage_tier = match &raw {
RawKvEvent::BlockStored { medium, .. } | RawKvEvent::BlockRemoved { medium, .. } => {
StorageTier::from_kv_medium_or_default(medium.as_deref())
}
RawKvEvent::AllBlocksCleared => StorageTier::Device,
};
let dp_rank = worker.dp_rank;
let event = match raw {
RawKvEvent::BlockStored {
block_hashes,
parent_block_hash,
......@@ -369,13 +376,16 @@ pub fn convert_event(
// Return an empty Removed instead of Cleared to avoid nuking
// the worker's entire index state. An empty Removed is a no-op
// in the radix tree (zero iterations, returns Ok(())).
return KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![],
}),
dp_rank,
};
return PlacementEvent::new(
Placement::local_worker(worker.worker_id, worker.dp_rank, storage_tier),
KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![],
}),
dp_rank,
},
);
}
}
......@@ -422,7 +432,12 @@ pub fn convert_event(
data: KvCacheEventData::Cleared,
dp_rank,
},
}
};
PlacementEvent::new(
Placement::local_worker(worker.worker_id, worker.dp_rank, storage_tier),
event,
)
}
pub fn create_stored_block_from_parts(
......
......@@ -17,7 +17,7 @@
use std::collections::{HashMap, HashSet};
use dynamo_kv_router::protocols::XXH3_SEED;
use dynamo_kv_router::protocols::{StorageTier as RouterStorageTier, XXH3_SEED};
/// LocalBlockHash type (content hash from tokens only)
type LocalBlockHash = u64;
......@@ -108,12 +108,7 @@ pub enum StorageTier {
impl StorageTier {
/// Parse from vLLM's medium string (e.g., "GPU", "CPU_TIER1", "CPU_TIER2")
pub fn from_vllm_medium(s: &str) -> Option<Self> {
match s {
"GPU" => Some(StorageTier::Device),
"CPU_TIER1" => Some(StorageTier::HostPinned),
"CPU_TIER2" => Some(StorageTier::Disk),
_ => None,
}
RouterStorageTier::from_kv_medium(s).map(Into::into)
}
/// Convert to vLLM's medium string
......@@ -135,6 +130,16 @@ impl StorageTier {
}
}
impl From<RouterStorageTier> for StorageTier {
fn from(value: RouterStorageTier) -> Self {
match value {
RouterStorageTier::Device => Self::Device,
RouterStorageTier::HostPinned => Self::HostPinned,
RouterStorageTier::Disk | RouterStorageTier::External => Self::Disk,
}
}
}
/// Legacy type alias for backward compatibility
#[deprecated(note = "Use StorageTier instead")]
pub type StorageMedium = StorageTier;
......
......@@ -225,10 +225,11 @@ impl KvEventSource {
/// Start the event source from a [`KvEventSourceConfig`].
fn start(
component: Component,
worker_id: WorkerId,
kv_block_size: u32,
source_config: KvEventSourceConfig,
cancellation_token: CancellationToken,
tx: mpsc::UnboundedSender<KvCacheEvent>,
tx: mpsc::UnboundedSender<PlacementEvent>,
next_event_id: Arc<AtomicU64>,
) -> Result<Self> {
match source_config {
......@@ -240,6 +241,7 @@ impl KvEventSource {
.spawn(start_zmq_listener(
endpoint,
topic,
worker_id,
tx,
cancellation_token.clone(),
kv_block_size,
......@@ -269,8 +271,10 @@ pub struct KvEventPublisher {
source: Option<KvEventSource>,
/// The cancellation token.
cancellation_token: CancellationToken,
/// The ID of the local worker emitting placement events.
worker_id: WorkerId,
/// The channel to send events to.
tx: mpsc::UnboundedSender<KvCacheEvent>,
tx: mpsc::UnboundedSender<PlacementEvent>,
/// Internal monotonic event ID counter - ensures each event gets a unique, incrementing ID.
/// Shared with the ZMQ listener (if any) to maintain consistency.
next_event_id: Arc<AtomicU64>,
......@@ -317,7 +321,7 @@ impl KvEventPublisher {
})
.map(|ms| ms.min(MAX_BATCHING_TIMEOUT_MS));
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
// Infer worker_id from component's connection
let worker_id = component.drt().connection_id();
......@@ -345,6 +349,7 @@ impl KvEventPublisher {
if let Some(config) = source_config {
source = Some(KvEventSource::start(
component.clone(),
worker_id,
kv_block_size,
config,
cancellation_token.clone(),
......@@ -444,13 +449,18 @@ impl KvEventPublisher {
kv_block_size,
source,
cancellation_token,
worker_id,
tx,
next_event_id,
})
}
pub fn publish(&self, event: KvCacheEvent) -> Result<(), mpsc::error::SendError<KvCacheEvent>> {
self.tx.send(event)
let placement_event = PlacementEvent::local_gpu(self.worker_id, event);
match self.tx.send(placement_event) {
Ok(()) => Ok(()),
Err(err) => Err(mpsc::error::SendError(err.0.event)),
}
}
/// Get and increment the next event ID atomically.
......@@ -575,7 +585,7 @@ async fn run_event_processor_loop<P: EventSink + Send + Sync + 'static>(
publisher: P,
worker_id: u64,
cancellation_token: CancellationToken,
mut rx: mpsc::UnboundedReceiver<KvCacheEvent>,
mut rx: mpsc::UnboundedReceiver<PlacementEvent>,
local_indexer: Option<Arc<LocalKvIndexer>>,
timeout_ms: Option<u64>,
max_batch_blocks: usize,
......@@ -594,7 +604,7 @@ async fn run_event_processor_loop<P: EventSink + Send + Sync + 'static>(
break;
}
event = rx.recv() => {
let Some(event) = event else {
let Some(placement_event) = event else {
tracing::debug!("Event processor channel closed.");
batching_state.flush(&publisher, &local_indexer, worker_id).await;
break;
......@@ -602,7 +612,7 @@ async fn run_event_processor_loop<P: EventSink + Send + Sync + 'static>(
// Warn if the raw input event_id is not consecutive — events were dropped
// (e.g. channel send error) before they reached the batching layer.
let raw_event_id = event.event_id;
let raw_event_id = placement_event.event.event_id;
if let Some(last_id) = last_raw_input_id
&& raw_event_id > last_id + 1
{
......@@ -627,6 +637,17 @@ async fn run_event_processor_loop<P: EventSink + Send + Sync + 'static>(
}
last_raw_input_id = Some(raw_event_id);
if !placement_event.placement.is_local_gpu() {
tracing::trace!(
worker_id,
?placement_event.placement,
event_id = placement_event.event.event_id,
"Skipping non-local-GPU placement event"
);
continue;
}
let event = placement_event.event;
tracing::trace!("Event processor for worker_id {} processing event: {:?}", worker_id, event.data);
let dp_rank_changed = batching_state.has_pending()
......@@ -702,7 +723,7 @@ async fn start_event_processor<P: EventSink + Send + Sync + 'static>(
publisher: P,
worker_id: u64,
cancellation_token: CancellationToken,
rx: mpsc::UnboundedReceiver<KvCacheEvent>,
rx: mpsc::UnboundedReceiver<PlacementEvent>,
local_indexer: Option<Arc<LocalKvIndexer>>,
batching_timeout_ms: Option<u64>,
) {
......@@ -723,7 +744,7 @@ async fn start_event_processor_jetstream<P: EventSink + Send + Sync + 'static>(
publisher: P,
worker_id: u64,
cancellation_token: CancellationToken,
rx: mpsc::UnboundedReceiver<KvCacheEvent>,
rx: mpsc::UnboundedReceiver<PlacementEvent>,
local_indexer: Option<Arc<LocalKvIndexer>>,
batching_timeout_ms: Option<u64>,
) {
......@@ -750,7 +771,8 @@ fn calculate_backoff_ms(consecutive_errors: u32) -> u64 {
pub async fn start_zmq_listener(
zmq_endpoint: String,
zmq_topic: String,
tx: mpsc::UnboundedSender<KvCacheEvent>,
worker_id: WorkerId,
tx: mpsc::UnboundedSender<PlacementEvent>,
cancellation_token: CancellationToken,
kv_block_size: u32,
next_event_id: Arc<AtomicU64>,
......@@ -868,8 +890,14 @@ pub async fn start_zmq_listener(
for raw_event in batch.events.into_iter() {
// Use shared monotonic event_id counter instead of engine's sequence number
let event_id = next_event_id.fetch_add(1, Ordering::SeqCst);
let event = convert_event(raw_event, event_id, kv_block_size, dp_rank, &warning_count);
let worker = WorkerWithDpRank::new(worker_id, dp_rank);
let event = convert_event(
raw_event,
event_id,
kv_block_size,
worker,
&warning_count,
);
if tx.send(event).is_err() {
tracing::warn!("Failed to send message to channel - receiver dropped");
exit_reason = "channel receiver dropped";
......@@ -1105,8 +1133,14 @@ mod test_event_processing {
block_mm_infos: None,
};
let out = convert_event(raw_evt, 42, kv_block_size, 0, &Arc::new(AtomicU32::new(0)));
assert!(matches!(out.data, KvCacheEventData::Stored(_)));
let out = convert_event(
raw_evt,
42,
kv_block_size,
WorkerWithDpRank::from_worker_id(1),
&Arc::new(AtomicU32::new(0)),
);
assert!(matches!(out.event.data, KvCacheEventData::Stored(_)));
}
#[test]
......@@ -1134,14 +1168,26 @@ mod test_event_processing {
};
let wc = Arc::new(AtomicU32::new(0));
let base_out = convert_event(base_evt, 1, kv_block_size, 0, &wc);
let lora_out = convert_event(lora_evt, 2, kv_block_size, 0, &wc);
let base_out = convert_event(
base_evt,
1,
kv_block_size,
WorkerWithDpRank::from_worker_id(1),
&wc,
);
let lora_out = convert_event(
lora_evt,
2,
kv_block_size,
WorkerWithDpRank::from_worker_id(1),
&wc,
);
let base_hash = match &base_out.data {
let base_hash = match &base_out.event.data {
KvCacheEventData::Stored(s) => s.blocks[0].tokens_hash,
_ => panic!("expected Stored"),
};
let lora_hash = match &lora_out.data {
let lora_hash = match &lora_out.event.data {
KvCacheEventData::Stored(s) => s.blocks[0].tokens_hash,
_ => panic!("expected Stored"),
};
......@@ -1176,14 +1222,26 @@ mod test_event_processing {
block_mm_infos: None,
};
let out1 = convert_event(evt1, 1, kv_block_size, 0, &wc);
let out2 = convert_event(evt2, 2, kv_block_size, 0, &wc);
let out1 = convert_event(
evt1,
1,
kv_block_size,
WorkerWithDpRank::from_worker_id(1),
&wc,
);
let out2 = convert_event(
evt2,
2,
kv_block_size,
WorkerWithDpRank::from_worker_id(1),
&wc,
);
let hash1 = match &out1.data {
let hash1 = match &out1.event.data {
KvCacheEventData::Stored(s) => s.blocks[0].tokens_hash,
_ => panic!("expected Stored"),
};
let hash2 = match &out2.data {
let hash2 = match &out2.event.data {
KvCacheEventData::Stored(s) => s.blocks[0].tokens_hash,
_ => panic!("expected Stored"),
};
......@@ -1256,17 +1314,29 @@ mod test_event_processing {
block_hashes: vec![BlockHashValue::Unsigned(123), BlockHashValue::Signed(456)],
medium: None,
};
let out = convert_event(raw_evt, 7, kv_block_size, 0, &Arc::new(AtomicU32::new(0)));
let out = convert_event(
raw_evt,
7,
kv_block_size,
WorkerWithDpRank::from_worker_id(1),
&Arc::new(AtomicU32::new(0)),
);
assert!(matches!(out.data, KvCacheEventData::Removed(_)));
assert!(matches!(out.event.data, KvCacheEventData::Removed(_)));
}
#[test]
fn test_convert_event_all_blocks_cleared() {
let kv_block_size = 4;
let raw_evt = RawKvEvent::AllBlocksCleared;
let out = convert_event(raw_evt, 1, kv_block_size, 0, &Arc::new(AtomicU32::new(0)));
assert!(matches!(out.data, KvCacheEventData::Cleared));
let out = convert_event(
raw_evt,
1,
kv_block_size,
WorkerWithDpRank::from_worker_id(1),
&Arc::new(AtomicU32::new(0)),
);
assert!(matches!(out.event.data, KvCacheEventData::Cleared));
}
#[test]
......@@ -1425,6 +1495,10 @@ mod tests_startup_helpers {
}
}
fn local_gpu_event(worker_id: WorkerId, event: KvCacheEvent) -> PlacementEvent {
PlacementEvent::local_gpu(worker_id, event)
}
//--------------------------------------------------------------------
// Test start_event_processor
//--------------------------------------------------------------------
......@@ -1441,8 +1515,8 @@ mod tests_startup_helpers {
};
let token = CancellationToken::new();
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
tx.send(event).unwrap();
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
tx.send(local_gpu_event(1, event)).unwrap();
drop(tx);
let handle = tokio::spawn(start_event_processor(
......@@ -1498,8 +1572,8 @@ mod tests_startup_helpers {
dp_rank: 0,
};
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
tx.send(event).unwrap();
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
tx.send(local_gpu_event(1, event)).unwrap();
drop(tx);
// Start event processor with local indexer
......@@ -1583,8 +1657,8 @@ mod tests_startup_helpers {
dp_rank: 0,
};
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
tx.send(store_event).unwrap();
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
tx.send(local_gpu_event(1, store_event)).unwrap();
// Start event processor with local indexer
let handle = tokio::spawn(start_event_processor(
......@@ -1604,7 +1678,7 @@ mod tests_startup_helpers {
}),
dp_rank: 0,
};
tx.send(remove_event).unwrap();
tx.send(local_gpu_event(1, remove_event)).unwrap();
drop(tx);
tokio::time::timeout(tokio::time::Duration::from_secs(1), handle)
......@@ -1665,8 +1739,8 @@ mod tests_startup_helpers {
dp_rank: 0,
};
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
tx.send(store_event).unwrap();
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
tx.send(local_gpu_event(1, store_event)).unwrap();
// Clear all blocks
let clear_event = KvCacheEvent {
......@@ -1674,7 +1748,7 @@ mod tests_startup_helpers {
data: KvCacheEventData::Cleared,
dp_rank: 0,
};
tx.send(clear_event).unwrap();
tx.send(local_gpu_event(1, clear_event)).unwrap();
drop(tx);
// Create event processor and wait
......@@ -1743,8 +1817,8 @@ mod tests_startup_helpers {
};
let new_token = CancellationToken::new();
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
tx.send(event).unwrap();
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
tx.send(local_gpu_event(1, event)).unwrap();
drop(tx);
// Despite local indexer being cancelled, event processor should continue
......@@ -1774,7 +1848,7 @@ mod tests_startup_helpers {
#[tokio::test]
async fn test_start_zmq_listener_pushes_to_channel() {
// Prepare channel that listener should fill
let (tx, mut rx) = mpsc::unbounded_channel::<KvCacheEvent>();
let (tx, mut rx) = mpsc::unbounded_channel::<PlacementEvent>();
// ZMQ TCP endpoint using localhost with fixed port
let endpoint = "tcp://127.0.0.1:15555";
......@@ -1792,7 +1866,7 @@ mod tests_startup_helpers {
// Spawn async listener (connects to publisher bound above)
let listener_handle = tokio::spawn({
let token = token.clone();
start_zmq_listener(endpoint.to_string(), topic, tx, token, 4, next_event_id)
start_zmq_listener(endpoint.to_string(), topic, 1, tx, token, 4, next_event_id)
});
// Give time for the connection to establish
......@@ -1835,7 +1909,7 @@ mod tests_startup_helpers {
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// Check that we received the message
let event = rx.try_recv().expect("no message received");
let event = rx.try_recv().expect("no message received").event;
let KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
......@@ -1872,7 +1946,7 @@ mod tests_startup_helpers {
100, // buffer size
));
let (worker_tx, worker_rx) = mpsc::unbounded_channel::<KvCacheEvent>();
let (worker_tx, worker_rx) = mpsc::unbounded_channel::<PlacementEvent>();
// Start worker's event processor
tokio::spawn(start_event_processor(
......@@ -1912,7 +1986,9 @@ mod tests_startup_helpers {
dp_rank: 0,
};
worker_tx.send(event_1.clone()).unwrap();
worker_tx
.send(local_gpu_event(worker_1_id, event_1.clone()))
.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// Simulate JetStream: forward worker's published event to router
......@@ -1978,7 +2054,9 @@ mod tests_startup_helpers {
dp_rank: 0,
};
worker_tx.send(event_2.clone()).unwrap(); // send to worker but not to router
worker_tx
.send(local_gpu_event(worker_1_id, event_2.clone()))
.unwrap(); // send to worker but not to router
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// assert: Worker published event_2 to "NATS" (MockComponent)
......@@ -2395,6 +2473,10 @@ mod event_processor_tests {
}
}
fn local_gpu_event(event: KvCacheEvent) -> PlacementEvent {
PlacementEvent::local_gpu(1, event)
}
/// Test that pushing N removed events results in batched output
/// Uses a 10ms timeout to ensure events are batched (events sent rapidly)
#[tokio::test]
......@@ -2419,7 +2501,7 @@ mod event_processor_tests {
/// Helper function to test removed events batching with configurable count and timeout
async fn test_removed_events_batching(event_count: usize, timeout_ms: Option<u64>) {
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
let publisher = MockPublisher::new();
let publisher_clone = publisher.clone();
let cancellation_token = CancellationToken::new();
......@@ -2445,7 +2527,7 @@ mod event_processor_tests {
}),
dp_rank: 0,
};
tx.send(event).unwrap();
tx.send(local_gpu_event(event)).unwrap();
// Yield to allow event processor to process the event
tokio::task::yield_now().await;
}
......@@ -2516,7 +2598,7 @@ mod event_processor_tests {
/// Helper function to test stored events batching with configurable count and timeout
async fn test_stored_events_batching(event_count: usize, timeout_ms: Option<u64>) {
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
let publisher = MockPublisher::new();
let publisher_clone = publisher.clone();
let cancellation_token = CancellationToken::new();
......@@ -2554,7 +2636,7 @@ mod event_processor_tests {
}),
dp_rank: 0,
};
tx.send(event).unwrap();
tx.send(local_gpu_event(event)).unwrap();
// Small sleep to allow event processor to batch events
tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
}
......@@ -2614,7 +2696,7 @@ mod event_processor_tests {
async fn test_run_event_processor_loop_non_sequential_flush() {
let timeout_ms = Some(100); // 100ms timeout
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
let publisher = MockPublisher::new();
let publisher_clone = publisher.clone();
let cancellation_token = CancellationToken::new();
......@@ -2646,7 +2728,7 @@ mod event_processor_tests {
}),
dp_rank: 0,
};
tx.send(event).unwrap();
tx.send(local_gpu_event(event)).unwrap();
}
drop(tx);
......@@ -2697,7 +2779,7 @@ mod event_processor_tests {
/// Helper function to test no batching with slow input
async fn test_no_batching_with_slow_input(timeout_ms: Option<u64>) {
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
let publisher = MockPublisher::new();
let publisher_clone = publisher.clone();
let cancellation_token = CancellationToken::new();
......@@ -2725,7 +2807,7 @@ mod event_processor_tests {
}),
dp_rank: 0,
};
tx.send(event).unwrap();
tx.send(local_gpu_event(event)).unwrap();
// Wait 2ms between events (much longer than the timeout)
// This ensures each event times out before the next one arrives
tokio::time::sleep(tokio::time::Duration::from_millis(2)).await;
......@@ -2770,7 +2852,7 @@ mod event_processor_tests {
async fn test_event_type_switching_causes_flush() {
let timeout_ms = Some(100); // 100ms timeout
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
let publisher = MockPublisher::new();
let publisher_clone = publisher.clone();
let cancellation_token = CancellationToken::new();
......@@ -2789,20 +2871,20 @@ mod event_processor_tests {
});
// Send a Removed event
tx.send(KvCacheEvent {
tx.send(local_gpu_event(KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(0)],
}),
dp_rank: 0,
})
}))
.unwrap();
// Small sleep
tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
// Send a Stored event (should cause flush of the Removed event)
tx.send(KvCacheEvent {
tx.send(local_gpu_event(KvCacheEvent {
event_id: 1,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: Some(ExternalSequenceBlockHash(0)),
......@@ -2813,7 +2895,7 @@ mod event_processor_tests {
}],
}),
dp_rank: 0,
})
}))
.unwrap();
// Give time for processing
......@@ -2837,7 +2919,7 @@ mod event_processor_tests {
async fn test_dp_rank_change_causes_flush() {
let timeout_ms = Some(100); // 100ms timeout
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
let publisher = MockPublisher::new();
let publisher_clone = publisher.clone();
let cancellation_token = CancellationToken::new();
......@@ -2857,26 +2939,26 @@ mod event_processor_tests {
// Send events with dp_rank=0
for i in 0..3 {
tx.send(KvCacheEvent {
tx.send(local_gpu_event(KvCacheEvent {
event_id: i as u64,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(i as u64)],
}),
dp_rank: 0,
})
}))
.unwrap();
tokio::task::yield_now().await;
}
// Send events with dp_rank=1 (should cause flush of previous batch)
for i in 3..6 {
tx.send(KvCacheEvent {
tx.send(local_gpu_event(KvCacheEvent {
event_id: i as u64,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(i as u64)],
}),
dp_rank: 1,
})
}))
.unwrap();
tokio::task::yield_now().await;
}
......@@ -2929,7 +3011,7 @@ mod event_processor_tests {
async fn test_flushed_events_have_correct_metadata() {
let timeout_ms = Some(100); // 100ms timeout
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
let publisher = MockPublisher::new();
let publisher_clone = publisher.clone();
let cancellation_token = CancellationToken::new();
......@@ -2949,13 +3031,13 @@ mod event_processor_tests {
// Send first batch: 3 events with dp_rank=0, event_ids 10-12
for i in 0..3 {
tx.send(KvCacheEvent {
tx.send(local_gpu_event(KvCacheEvent {
event_id: 10 + i as u64,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(i as u64)],
}),
dp_rank: 0,
})
}))
.unwrap();
tokio::task::yield_now().await;
}
......@@ -2963,13 +3045,13 @@ mod event_processor_tests {
// Send second batch: 2 events with dp_rank=1, event_ids 20-21
// This should flush the first batch with dp_rank=0
for i in 0..2 {
tx.send(KvCacheEvent {
tx.send(local_gpu_event(KvCacheEvent {
event_id: 20 + i as u64,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash((i + 3) as u64)],
}),
dp_rank: 1,
})
}))
.unwrap();
tokio::task::yield_now().await;
}
......@@ -3015,7 +3097,7 @@ mod event_processor_tests {
async fn test_first_event_after_idle_flushes_immediately_then_batches() {
let timeout_ms = Some(50); // 50ms timeout
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
let publisher = MockPublisher::new();
let publisher_clone = publisher.clone();
let cancellation_token = CancellationToken::new();
......@@ -3039,13 +3121,13 @@ mod event_processor_tests {
// Send 3 events rapidly - first should flush immediately (stale timer),
// remaining 2 should batch together
for i in 0..3 {
tx.send(KvCacheEvent {
tx.send(local_gpu_event(KvCacheEvent {
event_id: i as u64,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(i as u64)],
}),
dp_rank: 0,
})
}))
.unwrap();
tokio::task::yield_now().await;
}
......@@ -3085,7 +3167,7 @@ mod event_processor_tests {
async fn test_stored_events_with_dp_rank_change_correct_metadata() {
let timeout_ms = Some(100); // 100ms timeout
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
let publisher = MockPublisher::new();
let publisher_clone = publisher.clone();
let cancellation_token = CancellationToken::new();
......@@ -3104,7 +3186,7 @@ mod event_processor_tests {
});
// Send first batch: 2 sequential stored events with dp_rank=0, event_ids 100-101
tx.send(KvCacheEvent {
tx.send(local_gpu_event(KvCacheEvent {
event_id: 100,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: Some(ExternalSequenceBlockHash(0)),
......@@ -3115,11 +3197,11 @@ mod event_processor_tests {
}],
}),
dp_rank: 0,
})
}))
.unwrap();
tokio::task::yield_now().await;
tx.send(KvCacheEvent {
tx.send(local_gpu_event(KvCacheEvent {
event_id: 101,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: Some(ExternalSequenceBlockHash(1)),
......@@ -3130,13 +3212,13 @@ mod event_processor_tests {
}],
}),
dp_rank: 0,
})
}))
.unwrap();
tokio::task::yield_now().await;
// Send second batch: 1 event with dp_rank=1, event_id=200
// This should flush the first batch with dp_rank=0, event_id=101
tx.send(KvCacheEvent {
tx.send(local_gpu_event(KvCacheEvent {
event_id: 200,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: Some(ExternalSequenceBlockHash(0)),
......@@ -3147,7 +3229,7 @@ mod event_processor_tests {
}],
}),
dp_rank: 1,
})
}))
.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(2)).await;
......@@ -3202,7 +3284,7 @@ mod event_processor_tests {
async fn test_batch_parent_hash_preserved_when_extending() {
let timeout_ms = Some(100); // 100ms timeout
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
let publisher = MockPublisher::new();
let publisher_clone = publisher.clone();
let cancellation_token = CancellationToken::new();
......@@ -3221,7 +3303,7 @@ mod event_processor_tests {
});
// First event: parent_hash=None, block_hash=1
tx.send(KvCacheEvent {
tx.send(local_gpu_event(KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None, // Root block with no parent
......@@ -3232,12 +3314,12 @@ mod event_processor_tests {
}],
}),
dp_rank: 0,
})
}))
.unwrap();
tokio::task::yield_now().await;
// Second event: parent_hash=Some(1), block_hash=2 (sequential)
tx.send(KvCacheEvent {
tx.send(local_gpu_event(KvCacheEvent {
event_id: 1,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: Some(ExternalSequenceBlockHash(1)), // Points to previous block
......@@ -3248,12 +3330,12 @@ mod event_processor_tests {
}],
}),
dp_rank: 0,
})
}))
.unwrap();
tokio::task::yield_now().await;
// Third event: parent_hash=Some(2), block_hash=3 (sequential)
tx.send(KvCacheEvent {
tx.send(local_gpu_event(KvCacheEvent {
event_id: 2,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: Some(ExternalSequenceBlockHash(2)),
......@@ -3264,7 +3346,7 @@ mod event_processor_tests {
}],
}),
dp_rank: 0,
})
}))
.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(2)).await;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment