"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "c24882ffb28baeaa3ea7566bfe6e37f209594be4"
Unverified Commit e811db50 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat(kv-router): plumb medium into worker-side placements (#7462)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 9681225a
...@@ -623,9 +623,9 @@ impl SequenceData { ...@@ -623,9 +623,9 @@ impl SequenceData {
/// Convert to a store event. /// Convert to a store event.
pub fn to_store_event(&self, event_id: u64) -> RouterEvent { pub fn to_store_event(&self, event_id: u64) -> RouterEvent {
RouterEvent { RouterEvent::new(
worker_id: self.worker_id, self.worker_id,
event: KvCacheEvent { KvCacheEvent {
event_id, event_id,
data: KvCacheEventData::Stored(KvCacheStoreData { data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None, parent_hash: None,
...@@ -642,21 +642,21 @@ impl SequenceData { ...@@ -642,21 +642,21 @@ impl SequenceData {
}), }),
dp_rank: 0, dp_rank: 0,
}, },
} )
} }
/// Convert to a remove event. /// Convert to a remove event.
pub fn to_remove_event(&self, event_id: u64) -> RouterEvent { pub fn to_remove_event(&self, event_id: u64) -> RouterEvent {
RouterEvent { RouterEvent::new(
worker_id: self.worker_id, self.worker_id,
event: KvCacheEvent { KvCacheEvent {
event_id, event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData { data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: self.external_hashes.clone(), block_hashes: self.external_hashes.clone(),
}), }),
dp_rank: 0, dp_rank: 0,
}, },
} )
} }
} }
......
...@@ -297,10 +297,7 @@ async fn run_benchmark( ...@@ -297,10 +297,7 @@ async fn run_benchmark(
} }
WorkerTraceEntry::Event(event) => { WorkerTraceEntry::Event(event) => {
indexer indexer
.apply_event(RouterEvent { .apply_event(RouterEvent::new(worker_id as u64, event))
worker_id: worker_id as u64,
event,
})
.await; .await;
Ok(None) Ok(None)
} }
......
...@@ -590,6 +590,7 @@ impl ConcurrentRadixTree { ...@@ -590,6 +590,7 @@ impl ConcurrentRadixTree {
// Create a store event for this worker // Create a store event for this worker
let event = RouterEvent { let event = RouterEvent {
worker_id: worker.worker_id, worker_id: worker.worker_id,
storage_tier: crate::protocols::StorageTier::Device,
event: KvCacheEvent { event: KvCacheEvent {
event_id, event_id,
data: KvCacheEventData::Stored(KvCacheStoreData { data: KvCacheEventData::Stored(KvCacheStoreData {
......
...@@ -432,6 +432,7 @@ impl PositionalIndexer { ...@@ -432,6 +432,7 @@ impl PositionalIndexer {
events.push(RouterEvent { events.push(RouterEvent {
worker_id: worker.worker_id, worker_id: worker.worker_id,
storage_tier: crate::protocols::StorageTier::Device,
event: KvCacheEvent { event: KvCacheEvent {
event_id, event_id,
data: KvCacheEventData::Stored(KvCacheStoreData { data: KvCacheEventData::Stored(KvCacheStoreData {
......
...@@ -557,6 +557,7 @@ impl RadixTree { ...@@ -557,6 +557,7 @@ impl RadixTree {
// Create a store event for this worker // Create a store event for this worker
let event = RouterEvent { let event = RouterEvent {
worker_id: worker.worker_id, worker_id: worker.worker_id,
storage_tier: crate::protocols::StorageTier::Device,
event: KvCacheEvent { event: KvCacheEvent {
event_id, event_id,
data: KvCacheEventData::Stored(KvCacheStoreData { data: KvCacheEventData::Stored(KvCacheStoreData {
......
...@@ -13,6 +13,7 @@ use super::concurrent_radix_tree::ConcurrentRadixTree; ...@@ -13,6 +13,7 @@ use super::concurrent_radix_tree::ConcurrentRadixTree;
use super::positional::PositionalIndexer; use super::positional::PositionalIndexer;
use super::*; use super::*;
use crate::protocols::*; use crate::protocols::*;
use crate::test_utils::{remove_event, router_event, stored_blocks_with_sequence_hashes};
// ============================================================================ // ============================================================================
// Helper functions // Helper functions
...@@ -63,25 +64,15 @@ fn make_store_event_with_parent( ...@@ -63,25 +64,15 @@ fn make_store_event_with_parent(
local_hashes.iter().map(|&h| LocalBlockHash(h)).collect(); local_hashes.iter().map(|&h| LocalBlockHash(h)).collect();
let new_seq_hashes = &full_seq_hashes[prefix_hashes.len()..]; let new_seq_hashes = &full_seq_hashes[prefix_hashes.len()..];
RouterEvent { router_event(
worker_id, worker_id,
event: KvCacheEvent { 0,
event_id: 0, 0,
data: KvCacheEventData::Stored(KvCacheStoreData { KvCacheEventData::Stored(KvCacheStoreData {
parent_hash, parent_hash,
blocks: new_block_hashes blocks: stored_blocks_with_sequence_hashes(&new_block_hashes, new_seq_hashes),
.iter() }),
.zip(new_seq_hashes.iter()) )
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
}
} }
/// Create a store event with all options. /// Create a store event with all options.
...@@ -95,25 +86,15 @@ fn make_store_event_full( ...@@ -95,25 +86,15 @@ fn make_store_event_full(
local_hashes.iter().map(|&h| LocalBlockHash(h)).collect(); local_hashes.iter().map(|&h| LocalBlockHash(h)).collect();
let seq_hashes = compute_seq_hash_for_block(&local_block_hashes); let seq_hashes = compute_seq_hash_for_block(&local_block_hashes);
RouterEvent { router_event(
worker_id, worker_id,
event: KvCacheEvent { 0,
event_id: 0, dp_rank,
data: KvCacheEventData::Stored(KvCacheStoreData { KvCacheEventData::Stored(KvCacheStoreData {
parent_hash, parent_hash,
blocks: local_block_hashes blocks: stored_blocks_with_sequence_hashes(&local_block_hashes, &seq_hashes),
.iter() }),
.zip(seq_hashes.iter()) )
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank,
},
}
} }
/// Create a remove event for blocks with given local hashes. /// Create a remove event for blocks with given local hashes.
...@@ -131,19 +112,15 @@ fn make_remove_event_with_dp_rank( ...@@ -131,19 +112,15 @@ fn make_remove_event_with_dp_rank(
local_hashes.iter().map(|&h| LocalBlockHash(h)).collect(); local_hashes.iter().map(|&h| LocalBlockHash(h)).collect();
let seq_hashes = compute_seq_hash_for_block(&local_block_hashes); let seq_hashes = compute_seq_hash_for_block(&local_block_hashes);
RouterEvent { remove_event(
worker_id, worker_id,
event: KvCacheEvent { 0,
event_id: 0, dp_rank,
data: KvCacheEventData::Removed(KvCacheRemoveData { seq_hashes
block_hashes: seq_hashes .iter()
.iter() .map(|&h| ExternalSequenceBlockHash(h))
.map(|&h| ExternalSequenceBlockHash(h)) .collect(),
.collect(), )
}),
dp_rank,
},
}
} }
/// Create a remove event with parent hash for continuation sequences. /// Create a remove event with parent hash for continuation sequences.
...@@ -165,19 +142,15 @@ fn make_remove_event_with_parent( ...@@ -165,19 +142,15 @@ fn make_remove_event_with_parent(
let suffix_seq_hashes = &full_seq_hashes[prefix_hashes.len()..]; let suffix_seq_hashes = &full_seq_hashes[prefix_hashes.len()..];
RouterEvent { remove_event(
worker_id, worker_id,
event: KvCacheEvent { 0,
event_id: 0, 0,
data: KvCacheEventData::Removed(KvCacheRemoveData { suffix_seq_hashes
block_hashes: suffix_seq_hashes .iter()
.iter() .map(|&h| ExternalSequenceBlockHash(h))
.map(|&h| ExternalSequenceBlockHash(h)) .collect(),
.collect(), )
}),
dp_rank: 0,
},
}
} }
/// Snapshot the tree state for deterministic comparison. /// Snapshot the tree state for deterministic comparison.
...@@ -222,14 +195,7 @@ fn make_clear_event(worker_id: u64) -> RouterEvent { ...@@ -222,14 +195,7 @@ fn make_clear_event(worker_id: u64) -> RouterEvent {
/// Create a clear event with a specific dp_rank. /// Create a clear event with a specific dp_rank.
fn make_clear_event_with_dp_rank(worker_id: u64, dp_rank: u32) -> RouterEvent { fn make_clear_event_with_dp_rank(worker_id: u64, dp_rank: u32) -> RouterEvent {
RouterEvent { router_event(worker_id, 0, dp_rank, KvCacheEventData::Cleared)
worker_id,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Cleared,
dp_rank,
},
}
} }
// ============================================================================ // ============================================================================
...@@ -646,16 +612,7 @@ async fn test_partial_block_removal(variant: &str) { ...@@ -646,16 +612,7 @@ async fn test_partial_block_removal(variant: &str) {
let seq_hashes = compute_seq_hash_for_block(&full_hashes); let seq_hashes = compute_seq_hash_for_block(&full_hashes);
let block_3_seq_hash = ExternalSequenceBlockHash(seq_hashes[2]); // Last block's hash let block_3_seq_hash = ExternalSequenceBlockHash(seq_hashes[2]); // Last block's hash
let remove_event = RouterEvent { let remove_event = remove_event(0, 0, 0, vec![block_3_seq_hash]);
worker_id: 0,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![block_3_seq_hash],
}),
dp_rank: 0,
},
};
index.apply_event(remove_event).await; index.apply_event(remove_event).await;
flush_and_settle(index.as_ref()).await; flush_and_settle(index.as_ref()).await;
...@@ -698,16 +655,7 @@ async fn test_remove_mid_chain_block(variant: &str) { ...@@ -698,16 +655,7 @@ async fn test_remove_mid_chain_block(variant: &str) {
let seq_hashes = compute_seq_hash_for_block(&full_hashes); let seq_hashes = compute_seq_hash_for_block(&full_hashes);
let block_3_seq_hash = ExternalSequenceBlockHash(seq_hashes[2]); let block_3_seq_hash = ExternalSequenceBlockHash(seq_hashes[2]);
let remove_event = RouterEvent { let remove_event = remove_event(0, 0, 0, vec![block_3_seq_hash]);
worker_id: 0,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![block_3_seq_hash],
}),
dp_rank: 0,
},
};
index.apply_event(remove_event).await; index.apply_event(remove_event).await;
flush_and_settle(index.as_ref()).await; flush_and_settle(index.as_ref()).await;
...@@ -895,47 +843,27 @@ async fn test_lora_and_base_model_blocks_do_not_conflict(variant: &str) { ...@@ -895,47 +843,27 @@ async fn test_lora_and_base_model_blocks_do_not_conflict(variant: &str) {
let lora_seq = compute_seq_hash_for_block(&lora_hashes); let lora_seq = compute_seq_hash_for_block(&lora_hashes);
// Store base-model blocks on worker 0 // Store base-model blocks on worker 0
let base_event = RouterEvent { let base_event = router_event(
worker_id: 0, 0,
event: KvCacheEvent { 0,
event_id: 0, 0,
data: KvCacheEventData::Stored(KvCacheStoreData { KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None, parent_hash: None,
blocks: base_hashes blocks: stored_blocks_with_sequence_hashes(&base_hashes, &base_seq),
.iter() }),
.zip(base_seq.iter()) );
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
};
index.apply_event(base_event).await; index.apply_event(base_event).await;
// Store LoRA blocks on worker 1 // Store LoRA blocks on worker 1
let lora_event = RouterEvent { let lora_event = router_event(
worker_id: 1, 1,
event: KvCacheEvent { 0,
event_id: 0, 0,
data: KvCacheEventData::Stored(KvCacheStoreData { KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None, parent_hash: None,
blocks: lora_hashes blocks: stored_blocks_with_sequence_hashes(&lora_hashes, &lora_seq),
.iter() }),
.zip(lora_seq.iter()) );
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
};
index.apply_event(lora_event).await; index.apply_event(lora_event).await;
flush_and_settle(index.as_ref()).await; flush_and_settle(index.as_ref()).await;
...@@ -1003,49 +931,29 @@ async fn test_lora_base_same_tokens_no_seq_hash_mismatch(variant: &str) { ...@@ -1003,49 +931,29 @@ async fn test_lora_base_same_tokens_no_seq_hash_mismatch(variant: &str) {
// Worker 0: base model // Worker 0: base model
index index
.apply_event(RouterEvent { .apply_event(router_event(
worker_id: 0, 0,
event: KvCacheEvent { 0,
event_id: 0, 0,
data: KvCacheEventData::Stored(KvCacheStoreData { KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None, parent_hash: None,
blocks: base_local blocks: stored_blocks_with_sequence_hashes(&base_local, &base_seq),
.iter() }),
.zip(base_seq.iter()) ))
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
})
.await; .await;
// Worker 1: LoRA adapter — different LocalBlockHash, so this goes to // Worker 1: LoRA adapter — different LocalBlockHash, so this goes to
// a separate tree path instead of colliding with worker 0's node. // a separate tree path instead of colliding with worker 0's node.
index index
.apply_event(RouterEvent { .apply_event(router_event(
worker_id: 1, 1,
event: KvCacheEvent { 0,
event_id: 0, 0,
data: KvCacheEventData::Stored(KvCacheStoreData { KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None, parent_hash: None,
blocks: lora_local blocks: stored_blocks_with_sequence_hashes(&lora_local, &lora_seq),
.iter() }),
.zip(lora_seq.iter()) ))
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
})
.await; .await;
flush_and_settle(index.as_ref()).await; flush_and_settle(index.as_ref()).await;
...@@ -1094,48 +1002,28 @@ async fn test_different_lora_adapters_do_not_conflict(variant: &str) { ...@@ -1094,48 +1002,28 @@ async fn test_different_lora_adapters_do_not_conflict(variant: &str) {
// Store adapter-a blocks on worker 0 // Store adapter-a blocks on worker 0
index index
.apply_event(RouterEvent { .apply_event(router_event(
worker_id: 0, 0,
event: KvCacheEvent { 0,
event_id: 0, 0,
data: KvCacheEventData::Stored(KvCacheStoreData { KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None, parent_hash: None,
blocks: hashes_a blocks: stored_blocks_with_sequence_hashes(&hashes_a, &seq_a),
.iter() }),
.zip(seq_a.iter()) ))
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
})
.await; .await;
// Store adapter-b blocks on worker 1 // Store adapter-b blocks on worker 1
index index
.apply_event(RouterEvent { .apply_event(router_event(
worker_id: 1, 1,
event: KvCacheEvent { 0,
event_id: 0, 0,
data: KvCacheEventData::Stored(KvCacheStoreData { KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None, parent_hash: None,
blocks: hashes_b blocks: stored_blocks_with_sequence_hashes(&hashes_b, &seq_b),
.iter() }),
.zip(seq_b.iter()) ))
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
})
.await; .await;
flush_and_settle(index.as_ref()).await; flush_and_settle(index.as_ref()).await;
...@@ -1317,16 +1205,7 @@ async fn test_long_sequence_partial_removal(variant: &str) { ...@@ -1317,16 +1205,7 @@ async fn test_long_sequence_partial_removal(variant: &str) {
.map(|&h| ExternalSequenceBlockHash(h)) .map(|&h| ExternalSequenceBlockHash(h))
.collect(); .collect();
let remove_event = RouterEvent { let remove_event = remove_event(0, 0, 0, remove_hashes);
worker_id: 0,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: remove_hashes,
}),
dp_rank: 0,
},
};
index.apply_event(remove_event).await; index.apply_event(remove_event).await;
flush_and_settle(index.as_ref()).await; flush_and_settle(index.as_ref()).await;
...@@ -1868,6 +1747,7 @@ async fn test_frequency(variant: &str) { ...@@ -1868,6 +1747,7 @@ async fn test_frequency(variant: &str) {
// KvIndexerMetrics tests // KvIndexerMetrics tests
// ============================================================================ // ============================================================================
#[cfg(feature = "metrics")]
#[test] #[test]
fn test_increment_event_applied() { fn test_increment_event_applied() {
let metrics = KvIndexerMetrics::new_unregistered(); let metrics = KvIndexerMetrics::new_unregistered();
......
...@@ -133,6 +133,94 @@ impl WorkerWithDpRank { ...@@ -133,6 +133,94 @@ impl WorkerWithDpRank {
} }
} }
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "snake_case")]
pub enum StorageTier {
#[default]
Device,
HostPinned,
Disk,
External,
}
impl StorageTier {
pub fn from_kv_medium(medium: &str) -> Option<Self> {
match medium {
"GPU" | "DEVICE" => Some(Self::Device),
"CPU_PINNED" | "CPU_TIER1" => Some(Self::HostPinned),
"CPU_TIER2" | "DISK" | "NVME" => Some(Self::Disk),
"EXTERNAL" | "NETWORK" | "REMOTE" | "SHARED" => Some(Self::External),
_ => None,
}
}
pub fn from_kv_medium_or_default(medium: Option<&str>) -> Self {
medium
.and_then(Self::from_kv_medium)
.unwrap_or(Self::Device)
}
pub fn is_gpu(self) -> bool {
matches!(self, Self::Device)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum PlacementOwner {
LocalWorker(WorkerWithDpRank),
Shared,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct Placement {
pub owner: PlacementOwner,
pub tier: StorageTier,
}
impl Placement {
pub fn local_worker(worker_id: WorkerId, dp_rank: DpRank, tier: StorageTier) -> Self {
Self {
owner: PlacementOwner::LocalWorker(WorkerWithDpRank::new(worker_id, dp_rank)),
tier,
}
}
pub fn local_gpu(worker_id: WorkerId, dp_rank: DpRank) -> Self {
Self::local_worker(worker_id, dp_rank, StorageTier::Device)
}
pub fn is_local_gpu(&self) -> bool {
matches!(self.owner, PlacementOwner::LocalWorker(_)) && self.tier.is_gpu()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct PlacementEvent {
pub placement: Placement,
pub event: KvCacheEvent,
}
impl PlacementEvent {
pub fn new(placement: Placement, event: KvCacheEvent) -> Self {
Self { placement, event }
}
pub fn local_gpu(worker_id: WorkerId, event: KvCacheEvent) -> Self {
Self::new(Placement::local_gpu(worker_id, event.dp_rank), event)
}
pub fn into_router_event(self) -> Option<RouterEvent> {
let PlacementOwner::LocalWorker(worker) = self.placement.owner else {
return None;
};
Some(RouterEvent::with_storage_tier(
worker.worker_id,
self.event,
self.placement.tier,
))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "method", rename_all = "snake_case")] #[serde(tag = "method", rename_all = "snake_case")]
pub enum RouterRequest { pub enum RouterRequest {
...@@ -512,6 +600,9 @@ pub enum KvCacheEventError { ...@@ -512,6 +600,9 @@ pub enum KvCacheEventError {
pub struct RouterEvent { pub struct RouterEvent {
/// The ID of the worker emitting the event. /// The ID of the worker emitting the event.
pub worker_id: WorkerId, pub worker_id: WorkerId,
/// The storage tier associated with the event.
#[serde(default)]
pub storage_tier: StorageTier,
/// The cache event associated with the worker. /// The cache event associated with the worker.
pub event: KvCacheEvent, pub event: KvCacheEvent,
} }
...@@ -528,7 +619,19 @@ impl RouterEvent { ...@@ -528,7 +619,19 @@ impl RouterEvent {
/// ///
/// A new `RouterEvent`. /// A new `RouterEvent`.
pub fn new(worker_id: WorkerId, event: KvCacheEvent) -> Self { pub fn new(worker_id: WorkerId, event: KvCacheEvent) -> Self {
Self { worker_id, event } Self::with_storage_tier(worker_id, event, StorageTier::Device)
}
pub fn with_storage_tier(
worker_id: WorkerId,
event: KvCacheEvent,
storage_tier: StorageTier,
) -> Self {
Self {
worker_id,
storage_tier,
event,
}
} }
} }
......
...@@ -11,7 +11,7 @@ use tokio::sync::watch; ...@@ -11,7 +11,7 @@ use tokio::sync::watch;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use zeromq::{DealerSocket, Socket, SocketRecv, SocketSend, SubSocket}; use zeromq::{DealerSocket, Socket, SocketRecv, SocketSend, SubSocket};
use crate::protocols::{RouterEvent, WorkerId}; use crate::protocols::{RouterEvent, WorkerId, WorkerWithDpRank};
use crate::zmq_wire::{KvEventBatch, convert_event}; use crate::zmq_wire::{KvEventBatch, convert_event};
use super::indexer::Indexer; use super::indexer::Indexer;
...@@ -111,9 +111,19 @@ async fn replay_gap( ...@@ -111,9 +111,19 @@ async fn replay_gap(
.data_parallel_rank .data_parallel_rank
.map_or(dp_rank, |rank| rank.cast_unsigned()); .map_or(dp_rank, |rank| rank.cast_unsigned());
for raw_event in batch.events { for raw_event in batch.events {
let kv_event = let placement_event = convert_event(
convert_event(raw_event, seq, block_size, effective_dp_rank, warning_count); raw_event,
let router_event = RouterEvent::new(worker_id, kv_event); seq,
block_size,
WorkerWithDpRank::new(worker_id, effective_dp_rank),
warning_count,
);
if !placement_event.placement.is_local_gpu() {
continue;
}
let router_event = placement_event
.into_router_event()
.expect("local worker placement must convert to router event");
indexer.apply_event(router_event).await; indexer.apply_event(router_event).await;
} }
watermark.store(seq, Ordering::Release); watermark.store(seq, Ordering::Release);
...@@ -381,9 +391,19 @@ async fn zmq_recv_loop( ...@@ -381,9 +391,19 @@ async fn zmq_recv_loop(
.data_parallel_rank .data_parallel_rank
.map_or(dp_rank, |rank| rank.cast_unsigned()); .map_or(dp_rank, |rank| rank.cast_unsigned());
for raw_event in batch.events { for raw_event in batch.events {
let kv_event = let placement_event = convert_event(
convert_event(raw_event, seq, block_size, effective_dp_rank, &warning_count); raw_event,
let router_event = RouterEvent::new(worker_id, kv_event); seq,
block_size,
WorkerWithDpRank::new(worker_id, effective_dp_rank),
&warning_count,
);
if !placement_event.placement.is_local_gpu() {
continue;
}
let router_event = placement_event
.into_router_event()
.expect("local worker placement must convert to router event");
indexer.apply_event(router_event).await; indexer.apply_event(router_event).await;
messages_processed += 1; messages_processed += 1;
} }
......
...@@ -12,6 +12,51 @@ use crate::protocols::{ ...@@ -12,6 +12,51 @@ use crate::protocols::{
}; };
use crate::sequences::SequencePublisher; use crate::sequences::SequencePublisher;
pub fn router_event(
worker_id: WorkerId,
event_id: u64,
dp_rank: u32,
data: KvCacheEventData,
) -> RouterEvent {
RouterEvent::new(
worker_id,
KvCacheEvent {
event_id,
data,
dp_rank,
},
)
}
pub fn stored_blocks_with_sequence_hashes(
local_hashes: &[LocalBlockHash],
seq_hashes: &[u64],
) -> Vec<KvCacheStoredBlockData> {
local_hashes
.iter()
.zip(seq_hashes.iter())
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect()
}
pub fn remove_event(
worker_id: WorkerId,
event_id: u64,
dp_rank: u32,
block_hashes: Vec<ExternalSequenceBlockHash>,
) -> RouterEvent {
router_event(
worker_id,
event_id,
dp_rank,
KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }),
)
}
/// Creates blocks with artificial hash mapping (hash * 100) for testing. /// Creates blocks with artificial hash mapping (hash * 100) for testing.
pub fn make_blocks(hashes: Vec<u64>) -> Vec<KvCacheStoredBlockData> { pub fn make_blocks(hashes: Vec<u64>) -> Vec<KvCacheStoredBlockData> {
hashes hashes
...@@ -40,30 +85,19 @@ pub fn create_store_event( ...@@ -40,30 +85,19 @@ pub fn create_store_event(
hashes: Vec<u64>, hashes: Vec<u64>,
parent: Option<ExternalSequenceBlockHash>, parent: Option<ExternalSequenceBlockHash>,
) -> RouterEvent { ) -> RouterEvent {
RouterEvent { router_event(worker_id, event_id, 0, add_blocks(hashes, parent))
worker_id,
event: KvCacheEvent {
event_id,
data: add_blocks(hashes, parent),
dp_rank: 0,
},
}
} }
pub fn create_remove_event(worker_id: WorkerId, event_id: u64, hashes: Vec<u64>) -> RouterEvent { pub fn create_remove_event(worker_id: WorkerId, event_id: u64, hashes: Vec<u64>) -> RouterEvent {
RouterEvent { remove_event(
worker_id, worker_id,
event: KvCacheEvent { event_id,
event_id, 0,
data: KvCacheEventData::Removed(KvCacheRemoveData { hashes
block_hashes: hashes .iter()
.iter() .map(|i| ExternalSequenceBlockHash(*i * 100))
.map(|i| ExternalSequenceBlockHash(*i * 100)) .collect(),
.collect(), )
}),
dp_rank: 0,
},
}
} }
/// No-op [`SequencePublisher`] for tests and benchmarks that don't need event transport. /// No-op [`SequencePublisher`] for tests and benchmarks that don't need event transport.
......
...@@ -18,7 +18,8 @@ use serde::de::{self, Deserializer, IgnoredAny, MapAccess, SeqAccess, Visitor}; ...@@ -18,7 +18,8 @@ use serde::de::{self, Deserializer, IgnoredAny, MapAccess, SeqAccess, Visitor};
use crate::protocols::{ use crate::protocols::{
BlockExtraInfo, BlockMmObjectInfo, ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, BlockExtraInfo, BlockMmObjectInfo, ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData,
KvCacheRemoveData, KvCacheStoreData, KvCacheStoredBlockData, compute_block_hash_for_seq, KvCacheRemoveData, KvCacheStoreData, KvCacheStoredBlockData, Placement, PlacementEvent,
StorageTier, WorkerWithDpRank, compute_block_hash_for_seq,
}; };
// ------------------------------------------------------------------------- // -------------------------------------------------------------------------
...@@ -335,16 +336,22 @@ impl<'de> Visitor<'de> for RawKvEventVisitor { ...@@ -335,16 +336,22 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
// Event conversion -------------------------------------------------------- // Event conversion --------------------------------------------------------
// ------------------------------------------------------------------------- // -------------------------------------------------------------------------
/// Convert a raw event coming from the ZMQ channel into the internal /// Convert a raw event coming from the ZMQ channel into a placement-aware worker event.
/// [`KvCacheEvent`] representation used by the router.
pub fn convert_event( pub fn convert_event(
raw: RawKvEvent, raw: RawKvEvent,
event_id: u64, event_id: u64,
kv_block_size: u32, kv_block_size: u32,
dp_rank: u32, worker: WorkerWithDpRank,
warning_count: &Arc<AtomicU32>, warning_count: &Arc<AtomicU32>,
) -> KvCacheEvent { ) -> PlacementEvent {
match raw { let storage_tier = match &raw {
RawKvEvent::BlockStored { medium, .. } | RawKvEvent::BlockRemoved { medium, .. } => {
StorageTier::from_kv_medium_or_default(medium.as_deref())
}
RawKvEvent::AllBlocksCleared => StorageTier::Device,
};
let dp_rank = worker.dp_rank;
let event = match raw {
RawKvEvent::BlockStored { RawKvEvent::BlockStored {
block_hashes, block_hashes,
parent_block_hash, parent_block_hash,
...@@ -369,13 +376,16 @@ pub fn convert_event( ...@@ -369,13 +376,16 @@ pub fn convert_event(
// Return an empty Removed instead of Cleared to avoid nuking // Return an empty Removed instead of Cleared to avoid nuking
// the worker's entire index state. An empty Removed is a no-op // the worker's entire index state. An empty Removed is a no-op
// in the radix tree (zero iterations, returns Ok(())). // in the radix tree (zero iterations, returns Ok(())).
return KvCacheEvent { return PlacementEvent::new(
event_id, Placement::local_worker(worker.worker_id, worker.dp_rank, storage_tier),
data: KvCacheEventData::Removed(KvCacheRemoveData { KvCacheEvent {
block_hashes: vec![], event_id,
}), data: KvCacheEventData::Removed(KvCacheRemoveData {
dp_rank, block_hashes: vec![],
}; }),
dp_rank,
},
);
} }
} }
...@@ -422,7 +432,12 @@ pub fn convert_event( ...@@ -422,7 +432,12 @@ pub fn convert_event(
data: KvCacheEventData::Cleared, data: KvCacheEventData::Cleared,
dp_rank, dp_rank,
}, },
} };
PlacementEvent::new(
Placement::local_worker(worker.worker_id, worker.dp_rank, storage_tier),
event,
)
} }
pub fn create_stored_block_from_parts( pub fn create_stored_block_from_parts(
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use dynamo_kv_router::protocols::XXH3_SEED; use dynamo_kv_router::protocols::{StorageTier as RouterStorageTier, XXH3_SEED};
/// LocalBlockHash type (content hash from tokens only) /// LocalBlockHash type (content hash from tokens only)
type LocalBlockHash = u64; type LocalBlockHash = u64;
...@@ -108,12 +108,7 @@ pub enum StorageTier { ...@@ -108,12 +108,7 @@ pub enum StorageTier {
impl StorageTier { impl StorageTier {
/// Parse from vLLM's medium string (e.g., "GPU", "CPU_TIER1", "CPU_TIER2") /// Parse from vLLM's medium string (e.g., "GPU", "CPU_TIER1", "CPU_TIER2")
pub fn from_vllm_medium(s: &str) -> Option<Self> { pub fn from_vllm_medium(s: &str) -> Option<Self> {
match s { RouterStorageTier::from_kv_medium(s).map(Into::into)
"GPU" => Some(StorageTier::Device),
"CPU_TIER1" => Some(StorageTier::HostPinned),
"CPU_TIER2" => Some(StorageTier::Disk),
_ => None,
}
} }
/// Convert to vLLM's medium string /// Convert to vLLM's medium string
...@@ -135,6 +130,16 @@ impl StorageTier { ...@@ -135,6 +130,16 @@ impl StorageTier {
} }
} }
impl From<RouterStorageTier> for StorageTier {
fn from(value: RouterStorageTier) -> Self {
match value {
RouterStorageTier::Device => Self::Device,
RouterStorageTier::HostPinned => Self::HostPinned,
RouterStorageTier::Disk | RouterStorageTier::External => Self::Disk,
}
}
}
/// Legacy type alias for backward compatibility /// Legacy type alias for backward compatibility
#[deprecated(note = "Use StorageTier instead")] #[deprecated(note = "Use StorageTier instead")]
pub type StorageMedium = StorageTier; pub type StorageMedium = StorageTier;
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment