"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/nni.git" did not exist on "12410686a0d6da73e449291d41cc9ebcac19615a"
Unverified Commit e811db50 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat(kv-router): plumb medium into worker-side placements (#7462)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 9681225a
...@@ -623,9 +623,9 @@ impl SequenceData { ...@@ -623,9 +623,9 @@ impl SequenceData {
/// Convert to a store event. /// Convert to a store event.
pub fn to_store_event(&self, event_id: u64) -> RouterEvent { pub fn to_store_event(&self, event_id: u64) -> RouterEvent {
RouterEvent { RouterEvent::new(
worker_id: self.worker_id, self.worker_id,
event: KvCacheEvent { KvCacheEvent {
event_id, event_id,
data: KvCacheEventData::Stored(KvCacheStoreData { data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None, parent_hash: None,
...@@ -642,21 +642,21 @@ impl SequenceData { ...@@ -642,21 +642,21 @@ impl SequenceData {
}), }),
dp_rank: 0, dp_rank: 0,
}, },
} )
} }
/// Convert to a remove event. /// Convert to a remove event.
pub fn to_remove_event(&self, event_id: u64) -> RouterEvent { pub fn to_remove_event(&self, event_id: u64) -> RouterEvent {
RouterEvent { RouterEvent::new(
worker_id: self.worker_id, self.worker_id,
event: KvCacheEvent { KvCacheEvent {
event_id, event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData { data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: self.external_hashes.clone(), block_hashes: self.external_hashes.clone(),
}), }),
dp_rank: 0, dp_rank: 0,
}, },
} )
} }
} }
......
...@@ -297,10 +297,7 @@ async fn run_benchmark( ...@@ -297,10 +297,7 @@ async fn run_benchmark(
} }
WorkerTraceEntry::Event(event) => { WorkerTraceEntry::Event(event) => {
indexer indexer
.apply_event(RouterEvent { .apply_event(RouterEvent::new(worker_id as u64, event))
worker_id: worker_id as u64,
event,
})
.await; .await;
Ok(None) Ok(None)
} }
......
...@@ -590,6 +590,7 @@ impl ConcurrentRadixTree { ...@@ -590,6 +590,7 @@ impl ConcurrentRadixTree {
// Create a store event for this worker // Create a store event for this worker
let event = RouterEvent { let event = RouterEvent {
worker_id: worker.worker_id, worker_id: worker.worker_id,
storage_tier: crate::protocols::StorageTier::Device,
event: KvCacheEvent { event: KvCacheEvent {
event_id, event_id,
data: KvCacheEventData::Stored(KvCacheStoreData { data: KvCacheEventData::Stored(KvCacheStoreData {
......
...@@ -432,6 +432,7 @@ impl PositionalIndexer { ...@@ -432,6 +432,7 @@ impl PositionalIndexer {
events.push(RouterEvent { events.push(RouterEvent {
worker_id: worker.worker_id, worker_id: worker.worker_id,
storage_tier: crate::protocols::StorageTier::Device,
event: KvCacheEvent { event: KvCacheEvent {
event_id, event_id,
data: KvCacheEventData::Stored(KvCacheStoreData { data: KvCacheEventData::Stored(KvCacheStoreData {
......
...@@ -557,6 +557,7 @@ impl RadixTree { ...@@ -557,6 +557,7 @@ impl RadixTree {
// Create a store event for this worker // Create a store event for this worker
let event = RouterEvent { let event = RouterEvent {
worker_id: worker.worker_id, worker_id: worker.worker_id,
storage_tier: crate::protocols::StorageTier::Device,
event: KvCacheEvent { event: KvCacheEvent {
event_id, event_id,
data: KvCacheEventData::Stored(KvCacheStoreData { data: KvCacheEventData::Stored(KvCacheStoreData {
......
...@@ -13,6 +13,7 @@ use super::concurrent_radix_tree::ConcurrentRadixTree; ...@@ -13,6 +13,7 @@ use super::concurrent_radix_tree::ConcurrentRadixTree;
use super::positional::PositionalIndexer; use super::positional::PositionalIndexer;
use super::*; use super::*;
use crate::protocols::*; use crate::protocols::*;
use crate::test_utils::{remove_event, router_event, stored_blocks_with_sequence_hashes};
// ============================================================================ // ============================================================================
// Helper functions // Helper functions
...@@ -63,25 +64,15 @@ fn make_store_event_with_parent( ...@@ -63,25 +64,15 @@ fn make_store_event_with_parent(
local_hashes.iter().map(|&h| LocalBlockHash(h)).collect(); local_hashes.iter().map(|&h| LocalBlockHash(h)).collect();
let new_seq_hashes = &full_seq_hashes[prefix_hashes.len()..]; let new_seq_hashes = &full_seq_hashes[prefix_hashes.len()..];
RouterEvent { router_event(
worker_id, worker_id,
event: KvCacheEvent { 0,
event_id: 0, 0,
data: KvCacheEventData::Stored(KvCacheStoreData { KvCacheEventData::Stored(KvCacheStoreData {
parent_hash, parent_hash,
blocks: new_block_hashes blocks: stored_blocks_with_sequence_hashes(&new_block_hashes, new_seq_hashes),
.iter() }),
.zip(new_seq_hashes.iter()) )
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
}
} }
/// Create a store event with all options. /// Create a store event with all options.
...@@ -95,25 +86,15 @@ fn make_store_event_full( ...@@ -95,25 +86,15 @@ fn make_store_event_full(
local_hashes.iter().map(|&h| LocalBlockHash(h)).collect(); local_hashes.iter().map(|&h| LocalBlockHash(h)).collect();
let seq_hashes = compute_seq_hash_for_block(&local_block_hashes); let seq_hashes = compute_seq_hash_for_block(&local_block_hashes);
RouterEvent { router_event(
worker_id, worker_id,
event: KvCacheEvent { 0,
event_id: 0, dp_rank,
data: KvCacheEventData::Stored(KvCacheStoreData { KvCacheEventData::Stored(KvCacheStoreData {
parent_hash, parent_hash,
blocks: local_block_hashes blocks: stored_blocks_with_sequence_hashes(&local_block_hashes, &seq_hashes),
.iter() }),
.zip(seq_hashes.iter()) )
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank,
},
}
} }
/// Create a remove event for blocks with given local hashes. /// Create a remove event for blocks with given local hashes.
...@@ -131,19 +112,15 @@ fn make_remove_event_with_dp_rank( ...@@ -131,19 +112,15 @@ fn make_remove_event_with_dp_rank(
local_hashes.iter().map(|&h| LocalBlockHash(h)).collect(); local_hashes.iter().map(|&h| LocalBlockHash(h)).collect();
let seq_hashes = compute_seq_hash_for_block(&local_block_hashes); let seq_hashes = compute_seq_hash_for_block(&local_block_hashes);
RouterEvent { remove_event(
worker_id, worker_id,
event: KvCacheEvent { 0,
event_id: 0, dp_rank,
data: KvCacheEventData::Removed(KvCacheRemoveData { seq_hashes
block_hashes: seq_hashes .iter()
.iter() .map(|&h| ExternalSequenceBlockHash(h))
.map(|&h| ExternalSequenceBlockHash(h)) .collect(),
.collect(), )
}),
dp_rank,
},
}
} }
/// Create a remove event with parent hash for continuation sequences. /// Create a remove event with parent hash for continuation sequences.
...@@ -165,19 +142,15 @@ fn make_remove_event_with_parent( ...@@ -165,19 +142,15 @@ fn make_remove_event_with_parent(
let suffix_seq_hashes = &full_seq_hashes[prefix_hashes.len()..]; let suffix_seq_hashes = &full_seq_hashes[prefix_hashes.len()..];
RouterEvent { remove_event(
worker_id, worker_id,
event: KvCacheEvent { 0,
event_id: 0, 0,
data: KvCacheEventData::Removed(KvCacheRemoveData { suffix_seq_hashes
block_hashes: suffix_seq_hashes .iter()
.iter() .map(|&h| ExternalSequenceBlockHash(h))
.map(|&h| ExternalSequenceBlockHash(h)) .collect(),
.collect(), )
}),
dp_rank: 0,
},
}
} }
/// Snapshot the tree state for deterministic comparison. /// Snapshot the tree state for deterministic comparison.
...@@ -222,14 +195,7 @@ fn make_clear_event(worker_id: u64) -> RouterEvent { ...@@ -222,14 +195,7 @@ fn make_clear_event(worker_id: u64) -> RouterEvent {
/// Create a clear event with a specific dp_rank. /// Create a clear event with a specific dp_rank.
fn make_clear_event_with_dp_rank(worker_id: u64, dp_rank: u32) -> RouterEvent { fn make_clear_event_with_dp_rank(worker_id: u64, dp_rank: u32) -> RouterEvent {
RouterEvent { router_event(worker_id, 0, dp_rank, KvCacheEventData::Cleared)
worker_id,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Cleared,
dp_rank,
},
}
} }
// ============================================================================ // ============================================================================
...@@ -646,16 +612,7 @@ async fn test_partial_block_removal(variant: &str) { ...@@ -646,16 +612,7 @@ async fn test_partial_block_removal(variant: &str) {
let seq_hashes = compute_seq_hash_for_block(&full_hashes); let seq_hashes = compute_seq_hash_for_block(&full_hashes);
let block_3_seq_hash = ExternalSequenceBlockHash(seq_hashes[2]); // Last block's hash let block_3_seq_hash = ExternalSequenceBlockHash(seq_hashes[2]); // Last block's hash
let remove_event = RouterEvent { let remove_event = remove_event(0, 0, 0, vec![block_3_seq_hash]);
worker_id: 0,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![block_3_seq_hash],
}),
dp_rank: 0,
},
};
index.apply_event(remove_event).await; index.apply_event(remove_event).await;
flush_and_settle(index.as_ref()).await; flush_and_settle(index.as_ref()).await;
...@@ -698,16 +655,7 @@ async fn test_remove_mid_chain_block(variant: &str) { ...@@ -698,16 +655,7 @@ async fn test_remove_mid_chain_block(variant: &str) {
let seq_hashes = compute_seq_hash_for_block(&full_hashes); let seq_hashes = compute_seq_hash_for_block(&full_hashes);
let block_3_seq_hash = ExternalSequenceBlockHash(seq_hashes[2]); let block_3_seq_hash = ExternalSequenceBlockHash(seq_hashes[2]);
let remove_event = RouterEvent { let remove_event = remove_event(0, 0, 0, vec![block_3_seq_hash]);
worker_id: 0,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![block_3_seq_hash],
}),
dp_rank: 0,
},
};
index.apply_event(remove_event).await; index.apply_event(remove_event).await;
flush_and_settle(index.as_ref()).await; flush_and_settle(index.as_ref()).await;
...@@ -895,47 +843,27 @@ async fn test_lora_and_base_model_blocks_do_not_conflict(variant: &str) { ...@@ -895,47 +843,27 @@ async fn test_lora_and_base_model_blocks_do_not_conflict(variant: &str) {
let lora_seq = compute_seq_hash_for_block(&lora_hashes); let lora_seq = compute_seq_hash_for_block(&lora_hashes);
// Store base-model blocks on worker 0 // Store base-model blocks on worker 0
let base_event = RouterEvent { let base_event = router_event(
worker_id: 0, 0,
event: KvCacheEvent { 0,
event_id: 0, 0,
data: KvCacheEventData::Stored(KvCacheStoreData { KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None, parent_hash: None,
blocks: base_hashes blocks: stored_blocks_with_sequence_hashes(&base_hashes, &base_seq),
.iter() }),
.zip(base_seq.iter()) );
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
};
index.apply_event(base_event).await; index.apply_event(base_event).await;
// Store LoRA blocks on worker 1 // Store LoRA blocks on worker 1
let lora_event = RouterEvent { let lora_event = router_event(
worker_id: 1, 1,
event: KvCacheEvent { 0,
event_id: 0, 0,
data: KvCacheEventData::Stored(KvCacheStoreData { KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None, parent_hash: None,
blocks: lora_hashes blocks: stored_blocks_with_sequence_hashes(&lora_hashes, &lora_seq),
.iter() }),
.zip(lora_seq.iter()) );
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
};
index.apply_event(lora_event).await; index.apply_event(lora_event).await;
flush_and_settle(index.as_ref()).await; flush_and_settle(index.as_ref()).await;
...@@ -1003,49 +931,29 @@ async fn test_lora_base_same_tokens_no_seq_hash_mismatch(variant: &str) { ...@@ -1003,49 +931,29 @@ async fn test_lora_base_same_tokens_no_seq_hash_mismatch(variant: &str) {
// Worker 0: base model // Worker 0: base model
index index
.apply_event(RouterEvent { .apply_event(router_event(
worker_id: 0, 0,
event: KvCacheEvent { 0,
event_id: 0, 0,
data: KvCacheEventData::Stored(KvCacheStoreData { KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None, parent_hash: None,
blocks: base_local blocks: stored_blocks_with_sequence_hashes(&base_local, &base_seq),
.iter() }),
.zip(base_seq.iter()) ))
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
})
.await; .await;
// Worker 1: LoRA adapter — different LocalBlockHash, so this goes to // Worker 1: LoRA adapter — different LocalBlockHash, so this goes to
// a separate tree path instead of colliding with worker 0's node. // a separate tree path instead of colliding with worker 0's node.
index index
.apply_event(RouterEvent { .apply_event(router_event(
worker_id: 1, 1,
event: KvCacheEvent { 0,
event_id: 0, 0,
data: KvCacheEventData::Stored(KvCacheStoreData { KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None, parent_hash: None,
blocks: lora_local blocks: stored_blocks_with_sequence_hashes(&lora_local, &lora_seq),
.iter() }),
.zip(lora_seq.iter()) ))
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
})
.await; .await;
flush_and_settle(index.as_ref()).await; flush_and_settle(index.as_ref()).await;
...@@ -1094,48 +1002,28 @@ async fn test_different_lora_adapters_do_not_conflict(variant: &str) { ...@@ -1094,48 +1002,28 @@ async fn test_different_lora_adapters_do_not_conflict(variant: &str) {
// Store adapter-a blocks on worker 0 // Store adapter-a blocks on worker 0
index index
.apply_event(RouterEvent { .apply_event(router_event(
worker_id: 0, 0,
event: KvCacheEvent { 0,
event_id: 0, 0,
data: KvCacheEventData::Stored(KvCacheStoreData { KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None, parent_hash: None,
blocks: hashes_a blocks: stored_blocks_with_sequence_hashes(&hashes_a, &seq_a),
.iter() }),
.zip(seq_a.iter()) ))
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
})
.await; .await;
// Store adapter-b blocks on worker 1 // Store adapter-b blocks on worker 1
index index
.apply_event(RouterEvent { .apply_event(router_event(
worker_id: 1, 1,
event: KvCacheEvent { 0,
event_id: 0, 0,
data: KvCacheEventData::Stored(KvCacheStoreData { KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None, parent_hash: None,
blocks: hashes_b blocks: stored_blocks_with_sequence_hashes(&hashes_b, &seq_b),
.iter() }),
.zip(seq_b.iter()) ))
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
})
.await; .await;
flush_and_settle(index.as_ref()).await; flush_and_settle(index.as_ref()).await;
...@@ -1317,16 +1205,7 @@ async fn test_long_sequence_partial_removal(variant: &str) { ...@@ -1317,16 +1205,7 @@ async fn test_long_sequence_partial_removal(variant: &str) {
.map(|&h| ExternalSequenceBlockHash(h)) .map(|&h| ExternalSequenceBlockHash(h))
.collect(); .collect();
let remove_event = RouterEvent { let remove_event = remove_event(0, 0, 0, remove_hashes);
worker_id: 0,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: remove_hashes,
}),
dp_rank: 0,
},
};
index.apply_event(remove_event).await; index.apply_event(remove_event).await;
flush_and_settle(index.as_ref()).await; flush_and_settle(index.as_ref()).await;
...@@ -1868,6 +1747,7 @@ async fn test_frequency(variant: &str) { ...@@ -1868,6 +1747,7 @@ async fn test_frequency(variant: &str) {
// KvIndexerMetrics tests // KvIndexerMetrics tests
// ============================================================================ // ============================================================================
#[cfg(feature = "metrics")]
#[test] #[test]
fn test_increment_event_applied() { fn test_increment_event_applied() {
let metrics = KvIndexerMetrics::new_unregistered(); let metrics = KvIndexerMetrics::new_unregistered();
......
...@@ -133,6 +133,94 @@ impl WorkerWithDpRank { ...@@ -133,6 +133,94 @@ impl WorkerWithDpRank {
} }
} }
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "snake_case")]
pub enum StorageTier {
#[default]
Device,
HostPinned,
Disk,
External,
}
impl StorageTier {
pub fn from_kv_medium(medium: &str) -> Option<Self> {
match medium {
"GPU" | "DEVICE" => Some(Self::Device),
"CPU_PINNED" | "CPU_TIER1" => Some(Self::HostPinned),
"CPU_TIER2" | "DISK" | "NVME" => Some(Self::Disk),
"EXTERNAL" | "NETWORK" | "REMOTE" | "SHARED" => Some(Self::External),
_ => None,
}
}
pub fn from_kv_medium_or_default(medium: Option<&str>) -> Self {
medium
.and_then(Self::from_kv_medium)
.unwrap_or(Self::Device)
}
pub fn is_gpu(self) -> bool {
matches!(self, Self::Device)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum PlacementOwner {
LocalWorker(WorkerWithDpRank),
Shared,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct Placement {
pub owner: PlacementOwner,
pub tier: StorageTier,
}
impl Placement {
pub fn local_worker(worker_id: WorkerId, dp_rank: DpRank, tier: StorageTier) -> Self {
Self {
owner: PlacementOwner::LocalWorker(WorkerWithDpRank::new(worker_id, dp_rank)),
tier,
}
}
pub fn local_gpu(worker_id: WorkerId, dp_rank: DpRank) -> Self {
Self::local_worker(worker_id, dp_rank, StorageTier::Device)
}
pub fn is_local_gpu(&self) -> bool {
matches!(self.owner, PlacementOwner::LocalWorker(_)) && self.tier.is_gpu()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct PlacementEvent {
pub placement: Placement,
pub event: KvCacheEvent,
}
impl PlacementEvent {
pub fn new(placement: Placement, event: KvCacheEvent) -> Self {
Self { placement, event }
}
pub fn local_gpu(worker_id: WorkerId, event: KvCacheEvent) -> Self {
Self::new(Placement::local_gpu(worker_id, event.dp_rank), event)
}
pub fn into_router_event(self) -> Option<RouterEvent> {
let PlacementOwner::LocalWorker(worker) = self.placement.owner else {
return None;
};
Some(RouterEvent::with_storage_tier(
worker.worker_id,
self.event,
self.placement.tier,
))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "method", rename_all = "snake_case")] #[serde(tag = "method", rename_all = "snake_case")]
pub enum RouterRequest { pub enum RouterRequest {
...@@ -512,6 +600,9 @@ pub enum KvCacheEventError { ...@@ -512,6 +600,9 @@ pub enum KvCacheEventError {
pub struct RouterEvent { pub struct RouterEvent {
/// The ID of the worker emitting the event. /// The ID of the worker emitting the event.
pub worker_id: WorkerId, pub worker_id: WorkerId,
/// The storage tier associated with the event.
#[serde(default)]
pub storage_tier: StorageTier,
/// The cache event associated with the worker. /// The cache event associated with the worker.
pub event: KvCacheEvent, pub event: KvCacheEvent,
} }
...@@ -528,7 +619,19 @@ impl RouterEvent { ...@@ -528,7 +619,19 @@ impl RouterEvent {
/// ///
/// A new `RouterEvent`. /// A new `RouterEvent`.
pub fn new(worker_id: WorkerId, event: KvCacheEvent) -> Self { pub fn new(worker_id: WorkerId, event: KvCacheEvent) -> Self {
Self { worker_id, event } Self::with_storage_tier(worker_id, event, StorageTier::Device)
}
pub fn with_storage_tier(
worker_id: WorkerId,
event: KvCacheEvent,
storage_tier: StorageTier,
) -> Self {
Self {
worker_id,
storage_tier,
event,
}
} }
} }
......
...@@ -11,7 +11,7 @@ use tokio::sync::watch; ...@@ -11,7 +11,7 @@ use tokio::sync::watch;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use zeromq::{DealerSocket, Socket, SocketRecv, SocketSend, SubSocket}; use zeromq::{DealerSocket, Socket, SocketRecv, SocketSend, SubSocket};
use crate::protocols::{RouterEvent, WorkerId}; use crate::protocols::{RouterEvent, WorkerId, WorkerWithDpRank};
use crate::zmq_wire::{KvEventBatch, convert_event}; use crate::zmq_wire::{KvEventBatch, convert_event};
use super::indexer::Indexer; use super::indexer::Indexer;
...@@ -111,9 +111,19 @@ async fn replay_gap( ...@@ -111,9 +111,19 @@ async fn replay_gap(
.data_parallel_rank .data_parallel_rank
.map_or(dp_rank, |rank| rank.cast_unsigned()); .map_or(dp_rank, |rank| rank.cast_unsigned());
for raw_event in batch.events { for raw_event in batch.events {
let kv_event = let placement_event = convert_event(
convert_event(raw_event, seq, block_size, effective_dp_rank, warning_count); raw_event,
let router_event = RouterEvent::new(worker_id, kv_event); seq,
block_size,
WorkerWithDpRank::new(worker_id, effective_dp_rank),
warning_count,
);
if !placement_event.placement.is_local_gpu() {
continue;
}
let router_event = placement_event
.into_router_event()
.expect("local worker placement must convert to router event");
indexer.apply_event(router_event).await; indexer.apply_event(router_event).await;
} }
watermark.store(seq, Ordering::Release); watermark.store(seq, Ordering::Release);
...@@ -381,9 +391,19 @@ async fn zmq_recv_loop( ...@@ -381,9 +391,19 @@ async fn zmq_recv_loop(
.data_parallel_rank .data_parallel_rank
.map_or(dp_rank, |rank| rank.cast_unsigned()); .map_or(dp_rank, |rank| rank.cast_unsigned());
for raw_event in batch.events { for raw_event in batch.events {
let kv_event = let placement_event = convert_event(
convert_event(raw_event, seq, block_size, effective_dp_rank, &warning_count); raw_event,
let router_event = RouterEvent::new(worker_id, kv_event); seq,
block_size,
WorkerWithDpRank::new(worker_id, effective_dp_rank),
&warning_count,
);
if !placement_event.placement.is_local_gpu() {
continue;
}
let router_event = placement_event
.into_router_event()
.expect("local worker placement must convert to router event");
indexer.apply_event(router_event).await; indexer.apply_event(router_event).await;
messages_processed += 1; messages_processed += 1;
} }
......
...@@ -12,6 +12,51 @@ use crate::protocols::{ ...@@ -12,6 +12,51 @@ use crate::protocols::{
}; };
use crate::sequences::SequencePublisher; use crate::sequences::SequencePublisher;
pub fn router_event(
worker_id: WorkerId,
event_id: u64,
dp_rank: u32,
data: KvCacheEventData,
) -> RouterEvent {
RouterEvent::new(
worker_id,
KvCacheEvent {
event_id,
data,
dp_rank,
},
)
}
pub fn stored_blocks_with_sequence_hashes(
local_hashes: &[LocalBlockHash],
seq_hashes: &[u64],
) -> Vec<KvCacheStoredBlockData> {
local_hashes
.iter()
.zip(seq_hashes.iter())
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect()
}
pub fn remove_event(
worker_id: WorkerId,
event_id: u64,
dp_rank: u32,
block_hashes: Vec<ExternalSequenceBlockHash>,
) -> RouterEvent {
router_event(
worker_id,
event_id,
dp_rank,
KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }),
)
}
/// Creates blocks with artificial hash mapping (hash * 100) for testing. /// Creates blocks with artificial hash mapping (hash * 100) for testing.
pub fn make_blocks(hashes: Vec<u64>) -> Vec<KvCacheStoredBlockData> { pub fn make_blocks(hashes: Vec<u64>) -> Vec<KvCacheStoredBlockData> {
hashes hashes
...@@ -40,30 +85,19 @@ pub fn create_store_event( ...@@ -40,30 +85,19 @@ pub fn create_store_event(
hashes: Vec<u64>, hashes: Vec<u64>,
parent: Option<ExternalSequenceBlockHash>, parent: Option<ExternalSequenceBlockHash>,
) -> RouterEvent { ) -> RouterEvent {
RouterEvent { router_event(worker_id, event_id, 0, add_blocks(hashes, parent))
worker_id,
event: KvCacheEvent {
event_id,
data: add_blocks(hashes, parent),
dp_rank: 0,
},
}
} }
pub fn create_remove_event(worker_id: WorkerId, event_id: u64, hashes: Vec<u64>) -> RouterEvent { pub fn create_remove_event(worker_id: WorkerId, event_id: u64, hashes: Vec<u64>) -> RouterEvent {
RouterEvent { remove_event(
worker_id, worker_id,
event: KvCacheEvent { event_id,
event_id, 0,
data: KvCacheEventData::Removed(KvCacheRemoveData { hashes
block_hashes: hashes .iter()
.iter() .map(|i| ExternalSequenceBlockHash(*i * 100))
.map(|i| ExternalSequenceBlockHash(*i * 100)) .collect(),
.collect(), )
}),
dp_rank: 0,
},
}
} }
/// No-op [`SequencePublisher`] for tests and benchmarks that don't need event transport. /// No-op [`SequencePublisher`] for tests and benchmarks that don't need event transport.
......
...@@ -18,7 +18,8 @@ use serde::de::{self, Deserializer, IgnoredAny, MapAccess, SeqAccess, Visitor}; ...@@ -18,7 +18,8 @@ use serde::de::{self, Deserializer, IgnoredAny, MapAccess, SeqAccess, Visitor};
use crate::protocols::{ use crate::protocols::{
BlockExtraInfo, BlockMmObjectInfo, ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, BlockExtraInfo, BlockMmObjectInfo, ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData,
KvCacheRemoveData, KvCacheStoreData, KvCacheStoredBlockData, compute_block_hash_for_seq, KvCacheRemoveData, KvCacheStoreData, KvCacheStoredBlockData, Placement, PlacementEvent,
StorageTier, WorkerWithDpRank, compute_block_hash_for_seq,
}; };
// ------------------------------------------------------------------------- // -------------------------------------------------------------------------
...@@ -335,16 +336,22 @@ impl<'de> Visitor<'de> for RawKvEventVisitor { ...@@ -335,16 +336,22 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
// Event conversion -------------------------------------------------------- // Event conversion --------------------------------------------------------
// ------------------------------------------------------------------------- // -------------------------------------------------------------------------
/// Convert a raw event coming from the ZMQ channel into the internal /// Convert a raw event coming from the ZMQ channel into a placement-aware worker event.
/// [`KvCacheEvent`] representation used by the router.
pub fn convert_event( pub fn convert_event(
raw: RawKvEvent, raw: RawKvEvent,
event_id: u64, event_id: u64,
kv_block_size: u32, kv_block_size: u32,
dp_rank: u32, worker: WorkerWithDpRank,
warning_count: &Arc<AtomicU32>, warning_count: &Arc<AtomicU32>,
) -> KvCacheEvent { ) -> PlacementEvent {
match raw { let storage_tier = match &raw {
RawKvEvent::BlockStored { medium, .. } | RawKvEvent::BlockRemoved { medium, .. } => {
StorageTier::from_kv_medium_or_default(medium.as_deref())
}
RawKvEvent::AllBlocksCleared => StorageTier::Device,
};
let dp_rank = worker.dp_rank;
let event = match raw {
RawKvEvent::BlockStored { RawKvEvent::BlockStored {
block_hashes, block_hashes,
parent_block_hash, parent_block_hash,
...@@ -369,13 +376,16 @@ pub fn convert_event( ...@@ -369,13 +376,16 @@ pub fn convert_event(
// Return an empty Removed instead of Cleared to avoid nuking // Return an empty Removed instead of Cleared to avoid nuking
// the worker's entire index state. An empty Removed is a no-op // the worker's entire index state. An empty Removed is a no-op
// in the radix tree (zero iterations, returns Ok(())). // in the radix tree (zero iterations, returns Ok(())).
return KvCacheEvent { return PlacementEvent::new(
event_id, Placement::local_worker(worker.worker_id, worker.dp_rank, storage_tier),
data: KvCacheEventData::Removed(KvCacheRemoveData { KvCacheEvent {
block_hashes: vec![], event_id,
}), data: KvCacheEventData::Removed(KvCacheRemoveData {
dp_rank, block_hashes: vec![],
}; }),
dp_rank,
},
);
} }
} }
...@@ -422,7 +432,12 @@ pub fn convert_event( ...@@ -422,7 +432,12 @@ pub fn convert_event(
data: KvCacheEventData::Cleared, data: KvCacheEventData::Cleared,
dp_rank, dp_rank,
}, },
} };
PlacementEvent::new(
Placement::local_worker(worker.worker_id, worker.dp_rank, storage_tier),
event,
)
} }
pub fn create_stored_block_from_parts( pub fn create_stored_block_from_parts(
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use dynamo_kv_router::protocols::XXH3_SEED; use dynamo_kv_router::protocols::{StorageTier as RouterStorageTier, XXH3_SEED};
/// LocalBlockHash type (content hash from tokens only) /// LocalBlockHash type (content hash from tokens only)
type LocalBlockHash = u64; type LocalBlockHash = u64;
...@@ -108,12 +108,7 @@ pub enum StorageTier { ...@@ -108,12 +108,7 @@ pub enum StorageTier {
impl StorageTier { impl StorageTier {
/// Parse from vLLM's medium string (e.g., "GPU", "CPU_TIER1", "CPU_TIER2") /// Parse from vLLM's medium string (e.g., "GPU", "CPU_TIER1", "CPU_TIER2")
pub fn from_vllm_medium(s: &str) -> Option<Self> { pub fn from_vllm_medium(s: &str) -> Option<Self> {
match s { RouterStorageTier::from_kv_medium(s).map(Into::into)
"GPU" => Some(StorageTier::Device),
"CPU_TIER1" => Some(StorageTier::HostPinned),
"CPU_TIER2" => Some(StorageTier::Disk),
_ => None,
}
} }
/// Convert to vLLM's medium string /// Convert to vLLM's medium string
...@@ -135,6 +130,16 @@ impl StorageTier { ...@@ -135,6 +130,16 @@ impl StorageTier {
} }
} }
impl From<RouterStorageTier> for StorageTier {
fn from(value: RouterStorageTier) -> Self {
match value {
RouterStorageTier::Device => Self::Device,
RouterStorageTier::HostPinned => Self::HostPinned,
RouterStorageTier::Disk | RouterStorageTier::External => Self::Disk,
}
}
}
/// Legacy type alias for backward compatibility /// Legacy type alias for backward compatibility
#[deprecated(note = "Use StorageTier instead")] #[deprecated(note = "Use StorageTier instead")]
pub type StorageMedium = StorageTier; pub type StorageMedium = StorageTier;
......
...@@ -225,10 +225,11 @@ impl KvEventSource { ...@@ -225,10 +225,11 @@ impl KvEventSource {
/// Start the event source from a [`KvEventSourceConfig`]. /// Start the event source from a [`KvEventSourceConfig`].
fn start( fn start(
component: Component, component: Component,
worker_id: WorkerId,
kv_block_size: u32, kv_block_size: u32,
source_config: KvEventSourceConfig, source_config: KvEventSourceConfig,
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
tx: mpsc::UnboundedSender<KvCacheEvent>, tx: mpsc::UnboundedSender<PlacementEvent>,
next_event_id: Arc<AtomicU64>, next_event_id: Arc<AtomicU64>,
) -> Result<Self> { ) -> Result<Self> {
match source_config { match source_config {
...@@ -240,6 +241,7 @@ impl KvEventSource { ...@@ -240,6 +241,7 @@ impl KvEventSource {
.spawn(start_zmq_listener( .spawn(start_zmq_listener(
endpoint, endpoint,
topic, topic,
worker_id,
tx, tx,
cancellation_token.clone(), cancellation_token.clone(),
kv_block_size, kv_block_size,
...@@ -269,8 +271,10 @@ pub struct KvEventPublisher { ...@@ -269,8 +271,10 @@ pub struct KvEventPublisher {
source: Option<KvEventSource>, source: Option<KvEventSource>,
/// The cancellation token. /// The cancellation token.
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
/// The ID of the local worker emitting placement events.
worker_id: WorkerId,
/// The channel to send events to. /// The channel to send events to.
tx: mpsc::UnboundedSender<KvCacheEvent>, tx: mpsc::UnboundedSender<PlacementEvent>,
/// Internal monotonic event ID counter - ensures each event gets a unique, incrementing ID. /// Internal monotonic event ID counter - ensures each event gets a unique, incrementing ID.
/// Shared with the ZMQ listener (if any) to maintain consistency. /// Shared with the ZMQ listener (if any) to maintain consistency.
next_event_id: Arc<AtomicU64>, next_event_id: Arc<AtomicU64>,
...@@ -317,7 +321,7 @@ impl KvEventPublisher { ...@@ -317,7 +321,7 @@ impl KvEventPublisher {
}) })
.map(|ms| ms.min(MAX_BATCHING_TIMEOUT_MS)); .map(|ms| ms.min(MAX_BATCHING_TIMEOUT_MS));
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>(); let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
// Infer worker_id from component's connection // Infer worker_id from component's connection
let worker_id = component.drt().connection_id(); let worker_id = component.drt().connection_id();
...@@ -345,6 +349,7 @@ impl KvEventPublisher { ...@@ -345,6 +349,7 @@ impl KvEventPublisher {
if let Some(config) = source_config { if let Some(config) = source_config {
source = Some(KvEventSource::start( source = Some(KvEventSource::start(
component.clone(), component.clone(),
worker_id,
kv_block_size, kv_block_size,
config, config,
cancellation_token.clone(), cancellation_token.clone(),
...@@ -444,13 +449,18 @@ impl KvEventPublisher { ...@@ -444,13 +449,18 @@ impl KvEventPublisher {
kv_block_size, kv_block_size,
source, source,
cancellation_token, cancellation_token,
worker_id,
tx, tx,
next_event_id, next_event_id,
}) })
} }
pub fn publish(&self, event: KvCacheEvent) -> Result<(), mpsc::error::SendError<KvCacheEvent>> { pub fn publish(&self, event: KvCacheEvent) -> Result<(), mpsc::error::SendError<KvCacheEvent>> {
self.tx.send(event) let placement_event = PlacementEvent::local_gpu(self.worker_id, event);
match self.tx.send(placement_event) {
Ok(()) => Ok(()),
Err(err) => Err(mpsc::error::SendError(err.0.event)),
}
} }
/// Get and increment the next event ID atomically. /// Get and increment the next event ID atomically.
...@@ -575,7 +585,7 @@ async fn run_event_processor_loop<P: EventSink + Send + Sync + 'static>( ...@@ -575,7 +585,7 @@ async fn run_event_processor_loop<P: EventSink + Send + Sync + 'static>(
publisher: P, publisher: P,
worker_id: u64, worker_id: u64,
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
mut rx: mpsc::UnboundedReceiver<KvCacheEvent>, mut rx: mpsc::UnboundedReceiver<PlacementEvent>,
local_indexer: Option<Arc<LocalKvIndexer>>, local_indexer: Option<Arc<LocalKvIndexer>>,
timeout_ms: Option<u64>, timeout_ms: Option<u64>,
max_batch_blocks: usize, max_batch_blocks: usize,
...@@ -594,7 +604,7 @@ async fn run_event_processor_loop<P: EventSink + Send + Sync + 'static>( ...@@ -594,7 +604,7 @@ async fn run_event_processor_loop<P: EventSink + Send + Sync + 'static>(
break; break;
} }
event = rx.recv() => { event = rx.recv() => {
let Some(event) = event else { let Some(placement_event) = event else {
tracing::debug!("Event processor channel closed."); tracing::debug!("Event processor channel closed.");
batching_state.flush(&publisher, &local_indexer, worker_id).await; batching_state.flush(&publisher, &local_indexer, worker_id).await;
break; break;
...@@ -602,7 +612,7 @@ async fn run_event_processor_loop<P: EventSink + Send + Sync + 'static>( ...@@ -602,7 +612,7 @@ async fn run_event_processor_loop<P: EventSink + Send + Sync + 'static>(
// Warn if the raw input event_id is not consecutive — events were dropped // Warn if the raw input event_id is not consecutive — events were dropped
// (e.g. channel send error) before they reached the batching layer. // (e.g. channel send error) before they reached the batching layer.
let raw_event_id = event.event_id; let raw_event_id = placement_event.event.event_id;
if let Some(last_id) = last_raw_input_id if let Some(last_id) = last_raw_input_id
&& raw_event_id > last_id + 1 && raw_event_id > last_id + 1
{ {
...@@ -627,6 +637,17 @@ async fn run_event_processor_loop<P: EventSink + Send + Sync + 'static>( ...@@ -627,6 +637,17 @@ async fn run_event_processor_loop<P: EventSink + Send + Sync + 'static>(
} }
last_raw_input_id = Some(raw_event_id); last_raw_input_id = Some(raw_event_id);
if !placement_event.placement.is_local_gpu() {
tracing::trace!(
worker_id,
?placement_event.placement,
event_id = placement_event.event.event_id,
"Skipping non-local-GPU placement event"
);
continue;
}
let event = placement_event.event;
tracing::trace!("Event processor for worker_id {} processing event: {:?}", worker_id, event.data); tracing::trace!("Event processor for worker_id {} processing event: {:?}", worker_id, event.data);
let dp_rank_changed = batching_state.has_pending() let dp_rank_changed = batching_state.has_pending()
...@@ -702,7 +723,7 @@ async fn start_event_processor<P: EventSink + Send + Sync + 'static>( ...@@ -702,7 +723,7 @@ async fn start_event_processor<P: EventSink + Send + Sync + 'static>(
publisher: P, publisher: P,
worker_id: u64, worker_id: u64,
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
rx: mpsc::UnboundedReceiver<KvCacheEvent>, rx: mpsc::UnboundedReceiver<PlacementEvent>,
local_indexer: Option<Arc<LocalKvIndexer>>, local_indexer: Option<Arc<LocalKvIndexer>>,
batching_timeout_ms: Option<u64>, batching_timeout_ms: Option<u64>,
) { ) {
...@@ -723,7 +744,7 @@ async fn start_event_processor_jetstream<P: EventSink + Send + Sync + 'static>( ...@@ -723,7 +744,7 @@ async fn start_event_processor_jetstream<P: EventSink + Send + Sync + 'static>(
publisher: P, publisher: P,
worker_id: u64, worker_id: u64,
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
rx: mpsc::UnboundedReceiver<KvCacheEvent>, rx: mpsc::UnboundedReceiver<PlacementEvent>,
local_indexer: Option<Arc<LocalKvIndexer>>, local_indexer: Option<Arc<LocalKvIndexer>>,
batching_timeout_ms: Option<u64>, batching_timeout_ms: Option<u64>,
) { ) {
...@@ -750,7 +771,8 @@ fn calculate_backoff_ms(consecutive_errors: u32) -> u64 { ...@@ -750,7 +771,8 @@ fn calculate_backoff_ms(consecutive_errors: u32) -> u64 {
pub async fn start_zmq_listener( pub async fn start_zmq_listener(
zmq_endpoint: String, zmq_endpoint: String,
zmq_topic: String, zmq_topic: String,
tx: mpsc::UnboundedSender<KvCacheEvent>, worker_id: WorkerId,
tx: mpsc::UnboundedSender<PlacementEvent>,
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
kv_block_size: u32, kv_block_size: u32,
next_event_id: Arc<AtomicU64>, next_event_id: Arc<AtomicU64>,
...@@ -868,8 +890,14 @@ pub async fn start_zmq_listener( ...@@ -868,8 +890,14 @@ pub async fn start_zmq_listener(
for raw_event in batch.events.into_iter() { for raw_event in batch.events.into_iter() {
// Use shared monotonic event_id counter instead of engine's sequence number // Use shared monotonic event_id counter instead of engine's sequence number
let event_id = next_event_id.fetch_add(1, Ordering::SeqCst); let event_id = next_event_id.fetch_add(1, Ordering::SeqCst);
let worker = WorkerWithDpRank::new(worker_id, dp_rank);
let event = convert_event(raw_event, event_id, kv_block_size, dp_rank, &warning_count); let event = convert_event(
raw_event,
event_id,
kv_block_size,
worker,
&warning_count,
);
if tx.send(event).is_err() { if tx.send(event).is_err() {
tracing::warn!("Failed to send message to channel - receiver dropped"); tracing::warn!("Failed to send message to channel - receiver dropped");
exit_reason = "channel receiver dropped"; exit_reason = "channel receiver dropped";
...@@ -1105,8 +1133,14 @@ mod test_event_processing { ...@@ -1105,8 +1133,14 @@ mod test_event_processing {
block_mm_infos: None, block_mm_infos: None,
}; };
let out = convert_event(raw_evt, 42, kv_block_size, 0, &Arc::new(AtomicU32::new(0))); let out = convert_event(
assert!(matches!(out.data, KvCacheEventData::Stored(_))); raw_evt,
42,
kv_block_size,
WorkerWithDpRank::from_worker_id(1),
&Arc::new(AtomicU32::new(0)),
);
assert!(matches!(out.event.data, KvCacheEventData::Stored(_)));
} }
#[test] #[test]
...@@ -1134,14 +1168,26 @@ mod test_event_processing { ...@@ -1134,14 +1168,26 @@ mod test_event_processing {
}; };
let wc = Arc::new(AtomicU32::new(0)); let wc = Arc::new(AtomicU32::new(0));
let base_out = convert_event(base_evt, 1, kv_block_size, 0, &wc); let base_out = convert_event(
let lora_out = convert_event(lora_evt, 2, kv_block_size, 0, &wc); base_evt,
1,
kv_block_size,
WorkerWithDpRank::from_worker_id(1),
&wc,
);
let lora_out = convert_event(
lora_evt,
2,
kv_block_size,
WorkerWithDpRank::from_worker_id(1),
&wc,
);
let base_hash = match &base_out.data { let base_hash = match &base_out.event.data {
KvCacheEventData::Stored(s) => s.blocks[0].tokens_hash, KvCacheEventData::Stored(s) => s.blocks[0].tokens_hash,
_ => panic!("expected Stored"), _ => panic!("expected Stored"),
}; };
let lora_hash = match &lora_out.data { let lora_hash = match &lora_out.event.data {
KvCacheEventData::Stored(s) => s.blocks[0].tokens_hash, KvCacheEventData::Stored(s) => s.blocks[0].tokens_hash,
_ => panic!("expected Stored"), _ => panic!("expected Stored"),
}; };
...@@ -1176,14 +1222,26 @@ mod test_event_processing { ...@@ -1176,14 +1222,26 @@ mod test_event_processing {
block_mm_infos: None, block_mm_infos: None,
}; };
let out1 = convert_event(evt1, 1, kv_block_size, 0, &wc); let out1 = convert_event(
let out2 = convert_event(evt2, 2, kv_block_size, 0, &wc); evt1,
1,
kv_block_size,
WorkerWithDpRank::from_worker_id(1),
&wc,
);
let out2 = convert_event(
evt2,
2,
kv_block_size,
WorkerWithDpRank::from_worker_id(1),
&wc,
);
let hash1 = match &out1.data { let hash1 = match &out1.event.data {
KvCacheEventData::Stored(s) => s.blocks[0].tokens_hash, KvCacheEventData::Stored(s) => s.blocks[0].tokens_hash,
_ => panic!("expected Stored"), _ => panic!("expected Stored"),
}; };
let hash2 = match &out2.data { let hash2 = match &out2.event.data {
KvCacheEventData::Stored(s) => s.blocks[0].tokens_hash, KvCacheEventData::Stored(s) => s.blocks[0].tokens_hash,
_ => panic!("expected Stored"), _ => panic!("expected Stored"),
}; };
...@@ -1256,17 +1314,29 @@ mod test_event_processing { ...@@ -1256,17 +1314,29 @@ mod test_event_processing {
block_hashes: vec![BlockHashValue::Unsigned(123), BlockHashValue::Signed(456)], block_hashes: vec![BlockHashValue::Unsigned(123), BlockHashValue::Signed(456)],
medium: None, medium: None,
}; };
let out = convert_event(raw_evt, 7, kv_block_size, 0, &Arc::new(AtomicU32::new(0))); let out = convert_event(
raw_evt,
7,
kv_block_size,
WorkerWithDpRank::from_worker_id(1),
&Arc::new(AtomicU32::new(0)),
);
assert!(matches!(out.data, KvCacheEventData::Removed(_))); assert!(matches!(out.event.data, KvCacheEventData::Removed(_)));
} }
#[test] #[test]
fn test_convert_event_all_blocks_cleared() { fn test_convert_event_all_blocks_cleared() {
let kv_block_size = 4; let kv_block_size = 4;
let raw_evt = RawKvEvent::AllBlocksCleared; let raw_evt = RawKvEvent::AllBlocksCleared;
let out = convert_event(raw_evt, 1, kv_block_size, 0, &Arc::new(AtomicU32::new(0))); let out = convert_event(
assert!(matches!(out.data, KvCacheEventData::Cleared)); raw_evt,
1,
kv_block_size,
WorkerWithDpRank::from_worker_id(1),
&Arc::new(AtomicU32::new(0)),
);
assert!(matches!(out.event.data, KvCacheEventData::Cleared));
} }
#[test] #[test]
...@@ -1425,6 +1495,10 @@ mod tests_startup_helpers { ...@@ -1425,6 +1495,10 @@ mod tests_startup_helpers {
} }
} }
fn local_gpu_event(worker_id: WorkerId, event: KvCacheEvent) -> PlacementEvent {
PlacementEvent::local_gpu(worker_id, event)
}
//-------------------------------------------------------------------- //--------------------------------------------------------------------
// Test start_event_processor // Test start_event_processor
//-------------------------------------------------------------------- //--------------------------------------------------------------------
...@@ -1441,8 +1515,8 @@ mod tests_startup_helpers { ...@@ -1441,8 +1515,8 @@ mod tests_startup_helpers {
}; };
let token = CancellationToken::new(); let token = CancellationToken::new();
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>(); let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
tx.send(event).unwrap(); tx.send(local_gpu_event(1, event)).unwrap();
drop(tx); drop(tx);
let handle = tokio::spawn(start_event_processor( let handle = tokio::spawn(start_event_processor(
...@@ -1498,8 +1572,8 @@ mod tests_startup_helpers { ...@@ -1498,8 +1572,8 @@ mod tests_startup_helpers {
dp_rank: 0, dp_rank: 0,
}; };
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>(); let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
tx.send(event).unwrap(); tx.send(local_gpu_event(1, event)).unwrap();
drop(tx); drop(tx);
// Start event processor with local indexer // Start event processor with local indexer
...@@ -1583,8 +1657,8 @@ mod tests_startup_helpers { ...@@ -1583,8 +1657,8 @@ mod tests_startup_helpers {
dp_rank: 0, dp_rank: 0,
}; };
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>(); let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
tx.send(store_event).unwrap(); tx.send(local_gpu_event(1, store_event)).unwrap();
// Start event processor with local indexer // Start event processor with local indexer
let handle = tokio::spawn(start_event_processor( let handle = tokio::spawn(start_event_processor(
...@@ -1604,7 +1678,7 @@ mod tests_startup_helpers { ...@@ -1604,7 +1678,7 @@ mod tests_startup_helpers {
}), }),
dp_rank: 0, dp_rank: 0,
}; };
tx.send(remove_event).unwrap(); tx.send(local_gpu_event(1, remove_event)).unwrap();
drop(tx); drop(tx);
tokio::time::timeout(tokio::time::Duration::from_secs(1), handle) tokio::time::timeout(tokio::time::Duration::from_secs(1), handle)
...@@ -1665,8 +1739,8 @@ mod tests_startup_helpers { ...@@ -1665,8 +1739,8 @@ mod tests_startup_helpers {
dp_rank: 0, dp_rank: 0,
}; };
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>(); let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
tx.send(store_event).unwrap(); tx.send(local_gpu_event(1, store_event)).unwrap();
// Clear all blocks // Clear all blocks
let clear_event = KvCacheEvent { let clear_event = KvCacheEvent {
...@@ -1674,7 +1748,7 @@ mod tests_startup_helpers { ...@@ -1674,7 +1748,7 @@ mod tests_startup_helpers {
data: KvCacheEventData::Cleared, data: KvCacheEventData::Cleared,
dp_rank: 0, dp_rank: 0,
}; };
tx.send(clear_event).unwrap(); tx.send(local_gpu_event(1, clear_event)).unwrap();
drop(tx); drop(tx);
// Create event processor and wait // Create event processor and wait
...@@ -1743,8 +1817,8 @@ mod tests_startup_helpers { ...@@ -1743,8 +1817,8 @@ mod tests_startup_helpers {
}; };
let new_token = CancellationToken::new(); let new_token = CancellationToken::new();
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>(); let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
tx.send(event).unwrap(); tx.send(local_gpu_event(1, event)).unwrap();
drop(tx); drop(tx);
// Despite local indexer being cancelled, event processor should continue // Despite local indexer being cancelled, event processor should continue
...@@ -1774,7 +1848,7 @@ mod tests_startup_helpers { ...@@ -1774,7 +1848,7 @@ mod tests_startup_helpers {
#[tokio::test] #[tokio::test]
async fn test_start_zmq_listener_pushes_to_channel() { async fn test_start_zmq_listener_pushes_to_channel() {
// Prepare channel that listener should fill // Prepare channel that listener should fill
let (tx, mut rx) = mpsc::unbounded_channel::<KvCacheEvent>(); let (tx, mut rx) = mpsc::unbounded_channel::<PlacementEvent>();
// ZMQ TCP endpoint using localhost with fixed port // ZMQ TCP endpoint using localhost with fixed port
let endpoint = "tcp://127.0.0.1:15555"; let endpoint = "tcp://127.0.0.1:15555";
...@@ -1792,7 +1866,7 @@ mod tests_startup_helpers { ...@@ -1792,7 +1866,7 @@ mod tests_startup_helpers {
// Spawn async listener (connects to publisher bound above) // Spawn async listener (connects to publisher bound above)
let listener_handle = tokio::spawn({ let listener_handle = tokio::spawn({
let token = token.clone(); let token = token.clone();
start_zmq_listener(endpoint.to_string(), topic, tx, token, 4, next_event_id) start_zmq_listener(endpoint.to_string(), topic, 1, tx, token, 4, next_event_id)
}); });
// Give time for the connection to establish // Give time for the connection to establish
...@@ -1835,7 +1909,7 @@ mod tests_startup_helpers { ...@@ -1835,7 +1909,7 @@ mod tests_startup_helpers {
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// Check that we received the message // Check that we received the message
let event = rx.try_recv().expect("no message received"); let event = rx.try_recv().expect("no message received").event;
let KvCacheEventData::Stored(KvCacheStoreData { let KvCacheEventData::Stored(KvCacheStoreData {
parent_hash, parent_hash,
...@@ -1872,7 +1946,7 @@ mod tests_startup_helpers { ...@@ -1872,7 +1946,7 @@ mod tests_startup_helpers {
100, // buffer size 100, // buffer size
)); ));
let (worker_tx, worker_rx) = mpsc::unbounded_channel::<KvCacheEvent>(); let (worker_tx, worker_rx) = mpsc::unbounded_channel::<PlacementEvent>();
// Start worker's event processor // Start worker's event processor
tokio::spawn(start_event_processor( tokio::spawn(start_event_processor(
...@@ -1912,7 +1986,9 @@ mod tests_startup_helpers { ...@@ -1912,7 +1986,9 @@ mod tests_startup_helpers {
dp_rank: 0, dp_rank: 0,
}; };
worker_tx.send(event_1.clone()).unwrap(); worker_tx
.send(local_gpu_event(worker_1_id, event_1.clone()))
.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// Simulate JetStream: forward worker's published event to router // Simulate JetStream: forward worker's published event to router
...@@ -1978,7 +2054,9 @@ mod tests_startup_helpers { ...@@ -1978,7 +2054,9 @@ mod tests_startup_helpers {
dp_rank: 0, dp_rank: 0,
}; };
worker_tx.send(event_2.clone()).unwrap(); // send to worker but not to router worker_tx
.send(local_gpu_event(worker_1_id, event_2.clone()))
.unwrap(); // send to worker but not to router
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// assert: Worker published event_2 to "NATS" (MockComponent) // assert: Worker published event_2 to "NATS" (MockComponent)
...@@ -2395,6 +2473,10 @@ mod event_processor_tests { ...@@ -2395,6 +2473,10 @@ mod event_processor_tests {
} }
} }
fn local_gpu_event(event: KvCacheEvent) -> PlacementEvent {
PlacementEvent::local_gpu(1, event)
}
/// Test that pushing N removed events results in batched output /// Test that pushing N removed events results in batched output
/// Uses a 10ms timeout to ensure events are batched (events sent rapidly) /// Uses a 10ms timeout to ensure events are batched (events sent rapidly)
#[tokio::test] #[tokio::test]
...@@ -2419,7 +2501,7 @@ mod event_processor_tests { ...@@ -2419,7 +2501,7 @@ mod event_processor_tests {
/// Helper function to test removed events batching with configurable count and timeout /// Helper function to test removed events batching with configurable count and timeout
async fn test_removed_events_batching(event_count: usize, timeout_ms: Option<u64>) { async fn test_removed_events_batching(event_count: usize, timeout_ms: Option<u64>) {
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>(); let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
let publisher = MockPublisher::new(); let publisher = MockPublisher::new();
let publisher_clone = publisher.clone(); let publisher_clone = publisher.clone();
let cancellation_token = CancellationToken::new(); let cancellation_token = CancellationToken::new();
...@@ -2445,7 +2527,7 @@ mod event_processor_tests { ...@@ -2445,7 +2527,7 @@ mod event_processor_tests {
}), }),
dp_rank: 0, dp_rank: 0,
}; };
tx.send(event).unwrap(); tx.send(local_gpu_event(event)).unwrap();
// Yield to allow event processor to process the event // Yield to allow event processor to process the event
tokio::task::yield_now().await; tokio::task::yield_now().await;
} }
...@@ -2516,7 +2598,7 @@ mod event_processor_tests { ...@@ -2516,7 +2598,7 @@ mod event_processor_tests {
/// Helper function to test stored events batching with configurable count and timeout /// Helper function to test stored events batching with configurable count and timeout
async fn test_stored_events_batching(event_count: usize, timeout_ms: Option<u64>) { async fn test_stored_events_batching(event_count: usize, timeout_ms: Option<u64>) {
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>(); let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
let publisher = MockPublisher::new(); let publisher = MockPublisher::new();
let publisher_clone = publisher.clone(); let publisher_clone = publisher.clone();
let cancellation_token = CancellationToken::new(); let cancellation_token = CancellationToken::new();
...@@ -2554,7 +2636,7 @@ mod event_processor_tests { ...@@ -2554,7 +2636,7 @@ mod event_processor_tests {
}), }),
dp_rank: 0, dp_rank: 0,
}; };
tx.send(event).unwrap(); tx.send(local_gpu_event(event)).unwrap();
// Small sleep to allow event processor to batch events // Small sleep to allow event processor to batch events
tokio::time::sleep(tokio::time::Duration::from_micros(100)).await; tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
} }
...@@ -2614,7 +2696,7 @@ mod event_processor_tests { ...@@ -2614,7 +2696,7 @@ mod event_processor_tests {
async fn test_run_event_processor_loop_non_sequential_flush() { async fn test_run_event_processor_loop_non_sequential_flush() {
let timeout_ms = Some(100); // 100ms timeout let timeout_ms = Some(100); // 100ms timeout
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>(); let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
let publisher = MockPublisher::new(); let publisher = MockPublisher::new();
let publisher_clone = publisher.clone(); let publisher_clone = publisher.clone();
let cancellation_token = CancellationToken::new(); let cancellation_token = CancellationToken::new();
...@@ -2646,7 +2728,7 @@ mod event_processor_tests { ...@@ -2646,7 +2728,7 @@ mod event_processor_tests {
}), }),
dp_rank: 0, dp_rank: 0,
}; };
tx.send(event).unwrap(); tx.send(local_gpu_event(event)).unwrap();
} }
drop(tx); drop(tx);
...@@ -2697,7 +2779,7 @@ mod event_processor_tests { ...@@ -2697,7 +2779,7 @@ mod event_processor_tests {
/// Helper function to test no batching with slow input /// Helper function to test no batching with slow input
async fn test_no_batching_with_slow_input(timeout_ms: Option<u64>) { async fn test_no_batching_with_slow_input(timeout_ms: Option<u64>) {
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>(); let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
let publisher = MockPublisher::new(); let publisher = MockPublisher::new();
let publisher_clone = publisher.clone(); let publisher_clone = publisher.clone();
let cancellation_token = CancellationToken::new(); let cancellation_token = CancellationToken::new();
...@@ -2725,7 +2807,7 @@ mod event_processor_tests { ...@@ -2725,7 +2807,7 @@ mod event_processor_tests {
}), }),
dp_rank: 0, dp_rank: 0,
}; };
tx.send(event).unwrap(); tx.send(local_gpu_event(event)).unwrap();
// Wait 2ms between events (much longer than the timeout) // Wait 2ms between events (much longer than the timeout)
// This ensures each event times out before the next one arrives // This ensures each event times out before the next one arrives
tokio::time::sleep(tokio::time::Duration::from_millis(2)).await; tokio::time::sleep(tokio::time::Duration::from_millis(2)).await;
...@@ -2770,7 +2852,7 @@ mod event_processor_tests { ...@@ -2770,7 +2852,7 @@ mod event_processor_tests {
async fn test_event_type_switching_causes_flush() { async fn test_event_type_switching_causes_flush() {
let timeout_ms = Some(100); // 100ms timeout let timeout_ms = Some(100); // 100ms timeout
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>(); let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
let publisher = MockPublisher::new(); let publisher = MockPublisher::new();
let publisher_clone = publisher.clone(); let publisher_clone = publisher.clone();
let cancellation_token = CancellationToken::new(); let cancellation_token = CancellationToken::new();
...@@ -2789,20 +2871,20 @@ mod event_processor_tests { ...@@ -2789,20 +2871,20 @@ mod event_processor_tests {
}); });
// Send a Removed event // Send a Removed event
tx.send(KvCacheEvent { tx.send(local_gpu_event(KvCacheEvent {
event_id: 0, event_id: 0,
data: KvCacheEventData::Removed(KvCacheRemoveData { data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(0)], block_hashes: vec![ExternalSequenceBlockHash(0)],
}), }),
dp_rank: 0, dp_rank: 0,
}) }))
.unwrap(); .unwrap();
// Small sleep // Small sleep
tokio::time::sleep(tokio::time::Duration::from_micros(100)).await; tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
// Send a Stored event (should cause flush of the Removed event) // Send a Stored event (should cause flush of the Removed event)
tx.send(KvCacheEvent { tx.send(local_gpu_event(KvCacheEvent {
event_id: 1, event_id: 1,
data: KvCacheEventData::Stored(KvCacheStoreData { data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: Some(ExternalSequenceBlockHash(0)), parent_hash: Some(ExternalSequenceBlockHash(0)),
...@@ -2813,7 +2895,7 @@ mod event_processor_tests { ...@@ -2813,7 +2895,7 @@ mod event_processor_tests {
}], }],
}), }),
dp_rank: 0, dp_rank: 0,
}) }))
.unwrap(); .unwrap();
// Give time for processing // Give time for processing
...@@ -2837,7 +2919,7 @@ mod event_processor_tests { ...@@ -2837,7 +2919,7 @@ mod event_processor_tests {
async fn test_dp_rank_change_causes_flush() { async fn test_dp_rank_change_causes_flush() {
let timeout_ms = Some(100); // 100ms timeout let timeout_ms = Some(100); // 100ms timeout
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>(); let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
let publisher = MockPublisher::new(); let publisher = MockPublisher::new();
let publisher_clone = publisher.clone(); let publisher_clone = publisher.clone();
let cancellation_token = CancellationToken::new(); let cancellation_token = CancellationToken::new();
...@@ -2857,26 +2939,26 @@ mod event_processor_tests { ...@@ -2857,26 +2939,26 @@ mod event_processor_tests {
// Send events with dp_rank=0 // Send events with dp_rank=0
for i in 0..3 { for i in 0..3 {
tx.send(KvCacheEvent { tx.send(local_gpu_event(KvCacheEvent {
event_id: i as u64, event_id: i as u64,
data: KvCacheEventData::Removed(KvCacheRemoveData { data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(i as u64)], block_hashes: vec![ExternalSequenceBlockHash(i as u64)],
}), }),
dp_rank: 0, dp_rank: 0,
}) }))
.unwrap(); .unwrap();
tokio::task::yield_now().await; tokio::task::yield_now().await;
} }
// Send events with dp_rank=1 (should cause flush of previous batch) // Send events with dp_rank=1 (should cause flush of previous batch)
for i in 3..6 { for i in 3..6 {
tx.send(KvCacheEvent { tx.send(local_gpu_event(KvCacheEvent {
event_id: i as u64, event_id: i as u64,
data: KvCacheEventData::Removed(KvCacheRemoveData { data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(i as u64)], block_hashes: vec![ExternalSequenceBlockHash(i as u64)],
}), }),
dp_rank: 1, dp_rank: 1,
}) }))
.unwrap(); .unwrap();
tokio::task::yield_now().await; tokio::task::yield_now().await;
} }
...@@ -2929,7 +3011,7 @@ mod event_processor_tests { ...@@ -2929,7 +3011,7 @@ mod event_processor_tests {
async fn test_flushed_events_have_correct_metadata() { async fn test_flushed_events_have_correct_metadata() {
let timeout_ms = Some(100); // 100ms timeout let timeout_ms = Some(100); // 100ms timeout
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>(); let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
let publisher = MockPublisher::new(); let publisher = MockPublisher::new();
let publisher_clone = publisher.clone(); let publisher_clone = publisher.clone();
let cancellation_token = CancellationToken::new(); let cancellation_token = CancellationToken::new();
...@@ -2949,13 +3031,13 @@ mod event_processor_tests { ...@@ -2949,13 +3031,13 @@ mod event_processor_tests {
// Send first batch: 3 events with dp_rank=0, event_ids 10-12 // Send first batch: 3 events with dp_rank=0, event_ids 10-12
for i in 0..3 { for i in 0..3 {
tx.send(KvCacheEvent { tx.send(local_gpu_event(KvCacheEvent {
event_id: 10 + i as u64, event_id: 10 + i as u64,
data: KvCacheEventData::Removed(KvCacheRemoveData { data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(i as u64)], block_hashes: vec![ExternalSequenceBlockHash(i as u64)],
}), }),
dp_rank: 0, dp_rank: 0,
}) }))
.unwrap(); .unwrap();
tokio::task::yield_now().await; tokio::task::yield_now().await;
} }
...@@ -2963,13 +3045,13 @@ mod event_processor_tests { ...@@ -2963,13 +3045,13 @@ mod event_processor_tests {
// Send second batch: 2 events with dp_rank=1, event_ids 20-21 // Send second batch: 2 events with dp_rank=1, event_ids 20-21
// This should flush the first batch with dp_rank=0 // This should flush the first batch with dp_rank=0
for i in 0..2 { for i in 0..2 {
tx.send(KvCacheEvent { tx.send(local_gpu_event(KvCacheEvent {
event_id: 20 + i as u64, event_id: 20 + i as u64,
data: KvCacheEventData::Removed(KvCacheRemoveData { data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash((i + 3) as u64)], block_hashes: vec![ExternalSequenceBlockHash((i + 3) as u64)],
}), }),
dp_rank: 1, dp_rank: 1,
}) }))
.unwrap(); .unwrap();
tokio::task::yield_now().await; tokio::task::yield_now().await;
} }
...@@ -3015,7 +3097,7 @@ mod event_processor_tests { ...@@ -3015,7 +3097,7 @@ mod event_processor_tests {
async fn test_first_event_after_idle_flushes_immediately_then_batches() { async fn test_first_event_after_idle_flushes_immediately_then_batches() {
let timeout_ms = Some(50); // 50ms timeout let timeout_ms = Some(50); // 50ms timeout
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>(); let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
let publisher = MockPublisher::new(); let publisher = MockPublisher::new();
let publisher_clone = publisher.clone(); let publisher_clone = publisher.clone();
let cancellation_token = CancellationToken::new(); let cancellation_token = CancellationToken::new();
...@@ -3039,13 +3121,13 @@ mod event_processor_tests { ...@@ -3039,13 +3121,13 @@ mod event_processor_tests {
// Send 3 events rapidly - first should flush immediately (stale timer), // Send 3 events rapidly - first should flush immediately (stale timer),
// remaining 2 should batch together // remaining 2 should batch together
for i in 0..3 { for i in 0..3 {
tx.send(KvCacheEvent { tx.send(local_gpu_event(KvCacheEvent {
event_id: i as u64, event_id: i as u64,
data: KvCacheEventData::Removed(KvCacheRemoveData { data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(i as u64)], block_hashes: vec![ExternalSequenceBlockHash(i as u64)],
}), }),
dp_rank: 0, dp_rank: 0,
}) }))
.unwrap(); .unwrap();
tokio::task::yield_now().await; tokio::task::yield_now().await;
} }
...@@ -3085,7 +3167,7 @@ mod event_processor_tests { ...@@ -3085,7 +3167,7 @@ mod event_processor_tests {
async fn test_stored_events_with_dp_rank_change_correct_metadata() { async fn test_stored_events_with_dp_rank_change_correct_metadata() {
let timeout_ms = Some(100); // 100ms timeout let timeout_ms = Some(100); // 100ms timeout
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>(); let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
let publisher = MockPublisher::new(); let publisher = MockPublisher::new();
let publisher_clone = publisher.clone(); let publisher_clone = publisher.clone();
let cancellation_token = CancellationToken::new(); let cancellation_token = CancellationToken::new();
...@@ -3104,7 +3186,7 @@ mod event_processor_tests { ...@@ -3104,7 +3186,7 @@ mod event_processor_tests {
}); });
// Send first batch: 2 sequential stored events with dp_rank=0, event_ids 100-101 // Send first batch: 2 sequential stored events with dp_rank=0, event_ids 100-101
tx.send(KvCacheEvent { tx.send(local_gpu_event(KvCacheEvent {
event_id: 100, event_id: 100,
data: KvCacheEventData::Stored(KvCacheStoreData { data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: Some(ExternalSequenceBlockHash(0)), parent_hash: Some(ExternalSequenceBlockHash(0)),
...@@ -3115,11 +3197,11 @@ mod event_processor_tests { ...@@ -3115,11 +3197,11 @@ mod event_processor_tests {
}], }],
}), }),
dp_rank: 0, dp_rank: 0,
}) }))
.unwrap(); .unwrap();
tokio::task::yield_now().await; tokio::task::yield_now().await;
tx.send(KvCacheEvent { tx.send(local_gpu_event(KvCacheEvent {
event_id: 101, event_id: 101,
data: KvCacheEventData::Stored(KvCacheStoreData { data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: Some(ExternalSequenceBlockHash(1)), parent_hash: Some(ExternalSequenceBlockHash(1)),
...@@ -3130,13 +3212,13 @@ mod event_processor_tests { ...@@ -3130,13 +3212,13 @@ mod event_processor_tests {
}], }],
}), }),
dp_rank: 0, dp_rank: 0,
}) }))
.unwrap(); .unwrap();
tokio::task::yield_now().await; tokio::task::yield_now().await;
// Send second batch: 1 event with dp_rank=1, event_id=200 // Send second batch: 1 event with dp_rank=1, event_id=200
// This should flush the first batch with dp_rank=0, event_id=101 // This should flush the first batch with dp_rank=0, event_id=101
tx.send(KvCacheEvent { tx.send(local_gpu_event(KvCacheEvent {
event_id: 200, event_id: 200,
data: KvCacheEventData::Stored(KvCacheStoreData { data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: Some(ExternalSequenceBlockHash(0)), parent_hash: Some(ExternalSequenceBlockHash(0)),
...@@ -3147,7 +3229,7 @@ mod event_processor_tests { ...@@ -3147,7 +3229,7 @@ mod event_processor_tests {
}], }],
}), }),
dp_rank: 1, dp_rank: 1,
}) }))
.unwrap(); .unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(2)).await; tokio::time::sleep(tokio::time::Duration::from_millis(2)).await;
...@@ -3202,7 +3284,7 @@ mod event_processor_tests { ...@@ -3202,7 +3284,7 @@ mod event_processor_tests {
async fn test_batch_parent_hash_preserved_when_extending() { async fn test_batch_parent_hash_preserved_when_extending() {
let timeout_ms = Some(100); // 100ms timeout let timeout_ms = Some(100); // 100ms timeout
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>(); let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
let publisher = MockPublisher::new(); let publisher = MockPublisher::new();
let publisher_clone = publisher.clone(); let publisher_clone = publisher.clone();
let cancellation_token = CancellationToken::new(); let cancellation_token = CancellationToken::new();
...@@ -3221,7 +3303,7 @@ mod event_processor_tests { ...@@ -3221,7 +3303,7 @@ mod event_processor_tests {
}); });
// First event: parent_hash=None, block_hash=1 // First event: parent_hash=None, block_hash=1
tx.send(KvCacheEvent { tx.send(local_gpu_event(KvCacheEvent {
event_id: 0, event_id: 0,
data: KvCacheEventData::Stored(KvCacheStoreData { data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None, // Root block with no parent parent_hash: None, // Root block with no parent
...@@ -3232,12 +3314,12 @@ mod event_processor_tests { ...@@ -3232,12 +3314,12 @@ mod event_processor_tests {
}], }],
}), }),
dp_rank: 0, dp_rank: 0,
}) }))
.unwrap(); .unwrap();
tokio::task::yield_now().await; tokio::task::yield_now().await;
// Second event: parent_hash=Some(1), block_hash=2 (sequential) // Second event: parent_hash=Some(1), block_hash=2 (sequential)
tx.send(KvCacheEvent { tx.send(local_gpu_event(KvCacheEvent {
event_id: 1, event_id: 1,
data: KvCacheEventData::Stored(KvCacheStoreData { data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: Some(ExternalSequenceBlockHash(1)), // Points to previous block parent_hash: Some(ExternalSequenceBlockHash(1)), // Points to previous block
...@@ -3248,12 +3330,12 @@ mod event_processor_tests { ...@@ -3248,12 +3330,12 @@ mod event_processor_tests {
}], }],
}), }),
dp_rank: 0, dp_rank: 0,
}) }))
.unwrap(); .unwrap();
tokio::task::yield_now().await; tokio::task::yield_now().await;
// Third event: parent_hash=Some(2), block_hash=3 (sequential) // Third event: parent_hash=Some(2), block_hash=3 (sequential)
tx.send(KvCacheEvent { tx.send(local_gpu_event(KvCacheEvent {
event_id: 2, event_id: 2,
data: KvCacheEventData::Stored(KvCacheStoreData { data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: Some(ExternalSequenceBlockHash(2)), parent_hash: Some(ExternalSequenceBlockHash(2)),
...@@ -3264,7 +3346,7 @@ mod event_processor_tests { ...@@ -3264,7 +3346,7 @@ mod event_processor_tests {
}], }],
}), }),
dp_rank: 0, dp_rank: 0,
}) }))
.unwrap(); .unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(2)).await; tokio::time::sleep(tokio::time::Duration::from_millis(2)).await;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment