Unverified Commit e811db50 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat(kv-router): plumb medium into worker-side placements (#7462)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 9681225a
......@@ -623,9 +623,9 @@ impl SequenceData {
/// Convert to a store event.
pub fn to_store_event(&self, event_id: u64) -> RouterEvent {
RouterEvent {
worker_id: self.worker_id,
event: KvCacheEvent {
RouterEvent::new(
self.worker_id,
KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
......@@ -642,21 +642,21 @@ impl SequenceData {
}),
dp_rank: 0,
},
}
)
}
/// Convert to a remove event.
pub fn to_remove_event(&self, event_id: u64) -> RouterEvent {
RouterEvent {
worker_id: self.worker_id,
event: KvCacheEvent {
RouterEvent::new(
self.worker_id,
KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: self.external_hashes.clone(),
}),
dp_rank: 0,
},
}
)
}
}
......
......@@ -297,10 +297,7 @@ async fn run_benchmark(
}
WorkerTraceEntry::Event(event) => {
indexer
.apply_event(RouterEvent {
worker_id: worker_id as u64,
event,
})
.apply_event(RouterEvent::new(worker_id as u64, event))
.await;
Ok(None)
}
......
......@@ -590,6 +590,7 @@ impl ConcurrentRadixTree {
// Create a store event for this worker
let event = RouterEvent {
worker_id: worker.worker_id,
storage_tier: crate::protocols::StorageTier::Device,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
......
......@@ -432,6 +432,7 @@ impl PositionalIndexer {
events.push(RouterEvent {
worker_id: worker.worker_id,
storage_tier: crate::protocols::StorageTier::Device,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
......
......@@ -557,6 +557,7 @@ impl RadixTree {
// Create a store event for this worker
let event = RouterEvent {
worker_id: worker.worker_id,
storage_tier: crate::protocols::StorageTier::Device,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
......
......@@ -13,6 +13,7 @@ use super::concurrent_radix_tree::ConcurrentRadixTree;
use super::positional::PositionalIndexer;
use super::*;
use crate::protocols::*;
use crate::test_utils::{remove_event, router_event, stored_blocks_with_sequence_hashes};
// ============================================================================
// Helper functions
......@@ -63,25 +64,15 @@ fn make_store_event_with_parent(
local_hashes.iter().map(|&h| LocalBlockHash(h)).collect();
let new_seq_hashes = &full_seq_hashes[prefix_hashes.len()..];
RouterEvent {
router_event(
worker_id,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks: new_block_hashes
.iter()
.zip(new_seq_hashes.iter())
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
}
0,
0,
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks: stored_blocks_with_sequence_hashes(&new_block_hashes, new_seq_hashes),
}),
)
}
/// Create a store event with all options.
......@@ -95,25 +86,15 @@ fn make_store_event_full(
local_hashes.iter().map(|&h| LocalBlockHash(h)).collect();
let seq_hashes = compute_seq_hash_for_block(&local_block_hashes);
RouterEvent {
router_event(
worker_id,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks: local_block_hashes
.iter()
.zip(seq_hashes.iter())
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank,
},
}
0,
dp_rank,
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks: stored_blocks_with_sequence_hashes(&local_block_hashes, &seq_hashes),
}),
)
}
/// Create a remove event for blocks with given local hashes.
......@@ -131,19 +112,15 @@ fn make_remove_event_with_dp_rank(
local_hashes.iter().map(|&h| LocalBlockHash(h)).collect();
let seq_hashes = compute_seq_hash_for_block(&local_block_hashes);
RouterEvent {
remove_event(
worker_id,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: seq_hashes
.iter()
.map(|&h| ExternalSequenceBlockHash(h))
.collect(),
}),
dp_rank,
},
}
0,
dp_rank,
seq_hashes
.iter()
.map(|&h| ExternalSequenceBlockHash(h))
.collect(),
)
}
/// Create a remove event with parent hash for continuation sequences.
......@@ -165,19 +142,15 @@ fn make_remove_event_with_parent(
let suffix_seq_hashes = &full_seq_hashes[prefix_hashes.len()..];
RouterEvent {
remove_event(
worker_id,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: suffix_seq_hashes
.iter()
.map(|&h| ExternalSequenceBlockHash(h))
.collect(),
}),
dp_rank: 0,
},
}
0,
0,
suffix_seq_hashes
.iter()
.map(|&h| ExternalSequenceBlockHash(h))
.collect(),
)
}
/// Snapshot the tree state for deterministic comparison.
......@@ -222,14 +195,7 @@ fn make_clear_event(worker_id: u64) -> RouterEvent {
/// Create a clear event with a specific dp_rank.
fn make_clear_event_with_dp_rank(worker_id: u64, dp_rank: u32) -> RouterEvent {
RouterEvent {
worker_id,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Cleared,
dp_rank,
},
}
router_event(worker_id, 0, dp_rank, KvCacheEventData::Cleared)
}
// ============================================================================
......@@ -646,16 +612,7 @@ async fn test_partial_block_removal(variant: &str) {
let seq_hashes = compute_seq_hash_for_block(&full_hashes);
let block_3_seq_hash = ExternalSequenceBlockHash(seq_hashes[2]); // Last block's hash
let remove_event = RouterEvent {
worker_id: 0,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![block_3_seq_hash],
}),
dp_rank: 0,
},
};
let remove_event = remove_event(0, 0, 0, vec![block_3_seq_hash]);
index.apply_event(remove_event).await;
flush_and_settle(index.as_ref()).await;
......@@ -698,16 +655,7 @@ async fn test_remove_mid_chain_block(variant: &str) {
let seq_hashes = compute_seq_hash_for_block(&full_hashes);
let block_3_seq_hash = ExternalSequenceBlockHash(seq_hashes[2]);
let remove_event = RouterEvent {
worker_id: 0,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![block_3_seq_hash],
}),
dp_rank: 0,
},
};
let remove_event = remove_event(0, 0, 0, vec![block_3_seq_hash]);
index.apply_event(remove_event).await;
flush_and_settle(index.as_ref()).await;
......@@ -895,47 +843,27 @@ async fn test_lora_and_base_model_blocks_do_not_conflict(variant: &str) {
let lora_seq = compute_seq_hash_for_block(&lora_hashes);
// Store base-model blocks on worker 0
let base_event = RouterEvent {
worker_id: 0,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: base_hashes
.iter()
.zip(base_seq.iter())
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
};
let base_event = router_event(
0,
0,
0,
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: stored_blocks_with_sequence_hashes(&base_hashes, &base_seq),
}),
);
index.apply_event(base_event).await;
// Store LoRA blocks on worker 1
let lora_event = RouterEvent {
worker_id: 1,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: lora_hashes
.iter()
.zip(lora_seq.iter())
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
};
let lora_event = router_event(
1,
0,
0,
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: stored_blocks_with_sequence_hashes(&lora_hashes, &lora_seq),
}),
);
index.apply_event(lora_event).await;
flush_and_settle(index.as_ref()).await;
......@@ -1003,49 +931,29 @@ async fn test_lora_base_same_tokens_no_seq_hash_mismatch(variant: &str) {
// Worker 0: base model
index
.apply_event(RouterEvent {
worker_id: 0,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: base_local
.iter()
.zip(base_seq.iter())
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
})
.apply_event(router_event(
0,
0,
0,
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: stored_blocks_with_sequence_hashes(&base_local, &base_seq),
}),
))
.await;
// Worker 1: LoRA adapter — different LocalBlockHash, so this goes to
// a separate tree path instead of colliding with worker 0's node.
index
.apply_event(RouterEvent {
worker_id: 1,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: lora_local
.iter()
.zip(lora_seq.iter())
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
})
.apply_event(router_event(
1,
0,
0,
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: stored_blocks_with_sequence_hashes(&lora_local, &lora_seq),
}),
))
.await;
flush_and_settle(index.as_ref()).await;
......@@ -1094,48 +1002,28 @@ async fn test_different_lora_adapters_do_not_conflict(variant: &str) {
// Store adapter-a blocks on worker 0
index
.apply_event(RouterEvent {
worker_id: 0,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: hashes_a
.iter()
.zip(seq_a.iter())
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
})
.apply_event(router_event(
0,
0,
0,
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: stored_blocks_with_sequence_hashes(&hashes_a, &seq_a),
}),
))
.await;
// Store adapter-b blocks on worker 1
index
.apply_event(RouterEvent {
worker_id: 1,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: hashes_b
.iter()
.zip(seq_b.iter())
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
})
.apply_event(router_event(
1,
0,
0,
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: stored_blocks_with_sequence_hashes(&hashes_b, &seq_b),
}),
))
.await;
flush_and_settle(index.as_ref()).await;
......@@ -1317,16 +1205,7 @@ async fn test_long_sequence_partial_removal(variant: &str) {
.map(|&h| ExternalSequenceBlockHash(h))
.collect();
let remove_event = RouterEvent {
worker_id: 0,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: remove_hashes,
}),
dp_rank: 0,
},
};
let remove_event = remove_event(0, 0, 0, remove_hashes);
index.apply_event(remove_event).await;
flush_and_settle(index.as_ref()).await;
......@@ -1868,6 +1747,7 @@ async fn test_frequency(variant: &str) {
// KvIndexerMetrics tests
// ============================================================================
#[cfg(feature = "metrics")]
#[test]
fn test_increment_event_applied() {
let metrics = KvIndexerMetrics::new_unregistered();
......
......@@ -133,6 +133,94 @@ impl WorkerWithDpRank {
}
}
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "snake_case")]
pub enum StorageTier {
#[default]
Device,
HostPinned,
Disk,
External,
}
impl StorageTier {
pub fn from_kv_medium(medium: &str) -> Option<Self> {
match medium {
"GPU" | "DEVICE" => Some(Self::Device),
"CPU_PINNED" | "CPU_TIER1" => Some(Self::HostPinned),
"CPU_TIER2" | "DISK" | "NVME" => Some(Self::Disk),
"EXTERNAL" | "NETWORK" | "REMOTE" | "SHARED" => Some(Self::External),
_ => None,
}
}
pub fn from_kv_medium_or_default(medium: Option<&str>) -> Self {
medium
.and_then(Self::from_kv_medium)
.unwrap_or(Self::Device)
}
pub fn is_gpu(self) -> bool {
matches!(self, Self::Device)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum PlacementOwner {
LocalWorker(WorkerWithDpRank),
Shared,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct Placement {
pub owner: PlacementOwner,
pub tier: StorageTier,
}
impl Placement {
pub fn local_worker(worker_id: WorkerId, dp_rank: DpRank, tier: StorageTier) -> Self {
Self {
owner: PlacementOwner::LocalWorker(WorkerWithDpRank::new(worker_id, dp_rank)),
tier,
}
}
pub fn local_gpu(worker_id: WorkerId, dp_rank: DpRank) -> Self {
Self::local_worker(worker_id, dp_rank, StorageTier::Device)
}
pub fn is_local_gpu(&self) -> bool {
matches!(self.owner, PlacementOwner::LocalWorker(_)) && self.tier.is_gpu()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct PlacementEvent {
pub placement: Placement,
pub event: KvCacheEvent,
}
impl PlacementEvent {
pub fn new(placement: Placement, event: KvCacheEvent) -> Self {
Self { placement, event }
}
pub fn local_gpu(worker_id: WorkerId, event: KvCacheEvent) -> Self {
Self::new(Placement::local_gpu(worker_id, event.dp_rank), event)
}
pub fn into_router_event(self) -> Option<RouterEvent> {
let PlacementOwner::LocalWorker(worker) = self.placement.owner else {
return None;
};
Some(RouterEvent::with_storage_tier(
worker.worker_id,
self.event,
self.placement.tier,
))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "method", rename_all = "snake_case")]
pub enum RouterRequest {
......@@ -512,6 +600,9 @@ pub enum KvCacheEventError {
pub struct RouterEvent {
/// The ID of the worker emitting the event.
pub worker_id: WorkerId,
/// The storage tier associated with the event.
#[serde(default)]
pub storage_tier: StorageTier,
/// The cache event associated with the worker.
pub event: KvCacheEvent,
}
......@@ -528,7 +619,19 @@ impl RouterEvent {
///
/// A new `RouterEvent`.
pub fn new(worker_id: WorkerId, event: KvCacheEvent) -> Self {
Self { worker_id, event }
Self::with_storage_tier(worker_id, event, StorageTier::Device)
}
pub fn with_storage_tier(
worker_id: WorkerId,
event: KvCacheEvent,
storage_tier: StorageTier,
) -> Self {
Self {
worker_id,
storage_tier,
event,
}
}
}
......
......@@ -11,7 +11,7 @@ use tokio::sync::watch;
use tokio_util::sync::CancellationToken;
use zeromq::{DealerSocket, Socket, SocketRecv, SocketSend, SubSocket};
use crate::protocols::{RouterEvent, WorkerId};
use crate::protocols::{RouterEvent, WorkerId, WorkerWithDpRank};
use crate::zmq_wire::{KvEventBatch, convert_event};
use super::indexer::Indexer;
......@@ -111,9 +111,19 @@ async fn replay_gap(
.data_parallel_rank
.map_or(dp_rank, |rank| rank.cast_unsigned());
for raw_event in batch.events {
let kv_event =
convert_event(raw_event, seq, block_size, effective_dp_rank, warning_count);
let router_event = RouterEvent::new(worker_id, kv_event);
let placement_event = convert_event(
raw_event,
seq,
block_size,
WorkerWithDpRank::new(worker_id, effective_dp_rank),
warning_count,
);
if !placement_event.placement.is_local_gpu() {
continue;
}
let router_event = placement_event
.into_router_event()
.expect("local worker placement must convert to router event");
indexer.apply_event(router_event).await;
}
watermark.store(seq, Ordering::Release);
......@@ -381,9 +391,19 @@ async fn zmq_recv_loop(
.data_parallel_rank
.map_or(dp_rank, |rank| rank.cast_unsigned());
for raw_event in batch.events {
let kv_event =
convert_event(raw_event, seq, block_size, effective_dp_rank, &warning_count);
let router_event = RouterEvent::new(worker_id, kv_event);
let placement_event = convert_event(
raw_event,
seq,
block_size,
WorkerWithDpRank::new(worker_id, effective_dp_rank),
&warning_count,
);
if !placement_event.placement.is_local_gpu() {
continue;
}
let router_event = placement_event
.into_router_event()
.expect("local worker placement must convert to router event");
indexer.apply_event(router_event).await;
messages_processed += 1;
}
......
......@@ -12,6 +12,51 @@ use crate::protocols::{
};
use crate::sequences::SequencePublisher;
pub fn router_event(
worker_id: WorkerId,
event_id: u64,
dp_rank: u32,
data: KvCacheEventData,
) -> RouterEvent {
RouterEvent::new(
worker_id,
KvCacheEvent {
event_id,
data,
dp_rank,
},
)
}
pub fn stored_blocks_with_sequence_hashes(
local_hashes: &[LocalBlockHash],
seq_hashes: &[u64],
) -> Vec<KvCacheStoredBlockData> {
local_hashes
.iter()
.zip(seq_hashes.iter())
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect()
}
pub fn remove_event(
worker_id: WorkerId,
event_id: u64,
dp_rank: u32,
block_hashes: Vec<ExternalSequenceBlockHash>,
) -> RouterEvent {
router_event(
worker_id,
event_id,
dp_rank,
KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }),
)
}
/// Creates blocks with artificial hash mapping (hash * 100) for testing.
pub fn make_blocks(hashes: Vec<u64>) -> Vec<KvCacheStoredBlockData> {
hashes
......@@ -40,30 +85,19 @@ pub fn create_store_event(
hashes: Vec<u64>,
parent: Option<ExternalSequenceBlockHash>,
) -> RouterEvent {
RouterEvent {
worker_id,
event: KvCacheEvent {
event_id,
data: add_blocks(hashes, parent),
dp_rank: 0,
},
}
router_event(worker_id, event_id, 0, add_blocks(hashes, parent))
}
pub fn create_remove_event(worker_id: WorkerId, event_id: u64, hashes: Vec<u64>) -> RouterEvent {
RouterEvent {
remove_event(
worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: hashes
.iter()
.map(|i| ExternalSequenceBlockHash(*i * 100))
.collect(),
}),
dp_rank: 0,
},
}
event_id,
0,
hashes
.iter()
.map(|i| ExternalSequenceBlockHash(*i * 100))
.collect(),
)
}
/// No-op [`SequencePublisher`] for tests and benchmarks that don't need event transport.
......
......@@ -18,7 +18,8 @@ use serde::de::{self, Deserializer, IgnoredAny, MapAccess, SeqAccess, Visitor};
use crate::protocols::{
BlockExtraInfo, BlockMmObjectInfo, ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData,
KvCacheRemoveData, KvCacheStoreData, KvCacheStoredBlockData, compute_block_hash_for_seq,
KvCacheRemoveData, KvCacheStoreData, KvCacheStoredBlockData, Placement, PlacementEvent,
StorageTier, WorkerWithDpRank, compute_block_hash_for_seq,
};
// -------------------------------------------------------------------------
......@@ -335,16 +336,22 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
// Event conversion --------------------------------------------------------
// -------------------------------------------------------------------------
/// Convert a raw event coming from the ZMQ channel into the internal
/// [`KvCacheEvent`] representation used by the router.
/// Convert a raw event coming from the ZMQ channel into a placement-aware worker event.
pub fn convert_event(
raw: RawKvEvent,
event_id: u64,
kv_block_size: u32,
dp_rank: u32,
worker: WorkerWithDpRank,
warning_count: &Arc<AtomicU32>,
) -> KvCacheEvent {
match raw {
) -> PlacementEvent {
let storage_tier = match &raw {
RawKvEvent::BlockStored { medium, .. } | RawKvEvent::BlockRemoved { medium, .. } => {
StorageTier::from_kv_medium_or_default(medium.as_deref())
}
RawKvEvent::AllBlocksCleared => StorageTier::Device,
};
let dp_rank = worker.dp_rank;
let event = match raw {
RawKvEvent::BlockStored {
block_hashes,
parent_block_hash,
......@@ -369,13 +376,16 @@ pub fn convert_event(
// Return an empty Removed instead of Cleared to avoid nuking
// the worker's entire index state. An empty Removed is a no-op
// in the radix tree (zero iterations, returns Ok(())).
return KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![],
}),
dp_rank,
};
return PlacementEvent::new(
Placement::local_worker(worker.worker_id, worker.dp_rank, storage_tier),
KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![],
}),
dp_rank,
},
);
}
}
......@@ -422,7 +432,12 @@ pub fn convert_event(
data: KvCacheEventData::Cleared,
dp_rank,
},
}
};
PlacementEvent::new(
Placement::local_worker(worker.worker_id, worker.dp_rank, storage_tier),
event,
)
}
pub fn create_stored_block_from_parts(
......
......@@ -17,7 +17,7 @@
use std::collections::{HashMap, HashSet};
use dynamo_kv_router::protocols::XXH3_SEED;
use dynamo_kv_router::protocols::{StorageTier as RouterStorageTier, XXH3_SEED};
/// LocalBlockHash type (content hash from tokens only)
type LocalBlockHash = u64;
......@@ -108,12 +108,7 @@ pub enum StorageTier {
impl StorageTier {
/// Parse from vLLM's medium string (e.g., "GPU", "CPU_TIER1", "CPU_TIER2")
pub fn from_vllm_medium(s: &str) -> Option<Self> {
match s {
"GPU" => Some(StorageTier::Device),
"CPU_TIER1" => Some(StorageTier::HostPinned),
"CPU_TIER2" => Some(StorageTier::Disk),
_ => None,
}
RouterStorageTier::from_kv_medium(s).map(Into::into)
}
/// Convert to vLLM's medium string
......@@ -135,6 +130,16 @@ impl StorageTier {
}
}
impl From<RouterStorageTier> for StorageTier {
fn from(value: RouterStorageTier) -> Self {
match value {
RouterStorageTier::Device => Self::Device,
RouterStorageTier::HostPinned => Self::HostPinned,
RouterStorageTier::Disk | RouterStorageTier::External => Self::Disk,
}
}
}
/// Legacy type alias for backward compatibility
#[deprecated(note = "Use StorageTier instead")]
pub type StorageMedium = StorageTier;
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment