Unverified Commit 6271a31f authored by Yuewei Na's avatar Yuewei Na Committed by GitHub
Browse files

feat: Implement priority-based KV cache offload filtering (#5563)


Signed-off-by: default avatarYuewei Na <248773860+nv-yna@users.noreply.github.com>
Signed-off-by: default avatarYuewei Na <nv-yna@users.noreply.github.com>
Signed-off-by: default avataryna <nv-yna@users.noreply.github.com>
Co-authored-by: default avatarYuewei Na <nv-yna@users.noreply.github.com>
parent 2b6d1338
...@@ -86,6 +86,7 @@ Note that the default pip wheel built is not compatible with CUDA 13 at the mome ...@@ -86,6 +86,7 @@ Note that the default pip wheel built is not compatible with CUDA 13 at the mome
| `DYN_KVBM_METRICS` | Enable metrics endpoint | `false` | | `DYN_KVBM_METRICS` | Enable metrics endpoint | `false` |
| `DYN_KVBM_METRICS_PORT` | Metrics port | `6880` | | `DYN_KVBM_METRICS_PORT` | Metrics port | `6880` |
| `DYN_KVBM_DISABLE_DISK_OFFLOAD_FILTER` | Disable disk offload filtering to remove SSD lifespan protection | `false` | | `DYN_KVBM_DISABLE_DISK_OFFLOAD_FILTER` | Disable disk offload filtering to remove SSD lifespan protection | `false` |
| `DYN_KVBM_HOST_OFFLOAD_PREFIX_MIN_PRIORITY` | Minimum priority (0-100) for CPU offload with contiguous (prefix) semantics: offloading stops at the first block below threshold, and all subsequent blocks are also skipped. Used for priority-based filtering. | `0` (no filtering) |
#### Disk Storage Configuration #### Disk Storage Configuration
......
...@@ -130,6 +130,7 @@ class DynamoKVBMConnectorLeader(KvCacheConnectorScheduler): ...@@ -130,6 +130,7 @@ class DynamoKVBMConnectorLeader(KvCacheConnectorScheduler):
req.new_tokens, req.new_tokens,
req.new_block_ids, req.new_block_ids,
req.computed_position, req.computed_position,
req.priorities, # Pass retention priorities for offload filtering
) )
resumed_from_preemption = False resumed_from_preemption = False
...@@ -140,6 +141,7 @@ class DynamoKVBMConnectorLeader(KvCacheConnectorScheduler): ...@@ -140,6 +141,7 @@ class DynamoKVBMConnectorLeader(KvCacheConnectorScheduler):
req.new_tokens, req.new_tokens,
req.new_block_ids, req.new_block_ids,
req.computed_position, req.computed_position,
req.priorities, # Pass retention priorities for offload filtering
) )
output.add_num_scheduled_tokens( output.add_num_scheduled_tokens(
......
...@@ -42,22 +42,26 @@ impl SchedulerOutput { ...@@ -42,22 +42,26 @@ impl SchedulerOutput {
// I am surprised that vLLM's NewRequestData does not include the salt hash. // I am surprised that vLLM's NewRequestData does not include the salt hash.
// It has almost everything else to compute the block hashes worker side. // It has almost everything else to compute the block hashes worker side.
#[pyo3(signature = (request_id, prompt_token_ids, block_ids, num_computed_tokens, priorities=None))]
pub fn add_new_request( pub fn add_new_request(
&mut self, &mut self,
request_id: String, request_id: String,
prompt_token_ids: Vec<u32>, prompt_token_ids: Vec<u32>,
block_ids: Vec<BlockId>, block_ids: Vec<BlockId>,
num_computed_tokens: usize, num_computed_tokens: usize,
priorities: Option<Vec<u32>>,
) { ) {
self.new_requests.push(NewRequestData { self.new_requests.push(NewRequestData {
request_id, request_id,
prompt_token_ids, prompt_token_ids,
block_ids, block_ids,
num_computed_tokens, num_computed_tokens,
priorities,
}); });
} }
/// This is called by the leader to update the cached requests /// This is called by the leader to update the cached requests
#[pyo3(signature = (request_id, resumed_from_preemption, new_token_ids, new_block_ids, num_computed_tokens, priorities=None))]
pub fn add_cached_request( pub fn add_cached_request(
&mut self, &mut self,
request_id: String, request_id: String,
...@@ -65,6 +69,7 @@ impl SchedulerOutput { ...@@ -65,6 +69,7 @@ impl SchedulerOutput {
new_token_ids: Vec<u32>, new_token_ids: Vec<u32>,
new_block_ids: Vec<BlockId>, new_block_ids: Vec<BlockId>,
num_computed_tokens: usize, num_computed_tokens: usize,
priorities: Option<Vec<u32>>,
) { ) {
self.cached_requests.push(CachedRequestData { self.cached_requests.push(CachedRequestData {
request_id, request_id,
...@@ -72,6 +77,7 @@ impl SchedulerOutput { ...@@ -72,6 +77,7 @@ impl SchedulerOutput {
new_token_ids, new_token_ids,
new_block_ids, new_block_ids,
num_computed_tokens, num_computed_tokens,
priorities,
}); });
} }
...@@ -99,6 +105,9 @@ pub struct NewRequestData { ...@@ -99,6 +105,9 @@ pub struct NewRequestData {
pub prompt_token_ids: Vec<u32>, pub prompt_token_ids: Vec<u32>,
pub block_ids: Vec<BlockId>, pub block_ids: Vec<BlockId>,
pub num_computed_tokens: usize, pub num_computed_tokens: usize,
/// Retention priorities for each block (same length as block_ids).
/// Used for priority-based offload filtering.
pub priorities: Option<Vec<u32>>,
} }
impl std::fmt::Debug for NewRequestData { impl std::fmt::Debug for NewRequestData {
...@@ -119,6 +128,9 @@ pub struct CachedRequestData { ...@@ -119,6 +128,9 @@ pub struct CachedRequestData {
pub new_token_ids: Vec<u32>, pub new_token_ids: Vec<u32>,
pub new_block_ids: Vec<BlockId>, pub new_block_ids: Vec<BlockId>,
pub num_computed_tokens: usize, pub num_computed_tokens: usize,
/// Retention priorities for each new block (same length as new_block_ids).
/// Used for priority-based offload filtering.
pub priorities: Option<Vec<u32>>,
} }
impl std::fmt::Debug for CachedRequestData { impl std::fmt::Debug for CachedRequestData {
......
...@@ -427,7 +427,7 @@ impl Leader for KvConnectorLeader { ...@@ -427,7 +427,7 @@ impl Leader for KvConnectorLeader {
.get(request_id) .get(request_id)
.unwrap_or(&0); .unwrap_or(&0);
slot.apply_scheduler_output(&[], &[], new_req.num_computed_tokens, scheduled_tokens)?; slot.apply_scheduler_output(&[], &[], new_req.num_computed_tokens, scheduled_tokens, None)?;
let pending_ops_opt = slot.take_pending_operations(); let pending_ops_opt = slot.take_pending_operations();
...@@ -497,6 +497,7 @@ impl Leader for KvConnectorLeader { ...@@ -497,6 +497,7 @@ impl Leader for KvConnectorLeader {
&cached_req.new_block_ids, &cached_req.new_block_ids,
cached_req.num_computed_tokens, cached_req.num_computed_tokens,
scheduled_tokens, scheduled_tokens,
None,
)?; )?;
if let Some(pending_ops) = slot.take_pending_operations() { if let Some(pending_ops) = slot.take_pending_operations() {
......
...@@ -108,6 +108,7 @@ pub trait Slot: std::fmt::Debug { ...@@ -108,6 +108,7 @@ pub trait Slot: std::fmt::Debug {
block_ids: &[usize], block_ids: &[usize],
num_computed_tokens: usize, num_computed_tokens: usize,
num_scheduled_tokens: usize, num_scheduled_tokens: usize,
priorities: Option<&[u32]>,
) -> Result<(), SlotError>; ) -> Result<(), SlotError>;
fn record_start_iteration(&mut self, iteration: u64) -> Result<(), SlotError>; fn record_start_iteration(&mut self, iteration: u64) -> Result<(), SlotError>;
...@@ -366,6 +367,14 @@ pub struct VllmConnectorSlot { ...@@ -366,6 +367,14 @@ pub struct VllmConnectorSlot {
/// Cache statistics tracker for this KVBM instance /// Cache statistics tracker for this KVBM instance
cache_stats: Arc<CacheStatsTracker>, cache_stats: Arc<CacheStatsTracker>,
/// Minimum priority threshold for offload filtering.
/// All blocks after the first occurance of block priority < threshold are not offloaded.
offload_min_priority: u32,
/// Block index where offload was terminated due to priority filtering.
/// When Some, no further blocks will be offloaded to ensure global contiguity.
offload_terminated_at_block: Option<usize>,
} }
impl VllmConnectorSlot { impl VllmConnectorSlot {
...@@ -403,6 +412,11 @@ impl VllmConnectorSlot { ...@@ -403,6 +412,11 @@ impl VllmConnectorSlot {
performed_cache_lookup: false, performed_cache_lookup: false,
total_blocks_queried: 0, total_blocks_queried: 0,
cache_stats, cache_stats,
offload_min_priority: std::env::var("DYN_KVBM_HOST_OFFLOAD_PREFIX_MIN_PRIORITY")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(0),
offload_terminated_at_block: None,
} }
} }
...@@ -480,6 +494,7 @@ impl Slot for VllmConnectorSlot { ...@@ -480,6 +494,7 @@ impl Slot for VllmConnectorSlot {
self.tokens_cached_from_disk = 0; self.tokens_cached_from_disk = 0;
self.performed_cache_lookup = false; self.performed_cache_lookup = false;
self.total_blocks_queried = 0; self.total_blocks_queried = 0;
self.offload_terminated_at_block = None;
} }
fn reset(&mut self) { fn reset(&mut self) {
...@@ -519,7 +534,19 @@ impl Slot for VllmConnectorSlot { ...@@ -519,7 +534,19 @@ impl Slot for VllmConnectorSlot {
block_ids: &[BlockId], block_ids: &[BlockId],
num_computed_tokens: usize, num_computed_tokens: usize,
num_scheduled_tokens: usize, num_scheduled_tokens: usize,
priorities: Option<&[u32]>,
) -> Result<(), SlotError> { ) -> Result<(), SlotError> {
// Validate contract: priorities must match block_ids length when provided
if let Some(prios) = priorities {
assert_eq!(
prios.len(),
block_ids.len(),
"priorities length ({}) must match block_ids length ({})",
prios.len(),
block_ids.len()
);
}
if !tokens.is_empty() { if !tokens.is_empty() {
tracing::debug!( tracing::debug!(
"appending {} newly decoded tokens to sequence", "appending {} newly decoded tokens to sequence",
...@@ -542,6 +569,18 @@ impl Slot for VllmConnectorSlot { ...@@ -542,6 +569,18 @@ impl Slot for VllmConnectorSlot {
self.device_blocks.extend(block_ids); self.device_blocks.extend(block_ids);
} }
// Early exit if offload has been permanently terminated.
// This ensures global contiguity: once a gap is created by priority filtering,
// no subsequent blocks will be offloaded for this request.
if let Some(terminated_at) = self.offload_terminated_at_block {
tracing::debug!(
"offload terminated at block {}; skipping offload evaluation",
terminated_at
);
self.current_position += num_scheduled_tokens;
return Ok(());
}
// we should have enough device blocks to cover the newly scheduled tokens // we should have enough device blocks to cover the newly scheduled tokens
let next_position = self.current_position + num_scheduled_tokens; let next_position = self.current_position + num_scheduled_tokens;
assert!( assert!(
...@@ -585,34 +624,121 @@ impl Slot for VllmConnectorSlot { ...@@ -585,34 +624,121 @@ impl Slot for VllmConnectorSlot {
); );
if num_candidate_blocks != 0 { if num_candidate_blocks != 0 {
// do we have a mechanism for skipping gpu cache hit blocks? not sure yet. // Get candidate block IDs
// for now, offload all the blocks to the host let candidate_block_ids: Vec<usize> = self
let offload_block_ids: Vec<usize> = self
.device_blocks .device_blocks
.iter() .iter()
.skip(self.evaluated_blocks) .skip(self.evaluated_blocks)
.take(num_candidate_blocks) .take(num_candidate_blocks)
.copied() .copied()
.collect::<Vec<_>>(); .collect();
// Get candidate priorities from the priorities parameter.
// When priorities are provided, extract priorities for candidate blocks.
let candidate_priorities: Vec<u32> = if let Some(prios) = priorities {
let new_blocks_start = self.device_blocks.len() - block_ids.len();
let candidate_start = self.evaluated_blocks;
if candidate_start >= new_blocks_start {
let prio_offset = candidate_start - new_blocks_start;
debug_assert!(
prio_offset + num_candidate_blocks <= prios.len(),
"prio_offset ({}) + num_candidate_blocks ({}) > prios.len() ({}); \
candidate_start={}, new_blocks_start={}, device_blocks.len()={}, block_ids.len()={}",
prio_offset,
num_candidate_blocks,
prios.len(),
candidate_start,
new_blocks_start,
self.device_blocks.len(),
block_ids.len()
);
prios
.iter()
.skip(prio_offset)
.take(num_candidate_blocks)
.copied()
.collect()
} else {
vec![0; num_candidate_blocks]
}
} else {
vec![0; num_candidate_blocks]
};
assert_eq!( assert_eq!(
offload_block_ids.len(), candidate_block_ids.len(),
num_candidate_blocks, num_candidate_blocks,
"device block overflow - candidate blocks exceed block count at offset {}", "device block overflow - candidate blocks exceed block count at offset {}",
self.evaluated_blocks self.evaluated_blocks
); );
// Apply contiguous priority filtering: find how many blocks from the start
// meet the minimum priority threshold. Stop at first block below threshold.
let num_blocks_to_offload = if self.offload_min_priority > 0 {
candidate_priorities
.iter()
.take_while(|&&priority| priority >= self.offload_min_priority)
.count()
} else {
num_candidate_blocks
};
if num_blocks_to_offload > 0 {
if self.offload_min_priority > 0 {
tracing::debug!(
"priority filtering: offloading {}/{} blocks (threshold={})",
num_blocks_to_offload,
num_candidate_blocks,
self.offload_min_priority
);
}
let offload_block_ids: Vec<usize> = candidate_block_ids
.into_iter()
.take(num_blocks_to_offload)
.collect();
let offload_token_blocks: Vec<TokenBlock> = self let offload_token_blocks: Vec<TokenBlock> = self
.sequence .sequence
.blocks() .blocks()
.iter() .iter()
.skip(self.evaluated_blocks) .skip(self.evaluated_blocks)
.take(num_candidate_blocks) .take(num_blocks_to_offload)
.cloned() .cloned()
.collect::<Vec<_>>(); .collect();
self.offload_blocks(&offload_block_ids, &offload_token_blocks) let offload_priorities: Vec<u32> = candidate_priorities
.iter()
.take(num_blocks_to_offload)
.copied()
.collect();
self.offload_blocks(&offload_block_ids, &offload_token_blocks, &offload_priorities)
.expect("failed to offload blocks"); .expect("failed to offload blocks");
} else if self.offload_min_priority > 0 {
tracing::debug!(
"priority filtering: skipping all {} candidate blocks (threshold={})",
num_candidate_blocks,
self.offload_min_priority
);
}
// Check if we skipped any blocks due to priority filtering.
// If so, terminate offloading for this request to ensure global contiguity.
if num_blocks_to_offload < num_candidate_blocks {
let termination_index = self.evaluated_blocks + num_blocks_to_offload;
self.offload_terminated_at_block = Some(termination_index);
tracing::info!(
request_id = %self.request_id,
"offload terminated at block {}: priority {} < threshold {}; \
no further blocks will be offloaded",
termination_index,
candidate_priorities.get(num_blocks_to_offload).copied().unwrap_or(0),
self.offload_min_priority
);
}
self.evaluated_blocks += num_candidate_blocks; self.evaluated_blocks += num_candidate_blocks;
} }
...@@ -980,6 +1106,7 @@ impl VllmConnectorSlot { ...@@ -980,6 +1106,7 @@ impl VllmConnectorSlot {
&mut self, &mut self,
block_ids: &[BlockId], block_ids: &[BlockId],
token_blocks: &[TokenBlock], token_blocks: &[TokenBlock],
priorities: &[u32],
) -> Result<(), SlotError> { ) -> Result<(), SlotError> {
// Check if slot is in Finishing state before creating operations // Check if slot is in Finishing state before creating operations
// If we're finishing, don't create new operations // If we're finishing, don't create new operations
...@@ -988,12 +1115,14 @@ impl VllmConnectorSlot { ...@@ -988,12 +1115,14 @@ impl VllmConnectorSlot {
} }
assert!(block_ids.len() == token_blocks.len()); assert!(block_ids.len() == token_blocks.len());
assert!(block_ids.len() == priorities.len());
let operation_id = uuid::Uuid::new_v4(); let operation_id = uuid::Uuid::new_v4();
let xfer_req = LocalTransferRequest::Offload(LocalOffloadRequest::new( let xfer_req = LocalTransferRequest::Offload(LocalOffloadRequest::new(
self.request_id.clone(), self.request_id.clone(),
block_ids.to_vec(), block_ids.to_vec(),
token_blocks.to_vec(), token_blocks.to_vec(),
priorities.to_vec(),
operation_id, operation_id,
)); ));
...@@ -1088,6 +1217,8 @@ struct LocalOffloadRequest { ...@@ -1088,6 +1217,8 @@ struct LocalOffloadRequest {
request_id: String, request_id: String,
block_ids: Vec<BlockId>, block_ids: Vec<BlockId>,
token_blocks: Vec<TokenBlock>, token_blocks: Vec<TokenBlock>,
/// Priorities for each block, used to set BasicMetadata.priority during offload.
priorities: Vec<u32>,
operation_id: uuid::Uuid, operation_id: uuid::Uuid,
} }
...@@ -1096,13 +1227,16 @@ impl LocalOffloadRequest { ...@@ -1096,13 +1227,16 @@ impl LocalOffloadRequest {
request_id: String, request_id: String,
block_ids: Vec<BlockId>, block_ids: Vec<BlockId>,
token_blocks: Vec<TokenBlock>, token_blocks: Vec<TokenBlock>,
priorities: Vec<u32>,
operation_id: uuid::Uuid, operation_id: uuid::Uuid,
) -> Self { ) -> Self {
debug_assert!(block_ids.len() == token_blocks.len()); debug_assert!(block_ids.len() == token_blocks.len());
debug_assert!(block_ids.len() == priorities.len());
Self { Self {
request_id, request_id,
block_ids, block_ids,
token_blocks, token_blocks,
priorities,
operation_id, operation_id,
} }
} }
...@@ -1403,15 +1537,23 @@ where ...@@ -1403,15 +1537,23 @@ where
storage_name storage_name
); );
// 2. Apply token blocks // 2. Apply token blocks and set priorities
let mut blocks_to_register = Vec::new(); let mut blocks_to_register = Vec::new();
let zipped_blocks = blocks.into_iter().zip(token_blocks.into_iter()); let priorities = offload_req.priorities;
for (mut mutable_block, token_block) in zipped_blocks { for ((mut mutable_block, token_block), priority) in blocks
.into_iter()
.zip(token_blocks.into_iter())
.zip(priorities.into_iter())
{
mutable_block mutable_block
.apply_token_block(token_block.clone()) .apply_token_block(token_block.clone())
.map_err(|e| anyhow::anyhow!("failed to apply token block: {:?}", e))?; .map_err(|e| anyhow::anyhow!("failed to apply token block: {:?}", e))?;
// Set the priority on the block's metadata so it flows through to downstream processing
let updated_metadata = mutable_block.metadata().with_priority(priority);
mutable_block.update_metadata(updated_metadata);
blocks_to_register.push(mutable_block); blocks_to_register.push(mutable_block);
} }
tracing::debug!( tracing::debug!(
......
...@@ -378,7 +378,7 @@ impl Leader for KvConnectorLeader { ...@@ -378,7 +378,7 @@ impl Leader for KvConnectorLeader {
let scheduled_tokens = *scheduler_output let scheduled_tokens = *scheduler_output
.num_scheduled_tokens .num_scheduled_tokens
.get(request_id) .get(&new_req.request_id)
.unwrap_or(&0); .unwrap_or(&0);
slot.apply_scheduler_output( slot.apply_scheduler_output(
...@@ -386,6 +386,7 @@ impl Leader for KvConnectorLeader { ...@@ -386,6 +386,7 @@ impl Leader for KvConnectorLeader {
&new_req.block_ids, &new_req.block_ids,
new_req.num_computed_tokens, new_req.num_computed_tokens,
scheduled_tokens, scheduled_tokens,
new_req.priorities.as_deref(),
)?; )?;
let pending_ops_opt = slot.take_pending_operations(); let pending_ops_opt = slot.take_pending_operations();
...@@ -428,7 +429,7 @@ impl Leader for KvConnectorLeader { ...@@ -428,7 +429,7 @@ impl Leader for KvConnectorLeader {
let scheduled_tokens = *scheduler_output let scheduled_tokens = *scheduler_output
.num_scheduled_tokens .num_scheduled_tokens
.get(request_id) .get(&cached_req.request_id)
.unwrap_or(&0); .unwrap_or(&0);
slot.apply_scheduler_output( slot.apply_scheduler_output(
...@@ -436,6 +437,7 @@ impl Leader for KvConnectorLeader { ...@@ -436,6 +437,7 @@ impl Leader for KvConnectorLeader {
&cached_req.new_block_ids, &cached_req.new_block_ids,
cached_req.num_computed_tokens, cached_req.num_computed_tokens,
scheduled_tokens, scheduled_tokens,
cached_req.priorities.as_deref(),
)?; )?;
if let Some(pending_ops) = slot.take_pending_operations() { if let Some(pending_ops) = slot.take_pending_operations() {
......
...@@ -95,6 +95,10 @@ pub trait BlockMetadata: Default + std::fmt::Debug + Clone + Ord + Send + Sync + ...@@ -95,6 +95,10 @@ pub trait BlockMetadata: Default + std::fmt::Debug + Clone + Ord + Send + Sync +
/// The offload priority of the block. Higher priority blocks are offloaded first. /// The offload priority of the block. Higher priority blocks are offloaded first.
/// If the block should not be offloaded, return None. /// If the block should not be offloaded, return None.
fn offload_priority(&self) -> Option<u64>; fn offload_priority(&self) -> Option<u64>;
/// Returns a new metadata instance with the specified priority.
/// Used to carry priority through the block lifecycle for offload filtering.
fn with_priority(&self, priority: u32) -> Self;
} }
/// A trait for blocks that can be returned to the pool. /// A trait for blocks that can be returned to the pool.
...@@ -524,7 +528,38 @@ impl BlockMetadata for BasicMetadata { ...@@ -524,7 +528,38 @@ impl BlockMetadata for BasicMetadata {
fn offload_priority(&self) -> Option<u64> { fn offload_priority(&self) -> Option<u64> {
Some(self.priority as u64) Some(self.priority as u64)
} }
fn with_priority(&self, priority: u32) -> Self {
self.update_priority(priority)
}
} }
#[cfg(test)]
mod basic_metadata_tests {
use super::*;
#[test]
fn test_basic_metadata_with_priority() {
let metadata = BasicMetadata::default();
let updated = metadata.with_priority(75);
assert_eq!(updated.offload_priority(), Some(75));
}
#[test]
fn test_basic_metadata_with_priority_preserves_ticks() {
let mut metadata = BasicMetadata::default();
metadata.on_acquired(100);
metadata.on_returned(200);
let updated = metadata.with_priority(50);
assert_eq!(updated.priority(), 50);
assert_eq!(updated.acquired_tick(), 100);
assert_eq!(updated.returned_tick(), 200);
}
}
/// Collection that holds shared storage and layout /// Collection that holds shared storage and layout
#[derive(Debug)] #[derive(Debug)]
pub struct Blocks<L: BlockLayout, M: BlockMetadata> { pub struct Blocks<L: BlockLayout, M: BlockMetadata> {
......
...@@ -2267,6 +2267,163 @@ mod tests { ...@@ -2267,6 +2267,163 @@ mod tests {
Ok(()) Ok(())
} }
/// Test that metadata (priority) transfers correctly through the full G1→G2→G3 chain.
#[tokio::test]
async fn test_offload_transfer_metadata_to_disk() -> Result<()> {
let (offload_manager, device_pool, host_pool, disk_pool) =
build_pools(4, Some(4), Some(4), None)?;
let device_pool = device_pool.as_ref().unwrap();
let host_pool = host_pool.as_ref().unwrap();
let disk_pool = disk_pool.as_ref().unwrap();
// Create device block with non-default priority
let mut device_block = completed_block(device_pool, [0; 4]).await?;
populate_block(&device_block, 42)?;
let new_metadata = device_block.metadata().update_priority(42);
device_block.update_metadata(new_metadata);
let immutable_device_block = device_pool
.register_blocks(vec![device_block])
.await?
.into_iter()
.next()
.unwrap();
// Step 1: Offload G1→G2 (device to host)
offload_manager.offload(&immutable_device_block, 0).await?;
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let host_blocks = host_pool
.match_sequence_hashes(vec![immutable_device_block.sequence_hash()].as_slice())
.await?;
assert_eq!(host_blocks.len(), 1);
assert_eq!(
host_blocks[0].metadata().priority(),
42,
"G1→G2: Priority should transfer to host block"
);
// Step 2: Offload G2→G3 (host to disk)
offload_manager.offload(&host_blocks[0], 0).await?;
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
let disk_blocks = disk_pool
.match_sequence_hashes(vec![immutable_device_block.sequence_hash()].as_slice())
.await?;
assert_eq!(disk_blocks.len(), 1);
assert_eq!(
disk_blocks[0].metadata().priority(),
42,
"G2→G3: Priority should transfer to disk block"
);
Ok(())
}
/// Test that metadata (priority) transfers correctly when onboarding from G2→G1.
#[tokio::test]
async fn test_onboard_transfer_metadata_from_host() -> Result<()> {
let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None, None)?;
let _device_pool = device_pool.as_ref().unwrap();
let host_pool = host_pool.as_ref().unwrap();
// Create host block with non-default priority
let mut host_block = completed_block(host_pool, [0; 4]).await?;
populate_block(&host_block, 42)?;
let new_metadata = host_block.metadata().update_priority(42);
host_block.update_metadata(new_metadata);
let immutable_host_block = host_pool
.register_blocks(vec![host_block])
.await?
.into_iter()
.next()
.unwrap();
assert_eq!(
immutable_host_block.metadata().priority(),
42,
"Host block should have priority=42 before onboard"
);
// Onboard G2→G1 (host to device)
let onboarded_blocks = offload_manager
.onboard(vec![immutable_host_block.clone()], None)
.await??;
assert_eq!(onboarded_blocks.len(), 1);
assert_eq!(
onboarded_blocks[0].metadata().priority(),
42,
"G2→G1: Priority should transfer to device block after onboard"
);
Ok(())
}
/// Test that metadata is preserved through a full G1→G2→G1 cycle.
#[tokio::test]
async fn test_offload_onboard_preserves_metadata() -> Result<()> {
let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None, None)?;
let device_pool = device_pool.as_ref().unwrap();
let host_pool = host_pool.as_ref().unwrap();
// Create device block with non-default priority
let mut device_block = completed_block(device_pool, [0; 4]).await?;
populate_block(&device_block, 42)?;
let new_metadata = device_block.metadata().update_priority(42);
device_block.update_metadata(new_metadata);
let immutable_device_block = device_pool
.register_blocks(vec![device_block])
.await?
.into_iter()
.next()
.unwrap();
// Step 1: Offload G1→G2
offload_manager.offload(&immutable_device_block, 0).await?;
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let host_blocks = host_pool
.match_sequence_hashes(vec![immutable_device_block.sequence_hash()].as_slice())
.await?;
assert_eq!(host_blocks.len(), 1);
assert_eq!(
host_blocks[0].metadata().priority(),
42,
"G1→G2: Priority should transfer to host block"
);
// Drop device block and allocate new ones to evict it from device pool
drop(immutable_device_block);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let temp_blocks = device_pool.allocate_blocks(4).await?;
drop(temp_blocks);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// Step 2: Onboard G2→G1
let onboarded_blocks = offload_manager
.onboard(vec![host_blocks[0].clone()], None)
.await??;
assert_eq!(onboarded_blocks.len(), 1);
assert_eq!(
onboarded_blocks[0].metadata().priority(),
42,
"G2→G1: Priority should be preserved through full cycle"
);
Ok(())
}
#[tokio::test] #[tokio::test]
async fn test_onboard_duplicate() -> Result<()> { async fn test_onboard_duplicate() -> Result<()> {
let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None, None)?; let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None, None)?;
......
...@@ -132,15 +132,7 @@ impl<S: Storage, L: LocalityProvider, M: BlockMetadata> InactiveBlockPool<S, L, ...@@ -132,15 +132,7 @@ impl<S: Storage, L: LocalityProvider, M: BlockMetadata> InactiveBlockPool<S, L,
// If we already have an entry for this sequence hash or the block is reset, // If we already have an entry for this sequence hash or the block is reset,
// we need to move it to the uninitialized set // we need to move it to the uninitialized set
match block.state() { match block.state() {
BlockState::Reset => { BlockState::Reset | BlockState::Partial(_) | BlockState::Complete(_) => {
self.uninitialized_set.push_back(block);
}
BlockState::Partial(_) => {
let mut block = block;
block.reset();
self.uninitialized_set.push_back(block);
}
BlockState::Complete(_) => {
let mut block = block; let mut block = block;
block.reset(); block.reset();
self.uninitialized_set.push_back(block); self.uninitialized_set.push_back(block);
...@@ -571,6 +563,14 @@ pub(crate) mod tests { ...@@ -571,6 +563,14 @@ pub(crate) mod tests {
fn offload_priority(&self) -> Option<u64> { fn offload_priority(&self) -> Option<u64> {
Some(self.priority as u64) Some(self.priority as u64)
} }
fn with_priority(&self, priority: u32) -> Self {
Self {
priority,
returned_tick: self.returned_tick,
acquired_tick: self.acquired_tick,
}
}
} }
type TestPriorityKey = PriorityKey<TestMetadata>; type TestPriorityKey = PriorityKey<TestMetadata>;
...@@ -622,6 +622,42 @@ pub(crate) mod tests { ...@@ -622,6 +622,42 @@ pub(crate) mod tests {
assert!(map.is_empty()); assert!(map.is_empty());
} }
#[test]
fn test_with_priority_updates_priority() {
let metadata = TestMetadata {
priority: 10,
returned_tick: 100,
acquired_tick: 50,
};
let updated = metadata.with_priority(80);
assert_eq!(updated.priority, 80);
assert_eq!(updated.returned_tick, 100); // preserved
assert_eq!(updated.acquired_tick, 50); // preserved
}
#[test]
fn test_with_priority_immutability() {
let original = TestMetadata {
priority: 35,
returned_tick: 10,
acquired_tick: 5,
};
let updated = original.with_priority(100);
assert_eq!(original.priority, 35); // unchanged
assert_eq!(updated.priority, 100);
}
#[test]
fn test_with_priority_boundary_values() {
let metadata = TestMetadata::default();
assert_eq!(metadata.with_priority(0).priority, 0);
assert_eq!(metadata.with_priority(100).priority, 100);
assert_eq!(metadata.with_priority(u32::MAX).priority, u32::MAX);
}
// Helper function to create a sequence of tokens // Helper function to create a sequence of tokens
pub fn create_token_sequence(values: &[u32]) -> Tokens { pub fn create_token_sequence(values: &[u32]) -> Tokens {
let tokens: Vec<Token> = values.iter().map(|&v| Token::from(v)).collect(); let tokens: Vec<Token> = values.iter().map(|&v| Token::from(v)).collect();
...@@ -881,4 +917,144 @@ pub(crate) mod tests { ...@@ -881,4 +917,144 @@ pub(crate) mod tests {
pool.available_blocks() pool.available_blocks()
); );
} }
/// Test that validates blocks allocated from the pool always have the default
/// priority (0), regardless of what priority they had in a previous allocation.
///
/// This test exposes a bug where blocks in Reset state that are returned to
/// the pool retain their non-default priority when re-acquired, because the
/// uninitialized_set path in acquire_free_block() does not call block.reset().
#[test]
fn test_allocated_blocks_have_default_priority() {
let mut pool = create_block_pool(3);
// Step 1: Acquire blocks (they come from uninitialized_set in Reset state)
let mut blocks = pool.acquire_free_blocks(3).unwrap();
assert_eq!(blocks.len(), 3);
// Verify initial priority is 0 (default)
for block in &blocks {
assert_eq!(
block.metadata().offload_priority(),
Some(0),
"Newly acquired block should have default priority"
);
}
// Step 2: Set non-default priority on blocks (keep them in Reset state)
for block in &mut blocks {
let updated_metadata = block.metadata().with_priority(100);
block.update_metadata(updated_metadata);
assert_eq!(block.metadata().offload_priority(), Some(100));
}
// Step 3: Return blocks to inactive pool
// Since blocks are in Reset state, insert() will put them in uninitialized_set
// WITHOUT calling reset()
pool.return_blocks(blocks);
assert_eq!(pool.available_blocks(), 3);
// Step 4: Acquire blocks again
let reacquired_blocks = pool.acquire_free_blocks(3).unwrap();
// Step 5: Verify priority is reset to default (0)
for (i, block) in reacquired_blocks.iter().enumerate() {
assert_eq!(
block.metadata().offload_priority(),
Some(0),
"Block {} should have default priority after reallocation, but has {:?}",
i,
block.metadata().offload_priority()
);
}
}
/// Validates that after pool.reset(), all blocks have default priority
/// regardless of what priority they had when registered.
///
/// This test follows the exact flow described:
/// 1. Create a tokens sequence
/// 2. Allocate mutable blocks
/// 3. Apply the tokens sequence and some non-default priority
/// 4. Release them to the inactive pool (they go to priority_set as Registered)
/// 5. Reset the inactive pool
/// 6. Validate all blocks have default priority
///
/// This test should PASS because blocks evicted from priority_set go through
/// block.reset() which clears the priority.
#[test]
fn test_pool_reset_clears_priority_on_registered_blocks() {
let async_runtime = tokio::runtime::Runtime::new().unwrap();
const BLOCK_SIZE: u32 = 4;
let mut pool = create_block_pool(3);
assert_eq!(pool.available_blocks(), 3);
// Step 1 & 2: Create tokens and allocate blocks
let tokens = create_token_sequence(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]);
let (mut blocks, _matched) = acquire_blocks(
tokens,
BLOCK_SIZE,
&mut pool,
async_runtime.handle().clone(),
);
assert_eq!(blocks.len(), 3);
// Verify blocks are in Registered state
for block in &blocks {
assert!(
block.state().is_registered(),
"Block should be in Registered state after acquire_blocks"
);
}
// Step 3: Set non-default priority on blocks
for block in &mut blocks {
let updated_metadata = block.metadata().with_priority(100);
block.update_metadata(updated_metadata);
assert_eq!(
block.metadata().offload_priority(),
Some(100),
"Priority should be set to 100"
);
}
// Step 4: Release blocks to inactive pool
// Since blocks are Registered, they go to priority_set
pool.return_blocks(blocks);
assert_eq!(pool.available_blocks(), 3);
// Verify blocks are in priority_set (not uninitialized_set)
let (priority_count, uninit_count) = pool.status();
assert_eq!(priority_count, 3, "All blocks should be in priority_set");
assert_eq!(uninit_count, 0, "No blocks should be in uninitialized_set");
// Step 5: Reset the pool
// This calls acquire_free_blocks() which evicts from priority_set
// and calls block.reset() on each, then returns them
pool.reset().expect("Pool reset should succeed");
// After reset, all blocks should be in uninitialized_set
let (priority_count, uninit_count) = pool.status();
assert_eq!(
priority_count, 0,
"priority_set should be empty after reset"
);
assert_eq!(
uninit_count, 3,
"All blocks should be in uninitialized_set after reset"
);
// Step 6: Acquire all blocks and verify priority is default (0)
let reset_blocks = pool.acquire_free_blocks(3).unwrap();
for (i, block) in reset_blocks.iter().enumerate() {
assert_eq!(
block.metadata().offload_priority(),
Some(0),
"Block {} should have default priority after pool reset, but has {:?}",
i,
block.metadata().offload_priority()
);
}
}
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment